Source code for aim.sdk.sequence_collection

import logging
from abc import abstractmethod
from typing import Iterator
from typing import TYPE_CHECKING
from tqdm import tqdm

from aim.sdk.sequence import Sequence
from aim.sdk.types import QueryReportMode
from aim.sdk.query_utils import RunView, SequenceView
from aim.storage.query import RestrictedPythonQuery


if TYPE_CHECKING:
    from aim.sdk.run import Run
    from aim.sdk.repo import Repo
    from pandas import DataFrame

logger = logging.getLogger(__name__)


[docs]class SequenceCollection: """Abstract interface for collection of tracked series/sequences. Typically represents sequences of a same run or sequences matching given query expression. """ # TODO [AT]: move to a separate mixin class def dataframe( self, only_last: bool = False, include_run=True, include_name=True, include_context=True, include_props=True, include_params=True, ) -> 'DataFrame': # TODO [GA]: Separate runs and sequences dataframes collection dfs = [] if self._item == 'run': dfs = [ run.run.dataframe(include_props=include_props, include_params=include_params) for run in self.iter_runs() ] elif self._item == 'sequence': dfs = [ metric.dataframe(include_run=include_run, include_name=include_name, include_context=include_context, only_last=only_last) for metric in self ] if not dfs: return None import pandas as pd return pd.concat(dfs) def __iter__(self) -> Iterator[Sequence]: return self.iter()
[docs] @abstractmethod def iter(self) -> Iterator[Sequence]: """Get Sequence iterator for collection's sequences. Yields: Next sequence object based on implementation. """ ...
[docs] @abstractmethod def iter_runs(self) -> Iterator['SequenceCollection']: """Get SequenceCollection iterator for collection's runs. Yields: Next run's SequenceCollection based on implementation. """ ...
[docs]class SingleRunSequenceCollection(SequenceCollection): """Implementation of SequenceCollection interface for a single Run. Method `iter()` returns Sequence iterator which yields Sequence matching query from run's sequences. Method `iter_runs()` raises StopIteration, since the collection is bound to a single Run. Args: run (:obj:`Run`): Run object for which sequences are queried. seq_cls (:obj:`type`): The collection's sequence class. Sequences not matching to seq_cls.allowed_dtypes will be skipped. `Sequence` by default, meaning all sequences will match. query (:obj:`str`, optional): Query expression. If specified, method `iter()` will return iterator for sequences matching the query. If not, method `iter()` will return iterator for run's all sequences. """ def __init__( self, run: 'Run', seq_cls=Sequence, query: str = '', runs_proxy_cache: dict = None, timezone_offset: int = 0, ): self.run: 'Run' = run self.seq_cls = seq_cls self._item = 'sequence' self.query = RestrictedPythonQuery(query) self.runs_proxy_cache = runs_proxy_cache self._timezone_offset = timezone_offset def iter_runs(self) -> Iterator['SequenceCollection']: """""" logger.warning('Run is already bound to the Collection') raise StopIteration def iter( self ) -> Iterator[Sequence]: """""" allowed_dtypes = self.seq_cls.allowed_dtypes() seq_var = self.seq_cls.sequence_name() for seq_name, ctx, run in self.run.iter_sequence_info_by_type(allowed_dtypes): run_view = RunView(run, self.runs_proxy_cache, self._timezone_offset) seq_view = SequenceView(seq_name, ctx.to_dict(), run_view) match = self.query.check(**{'run': run_view, seq_var: seq_view}) if not match: continue yield self.seq_cls(seq_name, ctx, run)
[docs]class QuerySequenceCollection(SequenceCollection): """Implementation of SequenceCollection interface for repository's sequences matching given query. Method `iter()` returns Sequence iterator, which yields Sequence matching query from currently iterated run's sequences. Once there are no sequences left in current run, repository's next run is considered. Method `iter_runs()` returns SequenceCollection iterator for repository's runs. Args: repo (:obj:`Repo`): Aim repository object. seq_cls (:obj:`type`): The collection's sequence class. Sequences not matching to seq_cls.allowed_dtypes will be skipped. `Sequence` by default, meaning all sequences will match. query (:obj:`str`, optional): Query expression. If specified, method `iter()` will skip sequences not matching the query. If not, method `iter()` will return iterator for all sequences in repository (that's a lot of sequences!). """ def __init__( self, repo: 'Repo', seq_cls=Sequence, query: str = '', report_mode: QueryReportMode = QueryReportMode.PROGRESS_BAR, timezone_offset: int = 0, ): self.repo: 'Repo' = repo self.seq_cls = seq_cls self._item = 'sequence' self.query = query self.report_mode = report_mode self.runs_proxy_cache = dict() self._timezone_offset = timezone_offset def iter_runs(self) -> Iterator['SequenceCollection']: """""" if self.repo.structured_db: runs_iterator = self.repo.iter_runs_from_cache() else: runs_iterator = self.repo.iter_runs() runs_counter = 1 total_runs = self.repo.total_runs_count() if self.report_mode == QueryReportMode.PROGRESS_BAR: progress_bar = tqdm(total=total_runs) for run in runs_iterator: seq_collection = SingleRunSequenceCollection(run, self.seq_cls, self.query, runs_proxy_cache=self.runs_proxy_cache, timezone_offset=self._timezone_offset) if self.report_mode == QueryReportMode.PROGRESS_TUPLE: yield seq_collection, (runs_counter, total_runs) else: if self.report_mode == QueryReportMode.PROGRESS_BAR: progress_bar.update(1) yield seq_collection runs_counter += 1 def iter(self) -> Iterator[Sequence]: """""" if self.report_mode == QueryReportMode.PROGRESS_TUPLE: for run_seq, _ in self.iter_runs(): yield from run_seq else: for run_seq in self.iter_runs(): yield from run_seq
[docs]class QueryRunSequenceCollection(SequenceCollection): """Implementation of SequenceCollection interface for repository's runs matching given query. Method `iter()` returns Sequence iterator which yields Sequence for current run's all sequences. Method `iter_runs()` returns SequenceCollection iterator from repository's runs matching given query. Args: repo (:obj:`Repo`): Aim repository object. seq_cls (:obj:`type`): The collection's sequence class. Sequences not matching to seq_cls.allowed_dtypes will be skipped. `Sequence` by default, meaning all sequences will match. query (:obj:`str`, optional): Query expression. If specified, method `iter_runs()` will skip runs not matching the query. If not, method `iter_run()` will return SequenceCollection iterator for all runs in repository. """ def __init__( self, repo: 'Repo', seq_cls=Sequence, query: str = '', paginated: bool = False, offset: str = None, report_mode: QueryReportMode = QueryReportMode.PROGRESS_BAR, timezone_offset: int = 0, ): self.repo: 'Repo' = repo self.seq_cls = seq_cls self.query = query self._item = 'run' self.paginated = paginated self.offset = offset self.query = RestrictedPythonQuery(query) self.report_mode = report_mode self._timezone_offset = timezone_offset def iter(self) -> Iterator[Sequence]: """""" if self.report_mode == QueryReportMode.PROGRESS_TUPLE: for run_seq, _ in self.iter_runs(): yield from run_seq else: for run_seq in self.iter_runs(): yield from run_seq def iter_runs(self) -> Iterator['SequenceCollection']: """""" if self.repo.structured_db: runs_iterator = self.repo.iter_runs_from_cache(offset=self.offset) else: runs_iterator = self.repo.iter_runs() runs_counter = 1 total_runs = self.repo.total_runs_count() if self.report_mode == QueryReportMode.PROGRESS_BAR: progress_bar = tqdm(total=total_runs) for run in runs_iterator: run_view = RunView(run, timezone_offset=self._timezone_offset) match = self.query.check(run=run_view) seq_collection = SingleRunSequenceCollection(run, self.seq_cls) if match else None if self.report_mode == QueryReportMode.PROGRESS_TUPLE: yield seq_collection, (runs_counter, total_runs) else: if self.report_mode == QueryReportMode.PROGRESS_BAR: progress_bar.update(1) if match: yield seq_collection runs_counter += 1