Source code for aim._sdk.sequence

import logging

from typing import TypeVar, Generic, Any, Optional, Dict, Union, Tuple, List, Set, Iterator, Callable

from aim._sdk import type_utils
from aim._sdk.utils import utc_timestamp
from aim._sdk.interfaces.sequence import Sequence as ABCSequence
from aim._sdk.query_utils import SequenceQueryProxy, ContainerQueryProxy
from aim._sdk.constants import KeyNames

from aim._sdk.context import Context, cached_context
from aim._sdk.query import RestrictedPythonQuery

from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from aim._core.storage.treeview import TreeView
    from aim._sdk.container import Container
    from aim._sdk.collections import SequenceCollection
    from aim._sdk.storage_engine import StorageEngine
    from aim._sdk.repo import Repo

_ContextInfo = Union[Dict, Context, int]

logger = logging.getLogger(__name__)


class _SequenceInfo:
    def __init__(self, info_tree: 'TreeView'):
        self._meta_tree = info_tree

        self._initialized = False
        self.next_step = None
        self.version = None
        self.first_step = None
        self.last_step = None
        self.creation_time = None
        self.axis_names: Set[str] = set()
        self.dtype = None
        self.empty = None

    def preload(self):
        if self._initialized:
            return

        info = self._meta_tree.subtree(KeyNames.INFO_PREFIX)

        try:
            info.preload()
            self.version = info['version']
            self.first_step = info['first_step']
            self.last_step = info['last_step']
            self.creation_time = info['creation_time']
            self.axis_names = set(info['axis'])
            self.dtype = info[KeyNames.VALUE_TYPE]
            self.stype = info[KeyNames.SEQUENCE_TYPE]
            self.empty = False
            self.next_step = self.last_step + 1
        except KeyError:
            self.empty = True
            self.next_step = 0
        finally:
            self._initialized = True


ItemType = TypeVar('ItemType')


