From fd4bd633531d069292c278d79ec7c7c39395122e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Pu=C5=A1i=C4=87?= Date: Mon, 20 May 2024 16:15:12 +0200 Subject: [PATCH 1/3] Add the schema and credentials parameters --- src/connection.c | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/connection.c b/src/connection.c index 8618333..2fa6838 100644 --- a/src/connection.c +++ b/src/connection.c @@ -39,15 +39,18 @@ static int execute_trust_callback(const char *hostname, const char *ip_address, static int connection_init(ConnectionObject *conn, PyObject *args, PyObject *kwargs) { - static char *kwlist[] = {"host", "address", "port", "username", - "password", "client_name", "sslmode", "sslcert", - "sslkey", "trust_callback", "lazy", NULL}; + static char *kwlist[] = { + "host", "address", "port", "scheme", "username", + "password", "credentials", "client_name", "sslmode", "sslcert", + "sslkey", "trust_callback", "lazy", NULL}; const char *host = NULL; const char *address = NULL; int port = -1; + const char *scheme = NULL; const char *username = NULL; const char *password = NULL; + const char *credentials = NULL; const char *client_name = NULL; int sslmode_int = MG_SSLMODE_DISABLE; const char *sslcert = NULL; @@ -55,10 +58,10 @@ static int connection_init(ConnectionObject *conn, PyObject *args, PyObject *trust_callback = NULL; int lazy = 0; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$ssisssissOp", kwlist, &host, - &address, &port, &username, &password, - &client_name, &sslmode_int, &sslcert, - &sslkey, &trust_callback, &lazy)) { + if (!PyArg_ParseTupleAndKeywords( + args, kwargs, "|$ssisssssissOp", kwlist, &host, &address, &port, + &scheme, &username, &password, &credentials, &client_name, + &sslmode_int, &sslcert, &sslkey, &trust_callback, &lazy)) { return -1; } @@ -93,8 +96,10 @@ static int connection_init(ConnectionObject *conn, PyObject *args, mg_session_params_set_host(params, host); mg_session_params_set_port(params, (uint16_t)port); mg_session_params_set_address(params, address); + mg_session_params_set_scheme(params, scheme); mg_session_params_set_username(params, username); mg_session_params_set_password(params, password); + mg_session_params_set_credentials(params, credentials); if (client_name) { mg_session_params_set_user_agent(params, client_name); } From d5a58eab725468b658a0d8284c7cd7218d3024d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Pu=C5=A1i=C4=87?= Date: Wed, 29 May 2024 15:18:47 +0200 Subject: [PATCH 2/3] Align custom auth code with mgclient --- src/connection.c | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/connection.c b/src/connection.c index 2fa6838..b4ea7db 100644 --- a/src/connection.c +++ b/src/connection.c @@ -39,10 +39,10 @@ static int execute_trust_callback(const char *hostname, const char *ip_address, static int connection_init(ConnectionObject *conn, PyObject *args, PyObject *kwargs) { - static char *kwlist[] = { - "host", "address", "port", "scheme", "username", - "password", "credentials", "client_name", "sslmode", "sslcert", - "sslkey", "trust_callback", "lazy", NULL}; + static char *kwlist[] = {"host", "address", "port", "scheme", + "username", "password", "client_name", "sslmode", + "sslcert", "sslkey", "trust_callback", "lazy", + NULL}; const char *host = NULL; const char *address = NULL; @@ -50,7 +50,6 @@ static int connection_init(ConnectionObject *conn, PyObject *args, const char *scheme = NULL; const char *username = NULL; const char *password = NULL; - const char *credentials = NULL; const char *client_name = NULL; int sslmode_int = MG_SSLMODE_DISABLE; const char *sslcert = NULL; @@ -59,8 +58,8 @@ static int connection_init(ConnectionObject *conn, PyObject *args, int lazy = 0; if (!PyArg_ParseTupleAndKeywords( - args, kwargs, "|$ssisssssissOp", kwlist, &host, &address, &port, - &scheme, &username, &password, &credentials, &client_name, + args, kwargs, "|$ssissssissOp", kwlist, &host, &address, &port, + &scheme, &username, &password, &client_name, &sslmode_int, &sslcert, &sslkey, &trust_callback, &lazy)) { return -1; } @@ -99,7 +98,6 @@ static int connection_init(ConnectionObject *conn, PyObject *args, mg_session_params_set_scheme(params, scheme); mg_session_params_set_username(params, username); mg_session_params_set_password(params, password); - mg_session_params_set_credentials(params, credentials); if (client_name) { mg_session_params_set_user_agent(params, client_name); } From 6b649ce750e0c801249728fee656bf0bb14a5d24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Pu=C5=A1i=C4=87?= Date: Wed, 19 Jun 2024 00:55:24 +0200 Subject: [PATCH 3/3] Add test --- test/auth_module/dummy_auth_module.py | 21 ++++++++++ test/common.py | 4 +- test/test_connection.py | 56 ++++++++++++++++++++++++++- 3 files changed, 79 insertions(+), 2 deletions(-) create mode 100755 test/auth_module/dummy_auth_module.py diff --git a/test/auth_module/dummy_auth_module.py b/test/auth_module/dummy_auth_module.py new file mode 100755 index 0000000..270b3e1 --- /dev/null +++ b/test/auth_module/dummy_auth_module.py @@ -0,0 +1,21 @@ +#!/usr/bin/python3 +import io +import json + + +def authenticate(scheme: str, response: str): + return { + "authenticated": True, + "role": "architect", + "username": "andy", + } + + +if __name__ == "__main__": + # I/O with Memgraph + input_stream = io.FileIO(1000, mode="r") + output_stream = io.FileIO(1001, mode="w") + while True: + params = json.loads(input_stream.readline().decode("ascii")) + ret = authenticate(**params) + output_stream.write((json.dumps(ret) + "\n").encode("ascii")) diff --git a/test/common.py b/test/common.py index f3c481c..d24c4e5 100644 --- a/test/common.py +++ b/test/common.py @@ -70,7 +70,7 @@ def terminate(self): self.process.wait() -def start_memgraph(cert_file="", key_file=""): +def start_memgraph(cert_file="", key_file="", auth_module_mappings=""): if MEMGRAPH_HOST: use_ssl = MEMGRAPH_STARTED_WITH_SSL is not None return Memgraph(MEMGRAPH_HOST, MEMGRAPH_PORT, use_ssl, None) @@ -94,6 +94,8 @@ def start_memgraph(cert_file="", key_file=""): "--log-file", "", ] + if auth_module_mappings: + cmd.insert(-2, f"--auth-module-mappings={auth_module_mappings}") memgraph_process = subprocess.Popen(cmd) wait_for_server(MEMGRAPH_PORT) use_ssl = True if key_file.strip() else False diff --git a/test/test_connection.py b/test/test_connection.py index adfb8a1..d56f072 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile + import mgclient import pytest -import tempfile from common import start_memgraph, Memgraph, requires_ssl_enabled, requires_ssl_disabled from OpenSSL import crypto @@ -63,6 +65,42 @@ def secure_memgraph_server(): memgraph.terminate() +@pytest.fixture(scope="function") +def provide_role(): + memgraph = start_memgraph() + conn = mgclient.connect( + host=memgraph.host, + port=memgraph.port, + ) + conn.autocommit = True + cursor = conn.cursor() + cursor.execute("CREATE ROLE architect;") + memgraph.terminate() + + yield None + + memgraph = start_memgraph() + conn = mgclient.connect( + host=memgraph.host, + port=memgraph.port, + ) + conn.autocommit = True + cursor = conn.cursor() + cursor.execute("DROP ROLE architect;") + memgraph.terminate() + + +@pytest.fixture(scope="function") +def auth_module_path(): + yield os.path.normpath( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "auth_module", + "dummy_auth_module.py", + ) + ) + + def test_connect_args_validation(): # bad port with pytest.raises(ValueError): @@ -82,6 +120,22 @@ def test_connect_args_validation(): ) +def test_connect_with_custom_auth_scheme(provide_role, auth_module_path): + custom_scheme = "custom_scheme" + memgraph = start_memgraph( + auth_module_mappings=f"{custom_scheme}:{auth_module_path}" + ) + conn = mgclient.connect( + host=memgraph.host, + port=memgraph.port, + scheme=custom_scheme, + username="andy", + password="dummy auth token", + ) + assert conn.status == mgclient.CONN_STATUS_READY + memgraph.terminate() + + @requires_ssl_disabled def test_connect_insecure_success(memgraph_server): host, port, sslmode, _ = memgraph_server