Skip to content

Commit

Permalink
sqlalchemy 2.0 (#685)
Browse files Browse the repository at this point in the history
* Bump sqlalchemy requirement

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
TomAugspurger and pre-commit-ci[bot] authored Feb 15, 2023
1 parent 57c5aa8 commit 5771a3f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 50 deletions.
87 changes: 44 additions & 43 deletions dask-gateway-server/dask_gateway_server/backends/db_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def to_model(self):
class Worker:
"""Information on a worker.
Not all attributes on this object are publically accessible. When writing a
Not all attributes on this object are publicly accessible. When writing a
backend, you may access the following attributes:
Attributes
Expand Down Expand Up @@ -334,53 +334,54 @@ def connect(dbapi_con, con_record):
self.id_to_cluster = {}

# Load all existing clusters into memory
for c in self.db.execute(clusters.select()):
tls_cert, tls_key = self.decode_tls_credentials(c.tls_credentials)
token = self.decode_token(c.token)
cluster = Cluster(
id=c.id,
name=c.name,
username=c.username,
token=token,
options=c.options,
config=FrozenAttrDict(c.config),
status=c.status,
target=c.target,
count=c.count,
state=c.state,
scheduler_address=c.scheduler_address,
dashboard_address=c.dashboard_address,
api_address=c.api_address,
tls_cert=tls_cert,
tls_key=tls_key,
start_time=c.start_time,
stop_time=c.stop_time,
)
self.username_to_clusters[cluster.username][cluster.name] = cluster
self.id_to_cluster[cluster.id] = cluster
self.name_to_cluster[cluster.name] = cluster

# Next load all existing workers into memory
for w in self.db.execute(workers.select()):
cluster = self.id_to_cluster[w.cluster_id]
worker = Worker(
id=w.id,
name=w.name,
status=w.status,
target=w.target,
cluster=cluster,
state=w.state,
start_time=w.start_time,
stop_time=w.stop_time,
close_expected=w.close_expected,
)
cluster.workers[worker.name] = worker
with self.db.begin() as connection:
for c in connection.execute(clusters.select()):
tls_cert, tls_key = self.decode_tls_credentials(c.tls_credentials)
token = self.decode_token(c.token)
cluster = Cluster(
id=c.id,
name=c.name,
username=c.username,
token=token,
options=c.options,
config=FrozenAttrDict(c.config),
status=c.status,
target=c.target,
count=c.count,
state=c.state,
scheduler_address=c.scheduler_address,
dashboard_address=c.dashboard_address,
api_address=c.api_address,
tls_cert=tls_cert,
tls_key=tls_key,
start_time=c.start_time,
stop_time=c.stop_time,
)
self.username_to_clusters[cluster.username][cluster.name] = cluster
self.id_to_cluster[cluster.id] = cluster
self.name_to_cluster[cluster.name] = cluster

# Next load all existing workers into memory
for w in connection.execute(workers.select()):
cluster = self.id_to_cluster[w.cluster_id]
worker = Worker(
id=w.id,
name=w.name,
status=w.status,
target=w.target,
cluster=cluster,
state=w.state,
start_time=w.start_time,
stop_time=w.stop_time,
close_expected=w.close_expected,
)
cluster.workers[worker.name] = worker

def cleanup_expired(self, max_age_in_seconds):
cutoff = timestamp() - max_age_in_seconds * 1000
with self.db.begin() as conn:
to_delete = conn.execute(
sa.select([clusters.c.id]).where(clusters.c.stop_time < cutoff)
sa.select(clusters.c.id).where(clusters.c.stop_time < cutoff)
).fetchall()

if to_delete:
Expand Down
8 changes: 4 additions & 4 deletions dask-gateway-server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ def get_tag(self):
# pykerberos is tricky to install and requires a system package to
# successfully compile some C code, on ubuntu this is libkrb5-dev.
"kerberos": ["pykerberos"],
"jobqueue": ["sqlalchemy"],
"local": ["sqlalchemy"],
"yarn": ["sqlalchemy", "skein >= 0.7.3"],
"jobqueue": ["sqlalchemy>=2.0.0"],
"local": ["sqlalchemy>=2.0.0"],
"yarn": ["sqlalchemy>=2.0.0", "skein >= 0.7.3"],
"kubernetes": ["kubernetes_asyncio"],
"all_backends": [
"sqlalchemy",
"sqlalchemy>=2.0.0",
"skein >= 0.7.3",
"kubernetes_asyncio",
],
Expand Down
7 changes: 4 additions & 3 deletions tests/test_db_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ async def test_encryption(tmpdir):
# Check database state is encrypted
with db.db.begin() as conn:
res = conn.execute(
db_base.clusters.select(db_base.clusters.c.id == c.id)
db_base.clusters.select().where(db_base.clusters.c.id == c.id)
).fetchone()
assert res.tls_credentials != b";".join((c.tls_cert, c.tls_key))
cert, key = db.decrypt(res.tls_credentials).split(b";")
Expand Down Expand Up @@ -283,8 +283,9 @@ def check_db_consistency(db):
# Users without clusters are flushed
assert clusters

clusters = db.db.execute(db_base.clusters.select()).fetchall()
workers = db.db.execute(db_base.workers.select()).fetchall()
with db.db.begin() as conn:
clusters = conn.execute(db_base.clusters.select()).fetchall()
workers = conn.execute(db_base.workers.select()).fetchall()

# Check cluster state
for c in clusters:
Expand Down

0 comments on commit 5771a3f

Please sign in to comment.