Source code for aim.sdk.objects.distribution

import numpy as np

from aim.storage.object import CustomObject
from aim.storage.types import BLOB


[docs]@CustomObject.alias('aim.distribution') class Distribution(CustomObject): """Distribution object used to store distribution objects in Aim repository. Args: samples (:obj:): Array-like object of data sampled from a distribution. bin_count (:obj:`int`, optional): Optional distribution bin count for binning `samples`. 64 by default, max 512. hist (:obj:, optional): Array-like object representing bin frequency counts. Max 512 bins allowed. `samples` must not be specified. bin_range (:obj:`tuple`, optional): Tuple of (start, end) bin range. """ AIM_NAME = 'aim.distribution' def __init__(self, samples=None, bin_count=64, *, hist=None, bin_range=None): super().__init__() if samples is not None: if hist is not None: raise ValueError("hist should not be specified if samples is specified.") hist, bin_edges = np.histogram(samples, bins=bin_count) elif hist is not None: if bin_range is None: raise ValueError("Please specify bin_range of (start, end) bins.") bin_count = len(hist) bin_edges = np.linspace(bin_range[0], bin_range[-1], num=bin_count + 1) else: raise ValueError("Please specify either samples or hist.") hist = np.asanyarray(hist) bin_edges = np.asanyarray(bin_edges) self._from_np_histogram(hist, bin_edges)
[docs] @classmethod def from_histogram(cls, hist, bin_range): """Create Distribution object from histogram. Args: hist (:obj:): Array-like object representing bin frequency counts. Max 512 bins allowed. bin_range (:obj:`tuple`, optional): Tuple of (start, end) bin range. """ return cls(hist=hist, bin_range=bin_range)
[docs] @classmethod def from_samples(cls, samples, bin_count=64): """Create Distribution object from data samples. Args: samples (:obj:): Array-like object of data sampled from a distribution. bin_count (:obj:`int`, optional): Optional distribution bin count for binning `samples`. 64 by default, max 512. """ return cls(samples=samples, bin_count=bin_count)
@property def bin_count(self): """Stored distribution bin count :getter: Returns distribution bin_count. :type: string """ return self.storage['bin_count'] @property def range(self): """Stored distribution range :getter: Returns distribution range. :type: List """ return self.storage['range'] @property def weights(self): """Stored distribution weights :getter: Returns distribution weights as `np.array`. :type: np.ndarray """ return np.frombuffer(self.storage['data'].load(), dtype=self.storage['dtype'], count=self.storage['bin_count']) @property def ranges(self): """Stored distribution ranges :getter: Returns distribution ranges as `np.array`. :type: np.ndarray """ assert (len(self.range) == 2) return np.linspace(self.range[0], self.range[1], num=self.bin_count + 1)
[docs] def json(self): """Dump distribution metadata to a dict""" return { 'bin_count': self.bin_count, 'range': self.range, }
def _from_np_histogram(self, hist: np.ndarray, bin_edges: np.ndarray): bin_count = len(hist) if 1 > bin_count > 512: raise ValueError("Supported range for `bin_count` is [1, 512].") self.storage['data'] = BLOB(data=hist.tobytes()) self.storage['dtype'] = str(hist.dtype) self.storage['range'] = [bin_edges[0].item(), bin_edges[-1].item()] self.storage['bin_count'] = bin_count
[docs] def to_np_histogram(self): """Return `np.histogram` compatible format of the distribution""" return self.weights, self.ranges