Source code for aim.sdk.objects.figure

import logging

from aim.sdk.num_utils import inst_has_typename
from import CustomObject
from import BLOB

logger = logging.getLogger(__name__)

[docs]@CustomObject.alias('aim.figure') class Figure(CustomObject): """ Figure object can be used for storing Plotly or Matplotlib figures into Aim repository. Core functionality is based on Plotly. Args: obj (:obj:): plotly or matplotlib figure object. """ AIM_NAME = 'aim.figure' def __init__(self, obj): super().__init__() if inst_has_typename(obj, ['matplotlib', 'Figure']): self._from_matplotlib_figure(obj) elif inst_has_typename(obj, ['plotly', 'Figure', 'BaseFigure']): self._prepare(obj) else: raise TypeError('Object is not a Plotly Figure instance') def _prepare(self, obj): try: from plotly.version import __version__ as plotly_version except ModuleNotFoundError: plotly_version = 'unknown' assert hasattr(obj, 'to_json')['source'] = 'plotly'['version'] = plotly_version['format'] = 'raw_json'['data'] = BLOB(data=obj.to_json()) @property def data(self): return['data'].load() def _from_matplotlib_figure(self, obj): try: from import mpl_to_plotly except ModuleNotFoundError: raise ModuleNotFoundError('Plotly is required to track matplotlib figure.') try: logger.warning('Tracking a matplotlib object using "aim.Figure" might not behave as expected.' 'In such cases, consider tracking with "aim.Image".') for ax in obj.axes: for collection in ax.collections: if not hasattr(collection, "get_offset_position"): collection.get_offset_position = matplotlib_get_offset_position.__get__(collection) plotly_obj = mpl_to_plotly(obj) except ValueError as err: raise ValueError(f'Failed to convert matplotlib figure to plotly figure: {err}') return self._prepare(plotly_obj)
[docs] def json(self): """Dump figure metadata to a dict""" return { 'source':['source'], 'format':['format'], 'version':['version'] }
def to_plotly_figure(self): try: from import from_json except ModuleNotFoundError: raise ModuleNotFoundError('Could not find plotly in the installed modules.') return from_json(
def matplotlib_get_offset_position(self): return self._offset_position