dacapo.experiments.tasks.losses.hot_distance_loss ================================================= .. py:module:: dacapo.experiments.tasks.losses.hot_distance_loss Classes ------- .. autoapisummary:: dacapo.experiments.tasks.losses.hot_distance_loss.HotDistanceLoss Module Contents --------------- .. py:class:: HotDistanceLoss A class used to represent the Hot Distance Loss function. This class inherits from the Loss class. The Hot Distance Loss function is used for predicting hot and distance maps at the same time. The first half of the channels are the hot maps, the second half are the distance maps. The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. The model should predict twice the number of channels as the target. .. attribute:: hot_loss The Binary Cross Entropy Loss function. .. attribute:: distance_loss The Mean Square Error Loss function. .. method:: compute(prediction, target, weight) -> torch.Tensor Function to compute the Hot Distance Loss for the provided prediction and target, with respect to the weight. .. method:: split(x) -> Tuple[torch.Tensor, torch.Tensor] Function to split the input tensor into two tensors. .. note:: This class is abstract. Subclasses must implement the abstract methods. Once created, the values of its attributes cannot be changed. .. py:method:: compute(prediction, target, weight) Function to compute the Hot Distance Loss for the provided prediction and target, with respect to the weight. :param prediction: torch.Tensor The predicted tensor. :param target: torch.Tensor The target tensor. :param weight: torch.Tensor The weight tensor. :returns: torch.Tensor The computed Hot Distance Loss tensor. :raises NotImplementedError: If the method is not implemented in the subclass. .. rubric:: Examples >>> loss = HotDistanceLoss() >>> prediction = torch.tensor([1.0, 2.0, 3.0]) >>> target = torch.tensor([1.0, 2.0, 3.0]) >>> weight = torch.tensor([1.0, 1.0, 1.0]) >>> loss.compute(prediction, target, weight) tensor(0.) .. note:: This method must be implemented in the subclass. It should return the computed Hot Distance Loss tensor. .. py:method:: hot_loss(prediction, target, weight) The Binary Cross Entropy Loss function. This function computes the BCELoss for the hot maps. :param prediction: torch.Tensor The predicted tensor. :param target: torch.Tensor The target tensor. :param weight: torch.Tensor The weight tensor. :returns: torch.Tensor The computed BCELoss tensor. :raises NotImplementedError: If the method is not implemented in the subclass. .. rubric:: Examples >>> loss = HotDistanceLoss() >>> prediction = torch.tensor([1.0, 2.0, 3.0]) >>> target = torch.tensor([1.0, 2.0, 3.0]) >>> weight = torch.tensor([1.0, 1.0, 1.0]) >>> loss.hot_loss(prediction, target, weight) tensor(0.) .. note:: This method must be implemented in the subclass. It should return the computed BCELoss tensor. .. py:method:: distance_loss(prediction, target, weight) The Mean Square Error Loss function. This function computes the MSELoss for the distance maps. :param prediction: torch.Tensor The predicted tensor. :param target: torch.Tensor The target tensor. :param weight: torch.Tensor The weight tensor. :returns: torch.Tensor The computed MSELoss tensor. :raises NotImplementedError: If the method is not implemented in the subclass. .. rubric:: Examples >>> loss = HotDistanceLoss() >>> prediction = torch.tensor([1.0, 2.0, 3.0]) >>> target = torch.tensor([1.0, 2.0, 3.0]) >>> weight = torch.tensor([1.0, 1.0, 1.0]) >>> loss.distance_loss(prediction, target, weight) tensor(0.) .. note:: This method must be implemented in the subclass. It should return the computed MSELoss tensor. .. py:method:: split(x) Function to split the input tensor into two tensors. :param x: torch.Tensor The input tensor. :returns: Tuple[torch.Tensor, torch.Tensor] The two split tensors. :raises NotImplementedError: If the method is not implemented in the subclass. .. rubric:: Examples >>> loss = HotDistanceLoss() >>> x = torch.tensor([1.0, 2.0, 3.0]) >>> loss.split(x) (tensor([1.0]), tensor([2.0])) .. note:: This method must be implemented in the subclass. It should return the two split tensors.