Skip to content

Commit

Permalink
Refactor encoding based on schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Feb 7, 2025
1 parent 752c01b commit cf214d7
Show file tree
Hide file tree
Showing 41 changed files with 924 additions and 1,593 deletions.
25 changes: 0 additions & 25 deletions superduper/backends/base/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import hashlib
import os
import typing as t
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -66,30 +65,6 @@ def drop(self, force: bool = False):
"""
pass

@abstractmethod
def _exists(self, file_id: str):
pass

def exists(
self,
file_id: t.Optional[str] = None,
datatype: t.Optional[str] = None,
uri: t.Optional[str] = None,
):
"""Check if artifact exists in artifact store.
:param file_id: file id of artifact in the store
:param datatype: Datatype of the artifact
:param uri: URI of the artifact
"""
if file_id is None:
assert uri is not None, "if file_id is None, uri can\'t be None"
file_id = _construct_file_id_from_uri(uri)
if self.load('datatype', datatype).directory:
assert datatype is not None
file_id = os.path.join(datatype.directory, file_id)
return self._exists(file_id)

@abstractmethod
def put_bytes(self, serialized: bytes, file_id: str):
"""Save bytes in artifact store.
Expand Down
3 changes: 1 addition & 2 deletions superduper/backends/base/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from superduper import CFG, logging
from superduper.backends.base.query import Query
from superduper.base.constant import KEY_BLOBS, KEY_BUILDS, KEY_FILES, KEY_SCHEMA
from superduper.base.constant import KEY_BLOBS, KEY_BUILDS, KEY_FILES
from superduper.base.document import Document

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -232,7 +232,6 @@ def _do_insert(self, table, documents):
r.pop(KEY_BUILDS)
r.pop(KEY_BLOBS)
r.pop(KEY_FILES)
r.pop(KEY_SCHEMA, None)
documents[i] = r

out = self.insert(table, documents)
Expand Down
39 changes: 27 additions & 12 deletions superduper/backends/local/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,42 @@ class FileSystemArtifactStore(ArtifactStore):
"""
Abstraction for storing large artifacts separately from primary data.
:param conn: root directory of the artifact store
:param name: subdirectory to use for this artifact store
:param flavour: Flavour of the artifact store
:param conn: Root directory of the artifact store.
:param name: Name of the artifact store.
:param flavour: Flavour of the artifact store.
:param files: Subdirectory to use for files.
:param blobs: Subdirectory to use for blobs.
"""

def __init__(
self,
conn: t.Any,
name: t.Optional[str] = None,
flavour: t.Optional[str] = None,
files: str = '',
blobs: str = '',
):
if conn.startswith('filesystem://'):
conn = conn.split('filesystem://')[-1]
super().__init__(conn, name, flavour)

if not os.path.exists(self.conn):
logging.info('Creating artifact store directory')
os.makedirs(self.conn, exist_ok=True)

def _exists(self, file_id: str):
path = os.path.join(self.conn, file_id)
return os.path.exists(path)
self.files = os.path.join(self.conn, files) if files else self.conn
if self.files != self.conn and not os.path.exists(self.files):
logging.info('Creating file store directory')
os.makedirs(self.files, exist_ok=True)

self.blobs = os.path.join(self.conn, blobs) if blobs else self.conn
if self.blobs != self.conn and not os.path.exists(self.blobs):
logging.info('Creating file store directory')
os.makedirs(self.blobs, exist_ok=True)

# def _exists(self, file_id: str):
# path = os.path.join(self.conn, file_id)
# return os.path.exists(path)

def url(self):
"""Return the URL of the artifact store."""
Expand All @@ -44,7 +59,7 @@ def _delete_bytes(self, file_id: str):
:param file_id: File id uses to identify artifact in store
"""
path = os.path.join(self.conn, file_id)
path = os.path.join(self.blobs, file_id)
if os.path.isdir(path):
shutil.rmtree(path)
else:
Expand Down Expand Up @@ -81,7 +96,7 @@ def put_bytes(
:param serialized: The bytes to be saved.
:param file_id: The id of the file.
"""
path = os.path.join(self.conn, file_id)
path = os.path.join(self.blobs, file_id)
if os.path.exists(path):
logging.warn(f"File {path} already exists")

