Skip to content

Commit

Permalink
avoid splitting script into feature files/scripts (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Sep 7, 2024
1 parent a0ce094 commit ae493e7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 267 deletions.
78 changes: 13 additions & 65 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import posixpath
import subprocess
import sys
import tempfile
import time
import traceback
from collections.abc import Iterable, Iterator, Mapping, Sequence
Expand Down Expand Up @@ -77,7 +76,6 @@
)

from .datasource import DataSource
from .subclass import SubclassFinder

if TYPE_CHECKING:
from datachain.data_storage import (
Expand All @@ -92,7 +90,6 @@

DEFAULT_DATASET_DIR = "dataset"
DATASET_FILE_SUFFIX = ".edatachain"
FEATURE_CLASSES = ["DataModel"]

TTL_INT = 4 * 60 * 60

Expand Down Expand Up @@ -569,12 +566,6 @@ def find_column_to_str( # noqa: PLR0911
return ""


def form_module_source(source_ast):
module = ast.Module(body=source_ast, type_ignores=[])
module = ast.fix_missing_locations(module)
return ast.unparse(module)


class Catalog:
def __init__(
self,
Expand Down Expand Up @@ -660,29 +651,10 @@ def attach_query_wrapper(self, code_ast):
code_ast.body[-1:] = new_expressions
return code_ast

def compile_query_script(
self, script: str, feature_module_name: str
) -> tuple[Union[str, None], str]:
def compile_query_script(self, script: str) -> str:
code_ast = ast.parse(script)
code_ast = self.attach_query_wrapper(code_ast)
finder = SubclassFinder(FEATURE_CLASSES)
finder.visit(code_ast)

if not finder.feature_class:
main_module = form_module_source([*finder.imports, *finder.main_body])
return None, main_module

feature_import = ast.ImportFrom(
module=feature_module_name,
names=[ast.alias(name="*", asname=None)],
level=0,
)
feature_module = form_module_source([*finder.imports, *finder.feature_class])
main_module = form_module_source(
[*finder.imports, feature_import, *finder.main_body]
)

return feature_module, main_module
return ast.unparse(code_ast)

def parse_url(self, uri: str, **config: Any) -> tuple[Client, str]:
config = config or self.client_config
Expand Down Expand Up @@ -1863,11 +1835,6 @@ def query(
C.size > 1000
)
"""
feature_file = tempfile.NamedTemporaryFile( # noqa: SIM115
dir=os.getcwd(), suffix=".py", delete=False
)
_, feature_module = os.path.split(feature_file.name)

if not job_id:
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
job_id = self.metastore.create_job(
Expand All @@ -1877,23 +1844,16 @@ def query(
python_version=python_version,
)

try:
lines, proc = self.run_query(
python_executable or sys.executable,
query_script,
envs,
feature_file,
capture_output,
feature_module,
output_hook,
params,
save,
job_id,
)
finally:
feature_file.close()
os.unlink(feature_file.name)

lines, proc = self.run_query(
python_executable or sys.executable,
query_script,
envs,
capture_output,
output_hook,
params,
save,
job_id,
)
output = "".join(lines)

if proc.returncode:
Expand Down Expand Up @@ -1947,31 +1907,19 @@ def run_query(
python_executable: str,
query_script: str,
envs: Optional[Mapping[str, str]],
feature_file: IO[bytes],
capture_output: bool,
feature_module: str,
output_hook: Callable[[str], None],
params: Optional[dict[str, str]],
save: bool,
job_id: Optional[str],
) -> tuple[list[str], subprocess.Popen]:
try:
feature_code, query_script_compiled = self.compile_query_script(
query_script, feature_module[:-3]
)
if feature_code:
feature_file.write(feature_code.encode())
feature_file.flush()

query_script_compiled = self.compile_query_script(query_script)
except Exception as exc:
raise QueryScriptCompileError(
f"Query script failed to compile, reason: {exc}"
) from exc
envs = dict(envs or os.environ)
if feature_code:
envs["DATACHAIN_FEATURE_CLASS_SOURCE"] = json.dumps(
{feature_module: feature_code}
)
envs.update(
{
"DATACHAIN_QUERY_PARAMS": json.dumps(params or {}),
Expand Down
60 changes: 0 additions & 60 deletions src/datachain/catalog/subclass.py

This file was deleted.

142 changes: 0 additions & 142 deletions tests/unit/test_catalog.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from textwrap import dedent
from typing import TYPE_CHECKING

from datachain.catalog import Catalog
Expand All @@ -7,147 +6,6 @@
from datachain.data_storage import AbstractWarehouse


def test_compile_query_script_no_feature_class(catalog):
script = dedent(
"""
from datachain.query import C, DatasetQuery, asUDF
DatasetQuery("s3://bkt/dir1")
"""
).strip()
feature, result = catalog.compile_query_script(script, "tmpfeature")
expected = dedent(
"""
from datachain.query import C, DatasetQuery, asUDF
import datachain.query.dataset
datachain.query.dataset.query_wrapper(
DatasetQuery('s3://bkt/dir1'))
"""
).strip()
assert feature is None
assert result == expected


def test_compile_query_script_with_feature_class(catalog):
script = dedent(
"""
from datachain.query import C, DatasetQuery, asUDF
from datachain.lib.data_model import DataModel as FromAlias
from datachain.lib.data_model import DataModel
import datachain.lib.data_model.DataModel as DirectImportedFeature
import datachain
class NormalClass:
t = 1
class SFClass(FromAlias):
emb: float
class DirectImport(DirectImportedFeature):
emb: float
class FullImport(datachain.lib.data_model.DataModel):
emb: float
class Embedding(DataModel):
emb: float
DatasetQuery("s3://bkt/dir1")
"""
).strip()
feature, result = catalog.compile_query_script(script, "tmpfeature")
expected_feature = dedent(
"""
from datachain.query import C, DatasetQuery, asUDF
from datachain.lib.data_model import DataModel as FromAlias
from datachain.lib.data_model import DataModel
import datachain.lib.data_model.DataModel as DirectImportedFeature
import datachain
import datachain.query.dataset
class SFClass(FromAlias):
emb: float
class DirectImport(DirectImportedFeature):
emb: float
class FullImport(datachain.lib.data_model.DataModel):
emb: float
class Embedding(DataModel):
emb: float
"""
).strip()
expected_result = dedent(
"""
from datachain.query import C, DatasetQuery, asUDF
from datachain.lib.data_model import DataModel as FromAlias
from datachain.lib.data_model import DataModel
import datachain.lib.data_model.DataModel as DirectImportedFeature
import datachain
import datachain.query.dataset
from tmpfeature import *
class NormalClass:
t = 1
datachain.query.dataset.query_wrapper(
DatasetQuery('s3://bkt/dir1'))
"""
).strip()

assert feature == expected_feature
assert result == expected_result


def test_compile_query_script_with_decorator(catalog):
script = dedent(
"""
import os
from datachain.query import C, DatasetQuery, udf
from datachain.sql.types import Float, Float32, Int, String, Binary
@udf(
params=("name", ),
output={"num": Float, "bin": Binary}
)
def my_func1(name):
x = 3.14
int_example = 25
bin = int_example.to_bytes(2, "big")
return (x, bin)
print("Test ENV = ", os.environ['TEST_ENV'])
ds = DatasetQuery("s3://dql-small/*.jpg") \
.add_signals(my_func1)
ds
"""
).strip()
feature, result = catalog.compile_query_script(script, "tmpfeature")

expected_result = dedent(
"""
import os
from datachain.query import C, DatasetQuery, udf
from datachain.sql.types import Float, Float32, Int, String, Binary
import datachain.query.dataset
@udf(params=('name',), output={'num': Float, 'bin': Binary})
def my_func1(name):
x = 3.14
int_example = 25
bin = int_example.to_bytes(2, 'big')
return (x, bin)
print('Test ENV = ', os.environ['TEST_ENV'])
ds = DatasetQuery('s3://dql-small/*.jpg').add_signals(my_func1)
datachain.query.dataset.query_wrapper(
ds)
"""
).strip()

assert feature is None
assert result == expected_result


def test_catalog_warehouse_ready_callback(mocker, warehouse, id_generator, metastore):
spy = mocker.spy(warehouse, "is_ready")

Expand Down

0 comments on commit ae493e7

Please sign in to comment.