Skip to content

Commit

Permalink
Merge pull request #246 from vipyrsec/close-db-sessions
Browse files Browse the repository at this point in the history
Close DB sessions in functions that use FastAPI dependency injection
  • Loading branch information
import-pandas-as-numpy authored Apr 11, 2024
2 parents 695a677 + 110cb82 commit d8658a0
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 139 deletions.
2 changes: 1 addition & 1 deletion src/mainframe/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
pool_size=mainframe_settings.db_connection_pool_persistent_size,
max_overflow=mainframe_settings.db_connection_pool_max_size - mainframe_settings.db_connection_pool_persistent_size,
)
sessionmaker = sessionmaker(bind=engine, expire_on_commit=False)
sessionmaker = sessionmaker(bind=engine, expire_on_commit=False, autobegin=False)


def get_db() -> Generator[Session, None, None]:
Expand Down
83 changes: 41 additions & 42 deletions src/mainframe/endpoints/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,54 +40,53 @@ def get_jobs(
packages are always processed after newly queued packages.
"""

scans = (
session.scalars(
select(Scan)
.where(
or_(
Scan.status == Status.QUEUED,
and_(
Scan.pending_at
< datetime.now(timezone.utc) - timedelta(seconds=mainframe_settings.job_timeout),
Scan.status == Status.PENDING,
),
with session, session.begin():
scans = (
session.scalars(
select(Scan)
.where(
or_(
Scan.status == Status.QUEUED,
and_(
Scan.pending_at
< datetime.now(timezone.utc) - timedelta(seconds=mainframe_settings.job_timeout),
Scan.status == Status.PENDING,
),
)
)
.order_by(Scan.pending_at.nulls_first(), Scan.queued_at)
.limit(batch)
.options(joinedload(Scan.download_urls))
)
.order_by(Scan.pending_at.nulls_first(), Scan.queued_at)
.limit(batch)
.options(joinedload(Scan.download_urls))
.unique()
.all()
)
.unique()
.all()
)

response_body: list[JobResult] = []
for scan in scans:
scan.status = Status.PENDING
scan.pending_at = datetime.now(timezone.utc)
scan.pending_by = auth.subject
response_body: list[JobResult] = []
for scan in scans:
scan.status = Status.PENDING
scan.pending_at = datetime.now(timezone.utc)
scan.pending_by = auth.subject

logger.info(
"Job given and status set to pending in database",
package={
"name": scan.name,
"status": scan.status,
"pending_at": scan.pending_at,
"pending_by": auth.subject,
"version": scan.version,
},
tag="job_given",
)

job_result = JobResult(
name=scan.name,
version=scan.version,
distributions=[dist.url for dist in scan.download_urls],
hash=state.rules_commit,
)
logger.info(
"Job given and status set to pending in database",
package={
"name": scan.name,
"status": scan.status,
"pending_at": scan.pending_at,
"pending_by": auth.subject,
"version": scan.version,
},
tag="job_given",
)

response_body.append(job_result)
job_result = JobResult(
name=scan.name,
version=scan.version,
distributions=[dist.url for dist in scan.download_urls],
hash=state.rules_commit,
)

session.commit()
response_body.append(job_result)

return response_body
84 changes: 41 additions & 43 deletions src/mainframe/endpoints/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def submit_results(
name = result.name
version = result.version

scan = session.scalar(
select(Scan).where(Scan.name == name).where(Scan.version == version).options(joinedload(Scan.rules))
)
with session.begin():
scan = session.scalar(
select(Scan).where(Scan.name == name).where(Scan.version == version).options(joinedload(Scan.rules))
)

log = logger.bind(package={"name": name, "version": version})

Expand All @@ -62,28 +63,29 @@ def submit_results(
)
raise error

if isinstance(result, PackageScanResultFail):
scan.status = Status.FAILED
scan.fail_reason = result.reason
with session, session.begin():
if isinstance(result, PackageScanResultFail):
scan.status = Status.FAILED
scan.fail_reason = result.reason

session.commit()
return
session.commit()
return

scan.status = Status.FINISHED
scan.finished_at = dt.datetime.now(dt.timezone.utc)
scan.inspector_url = result.inspector_url
scan.score = result.score
scan.finished_by = auth.subject
scan.commit_hash = result.commit
scan.status = Status.FINISHED
scan.finished_at = dt.datetime.now(dt.timezone.utc)
scan.inspector_url = result.inspector_url
scan.score = result.score
scan.finished_by = auth.subject
scan.commit_hash = result.commit

# These are the rules that already have an entry in the database
rules = session.scalars(select(Rule).where(Rule.name.in_(result.rules_matched))).all()
rule_names = {rule.name for rule in rules}
scan.rules.extend(rules)
# These are the rules that already have an entry in the database
rules = session.scalars(select(Rule).where(Rule.name.in_(result.rules_matched))).all()
rule_names = {rule.name for rule in rules}
scan.rules.extend(rules)

# These are the rules that had to be created
new_rules = [Rule(name=rule_name) for rule_name in result.rules_matched if rule_name not in rule_names]
scan.rules.extend(new_rules)
# These are the rules that had to be created
new_rules = [Rule(name=rule_name) for rule_name in result.rules_matched if rule_name not in rule_names]
scan.rules.extend(new_rules)

log.info(
"Scan results submitted",
Expand All @@ -101,8 +103,6 @@ def submit_results(
tag="scan_submitted",
)

session.commit()


@router.get(
"/package",
Expand Down Expand Up @@ -166,10 +166,10 @@ def lookup_package_info(
if nn_since:
query = query.where(Scan.finished_at >= dt.datetime.fromtimestamp(since, tz=dt.timezone.utc))

data = session.scalars(query)
with session, session.begin():
data = session.scalars(query).unique().all()

log.info("Package information queried")
return data.unique().all()
return data


def _deduplicate_packages(packages: list[PackageSpecifier], session: Session) -> set[tuple[str, str]]:
Expand Down Expand Up @@ -208,22 +208,21 @@ def batch_queue_package(
auth: Annotated[AuthenticationData, Depends(validate_token)],
pypi_client: Annotated[PyPIServices, Depends(get_pypi_client)],
):
packages_to_check = _deduplicate_packages(packages, session)

for package_metadata in _get_packages_metadata(pypi_client, packages_to_check):
scan = Scan(
name=package_metadata.title,
version=package_metadata.releases[0].version,
status=Status.QUEUED,
queued_by=auth.subject,
download_urls=[
DownloadURL(url=distribution.url) for distribution in package_metadata.releases[0].distributions
],
)
with session, session.begin():
packages_to_check = _deduplicate_packages(packages, session)

session.add(scan)
for package_metadata in _get_packages_metadata(pypi_client, packages_to_check):
scan = Scan(
name=package_metadata.title,
version=package_metadata.releases[0].version,
status=Status.QUEUED,
queued_by=auth.subject,
download_urls=[
DownloadURL(url=distribution.url) for distribution in package_metadata.releases[0].distributions
],
)

session.commit()
session.add(scan)


@router.post(
Expand Down Expand Up @@ -277,10 +276,9 @@ def queue_package(
],
)

session.add(new_package)

try:
session.commit()
with session, session.begin():
session.add(new_package)
except IntegrityError:
log.warn(f"Package {name}@{version} already queued for scanning.", tag="already_queued")
raise HTTPException(409, f"Package {name}@{version} is already queued for scanning")
Expand Down
16 changes: 10 additions & 6 deletions src/mainframe/endpoints/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def _lookup_package(name: str, version: str, session: Session) -> Scan:
log = logger.bind(package={"name": name, "version": version})

query = select(Scan).where(Scan.name == name).options(joinedload(Scan.rules))
scans = session.scalars(query).unique().all()
with session.begin():
scans = session.scalars(query).unique().all()

if not scans:
error = HTTPException(404, detail=f"No records for package `{name}` were found in the database")
Expand All @@ -70,7 +71,8 @@ def _lookup_package(name: str, version: str, session: Session) -> Scan:
)
raise error

scan = session.scalar(query.where(Scan.version == version))
with session.begin():
scan = session.scalar(query.where(Scan.version == version))
if scan is None:
error = HTTPException(
404, detail=f"Package `{name}` has records in the database, but none with version `{version}`"
Expand Down Expand Up @@ -233,6 +235,12 @@ def report_package(

httpx.post(f"{mainframe_settings.reporter_url}/report/{name}", json=jsonable_encoder(report))

with session.begin():
scan.reported_by = auth.subject
scan.reported_at = dt.datetime.now(dt.timezone.utc)

session.close()

log.info(
"Sent report",
report_data={
Expand All @@ -245,7 +253,3 @@ def report_package(
},
reported_by=auth.subject,
)

scan.reported_by = auth.subject
scan.reported_at = dt.datetime.now(dt.timezone.utc)
session.commit()
11 changes: 6 additions & 5 deletions src/mainframe/endpoints/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def _get_failed_packages(session: Session) -> int:

@router.get("/stats", dependencies=[Depends(validate_token)])
def get_stats(session: Annotated[Session, Depends(get_db)]) -> StatsResponse:
return StatsResponse(
ingested=_get_package_ingest(session),
average_scan_time=_get_average_scan_time(session),
failed=_get_failed_packages(session),
)
with session, session.begin():
return StatsResponse(
ingested=_get_package_ingest(session),
average_scan_time=_get_average_scan_time(session),
failed=_get_failed_packages(session),
)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@pytest.fixture(scope="session")
def sm(engine: Engine) -> sessionmaker[Session]:
return sessionmaker(bind=engine, expire_on_commit=False)
return sessionmaker(bind=engine, expire_on_commit=False, autobegin=False)


@pytest.fixture(scope="session")
Expand Down
21 changes: 12 additions & 9 deletions tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@


def oldest_queued_package(db_session: Session):
return db_session.scalar(select(func.min(Scan.queued_at)).where(Scan.status == Status.QUEUED))
with db_session.begin():
return db_session.scalar(select(func.min(Scan.queued_at)).where(Scan.status == Status.QUEUED))


def test_min_queue_date_of_queued_rows(test_data: list[Scan], db_session: Session):
Expand All @@ -25,7 +26,8 @@ def test_min_queue_date_of_queued_rows(test_data: list[Scan], db_session: Sessio


def fetch_queue_time(name: str, version: str, db_session: Session) -> dt.datetime | None:
return db_session.scalar(select(Scan.queued_at).where(Scan.name == name).where(Scan.version == version))
with db_session.begin():
return db_session.scalar(select(Scan.queued_at).where(Scan.name == name).where(Scan.version == version))


def test_fetch_queue_time(test_data: list[Scan], db_session: Session):
Expand Down Expand Up @@ -63,10 +65,11 @@ def test_batch_job(test_data: list[Scan], db_session: Session, auth: Authenticat
assert (row.name, row.version) not in jobs

# check if the database was accurately updated
for name, version in jobs:
row = db_session.scalar(select(Scan).where(Scan.name == name).where(Scan.version == version))

assert row is not None
assert row.status == Status.PENDING
assert row.pending_by is not None
assert row.pending_at is not None
with db_session.begin():
for name, version in jobs:
row = db_session.scalar(select(Scan).where(Scan.name == name).where(Scan.version == version))

assert row is not None
assert row.status == Status.PENDING
assert row.pending_by is not None
assert row.pending_at is not None
Loading

0 comments on commit d8658a0

Please sign in to comment.