Pipeline

Submodules

kale.pipeline.deepdta module

class kale.pipeline.deepdta.BaseDTATrainer(drug_encoder, target_encoder, decoder, lr=0.001, ci_metric=False, **kwargs)

Bases: LightningModule

Base class for all drug target encoder-decoder architecture models, which is based on pytorch lightning wrapper, for more details about pytorch lightning, please check https://github.com/PyTorchLightning/pytorch-lightning. If you inherit from this class, a forward pass function must be implemented.

Parameters
  • drug_encoder – drug information encoder.

  • target_encoder – target information encoder.

  • decoder – drug-target representations decoder.

  • lr – learning rate. (default: 0.001)

  • ci_metric – calculate the Concordance Index (CI) metric, and the operation is time-consuming for large-scale

  • (default (dataset.) – False)

configure_optimizers()

Config adam as default optimizer.

forward(x_drug, x_target)

Same as torch.nn.Module.forward()

training_step(train_batch, batch_idx)

Compute and return the training loss on one step

validation_step(valid_batch, batch_idx)

Compute and return the validation loss on one step

test_step(test_batch, batch_idx)

Compute and return the test loss on one step

training: bool
class kale.pipeline.deepdta.DeepDTATrainer(drug_encoder, target_encoder, decoder, lr=0.001, ci_metric=False, **kwargs)

Bases: BaseDTATrainer

An implementation of DeepDTA model based on BaseDTATrainer. :param drug_encoder: drug CNN encoder. :param target_encoder: target CNN encoder. :param decoder: drug-target MLP decoder. :param lr: learning rate.

forward(x_drug, x_target)

Forward propagation in DeepDTA architecture.

Parameters
  • x_drug – drug sequence encoding.

  • x_target – target protein sequence encoding.

validation_step(valid_batch, batch_idx)
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool

kale.pipeline.domain_adapter module

Domain adaptation systems (pipelines) with three types of architectures

This module takes individual modules as input and organises them into an architecture. This is taken directly from https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/models/architectures.py with minor changes.

This module uses PyTorch Lightning to standardize the flow.

class kale.pipeline.domain_adapter.GradReverse(*args, **kwargs)

Bases: Function

The gradient reversal layer (GRL)

This is defined in the DANN paper http://jmlr.org/papers/volume17/15-239/15-239.pdf

Forward pass: identity transformation. Backward propagation: flip the sign of the gradient.

From https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/models/layers.py

static forward(ctx, x, alpha)
static backward(ctx, grad_output)
kale.pipeline.domain_adapter.set_requires_grad(model, requires_grad=True)

Configure whether gradients are required for a model

kale.pipeline.domain_adapter.get_aggregated_metrics(metric_name_list, metric_outputs)

Get a dictionary of the mean metric values (to log) from metric names and their values

kale.pipeline.domain_adapter.get_aggregated_metrics_from_dict(input_metric_dict)

Get a dictionary of the mean metric values (to log) from a dictionary of metric values

kale.pipeline.domain_adapter.get_metrics_from_parameter_dict(parameter_dict, device)

Get a key-value pair from the hyperparameter dictionary

class kale.pipeline.domain_adapter.Method(value)

Bases: Enum

Lists the available methods. Provides a few methods that group the methods by type.

Source = 'Source'
DANN = 'DANN'
CDAN = 'CDAN'
CDAN_E = 'CDAN-E'
FSDANN = 'FSDANN'
MME = 'MME'
WDGRL = 'WDGRL'
WDGRLMod = 'WDGRLMod'
DAN = 'DAN'
JAN = 'JAN'
is_mmd_method()
is_dann_method()
is_cdan_method()
is_fewshot_method()
allow_supervised()
kale.pipeline.domain_adapter.create_mmd_based(method: Method, dataset, feature_extractor, task_classifier, **train_params)

MMD-based deep learning methods for domain adaptation: DAN and JAN

kale.pipeline.domain_adapter.create_dann_like(method: Method, dataset, feature_extractor, task_classifier, critic, **train_params)

DANN-based deep learning methods for domain adaptation: DANN, CDAN, CDAN+E

