dacapo.experiments.trainers.trainer
Module Contents
Classes
Trainer Abstract Base Class |
- class dacapo.experiments.trainers.trainer.Trainer
Trainer Abstract Base Class
This serves as the blueprint for any trainer classes in the dacapo library. It defines essential methods that every subclass must implement for effective training of a neural network model.
- iteration: int
- batch_size: int
- learning_rate: float
- abstract create_optimizer(model: dacapo.experiments.model.Model) torch.optim.Optimizer
Creates an optimizer for the model.
- Parameters:
model (Model) – The model for which the optimizer will be created.
- Returns:
The optimizer created for the model.
- Return type:
torch.optim.Optimizer
- abstract iterate(num_iterations: int, model: dacapo.experiments.model.Model, optimizer: torch.optim.Optimizer, device: torch.device) Iterator[dacapo.experiments.training_iteration_stats.TrainingIterationStats]
Performs a number of training iterations.
- Parameters:
num_iterations (int) – Number of training iterations.
model (Model) – The model to be trained.
optimizer (torch.optim.Optimizer) – The optimizer for the model.
device (torch.device) – The device (GPU/CPU) where the model will be trained.
- Returns:
An iterator of the training statistics.
- Return type:
Iterator[TrainingIterationStats]
- abstract can_train(datasets: List[dacapo.experiments.datasplits.datasets.Dataset]) bool
Checks if the trainer can train with a specific set of datasets.
Some trainers may have specific requirements for their training datasets.
- Parameters:
datasets (List[Dataset]) – The training datasets.
- Returns:
True if the trainer can train on the given datasets, False otherwise.
- Return type:
bool
- abstract build_batch_provider(datasets: List[dacapo.experiments.datasplits.datasets.Dataset], model: dacapo.experiments.model.Model, task: dacapo.experiments.tasks.task.Task, snapshot_container: dacapo.store.array_store.LocalContainerIdentifier) None
Initializes the training pipeline using various components.
This method uses the datasets, model, task, and snapshot_container to set up the training pipeline.
- Parameters:
datasets (List[Dataset]) – The datasets to pull data from.
model (Model) – The model to inform the pipeline of required input/output sizes.
task (Task) – The task to transform ground truth into target.
snapshot_container (LocalContainerIdentifier) – Defines where snapshots will be saved.