dacapo.experiments.trainers =========================== .. py:module:: dacapo.experiments.trainers Subpackages ----------- .. toctree:: :maxdepth: 1 /autoapi/dacapo/experiments/trainers/gp_augments/index /autoapi/dacapo/experiments/trainers/optimizers/index Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/dacapo/experiments/trainers/dummy_trainer/index /autoapi/dacapo/experiments/trainers/dummy_trainer_config/index /autoapi/dacapo/experiments/trainers/gunpowder_trainer/index /autoapi/dacapo/experiments/trainers/gunpowder_trainer_config/index /autoapi/dacapo/experiments/trainers/trainer/index /autoapi/dacapo/experiments/trainers/trainer_config/index Classes ------- .. autoapisummary:: dacapo.experiments.trainers.Trainer dacapo.experiments.trainers.TrainerConfig dacapo.experiments.trainers.DummyTrainerConfig dacapo.experiments.trainers.DummyTrainer dacapo.experiments.trainers.GunpowderTrainerConfig dacapo.experiments.trainers.GunpowderTrainer dacapo.experiments.trainers.AugmentConfig Package Contents ---------------- .. py:class:: Trainer Trainer Abstract Base Class This serves as the blueprint for any trainer classes in the dacapo library. It defines essential methods that every subclass must implement for effective training of a neural network model. .. attribute:: iteration The number of training iterations. :type: int .. attribute:: batch_size The size of the training batch. :type: int .. attribute:: learning_rate The learning rate for the optimizer. :type: float .. method:: create_optimizer(model Model) -> torch.optim.Optimizer: Creates an optimizer for the model. .. method:: iterate(num_iterations int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: Performs a number of training iterations. .. method:: can_train(datasets List[Dataset]) -> bool: Checks if the trainer can train with a specific set of datasets. .. method:: build_batch_provider(datasets List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: Initializes the training pipeline using various components. .. note:: The Trainer class is an abstract class that cannot be instantiated directly. It is meant to be subclassed. .. py:attribute:: iteration :type: int .. py:attribute:: batch_size :type: int .. py:attribute:: learning_rate :type: float .. py:method:: create_optimizer(model: dacapo.experiments.model.Model) -> torch.optim.Optimizer :abstractmethod: Creates an optimizer for the model. :param model: The model for which the optimizer will be created. :type model: Model :returns: The optimizer created for the model. :rtype: torch.optim.Optimizer :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> optimizer = trainer.create_optimizer(model) .. note:: This method must be implemented by the subclass. .. py:method:: iterate(num_iterations: int, model: dacapo.experiments.model.Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[dacapo.experiments.training_iteration_stats.TrainingIterationStats] :abstractmethod: Performs a number of training iterations. :param num_iterations: Number of training iterations. :type num_iterations: int :param model: The model to be trained. :type model: Model :param optimizer: The optimizer for the model. :type optimizer: torch.optim.Optimizer :param device: The device (GPU/CPU) where the model will be trained. :type device: torch.device :returns: An iterator of the training statistics. :rtype: Iterator[TrainingIterationStats] :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device): >>> print(iteration_stats) .. note:: This method must be implemented by the subclass. .. py:method:: can_train(datasets: List[dacapo.experiments.datasplits.datasets.Dataset]) -> bool :abstractmethod: Checks if the trainer can train with a specific set of datasets. Some trainers may have specific requirements for their training datasets. :param datasets: The training datasets. :type datasets: List[Dataset] :returns: True if the trainer can train on the given datasets, False otherwise. :rtype: bool :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> can_train = trainer.can_train(datasets) .. note:: This method must be implemented by the subclass. .. py:method:: build_batch_provider(datasets: List[dacapo.experiments.datasplits.datasets.Dataset], model: dacapo.experiments.model.Model, task: dacapo.experiments.tasks.task.Task, snapshot_container: dacapo.store.array_store.LocalContainerIdentifier) -> None :abstractmethod: Initializes the training pipeline using various components. This method uses the datasets, model, task, and snapshot_container to set up the training pipeline. :param datasets: The datasets to pull data from. :type datasets: List[Dataset] :param model: The model to inform the pipeline of required input/output sizes. :type model: Model :param task: The task to transform ground truth into target. :type task: Task :param snapshot_container: Defines where snapshots will be saved. :type snapshot_container: LocalContainerIdentifier :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> trainer.build_batch_provider(datasets, model, task, snapshot_container) .. note:: This method must be implemented by the subclass. .. py:class:: TrainerConfig A class to represent the Trainer Configurations. It is the base class for trainer configurations. Each subclass of a `Trainer` should have a specific config class derived from `TrainerConfig`. .. attribute:: name A unique name for this trainer. :type: str .. attribute:: batch_size The batch size to be used during training. :type: int .. attribute:: learning_rate The learning rate of the optimizer. :type: float .. method:: verify() -> Tuple[bool, str] Verify whether this TrainerConfig is valid or not. .. note:: The TrainerConfig class is an abstract class that cannot be instantiated directly. It is meant to be subclassed. .. py:attribute:: name :type: str .. py:attribute:: batch_size :type: int .. py:attribute:: learning_rate :type: float .. py:method:: verify() -> Tuple[bool, str] Verify whether this TrainerConfig is valid or not. A TrainerConfig is considered valid if it has a valid batch size and learning rate. :returns: A tuple containing a boolean indicating whether the TrainerConfig is valid and a message explaining why. :rtype: tuple :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> valid, message = trainer_config.verify() >>> valid True >>> message "No validation for this Trainer" .. note:: This method must be implemented by the subclass. .. py:class:: DummyTrainerConfig This is just a dummy trainer config used for testing. None of the attributes have any particular meaning. This is just to test the trainer and the trainer config. .. attribute:: mirror_augment A boolean value indicating whether to use mirror augmentation or not. :type: bool .. method:: verify(self) -> Tuple[bool, str] This method verifies the DummyTrainerConfig object. .. py:attribute:: trainer_type .. py:attribute:: mirror_augment :type: bool .. py:method:: verify() -> Tuple[bool, str] Verify the DummyTrainerConfig object. :returns: A tuple containing a boolean value indicating whether the DummyTrainerConfig object is valid and a string containing the reason why the object is invalid. :rtype: Tuple[bool, str] .. rubric:: Examples >>> valid, reason = trainer_config.verify() .. py:class:: DummyTrainer(trainer_config) This class is used to train a model using dummy data and is used for testing purposes. It contains attributes related to learning rate, batch size, and mirror augment. It also contains methods to create an optimizer, iterate over the training data, build a batch provider, and check if the trainer can train on the given data split. This class contains methods to enter and exit the context manager. The iterate method yields training iteration statistics. .. attribute:: learning_rate The learning rate to use. :type: float .. attribute:: batch_size The batch size to use. :type: int .. attribute:: mirror_augment A boolean value indicating whether to use mirror augmentation or not. :type: bool .. method:: __init__(self, trainer_config) This method initializes the DummyTrainer object. .. method:: create_optimizer(self, model) This method creates an optimizer for the given model. .. method:: iterate(self, num_iterations int, model, optimizer, device): This method iterates over the training data for the specified number of iterations. .. method:: build_batch_provider(self, datasplit, architecture, task, snapshot_container) This method builds a batch provider for the given data split, architecture, task, and snapshot container. .. method:: can_train(self, datasplit) This method checks if the trainer can train on the given data split. .. method:: __enter__(self) This method enters the context manager. .. method:: __exit__(self, exc_type, exc_val, exc_tb) This method exits the context manager. .. note:: The iterate method yields TrainingIterationStats. .. py:attribute:: iteration :value: 0 .. py:attribute:: learning_rate .. py:attribute:: batch_size .. py:attribute:: mirror_augment .. py:method:: create_optimizer(model) Create an optimizer for the given model. :param model: The model to optimize. :type model: Model :returns: The optimizer object. :rtype: torch.optim.Optimizer .. rubric:: Examples >>> optimizer = create_optimizer(model) .. py:method:: iterate(num_iterations: int, model: dacapo.experiments.model.Model, optimizer, device) Iterate over the training data for the specified number of iterations. :param num_iterations: The number of iterations to perform. :type num_iterations: int :param model: The model to train. :type model: Model :param optimizer: The optimizer to use. :type optimizer: torch.optim.Optimizer :param device: The device to perform the computations on. :type device: torch.device :Yields: *TrainingIterationStats* -- The training iteration statistics. :raises ValueError: If the number of iterations is less than or equal to zero. .. rubric:: Examples >>> for stats in iterate(num_iterations, model, optimizer, device): >>> print(stats) .. py:method:: build_batch_provider(datasplit, architecture, task, snapshot_container) Build a batch provider for the given data split, architecture, task, and snapshot container. :param datasplit: The data split to use. :type datasplit: DataSplit :param architecture: The architecture to use. :type architecture: Architecture :param task: The task to perform. :type task: Task :param snapshot_container: The snapshot container to use. :type snapshot_container: SnapshotContainer :returns: The batch provider object. :rtype: BatchProvider :raises ValueError: If the task loss is not set. .. rubric:: Examples >>> batch_provider = build_batch_provider(datasplit, architecture, task, snapshot_container) .. py:method:: can_train(datasplit) Check if the trainer can train on the given data split. :param datasplit: The data split to check. :type datasplit: DataSplit :returns: True if the trainer can train on the data split, False otherwise. :rtype: bool :raises NotImplementedError: If the method is not implemented. .. rubric:: Examples >>> can_train(datasplit) .. py:class:: GunpowderTrainerConfig This class is used to configure a Gunpowder Trainer. It contains attributes related to trainer type, number of data fetchers, augmentations to apply, snapshot interval, minimum masked value, and a boolean value indicating whether to clip raw or not. .. attribute:: trainer_type This is the type of the trainer which is set to GunpowderTrainer by default. :type: class .. attribute:: num_data_fetchers This is the number of CPU workers who will be dedicated to fetch and process the data. :type: int .. attribute:: augments This is the list of augments to apply during the training. :type: List[AugmentConfig] .. attribute:: snapshot_interval This is the number of iterations after which a new snapshot should be saved. :type: Optional[int] .. attribute:: min_masked This is the minimum masked value. :type: Optional[float] .. attribute:: clip_raw This is a boolean value indicating if the raw data should be clipped to the size of the GT data or not. :type: bool .. py:attribute:: trainer_type .. py:attribute:: num_data_fetchers :type: int .. py:attribute:: augments :type: List[dacapo.experiments.trainers.gp_augments.AugmentConfig] .. py:attribute:: snapshot_interval :type: Optional[int] .. py:attribute:: min_masked :type: Optional[float] .. py:attribute:: clip_raw :type: bool .. py:attribute:: gt_min_reject :type: Optional[float] .. py:class:: GunpowderTrainer(trainer_config) GunpowderTrainer class for training a model using gunpowder. This class is a subclass of the Trainer class. It implements the abstract methods defined in the Trainer class. The GunpowderTrainer class is used to train a model using gunpowder, a data loading and augmentation library. It is used to train a model on a dataset using a specific task. .. attribute:: learning_rate The learning rate for the optimizer. :type: float .. attribute:: batch_size The size of the training batch. :type: int .. attribute:: num_data_fetchers The number of data fetchers. :type: int .. attribute:: print_profiling The number of iterations after which to print profiling stats. :type: int .. attribute:: snapshot_iteration The number of iterations after which to save a snapshot. :type: int .. attribute:: min_masked The minimum value of the mask. :type: float .. attribute:: augments The list of augmentations to apply to the data. :type: List[Augment] .. attribute:: mask_integral_downsample_factor The downsample factor for the mask integral. :type: int .. attribute:: clip_raw Whether to clip the raw data. :type: bool .. attribute:: scheduler The learning rate scheduler. :type: torch.optim.lr_scheduler.LinearLR .. method:: create_optimizer(model Model) -> torch.optim.Optimizer: Creates an optimizer for the model. .. method:: build_batch_provider(datasets List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: Initializes the training pipeline using various components. .. method:: iterate(num_iterations int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: Performs a number of training iterations. .. method:: __iter__() -> Iterator[None] Initializes the training pipeline. .. method:: next() -> Tuple[NumpyArray, NumpyArray, NumpyArray, NumpyArray, NumpyArray] Fetches the next batch of data. .. method:: __enter__() -> GunpowderTrainer Enters the context manager. .. method:: __exit__(exc_type, exc_val, exc_tb) -> None Exits the context manager. .. method:: can_train(datasets List[Dataset]) -> bool: Checks if the trainer can train with a specific set of datasets. .. note:: The GunpowderTrainer class is a subclass of the Trainer class. It is used to train a model using gunpowder. .. py:attribute:: iteration :value: 0 .. py:attribute:: learning_rate .. py:attribute:: batch_size .. py:attribute:: num_data_fetchers .. py:attribute:: print_profiling :value: 100 .. py:attribute:: snapshot_iteration .. py:attribute:: min_masked .. py:attribute:: augments .. py:attribute:: mask_integral_downsample_factor :value: 4 .. py:attribute:: clip_raw .. py:attribute:: gt_min_reject .. py:attribute:: scheduler :value: None .. py:method:: create_optimizer(model) Creates an optimizer for the model. :param model: The model for which the optimizer will be created. :type model: Model :returns: The optimizer created for the model. :rtype: torch.optim.Optimizer :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> optimizer = trainer.create_optimizer(model) .. py:method:: build_batch_provider(datasets, model, task, snapshot_container=None) Initializes the training pipeline using various components. :param datasets: The list of datasets. :type datasets: List[Dataset] :param model: The model to be trained. :type model: Model :param task: The task to be performed. :type task: Task :param snapshot_container: The snapshot container. :type snapshot_container: LocalContainerIdentifier :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> trainer.build_batch_provider(datasets, model, task, snapshot_container) .. py:method:: iterate(num_iterations, model, optimizer, device) Performs a number of training iterations. :param num_iterations: The number of training iterations. :type num_iterations: int :param model: The model to be trained. :type model: Model :param optimizer: The optimizer for the model. :type optimizer: torch.optim.Optimizer :param device: The device (GPU/CPU) where the model will be trained. :type device: torch.device :returns: An iterator of the training statistics. :rtype: Iterator[TrainingIterationStats] :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device): >>> print(iteration_stats) .. py:method:: next() Fetches the next batch of data. :returns: A tuple containing the raw data, ground truth data, target data, weight data, and mask data. :rtype: Tuple[NumpyArray, NumpyArray, NumpyArray, NumpyArray, NumpyArray] :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> raw, gt, target, weight, mask = trainer.next() .. py:method:: can_train(datasets) -> bool Checks if the trainer can train with a specific set of datasets. :param datasets: The list of datasets. :type datasets: List[Dataset] :returns: True if the trainer can train with the datasets, False otherwise. :rtype: bool :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> can_train = trainer.can_train(datasets) .. py:method:: visualize_pipeline(bind_address='0.0.0.0', bind_port=0) Visualizes the pipeline for the run, including all produced arrays. :param bind_address: str Bind address for Neuroglancer webserver :param bind_port: int Bind port for Neuroglancer webserver .. py:class:: AugmentConfig Base class for gunpowder augment configurations. Each subclass of a `Augment` should have a corresponding config class derived from `AugmentConfig`. .. attribute:: _raw_key Key for raw data. Not used in this implementation. Defaults to None. .. attribute:: _gt_key Key for ground truth data. Not used in this implementation. Defaults to None. .. attribute:: _mask_key Key for mask data. Not used in this implementation. Defaults to None. .. method:: node(_raw_key=None, _gt_key=None, _mask_key=None) Get a gp.Augment node. .. py:method:: node(raw_key: gunpowder.ArrayKey, gt_key: gunpowder.ArrayKey, mask_key: gunpowder.ArrayKey) -> gunpowder.BatchFilter :abstractmethod: Get a gunpowder augment node. :param raw_key: Key for raw data. :type raw_key: gp.ArrayKey :param gt_key: Key for ground truth data. :type gt_key: gp.ArrayKey :param mask_key: Key for mask data. :type mask_key: gp.ArrayKey :returns: Augmentation node which can be incorporated in the pipeline. :rtype: gunpowder.BatchFilter :raises NotImplementedError: This method is not implemented. .. rubric:: Examples >>> node = augment_config.node(raw_key, gt_key, mask_key)