kale.pipeline.domain_adapter.create_fewshot_trainer(method: Method, dataset, feature_extractor, task_classifier, critic, **train_params)

DANN-based few-shot deep learning methods for domain adaptation: FSDANN, MME

class kale.pipeline.domain_adapter.BaseAdaptTrainer(dataset, feature_extractor, task_classifier, method: Optional[str] = None, lambda_init: float = 1.0, adapt_lambda: bool = True, adapt_lr: bool = True, nb_init_epochs: int = 10, nb_adapt_epochs: int = 50, batch_size: int = 32, init_lr: float = 0.001, optimizer: Optional[dict] = None)

Bases: LightningModule

Base class for all domain adaptation architectures.

This class implements the classic building blocks used in all the derived architectures for domain adaptation. If you inherit from this class, you will have to implement only:

  • a forward pass

  • a compute_loss function that returns the task loss \(\mathcal{L}_c\) and adaptation loss

\(\mathcal{L}_a\), as well as a dictionary for summary statistics and other metrics you may want to have access to.

The default training step uses only the task loss \(\mathcal{L}_c\) during warmup, then uses the loss defined as:

\(\mathcal{L} = \mathcal{L}_c + \lambda \mathcal{L}_a\),

where \(\lambda\) will follow the schedule defined by the DANN paper:

\(\lambda_p = \frac{2}{1 + \exp{(-\gamma \cdot p)}} - 1\) where \(p\) the learning progress changes linearly from 0 to 1.

Parameters
  • dataset (kale.loaddata.multi_domain) – the multi-domain datasets to be used for train, validation, and tests.

  • feature_extractor (torch.nn.Module) – the feature extractor network (mapping inputs \(x\in\mathcal{X}\) to a latent space \(\mathcal{Z}\),)

  • task_classifier (torch.nn.Module) – the task classifier network that learns to predict labels \(y \in \mathcal{Y}\) from latent vectors,

  • method (Method, optional) – the method implemented by the class. Defaults to None. Mostly useful when several methods may be implemented using the same class.

  • lambda_init (float, optional) – Weight attributed to the adaptation part of the loss. Defaults to 1.0.

  • adapt_lambda (bool, optional) – Whether to make lambda grow from 0 to 1 following the schedule from the DANN paper. Defaults to True.

  • adapt_lr (bool, optional) – Whether to use the schedule for the learning rate as defined in the DANN paper. Defaults to True.

  • nb_init_epochs (int, optional) – Number of warmup epochs (during which lambda=0, training only on the source). Defaults to 10.

  • nb_adapt_epochs (int, optional) – Number of training epochs. Defaults to 50.

  • batch_size (int, optional) – Defaults to 32.

  • init_lr (float, optional) – Initial learning rate. Defaults to 1e-3.

  • optimizer (dict, optional) – Optimizer parameters, a dictionary with 2 keys: “type”: a string in (“SGD”, “Adam”, “AdamW”) “optim_params”: kwargs for the above PyTorch optimizer. Defaults to None.

property method
get_parameters_watch_list()

Update this list for parameters to watch while training (ie log with MLFlow)

forward(x)
compute_loss(batch, split_name='valid')

Define the loss of the model

Parameters
  • batch (tuple) – batches returned by the MultiDomainLoader.

  • split_name (str, optional) – learning stage (one of [“train”, “valid”, “test”]). Defaults to “valid” for validation. “train” is for training and “test” for testing. This is currently used only for naming the metrics used for logging.

Returns

a 3-element tuple with task_loss, adv_loss, log_metrics. log_metrics should be a dictionary.

Raises

NotImplementedError – children of this classes should implement this method.

training_step(batch, batch_nb)

The most generic of training steps

Parameters
  • batch (tuple) – the batch as returned by the MultiDomainLoader dataloader iterator: 2 tuples: (x_source, y_source), (x_target, y_target) in the unsupervised setting 3 tuples: (x_source, y_source), (x_target_labeled, y_target_labeled), (x_target_unlabeled, y_target_unlabeled) in the semi-supervised setting

  • batch_nb (int) – id of the current batch.

