Skip to content

Commit

Permalink
Refactor the configuration parsing (#513)
Browse files Browse the repository at this point in the history
* Refactor the configuration parsing

As a part of #10774, this introduces a process to save the configuration
in local, system and global configuration.

The precedence of the level are as:
- system
- global
- local

Local configuration overrides global and so on.

This borrows the logic of how configuration is managed in DVC.

* Return toml document by default

* Update some fixes

* Fix init
  • Loading branch information
amritghimire authored Oct 17, 2024
1 parent c6ca542 commit f29e034
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 62 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ dependencies = [
"msgpack>=1.0.4,<2",
"psutil",
"huggingface_hub",
"iterative-telemetry>=0.0.9"
"iterative-telemetry>=0.0.9",
"platformdirs"
]

[project.optional-dependencies]
Expand Down
10 changes: 3 additions & 7 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from datachain.cache import DataChainCache
from datachain.client import Client
from datachain.config import get_remote_config, read_config
from datachain.config import Config
from datachain.dataset import (
DATASET_PREFIX,
QUERY_DATASET_PREFIX,
Expand Down Expand Up @@ -1240,9 +1240,7 @@ def get_dataset(self, name: str) -> DatasetRecord:
return self.metastore.get_dataset(name)

def get_remote_dataset(self, name: str, *, remote_config=None) -> DatasetRecord:
remote_config = remote_config or get_remote_config(
read_config(DataChainDir.find().root), remote=""
)
remote_config = remote_config or Config().get_remote_config(remote="")
studio_client = StudioClient(
remote_config["url"], remote_config["username"], remote_config["token"]
)
Expand Down Expand Up @@ -1479,9 +1477,7 @@ def _instantiate_dataset():
raise ValueError("Please provide output directory for instantiation")

client_config = client_config or self.client_config
remote_config = remote_config or get_remote_config(
read_config(DataChainDir.find().root), remote=""
)
remote_config = remote_config or Config().get_remote_config(remote="")

studio_client = StudioClient(
remote_config["url"], remote_config["username"], remote_config["token"]
Expand Down
5 changes: 2 additions & 3 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
from datachain.lib.dc import DataChain
from datachain.telemetry import telemetry
from datachain.utils import DataChainDir

if TYPE_CHECKING:
from datachain.catalog import Catalog
Expand Down Expand Up @@ -679,9 +678,9 @@ def ls(
**kwargs,
):
if config is None:
from .config import get_remote_config, read_config
from .config import Config

config = get_remote_config(read_config(DataChainDir.find().root), remote=remote)
config = Config().get_remote_config(remote=remote)
remote_type = config["type"]
if remote_type == "local":
ls_local(sources, long=long, **kwargs)
Expand Down
171 changes: 120 additions & 51 deletions src/datachain/config.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,131 @@
import os
from collections.abc import Mapping
from typing import TYPE_CHECKING, Optional
from contextlib import contextmanager
from typing import Optional, Union

from tomlkit import load
from tomlkit import TOMLDocument, dump, load

if TYPE_CHECKING:
from tomlkit import TOMLDocument
from datachain.utils import DataChainDir, global_config_dir, system_config_dir


def read_config(datachain_root: str) -> Optional["TOMLDocument"]:
config_path = os.path.join(datachain_root, "config")
try:
with open(config_path, encoding="utf-8") as f:
return load(f)
except FileNotFoundError:
return None
class Config:
SYSTEM_LEVELS = ("system", "global")
LOCAL_LEVELS = ("local",)

# In the order of precedence
LEVELS = SYSTEM_LEVELS + LOCAL_LEVELS

CONFIG = "config"

def __init__(
self,
level: Optional[str] = None,
):
self.level = level

self.init()

@classmethod
def get_dir(cls, level: Optional[str]) -> str:
if level == "system":
return system_config_dir()
if level == "global":
return global_config_dir()

return DataChainDir.find().root

def init(self):
d = DataChainDir(self.get_dir(self.level))
d.init()

def load_one(self, level: Optional[str] = None) -> TOMLDocument:
config_path = DataChainDir(self.get_dir(level)).config

def get_remote_config(
config: Optional["TOMLDocument"], remote: str = ""
) -> Mapping[str, str]:
if config is None:
return {"type": "local"}
if not remote:
try:
remote = config["core"]["default-remote"] # type: ignore[index,assignment]
except KeyError:
with open(config_path, encoding="utf-8") as f:
return load(f)
except FileNotFoundError:
return TOMLDocument()

def load_config_to_level(self) -> TOMLDocument:
merged_conf = TOMLDocument()

for merge_level in self.LEVELS:
if merge_level == self.level:
break
config = self.load_one(merge_level)
if config:
merge(merged_conf, config)

return merged_conf

def read(self) -> TOMLDocument:
if self.level is None:
return self.load_config_to_level()
return self.load_one(self.level)

@contextmanager
def edit(self):
config = self.load_one(self.level)
yield config

self.write(config)

def config_file(self):
return DataChainDir(self.get_dir(self.level)).config

def write(self, config: TOMLDocument):
with open(self.config_file(), "w") as f:
dump(config, f)

def get_remote_config(self, remote: str = "") -> Mapping[str, str]:
config = self.read()

if not config:
return {"type": "local"}
try:
remote_conf: Mapping[str, str] = config["remote"][remote] # type: ignore[assignment,index]
except KeyError:
raise Exception(
f"missing config section for default remote: remote.{remote}"
) from None
except Exception as exc:
raise Exception("invalid config") from exc

if not isinstance(remote_conf, Mapping):
raise TypeError(f"config section remote.{remote} must be a mapping")

remote_type = remote_conf.get("type")
if remote_type not in ("local", "http"):
raise Exception(
f'config section remote.{remote} must have "type" with one of: '
'"local", "http"'
)

if remote_type == "http":
for key in ["url", "username", "token"]:
if not remote:
try:
remote_conf[key]
remote = config["core"]["default-remote"] # type: ignore[index,assignment]
except KeyError:
raise Exception(
f"config section remote.{remote} of type {remote_type} "
f"must contain key {key}"
) from None
elif remote_type != "local":
raise Exception(
f"config section remote.{remote} has invalid remote type {remote_type}"
)
return remote_conf
return {"type": "local"}
try:
remote_conf: Mapping[str, str] = config["remote"][remote] # type: ignore[assignment,index]
except KeyError:
raise Exception(
f"missing config section for default remote: remote.{remote}"
) from None
except Exception as exc:
raise Exception("invalid config") from exc

if not isinstance(remote_conf, Mapping):
raise TypeError(f"config section remote.{remote} must be a mapping")

remote_type = remote_conf.get("type")
if remote_type not in ("local", "http"):
raise Exception(
f'config section remote.{remote} must have "type" with one of: '
'"local", "http"'
)

if remote_type == "http":
for key in ["url", "username", "token"]:
try:
remote_conf[key]
except KeyError:
raise Exception(
f"config section remote.{remote} of type {remote_type} "
f"must contain key {key}"
) from None
elif remote_type != "local":
raise Exception(
f"config section remote.{remote} has invalid remote type {remote_type}"
)
return remote_conf


def merge(into: Union[TOMLDocument, dict], update: Union[TOMLDocument, dict]):
"""Merges second dict into first recursively"""
for key, val in update.items():
if isinstance(into.get(key), dict) and isinstance(val, dict):
merge(into[key], val) # type: ignore[arg-type]
else:
into[key] = val
28 changes: 28 additions & 0 deletions src/datachain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from uuid import UUID

import cloudpickle
import platformdirs
from dateutil import tz
from dateutil.parser import isoparse
from pydantic import BaseModel
Expand All @@ -25,6 +26,13 @@
NUL = b"\0"
TIME_ZERO = datetime.fromtimestamp(0, tz=timezone.utc)

APPNAME = "datachain"
APPAUTHOR = "iterative"
ENV_DATACHAIN_SYSTEM_CONFIG_DIR = "DATACHAIN_SYSTEM_CONFIG_DIR"
ENV_DATACHAIN_GLOBAL_CONFIG_DIR = "DATACHAIN_GLOBAL_CONFIG_DIR"
STUDIO_URL = "https://studio.dvc.ai"


T = TypeVar("T", bound="DataChainDir")


Expand All @@ -33,6 +41,7 @@ class DataChainDir:
CACHE = "cache"
TMP = "tmp"
DB = "db"
CONFIG = "config"
ENV_VAR = "DATACHAIN_DIR"
ENV_VAR_DATACHAIN_ROOT = "DATACHAIN_ROOT_DIR"

Expand All @@ -42,6 +51,7 @@ def __init__(
cache: Optional[str] = None,
tmp: Optional[str] = None,
db: Optional[str] = None,
config: Optional[str] = None,
) -> None:
self.root = osp.abspath(root) if root is not None else self.default_root()
self.cache = (
Expand All @@ -51,12 +61,18 @@ def __init__(
osp.abspath(tmp) if tmp is not None else osp.join(self.root, self.TMP)
)
self.db = osp.abspath(db) if db is not None else osp.join(self.root, self.DB)
self.config = (
osp.abspath(config)
if config is not None
else osp.join(self.root, self.CONFIG)
)

def init(self):
os.makedirs(self.root, exist_ok=True)
os.makedirs(self.cache, exist_ok=True)
os.makedirs(self.tmp, exist_ok=True)
os.makedirs(osp.split(self.db)[0], exist_ok=True)
os.makedirs(osp.split(self.config)[0], exist_ok=True)

@classmethod
def default_root(cls) -> str:
Expand All @@ -82,6 +98,18 @@ def find(cls: type[T], create: bool = True) -> T:
return instance


def system_config_dir():
return os.getenv(ENV_DATACHAIN_SYSTEM_CONFIG_DIR) or platformdirs.site_config_dir(
APPNAME, APPAUTHOR
)


def global_config_dir():
return os.getenv(ENV_DATACHAIN_GLOBAL_CONFIG_DIR) or platformdirs.user_config_dir(
APPNAME, APPAUTHOR
)


def human_time_to_int(time: str) -> Optional[int]:
if not time:
return None
Expand Down

0 comments on commit f29e034

Please sign in to comment.