diff --git a/pydantic_settings/main.py b/pydantic_settings/main.py index 333980f..6a40bf0 100644 --- a/pydantic_settings/main.py +++ b/pydantic_settings/main.py @@ -1,11 +1,11 @@ from __future__ import annotations as _annotations from pathlib import Path -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeVar -from pydantic import ConfigDict +from pydantic import ConfigDict, TypeAdapter from pydantic._internal._config import config_keys -from pydantic._internal._utils import deep_update +from pydantic._internal._utils import is_model_class from pydantic.main import BaseModel from .sources import ( @@ -20,6 +20,8 @@ SecretsSettingsSource, ) +KeyType = TypeVar('KeyType') + class SettingsConfigDict(ConfigDict, total=False): case_sensitive: bool @@ -308,7 +310,7 @@ def _settings_build_values( ) sources = (cli_settings,) + sources if sources: - return deep_update(*reversed([source() for source in sources])) + return BaseSettings._deep_update(*reversed([source() for source in sources])) else: # no one should mean to do this, but I think returning an empty dict is marginally preferable # to an informative error and much better than a confusing error @@ -343,3 +345,23 @@ def _settings_build_values( secrets_dir=None, protected_namespaces=('model_', 'settings_'), ) + + @staticmethod + def _deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]: + """Adapts logic from `pydantic._internal._utils.deep_update` to handle nested partial overrides of BaseModel derived types.""" + updated_mapping = mapping.copy() + for updating_mapping in updating_mappings: + for key, new_val in updating_mapping.items(): + if key in updated_mapping: + old_val = updated_mapping[key] + old_val_type = type(old_val) + if is_model_class(old_val_type) and isinstance(new_val, dict): + old_val = old_val.model_dump() + updated_mapping[key] = ( + TypeAdapter(old_val_type).validate_python(BaseSettings._deep_update(old_val, new_val)) + if isinstance(old_val, dict) and isinstance(new_val, dict) + else new_val + ) + else: + updated_mapping[key] = new_val + return updated_mapping diff --git a/tests/test_settings.py b/tests/test_settings.py index c4809e2..48bb2fa 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -6,6 +6,7 @@ import sys import typing import uuid +from abc import ABC, abstractmethod from datetime import datetime, timezone from enum import IntEnum from pathlib import Path @@ -713,6 +714,88 @@ def settings_customise_sources( assert s.bar == 'env setting' +def test_env_deep_override(env): + class DeepSubModel(BaseModel): + v4: str + + class SubModel(BaseModel): + v1: str + v2: bytes + v3: int + deep: DeepSubModel + + class Settings(BaseSettings, env_nested_delimiter='__'): + v0: str + sub_model: SubModel + + @classmethod + def settings_customise_sources( + cls, settings_cls, init_settings, env_settings, dotenv_settings, file_secret_settings + ): + return env_settings, dotenv_settings, init_settings, file_secret_settings + + env.set('SUB_MODEL__DEEP__V4', 'override-v4') + + s_final = {'v0': '0', 'sub_model': {'v1': 'init-v1', 'v2': b'init-v2', 'v3': 3, 'deep': {'v4': 'override-v4'}}} + + s = Settings(v0='0', sub_model={'v1': 'init-v1', 'v2': b'init-v2', 'v3': 3, 'deep': {'v4': 'init-v4'}}) + assert s.model_dump() == s_final + + s = Settings(v0='0', sub_model=SubModel(v1='init-v1', v2=b'init-v2', v3=3, deep=DeepSubModel(v4='init-v4'))) + assert s.model_dump() == s_final + + s = Settings(v0='0', sub_model=SubModel(v1='init-v1', v2=b'init-v2', v3=3, deep={'v4': 'init-v4'})) + assert s.model_dump() == s_final + + s = Settings(v0='0', sub_model={'v1': 'init-v1', 'v2': b'init-v2', 'v3': 3, 'deep': DeepSubModel(v4='init-v4')}) + assert s.model_dump() == s_final + + +def test_env_deep_override_copy_by_reference(env): + class BaseAuth(ABC, BaseModel): + @property + @abstractmethod + def token(self) -> str: + """returns authentication token for XYZ""" + pass + + class CustomAuth(BaseAuth): + url: HttpUrl + username: str + password: SecretStr + + _token: SecretStr + + @property + def token(self): + ... # (re)fetch token + return self._token.get_secret_value() + + class Settings(BaseSettings, env_nested_delimiter='__'): + auth: BaseAuth + + @classmethod + def settings_customise_sources( + cls, settings_cls, init_settings, env_settings, dotenv_settings, file_secret_settings + ): + return env_settings, init_settings, file_secret_settings + + auth_orig = CustomAuth(url='https://127.0.0.1', username='some-username', password='some-password') + + s = Settings(auth=auth_orig) + assert s.auth is auth_orig + + env.set('AUTH__URL', 'https://123.4.5.6') + + s = Settings(auth=auth_orig) + assert s.auth is not auth_orig + assert type(s.auth) is CustomAuth + assert s.auth.username is auth_orig.username + assert s.auth.password is auth_orig.password + assert s.auth.url is not auth_orig.url + assert s.auth.url == HttpUrl('https://123.4.5.6') + + def test_config_file_settings_nornir(env): """ See https://github.com/pydantic/pydantic/pull/341#issuecomment-450378771