Returns

must contain a “loss” key with the loss to be used for back-propagation.

see pytorch-lightning for more details.

Return type

dict

validation_step(batch, batch_nb)
validation_epoch_end(outputs)
test_step(batch, batch_nb)
test_epoch_end(outputs)
configure_optimizers()
train_dataloader()
val_dataloader()
test_dataloader()
training: bool
class kale.pipeline.domain_adapter.BaseDANNLike(dataset, feature_extractor, task_classifier, critic, alpha=1.0, entropy_reg=0.0, adapt_reg=True, batch_reweighting=False, **base_params)

Bases: BaseAdaptTrainer

Common API for DANN-based methods: DANN, CDAN, CDAN+E, WDGRL, MME, FSDANN

get_parameters_watch_list()

Update this list for parameters to watch while training (ie log with MLFlow)

compute_loss(batch, split_name='valid')
validation_epoch_end(outputs)
test_epoch_end(outputs)
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.domain_adapter.DANNTrainer(dataset, feature_extractor, task_classifier, critic, method=None, **base_params)

Bases: BaseDANNLike

This class implements the DANN architecture from Ganin, Yaroslav, et al. “Domain-adversarial training of neural networks.” The Journal of Machine Learning Research (2016) https://arxiv.org/abs/1505.07818

forward(x)
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.domain_adapter.CDANTrainer(dataset, feature_extractor, task_classifier, critic, use_entropy=False, use_random=False, random_dim=1024, **base_params)

Bases: BaseDANNLike

Implements CDAN: Long, Mingsheng, et al. “Conditional adversarial domain adaptation.” Advances in Neural Information Processing Systems. 2018. https://papers.nips.cc/paper/7436-conditional-adversarial-domain-adaptation.pdf

forward(x)
compute_loss(batch, split_name='valid')
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.domain_adapter.WDGRLTrainer(dataset, feature_extractor, task_classifier, critic, k_critic=5, gamma=10, beta_ratio=0, **base_params)

Bases: BaseDANNLike

Implements WDGRL as described in Shen, Jian, et al. “Wasserstein distance guided representation learning for domain adaptation.” Thirty-Second AAAI Conference on Artificial Intelligence. 2018. https://arxiv.org/pdf/1707.01217.pdf

This class also implements the asymmetric ($eta$) variant described in: Wu, Yifan, et al. “Domain adaptation with asymmetrically-relaxed distribution alignment.” ICML (2019) https://arxiv.org/pdf/1903.01689.pdf

forward(x)
compute_loss(batch, split_name='valid')
critic_update_steps(batch)
training_step(batch, batch_id)
configure_optimizers()
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.domain_adapter.WDGRLTrainerMod(dataset, feature_extractor, task_classifier, critic, k_critic=5, gamma=10, beta_ratio=0, **base_params)

Bases: WDGRLTrainer

Implements a modified version WDGRL as described in Shen, Jian, et al. “Wasserstein distance guided representation learning for domain adaptation.” Thirty-Second AAAI Conference on Artificial Intelligence. 2018. https://arxiv.org/pdf/1707.01217.pdf

This class also implements the asymmetric ($eta$) variant described in: Wu, Yifan, et al. “Domain adaptation with asymmetrically-relaxed distribution alignment.” ICML (2019) https://arxiv.org/pdf/1903.01689.pdf

critic_update_steps(batch)
training_step(batch, batch_id, optimizer_idx)
optimizer_step(current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False)
configure_optimizers()
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.domain_adapter.FewShotDANNTrainer(dataset, feature_extractor, task_classifier, critic, method, **base_params)

Bases: BaseDANNLike

Implements adaptations of DANN to the semi-supervised setting

naive: task classifier is trained on labeled target data, in addition to source data. MME: immplements Saito, Kuniaki, et al. “Semi-supervised domain adaptation via minimax entropy.” Proceedings of the IEEE International Conference on Computer Vision. 2019 https://arxiv.org/pdf/1904.06487.pdf

