diff --git a/pyproject.toml b/pyproject.toml index bf37b22d..794d1bd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,8 @@ dependencies = [ "msgpack>=1.0.4,<2", "psutil", "huggingface_hub", - "iterative-telemetry>=0.0.9" + "iterative-telemetry>=0.0.9", + "platformdirs" ] [project.optional-dependencies] diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 1ad62c18..909ec4ea 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -35,7 +35,7 @@ from datachain.cache import DataChainCache from datachain.client import Client -from datachain.config import get_remote_config, read_config +from datachain.config import Config from datachain.dataset import ( DATASET_PREFIX, QUERY_DATASET_PREFIX, @@ -1240,9 +1240,7 @@ def get_dataset(self, name: str) -> DatasetRecord: return self.metastore.get_dataset(name) def get_remote_dataset(self, name: str, *, remote_config=None) -> DatasetRecord: - remote_config = remote_config or get_remote_config( - read_config(DataChainDir.find().root), remote="" - ) + remote_config = remote_config or Config().get_remote_config(remote="") studio_client = StudioClient( remote_config["url"], remote_config["username"], remote_config["token"] ) @@ -1479,9 +1477,7 @@ def _instantiate_dataset(): raise ValueError("Please provide output directory for instantiation") client_config = client_config or self.client_config - remote_config = remote_config or get_remote_config( - read_config(DataChainDir.find().root), remote="" - ) + remote_config = remote_config or Config().get_remote_config(remote="") studio_client = StudioClient( remote_config["url"], remote_config["username"], remote_config["token"] diff --git a/src/datachain/cli.py b/src/datachain/cli.py index e89b264e..c0a30650 100644 --- a/src/datachain/cli.py +++ b/src/datachain/cli.py @@ -16,7 +16,6 @@ from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs from datachain.lib.dc import DataChain from datachain.telemetry import telemetry -from datachain.utils import DataChainDir if TYPE_CHECKING: from datachain.catalog import Catalog @@ -679,9 +678,9 @@ def ls( **kwargs, ): if config is None: - from .config import get_remote_config, read_config + from .config import Config - config = get_remote_config(read_config(DataChainDir.find().root), remote=remote) + config = Config().get_remote_config(remote=remote) remote_type = config["type"] if remote_type == "local": ls_local(sources, long=long, **kwargs) diff --git a/src/datachain/config.py b/src/datachain/config.py index a6164efa..6a997c28 100644 --- a/src/datachain/config.py +++ b/src/datachain/config.py @@ -1,62 +1,131 @@ -import os from collections.abc import Mapping -from typing import TYPE_CHECKING, Optional +from contextlib import contextmanager +from typing import Optional, Union -from tomlkit import load +from tomlkit import TOMLDocument, dump, load -if TYPE_CHECKING: - from tomlkit import TOMLDocument +from datachain.utils import DataChainDir, global_config_dir, system_config_dir -def read_config(datachain_root: str) -> Optional["TOMLDocument"]: - config_path = os.path.join(datachain_root, "config") - try: - with open(config_path, encoding="utf-8") as f: - return load(f) - except FileNotFoundError: - return None +class Config: + SYSTEM_LEVELS = ("system", "global") + LOCAL_LEVELS = ("local",) + # In the order of precedence + LEVELS = SYSTEM_LEVELS + LOCAL_LEVELS + + CONFIG = "config" + + def __init__( + self, + level: Optional[str] = None, + ): + self.level = level + + self.init() + + @classmethod + def get_dir(cls, level: Optional[str]) -> str: + if level == "system": + return system_config_dir() + if level == "global": + return global_config_dir() + + return DataChainDir.find().root + + def init(self): + d = DataChainDir(self.get_dir(self.level)) + d.init() + + def load_one(self, level: Optional[str] = None) -> TOMLDocument: + config_path = DataChainDir(self.get_dir(level)).config -def get_remote_config( - config: Optional["TOMLDocument"], remote: str = "" -) -> Mapping[str, str]: - if config is None: - return {"type": "local"} - if not remote: try: - remote = config["core"]["default-remote"] # type: ignore[index,assignment] - except KeyError: + with open(config_path, encoding="utf-8") as f: + return load(f) + except FileNotFoundError: + return TOMLDocument() + + def load_config_to_level(self) -> TOMLDocument: + merged_conf = TOMLDocument() + + for merge_level in self.LEVELS: + if merge_level == self.level: + break + config = self.load_one(merge_level) + if config: + merge(merged_conf, config) + + return merged_conf + + def read(self) -> TOMLDocument: + if self.level is None: + return self.load_config_to_level() + return self.load_one(self.level) + + @contextmanager + def edit(self): + config = self.load_one(self.level) + yield config + + self.write(config) + + def config_file(self): + return DataChainDir(self.get_dir(self.level)).config + + def write(self, config: TOMLDocument): + with open(self.config_file(), "w") as f: + dump(config, f) + + def get_remote_config(self, remote: str = "") -> Mapping[str, str]: + config = self.read() + + if not config: return {"type": "local"} - try: - remote_conf: Mapping[str, str] = config["remote"][remote] # type: ignore[assignment,index] - except KeyError: - raise Exception( - f"missing config section for default remote: remote.{remote}" - ) from None - except Exception as exc: - raise Exception("invalid config") from exc - - if not isinstance(remote_conf, Mapping): - raise TypeError(f"config section remote.{remote} must be a mapping") - - remote_type = remote_conf.get("type") - if remote_type not in ("local", "http"): - raise Exception( - f'config section remote.{remote} must have "type" with one of: ' - '"local", "http"' - ) - - if remote_type == "http": - for key in ["url", "username", "token"]: + if not remote: try: - remote_conf[key] + remote = config["core"]["default-remote"] # type: ignore[index,assignment] except KeyError: - raise Exception( - f"config section remote.{remote} of type {remote_type} " - f"must contain key {key}" - ) from None - elif remote_type != "local": - raise Exception( - f"config section remote.{remote} has invalid remote type {remote_type}" - ) - return remote_conf + return {"type": "local"} + try: + remote_conf: Mapping[str, str] = config["remote"][remote] # type: ignore[assignment,index] + except KeyError: + raise Exception( + f"missing config section for default remote: remote.{remote}" + ) from None + except Exception as exc: + raise Exception("invalid config") from exc + + if not isinstance(remote_conf, Mapping): + raise TypeError(f"config section remote.{remote} must be a mapping") + + remote_type = remote_conf.get("type") + if remote_type not in ("local", "http"): + raise Exception( + f'config section remote.{remote} must have "type" with one of: ' + '"local", "http"' + ) + + if remote_type == "http": + for key in ["url", "username", "token"]: + try: + remote_conf[key] + except KeyError: + raise Exception( + f"config section remote.{remote} of type {remote_type} " + f"must contain key {key}" + ) from None + elif remote_type != "local": + raise Exception( + f"config section remote.{remote} has invalid remote type {remote_type}" + ) + return remote_conf + + +def merge(into: Union[TOMLDocument, dict], update: Union[TOMLDocument, dict]): + """Merges second dict into first recursively""" + for key, val in update.items(): + if isinstance(into.get(key), dict) and isinstance(val, dict): + merge(into[key], val) # type: ignore[arg-type] + else: + into[key] = val diff --git a/src/datachain/utils.py b/src/datachain/utils.py index bf5dea5c..712f5084 100644 --- a/src/datachain/utils.py +++ b/src/datachain/utils.py @@ -15,6 +15,7 @@ from uuid import UUID import cloudpickle +import platformdirs from dateutil import tz from dateutil.parser import isoparse from pydantic import BaseModel @@ -25,6 +26,13 @@ NUL = b"\0" TIME_ZERO = datetime.fromtimestamp(0, tz=timezone.utc) +APPNAME = "datachain" +APPAUTHOR = "iterative" +ENV_DATACHAIN_SYSTEM_CONFIG_DIR = "DATACHAIN_SYSTEM_CONFIG_DIR" +ENV_DATACHAIN_GLOBAL_CONFIG_DIR = "DATACHAIN_GLOBAL_CONFIG_DIR" +STUDIO_URL = "https://studio.dvc.ai" + + T = TypeVar("T", bound="DataChainDir") @@ -33,6 +41,7 @@ class DataChainDir: CACHE = "cache" TMP = "tmp" DB = "db" + CONFIG = "config" ENV_VAR = "DATACHAIN_DIR" ENV_VAR_DATACHAIN_ROOT = "DATACHAIN_ROOT_DIR" @@ -42,6 +51,7 @@ def __init__( cache: Optional[str] = None, tmp: Optional[str] = None, db: Optional[str] = None, + config: Optional[str] = None, ) -> None: self.root = osp.abspath(root) if root is not None else self.default_root() self.cache = ( @@ -51,12 +61,18 @@ def __init__( osp.abspath(tmp) if tmp is not None else osp.join(self.root, self.TMP) ) self.db = osp.abspath(db) if db is not None else osp.join(self.root, self.DB) + self.config = ( + osp.abspath(config) + if config is not None + else osp.join(self.root, self.CONFIG) + ) def init(self): os.makedirs(self.root, exist_ok=True) os.makedirs(self.cache, exist_ok=True) os.makedirs(self.tmp, exist_ok=True) os.makedirs(osp.split(self.db)[0], exist_ok=True) + os.makedirs(osp.split(self.config)[0], exist_ok=True) @classmethod def default_root(cls) -> str: @@ -82,6 +98,18 @@ def find(cls: type[T], create: bool = True) -> T: return instance +def system_config_dir(): + return os.getenv(ENV_DATACHAIN_SYSTEM_CONFIG_DIR) or platformdirs.site_config_dir( + APPNAME, APPAUTHOR + ) + + +def global_config_dir(): + return os.getenv(ENV_DATACHAIN_GLOBAL_CONFIG_DIR) or platformdirs.user_config_dir( + APPNAME, APPAUTHOR + ) + + def human_time_to_int(time: str) -> Optional[int]: if not time: return None