Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 18, 2024
1 parent 8797c6f commit 66d257a
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 26 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ jobs:
CACHIER_TEST_VS_DOCKERIZED_MYSQL: "true"
CACHIER_TEST_PYODBC_CONNECTION_STRING: "DRIVER={MySQL ODBC Driver};SERVER=localhost;PORT=3306;DATABASE=test;USER=root;PASSWORD=password;"


steps:
- uses: actions/checkout@v4

Expand Down
59 changes: 40 additions & 19 deletions cachier/cores/odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@
# Copyright (c) 2016, Shay Palachy <[email protected]>

# standard library imports
import datetime
import pickle
import time
import datetime

pyodbc = None
# third party imports
with suppress(ImportError):
import pyodbc

# local imports
from .base import _BaseCore, RecalculationNeeded
from .base import RecalculationNeeded, _BaseCore

class _OdbcCore(_BaseCore):

class _OdbcCore(_BaseCore):
def __init__(
self,
hash_func,
wait_for_calc_timeout,
connection_string,
table_name,
self,
hash_func,
wait_for_calc_timeout,
connection_string,
table_name,
):
if "pyodbc" not in sys.modules:
warnings.warn(
Expand All @@ -43,7 +43,8 @@ def __init__(
def ensure_table_exists(self):
with pyodbc.connect(self.connection_string) as conn:
cursor = conn.cursor()
cursor.execute(f"""
cursor.execute(
f"""
IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = N'{self.table_name}')
BEGIN
CREATE TABLE {self.table_name} (
Expand All @@ -54,13 +55,17 @@ def ensure_table_exists(self):
PRIMARY KEY (key)
);
END
""")
"""
)
conn.commit()

def get_entry_by_key(self, key):
with pyodbc.connect(self.connection_string) as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT value, time, being_calculated FROM {self.table_name} WHERE key = ?", key)
cursor.execute(
f"SELECT value, time, being_calculated FROM {self.table_name} WHERE key = ?",
key,
)
row = cursor.fetchone()
if row:
return {
Expand All @@ -73,34 +78,48 @@ def get_entry_by_key(self, key):
def set_entry(self, key, func_res):
with pyodbc.connect(self.connection_string) as conn:
cursor = conn.cursor()
cursor.execute(f"""
cursor.execute(
f"""
MERGE INTO {self.table_name} USING (SELECT 1 AS dummy) AS src ON (key = ?)
WHEN MATCHED THEN
UPDATE SET value = ?, time = GETDATE(), being_calculated = 0
WHEN NOT MATCHED THEN
INSERT (key, value, time, being_calculated) VALUES (?, ?, GETDATE(), 0);
""", key, pickle.dumps(func_res), key, pickle.dumps(func_res))
""",
key,
pickle.dumps(func_res),
key,
pickle.dumps(func_res),
)
conn.commit()

def mark_entry_being_calculated(self, key):
with pyodbc.connect(self.connection_string) as conn:
cursor = conn.cursor()
cursor.execute(f"UPDATE {self.table_name} SET being_calculated = 1 WHERE key = ?", key)
cursor.execute(
f"UPDATE {self.table_name} SET being_calculated = 1 WHERE key = ?",
key,
)
conn.commit()

def mark_entry_not_calculated(self, key):
with pyodbc.connect(self.connection_string) as conn:
cursor = conn.cursor()
cursor.execute(f"UPDATE {self.table_name} SET being_calculated = 0 WHERE key = ?", key)
cursor.execute(
f"UPDATE {self.table_name} SET being_calculated = 0 WHERE key = ?",
key,
)
conn.commit()

def wait_on_entry_calc(self, key):
start_time = datetime.datetime.now()
while True:
entry = self.get_entry_by_key(key)
if entry and not entry['being_calculated']:
return entry['value']
if (datetime.datetime.now() - start_time).total_seconds() > self.wait_for_calc_timeout:
if entry and not entry["being_calculated"]:
return entry["value"]
if (
datetime.datetime.now() - start_time
).total_seconds() > self.wait_for_calc_timeout:
raise RecalculationNeeded()
time.sleep(1)

Expand All @@ -113,5 +132,7 @@ def clear_cache(self):
def clear_being_calculated(self):
with pyodbc.connect(self.connection_string) as conn:
cursor = conn.cursor()
cursor.execute(f"UPDATE {self.table_name} SET being_calculated = 0")
cursor.execute(
f"UPDATE {self.table_name} SET being_calculated = 0"
)
conn.commit()
24 changes: 18 additions & 6 deletions tests/test_odbc_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

# local imports
from cachier import cachier

# from cachier.cores.base import RecalculationNeeded
# from cachier.cores.odbc import _OdbcCore


class CfgKey:
"""Configuration keys for testing."""

TEST_VS_DOCKERIZED_MYSQL = "TEST_VS_DOCKERIZED_MYSQL"
TEST_PYODBC_CONNECTION_STRING = "TEST_PYODBC_CONNECTION_STRING"

Expand All @@ -34,7 +36,7 @@ class CfgKey:
def test_odbc_entry_creation_and_retrieval(odbc_core):
"""Test inserting and retrieving an entry from ODBC cache."""

@cachier(backend='odbc', odbc_connection_string=CONCT_STR)
@cachier(backend="odbc", odbc_connection_string=CONCT_STR)
def sample_function(arg_1, arg_2):
return arg_1 + arg_2

Expand All @@ -48,19 +50,28 @@ def test_odbc_stale_after(odbc_core):
"""Test ODBC core handling stale_after parameter."""
stale_after = datetime.timedelta(seconds=1)

@cachier(backend='odbc', odbc_connection_string=CONCT_STR, stale_after=stale_after)
@cachier(
backend="odbc",
odbc_connection_string=CONCT_STR,
stale_after=stale_after,
)
def stale_test_function(arg_1, arg_2):
return arg_1 + arg_2 + datetime.datetime.now().timestamp() # Add timestamp to ensure unique values
return (
arg_1 + arg_2 + datetime.datetime.now().timestamp()
) # Add timestamp to ensure unique values

initial_value = stale_test_function(5, 10)
sleep(2) # Wait for the entry to become stale
assert stale_test_function(5, 10) != initial_value # Should recompute since stale
assert (
stale_test_function(5, 10) != initial_value
) # Should recompute since stale


@pytest.mark.odbc
def test_odbc_clear_cache(odbc_core):
"""Test clearing the ODBC cache."""
@cachier(backend='odbc', odbc_connection_string=CONCT_STR)

@cachier(backend="odbc", odbc_connection_string=CONCT_STR)
def clearable_function(arg):
return arg

Expand All @@ -74,7 +85,8 @@ def clearable_function(arg):
@pytest.mark.odbc
def test_odbc_being_calculated_flag(odbc_core):
"""Test handling of 'being_calculated' flag in ODBC core."""
@cachier(backend='odbc', odbc_connection_string=CONCT_STR)

@cachier(backend="odbc", odbc_connection_string=CONCT_STR)
def slow_function(arg):
sleep(2) # Simulate long computation
return arg * 2
Expand Down

0 comments on commit 66d257a

Please sign in to comment.