Pipeline
Submodules
kale.pipeline.base_nn_trainer module
Classification systems (pipelines)
This module provides neural network (nn) trainers for developing classification task models. The BaseNNTrainer defines the required fundamental functions and structures, such as the optimizer, learning rate scheduler, training/validation/testing procedure, workflow, etc. The BaseNNTrainer is inherited to construct specialized trainers.
The structure and workflow of BaseNNTrainer is consistent with kale.pipeline.domain_adapter.BaseAdaptTrainer
This module uses PyTorch Lightning to standardize the flow.
This module also provides a Multimodal Neural Network Trainer (MultimodalNNTrainer) where this trainer uses separate encoders for each modality, a fusion technique to combine the modalities, and a classifier head for final prediction. MultimodalNNTrainer is also designed to handle training, validation, and testing steps for multimodal data using specified models, optimization algorithms, and loss functions. Adapted from: https://github.com/pliang279/MultiBench/blob/main/training_structures/Supervised_Learning.py
- class kale.pipeline.base_nn_trainer.BaseNNTrainer(optimizer, max_epochs, init_lr=0.001, adapt_lr=False)
Bases:
LightningModule
Base class for classification models using neural network, based on PyTorch Lightning wrapper. The forward pass and loss computation must be implemented if new trainers inherit from this class. The basic workflow is defined in this class as follows. Every training/validation/testing procedure will call compute_loss() to compute the loss and log the output metrics. The compute_loss() function will call forward() to generate the output feature using the neural networks.
- Parameters:
optimizer (dict, None) – optimizer parameters.
max_epochs (int) – maximum number of epochs.
init_lr (float) – initial learning rate. Defaults to 0.001.
adapt_lr (bool) – whether to use the schedule for the learning rate. Defaults to False.
- forward(x)
Override this function to define the forward pass. Normally includes feature extraction and classification and be called in compute_loss().
- compute_loss(batch, split_name='valid')
Compute loss for a given batch.
- Parameters:
batch (tuple) – batches returned by dataloader.
split_name (str, optional) – learning stage (one of [“train”, “valid”, “test”]). Defaults to “valid” for validation. “train” is for training and “test” for testing. This is currently used only for naming the metrics used for logging.
- Returns:
loss value. log_metrics (dict): dictionary of metrics to be logged. This is needed when using PyKale logging, but not
mandatory when using PyTorch Lightning logging.
- Return type:
loss (torch.Tensor)
- configure_optimizers()
Default optimizer configuration. Set Adam to the default and provide SGD with cosine annealing. If other optimizers are needed, please override this function.
- training_step(train_batch, batch_idx) Tensor
Compute and return the training loss and metrics on one step. loss is to store the loss value. log_metrics is to store the metrics to be logged, including loss, top1 and/or top5 accuracies.
Use self.log_dict(log_metrics, on_step, on_epoch, logger) to log the metrics on each step and each epoch. For training, log on each step and each epoch. For validation and testing, only log on each epoch. This way can avoid using on_training_epoch_end() and on_validation_epoch_end().
- validation_step(valid_batch, batch_idx) None
Compute and return the validation loss and metrics on one step.
- test_step(test_batch, batch_idx) None
Compute and return the testing loss and metrics on one step.
- class kale.pipeline.base_nn_trainer.CNNTransformerTrainer(feature_extractor, task_classifier, lr_milestones, lr_gamma, **kwargs)
Bases:
BaseNNTrainer
PyTorch Lightning trainer for cnntransformer.
- Parameters:
feature_extractor (torch.nn.Sequential, optional) – the feature extractor network.
optimizer (dict) – optimizer parameters.
lr_milestones (list) – list of epoch indices. Must be increasing.
lr_gamma (float) – multiplicative factor of learning rate decay.
- forward(x)
Forward pass for the model with a feature extractor and a classifier.
- compute_loss(batch, split_name='valid')
Compute loss, top1 and top5 accuracy for a given batch.
- configure_optimizers()
Set up an SGD optimizer and multistep learning rate scheduler. When self._adapt_lr is True, the learning rate will be decayed by self.lr_gamma every step in milestones.
- class kale.pipeline.base_nn_trainer.MultimodalNNTrainer(encoders, fusion, head, variable_length_sequences=False, optim=<class 'torch.optim.sgd.SGD'>, lr=0.001, weight_decay=0.0, objective=CrossEntropyLoss())
Bases:
LightningModule
MultimodalNNTrainer, serves as a PyTorch Lightning trainer for multimodal models. It is designed to handle training, validation, and testing steps for multimodal data using specified models, optimization algorithms, and loss functions.
For each training, validation, and test step, the trainer class computes the model’s loss and accuracy and logs these metrics. This trainer simplifies the process of training complex multimodal models, allowing the user to focus on model architecture and hyperparameter tuning. This trainer is flexible and can be used with various models, optimizers, and loss functions, enabling its use across a wide range of multimodal learning tasks.
- Parameters:
encoders (List[nn.Module]) – A list of PyTorch nn.Module encoders, with one encoder per modality. Each encoder
representation. (single)
fusion (nn.Module) – A PyTorch nn.Module that merges the high-level representations from each modality into a
representation.
head (nn.Module) – A PyTorch nn.Module that takes the fused representation and outputs a class prediction.
is_packed (bool, optional) – whether the input modalities are packed in one list or not (default is False, which
[tensor (means we expect input of)
modalities)
optim (torch.optim, optional) – The optimization algorithm to use. Defaults to torch.optim.SGD.
lr (float, optional) – Learning rate for the optimizer. Defaults to 0.001.
weight_decay (float, optional) – Weight decay for the optimizer. Defaults to 0.0.
objective (torch.nn.Module, optional) – Loss function. Defaults to torch.nn.CrossEntropyLoss.
- forward(inputs)
- compute_loss(batch, split_name='valid')
- configure_optimizers()
- training_step(train_batch, batch_idx)
- validation_step(valid_batch, batch_idx)
- test_step(test_batch, batch_idx)
kale.pipeline.deepdta module
- class kale.pipeline.deepdta.BaseDTATrainer(drug_encoder, target_encoder, decoder, lr=0.001, ci_metric=False, **kwargs)
Bases:
LightningModule
Base class for all drug target encoder-decoder architecture models, which is based on pytorch lightning wrapper, for more details about pytorch lightning, please check https://github.com/PyTorchLightning/pytorch-lightning. If you inherit from this class, a forward pass function must be implemented.
- Parameters:
drug_encoder – drug information encoder.
target_encoder – target information encoder.
decoder – drug-target representations decoder.
lr – learning rate. (default: 0.001)
ci_metric – calculate the Concordance Index (CI) metric, and the operation is time-consuming for large-scale
(default (dataset.) –
False
)
- configure_optimizers()
Config adam as default optimizer.
- forward(x_drug, x_target)
Same as
torch.nn.Module.forward()
- training_step(train_batch, batch_idx)
Compute and return the training loss on one step
- validation_step(valid_batch, batch_idx)
Compute and return the validation loss on one step
- test_step(test_batch, batch_idx)
Compute and return the test loss on one step
- class kale.pipeline.deepdta.DeepDTATrainer(drug_encoder, target_encoder, decoder, lr=0.001, ci_metric=False, **kwargs)
Bases:
BaseDTATrainer
An implementation of DeepDTA model based on BaseDTATrainer. :param drug_encoder: drug CNN encoder. :param target_encoder: target CNN encoder. :param decoder: drug-target MLP decoder. :param lr: learning rate.
- forward(x_drug, x_target)
Forward propagation in DeepDTA architecture.
- Parameters:
x_drug – drug sequence encoding.
x_target – target protein sequence encoding.
- validation_step(valid_batch, batch_idx)
kale.pipeline.domain_adapter module
Domain adaptation systems (pipelines) with three types of architectures
This module takes individual modules as input and organises them into an architecture. This is taken directly from https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/models/architectures.py with minor changes.
This module uses PyTorch Lightning to standardize the flow.
- class kale.pipeline.domain_adapter.GradReverse(*args, **kwargs)
Bases:
Function
The gradient reversal layer (GRL)
This is defined in the DANN paper https://jmlr.org/papers/volume17/15-239/15-239.pdf
Forward pass: identity transformation. Backward propagation: flip the sign of the gradient.
From https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/models/layers.py
- static forward(ctx, x, alpha)
- static backward(ctx, grad_output)
- kale.pipeline.domain_adapter.set_requires_grad(model, requires_grad=True)
Configure whether gradients are required for a model
- class kale.pipeline.domain_adapter.Method(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)
Bases:
Enum
Lists the available methods. Provides a few methods that group the methods by type.
- Source = 'Source'
- DANN = 'DANN'
- CDAN = 'CDAN'
- CDAN_E = 'CDAN-E'
- FSDANN = 'FSDANN'
- MME = 'MME'
- WDGRL = 'WDGRL'
- WDGRLMod = 'WDGRLMod'
- DAN = 'DAN'
- JAN = 'JAN'
- is_mmd_method()
- is_dann_method()
- is_cdan_method()
- is_fewshot_method()
- allow_supervised()
- kale.pipeline.domain_adapter.create_mmd_based(method: Method, dataset, feature_extractor, task_classifier, **train_params)
MMD-based deep learning methods for domain adaptation: DAN and JAN
- kale.pipeline.domain_adapter.create_dann_like(method: Method, dataset, feature_extractor, task_classifier, critic, **train_params)
DANN-based deep learning methods for domain adaptation: DANN, CDAN, CDAN+E
- kale.pipeline.domain_adapter.create_fewshot_trainer(method: Method, dataset, feature_extractor, task_classifier, critic, **train_params)
DANN-based few-shot deep learning methods for domain adaptation: FSDANN, MME
- class kale.pipeline.domain_adapter.BaseAdaptTrainer(dataset, feature_extractor, task_classifier, method: str | None = None, lambda_init: float = 1.0, adapt_lambda: bool = True, adapt_lr: bool = True, nb_init_epochs: int = 10, nb_adapt_epochs: int = 50, batch_size: int = 32, num_workers: int = 1, init_lr: float = 0.001, optimizer: dict | None = None)
Bases:
LightningModule
Base class for all domain adaptation architectures.
This class implements the classic building blocks used in all the derived architectures for domain adaptation. If you inherit from this class, you will have to implement only:
a forward pass
a compute_loss function that returns the task loss \(\mathcal{L}_c\) and adaptation loss
\(\mathcal{L}_a\), as well as a dictionary for summary statistics and other metrics you may want to have access to.
The default training step uses only the task loss \(\mathcal{L}_c\) during warmup, then uses the loss defined as:
\(\mathcal{L} = \mathcal{L}_c + \lambda \mathcal{L}_a\),
where \(\lambda\) will follow the schedule defined by the DANN paper:
\(\lambda_p = \frac{2}{1 + \exp{(-\gamma \cdot p)}} - 1\) where \(p\) the learning progress changes linearly from 0 to 1.
- Parameters:
dataset (kale.loaddata.multi_domain.MultiDomainDatasets) – the multi-domain datasets to be used for train, validation, and tests.
feature_extractor (torch.nn.Module) – the feature extractor network (mapping inputs \(x\in\mathcal{X}\) to a latent space \(\mathcal{Z}\),).
task_classifier (torch.nn.Module) – the task classifier network that learns to predict labels \(y \in \mathcal{Y}\) from latent vectors.
method (Method, optional) – the method implemented by the class. Defaults to None. Mostly useful when several methods may be implemented using the same class.
lambda_init (float, optional) – weight attributed to the adaptation part of the loss. Defaults to 1.0.
adapt_lambda (bool, optional) – whether to make lambda grow from 0 to 1 following the schedule from the DANN paper. Defaults to True.
adapt_lr (bool, optional) – whether to use the schedule for the learning rate as defined in the DANN paper. Defaults to True.
nb_init_epochs (int, optional) – number of warmup epochs (during which lambda=0, training only on the source). Defaults to 10.
nb_adapt_epochs (int, optional) – number of training epochs. Defaults to 50.
batch_size (int, optional) – defaults to 32.
init_lr (float, optional) – initial learning rate. Defaults to 1e-3.
optimizer (dict, optional) – optimizer parameters, a dictionary with 2 keys: “type”: a string in (“SGD”, “Adam”, “AdamW”) “optim_params”: kwargs for the above PyTorch optimizer. Defaults to None.
- property method
- forward(x)
- compute_loss(batch, split_name='valid')
Define the loss of the model
- Parameters:
batch (tuple) – batches returned by the MultiDomainLoader.
split_name (str, optional) – learning stage (one of [“train”, “valid”, “test”]). Defaults to “valid” for validation. “train” is for training and “test” for testing. This is currently used only for naming the metrics used for logging.
- Returns:
a 3-element tuple with task_loss, adv_loss, log_metrics. log_metrics should be a dictionary.
- Raises:
NotImplementedError – children of this classes should implement this method.
- training_step(batch, batch_nb)
The most generic of training steps
- Parameters:
batch (tuple) –
the batch as returned by the MultiDomainLoader dataloader iterator: 2 tuples: (x_source, y_source), (x_target, y_target) in the unsupervised setting 3 tuples: (x_source, y_source), (x_target_labeled, y_target_labeled), (x_target_unlabeled,
y_target_unlabeled) in the semi-supervised setting
batch_nb (int) – id of the current batch.
- Returns:
- must contain a “loss” key with the loss to be used for back-propagation.
see pytorch-lightning for more details.
- Return type:
dict
- validation_step(batch, batch_nb)
- test_step(batch, batch_nb)
- configure_optimizers()
- train_dataloader()
- val_dataloader()
- test_dataloader()
- class kale.pipeline.domain_adapter.BaseDANNLike(dataset, feature_extractor, task_classifier, critic, alpha=1.0, entropy_reg=0.0, adapt_reg=True, batch_reweighting=False, **base_params)
Bases:
BaseAdaptTrainer
Common API for DANN-based methods: DANN, CDAN, CDAN+E, WDGRL, MME, FSDANN
- compute_loss(batch, split_name='valid')
- class kale.pipeline.domain_adapter.DANNTrainer(dataset, feature_extractor, task_classifier, critic, method=None, **base_params)
Bases:
BaseDANNLike
This class implements the DANN architecture from Ganin, Yaroslav, et al. “Domain-adversarial training of neural networks.” The Journal of Machine Learning Research (2016) https://arxiv.org/abs/1505.07818
- forward(x)
- class kale.pipeline.domain_adapter.CDANTrainer(dataset, feature_extractor, task_classifier, critic, use_entropy=False, use_random=False, random_dim=1024, **base_params)
Bases:
BaseDANNLike
Implements CDAN: Long, Mingsheng, et al. “Conditional adversarial domain adaptation.” Advances in Neural Information Processing Systems. 2018. https://papers.nips.cc/paper/7436-conditional-adversarial-domain-adaptation.pdf
- forward(x)
- compute_loss(batch, split_name='valid')
- class kale.pipeline.domain_adapter.WDGRLTrainer(dataset, feature_extractor, task_classifier, critic, k_critic=5, gamma=10, beta_ratio=0, **base_params)
Bases:
BaseDANNLike
Implements WDGRL as described in Shen, Jian, et al. “Wasserstein distance guided representation learning for domain adaptation.” Thirty-Second AAAI Conference on Artificial Intelligence. 2018. https://arxiv.org/pdf/1707.01217.pdf
This class also implements the asymmetric (\(\beta\)) variant described in: Wu, Yifan, et al. “Domain adaptation with asymmetrically-relaxed distribution alignment.” ICML (2019) https://arxiv.org/pdf/1903.01689.pdf
- forward(x)
- compute_loss(batch, split_name='valid')
- critic_update_steps(batch)
- training_step(batch, batch_id)
- configure_optimizers()
- class kale.pipeline.domain_adapter.WDGRLTrainerMod(dataset, feature_extractor, task_classifier, critic, k_critic=5, gamma=10, beta_ratio=0, **base_params)
Bases:
WDGRLTrainer
Implements a modified version WDGRL as described in Shen, Jian, et al. “Wasserstein distance guided representation learning for domain adaptation.” Thirty-Second AAAI Conference on Artificial Intelligence. 2018. https://arxiv.org/pdf/1707.01217.pdf
This class also implements the asymmetric (\(\beta\)) variant described in: Wu, Yifan, et al. “Domain adaptation with asymmetrically-relaxed distribution alignment.” ICML (2019) https://arxiv.org/pdf/1903.01689.pdf
- critic_update_steps(batch)
- training_step(batch, batch_id)
- configure_optimizers()
- class kale.pipeline.domain_adapter.FewShotDANNTrainer(dataset, feature_extractor, task_classifier, critic, method, **base_params)
Bases:
BaseDANNLike
Implements adaptations of DANN to the semi-supervised setting
naive: task classifier is trained on labeled target data, in addition to source data. MME: immplements Saito, Kuniaki, et al. “Semi-supervised domain adaptation via minimax entropy.” Proceedings of the IEEE International Conference on Computer Vision. 2019 https://arxiv.org/pdf/1904.06487.pdf
- forward(x)
- compute_loss(batch, split_name='valid')
- class kale.pipeline.domain_adapter.BaseMMDLike(dataset, feature_extractor, task_classifier, kernel_mul=2.0, kernel_num=5, **base_params)
Bases:
BaseAdaptTrainer
Common API for MME-based deep learning DA methods: DAN, JAN
- forward(x)
- compute_loss(batch, split_name='valid')
- class kale.pipeline.domain_adapter.DANTrainer(dataset, feature_extractor, task_classifier, **base_params)
Bases:
BaseMMDLike
This is an implementation of DAN Long, Mingsheng, et al. “Learning Transferable Features with Deep Adaptation Networks.” International Conference on Machine Learning. 2015. https://proceedings.mlr.press/v37/long15.pdf code based on https://github.com/thuml/Xlearn.
- class kale.pipeline.domain_adapter.JANTrainer(dataset, feature_extractor, task_classifier, kernel_mul=(2.0, 2.0), kernel_num=(5, 1), **base_params)
Bases:
BaseMMDLike
This is an implementation of JAN Long, Mingsheng, et al. “Deep transfer learning with joint adaptation networks.” International Conference on Machine Learning, 2017. https://arxiv.org/pdf/1605.06636.pdf code based on https://github.com/thuml/Xlearn.
kale.pipeline.mpca_trainer module
Implementation of MPCA->Feature Selection->Linear SVM/LogisticRegression Pipeline
References
[1] Swift, A. J., Lu, H., Uthoff, J., Garg, P., Cogliano, M., Taylor, J., … & Kiely, D. G. (2020). A machine learning cardiac magnetic resonance approach to extract disease features and automate pulmonary arterial hypertension diagnosis. European Heart Journal-Cardiovascular Imaging. [2] Song, X., Meng, L., Shi, Q., & Lu, H. (2015, October). Learning tensor-based features for whole-brain fMRI classification. In International Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 613-620). Springer, Cham. [3] Lu, H., Plataniotis, K. N., & Venetsanopoulos, A. N. (2008). MPCA: Multilinear principal component analysis of tensor objects. IEEE Transactions on Neural Networks, 19(1), 18-39.
- class kale.pipeline.mpca_trainer.MPCATrainer(classifier='svc', classifier_params='auto', classifier_param_grid=None, mpca_params=None, n_features=None, search_params=None)
Bases:
BaseEstimator
,ClassifierMixin
Trainer of pipeline: MPCA->Feature selection->Classifier
- Parameters:
classifier (str, optional) – Available classifier options: {“svc”, “linear_svc”, “lr”}, where “svc” trains a support vector classifier, supports both linear and non-linear kernels, optimizes with library “libsvm”; “linear_svc” trains a support vector classifier with linear kernel only, and optimizes with library “liblinear”, which suppose to be faster and better in handling large number of samples; and “lr” trains a classifier with logistic regression. Defaults to “svc”.
classifier_params (dict, optional) – Parameters of classifier. Defaults to ‘auto’.
classifier_param_grid (dict, optional) – Grids for searching the optimal hyper-parameters. Works only when classifier_params == “auto”. Defaults to None by searching from the following hyper-parameter values: 1. svc, {“kernel”: [“linear”], “C”: [0.0001, 0.001, 0.01, 0.1, 1, 10, 100], “max_iter”: [50000]}, 2. linear_svc, {“C”: [0.0001, 0.001, 0.01, 0.1, 1, 10, 100]}, 3. lr, {“C”: [0.0001, 0.001, 0.01, 0.1, 1, 10, 100]}
mpca_params (dict, optional) – Parameters of MPCA, e.g., {“var_ratio”: 0.8}. Defaults to None, i.e., using the default parameters (https://pykale.readthedocs.io/en/latest/kale.embed.html#module-kale.embed.mpca).
n_features (int, optional) – Number of features for feature selection. Defaults to None, i.e., all features after dimension reduction will be used.
search_params (dict, optional) – Parameters of grid search, for more detail please see https://scikit-learn.org/stable/modules/grid_search.html#grid-search . Defaults to None, i.e., using the default params: {“cv”: 5}.
- fit(x, y)
Fit a pipeline with the given data x and labels y
- Parameters:
x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)
y (array-like) – data labels, shape (n_samples, )
- Returns:
self
- predict(x)
Predict the labels for the given data x
- Parameters:
x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)
- Returns:
Predicted labels, shape (n_samples, )
- Return type:
array-like
- decision_function(x)
Decision scores of each class for the given data x
- Parameters:
x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)
- Returns:
decision scores, shape (n_samples,) for binary case, else (n_samples, n_class)
- Return type:
array-like
- predict_proba(x)
Probability of each class for the given data x. Not supported by “linear_svc”.
- Parameters:
x (array-like tensor) – input data, shape (n_samples, I_1, I_2, …, I_N)
- Returns:
probabilities, shape (n_samples, n_class)
- Return type:
array-like
kale.pipeline.multi_domain_adapter module
kale.pipeline.multiomics_trainer module
Construct a pipeline to run the MOGONET method based on PyTorch Lightning. MOGONET is a multiomics fusion framework for cancer classification and biomarker identification that utilizes supervised graph convolutional networks for omics datasets.
This code is written by refactoring the MOGONET code (https://github.com/txWang/MOGONET/blob/main/train_test.py) within the PyTorch Lightning.
Reference: Wang, T., Shao, W., Huang, Z., Tang, H., Zhang, J., Ding, Z., Huang, K. (2021). MOGONET integrates multi-omics data using graph convolutional networks allowing patient classification and biomarker identification. Nature communications. https://www.nature.com/articles/s41467-021-23774-w
- class kale.pipeline.multiomics_trainer.MultiomicsTrainer(dataset: SparseMultiomicsDataset, num_modalities: int, num_classes: int, unimodal_encoder: List[MogonetGCN], unimodal_decoder: List[LinearClassifier], loss_fn: CrossEntropyLoss, multimodal_decoder: VCDN | None = None, train_multimodal_decoder: bool = True, gcn_lr: float = 0.0005, vcdn_lr: float = 0.001)
Bases:
LightningModule
The PyTorch Lightning implementation of the MOGONET method, a multiomics fusion method designed for classification tasks.
- Parameters:
dataset (SparseMultiomicsDataset) – The input dataset created in form of
Dataset
.num_modalities (int) – The total number of modalities in the dataset.
num_classes (int) – The total number of classes in the dataset.
unimodal_encoder (List[MogonetGCN]) – The list of GCN encoders for each modality.
unimodal_decoder (List[LinearClassifier]) – The list of linear classifier decoders for each modality.
loss_fn (CrossEntropyLoss) – The loss function used to gauge the error between the prediction outputs and the provided target values.
multimodal_decoder (VCDN, optional) – The VCDN decoder used in the multiomics dataset. (default:
None
)train_multimodal_decoder (bool, optional) – Whether to train VCDN module. (default:
True
)gcn_lr (float, optional) – The learning rate used in the GCN module. (default: 5e-4)
vcdn_lr (float, optional) – The learning rate used in the VCDN module. (default: 1e-3)
- configure_optimizers() Optimizer | List[Optimizer]
Return the optimizers used during training.
- forward(x: List[Tensor], adj_t: List[SparseTensor], multimodal: bool = False) Tensor | List[Tensor]
Same as
torch.nn.Module.forward()
.- Raises:
TypeError – If multimodal_decoder is None for multiomics datasets.
- training_step(train_batch, batch_idx: int)
Compute and return the training loss.
- Parameters:
train_batch (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.batch_idx (
int
) – Integer displaying index of this batch.
- test_step(test_batch, batch_idx: int)
Compute and return the test loss.
- Parameters:
test_batch (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.batch_idx (int) – Integer displaying index of this batch.
- train_dataloader() DataLoader
Return an iterable or a collection of iterables that specifies training samples in the dataset.
- test_dataloader() DataLoader
Return an iterable or a collection of iterables that specifies test samples in the dataset.
kale.pipeline.video_domain_adapter module
kale.pipeline.fewshot_trainer module
This module contains the ProtoNet trainer class and related functions. It trains a prototypical network model for few-shot learning problems under \(N\)-way-\(K\)-shot settings.
ProtoNet is a few-shot learning method that can be considered a clustering method. It learns a feature space where samples from the same class are close to each other and samples from different classes are far apart. The prototypes can be seen as the cluster centers, and the feature space is learned to make the samples cluster around these prototypes. But note that ProtoNet operates in a supervised learning context, where the goal is to classify data points based on labeled training examples. Clustering is typically an unsupervised learning task, where the objective is to group data points into clusters without prior knowledge of labels.
This is a PyTorch Lightning <https://github.com/Lightning-AI/lightning>
version of the original implementation <https://github.com/jakesnell/prototypical-networks> of Prototypical Networks for Few-shot Learning <https://arxiv.org/abs/1703.05175>.
- class kale.pipeline.fewshot_trainer.ProtoNetTrainer(net: Module, train_num_classes: int = 30, train_num_support_samples: int = 5, train_num_query_samples: int = 15, val_num_classes: int = 5, val_num_support_samples: int = 5, val_num_query_samples: int = 15, devices: str = 'cuda', optimizer: str = 'SGD', lr: float = 0.001)
Bases:
LightningModule
ProtoNet trainer class.
This class trains a ProtoNet model for few-shot learning problems under \(N\)-way-\(K\)-shot settings. It uses
pl.LightningModule
class ofPyTorch Lightning
to standardize the workflow. Updating other modules exceptkale.evaluate.metrics.protonet_loss
andkale.embed.image_cnn
will not affect this trainer.\(N\)-way: The number of classes under a particular setting. The model is presented with samples from these \(N\) classes and needs to classify them. For example, 3-way means the model has to classify 3 different classes.
\(K\)-shot: The number of samples for each class in the support set. For example, in a 2-shot setting, two support samples are provided per class.
Support set: It is a small, labeled dataset used to train the model with a few samples of each class. The support set consists of \(N\) classes (\(N\)-way), with \(K\) samples (\(K\)-shot) for each class. For example, under a 3-way-2-shot setting, the support set has 3 classes with 2 samples per class, totaling 6 samples.
Query set: It evaluates the model’s ability to generalize what it has learned from the support set. It contains samples from the same \(N\) classes but not included in the support set. Continuing with the 3-way-2-shot example, the query set would include additional samples from the 3 classes, which the model must classify after learning from the support set.
- Parameters:
net (torch.nn.Module) – A feature extractor without any task-specific heads. It outputs a 1-D feature vector.
train_num_classes (int) – Number of classes in training. It could be different from \(N\) under \(N\)-way-\(K\)-shot settings in ProtoNet. Default: 30.
train_num_support_samples (int) – Number of samples per class in the support set in training. It corresponds to \(K\) under \(N\)-way-\(K\)-shot settings. Default: 5.
train_num_query_samples (int) – Number of samples per class in the query set in training. Default: 15.
val_num_classes (int) – Number of classes in validation and testing. It corresponds to \(N\) under \(N\)-way-\(K\)-shot settings. Default: 5.
val_num_support_samples (int) – Number of samples per class in the support set in validation and testing. It corresponds to \(K\) under \(N\)-way-\(K\)-shot settings. Default: 5.
val_num_query_samples (int) – Number of samples per class in the query set in validation and testing. Default: 15.
devices (str) – Devices used for training. Default: “cuda”.
optimizer (str) – Optimizer used for training. Default: “SGD”.
lr (float) – Learning rate. Default: 0.001.
- forward(x, num_support_samples, num_classes) Tensor
- compute_loss(feature_support, feature_query, mode='train') tuple
Compute loss and accuracy.
Here we use the same loss function for both training and validation, which is related to Euclidean distance.
- Parameters:
feature_support (torch.Tensor) – Support features.
feature_query (torch.Tensor) – Query features.
mode (str) – Mode of the trainer, “train”, “val” or “test”. Default: “train”.
- Returns:
Loss value. return_dict (dict): Dictionary of loss and accuracy.
- Return type:
loss (torch.Tensor)
- training_step(batch: Any, batch_idx: int) Tensor
Training step.
- Compute loss and accuracy, and log them by
self.log_dict
. For training, log on each step and each epoch. For validation and testing, only log on each epoch. This way can avoid using
on_training_epoch_end()
andon_validation_epoch_end()
.
- Compute loss and accuracy, and log them by
- validation_step(batch: Any, batch_idx: int) None
Compute and return the validation loss and log_metrics on one step.
- test_step(batch: Any, batch_idx: int) None
Compute and return the test loss and log_metrics on one step.
- configure_optimizers() Optimizer
Configure optimizer for training. Can be modified to support different optimizers from
torch.optim
.