dacapo.utils.balance_weights
Functions
|
Balances the weights based on the label data and other parameters. |
Module Contents
- dacapo.utils.balance_weights.balance_weights(label_data: numpy.ndarray, num_classes: int, masks: List[numpy.ndarray] = list(), slab=None, clipmin: float = 0.05, clipmax: float = 0.95, moving_counts: List[Dict[int, Tuple[int, int]]] | None = None)
Balances the weights based on the label data and other parameters.
- Parameters:
label_data (np.ndarray) – The label data.
num_classes (int) – The number of classes.
masks (List[np.ndarray], optional) – List of masks. Defaults to an empty list.
slab (optional) – The slab parameter. Defaults to None.
clipmin (float, optional) – The minimum clipping value. Defaults to 0.05.
clipmax (float, optional) – The maximum clipping value. Defaults to 0.95.
moving_counts (Optional[List[Dict[int, Tuple[int, int]]]], optional) – List of moving counts. Defaults to None.
- Returns:
The balanced error scale and moving counts.
- Return type:
Tuple[np.ndarray, List[Dict[int, Tuple[int, int]]]]
- Raises:
AssertionError – If the number of unique labels is greater than the number of classes.
AssertionError – If the minimum label is less than 0 or the maximum label is greater than the number of classes.
Examples
>>> label_data = np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]) >>> num_classes = 3 >>> masks = [np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])] >>> balance_weights(label_data, num_classes, masks) (array([[0.33333334, 0.33333334, 0.33333334], [0.33333334, 0.33333334, 0.33333334], [0.33333334, 0.33333334, 0.33333334]], dtype=float32), [{0: (3, 9), 1: (3, 9), 2: (3, 9)}])
Notes
The balanced error scale is computed as: error_scale = np.ones(label_data.shape, dtype=np.float32) for mask in masks:
error_scale = error_scale * mask
slab_ranges = (range(0, m, s) for m, s in zip(error_scale.shape, slab)) for ind, start in enumerate(itertools.product(*slab_ranges)):
slab_counts = moving_counts[ind] slices = tuple(slice(start[d], start[d] + slab[d]) for d in range(len(slab))) scale_slab = error_scale[slices] labels_slab = label_data[slices] masked_in = scale_slab.sum() classes, counts = np.unique(labels_slab[np.nonzero(scale_slab)], return_counts=True) updated_fracs = [] for key, (num, den) in slab_counts.items():
slab_counts[key] = (num, den + masked_in)
- for class_id, num in zip(classes, counts):
(old_num, den) = slab_counts[class_id] slab_counts[class_id] = (num + old_num, den) updated_fracs.append(slab_counts[class_id][0] / slab_counts[class_id][1])
fracs = np.array(updated_fracs) if clipmin is not None or clipmax is not None:
np.clip(fracs, clipmin, clipmax, fracs)
total_frac = 1.0 w_sparse = total_frac / float(num_classes) / fracs w = np.zeros(num_classes) w[classes] = w_sparse labels_slab = labels_slab.astype(np.int64) scale_slab *= np.take(w, labels_slab)