forward(x)
compute_loss(batch, split_name='valid')
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.domain_adapter.BaseMMDLike(dataset, feature_extractor, task_classifier, kernel_mul=2.0, kernel_num=5, **base_params)

Bases: BaseAdaptTrainer

Common API for MME-based deep learning DA methods: DAN, JAN

forward(x)
compute_loss(batch, split_name='valid')
validation_epoch_end(outputs)
test_epoch_end(outputs)
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.domain_adapter.DANTrainer(dataset, feature_extractor, task_classifier, **base_params)

Bases: BaseMMDLike

This is an implementation of DAN Long, Mingsheng, et al. “Learning Transferable Features with Deep Adaptation Networks.” International Conference on Machine Learning. 2015. http://proceedings.mlr.press/v37/long15.pdf code based on https://github.com/thuml/Xlearn.

training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.domain_adapter.JANTrainer(dataset, feature_extractor, task_classifier, kernel_mul=(2.0, 2.0), kernel_num=(5, 1), **base_params)

Bases: BaseMMDLike

This is an implementation of JAN Long, Mingsheng, et al. “Deep transfer learning with joint adaptation networks.” International Conference on Machine Learning, 2017. https://arxiv.org/pdf/1605.06636.pdf code based on https://github.com/thuml/Xlearn.

training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool

kale.pipeline.mpca_trainer module

Implementation of MPCA->Feature Selection->Linear SVM/LogisticRegression Pipeline

References

[1] Swift, A. J., Lu, H., Uthoff, J., Garg, P., Cogliano, M., Taylor, J., … & Kiely, D. G. (2020). A machine learning cardiac magnetic resonance approach to extract disease features and automate pulmonary arterial hypertension diagnosis. European Heart Journal-Cardiovascular Imaging. [2] Song, X., Meng, L., Shi, Q., & Lu, H. (2015, October). Learning tensor-based features for whole-brain fMRI classification. In International Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 613-620). Springer, Cham. [3] Lu, H., Plataniotis, K. N., & Venetsanopoulos, A. N. (2008). MPCA: Multilinear principal component analysis of tensor objects. IEEE Transactions on Neural Networks, 19(1), 18-39.

class kale.pipeline.mpca_trainer.MPCATrainer(classifier='svc', classifier_params='auto', classifier_param_grid=None, mpca_params=None, n_features=None, search_params=None)

Bases: BaseEstimator, ClassifierMixin

Trainer of pipeline: MPCA->Feature selection->Classifier

Parameters
  • classifier (str, optional) – Available classifier options: {“svc”, “linear_svc”, “lr”}, where “svc” trains a support vector classifier, supports both linear and non-linear kernels, optimizes with library “libsvm”; “linear_svc” trains a support vector classifier with linear kernel only, and optimizes with library “liblinear”, which suppose to be faster and better in handling large number of samples; and “lr” trains a classifier with logistic regression. Defaults to “svc”.

  • classifier_params (dict, optional) – Parameters of classifier. Defaults to ‘auto’.

  • classifier_param_grid (dict, optional) – Grids for searching the optimal hyper-parameters. Works only when classifier_params == “auto”. Defaults to None by searching from the following hyper-parameter values: 1. svc, {“kernel”: [“linear”], “C”: [0.0001, 0.001, 0.01, 0.1, 1, 10, 100], “max_iter”: [50000]}, 2. linear_svc, {“C”: [0.0001, 0.001, 0.01, 0.1, 1, 10, 100]}, 3. lr, {“C”: [0.0001, 0.001, 0.01, 0.1, 1, 10, 100]}

  • mpca_params (dict, optional) – Parameters of MPCA, e.g., {“var_ratio”: 0.8}. Defaults to None, i.e., using the default parameters (https://pykale.readthedocs.io/en/latest/kale.embed.html#module-kale.embed.mpca).

  • n_features (int, optional) – Number of features for feature selection. Defaults to None, i.e., all features after dimension reduction will be used.

  • search_params (dict, optional) – Parameters of grid search, for more detail please see https://scikit-learn.org/stable/modules/grid_search.html#grid-search . Defaults to None, i.e., using the default params: {“cv”: 5}.

fit(x, y)

Fit a pipeline with the given data x and labels y

Parameters
  • x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)

  • y (array-like) – data labels, shape (n_samples, )

