Skip to content

Commit

Permalink
Merge branch 'qual-metadata' of github.com:vector-engineering/covidcg…
Browse files Browse the repository at this point in the history
… into qual-metadata
  • Loading branch information
atc3 committed Nov 5, 2023
2 parents e8a6c2d + c5ae376 commit aaea018
Show file tree
Hide file tree
Showing 22 changed files with 432 additions and 54 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ data
data_genbank
data_flu
data_flu_small
data_flu_genbank
data_gisaid_flu
data_gisaid_rsv
data_genbank_rsv
Expand Down
1 change: 1 addition & 0 deletions .gcloudignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ data_az
data_ma
data_flu
data_flu_small
data_flu_genbank
data_gisaid_flu
data_gisaid_rsv
dist
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ example_data_genbank/*/lineage_treetime/*.pdf

data
data_genbank
data_flu_genbank
example_data_genbank/rsv/**
example_data_genbank/flu/**
example_data_genbank/sars2/**
Expand Down
2 changes: 1 addition & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "covidcg",
"version": "2.7.5-pango",
"version": "2.7.6-qual-rc1",
"description": "",
"engines": {
"node": ">=8",
Expand Down
38 changes: 21 additions & 17 deletions services/server/cg_server/db_seed/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def df_to_sql(cur, df, table, index_label=None):

def seed_database(conn, schema="public"):
with conn.cursor() as cur:

cur.execute(sql.SQL("SET search_path TO {};").format(sql.Identifier(schema)))

cur.execute("DROP EXTENSION IF EXISTS intarray;")
Expand Down Expand Up @@ -195,7 +194,6 @@ def seed_database(conn, schema="public"):

mutation_fields = ["dna", "gene_aa", "protein_aa"]
for grouping in group_mutation_frequencies.keys():

# Get references
reference_names = sorted(group_mutation_frequencies[grouping].keys())

Expand Down Expand Up @@ -256,7 +254,6 @@ def seed_database(conn, schema="public"):

# Build colormaps
for grouping in config["group_cols"].keys():

# Collect unique group names
group_names = []
for reference in global_group_counts.keys():
Expand Down Expand Up @@ -296,12 +293,27 @@ def seed_database(conn, schema="public"):
print("done")

print("Writing sequence metadata...", end="", flush=True)

isolate_df = pd.read_json(data_path / "isolate_data.json")
isolate_df["collection_date"] = pd.to_datetime(isolate_df["collection_date"])
isolate_df["submission_date"] = pd.to_datetime(isolate_df["submission_date"])
# print(isolate_df.columns)

# Make a column for each metadata field
metadata_cols = []
metadata_col_defs = []
for field in list(config["metadata_cols"].keys()) + loc_levels:
metadata_col_defs.append(sql.SQL(f"{field} INTEGER NOT NULL"))
metadata_cols.append(field)

# Make columns for sequence metadata, if they exist
if "length" in isolate_df.columns:
metadata_cols.append("length")
metadata_col_defs.append(sql.SQL("length INTEGER NOT NULL"))
if "percent_ambiguous" in isolate_df.columns:
metadata_cols.append("percent_ambiguous")
metadata_col_defs.append(sql.SQL("percent_ambiguous REAL NOT NULL"))

metadata_col_defs = sql.SQL(",\n").join(metadata_col_defs)

# Make a column for each grouping
Expand Down Expand Up @@ -338,11 +350,6 @@ def seed_database(conn, schema="public"):
)
)

isolate_df = pd.read_json(data_path / "isolate_data.json")
isolate_df["collection_date"] = pd.to_datetime(isolate_df["collection_date"])
isolate_df["submission_date"] = pd.to_datetime(isolate_df["submission_date"])
# print(isolate_df.columns)

# Partition settings
min_date = isolate_df["collection_date"].min()
# Round latest sequence to the nearest partition break
Expand Down Expand Up @@ -423,8 +430,7 @@ def seed_database(conn, schema="public"):
"segments",
"accession_ids",
]
+ list(config["metadata_cols"].keys())
+ loc_levels
+ metadata_cols
+ list(
filter(lambda x: x != "subtype", config["group_cols"].keys())
) # Avoid duplicate subtype index
Expand Down Expand Up @@ -480,7 +486,9 @@ def seed_database(conn, schema="public"):
# Clean up the reference name as a SQL ident - no dots
reference_name_sql = reference_name.replace(".", "_")

reference_partition_name = f"seqmut_{mutation_field}_{reference_name_sql}"
reference_partition_name = (
f"seqmut_{mutation_field}_{reference_name_sql}"
)

# Create reference partition
cur.execute(
Expand Down Expand Up @@ -555,13 +563,11 @@ def seed_database(conn, schema="public"):
"subtype",
"reference",
]
+ list(config["metadata_cols"].keys())
+ loc_levels
+ metadata_cols
+ list(
filter(lambda x: x != "subtype", config["group_cols"].keys())
) # Avoid duplicate subtype index
):

cur.execute(
sql.SQL(
"CREATE INDEX {index_name} ON {table_name}({field});"
Expand Down Expand Up @@ -756,13 +762,11 @@ def seed_database(conn, schema="public"):
"subtype",
"reference",
]
+ list(config["metadata_cols"].keys())
+ loc_levels
+ metadata_cols
+ list(
filter(lambda x: x != "subtype", config["group_cols"].keys())
) # Avoid duplicate subtype index
):

cur.execute(
sql.SQL(
"CREATE INDEX {index_name} ON {table_name}({field});"
Expand Down
2 changes: 2 additions & 0 deletions services/server/cg_server/download/genomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def download_genomes(conn, req):
req.get("selected_metadata_fields", None),
req.get("selected_group_fields", None),
req.get("selected_reference", None),
req.get("sequence_length", None),
req.get("percent_ambiguous", None),
)

cur.execute(
Expand Down
8 changes: 6 additions & 2 deletions services/server/cg_server/download/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


def download_metadata(conn, req):

selected_reference = req.get("selected_reference", None)
if not selected_reference:
raise Exception("No reference specified")
Expand All @@ -34,6 +33,8 @@ def download_metadata(conn, req):
req.get("subm_end_date", None),
req.get("selected_metadata_fields", None),
req.get("selected_group_fields", None),
req.get("sequence_length", None),
req.get("percent_ambiguous", None),
selected_reference,
)

Expand Down Expand Up @@ -179,7 +180,10 @@ def download_metadata(conn, req):

cur.execute(query)

res_df = pd.DataFrame.from_records(cur.fetchall(), columns=sequence_cols,)
res_df = pd.DataFrame.from_records(
cur.fetchall(),
columns=sequence_cols,
)

# Replace mutation IDs with names
for mutation_field in ["dna", "gene_aa", "protein_aa"]:
Expand Down
15 changes: 11 additions & 4 deletions services/server/cg_server/query/group_mutation_frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def query_group_mutation_frequencies(conn, req):
"mutation_str",
]
if mutation_type == "gene_aa" or mutation_type == "protein_aa":
mutation_cols = ["feature",] + mutation_cols
mutation_cols = [
"feature",
] + mutation_cols

mutation_cols_expr = sql.SQL(",\n").join(
[
Expand Down Expand Up @@ -94,10 +96,14 @@ def query_group_mutation_frequencies_dynamic(conn, req):
mutation_table = "dna_mutation"
elif mutation_type == "gene_aa":
mutation_table = "gene_aa_mutation"
mutation_cols = ["feature",] + mutation_cols
mutation_cols = [
"feature",
] + mutation_cols
elif mutation_type == "protein_aa":
mutation_table = "protein_aa_mutation"
mutation_cols = ["feature",] + mutation_cols
mutation_cols = [
"feature",
] + mutation_cols

sequence_where_filter = build_sequence_location_where_filter(
group_key,
Expand All @@ -108,6 +114,8 @@ def query_group_mutation_frequencies_dynamic(conn, req):
req.get("subm_end_date", None),
req.get("selected_metadata_fields", None),
req.get("selected_group_fields", None),
req.get("sequence_length", None),
req.get("percent_ambiguous", None),
selected_reference,
)
sequence_mutation_table = "sequence_" + mutation_table
Expand Down Expand Up @@ -167,4 +175,3 @@ def query_group_mutation_frequencies_dynamic(conn, req):
)

return res.to_json(orient="records")

3 changes: 2 additions & 1 deletion services/server/cg_server/query/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

def query_metadata(conn, req):
with conn.cursor() as cur:

sequence_where_filter = build_sequence_location_where_filter(
req.get("group_key", None),
get_loc_level_ids(req),
Expand All @@ -25,6 +24,8 @@ def query_metadata(conn, req):
req.get("selected_metadata_fields", None),
req.get("selected_group_fields", None),
req.get("selected_reference", None),
req.get("sequence_length", None),
req.get("percent_ambiguous", None),
)

# Iterate over each metadata column, and aggregate counts
Expand Down
59 changes: 56 additions & 3 deletions services/server/cg_server/query/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def build_sequence_where_filter(
selected_metadata_fields=None,
selected_group_fields=None,
selected_reference=None,
sequence_length=None,
percent_ambiguous=None,
):
"""Build query for filtering sequences based on user's location/date
selection and selected metadata fields
Expand All @@ -138,6 +140,10 @@ def build_sequence_where_filter(
- Values are a list of group values, i.e., ["B.1.617.2", "BA.1"]
selected_reference: str
- Reference name (e.g., "NC_012920.1")
sequence_length: pair of integers
- (min_length, max_length)
percent_ambiguous: pair of floats
- (min_percent, max_percent)
Returns
-------
Expand Down Expand Up @@ -230,19 +236,59 @@ def build_sequence_where_filter(
else:
group_filters = sql.SQL("")

if sequence_length:
sequence_length_filter = []
if sequence_length[0] is not None:
sequence_length_filter.append(
sql.SQL('"length" >= {}').format(sql.Literal(sequence_length[0]))
)
if sequence_length[1] is not None:
sequence_length_filter.append(
sql.SQL('"length" <= {}').format(sql.Literal(sequence_length[1]))
)
sequence_length_filter = sql.Composed(
[sql.SQL(" AND "), sql.SQL(" AND ").join(sequence_length_filter)]
)
else:
sequence_length_filter = sql.SQL("")

if percent_ambiguous:
percent_ambiguous_filter = []
if percent_ambiguous[0] is not None:
percent_ambiguous_filter.append(
sql.SQL('"percent_ambiguous" >= {}').format(
sql.Literal(percent_ambiguous[0])
)
)
if percent_ambiguous[1] is not None:
percent_ambiguous_filter.append(
sql.SQL('"percent_ambiguous" <= {}').format(
sql.Literal(percent_ambiguous[1])
)
)
percent_ambiguous_filter = sql.Composed(
[sql.SQL(" AND "), sql.SQL(" AND ").join(percent_ambiguous_filter)]
)
else:
percent_ambiguous_filter = sql.SQL("")

sequence_where_filter = sql.SQL(
"""
{metadata_filters}
{group_filters}
"collection_date" >= {start_date} AND "collection_date" <= {end_date}
{submission_date_filter}
{sequence_length_filter}
{percent_ambiguous_filter}
"""
).format(
metadata_filters=metadata_filters,
group_filters=group_filters,
start_date=sql.Literal(pd.to_datetime(start_date)),
end_date=sql.Literal(pd.to_datetime(end_date)),
submission_date_filter=submission_date_filter,
sequence_length_filter=sequence_length_filter,
percent_ambiguous_filter=percent_ambiguous_filter,
)

return sequence_where_filter
Expand Down Expand Up @@ -283,7 +329,8 @@ def build_sequence_location_where_filter(group_key, loc_level_ids, *args, **kwar
continue
loc_where.append(
sql.SQL("({loc_level_col} = ANY({loc_ids}))").format(
loc_level_col=sql.Identifier(loc_level), loc_ids=sql.Literal(loc_ids),
loc_level_col=sql.Identifier(loc_level),
loc_ids=sql.Literal(loc_ids),
)
)

Expand Down Expand Up @@ -370,7 +417,10 @@ def count_coverage(

cur.execute(coverage_query)

coverage_df = pd.DataFrame.from_records(cur.fetchall(), columns=["ind", "count"],)
coverage_df = pd.DataFrame.from_records(
cur.fetchall(),
columns=["ind", "count"],
)

if dna_or_aa != constants["DNA_OR_AA"]["DNA"]:
if coordinate_mode == constants["COORDINATE_MODES"]["COORD_GENE"]:
Expand Down Expand Up @@ -417,7 +467,6 @@ def query_and_aggregate(conn, req):
selected_protein = req.get("selected_protein", None)

with conn.cursor() as cur:

main_query = []
for loc_level in constants["GEO_LEVELS"].values():
loc_ids = req.get(loc_level, None)
Expand All @@ -433,6 +482,8 @@ def query_and_aggregate(conn, req):
req.get("selected_metadata_fields", None),
req.get("selected_group_fields", None),
req.get("selected_reference", None),
req.get("sequence_length", None),
req.get("percent_ambiguous", None),
)
sequence_where_filter = sql.SQL(
"{prior} AND {loc_level_col} = ANY({loc_ids})"
Expand Down Expand Up @@ -536,6 +587,8 @@ def query_and_aggregate(conn, req):
req.get("selected_metadata_fields", None),
req.get("selected_group_fields", None),
req.get("selected_reference", None),
req.get("sequence_length", None),
req.get("percent_ambiguous", None),
)
coverage_df = count_coverage(
cur,
Expand Down
2 changes: 2 additions & 0 deletions services/server/cg_server/query/variant_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def build_variant_table(conn, req):
req.get("subm_end_date", None),
req.get("selected_metadata_fields", None),
req.get("selected_group_fields", None),
req.get("sequence_length", None),
req.get("percent_ambiguous", None),
selected_reference,
)

Expand Down
Loading

0 comments on commit aaea018

Please sign in to comment.