Skip to content

Commit 7741016

Browse files
committed
Test on both postgres and sqlite in CI
1 parent 4f7cba7 commit 7741016

13 files changed

+112
-67
lines changed

nomenklatura/db.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from contextlib import contextmanager
22
from functools import cache
33
from typing import Generator, Optional, Union
4+
import logging
45

56
from sqlalchemy import MetaData, create_engine
67
from sqlalchemy.dialects.mysql import insert as mysql_insert
@@ -13,14 +14,31 @@
1314
Conn = Connection
1415
Connish = Optional[Connection]
1516

17+
WARNED_DB_URL = False
18+
19+
logger = logging.getLogger(__name__)
20+
1621

1722
@cache
18-
def get_engine(url: str = settings.DB_URL) -> Engine:
19-
if settings.TESTING:
20-
url = "sqlite:///:memory:"
21-
# if url.lower().startswith('sqlite'):
22-
# return create_engine(url)
23-
return create_engine(url, pool_size=settings.DB_POOL_SIZE)
23+
def get_engine(url: Optional[str] = None) -> Engine:
24+
if not url:
25+
if settings.DB_URL:
26+
url = settings.DB_URL
27+
else:
28+
url = f"sqlite:///{settings.DB_PATH.as_posix()}"
29+
30+
global WARNED_DB_URL
31+
if not WARNED_DB_URL:
32+
logger.warning(f"No DB_URL set. Using {url}")
33+
WARNED_DB_URL = True
34+
35+
connect_args = {}
36+
if url.startswith("postgres"):
37+
connect_args["options"] = "-c statement_timeout=3000"
38+
39+
return create_engine(
40+
url, pool_size=settings.DB_POOL_SIZE, connect_args=connect_args
41+
)
2442

2543

2644
@cache

nomenklatura/resolver/resolver.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,19 @@ def commit(self) -> None:
9696
self._conn.close()
9797
self._conn = None
9898

99-
def rollback(self) -> None:
100-
if self._transaction is None or self._conn is None:
101-
raise RuntimeError("No transaction to rollback.")
102-
self._transaction.rollback()
103-
self._transaction = None
104-
self._conn.close()
105-
self._conn = None
99+
def rollback(self, force=False) -> None:
100+
if self._transaction is not None:
101+
self._transaction.rollback()
102+
self._transaction = None
103+
else:
104+
if not force:
105+
raise RuntimeError("No transaction to rollback.")
106+
if self._conn is not None:
107+
self._conn.close()
108+
self._conn = None
109+
else:
110+
if not force:
111+
raise RuntimeError("No connection to close.")
106112

107113
def _get_connection(self) -> Connection:
108114
if self._transaction is None or self._conn is None:
@@ -487,6 +493,11 @@ def prune(self) -> None:
487493
self._invalidate()
488494

489495
def apply_statement(self, stmt: Statement) -> Statement:
496+
"""
497+
Canonicalise the entity ID.
498+
499+
Doesn't canonicalise entity ID values.
500+
"""
490501
if stmt.entity_id is not None:
491502
stmt.canonical_id = self.get_canonical(stmt.entity_id)
492503
if stmt.prop_type == registry.entity.name:

nomenklatura/settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
TESTING = False
55

66
DB_PATH = Path("nomenklatura.db").resolve()
7-
DB_URL = env_str("NOMENKLATURA_DB_URL", f"sqlite:///{DB_PATH.as_posix()}")
7+
DB_URL = env_str("NOMENKLATURA_DB_URL", "")
88
DB_POOL_SIZE = int(env_str("NOMENKLATURA_DB_POOL_SIZE", "5"))
99

1010
REDIS_URL = env_str("NOMENKLATURA_REDIS_URL", "")

tests/conftest.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import shutil
3+
from sqlalchemy import MetaData
34
import yaml
45
import pytest
56
from pathlib import Path
@@ -21,7 +22,12 @@
2122

