Skip to content

Commit

Permalink
[feat] Add ability to define properties on Container subclasses (#2966)
Browse files Browse the repository at this point in the history
  • Loading branch information
alberttorosyan authored Sep 5, 2023
1 parent cc5ab0c commit 57c360d
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 42 deletions.
2 changes: 1 addition & 1 deletion pkgs/aimstack/asp/boards/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def merge_dicts(dict1, dict2):
params_tab, metrics_tab, audios_tab, texts_tab, images_tab, figures_tab = ui.tabs(('Params', 'Metrics', 'Audios',
'Texts', 'Images', 'Figures'))
with params_tab:
params = run.get('params')
params = run
if params is None:
ui.text('No parameters found')
else:
Expand Down
34 changes: 5 additions & 29 deletions pkgs/aimstack/asp/models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from functools import partialmethod

from aim import Container
from aim import Container, Property
from aim._sdk.utils import utc_timestamp
from aim._sdk import type_utils
from aimcore.callbacks import Caller
Expand Down Expand Up @@ -41,6 +41,10 @@

@type_utils.query_alias('run')
class Run(Container, Caller):
name = Property()
description = Property(default='')
archived = Property(default=False)

def __init__(self, hash_: Optional[str] = None, *,
repo: Optional[Union[str, 'Repo']] = None,
mode: Optional[Union[str, ContainerOpenMode]] = ContainerOpenMode.WRITE):
Expand All @@ -49,10 +53,6 @@ def __init__(self, hash_: Optional[str] = None, *,
if not self._is_readonly:
if self.name is None:
self.name = f'Run #{self.hash}'
if self.description is None:
self.description = ''
if self.archived is None:
self.archived = False

def enable_system_monitoring(self):
if not self._is_readonly:
Expand Down Expand Up @@ -81,30 +81,6 @@ def track_system_resources(self, stats: Dict[str, Any], context: Dict, **kwargs)
for resource_name, usage in stats.items():
self.sequences.typed_sequence(SystemMetric, resource_name, context).track(usage)

@property
def name(self) -> str:
return self._attrs_tree.get('name', None)

@name.setter
def name(self, val: str):
self['name'] = val

@property
def description(self) -> str:
return self._attrs_tree.get('description', None)

@description.setter
def description(self, val: str):
self['description'] = val

@property
def archived(self) -> bool:
return self._attrs_tree.get('archived', None)

@archived.setter
def archived(self, val: bool):
self['archived'] = val

@property
def creation_time(self) -> float:
return self._tree[KeyNames.INFO_PREFIX, 'creation_time']
Expand Down
26 changes: 21 additions & 5 deletions src/aimcore/web/ui/public/aim_ui_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ class WaitForQueryError(Exception):
pass


# dict forbids setting attributes, hence using derived class
class dictionary(dict):
pass


def process_properties(obj: dict):
if '$properties' in obj:
props = obj.pop('$properties')
new_obj = dictionary()
new_obj.update(obj)
for k, v in props.items():
setattr(new_obj, k, v)
return new_obj
return obj


def query_filter(type_, query="", count=None, start=None, stop=None, is_sequence=False):
query_key = f"{type_}_{query}_{count}_{start}_{stop}"

Expand All @@ -63,8 +79,8 @@ def query_filter(type_, query="", count=None, start=None, stop=None, is_sequence

if data is None:
raise WaitForQueryError()
data = json.loads(data)

data = json.loads(data, object_hook=process_properties)

query_results_cache[query_key] = data
return data
Expand All @@ -83,7 +99,7 @@ def run_function(func_name, params):

try:
res = runFunction(board_path, func_name, params)
data = json.loads(res)["value"]
data = json.loads(res, object_hook=process_properties)["value"]

query_results_cache[run_function_key] = data
return data
Expand All @@ -105,7 +121,7 @@ def find_item(type_, is_sequence=False, hash_=None, name=None, ctx=None):

try:
data = findItem(board_path, type_, is_sequence, hash_, name, ctx)
data = json.loads(data)
data = json.loads(data, object_hook=process_properties)

query_results_cache[query_key] = data
return data
Expand Down Expand Up @@ -869,7 +885,7 @@ def focused_row(self):

@property
def selected_rows_indices(self):
return self.state["selected_rows_indices"] if "selected_rows_indices" in self.state else None
return self.state["selected_rows_indices"] if "selected_rows_indices" in self.state else None

@property
def focused_row_index(self):
Expand Down
4 changes: 2 additions & 2 deletions src/python/aim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .record import Record
from .sequence import Sequence
from .container import Container
from .container import Container, Property
from .repo import Repo

from aim._ext.notebook.notebook import load_ipython_extension
Expand All @@ -11,7 +11,7 @@
from aim._ext.tracking import analytics
from aim._sdk.package_utils import register_aimstack_packages, register_package

__all__ = ['Record', 'Sequence', 'Container', 'Repo', 'register_package']
__all__ = ['Record', 'Sequence', 'Container', 'Repo', 'Property', 'register_package']
__aim_types__ = [Sequence, Container, Record]

# python_version_deprecation_check()
Expand Down
42 changes: 42 additions & 0 deletions src/python/aim/_sdk/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@ def _close(self) -> None:
self._lock.release()


class Property:
PROP_NAME_BLACKLIST = ( # do not allow property names to be dict class public methods
'clear', 'copy', 'fromkeys', 'get', 'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values')

def __init__(self, default=None):
self._default = default
self._name = None # Will be set by __set_name__

def __set_name__(self, owner, name):
if name in Property.PROP_NAME_BLACKLIST:
raise RuntimeError(f'Cannot define Aim Property with name \'{name}\'.')
self._name = name

def __get__(self, instance: 'Container', owner):
if instance is None:
return self
return instance._get_property(self._name, self._default)

def __set__(self, instance: 'Container', value: Any):
instance._set_property(self._name, value)


@type_utils.query_alias('container', 'c')
@type_utils.auto_registry
class Container(ABCContainer):
Expand Down Expand Up @@ -138,6 +160,10 @@ def __init__(self, hash_: Optional[str] = None, *,
self._meta_tree[KeyNames.CONTAINERS, self.get_typename()] = 1
self[...] = {}

for attr_name, attr in self.__class__.__dict__.items():
if isinstance(attr, Property):
self._set_property(attr_name, attr._default)

self._tree[KeyNames.INFO_PREFIX, 'end_time'] = None

self._resources = ContainerAutoClean(self)
Expand Down Expand Up @@ -180,6 +206,9 @@ def __storage_init__(self):
self._meta_attrs_tree: TreeView = self._meta_tree.subtree('attrs')
self._attrs_tree: TreeView = self._tree.subtree('attrs')

self._meta_props_tree: TreeView = self._meta_tree.subtree('_props')
self._props_tree: TreeView = self._tree.subtree('_props')

self._data_loader: Callable[[], 'TreeView'] = lambda: self._sequence_data_tree
self.__sequence_data_tree: TreeView = None
self._sequence_map = ContainerSequenceMap(self, Sequence)
Expand Down Expand Up @@ -211,6 +240,19 @@ def get(self, key, default: Any = None, strict: bool = False):
except KeyError:
return default

def _set_property(self, name: str, value: Any):
self._props_tree[name] = value
self._meta_props_tree.merge(name, value)

def _get_property(self, name: str, default: Any = None) -> Any:
return self._props_tree.get(name, default)

def collect_properties(self) -> Dict:
try:
return self._props_tree.collect()
except KeyError:
return {}

def match(self, expr) -> bool:
query = RestrictedPythonQuery(expr)
query_cache = {}
Expand Down
2 changes: 1 addition & 1 deletion src/python/aim/container.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from aim._sdk.container import Container # noqa F401
from aim._sdk.container import Container, Property # noqa F401
13 changes: 9 additions & 4 deletions src/python/aim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,17 @@ def sequence_data(


def container_data(container: Container) -> Dict:
data = {
'hash': container.hash,
'params': container[...],
props = container.collect_properties()
props.update({
'container_type': container.get_typename(),
'container_full_type': container.get_full_typename(),
}
'hash': container.hash,
})
data = container[...]
data.update({
'hash': container.hash,
'$properties': props,
})
return data


Expand Down

0 comments on commit 57c360d

Please sign in to comment.