dacapo.utils.balance_weights ============================ .. py:module:: dacapo.utils.balance_weights Functions --------- .. autoapisummary:: dacapo.utils.balance_weights.balance_weights Module Contents --------------- .. py:function:: balance_weights(label_data: numpy.ndarray, num_classes: int, masks: List[numpy.ndarray] = list(), slab=None, clipmin: float = 0.05, clipmax: float = 0.95, moving_counts: Optional[List[Dict[int, Tuple[int, int]]]] = None) Balances the weights based on the label data and other parameters. :param label_data: The label data. :type label_data: np.ndarray :param num_classes: The number of classes. :type num_classes: int :param masks: List of masks. Defaults to an empty list. :type masks: List[np.ndarray], optional :param slab: The slab parameter. Defaults to None. :type slab: optional :param clipmin: The minimum clipping value. Defaults to 0.05. :type clipmin: float, optional :param clipmax: The maximum clipping value. Defaults to 0.95. :type clipmax: float, optional :param moving_counts: List of moving counts. Defaults to None. :type moving_counts: Optional[List[Dict[int, Tuple[int, int]]]], optional :returns: The balanced error scale and moving counts. :rtype: Tuple[np.ndarray, List[Dict[int, Tuple[int, int]]]] :raises AssertionError: If the number of unique labels is greater than the number of classes. :raises AssertionError: If the minimum label is less than 0 or the maximum label is greater than the number of classes. .. rubric:: Examples >>> label_data = np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]) >>> num_classes = 3 >>> masks = [np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])] >>> balance_weights(label_data, num_classes, masks) (array([[0.33333334, 0.33333334, 0.33333334], [0.33333334, 0.33333334, 0.33333334], [0.33333334, 0.33333334, 0.33333334]], dtype=float32), [{0: (3, 9), 1: (3, 9), 2: (3, 9)}]) .. rubric:: Notes The balanced error scale is computed as: error_scale = np.ones(label_data.shape, dtype=np.float32) for mask in masks: error_scale = error_scale * mask slab_ranges = (range(0, m, s) for m, s in zip(error_scale.shape, slab)) for ind, start in enumerate(itertools.product(*slab_ranges)): slab_counts = moving_counts[ind] slices = tuple(slice(start[d], start[d] + slab[d]) for d in range(len(slab))) scale_slab = error_scale[slices] labels_slab = label_data[slices] masked_in = scale_slab.sum() classes, counts = np.unique(labels_slab[np.nonzero(scale_slab)], return_counts=True) updated_fracs = [] for key, (num, den) in slab_counts.items(): slab_counts[key] = (num, den + masked_in) for class_id, num in zip(classes, counts): (old_num, den) = slab_counts[class_id] slab_counts[class_id] = (num + old_num, den) updated_fracs.append(slab_counts[class_id][0] / slab_counts[class_id][1]) fracs = np.array(updated_fracs) if clipmin is not None or clipmax is not None: np.clip(fracs, clipmin, clipmax, fracs) total_frac = 1.0 w_sparse = total_frac / float(num_classes) / fracs w = np.zeros(num_classes) w[classes] = w_sparse labels_slab = labels_slab.astype(np.int64) scale_slab *= np.take(w, labels_slab)