2223
@pytest.fixture(autouse=True)
2324
def wrap_test():
25+
if not settings.DB_URL:
26+
settings.DB_URL = "sqlite:///:memory:"
2427
yield
28+
# Dispose of connections to let open transactions for resources not
29+
# managed by the setup/teardown abort.
30+
get_engine().dispose()
2531
get_engine.cache_clear()
2632
get_redis.cache_clear()
2733
get_metadata.cache_clear()
@@ -55,13 +61,24 @@ def donations_json(donations_path):
5561
@pytest.fixture(scope="function")
5662
def resolver():
5763
resolver = Resolver[CompositeEntity].make_default()
58-
resolver.begin()
5964
yield resolver
60-
resolver.rollback()
65+
resolver.rollback(force=True)
66+
resolver._table.drop(resolver._engine)
67+
68+
69+
@pytest.fixture(scope="function")
70+
def other_table_resolver():
71+
engine = get_engine()
72+
meta = MetaData()
73+
resolver = Resolver(engine, meta, create=True, table_name="another_table")
74+
yield resolver
75+
resolver.rollback(force=True)
76+
resolver._table.drop(engine)
6177

6278

6379
@pytest.fixture(scope="function")
6480
def dstore(donations_path, resolver) -> SimpleMemoryStore:
81+
resolver.begin()
6582
return load_entity_file_store(donations_path, resolver)
6683

6784

tests/enrich/test_nominatim.py

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_nominatim_match():
8888

8989

9090
def test_nominatim_match_list(resolver: Resolver[CompositeEntity]):
91+
resolver.begin()
9192
enricher = load_enricher()
9293

9394
full = "Kopenhagener Str. 47, Berlin"
@@ -116,6 +117,7 @@ def test_nominatim_enrich():
116117

117118

118119
def test_nominatim_enrich_list(resolver: Resolver[CompositeEntity]):
120+
resolver.begin()
119121
enricher = load_enricher()
120122

121123
full = "Kopenhagener Str. 47, Berlin"

