Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove thread local, expose printing retry errors #844

Merged
merged 3 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions sdk/src/beta9/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from functools import wraps
from typing import Any, Callable
from typing import Callable, TypeVar, Union, overload


def called_on_import() -> bool:
Expand Down Expand Up @@ -28,11 +28,24 @@ def wrapper(*args, **kwargs):
wrapper()


def try_env(env: str, default: Any) -> Any:
EnvValue = TypeVar("EnvValue", str, int, float, bool)


@overload
def try_env(env: str, default: str) -> bool: ...


@overload
def try_env(env: str, default: EnvValue) -> EnvValue: ...


def try_env(env: str, default: EnvValue) -> Union[EnvValue, bool]:
"""
Tries to get an environment variable and returns the default value if it doesn't exist.
Gets an environment variable and converts it to the correct type based on
the default value.

Will also try to convert the value to the type of the default value.
The environment variable name is prefixed with the name of the project if
it is set in the settings.

Args:
env: The name of the environment variable.
Expand All @@ -52,6 +65,6 @@ def try_env(env: str, default: Any) -> Any:
if target_type is bool:
return env_val.lower() in ["true", "1", "yes", "on"]

return target_type(env_val)
return target_type(env_val) or default
except (ValueError, TypeError):
return default
28 changes: 20 additions & 8 deletions sdk/src/beta9/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from os import PathLike
from pathlib import Path
from queue import Queue
from threading import Thread, local
from threading import Thread
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -59,10 +59,10 @@
)
from .terminal import CustomProgress

_PROCESS_LOCAL: Final[local] = local()
# Value of 0 means the number of workers is calculated based on the file size
_MAX_WORKERS: Final[int] = try_env("MULTIPART_MAX_WORKERS", 0)
_REQUEST_TIMEOUT: Final[float] = try_env("MULTIPART_REQUEST_TIMEOUT", 5)
_MAX_WORKERS: Final = try_env("MULTIPART_MAX_WORKERS", 0)
_REQUEST_TIMEOUT: Final = try_env("MULTIPART_REQUEST_TIMEOUT", 5)
_DEBUG_RETRY: Final = try_env("MULTIPART_DEBUG_RETRY", False)


class ProgressCallbackType(Protocol):
Expand Down Expand Up @@ -108,6 +108,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
return func(*args, **kwargs)
except Exception as e:
if _DEBUG_RETRY:
print(e)

last_exception = e
if attempt < times - 1:
time.sleep(min(current_delay, max_delay))
Expand Down Expand Up @@ -160,19 +163,28 @@ def target():
thread.join(timeout=timeout)


# Global session for making HTTP requests
_session = None


def _get_session() -> Session:
"""
Get a requests session from the process's local storage.
Get a session for making HTTP requests.

This is not thread safe, but should be process safe.
"""
if not hasattr(_PROCESS_LOCAL, "session"):
_PROCESS_LOCAL.session = requests.Session()
return _PROCESS_LOCAL.session
global _session
if _session is None:
_session = requests.Session()
return _session


def _init():
"""
Initialize the process by setting a signal handler.
"""
_get_session()

signal.signal(signal.SIGINT, lambda *_: os.kill(os.getpid(), signal.SIGTERM))


Expand Down
38 changes: 38 additions & 0 deletions sdk/tests/test_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
from contextlib import contextmanager
from typing import Any

import pytest

from beta9 import config, env


@contextmanager
def temp_env(name: str, value: Any):
env_prefix = config.get_settings().name
name = f"{env_prefix}_{str(name)}".upper()
if value is not None:
os.environ[name] = str(value)
yield
if name in os.environ:
del os.environ[name]


@pytest.mark.parametrize(
("env_name", "env_value", "default", "expected"),
[
("VAR1", None, "a-string-id", "a-string-id"),
("VAR2", None, 123, 123),
("VAR3", None, 0.5, 0.5),
("VAR4", None, False, False),
("VAR5", "my-env-value", "", "my-env-value"),
("VAR6", "100", 0, 100),
("VAR7", "on", False, True),
("VAR8", "no", True, False),
("VAR9", "33.3", 0.1, 33.3),
],
)
def test_try_env(env_name: str, env_value: Any, default: Any, expected: Any):
with temp_env(env_name, env_value):
value = env.try_env(env_name, default)
assert value == expected
Loading