[docs]@type_utils.query_alias('sequence', 's') @type_utils.auto_registry class Sequence(Generic[ItemType], ABCSequence): version = '1.0.0' def __init__(self, container: 'Container', *, name: str, context: _ContextInfo): """ Initializes a Sequence with the given container, name, and context. Args: container (Container): The enclosing container for the sequence. name (str): The name of the sequence. context (_ContextInfo): The context for the sequence. """ self.storage: StorageEngine = container.storage self._container: 'Container' = container self._container_hash: str = container.hash self._meta_tree = container._meta_tree self._container_tree = container._tree self._name = name self._ctx_idx: int = None self._context: Context = None self._init_context(context) self._data_loader: Callable[[], 'TreeView'] = container._data_loader self.__data = None self.__storage_init__() def __storage_init__(self): self._tree = self._container_tree.subtree((KeyNames.SEQUENCES, self._ctx_idx, self._name)) self._values = None self._info = _SequenceInfo(self._tree) self._allowed_value_types = None
[docs] @classmethod def from_storage(cls, storage, meta_tree: 'TreeView', *, hash_: str, name: str, context: _ContextInfo): """ Creates a Sequence instance from the provided storage and metadata. Args: storage: The storage object where the sequence is stored. meta_tree (TreeView): The tree view of metadata. hash_ (str): The hash of the enclosing container. name (str): The name of the sequence. context (_ContextInfo): The context for the sequence. Returns: Sequence: A Sequence instance. """ self = cls.__new__(cls) self.storage = storage self._container = None self._container_hash = hash_ self._meta_tree = meta_tree self._container_tree = meta_tree.subtree('chunks').subtree(hash_) self._name = name self._ctx_idx = None self._context = None self._init_context(context) self._data_loader = lambda: storage.tree(hash_, 'seqs', read_only=True).subtree('chunks').subtree(hash_) self.__data: TreeView = None self.__storage_init__() return self
[docs] @classmethod def filter(cls, expr: str = '', repo: 'Repo' = None) -> 'SequenceCollection': """ Filters sequences based on a given expression. Args: expr (str, optional): The query expression for filtering. Defaults to an empty string. repo (Repo, optional): The repository to filter from. If not provided, uses the active repo. Returns: SequenceCollection: A collection of sequences satisfying the filter. """ if repo is None: from aim._sdk.repo import Repo repo = Repo.active_repo() return repo.sequences(query_=expr, type_=cls)
[docs] @classmethod def find(cls, hash_: str, name: str, context: Dict) -> Optional['Sequence']: """ Finds a Sequence based on container hash, name, and context. Args: hash_ (str): The hash identifier of the enclosing container. name (str): The name of the sequence. context (Dict): The context for the sequence. Returns: Sequence or None: Returns a Sequence if found, otherwise None. """ from aim._sdk.repo import Repo repo = Repo.active_repo() storage = repo.storage_engine meta_tree = repo._meta_tree seq = cls.from_storage(storage, meta_tree, hash_=hash_, name=name, context=context) if seq.is_empty: return None return seq
def _init_context(self, context: _ContextInfo): if isinstance(context, int): self._ctx_idx = context self._context = None elif isinstance(context, dict): self._context = Context(context) self._ctx_idx = self._context.idx elif isinstance(context, Context): self._context = context self._ctx_idx = self._context.idx @property def name(self): """str: Gets the name of the sequence.""" return self._name @property def context(self) -> Dict: """Dict: Gets the context dictionary of the sequence.""" if self._context is None: self._context = self._context_from_idx(ctx_idx=self._ctx_idx) return self._context.to_dict() @cached_context def _context_from_idx(self, ctx_idx) -> Context: return Context(self._meta_tree[KeyNames.CONTEXTS, ctx_idx]) @property def type(self) -> str: """str: Gets the type of the tracked items.""" self._info.preload() return self._info.dtype @property def allowed_value_types(self) -> Tuple[str]: """Tuple[str]: Gets the allowed value types for the sequence.""" if self._allowed_value_types is None: sequence_class = self.__sequence_class__ self._allowed_value_types = type_utils.get_sequence_value_types(sequence_class) return self._allowed_value_types @property def is_empty(self) -> bool: """bool: Returns True if the sequence is empty, otherwise False.""" self._info.preload() return self._info.empty @property def start(self) -> int: """int: Gets the first step of the sequence.""" self._info.preload() return self._info.first_step @property def stop(self) -> int: """int: Gets the last step of the sequence.""" self._info.preload() return self._info.last_step @property def next_step(self) -> int: """int: Gets the next available step for the sequence.""" self._info.preload() return self._info.next_step
[docs] def match(self, expr) -> bool: """ Checks if the sequence matches a given expression. Args: expr: The expression to check. Returns: bool: True if the sequence matches the expression, otherwise False. """ query = RestrictedPythonQuery(expr) query_cache = dict return self._check(query, query_cache)
@property def axis_names(self) -> Tuple[str]: """Tuple[str]: Gets the names of the axis for the sequence.""" self._info.preload() return tuple(self._info.axis_names)
[docs] def axis(self, name: str) -> Iterator[Any]: """ Gets the axis values for a given axis name. Args: name (str): The name of the axis. Returns: Iterator[Any]: An iterator over the axis values. """ return map(lambda x: x.get(name, None), self._data.reservoir().values())
[docs] def items(self) -> Iterator[Tuple[int, Any]]: """Iterator[Tuple[int, Any]]: Returns an iterator over the sequence items (steps and their values).""" return self._data.reservoir().items()
[docs] def values(self) -> Iterator[Any]: """Iterator[Any]: Returns an iterator over the sequence values.""" for _, v in self.items(): yield v['val']
[docs] def sample(self, k: int) -> List[Any]: """ Samples k items from the sequence. Args: k (int): The number of items to sample. Returns: List[Any]: A list of sampled items. """ return self[:].sample(k)
[docs] def track(self, value: Any, *, step: Optional[int] = None, **axis): """ Tracks a given value at a specific step in the sequence. Args: value (Any): The value to be tracked. step (Optional[int], optional): The step at which the value should be tracked. If not provided, it will be determined automatically. **axis: Additional axis values. Raises: ValueError: If the provided value type is not supported by the sequence. """ value_type = type_utils.get_object_typename(value) if not type_utils.is_allowed_type(value_type, self.allowed_value_types): raise ValueError(f'Cannot track value \'{value}\' of type {type(value)}. Type is not supported.') self._info.preload() if step is None: step = self._info.next_step axis_names = set(axis.keys()) with self.storage.write_batch(self._container_hash): if self._info.empty: sequence_type = self.get_full_typename() self._tree[KeyNames.INFO_PREFIX, 'creation_time'] = utc_timestamp() self._tree[KeyNames.INFO_PREFIX, 'version'] = self.version self._tree[KeyNames.INFO_PREFIX, KeyNames.OBJECT_CATEGORY] = self.object_category self._tree[KeyNames.INFO_PREFIX, 'first_step'] = self._info.first_step = step self._tree[KeyNames.INFO_PREFIX, 'last_step'] = self._info.last_step = step self._tree[KeyNames.INFO_PREFIX, 'axis'] = tuple(axis_names) self._tree[KeyNames.INFO_PREFIX, KeyNames.VALUE_TYPE] = self._info.dtype = value_type self._tree[KeyNames.INFO_PREFIX, KeyNames.SEQUENCE_TYPE] = sequence_type self._meta_tree[KeyNames.CONTEXTS, self._ctx_idx] = self._context.to_dict() self._container_tree[KeyNames.CONTEXTS, self._ctx_idx] = self._context.to_dict() for typename in sequence_type.split('->'): self._meta_tree[KeyNames.SEQUENCES, typename, self._ctx_idx, self.name] = 1 self._tree['first_value'] = value self._tree['last_value'] = value self._tree['axis_last_values'] = axis self._info.axis_names = axis_names self._info.empty = False if step > self._info.last_step: self._tree[KeyNames.INFO_PREFIX, 'last_step'] = self._info.last_step = step self._tree['last_value'] = value self._tree['axis_last_values'] = axis if step < self._info.first_step: self._tree[KeyNames.INFO_PREFIX, 'first_step'] = self._info.first_step = step self._tree['first_value'] = value if not type_utils.is_subtype(value_type, self._info.dtype): dtype = type_utils.get_common_typename((value_type, self._info.dtype)) self._tree[KeyNames.INFO_PREFIX, KeyNames.VALUE_TYPE] = self._info.dtype = dtype if not axis_names.issubset(self._info.axis_names): self._info.axis_names.update(axis_names) self._tree[KeyNames.INFO_PREFIX, 'axis'] = tuple(self._info.axis_names) if self._values is None: self._values = self._data.reservoir() val = {k: v for k, v in axis.items()} val['val'] = value self._values[step] = val self._info.next_step = self._info.last_step + 1
[docs] def get_logged_typename(self) -> str: """ Retrieves the type name of the logged sequence. Returns: str: The type name of the logged sequence. """ if self.is_empty: return self.get_full_typename() return self._tree[KeyNames.INFO_PREFIX, KeyNames.SEQUENCE_TYPE]
[docs] def __iter__(self) -> Iterator[Tuple[int, Tuple[Any, ...]]]: """ Returns an iterator over the steps and their associated values in the sequence. Returns: Iterator[Tuple[int, Tuple[Any, ...]]]: An iterator over the sequence items. """ data_iterator = zip(self.items(), zip(map(self.axis, self.axis_names))) for (step, value), axis_values in data_iterator: yield step, (value,) + axis_values
[docs] def __getitem__(self, item: Union[slice, str, Tuple[str]]) -> 'SequenceView': """ Retrieves items from the sequence by index, slice, or column name(s). Args: item (Union[slice, str, Tuple[str]]): The index, slice, or column name(s) to retrieve. Returns: SequenceView: A view on the selected sequence data. """ if isinstance(item, int): return self._data.reservoir()[item] if isinstance(item, str): item = (item,) if isinstance(item, slice): columns = self.axis_names + ('val',) return SequenceView(self, columns=columns, start=item.start, stop=item.stop) elif isinstance(item, tuple): return SequenceView(self, columns=item)
@property def _data(self) -> 'TreeView': if self.__data is None: self.__data = self._data_loader().subtree((self._ctx_idx, self._name)) return self.__data @property def __sequence_class__(self): if hasattr(self, '__orig_class__'): return self.__orig_class__ return self.__class__ def _check(self, query, query_cache, *, aliases=()) -> bool: hash_ = self._container_hash proxy = SequenceQueryProxy(self.name, self._context_from_idx, self._ctx_idx, self._tree, query_cache[hash_]) c_proxy = ContainerQueryProxy(hash_, self._container_tree, query_cache[hash_]) if isinstance(aliases, str): aliases = (aliases,) alias_names = self.default_aliases.union(aliases) if self._container is not None: container_alias_names = self._container.default_aliases else: from aim._sdk.container import Container container_alias_names = Container.default_aliases query_params = {p: proxy for p in alias_names} query_params.update({cp: c_proxy for cp in container_alias_names}) return query.check(**query_params) def delete(self): del self._data_loader()[(self._ctx_idx, self._name)] del self._container_tree[(KeyNames.SEQUENCES, self._ctx_idx, self._name)] self._info.empty = True self._info.next_step = 0 def __repr__(self) -> str: return f'<{self.get_typename()} #{hash(self)} name={self.name} context={self._ctx_idx}>'
[docs]class SequenceView(object): def __init__(self, sequence: Sequence, *, columns: Tuple[str], start: int = None, stop: int = None): self._start: int = start if start is not None else sequence._info.first_step self._stop: int = stop if stop is not None else sequence._info.next_step self._columns: Set[str] = set(columns) self._sequence = sequence @property def start(self) -> int: return self._start @property def stop(self) -> int: return self._stop @property def columns(self) -> Tuple[str]: return tuple(self._columns) def __getitem__(self, item: Union[slice, str, Tuple[str]]) -> 'SequenceView': if isinstance(item, int): return self._sequence._data.reservoir()[item] if isinstance(item, str): item = (item,) if isinstance(item, slice): if self.start is not None and item.start is not None: start = max(self.start, item.start) else: start = self.start if item.start is None else item.start if self.stop is not None and item.stop is not None: stop = min(self.stop, item.stop) else: stop = self.stop if item.stop is None else item.stop return SequenceView(self._sequence, start=start, stop=stop, columns=self.columns) elif isinstance(item, tuple): columns = tuple(self._columns.intersection(item)) return SequenceView(self._sequence, start=self.start, stop=self.stop, columns=columns) def sample(self, k: Optional[int] = None) -> List[Any]: def get_columns(item): return [item[0], {k: v for k, v in item[1].items() if k in self._columns}] if k is None: k = self.stop - self.start samples = self._sequence._data.reservoir().sample(k, begin=self.start, end=self.stop) return sorted(map(get_columns, samples), key=lambda x: x[0])