tests/store/test_leveldb.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
def test_leveldb_store_basics(
2727
test_dataset: Dataset, resolver: Resolver[CompositeEntity]
2828
):
29+
resolver.begin()
2930
path = Path(tempfile.mkdtemp()) / "leveldb"
3031
store = LevelDBStore(test_dataset, resolver, path)
3132
entity = CompositeEntity.from_data(test_dataset, PERSON)
@@ -52,6 +53,7 @@ def test_leveldb_store_basics(
5253
def test_leveldb_graph_query(
5354
donations_path: Path, test_dataset: Dataset, resolver: Resolver[CompositeEntity]
5455
):
56+
resolver.begin()
5557
path = Path(tempfile.mkdtemp()) / "xxx"
5658
store = LevelDBStore(test_dataset, resolver, path)
5759
assert len(list(store.view(test_dataset).entities())) == 0

tests/store/test_memory.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323

2424
def test_basic_store(test_dataset: Dataset, resolver: Resolver[CompositeEntity]):
25+
resolver.begin()
2526
store = MemoryStore(test_dataset, resolver)
2627
entity = CompositeEntity.from_data(test_dataset, PERSON)
2728
entity_ext = CompositeEntity.from_data(test_dataset, PERSON_EXT)

tests/store/test_redis.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525

2626
def test_redis_store_basics(test_dataset: Dataset, resolver: Resolver[CompositeEntity]):
27+
resolver.begin()
2728
redis = fakeredis.FakeStrictRedis(version=6, decode_responses=False)
2829
store = RedisStore(test_dataset, resolver, db=redis)
2930
entity = CompositeEntity.from_data(test_dataset, PERSON)
@@ -50,6 +51,7 @@ def test_redis_store_basics(test_dataset: Dataset, resolver: Resolver[CompositeE
5051
def test_leveldb_graph_query(
5152
donations_path: Path, test_dataset: Dataset, resolver: Resolver[CompositeEntity]
5253
):
54+
resolver.begin()
5355
redis = fakeredis.FakeStrictRedis(version=6, decode_responses=False)
5456
store = RedisStore(test_dataset, resolver, db=redis)
5557
assert len(list(store.view(test_dataset).entities())) == 0

tests/store/test_resolved.py

+3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929

3030
def test_store_basics(test_dataset: Dataset, resolver: Resolver[CompositeEntity]):
31+
resolver.begin()
3132
redis = fakeredis.FakeStrictRedis(version=6, decode_responses=False)
3233
store = ResolvedStore(test_dataset, resolver, db=redis)
3334
entity = CompositeEntity.from_data(test_dataset, PERSON)
@@ -59,6 +60,7 @@ def test_store_basics(test_dataset: Dataset, resolver: Resolver[CompositeEntity]
5960
def test_graph_query(
6061
donations_path: Path, test_dataset: Dataset, resolver: Resolver[CompositeEntity]
6162
):
63+
resolver.begin()
6264
redis = fakeredis.FakeStrictRedis(version=6, decode_responses=False)
6365
store = ResolvedStore(test_dataset, resolver, db=redis)
6466
assert len(list(store.view(test_dataset).entities())) == 0
@@ -108,6 +110,7 @@ def test_graph_query(
108110
def test_custom_functions(
109111
donations_path: Path, test_dataset: Dataset, resolver: Resolver[CompositeEntity]
110112
):
113+
resolver.begin()
111114
redis = fakeredis.FakeStrictRedis(version=6, decode_responses=False)
112115
prefix = "test123"
113116
mem_store = MemoryStore(test_dataset, resolver)

tests/store/test_stores.py

+3
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def test_store_sql(
7474
donations_json: List[Dict[str, Any]],
7575
resolver: Resolver[CompositeEntity],
7676
):
77+
resolver.begin()
7778
uri = f"sqlite:///{tmp_path / 'test.db'}"
7879
store = SQLStore(dataset=test_dataset, linker=resolver, uri=uri)
7980
assert str(store.engine.url) == uri
@@ -85,6 +86,7 @@ def test_store_memory(
8586
donations_json: List[Dict[str, Any]],
8687
resolver: Resolver[CompositeEntity],
8788
):
89+
resolver.begin()
8890
store = SimpleMemoryStore(dataset=test_dataset, linker=resolver)
8991
assert _run_store_test(store, test_dataset, donations_json)
9092

@@ -95,6 +97,7 @@ def test_store_level(
9597
donations_json: List[Dict[str, Any]],
9698
resolver: Resolver[CompositeEntity],
9799
):
100+
resolver.begin()
98101
path = tmp_path / "level.db"
99102
store = LevelDBStore(dataset=test_dataset, linker=resolver, path=path)
100103
assert _run_store_test(store, test_dataset, donations_json)

tests/store/test_versioned.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828

2929
def test_store_basics(test_dataset: Dataset, resolver: Resolver[CompositeEntity]):
30+
resolver.begin()
3031
redis = fakeredis.FakeStrictRedis(version=6, decode_responses=False)
3132
store = VersionedRedisStore(test_dataset, resolver, db=redis)
3233
assert len(list(store.view(test_dataset).statements())) == 0
@@ -64,6 +65,7 @@ def test_store_basics(test_dataset: Dataset, resolver: Resolver[CompositeEntity]
6465
def test_graph_query(
6566
donations_path: Path, test_dataset: Dataset, resolver: Resolver[CompositeEntity]
6667
):
68+
resolver.begin()
6769
redis = fakeredis.FakeStrictRedis(version=6, decode_responses=False)
6870
store = VersionedRedisStore(test_dataset, resolver, db=redis)
6971
assert len(list(store.view(test_dataset).entities())) == 0

0 commit comments

Comments
 (0)