Source code for aim.sdk.objects.distribution

import numpy as np

from typing import Tuple

from aim.storage.object import CustomObject
from aim.storage.types import BLOB


[docs]@CustomObject.alias('aim.distribution') class Distribution(CustomObject): """Distribution object used to store distribution objects in Aim repository. Args: distribution (:obj:): array-like object used to construct `aim.Distribution`. bin_count (:obj:`int`, optional): Optional distribution bin count. 64 by default, max 512. """ AIM_NAME = 'aim.distribution' def __init__(self, distribution, bin_count=64): super().__init__() if not isinstance(bin_count, int): raise TypeError('`bin_count` must be an integer.') if 1 > bin_count > 512: raise ValueError('Supported range for `bin_count` is [1, 512].') self.storage['bin_count'] = bin_count # convert to np.histogram try: np_histogram = np.histogram(distribution, bins=bin_count) except TypeError: raise TypeError(f'Cannot convert to aim.Distribution. Unsupported type {type(distribution)}.') self._from_np_histogram(np_histogram) @property def bin_count(self): """Stored distribution bin count :getter: Returns distribution bin_count. :type: string """ return self.storage['bin_count'] @property def range(self): """Stored distribution range :getter: Returns distribution range. :type: List """ return self.storage['range'] @property def weights(self): """Stored distribution weights :getter: Returns distribution weights as `np.array`. :type: np.ndarray """ return np.frombuffer(self.storage['data'].load(), dtype=self.storage['dtype'], count=self.storage['bin_count']) @property def ranges(self): """Stored distribution ranges :getter: Returns distribution ranges as `np.array`. :type: np.ndarray """ assert (len(self.range) == 2) return np.linspace(self.range[0], self.range[1], num=self.bin_count)
[docs] def json(self): """Dump distribution metadata to a dict""" return { 'bin_count': self.bin_count, 'range': self.range, }
def _from_np_histogram(self, np_histogram: Tuple[np.ndarray, np.ndarray]): assert isinstance(np_histogram[0], np.ndarray) assert isinstance(np_histogram[1], np.ndarray) self.storage['data'] = BLOB(data=np_histogram[0].tobytes()) self.storage['dtype'] = str(np_histogram[0].dtype) self.storage['range'] = [np_histogram[1][0].item(), np_histogram[1][-1].item()]
[docs] def to_np_histogram(self): """Return `np.histogram` compatible format of the distribution""" return self.weights, self.ranges