dacapo.experiments.trainers.dummy_trainer
Classes
This class is used to train a model using dummy data and is used for testing purposes. It contains attributes |
Module Contents
- class dacapo.experiments.trainers.dummy_trainer.DummyTrainer(trainer_config)
This class is used to train a model using dummy data and is used for testing purposes. It contains attributes related to learning rate, batch size, and mirror augment. It also contains methods to create an optimizer, iterate over the training data, build a batch provider, and check if the trainer can train on the given data split. This class contains methods to enter and exit the context manager. The iterate method yields training iteration statistics.
- learning_rate
The learning rate to use.
- Type:
float
- batch_size
The batch size to use.
- Type:
int
- mirror_augment
A boolean value indicating whether to use mirror augmentation or not.
- Type:
bool
- __init__(self, trainer_config)
This method initializes the DummyTrainer object.
- create_optimizer(self, model)
This method creates an optimizer for the given model.
- iterate(self, num_iterations
int, model, optimizer, device): This method iterates over the training data for the specified number of iterations.
- build_batch_provider(self, datasplit, architecture, task, snapshot_container)
This method builds a batch provider for the given data split, architecture, task, and snapshot container.
- can_train(self, datasplit)
This method checks if the trainer can train on the given data split.
- __enter__(self)
This method enters the context manager.
- __exit__(self, exc_type, exc_val, exc_tb)
This method exits the context manager.
Note
The iterate method yields TrainingIterationStats.
- iteration = 0
- learning_rate
- batch_size
- mirror_augment
- create_optimizer(model)
Create an optimizer for the given model.
- Parameters:
model (Model) – The model to optimize.
- Returns:
The optimizer object.
- Return type:
torch.optim.Optimizer
Examples
>>> optimizer = create_optimizer(model)
- iterate(num_iterations: int, model: dacapo.experiments.model.Model, optimizer, device)
Iterate over the training data for the specified number of iterations.
- Parameters:
num_iterations (int) – The number of iterations to perform.
model (Model) – The model to train.
optimizer (torch.optim.Optimizer) – The optimizer to use.
device (torch.device) – The device to perform the computations on.
- Yields:
TrainingIterationStats – The training iteration statistics.
- Raises:
ValueError – If the number of iterations is less than or equal to zero.
Examples
>>> for stats in iterate(num_iterations, model, optimizer, device): >>> print(stats)
- build_batch_provider(datasplit, architecture, task, snapshot_container)
Build a batch provider for the given data split, architecture, task, and snapshot container.
- Parameters:
datasplit (DataSplit) – The data split to use.
architecture (Architecture) – The architecture to use.
task (Task) – The task to perform.
snapshot_container (SnapshotContainer) – The snapshot container to use.
- Returns:
The batch provider object.
- Return type:
BatchProvider
- Raises:
ValueError – If the task loss is not set.
Examples
>>> batch_provider = build_batch_provider(datasplit, architecture, task, snapshot_container)
- can_train(datasplit)
Check if the trainer can train on the given data split.
- Parameters:
datasplit (DataSplit) – The data split to check.
- Returns:
True if the trainer can train on the data split, False otherwise.
- Return type:
bool
- Raises:
NotImplementedError – If the method is not implemented.
Examples
>>> can_train(datasplit)