dacapo.experiments.trainers.gunpowder_trainer

Module Contents

Classes

GunpowderTrainer

Trainer Abstract Base Class

Attributes

logger

dacapo.experiments.trainers.gunpowder_trainer.logger
class dacapo.experiments.trainers.gunpowder_trainer.GunpowderTrainer(trainer_config)

Trainer Abstract Base Class

This serves as the blueprint for any trainer classes in the dacapo library. It defines essential methods that every subclass must implement for effective training of a neural network model.

iteration = 0
create_optimizer(model)

Creates an optimizer for the model.

Parameters:

model (Model) – The model for which the optimizer will be created.

Returns:

The optimizer created for the model.

Return type:

torch.optim.Optimizer

build_batch_provider(datasets, model, task, snapshot_container=None)

Initializes the training pipeline using various components.

This method uses the datasets, model, task, and snapshot_container to set up the training pipeline.

Parameters:
  • datasets (List[Dataset]) – The datasets to pull data from.

  • model (Model) – The model to inform the pipeline of required input/output sizes.

  • task (Task) – The task to transform ground truth into target.

  • snapshot_container (LocalContainerIdentifier) – Defines where snapshots will be saved.

iterate(num_iterations, model, optimizer, device)

Performs a number of training iterations.

Parameters:
  • num_iterations (int) – Number of training iterations.

  • model (Model) – The model to be trained.

  • optimizer (torch.optim.Optimizer) – The optimizer for the model.

  • device (torch.device) – The device (GPU/CPU) where the model will be trained.

Returns:

An iterator of the training statistics.

Return type:

Iterator[TrainingIterationStats]

next()
can_train(datasets) bool

Checks if the trainer can train with a specific set of datasets.

Some trainers may have specific requirements for their training datasets.

Parameters:

datasets (List[Dataset]) – The training datasets.

Returns:

True if the trainer can train on the given datasets, False otherwise.

Return type:

bool