Skip to content

Commit

Permalink
API key client option (#486)
Browse files Browse the repository at this point in the history
Fixes #482
  • Loading branch information
cretz authored Mar 13, 2024
1 parent 477aa31 commit f3d1b85
Show file tree
Hide file tree
Showing 8 changed files with 414 additions and 217 deletions.
453 changes: 282 additions & 171 deletions temporalio/bridge/Cargo.lock

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions temporalio/bridge/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ crate-type = ["cdylib"]
[dependencies]
futures = "0.3"
log = "0.4"
once_cell = "1.16.0"
parking_lot = "0.12"
prost = "0.11"
prost-types = "0.11"
once_cell = "1.16"
prost = "0.12"
prost-types = "0.12"
pyo3 = { version = "0.19", features = ["extension-module", "abi3-py38"] }
pyo3-asyncio = { version = "0.19", features = ["tokio-runtime"] }
pythonize = "0.19"
Expand All @@ -23,7 +22,7 @@ temporal-sdk-core-api = { version = "0.1.0", path = "./sdk-core/core-api" }
temporal-sdk-core-protos = { version = "0.1.0", path = "./sdk-core/sdk-core-protos" }
tokio = "1.26"
tokio-stream = "0.1"
tonic = "0.9"
tonic = "0.11"
tracing = "0.1"
url = "2.2"

Expand Down
5 changes: 5 additions & 0 deletions temporalio/bridge/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ClientConfig:

target_url: str
metadata: Mapping[str, str]
api_key: Optional[str]
identity: str
tls_config: Optional[ClientTlsConfig]
retry_config: Optional[ClientRetryConfig]
Expand Down Expand Up @@ -102,6 +103,10 @@ def update_metadata(self, metadata: Mapping[str, str]) -> None:
"""Update underlying metadata on Core client."""
self._ref.update_metadata(metadata)

def update_api_key(self, api_key: Optional[str]) -> None:
"""Update underlying API key on Core client."""
self._ref.update_api_key(api_key)

async def call(
self,
*,
Expand Down
2 changes: 1 addition & 1 deletion temporalio/bridge/sdk-core
Submodule sdk-core updated 36 files
+7 −3 Cargo.toml
+8 −5 client/Cargo.toml
+100 −31 client/src/lib.rs
+1 −1 client/src/raw.rs
+3 −3 core-api/Cargo.toml
+20 −20 core/Cargo.toml
+3 −3 core/src/core_tests/mod.rs
+1 −5 core/src/ephemeral_server/mod.rs
+1 −1 core/src/protosext/mod.rs
+21 −251 core/src/telemetry/metrics.rs
+9 −37 core/src/telemetry/mod.rs
+276 −0 core/src/telemetry/otel.rs
+34 −21 core/src/telemetry/prometheus_server.rs
+1 −1 core/src/worker/activities/local_activities.rs
+3 −4 core/src/worker/client/mocks.rs
+6 −6 core/src/worker/workflow/history_update.rs
+2 −2 core/src/worker/workflow/machines/activity_state_machine.rs
+2 −2 core/src/worker/workflow/machines/cancel_workflow_state_machine.rs
+14 −18 core/src/worker/workflow/machines/child_workflow_state_machine.rs
+6 −7 core/src/worker/workflow/machines/transition_coverage.rs
+5 −5 core/src/worker/workflow/machines/workflow_machines.rs
+2 −2 core/src/worker/workflow/machines/workflow_task_state_machine.rs
+4 −4 sdk-core-protos/Cargo.toml
+4 −4 sdk-core-protos/src/lib.rs
+3 −4 sdk/Cargo.toml
+1 −1 sdk/src/lib.rs
+3 −3 test-utils/Cargo.toml
+1 −1 test-utils/src/histfetch.rs
+1 −1 test-utils/src/lib.rs
+2 −2 tests/integ_tests/client_tests.rs
+1 −1 tests/integ_tests/ephemeral_server_tests.rs
+3 −3 tests/integ_tests/metrics_tests.rs
+2 −4 tests/integ_tests/update_tests.rs
+2 −2 tests/integ_tests/visibility_tests.rs
+1 −1 tests/integ_tests/workflow_tests/eager.rs
+2 −2 tests/main.rs
21 changes: 9 additions & 12 deletions temporalio/bridge/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use parking_lot::RwLock;
use pyo3::exceptions::{PyException, PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use temporal_client::{
ClientKeepAliveConfig as CoreClientKeepAliveConfig, ClientOptions, ClientOptionsBuilder,
Expand Down Expand Up @@ -31,6 +29,7 @@ pub struct ClientConfig {
client_name: String,
client_version: String,
metadata: HashMap<String, String>,
api_key: Option<String>,
identity: String,
tls_config: Option<ClientTlsConfig>,
retry_config: Option<ClientRetryConfig>,
Expand Down Expand Up @@ -75,20 +74,12 @@ pub fn connect_client<'a>(
runtime_ref: &runtime::RuntimeRef,
config: ClientConfig,
) -> PyResult<&'a PyAny> {
let headers = if config.metadata.is_empty() {
None
} else {
Some(Arc::new(RwLock::new(config.metadata.clone())))
};
let opts: ClientOptions = config.try_into()?;
let runtime = runtime_ref.runtime.clone();
runtime_ref.runtime.future_into_py(py, async move {
Ok(ClientRef {
retry_client: opts
.connect_no_namespace(
runtime.core.telemetry().get_temporal_metric_meter(),
headers,
)
.connect_no_namespace(runtime.core.telemetry().get_temporal_metric_meter())
.await
.map_err(|err| {
PyRuntimeError::new_err(format!("Failed client connect: {}", err))
Expand All @@ -114,6 +105,10 @@ impl ClientRef {
self.retry_client.get_client().set_headers(headers);
}

fn update_api_key(&self, api_key: Option<String>) {
self.retry_client.get_client().set_api_key(api_key);
}

fn call_workflow_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<&'p PyAny> {
let mut retry_client = self.retry_client.clone();
self.runtime.future_into_py(py, async move {
Expand Down Expand Up @@ -396,7 +391,9 @@ impl TryFrom<ClientConfig> for ClientOptions {
opts.retry_config
.map_or(RetryConfig::default(), |c| c.into()),
)
.keep_alive(opts.keep_alive_config.map(Into::into));
.keep_alive(opts.keep_alive_config.map(Into::into))
.headers(Some(opts.metadata))
.api_key(opts.api_key);
// Builder does not allow us to set option here, so we have to make
// a conditional to even call it
if let Some(tls_config) = opts.tls_config {
Expand Down
21 changes: 21 additions & 0 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ async def connect(
target_host: str,
*,
namespace: str = "default",
api_key: Optional[str] = None,
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
interceptors: Sequence[Interceptor] = [],
default_workflow_query_reject_condition: Optional[
Expand All @@ -116,6 +117,9 @@ async def connect(
target_host: ``host:port`` for the Temporal server. For local
development, this is often "localhost:7233".
namespace: Namespace to use for client calls.
api_key: API key for Temporal. This becomes the "Authorization"
HTTP header with "Bearer " prepended. This is only set if RPC
metadata doesn't already have an "authorization" key.
data_converter: Data converter to use for all data conversions
to/from payloads.
interceptors: Set of interceptors that are chained together to allow
Expand Down Expand Up @@ -152,6 +156,7 @@ async def connect(
"""
connect_config = temporalio.service.ConnectConfig(
target_host=target_host,
api_key=api_key,
tls=tls,
retry_config=retry_config,
keep_alive_config=keep_alive_config,
Expand Down Expand Up @@ -261,6 +266,22 @@ def rpc_metadata(self, value: Mapping[str, str]) -> None:
self.service_client.config.rpc_metadata = value
self.service_client.update_rpc_metadata(value)

@property
def api_key(self) -> Optional[str]:
"""API key for every call made by this client."""
return self.service_client.config.api_key

@api_key.setter
def api_key(self, value: Optional[str]) -> None:
"""Update the API key for this client.
This is only set if RPCmetadata doesn't already have an "authorization"
key.
"""
# Update config and perform update
self.service_client.config.api_key = value
self.service_client.update_api_key(value)

# Overload for no-param workflow
@overload
async def start_workflow(
Expand Down
15 changes: 15 additions & 0 deletions temporalio/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class ConnectConfig:
"""Config for connecting to the server."""

target_host: str
api_key: Optional[str] = None
tls: Union[bool, TLSConfig] = False
retry_config: Optional[RetryConfig] = None
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default
Expand Down Expand Up @@ -161,6 +162,7 @@ def _to_bridge_config(self) -> temporalio.bridge.client.ClientConfig:

return temporalio.bridge.client.ClientConfig(
target_url=target_url,
api_key=self.api_key,
tls_config=tls_config,
retry_config=self.retry_config._to_bridge_config()
if self.retry_config
Expand Down Expand Up @@ -238,6 +240,11 @@ def update_rpc_metadata(self, metadata: Mapping[str, str]) -> None:
"""Update service client's RPC metadata."""
raise NotImplementedError

@abstractmethod
def update_api_key(self, api_key: Optional[str]) -> None:
"""Update service client's API key."""
raise NotImplementedError

@abstractmethod
async def _rpc_call(
self,
Expand Down Expand Up @@ -740,6 +747,14 @@ def update_rpc_metadata(self, metadata: Mapping[str, str]) -> None:
if self._bridge_client:
self._bridge_client.update_metadata(metadata)

def update_api_key(self, api_key: Optional[str]) -> None:
"""Update Core client API key."""
# Mutate the bridge config and then only mutate the running client
# metadata if already connected
self._bridge_config.api_key = api_key
if self._bridge_client:
self._bridge_client.update_api_key(api_key)

async def _rpc_call(
self,
rpc: str,
Expand Down
105 changes: 77 additions & 28 deletions tests/api/test_grpc_stub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import timedelta
from typing import Mapping

from google.protobuf.empty_pb2 import Empty
from google.protobuf.timestamp_pb2 import Timestamp
Expand Down Expand Up @@ -27,12 +28,6 @@
from temporalio.client import Client


def assert_metadata(context: ServicerContext, **kwargs) -> None:
metadata = dict(context.invocation_metadata())
for k, v in kwargs.items():
assert metadata.get(k) == v


def assert_time_remaining(context: ServicerContext, expected: int) -> None:
# Give or take 5 seconds
assert expected - 5 <= context.time_remaining() <= expected + 5
Expand All @@ -41,24 +36,26 @@ def assert_time_remaining(context: ServicerContext, expected: int) -> None:
class SimpleWorkflowServer(WorkflowServiceServicer):
def __init__(self) -> None:
super().__init__()
self.expected_client_key_value = "client_value"
self.last_metadata: Mapping[str, str] = {}

def assert_last_metadata(self, expected: Mapping[str, str]) -> None:
for k, v in expected.items():
assert self.last_metadata.get(k) == v

async def GetSystemInfo( # type: ignore # https://github.com/nipunn1313/mypy-protobuf/issues/216
self,
request: GetSystemInfoRequest,
context: ServicerContext,
) -> GetSystemInfoResponse:
assert_metadata(context, client_key=self.expected_client_key_value)
self.last_metadata = dict(context.invocation_metadata())
return GetSystemInfoResponse()

async def CountWorkflowExecutions( # type: ignore # https://github.com/nipunn1313/mypy-protobuf/issues/216
self,
request: CountWorkflowExecutionsRequest,
context: ServicerContext,
) -> CountWorkflowExecutionsResponse:
assert_metadata(
context, client_key=self.expected_client_key_value, rpc_key="rpc_value"
)
self.last_metadata = dict(context.invocation_metadata())
assert_time_remaining(context, 123)
assert request.namespace == "my namespace"
assert request.query == "my query"
Expand All @@ -71,7 +68,6 @@ async def DeleteNamespace( # type: ignore # https://github.com/nipunn1313/mypy-
request: DeleteNamespaceRequest,
context: ServicerContext,
) -> DeleteNamespaceResponse:
assert_metadata(context, client_key="client_value", rpc_key="rpc_value")
assert_time_remaining(context, 123)
assert request.namespace == "my namespace"
return DeleteNamespaceResponse(deleted_namespace="my namespace response")
Expand All @@ -83,7 +79,6 @@ async def GetCurrentTime( # type: ignore # https://github.com/nipunn1313/mypy-p
request: Empty,
context: ServicerContext,
) -> GetCurrentTimeResponse:
assert_metadata(context, client_key="client_value", rpc_key="rpc_value")
assert_time_remaining(context, 123)
return GetCurrentTimeResponse(time=Timestamp(seconds=123))

Expand All @@ -101,34 +96,88 @@ async def test_python_grpc_stub():
await server.start()

# Use our client to make a call to each service
client = await Client.connect(
f"localhost:{port}", rpc_metadata={"client_key": "client_value"}
)
metadata = {"rpc_key": "rpc_value"}
client = await Client.connect(f"localhost:{port}")
timeout = timedelta(seconds=123)
count_resp = await client.workflow_service.count_workflow_executions(
CountWorkflowExecutionsRequest(namespace="my namespace", query="my query"),
metadata=metadata,
timeout=timeout,
)
assert count_resp.count == 123
del_resp = await client.operator_service.delete_namespace(
DeleteNamespaceRequest(namespace="my namespace"),
metadata=metadata,
timeout=timeout,
)
assert del_resp.deleted_namespace == "my namespace response"
time_resp = await client.test_service.get_current_time(
Empty(), metadata=metadata, timeout=timeout
)
time_resp = await client.test_service.get_current_time(Empty(), timeout=timeout)
assert time_resp.time.seconds == 123

# Make another call to get system info after changing the client-level
# header
new_metadata = dict(client.rpc_metadata)
new_metadata["client_key"] = "changed_value"
client.rpc_metadata = new_metadata
workflow_server.expected_client_key_value = "changed_value"
await server.stop(grace=None)


async def test_grpc_metadata():
# Start server
server = grpc_server()
workflow_server = SimpleWorkflowServer() # type: ignore[abstract]
add_WorkflowServiceServicer_to_server(workflow_server, server)
port = server.add_insecure_port("[::]:0")
await server.start()

# Connect and confirm metadata of get system info call
client = await Client.connect(
f"localhost:{port}",
api_key="my-api-key",
rpc_metadata={"my-meta-key": "my-meta-val"},
)
workflow_server.assert_last_metadata(
{
"authorization": "Bearer my-api-key",
"my-meta-key": "my-meta-val",
}
)

# Overwrite API key via client RPC metadata, confirm there
client.rpc_metadata = {
"authorization": "my-auth-val1",
"my-meta-key": "my-meta-val",
}
await client.workflow_service.get_system_info(GetSystemInfoRequest())
workflow_server.assert_last_metadata(
{
"authorization": "my-auth-val1",
"my-meta-key": "my-meta-val",
}
)
client.rpc_metadata = {"my-meta-key": "my-meta-val"}

# Overwrite API key via call RPC metadata, confirm there
await client.workflow_service.get_system_info(
GetSystemInfoRequest(), metadata={"authorization": "my-auth-val2"}
)
workflow_server.assert_last_metadata(
{
"authorization": "my-auth-val2",
"my-meta-key": "my-meta-val",
}
)

# Update API key, confirm updated
client.api_key = "my-new-api-key"
await client.workflow_service.get_system_info(GetSystemInfoRequest())
workflow_server.assert_last_metadata(
{
"authorization": "Bearer my-new-api-key",
"my-meta-key": "my-meta-val",
}
)

# Remove API key, confirm removed
client.api_key = None
await client.workflow_service.get_system_info(GetSystemInfoRequest())
workflow_server.assert_last_metadata(
{
"my-meta-key": "my-meta-val",
}
)
assert "authorization" not in workflow_server.last_metadata

await server.stop(grace=None)

0 comments on commit f3d1b85

Please sign in to comment.