dacapo.experiments.trainers.trainer =================================== .. py:module:: dacapo.experiments.trainers.trainer Classes ------- .. autoapisummary:: dacapo.experiments.trainers.trainer.Trainer Module Contents --------------- .. py:class:: Trainer Trainer Abstract Base Class This serves as the blueprint for any trainer classes in the dacapo library. It defines essential methods that every subclass must implement for effective training of a neural network model. .. attribute:: iteration The number of training iterations. :type: int .. attribute:: batch_size The size of the training batch. :type: int .. attribute:: learning_rate The learning rate for the optimizer. :type: float .. method:: create_optimizer(model Model) -> torch.optim.Optimizer: Creates an optimizer for the model. .. method:: iterate(num_iterations int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: Performs a number of training iterations. .. method:: can_train(datasets List[Dataset]) -> bool: Checks if the trainer can train with a specific set of datasets. .. method:: build_batch_provider(datasets List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: Initializes the training pipeline using various components. .. note:: The Trainer class is an abstract class that cannot be instantiated directly. It is meant to be subclassed. .. py:attribute:: iteration :type: int .. py:attribute:: batch_size :type: int .. py:attribute:: learning_rate :type: float .. py:method:: create_optimizer(model: dacapo.experiments.model.Model) -> torch.optim.Optimizer :abstractmethod: Creates an optimizer for the model. :param model: The model for which the optimizer will be created. :type model: Model :returns: The optimizer created for the model. :rtype: torch.optim.Optimizer :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> optimizer = trainer.create_optimizer(model) .. note:: This method must be implemented by the subclass. .. py:method:: iterate(num_iterations: int, model: dacapo.experiments.model.Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[dacapo.experiments.training_iteration_stats.TrainingIterationStats] :abstractmethod: Performs a number of training iterations. :param num_iterations: Number of training iterations. :type num_iterations: int :param model: The model to be trained. :type model: Model :param optimizer: The optimizer for the model. :type optimizer: torch.optim.Optimizer :param device: The device (GPU/CPU) where the model will be trained. :type device: torch.device :returns: An iterator of the training statistics. :rtype: Iterator[TrainingIterationStats] :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device): >>> print(iteration_stats) .. note:: This method must be implemented by the subclass. .. py:method:: can_train(datasets: List[dacapo.experiments.datasplits.datasets.Dataset]) -> bool :abstractmethod: Checks if the trainer can train with a specific set of datasets. Some trainers may have specific requirements for their training datasets. :param datasets: The training datasets. :type datasets: List[Dataset] :returns: True if the trainer can train on the given datasets, False otherwise. :rtype: bool :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> can_train = trainer.can_train(datasets) .. note:: This method must be implemented by the subclass. .. py:method:: build_batch_provider(datasets: List[dacapo.experiments.datasplits.datasets.Dataset], model: dacapo.experiments.model.Model, task: dacapo.experiments.tasks.task.Task, snapshot_container: dacapo.store.array_store.LocalContainerIdentifier) -> None :abstractmethod: Initializes the training pipeline using various components. This method uses the datasets, model, task, and snapshot_container to set up the training pipeline. :param datasets: The datasets to pull data from. :type datasets: List[Dataset] :param model: The model to inform the pipeline of required input/output sizes. :type model: Model :param task: The task to transform ground truth into target. :type task: Task :param snapshot_container: Defines where snapshots will be saved. :type snapshot_container: LocalContainerIdentifier :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> trainer.build_batch_provider(datasets, model, task, snapshot_container) .. note:: This method must be implemented by the subclass.