Returns

self

predict(x)

Predict the labels for the given data x

Parameters

x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)

Returns

Predicted labels, shape (n_samples, )

Return type

array-like

decision_function(x)

Decision scores of each class for the given data x

Parameters

x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)

Returns

decision scores, shape (n_samples,) for binary case, else (n_samples, n_class)

Return type

array-like

predict_proba(x)

Probability of each class for the given data x. Not supported by “linear_svc”.

Parameters

x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)

Returns

probabilities, shape (n_samples, n_class)

Return type

array-like

kale.pipeline.multi_domain_adapter module

Multi-source domain adaptation pipelines

kale.pipeline.multi_domain_adapter.create_ms_adapt_trainer(method: str, dataset, feature_extractor, task_classifier, **train_params)

Methods for multi-source domain adaptation

Parameters
  • method (str) – Multi-source domain adaptation method, M3SDA or MFSAN

  • dataset (kale.loaddata.multi_domain.MultiDomainAdapDataset) – the multi-domain datasets to be used for train, validation, and tests.

  • feature_extractor (torch.nn.Module) – feature extractor network

  • task_classifier (torch.nn.Module) – task classifier network

Returns

Multi-source domain adaptation trainer.

Return type

[pl.LightningModule]

class kale.pipeline.multi_domain_adapter.BaseMultiSourceTrainer(dataset, feature_extractor, task_classifier, n_classes: int, target_domain: str, **base_params)

Bases: BaseAdaptTrainer

Base class for all domain adaptation architectures

Parameters
  • dataset (kale.loaddata.multi_domain) – the multi-domain datasets to be used for train, validation, and tests.

  • feature_extractor (torch.nn.Module) – the feature extractor network

  • task_classifier (torch.nn.Module) – the task classifier network

  • n_classes (int) – number of classes

  • target_domain (str) – target domain name

forward(x)
compute_loss(batch, split_name='valid')
validation_epoch_end(outputs)
test_epoch_end(outputs)
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.multi_domain_adapter.M3SDATrainer(dataset, feature_extractor, task_classifier, n_classes: int, target_domain: str, k_moment: int = 3, **base_params)

Bases: BaseMultiSourceTrainer

Moment matching for multi-source domain adaptation (M3SDA).

Reference:

Peng, X., Bai, Q., Xia, X., Huang, Z., Saenko, K., & Wang, B. (2019). Moment matching for multi-source domain adaptation. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 1406-1415). https://openaccess.thecvf.com/content_ICCV_2019/html/Peng_Moment_Matching_for_Multi-Source_Domain_Adaptation_ICCV_2019_paper.html

compute_loss(batch, split_name='valid')
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.multi_domain_adapter.MFSANTrainer(dataset, feature_extractor, task_classifier, n_classes: int, target_domain: str, domain_feat_dim: int = 100, kernel_mul: float = 2.0, kernel_num: int = 5, input_dimension: int = 2, **base_params)

Bases: BaseMultiSourceTrainer

Multiple Feature Spaces Adaptation Network (MFSAN)

Reference: Zhu, Y., Zhuang, F. and Wang, D., 2019, July. Aligning domain-specific distribution and classifier

for cross-domain classification from multiple sources. In AAAI. https://ojs.aaai.org/index.php/AAAI/article/view/4551

Original implementation: https://github.com/easezyc/deep-transfer-learning/tree/master/MUDA/MFSAN

compute_loss(batch, split_name='valid')
cls_discrepancy(x)

Compute discrepancy between all classifiers’ probabilistic outputs

training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.multi_domain_adapter.CoIRLS(kernel='linear', kernel_kwargs=None, alpha=1.0, lambda_=1.0)

Bases: BaseEstimator, ClassifierMixin

Covariate-Independence Regularized Least Squares (CoIRLS)

