Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Stop ignoring tests/server.py #15084

Merged
merged 21 commits into from
Feb 17, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from twisted.web.server import Request, Site

from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
Expand Down Expand Up @@ -835,17 +836,17 @@ def connect_client(


class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
DATASTORE_CLASS = DataStore # type: ignore[assignment]


def setup_test_homeserver(
cleanup_func,
name="test",
config=None,
reactor=None,
cleanup_func: Callable[[Callable[[], None]], None],
name: str = "test",
config: Optional[HomeServerConfig] = None,
reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs,
):
**kwargs: Any,
) -> HomeServer:
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
Expand All @@ -860,13 +861,14 @@ def setup_test_homeserver(
HomeserverTestCase.
"""
if reactor is None:
from twisted.internet import reactor
from twisted.internet import reactor as _reactor

reactor = cast(ISynapseReactor, _reactor)

if config is None:
config = default_config(name, parse=True)

config.caches.resize_all_caches()
config.ldap_enabled = False

if "clock" not in kwargs:
kwargs["clock"] = MockClock()
Expand Down Expand Up @@ -917,13 +919,16 @@ def setup_test_homeserver(
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine):
import psycopg2.extensions

db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
assert isinstance(db_conn, psycopg2.extensions.connection)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
Expand Down Expand Up @@ -952,14 +957,15 @@ def setup_test_homeserver(
hs.setup_background_tasks()

if isinstance(db_engine, PostgresEngine):
database = hs.get_datastores().databases[0]
database_pool = hs.get_datastores().databases[0]

# We need to do cleanup on PostgreSQL
def cleanup():
def cleanup() -> None:
import psycopg2
import psycopg2.extensions

# Close all the db pools
database._db_pool.close()
database_pool._db_pool.close()

dropped = False

Expand All @@ -971,6 +977,7 @@ def cleanup():
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()

Expand Down Expand Up @@ -1003,23 +1010,23 @@ def cleanup():
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
async def hash(p):
async def hash(p: str) -> str:
return hashlib.md5(p.encode("utf8")).hexdigest()

hs.get_auth_handler().hash = hash
hs.get_auth_handler().hash = hash # type: ignore[assignment]

async def validate_hash(p, h):
async def validate_hash(p: str, h: str) -> bool:
return hashlib.md5(p.encode("utf8")).hexdigest() == h

hs.get_auth_handler().validate_hash = validate_hash
hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]

# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)

# Load any configured modules into the homeserver
module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules:
module(config=config, api=module_api)
for module, module_config in hs.config.modules.loaded_modules:
module(config=module_config, api=module_api)

load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
Expand Down