Expand All @@ -95,7 +110,7 @@ def get_bytes(self, file_id: str) -> bytes:
:param file_id: The id of the file.
"""
with open(os.path.join(self.conn, file_id), 'rb') as f:
with open(os.path.join(self.blobs, file_id), 'rb') as f:
return f.read()

def put_file(self, file_path: str, file_id: str):
Expand All @@ -108,7 +123,7 @@ def put_file(self, file_path: str, file_id: str):
"""
path = Path(file_path)
name = path.name
file_id_folder = os.path.join(self.conn, file_id)
file_id_folder = os.path.join(self.files, file_id)

os.makedirs(file_id_folder, exist_ok=True)
os.chmod(file_id_folder, 0o777)
Expand All @@ -126,8 +141,8 @@ def get_file(self, file_id: str) -> str:
:param file_id: The id of the file.
"""
logging.info(f"Loading file {file_id} from {self.conn}")
path = os.path.join(self.conn, file_id)
logging.info(f"Loading file {file_id} from {self.files}")
path = os.path.join(self.files, file_id)
files = os.listdir(path)
assert len(files) == 1, f"Expected 1 file, got {len(files)}"
name = files[0]
Expand Down
1 change: 1 addition & 0 deletions superduper/backends/query_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __getitem__(self, item):
return self._get_item(input)


# TODO remove - never used
class CachedQueryDataset(QueryDataset):
"""Cached Query Dataset for fetching documents from database.
Expand Down
1 change: 1 addition & 0 deletions superduper/base/code.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO remove - never used
import inspect

from superduper import logging
Expand Down
1 change: 0 additions & 1 deletion superduper/base/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@
KEY_BUILDS = '_builds'
KEY_BLOBS = '_blobs'
KEY_FILES = '_files'
KEY_SCHEMA = '_schema'
32 changes: 12 additions & 20 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from superduper.base.cursor import SuperDuperCursor
from superduper.base.document import Document
from superduper.components.component import Component
from superduper.components.datatype import LeafType
from superduper.components.schema import Schema
from superduper.components.table import Table
from superduper.misc.importing import import_object

DBResult = t.Any
TaskGraph = t.Any
Expand Down Expand Up @@ -469,21 +469,13 @@ def load(
info = self.metadata.get_component_by_uuid(
uuid=uuid, allow_hidden=allow_hidden
)
try:
class_schema = import_object(info['_path']).build_class_schema()
except (KeyError, AttributeError):
# if defined in __main__ then the class is directly serialized
assert '_object' in info
from superduper.components.datatype import DEFAULT_SERIALIZER, Blob

bytes_ = Blob(
identifier=info['_object'].split(':')[-1], db=self
).unpack()
object = DEFAULT_SERIALIZER.decode_data(bytes_)
class_schema = object.build_class_schema()

c = Document.decode(info, db=self, schema=class_schema)
c.db = self
builds = info.get('_builds', {})
for k in builds:
builds[k]['identifier'] = k.split(':')[-1]
c = LeafType('leaf_type', db=self).decode_data(
{k: v for k, v in info.items() if k != '_builds'},
builds=info.get('_builds', {}),
)
if c.cache:
logging.info(f'Adding {c.huuid} to cache')
self.cluster.cache.put(c)
Expand All @@ -506,9 +498,9 @@ def load(
identifier=identifier,
allow_hidden=allow_hidden,
)
c = Document.decode(info, db=self)
c.db = self

c = LeafType('leaf_type', db=self).decode_data(
info, builds=info.get('_builds', {})
)
if c.cache:
logging.info(f'Adding {c.huuid} to cache')
self.cluster.cache.put(c)
Expand Down Expand Up @@ -735,7 +727,7 @@ def infer_schema(
# TODO have a slightly more user-friendly schema
from superduper.misc.auto_schema import infer_schema

return infer_schema(data, identifier=identifier)
return infer_schema(data, identifier=identifier, db=self)

@property
def cfg(self) -> Config:
Expand Down
Loading

0 comments on commit cf214d7

Please sign in to comment.