Parameters
  • kernel (str, optional) – {“linear”, “rbf”, “poly”}. Kernel to use. Defaults to “linear”.

  • kernel_kwargs (dict or None, optional) – Hyperparameter for the kernel. Defaults to None.

  • alpha (float, optional) – Hyperparameter of the l2 (Ridge) penalty. Defaults to 1.0.

  • lambda (float, optional) – Hyperparameter of the covariate dependence. Defaults to 1.0.

Reference:
[1] Zhou, S., 2022. Interpretable Domain-Aware Learning for Neuroimage Classification (Doctoral dissertation,

University of Sheffield).

[2] Zhou, S., Li, W., Cox, C.R., & Lu, H. (2020). Side Information Dependence as a Regularizer for Analyzing

Human Brain Conditions across Cognitive Experiments. AAAI 2020, New York, USA.

fit(x, y, covariates)

fit a model with input data, labels and covariates

Parameters
  • x (np.ndarray or tensor) – shape (n_samples, n_features)

  • y (np.ndarray or tensor) – shape (n_samples, )

  • covariates (np.ndarray or tensor) – (n_samples, n_covariates)

predict(x)

Predict labels for data x

Parameters

x (np.ndarray or tensor) – Samples need prediction, shape (n_samples, n_features)

Returns

Predicted labels, shape (n_samples, )

Return type

y (np.ndarray)

decision_function(x)

Compute decision scores for data x

Parameters

x (np.ndarray or tensor) – Samples need decision scores, shape (n_samples, n_features)

Returns

Decision scores, shape (n_samples, )

Return type

scores (np.ndarray)

kale.pipeline.video_domain_adapter module

Domain adaptation systems (pipelines) for video data, e.g., for action recognition. Most are inherited from kale.pipeline.domain_adapter.

kale.pipeline.video_domain_adapter.create_mmd_based_video(method: Method, dataset, image_modality, feature_extractor, task_classifier, **train_params)

MMD-based deep learning methods for domain adaptation on video data: DAN and JAN

kale.pipeline.video_domain_adapter.create_dann_like_video(method: Method, dataset, image_modality, feature_extractor, task_classifier, critic, **train_params)

DANN-based deep learning methods for domain adaptation on video data: DANN, CDAN, CDAN+E

class kale.pipeline.video_domain_adapter.BaseMMDLikeVideo(dataset, image_modality, feature_extractor, task_classifier, kernel_mul=2.0, kernel_num=5, **base_params)

Bases: BaseMMDLike

Common API for MME-based domain adaptation on video data: DAN, JAN

forward(x)
compute_loss(batch, split_name='valid')
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.video_domain_adapter.DANTrainerVideo(dataset, image_modality, feature_extractor, task_classifier, **base_params)

Bases: BaseMMDLikeVideo

This is an implementation of DAN for video data.

training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.video_domain_adapter.JANTrainerVideo(dataset, image_modality, feature_extractor, task_classifier, kernel_mul=(2.0, 2.0), kernel_num=(5, 1), **base_params)

Bases: BaseMMDLikeVideo

This is an implementation of JAN for video data.

training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.video_domain_adapter.DANNTrainerVideo(dataset, image_modality, feature_extractor, task_classifier, critic, method, **base_params)

Bases: DANNTrainer

This is an implementation of DANN for video data.

forward(x)
compute_loss(batch, split_name='valid')
training_step(batch, batch_nb)
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.video_domain_adapter.CDANTrainerVideo(dataset, image_modality, feature_extractor, task_classifier, critic, use_entropy=False, use_random=False, random_dim=1024, **base_params)

Bases: CDANTrainer

This is an implementation of CDAN for video data.

forward(x)
compute_loss(batch, split_name='valid')
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
class kale.pipeline.video_domain_adapter.WDGRLTrainerVideo(dataset, image_modality, feature_extractor, task_classifier, critic, k_critic=5, gamma=10, beta_ratio=0, **base_params)

Bases: WDGRLTrainer

This is an implementation of WDGRL for video data.

forward(x)
compute_loss(batch, split_name='valid')
configure_optimizers()
critic_update_steps(batch)
training: bool
precision: int
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool

Module contents