from typing import Generic, Union, Tuple, TypeVar, Dict, Iterator, Any, List
from itertools import islice
import numpy as np
import logging
from aim.sdk.tracker import STEP_HASH_FUNCTIONS
from aim.storage.treeview import TreeView
from aim.storage.arrayview import ArrayView
from aim.storage.context import Context
from aim.storage.hashing import hash_auto
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from aim.sdk.run import Run
logger = logging.getLogger(__name__)
T = TypeVar('T')
class SequenceData:
def __init__(self, series_tree, version: int, columns: List[Tuple[str, str]]):
if len(columns) == 0:
raise ValueError('Cannot create SequenceData. Please specify at least one column.')
self.series_tree = series_tree
self.version = version
self.columns = columns
self.arrays: Tuple[ArrayView] = tuple(self._get_array(col, dtype) for col, dtype in columns)
self._dtype_map = {col: dtype for col, dtype in columns}
self._step_hash_fn = STEP_HASH_FUNCTIONS[self.version]
def step_hash(self, step):
return self._step_hash_fn(step)
def _checked_columns(self, columns: Union[str, List[str]]) -> List[Tuple[str, str]]:
if isinstance(columns, str):
columns = [columns]
return [(col, self._dtype_map[col]) for col in columns if (col in self._dtype_map)]
def _get_array(self, column: str, dtype: str = None) -> ArrayView:
return self.series_tree.array(column, dtype=dtype)
def view(self, columns: Union[str, Tuple[str]]):
raise NotImplementedError
def range(self, start, stop) -> 'SequenceData':
raise NotImplementedError
def sample(self, k) -> 'SequenceData':
raise NotImplementedError
def __iter__(self) -> Iterator[Tuple[int, Any]]:
yield from self.items()
def items(self) -> Iterator[Tuple[int, Any]]:
steps, vals_list = self.items_list()
yield from zip(steps, zip(*vals_list))
def values(self) -> Iterator[Any]:
yield from zip(self.values_list())
def indices(self) -> Iterator[int]:
yield from self.indices_list()
def items_list(self) -> Tuple[List[int], List[Any]]:
raise NotImplementedError
def values_list(self) -> List[Any]:
# default implementation
return self.items_list()[1]
def indices_list(self) -> List[int]:
# default implementation
return self.items_list()[0]
def numpy(self) -> Tuple[np.ndarray, List[np.ndarray]]:
# default implementation
steps, vals_list = self.items_list()
numpy_list = []
for col_idx, vals in enumerate(vals_list):
numpy_list.append(np.array(vals, dtype=self.arrays[col_idx].dtype))
return np.array(steps, np.intp), numpy_list
class SequenceV1Data(SequenceData):
def __init__(
self,
series_tree, *,
columns: List[Tuple[str, str]],
step_range: Tuple[int, int] = None,
n_items: int = -1
):
super().__init__(series_tree, version=1, columns=columns)
self.step_range = step_range
self.n_items = n_items
def view(self, columns: List[str]) -> 'SequenceData':
return SequenceV1Data(
self.series_tree, columns=self._checked_columns(columns), step_range=self.step_range, n_items=self.n_items)
def range(self, start, stop) -> 'SequenceData':
return SequenceV1Data(self.series_tree, columns=self.columns, step_range=(start, stop), n_items=self.n_items)
def sample(self, k) -> 'SequenceData':
return SequenceV1Data(self.series_tree, columns=self.columns, step_range=self.step_range, n_items=k)
def items_list(self) -> Tuple[List[int], List[Any]]:
iters = self._get_iters()
columns = [[] for _ in iters]
steps = []
if self.step_range is not None:
start, stop = self.step_range
if stop <= start or start < 0 or stop < 0:
return
for idx, val in iters[0]:
if idx < start:
for it in iters[1:]:
next(it)
continue
if idx >= stop:
break
steps.append(idx)
columns[0].append(val)
for it_index, it in enumerate(iters[1:]):
columns[it_index + 1].append(next(it))
else:
for idx, val in iters[0]:
steps.append(idx)
columns[0].append(val)
for it_index, it in enumerate(iters[1:]):
columns[it_index + 1].append(next(it))
length = len(steps)
step = (length // self.n_items or 1) if self.n_items > 0 else 1
if step != 1:
slice_ = slice(0, length, step)
last_step_needed = (length - 1) % step != 0
steps = steps[slice_] + [steps[-1]] if last_step_needed else steps[slice_]
for i, v in enumerate(columns):
columns[i] = v[slice_] + [v[-1]] if last_step_needed else v[slice_]
return steps, columns
def _get_iters(self) -> List[Iterator[Any]]:
iters = [self.arrays[0].items()]
for arr in self.arrays[1:]:
iters.append(arr.values())
return iters
class SequenceV2Data(SequenceData):
def __init__(
self,
meta_tree, series_tree, *,
columns: List[Tuple[str, str]],
n_items: int = -1
):
super().__init__(series_tree, version=2, columns=columns)
# `SequenceV2Data` has access to both metadata and series data
# trees that are not necessarily based on the same physical storage.
# Therefore, we don't have a consistency guarantee between the two.
# The implemented methods should tolerate this.
self.meta_tree = meta_tree
self.n_items = n_items
self.steps: ArrayView = self._get_array('step')
def view(self, columns: List[str]) -> 'SequenceData':
return SequenceV2Data(
self.meta_tree, self.series_tree, columns=self._checked_columns(columns), n_items=self.n_items)
def range(self, start, stop) -> 'SequenceData':
raise ValueError('Range selection cannot be applied to data stored with reservoir sampling.')
def sample(self, k) -> 'SequenceData':
return SequenceV2Data(self.meta_tree, self.series_tree, columns=self.columns, n_items=k)
def items_list(self) -> Tuple[List[int], List[Any]]:
steps, values = self.numpy()
return steps.tolist(), [v.tolist() for v in values]
def numpy(self) -> Tuple[np.ndarray, List[np.ndarray]]:
if self.n_items == -1: # select all
last_step = None
steps = np.fromiter(self.steps.values(), np.intp)
columns = [np.fromiter(arr.values(), arr.dtype) for arr in self.arrays]
else:
last_step = self.meta_tree['last_step']
steps = np.fromiter(islice(self.steps.values(), self.n_items), np.intp)
columns = [np.fromiter(islice(arr.values(), self.n_items), arr.dtype) for arr in self.arrays]
# sort all columns by step
sort_indices = steps.argsort()
columns = [arr[sort_indices] for arr in columns]
steps = steps[sort_indices]
if last_step is not None and last_step != steps[-1]:
step_hash = self.step_hash(last_step)
# The `last_step` is provided by the meta tree which may potentially
# be out of sync with the series tree.
# If such case occurs, we fall back to the series tree for the last step.
last_steps = []
try:
for i in range(len(columns)):
last_steps.append(self.arrays[i][step_hash])
except KeyError:
logger.debug('Last step not found in reservoir.')
else:
# Only if all the last steps are found, we use them.
for i in range(len(columns)):
columns[i][-1] = last_steps[i]
steps[-1] = last_step
return steps, columns
[docs]class Sequence(Generic[T]):
"""Class representing single series of tracked value.
Objects series can be retrieved as Sequence regardless the object's type,
but subclasses of Sequence might provide additional functionality.
Provides interface to access tracked values, steps, timestamps and epochs.
Values, epochs and timestamps are accessed via :obj:`aim.storage.arrayview.ArrayView` interface.
"""
registry: Dict[str, 'Sequence'] = dict()
collections_allowed = False
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
subclass_typename = cls.sequence_name()
if subclass_typename is not None: # check for intermediate helper classes
cls.registry[subclass_typename] = cls
def __init__(
self,
name: str,
context: Context, # TODO ?dict
run: 'Run'
):
self._hash: int = None
self._version: int = None
self._meta_tree: TreeView = run.meta_run_tree.subtree(('traces', context.idx, name))
self._series_tree: TreeView = None
self._columns = [('val', None), ('epoch', 'float64'), ('time', 'float64')]
self._data: SequenceData = None # use data property
self.name = name
self.context = context
self.run = run
def __repr__(self) -> str:
return f'<Sequence#{hash(self)} name=`{self.name}` context=`{self.context}` run=`{self.run}`>'
[docs] @classmethod
def allowed_dtypes(cls) -> Union[str, Tuple[str, ...]]:
"""classmethod to get allowed object types for particular sequence
For example, numeric sequences a.k.a. Metric allow float and integer numbers.
The base Sequence allows any value, and to indicate that, `allowed_dtypes` returns '*'.
"""
return '*'
[docs] @classmethod
def sequence_name(cls) -> str:
"""classmethod to get retrieve sequence's registered name"""
...
def _calc_hash(self):
return hash_auto(
(self.name,
hash(self.context),
hash(self.run))
)
def __hash__(self) -> int:
if self._hash is None:
self._hash = self._calc_hash()
return self._hash
@property
def series_tree(self) -> TreeView:
if self._series_tree is None:
self._series_tree = self.run.series_run_trees[self.version].subtree((self.context.idx, self.name))
return self._series_tree
@property
def data(self) -> SequenceData:
if self._data is None:
if self.version == 1:
self._data = SequenceV1Data(self.series_tree, columns=self._columns)
else:
self._data = SequenceV2Data(self._meta_tree, self.series_tree, columns=self._columns)
return self._data
@property
def version(self):
if self._version is None:
self._version = self._meta_tree.get('version', 1)
return self._version
@property
def values(self) -> ArrayView:
"""Tracked values array as :obj:`ArrayView`.
:getter: Returns values ArrayView.
"""
return self.data._get_array('val')
@property
def epochs(self) -> ArrayView:
"""Tracked epochs array as :obj:`ArrayView`.
:getter: Returns epochs ArrayView.
"""
return self.data._get_array('epoch', dtype='float64')
@property
def timestamps(self) -> ArrayView:
"""Tracked timestamps array as :obj:`ArrayView`.
:getter: Returns timestamps ArrayView.
"""
return self.data._get_array('time', dtype='float64')
def __bool__(self) -> bool:
try:
return bool(self.values)
except ValueError:
return False
def __len__(self) -> int:
return len(self.values)
def preload(self):
self.series_tree.preload()
class MediaSequenceBase(Sequence):
"""Helper class for media sequence types."""
collections_allowed = True
def first_step(self):
"""Get sequence tracked first step.
Required to implement ranged and sliced data fetching.
"""
return self._meta_tree['first_step']
def last_step(self):
"""Get sequence tracked last step.
Required to implement ranged and sliced data fetching.
"""
return self._meta_tree['last_step']
def record_length(self):
"""Get tracked records longest list length or `None` if Text objects are tracked.
Required to implement ranged and sliced data fetching.
"""
return self._meta_tree.get('record_max_length', None)