diff --git a/sdk/src/beam/exceptions.py b/sdk/src/beam/exceptions.py index 761dfc7af..cac321a72 100644 --- a/sdk/src/beam/exceptions.py +++ b/sdk/src/beam/exceptions.py @@ -2,3 +2,7 @@ class RunnerException(SystemExit): def __init__(self, message="", *args): self.message = message super().__init__(*args) + + +class InvalidFunctionArgumentsException(RuntimeError): + pass diff --git a/sdk/src/beam/runner/common.py b/sdk/src/beam/runner/common.py index 3b21b3d16..9472c75ad 100644 --- a/sdk/src/beam/runner/common.py +++ b/sdk/src/beam/runner/common.py @@ -17,6 +17,7 @@ class Config: concurrency: Optional[int] keep_warm_seconds: Optional[int] handler: str + task_id: Optional[str] @classmethod def load_from_env(cls) -> "Config": @@ -36,6 +37,8 @@ def load_from_env(cls) -> "Config": if not handler: raise RunnerException("Invalid handler") + task_id = os.getenv("TASK_ID") + return cls( container_id=container_id, container_hostname=container_hostname, @@ -43,6 +46,7 @@ def load_from_env(cls) -> "Config": concurrency=concurrency, keep_warm_seconds=keep_warm_seconds, handler=handler, + task_id=task_id, ) diff --git a/sdk/src/beam/runner/function.py b/sdk/src/beam/runner/function.py index 4a9592716..0e1ec6961 100644 --- a/sdk/src/beam/runner/function.py +++ b/sdk/src/beam/runner/function.py @@ -12,8 +12,8 @@ ) from beam.clients.gateway import EndTaskResponse, GatewayServiceStub, StartTaskResponse from beam.config import with_runner_context -from beam.exceptions import RunnerException -from beam.runner.common import USER_CODE_VOLUME, load_handler +from beam.exceptions import InvalidFunctionArgumentsException, RunnerException +from beam.runner.common import USER_CODE_VOLUME, config, load_handler from beam.type import TaskStatus @@ -22,10 +22,10 @@ def main(channel: Channel): function_stub: FunctionServiceStub = FunctionServiceStub(channel) gateway_stub: GatewayServiceStub = GatewayServiceStub(channel) - task_id = os.getenv("TASK_ID") - container_id = os.getenv("CONTAINER_ID") - container_hostname = os.getenv("CONTAINER_HOSTNAME") - if not task_id or not container_id: + task_id = config.task_id + container_id = config.container_id + container_hostname = config.container_hostname + if not task_id: raise RunnerException("Invalid runner environment") # Start the task @@ -47,7 +47,7 @@ def main(channel: Channel): function_stub.function_get_args(task_id=task_id), ) if not get_args_resp.ok: - raise RuntimeError("invalid args") + raise InvalidFunctionArgumentsException args: dict = cloudpickle.loads(get_args_resp.args) os.chdir(USER_CODE_VOLUME)