Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the configuration parsing #513

Merged
merged 4 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ dependencies = [
"msgpack>=1.0.4,<2",
"psutil",
"huggingface_hub",
"iterative-telemetry>=0.0.9"
"iterative-telemetry>=0.0.9",
"platformdirs"
]

[project.optional-dependencies]
Expand Down
10 changes: 3 additions & 7 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from datachain.cache import DataChainCache
from datachain.client import Client
from datachain.config import get_remote_config, read_config
from datachain.config import Config
from datachain.dataset import (
DATASET_PREFIX,
QUERY_DATASET_PREFIX,
Expand Down Expand Up @@ -1240,9 +1240,7 @@ def get_dataset(self, name: str) -> DatasetRecord:
return self.metastore.get_dataset(name)

def get_remote_dataset(self, name: str, *, remote_config=None) -> DatasetRecord:
remote_config = remote_config or get_remote_config(
read_config(DataChainDir.find().root), remote=""
)
remote_config = remote_config or Config().get_remote_config(remote="")
studio_client = StudioClient(
remote_config["url"], remote_config["username"], remote_config["token"]
)
Expand Down Expand Up @@ -1479,9 +1477,7 @@ def _instantiate_dataset():
raise ValueError("Please provide output directory for instantiation")

client_config = client_config or self.client_config
remote_config = remote_config or get_remote_config(
read_config(DataChainDir.find().root), remote=""
)
remote_config = remote_config or Config().get_remote_config(remote="")

studio_client = StudioClient(
remote_config["url"], remote_config["username"], remote_config["token"]
Expand Down
5 changes: 2 additions & 3 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
from datachain.lib.dc import DataChain
from datachain.telemetry import telemetry
from datachain.utils import DataChainDir

if TYPE_CHECKING:
from datachain.catalog import Catalog
Expand Down Expand Up @@ -679,9 +678,9 @@ def ls(
**kwargs,
):
if config is None:
from .config import get_remote_config, read_config
from .config import Config

config = get_remote_config(read_config(DataChainDir.find().root), remote=remote)
config = Config().get_remote_config(remote=remote)
remote_type = config["type"]
if remote_type == "local":
ls_local(sources, long=long, **kwargs)
Expand Down
171 changes: 120 additions & 51 deletions src/datachain/config.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,131 @@
import os
from collections.abc import Mapping
from typing import TYPE_CHECKING, Optional
from contextlib import contextmanager
from typing import Optional, Union

from tomlkit import load
from tomlkit import TOMLDocument, dump, load

if TYPE_CHECKING:
from tomlkit import TOMLDocument
from datachain.utils import DataChainDir, global_config_dir, system_config_dir


def read_config(datachain_root: str) -> Optional["TOMLDocument"]:
config_path = os.path.join(datachain_root, "config")
try:
with open(config_path, encoding="utf-8") as f:
return load(f)
except FileNotFoundError:
return None
class Config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, it would be great to add tests for this class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests and fixes in #514

SYSTEM_LEVELS = ("system", "global")
LOCAL_LEVELS = ("local",)

# In the order of precedence
LEVELS = SYSTEM_LEVELS + LOCAL_LEVELS

CONFIG = "config"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I understand the reason behind this attribute 👀 Can't find any usage of this.


def __init__(
self,
level: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wonder, will it be better from typing prospective to have level as enum, let's say? 👀

):
self.level = level

self.init()

@classmethod
def get_dir(cls, level: Optional[str]) -> str:
if level == "system":
return system_config_dir()
if level == "global":
return global_config_dir()

return DataChainDir.find().root

def init(self):
d = DataChainDir(self.get_dir(self.level))
d.init()

def load_one(self, level: Optional[str] = None) -> TOMLDocument:
config_path = DataChainDir(self.get_dir(level)).config

def get_remote_config(
config: Optional["TOMLDocument"], remote: str = ""
) -> Mapping[str, str]:
if config is None:
return {"type": "local"}
if not remote:
try:
remote = config["core"]["default-remote"] # type: ignore[index,assignment]
except KeyError:
with open(config_path, encoding="utf-8") as f:
return load(f)

Check warning on line 45 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L45

Added line #L45 was not covered by tests
except FileNotFoundError:
return TOMLDocument()

def load_config_to_level(self) -> TOMLDocument:
merged_conf = TOMLDocument()

for merge_level in self.LEVELS:
if merge_level == self.level:
break

Check warning on line 54 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L54

Added line #L54 was not covered by tests
config = self.load_one(merge_level)
if config:
merge(merged_conf, config)

Check warning on line 57 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L57

Added line #L57 was not covered by tests

return merged_conf

def read(self) -> TOMLDocument:
if self.level is None:
return self.load_config_to_level()
return self.load_one(self.level)

Check warning on line 64 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L64

Added line #L64 was not covered by tests

@contextmanager
def edit(self):
config = self.load_one(self.level)
yield config

Check warning on line 69 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L68-L69

Added lines #L68 - L69 were not covered by tests

self.write(config)

Check warning on line 71 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L71

Added line #L71 was not covered by tests

def config_file(self):
return DataChainDir(self.get_dir(self.level)).config

Check warning on line 74 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L74

Added line #L74 was not covered by tests

def write(self, config: TOMLDocument):
with open(self.config_file(), "w") as f:
dump(config, f)

Check warning on line 78 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L77-L78

Added lines #L77 - L78 were not covered by tests

def get_remote_config(self, remote: str = "") -> Mapping[str, str]:
config = self.read()

if not config:
return {"type": "local"}
try:
remote_conf: Mapping[str, str] = config["remote"][remote] # type: ignore[assignment,index]
except KeyError:
raise Exception(
f"missing config section for default remote: remote.{remote}"
) from None
except Exception as exc:
raise Exception("invalid config") from exc

if not isinstance(remote_conf, Mapping):
raise TypeError(f"config section remote.{remote} must be a mapping")

remote_type = remote_conf.get("type")
if remote_type not in ("local", "http"):
raise Exception(
f'config section remote.{remote} must have "type" with one of: '
'"local", "http"'
)

if remote_type == "http":
for key in ["url", "username", "token"]:
if not remote:
try:
remote_conf[key]
remote = config["core"]["default-remote"] # type: ignore[index,assignment]

Check warning on line 87 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L87

Added line #L87 was not covered by tests
except KeyError:
raise Exception(
f"config section remote.{remote} of type {remote_type} "
f"must contain key {key}"
) from None
elif remote_type != "local":
raise Exception(
f"config section remote.{remote} has invalid remote type {remote_type}"
)
return remote_conf
return {"type": "local"}
try:
remote_conf: Mapping[str, str] = config["remote"][remote] # type: ignore[assignment,index]
except KeyError:
raise Exception(

Check warning on line 93 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L89-L93

Added lines #L89 - L93 were not covered by tests
f"missing config section for default remote: remote.{remote}"
) from None
except Exception as exc:
raise Exception("invalid config") from exc

Check warning on line 97 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L96-L97

Added lines #L96 - L97 were not covered by tests

if not isinstance(remote_conf, Mapping):
raise TypeError(f"config section remote.{remote} must be a mapping")

Check warning on line 100 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L100

Added line #L100 was not covered by tests

remote_type = remote_conf.get("type")

Check warning on line 102 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L102

Added line #L102 was not covered by tests
if remote_type not in ("local", "http"):
raise Exception(

Check warning on line 104 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L104

Added line #L104 was not covered by tests
f'config section remote.{remote} must have "type" with one of: '
'"local", "http"'
)

if remote_type == "http":
for key in ["url", "username", "token"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wonder, should we also have team_id as a part of remote config? User can be a member of many teams in Studio 🤔 Or team_id will be required as a (required? optional?) param in API calls?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be part of future changes when implementing some functionality with the token access.

try:
remote_conf[key]
except KeyError:
raise Exception(

Check warning on line 114 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L111-L114

Added lines #L111 - L114 were not covered by tests
f"config section remote.{remote} of type {remote_type} "
f"must contain key {key}"
) from None
Comment on lines +113 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment: Studio URL may be optional 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not sure if we will continue using this TBH. Keeping this part as it is for now.

elif remote_type != "local":
raise Exception(

Check warning on line 119 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L119

Added line #L119 was not covered by tests
f"config section remote.{remote} has invalid remote type {remote_type}"
)
return remote_conf

Check warning on line 122 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L122

Added line #L122 was not covered by tests


def merge(into: Union[TOMLDocument, dict], update: Union[TOMLDocument, dict]):
"""Merges second dict into first recursively"""
for key, val in update.items():
if isinstance(into.get(key), dict) and isinstance(val, dict):
merge(into[key], val) # type: ignore[arg-type]

Check warning on line 129 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L129

Added line #L129 was not covered by tests
else:
into[key] = val

Check warning on line 131 in src/datachain/config.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/config.py#L131

Added line #L131 was not covered by tests
28 changes: 28 additions & 0 deletions src/datachain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from uuid import UUID

import cloudpickle
import platformdirs
from dateutil import tz
from dateutil.parser import isoparse
from pydantic import BaseModel
Expand All @@ -25,6 +26,13 @@
NUL = b"\0"
TIME_ZERO = datetime.fromtimestamp(0, tz=timezone.utc)

APPNAME = "datachain"
APPAUTHOR = "iterative"
ENV_DATACHAIN_SYSTEM_CONFIG_DIR = "DATACHAIN_SYSTEM_CONFIG_DIR"
ENV_DATACHAIN_GLOBAL_CONFIG_DIR = "DATACHAIN_GLOBAL_CONFIG_DIR"
STUDIO_URL = "https://studio.dvc.ai"


T = TypeVar("T", bound="DataChainDir")


Expand All @@ -33,6 +41,7 @@ class DataChainDir:
CACHE = "cache"
TMP = "tmp"
DB = "db"
CONFIG = "config"
ENV_VAR = "DATACHAIN_DIR"
ENV_VAR_DATACHAIN_ROOT = "DATACHAIN_ROOT_DIR"

Expand All @@ -42,6 +51,7 @@ def __init__(
cache: Optional[str] = None,
tmp: Optional[str] = None,
db: Optional[str] = None,
config: Optional[str] = None,
) -> None:
self.root = osp.abspath(root) if root is not None else self.default_root()
self.cache = (
Expand All @@ -51,12 +61,18 @@ def __init__(
osp.abspath(tmp) if tmp is not None else osp.join(self.root, self.TMP)
)
self.db = osp.abspath(db) if db is not None else osp.join(self.root, self.DB)
self.config = (
osp.abspath(config)
if config is not None
else osp.join(self.root, self.CONFIG)
)

def init(self):
os.makedirs(self.root, exist_ok=True)
os.makedirs(self.cache, exist_ok=True)
os.makedirs(self.tmp, exist_ok=True)
os.makedirs(osp.split(self.db)[0], exist_ok=True)
os.makedirs(osp.split(self.config)[0], exist_ok=True)

@classmethod
def default_root(cls) -> str:
Expand All @@ -82,6 +98,18 @@ def find(cls: type[T], create: bool = True) -> T:
return instance


def system_config_dir():
return os.getenv(ENV_DATACHAIN_SYSTEM_CONFIG_DIR) or platformdirs.site_config_dir(
APPNAME, APPAUTHOR
)


def global_config_dir():
return os.getenv(ENV_DATACHAIN_GLOBAL_CONFIG_DIR) or platformdirs.user_config_dir(
APPNAME, APPAUTHOR
)


def human_time_to_int(time: str) -> Optional[int]:
if not time:
return None
Expand Down