Pipeline
Submodules
kale.pipeline.deepdta module
- class kale.pipeline.deepdta.BaseDTATrainer(drug_encoder, target_encoder, decoder, lr=0.001, ci_metric=False, **kwargs)
Bases:
LightningModule
Base class for all drug target encoder-decoder architecture models, which is based on pytorch lightning wrapper, for more details about pytorch lightning, please check https://github.com/PyTorchLightning/pytorch-lightning. If you inherit from this class, a forward pass function must be implemented.
- Parameters
drug_encoder – drug information encoder.
target_encoder – target information encoder.
decoder – drug-target representations decoder.
lr – learning rate. (default: 0.001)
ci_metric – calculate the Concordance Index (CI) metric, and the operation is time-consuming for large-scale
(default (dataset.) –
False
)
- configure_optimizers()
Config adam as default optimizer.
- forward(x_drug, x_target)
Same as
torch.nn.Module.forward()
- training_step(train_batch, batch_idx)
Compute and return the training loss on one step
- validation_step(valid_batch, batch_idx)
Compute and return the validation loss on one step
- test_step(test_batch, batch_idx)
Compute and return the test loss on one step
- training: bool
- class kale.pipeline.deepdta.DeepDTATrainer(drug_encoder, target_encoder, decoder, lr=0.001, ci_metric=False, **kwargs)
Bases:
BaseDTATrainer
An implementation of DeepDTA model based on BaseDTATrainer. :param drug_encoder: drug CNN encoder. :param target_encoder: target CNN encoder. :param decoder: drug-target MLP decoder. :param lr: learning rate.
- forward(x_drug, x_target)
Forward propagation in DeepDTA architecture.
- Parameters
x_drug – drug sequence encoding.
x_target – target protein sequence encoding.
- validation_step(valid_batch, batch_idx)
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
kale.pipeline.domain_adapter module
Domain adaptation systems (pipelines) with three types of architectures
This module takes individual modules as input and organises them into an architecture. This is taken directly from https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/models/architectures.py with minor changes.
This module uses PyTorch Lightning to standardize the flow.
- class kale.pipeline.domain_adapter.GradReverse(*args, **kwargs)
Bases:
Function
The gradient reversal layer (GRL)
This is defined in the DANN paper http://jmlr.org/papers/volume17/15-239/15-239.pdf
Forward pass: identity transformation. Backward propagation: flip the sign of the gradient.
From https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/models/layers.py
- static forward(ctx, x, alpha)
- static backward(ctx, grad_output)
- kale.pipeline.domain_adapter.set_requires_grad(model, requires_grad=True)
Configure whether gradients are required for a model
- kale.pipeline.domain_adapter.get_aggregated_metrics(metric_name_list, metric_outputs)
Get a dictionary of the mean metric values (to log) from metric names and their values
- kale.pipeline.domain_adapter.get_aggregated_metrics_from_dict(input_metric_dict)
Get a dictionary of the mean metric values (to log) from a dictionary of metric values
- kale.pipeline.domain_adapter.get_metrics_from_parameter_dict(parameter_dict, device)
Get a key-value pair from the hyperparameter dictionary
- class kale.pipeline.domain_adapter.Method(value)
Bases:
Enum
Lists the available methods. Provides a few methods that group the methods by type.
- Source = 'Source'
- DANN = 'DANN'
- CDAN = 'CDAN'
- CDAN_E = 'CDAN-E'
- FSDANN = 'FSDANN'
- MME = 'MME'
- WDGRL = 'WDGRL'
- WDGRLMod = 'WDGRLMod'
- DAN = 'DAN'
- JAN = 'JAN'
- is_mmd_method()
- is_dann_method()
- is_cdan_method()
- is_fewshot_method()
- allow_supervised()
- kale.pipeline.domain_adapter.create_mmd_based(method: Method, dataset, feature_extractor, task_classifier, **train_params)
MMD-based deep learning methods for domain adaptation: DAN and JAN
- kale.pipeline.domain_adapter.create_dann_like(method: Method, dataset, feature_extractor, task_classifier, critic, **train_params)
DANN-based deep learning methods for domain adaptation: DANN, CDAN, CDAN+E
- kale.pipeline.domain_adapter.create_fewshot_trainer(method: Method, dataset, feature_extractor, task_classifier, critic, **train_params)
DANN-based few-shot deep learning methods for domain adaptation: FSDANN, MME
- class kale.pipeline.domain_adapter.BaseAdaptTrainer(dataset, feature_extractor, task_classifier, method: Optional[str] = None, lambda_init: float = 1.0, adapt_lambda: bool = True, adapt_lr: bool = True, nb_init_epochs: int = 10, nb_adapt_epochs: int = 50, batch_size: int = 32, init_lr: float = 0.001, optimizer: Optional[dict] = None)
Bases:
LightningModule
Base class for all domain adaptation architectures.
This class implements the classic building blocks used in all the derived architectures for domain adaptation. If you inherit from this class, you will have to implement only:
a forward pass
a compute_loss function that returns the task loss \(\mathcal{L}_c\) and adaptation loss
\(\mathcal{L}_a\), as well as a dictionary for summary statistics and other metrics you may want to have access to.
The default training step uses only the task loss \(\mathcal{L}_c\) during warmup, then uses the loss defined as:
\(\mathcal{L} = \mathcal{L}_c + \lambda \mathcal{L}_a\),
where \(\lambda\) will follow the schedule defined by the DANN paper:
\(\lambda_p = \frac{2}{1 + \exp{(-\gamma \cdot p)}} - 1\) where \(p\) the learning progress changes linearly from 0 to 1.
- Parameters
dataset (kale.loaddata.multi_domain) – the multi-domain datasets to be used for train, validation, and tests.
feature_extractor (torch.nn.Module) – the feature extractor network (mapping inputs \(x\in\mathcal{X}\) to a latent space \(\mathcal{Z}\),)
task_classifier (torch.nn.Module) – the task classifier network that learns to predict labels \(y \in \mathcal{Y}\) from latent vectors,
method (Method, optional) – the method implemented by the class. Defaults to None. Mostly useful when several methods may be implemented using the same class.
lambda_init (float, optional) – Weight attributed to the adaptation part of the loss. Defaults to 1.0.
adapt_lambda (bool, optional) – Whether to make lambda grow from 0 to 1 following the schedule from the DANN paper. Defaults to True.
adapt_lr (bool, optional) – Whether to use the schedule for the learning rate as defined in the DANN paper. Defaults to True.
nb_init_epochs (int, optional) – Number of warmup epochs (during which lambda=0, training only on the source). Defaults to 10.
nb_adapt_epochs (int, optional) – Number of training epochs. Defaults to 50.
batch_size (int, optional) – Defaults to 32.
init_lr (float, optional) – Initial learning rate. Defaults to 1e-3.
optimizer (dict, optional) – Optimizer parameters, a dictionary with 2 keys: “type”: a string in (“SGD”, “Adam”, “AdamW”) “optim_params”: kwargs for the above PyTorch optimizer. Defaults to None.
- property method
- get_parameters_watch_list()
Update this list for parameters to watch while training (ie log with MLFlow)
- forward(x)
- compute_loss(batch, split_name='valid')
Define the loss of the model
- Parameters
batch (tuple) – batches returned by the MultiDomainLoader.
split_name (str, optional) – learning stage (one of [“train”, “valid”, “test”]). Defaults to “valid” for validation. “train” is for training and “test” for testing. This is currently used only for naming the metrics used for logging.
- Returns
a 3-element tuple with task_loss, adv_loss, log_metrics. log_metrics should be a dictionary.
- Raises
NotImplementedError – children of this classes should implement this method.
- training_step(batch, batch_nb)
The most generic of training steps
- Parameters
batch (tuple) – the batch as returned by the MultiDomainLoader dataloader iterator: 2 tuples: (x_source, y_source), (x_target, y_target) in the unsupervised setting 3 tuples: (x_source, y_source), (x_target_labeled, y_target_labeled), (x_target_unlabeled, y_target_unlabeled) in the semi-supervised setting
batch_nb (int) – id of the current batch.
- Returns
- must contain a “loss” key with the loss to be used for back-propagation.
see pytorch-lightning for more details.
- Return type
dict
- validation_step(batch, batch_nb)
- validation_epoch_end(outputs)
- test_step(batch, batch_nb)
- test_epoch_end(outputs)
- configure_optimizers()
- train_dataloader()
- val_dataloader()
- test_dataloader()
- training: bool
- class kale.pipeline.domain_adapter.BaseDANNLike(dataset, feature_extractor, task_classifier, critic, alpha=1.0, entropy_reg=0.0, adapt_reg=True, batch_reweighting=False, **base_params)
Bases:
BaseAdaptTrainer
Common API for DANN-based methods: DANN, CDAN, CDAN+E, WDGRL, MME, FSDANN
- get_parameters_watch_list()
Update this list for parameters to watch while training (ie log with MLFlow)
- compute_loss(batch, split_name='valid')
- validation_epoch_end(outputs)
- test_epoch_end(outputs)
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.domain_adapter.DANNTrainer(dataset, feature_extractor, task_classifier, critic, method=None, **base_params)
Bases:
BaseDANNLike
This class implements the DANN architecture from Ganin, Yaroslav, et al. “Domain-adversarial training of neural networks.” The Journal of Machine Learning Research (2016) https://arxiv.org/abs/1505.07818
- forward(x)
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.domain_adapter.CDANTrainer(dataset, feature_extractor, task_classifier, critic, use_entropy=False, use_random=False, random_dim=1024, **base_params)
Bases:
BaseDANNLike
Implements CDAN: Long, Mingsheng, et al. “Conditional adversarial domain adaptation.” Advances in Neural Information Processing Systems. 2018. https://papers.nips.cc/paper/7436-conditional-adversarial-domain-adaptation.pdf
- forward(x)
- compute_loss(batch, split_name='valid')
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.domain_adapter.WDGRLTrainer(dataset, feature_extractor, task_classifier, critic, k_critic=5, gamma=10, beta_ratio=0, **base_params)
Bases:
BaseDANNLike
Implements WDGRL as described in Shen, Jian, et al. “Wasserstein distance guided representation learning for domain adaptation.” Thirty-Second AAAI Conference on Artificial Intelligence. 2018. https://arxiv.org/pdf/1707.01217.pdf
This class also implements the asymmetric ($eta$) variant described in: Wu, Yifan, et al. “Domain adaptation with asymmetrically-relaxed distribution alignment.” ICML (2019) https://arxiv.org/pdf/1903.01689.pdf
- forward(x)
- compute_loss(batch, split_name='valid')
- critic_update_steps(batch)
- training_step(batch, batch_id)
- configure_optimizers()
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.domain_adapter.WDGRLTrainerMod(dataset, feature_extractor, task_classifier, critic, k_critic=5, gamma=10, beta_ratio=0, **base_params)
Bases:
WDGRLTrainer
Implements a modified version WDGRL as described in Shen, Jian, et al. “Wasserstein distance guided representation learning for domain adaptation.” Thirty-Second AAAI Conference on Artificial Intelligence. 2018. https://arxiv.org/pdf/1707.01217.pdf
This class also implements the asymmetric ($eta$) variant described in: Wu, Yifan, et al. “Domain adaptation with asymmetrically-relaxed distribution alignment.” ICML (2019) https://arxiv.org/pdf/1903.01689.pdf
- critic_update_steps(batch)
- training_step(batch, batch_id, optimizer_idx)
- optimizer_step(current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False)
- configure_optimizers()
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.domain_adapter.FewShotDANNTrainer(dataset, feature_extractor, task_classifier, critic, method, **base_params)
Bases:
BaseDANNLike
Implements adaptations of DANN to the semi-supervised setting
naive: task classifier is trained on labeled target data, in addition to source data. MME: immplements Saito, Kuniaki, et al. “Semi-supervised domain adaptation via minimax entropy.” Proceedings of the IEEE International Conference on Computer Vision. 2019 https://arxiv.org/pdf/1904.06487.pdf
- forward(x)
- compute_loss(batch, split_name='valid')
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.domain_adapter.BaseMMDLike(dataset, feature_extractor, task_classifier, kernel_mul=2.0, kernel_num=5, **base_params)
Bases:
BaseAdaptTrainer
Common API for MME-based deep learning DA methods: DAN, JAN
- forward(x)
- compute_loss(batch, split_name='valid')
- validation_epoch_end(outputs)
- test_epoch_end(outputs)
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.domain_adapter.DANTrainer(dataset, feature_extractor, task_classifier, **base_params)
Bases:
BaseMMDLike
This is an implementation of DAN Long, Mingsheng, et al. “Learning Transferable Features with Deep Adaptation Networks.” International Conference on Machine Learning. 2015. http://proceedings.mlr.press/v37/long15.pdf code based on https://github.com/thuml/Xlearn.
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.domain_adapter.JANTrainer(dataset, feature_extractor, task_classifier, kernel_mul=(2.0, 2.0), kernel_num=(5, 1), **base_params)
Bases:
BaseMMDLike
This is an implementation of JAN Long, Mingsheng, et al. “Deep transfer learning with joint adaptation networks.” International Conference on Machine Learning, 2017. https://arxiv.org/pdf/1605.06636.pdf code based on https://github.com/thuml/Xlearn.
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
kale.pipeline.mpca_trainer module
Implementation of MPCA->Feature Selection->Linear SVM/LogisticRegression Pipeline
References
[1] Swift, A. J., Lu, H., Uthoff, J., Garg, P., Cogliano, M., Taylor, J., … & Kiely, D. G. (2020). A machine learning cardiac magnetic resonance approach to extract disease features and automate pulmonary arterial hypertension diagnosis. European Heart Journal-Cardiovascular Imaging. [2] Song, X., Meng, L., Shi, Q., & Lu, H. (2015, October). Learning tensor-based features for whole-brain fMRI classification. In International Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 613-620). Springer, Cham. [3] Lu, H., Plataniotis, K. N., & Venetsanopoulos, A. N. (2008). MPCA: Multilinear principal component analysis of tensor objects. IEEE Transactions on Neural Networks, 19(1), 18-39.
- class kale.pipeline.mpca_trainer.MPCATrainer(classifier='svc', classifier_params='auto', classifier_param_grid=None, mpca_params=None, n_features=None, search_params=None)
Bases:
BaseEstimator
,ClassifierMixin
Trainer of pipeline: MPCA->Feature selection->Classifier
- Parameters
classifier (str, optional) – Available classifier options: {“svc”, “linear_svc”, “lr”}, where “svc” trains a support vector classifier, supports both linear and non-linear kernels, optimizes with library “libsvm”; “linear_svc” trains a support vector classifier with linear kernel only, and optimizes with library “liblinear”, which suppose to be faster and better in handling large number of samples; and “lr” trains a classifier with logistic regression. Defaults to “svc”.
classifier_params (dict, optional) – Parameters of classifier. Defaults to ‘auto’.
classifier_param_grid (dict, optional) – Grids for searching the optimal hyper-parameters. Works only when classifier_params == “auto”. Defaults to None by searching from the following hyper-parameter values: 1. svc, {“kernel”: [“linear”], “C”: [0.0001, 0.001, 0.01, 0.1, 1, 10, 100], “max_iter”: [50000]}, 2. linear_svc, {“C”: [0.0001, 0.001, 0.01, 0.1, 1, 10, 100]}, 3. lr, {“C”: [0.0001, 0.001, 0.01, 0.1, 1, 10, 100]}
mpca_params (dict, optional) – Parameters of MPCA, e.g., {“var_ratio”: 0.8}. Defaults to None, i.e., using the default parameters (https://pykale.readthedocs.io/en/latest/kale.embed.html#module-kale.embed.mpca).
n_features (int, optional) – Number of features for feature selection. Defaults to None, i.e., all features after dimension reduction will be used.
search_params (dict, optional) – Parameters of grid search, for more detail please see https://scikit-learn.org/stable/modules/grid_search.html#grid-search . Defaults to None, i.e., using the default params: {“cv”: 5}.
- fit(x, y)
Fit a pipeline with the given data x and labels y
- Parameters
x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)
y (array-like) – data labels, shape (n_samples, )
- Returns
self
- predict(x)
Predict the labels for the given data x
- Parameters
x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)
- Returns
Predicted labels, shape (n_samples, )
- Return type
array-like
- decision_function(x)
Decision scores of each class for the given data x
- Parameters
x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)
- Returns
decision scores, shape (n_samples,) for binary case, else (n_samples, n_class)
- Return type
array-like
- predict_proba(x)
Probability of each class for the given data x. Not supported by “linear_svc”.
- Parameters
x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)
- Returns
probabilities, shape (n_samples, n_class)
- Return type
array-like
kale.pipeline.multi_domain_adapter module
Multi-source domain adaptation pipelines
- kale.pipeline.multi_domain_adapter.create_ms_adapt_trainer(method: str, dataset, feature_extractor, task_classifier, **train_params)
Methods for multi-source domain adaptation
- Parameters
method (str) – Multi-source domain adaptation method, M3SDA or MFSAN
dataset (kale.loaddata.multi_domain.MultiDomainAdapDataset) – the multi-domain datasets to be used for train, validation, and tests.
feature_extractor (torch.nn.Module) – feature extractor network
task_classifier (torch.nn.Module) – task classifier network
- Returns
Multi-source domain adaptation trainer.
- Return type
[pl.LightningModule]
- class kale.pipeline.multi_domain_adapter.BaseMultiSourceTrainer(dataset, feature_extractor, task_classifier, n_classes: int, target_domain: str, **base_params)
Bases:
BaseAdaptTrainer
Base class for all domain adaptation architectures
- Parameters
dataset (kale.loaddata.multi_domain) – the multi-domain datasets to be used for train, validation, and tests.
feature_extractor (torch.nn.Module) – the feature extractor network
task_classifier (torch.nn.Module) – the task classifier network
n_classes (int) – number of classes
target_domain (str) – target domain name
- forward(x)
- compute_loss(batch, split_name='valid')
- validation_epoch_end(outputs)
- test_epoch_end(outputs)
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.multi_domain_adapter.M3SDATrainer(dataset, feature_extractor, task_classifier, n_classes: int, target_domain: str, k_moment: int = 3, **base_params)
Bases:
BaseMultiSourceTrainer
Moment matching for multi-source domain adaptation (M3SDA).
- Reference:
Peng, X., Bai, Q., Xia, X., Huang, Z., Saenko, K., & Wang, B. (2019). Moment matching for multi-source domain adaptation. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 1406-1415). https://openaccess.thecvf.com/content_ICCV_2019/html/Peng_Moment_Matching_for_Multi-Source_Domain_Adaptation_ICCV_2019_paper.html
- compute_loss(batch, split_name='valid')
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.multi_domain_adapter.MFSANTrainer(dataset, feature_extractor, task_classifier, n_classes: int, target_domain: str, domain_feat_dim: int = 100, kernel_mul: float = 2.0, kernel_num: int = 5, input_dimension: int = 2, **base_params)
Bases:
BaseMultiSourceTrainer
Multiple Feature Spaces Adaptation Network (MFSAN)
- Reference: Zhu, Y., Zhuang, F. and Wang, D., 2019, July. Aligning domain-specific distribution and classifier
for cross-domain classification from multiple sources. In AAAI. https://ojs.aaai.org/index.php/AAAI/article/view/4551
Original implementation: https://github.com/easezyc/deep-transfer-learning/tree/master/MUDA/MFSAN
- compute_loss(batch, split_name='valid')
- cls_discrepancy(x)
Compute discrepancy between all classifiers’ probabilistic outputs
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.multi_domain_adapter.CoIRLS(kernel='linear', kernel_kwargs=None, alpha=1.0, lambda_=1.0)
Bases:
BaseEstimator
,ClassifierMixin
Covariate-Independence Regularized Least Squares (CoIRLS)
- Parameters
kernel (str, optional) – {“linear”, “rbf”, “poly”}. Kernel to use. Defaults to “linear”.
kernel_kwargs (dict or None, optional) – Hyperparameter for the kernel. Defaults to None.
alpha (float, optional) – Hyperparameter of the l2 (Ridge) penalty. Defaults to 1.0.
lambda (float, optional) – Hyperparameter of the covariate dependence. Defaults to 1.0.
- Reference:
- [1] Zhou, S., 2022. Interpretable Domain-Aware Learning for Neuroimage Classification (Doctoral dissertation,
University of Sheffield).
- [2] Zhou, S., Li, W., Cox, C.R., & Lu, H. (2020). Side Information Dependence as a Regularizer for Analyzing
Human Brain Conditions across Cognitive Experiments. AAAI 2020, New York, USA.
- fit(x, y, covariates)
fit a model with input data, labels and covariates
- Parameters
x (np.ndarray or tensor) – shape (n_samples, n_features)
y (np.ndarray or tensor) – shape (n_samples, )
covariates (np.ndarray or tensor) – (n_samples, n_covariates)
- predict(x)
Predict labels for data x
- Parameters
x (np.ndarray or tensor) – Samples need prediction, shape (n_samples, n_features)
- Returns
Predicted labels, shape (n_samples, )
- Return type
y (np.ndarray)
- decision_function(x)
Compute decision scores for data x
- Parameters
x (np.ndarray or tensor) – Samples need decision scores, shape (n_samples, n_features)
- Returns
Decision scores, shape (n_samples, )
- Return type
scores (np.ndarray)
kale.pipeline.video_domain_adapter module
Domain adaptation systems (pipelines) for video data, e.g., for action recognition. Most are inherited from kale.pipeline.domain_adapter.
- kale.pipeline.video_domain_adapter.create_mmd_based_video(method: Method, dataset, image_modality, feature_extractor, task_classifier, **train_params)
MMD-based deep learning methods for domain adaptation on video data: DAN and JAN
- kale.pipeline.video_domain_adapter.create_dann_like_video(method: Method, dataset, image_modality, feature_extractor, task_classifier, critic, **train_params)
DANN-based deep learning methods for domain adaptation on video data: DANN, CDAN, CDAN+E
- class kale.pipeline.video_domain_adapter.BaseMMDLikeVideo(dataset, image_modality, feature_extractor, task_classifier, kernel_mul=2.0, kernel_num=5, **base_params)
Bases:
BaseMMDLike
Common API for MME-based domain adaptation on video data: DAN, JAN
- forward(x)
- compute_loss(batch, split_name='valid')
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.video_domain_adapter.DANTrainerVideo(dataset, image_modality, feature_extractor, task_classifier, **base_params)
Bases:
BaseMMDLikeVideo
This is an implementation of DAN for video data.
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.video_domain_adapter.JANTrainerVideo(dataset, image_modality, feature_extractor, task_classifier, kernel_mul=(2.0, 2.0), kernel_num=(5, 1), **base_params)
Bases:
BaseMMDLikeVideo
This is an implementation of JAN for video data.
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.video_domain_adapter.DANNTrainerVideo(dataset, image_modality, feature_extractor, task_classifier, critic, method, **base_params)
Bases:
DANNTrainer
This is an implementation of DANN for video data.
- forward(x)
- compute_loss(batch, split_name='valid')
- training_step(batch, batch_nb)
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.video_domain_adapter.CDANTrainerVideo(dataset, image_modality, feature_extractor, task_classifier, critic, use_entropy=False, use_random=False, random_dim=1024, **base_params)
Bases:
CDANTrainer
This is an implementation of CDAN for video data.
- forward(x)
- compute_loss(batch, split_name='valid')
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool
- class kale.pipeline.video_domain_adapter.WDGRLTrainerVideo(dataset, image_modality, feature_extractor, task_classifier, critic, k_critic=5, gamma=10, beta_ratio=0, **base_params)
Bases:
WDGRLTrainer
This is an implementation of WDGRL for video data.
- forward(x)
- compute_loss(batch, split_name='valid')
- configure_optimizers()
- critic_update_steps(batch)
- training: bool
- precision: int
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool