dacapo.experiments.tasks.losses =============================== .. py:module:: dacapo.experiments.tasks.losses Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/dacapo/experiments/tasks/losses/affinities_loss/index /autoapi/dacapo/experiments/tasks/losses/dummy_loss/index /autoapi/dacapo/experiments/tasks/losses/hot_distance_loss/index /autoapi/dacapo/experiments/tasks/losses/loss/index /autoapi/dacapo/experiments/tasks/losses/mse_loss/index Classes ------- .. autoapisummary:: dacapo.experiments.tasks.losses.DummyLoss dacapo.experiments.tasks.losses.MSELoss dacapo.experiments.tasks.losses.Loss dacapo.experiments.tasks.losses.AffinitiesLoss dacapo.experiments.tasks.losses.HotDistanceLoss Package Contents ---------------- .. py:class:: DummyLoss A class representing a dummy loss function that calculates the absolute difference between each prediction and target. Inherits the Loss class. .. attribute:: name str name of the loss function .. method:: compute(prediction, target, weight=None) Calculate the total loss between prediction and target. .. note:: The dummy loss is used to test the training loop and the loss calculation. It is not a real loss function. It is used to test the training loop and the loss calculation. .. py:method:: compute(prediction, target, weight=None) Method to calculate the total dummy loss. :param prediction: torch.Tensor the model's prediction :param target: torch.Tensor the target values :param weight: torch.Tensor the weight to apply to the loss :returns: torch.Tensor the total loss between prediction and target .. rubric:: Examples >>> dummy_loss = DummyLoss() >>> prediction = torch.tensor([1, 2, 3]) >>> target = torch.tensor([4, 5, 6]) >>> dummy_loss.compute(prediction, target) tensor(9) .. note:: The dummy loss is used to test the training loop and the loss calculation. It is not a real loss function. It is used to test the training loop and the loss calculation. .. py:class:: MSELoss A class used to represent the Mean Square Error Loss function (MSELoss). This class inherits from the Loss class. .. method:: compute(prediction, target, weight) -> torch.Tensor Function to compute the MSELoss for the provided prediction and target, with respect to the weight. .. note:: This class is abstract. Subclasses must implement the abstract methods. Once created, the values of its attributes cannot be changed. .. py:method:: compute(prediction, target, weight) Function to compute the MSELoss for the provided prediction and target, with respect to the weight. :param prediction: torch.Tensor The predicted tensor. :param target: torch.Tensor The target tensor. :param weight: torch.Tensor The weight tensor. :returns: torch.Tensor The computed MSELoss tensor. :raises NotImplementedError: If the method is not implemented in the subclass. .. rubric:: Examples >>> loss = MSELoss() >>> prediction = torch.tensor([1.0, 2.0, 3.0]) >>> target = torch.tensor([1.0, 2.0, 3.0]) >>> weight = torch.tensor([1.0, 1.0, 1.0]) >>> loss.compute(prediction, target, weight) tensor(0.) .. note:: This method must be implemented in the subclass. It should return the computed MSELoss tensor. .. py:class:: Loss A class used to represent a loss function. This class is an abstract class that should be inherited by any loss function class. .. method:: compute(prediction, target, weight) -> torch.Tensor Function to compute the loss for the provided prediction and target, with respect to the weight. .. note:: This class is abstract. Subclasses must implement the abstract methods. Once created, the values of its attributes cannot be changed. .. py:method:: compute(prediction: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor :abstractmethod: Compute the loss for the given prediction and target. Optionally, if given, a loss weight should be considered. All arguments are ``torch`` tensors. The return type should be a ``torch`` scalar that can be used with an optimizer, just as usual when training with ``torch``. :param prediction: The predicted tensor. :param target: The target tensor. :param weight: The weight tensor. :returns: The computed loss tensor. :raises NotImplementedError: If the method is not implemented in the subclass. .. rubric:: Examples >>> loss = MSELoss() >>> prediction = torch.tensor([1.0, 2.0, 3.0]) >>> target = torch.tensor([1.0, 2.0, 3.0]) >>> weight = torch.tensor([1.0, 1.0, 1.0]) >>> loss.compute(prediction, target, weight) tensor(0.) .. note:: This method must be implemented in the subclass. It should return the computed loss tensor. .. py:class:: AffinitiesLoss(num_affinities: int, lsds_to_affs_weight_ratio: float) A class representing a loss function that calculates the loss between affinities and local shape descriptors (LSDs). .. attribute:: num_affinities int the number of affinities .. attribute:: lsds_to_affs_weight_ratio float the ratio of the weight of the loss between affinities and LSDs .. method:: compute(prediction, target, weight=None) Calculate the total loss between prediction and target. .. note:: The AffinitiesLoss class is used to calculate the loss between affinities and local shape descriptors (LSDs). .. py:attribute:: num_affinities .. py:attribute:: lsds_to_affs_weight_ratio .. py:method:: compute(prediction, target, weight) Method to calculate the total loss between affinities and LSDs. :param prediction: torch.Tensor the model's prediction :param target: torch.Tensor the target values :param weight: torch.Tensor the weight to apply to the loss :returns: torch.Tensor the total loss between affinities and LSDs :raises ValueError: if the number of affinities in the prediction and target does not match .. rubric:: Examples >>> affinities_loss = AffinitiesLoss(3, 0.5) >>> prediction = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) >>> target = torch.tensor([[9, 10, 11, 12], [13, 14, 15, 16]]) >>> weight = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 1]]) >>> affinities_loss.compute(prediction, target, weight) tensor(0.5) .. note:: The AffinitiesLoss class is used to calculate the loss between affinities and local shape descriptors (LSDs). .. py:class:: HotDistanceLoss A class used to represent the Hot Distance Loss function. This class inherits from the Loss class. The Hot Distance Loss function is used for predicting hot and distance maps at the same time. The first half of the channels are the hot maps, the second half are the distance maps. The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. The model should predict twice the number of channels as the target. .. attribute:: hot_loss The Binary Cross Entropy Loss function. .. attribute:: distance_loss The Mean Square Error Loss function. .. method:: compute(prediction, target, weight) -> torch.Tensor Function to compute the Hot Distance Loss for the provided prediction and target, with respect to the weight. .. method:: split(x) -> Tuple[torch.Tensor, torch.Tensor] Function to split the input tensor into two tensors. .. note:: This class is abstract. Subclasses must implement the abstract methods. Once created, the values of its attributes cannot be changed. .. py:method:: compute(prediction, target, weight) Function to compute the Hot Distance Loss for the provided prediction and target, with respect to the weight. :param prediction: torch.Tensor The predicted tensor. :param target: torch.Tensor The target tensor. :param weight: torch.Tensor The weight tensor. :returns: torch.Tensor The computed Hot Distance Loss tensor. :raises NotImplementedError: If the method is not implemented in the subclass. .. rubric:: Examples >>> loss = HotDistanceLoss() >>> prediction = torch.tensor([1.0, 2.0, 3.0]) >>> target = torch.tensor([1.0, 2.0, 3.0]) >>> weight = torch.tensor([1.0, 1.0, 1.0]) >>> loss.compute(prediction, target, weight) tensor(0.) .. note:: This method must be implemented in the subclass. It should return the computed Hot Distance Loss tensor. .. py:method:: hot_loss(prediction, target, weight) The Binary Cross Entropy Loss function. This function computes the BCELoss for the hot maps. :param prediction: torch.Tensor The predicted tensor. :param target: torch.Tensor The target tensor. :param weight: torch.Tensor The weight tensor. :returns: torch.Tensor The computed BCELoss tensor. :raises NotImplementedError: If the method is not implemented in the subclass. .. rubric:: Examples >>> loss = HotDistanceLoss() >>> prediction = torch.tensor([1.0, 2.0, 3.0]) >>> target = torch.tensor([1.0, 2.0, 3.0]) >>> weight = torch.tensor([1.0, 1.0, 1.0]) >>> loss.hot_loss(prediction, target, weight) tensor(0.) .. note:: This method must be implemented in the subclass. It should return the computed BCELoss tensor. .. py:method:: distance_loss(prediction, target, weight) The Mean Square Error Loss function. This function computes the MSELoss for the distance maps. :param prediction: torch.Tensor The predicted tensor. :param target: torch.Tensor The target tensor. :param weight: torch.Tensor The weight tensor. :returns: torch.Tensor The computed MSELoss tensor. :raises NotImplementedError: If the method is not implemented in the subclass. .. rubric:: Examples >>> loss = HotDistanceLoss() >>> prediction = torch.tensor([1.0, 2.0, 3.0]) >>> target = torch.tensor([1.0, 2.0, 3.0]) >>> weight = torch.tensor([1.0, 1.0, 1.0]) >>> loss.distance_loss(prediction, target, weight) tensor(0.) .. note:: This method must be implemented in the subclass. It should return the computed MSELoss tensor. .. py:method:: split(x) Function to split the input tensor into two tensors. :param x: torch.Tensor The input tensor. :returns: Tuple[torch.Tensor, torch.Tensor] The two split tensors. :raises NotImplementedError: If the method is not implemented in the subclass. .. rubric:: Examples >>> loss = HotDistanceLoss() >>> x = torch.tensor([1.0, 2.0, 3.0]) >>> loss.split(x) (tensor([1.0]), tensor([2.0])) .. note:: This method must be implemented in the subclass. It should return the two split tensors.