From 154318ee7d5126cb4b0a44c2a6ddf96fb3017647 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 3 Jan 2025 16:32:50 +0530 Subject: [PATCH] AIP-72: Add missing client tests for connection operations (#45374) --- task_sdk/tests/api/test_client.py | 52 +++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 31580892dc27f..ff686537a96ed 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -25,8 +25,9 @@ import uuid6 from airflow.sdk.api.client import Client, RemoteValidationError, ServerResponseError -from airflow.sdk.api.datamodels._generated import VariableResponse, XComResponse -from airflow.sdk.execution_time.comms import DeferTask, RescheduleTask +from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse, XComResponse +from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time.comms import DeferTask, ErrorResponse, RescheduleTask from airflow.utils import timezone from airflow.utils.state import TerminalTIState @@ -538,3 +539,50 @@ def handle_request(request: httpx.Request) -> httpx.Response: map_index=2, ) assert result == {"ok": True} + + +class TestConnectionOperations: + """ + Test that the TestConnectionOperations class works as expected. While the operations are simple, it + still catches the basic functionality of the client for connections including endpoint and + response parsing. + """ + + def test_connection_get_success(self): + # Simulate a successful response from the server with a connection + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/connections/test_conn": + return httpx.Response( + status_code=200, + json={ + "conn_id": "test_conn", + "conn_type": "mysql", + }, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.connections.get(conn_id="test_conn") + + assert isinstance(result, ConnectionResponse) + assert result.conn_id == "test_conn" + assert result.conn_type == "mysql" + + def test_connection_get_404_not_found(self): + # Simulate a successful response from the server with a connection + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/connections/test_conn": + return httpx.Response( + status_code=404, + json={ + "reason": "not_found", + "message": "Connection with ID test_conn not found", + }, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.connections.get(conn_id="test_conn") + + assert isinstance(result, ErrorResponse) + assert result.error == ErrorType.CONNECTION_NOT_FOUND