From ea3dc4b1aea8116ffe5ad4b89e31ffbdcc0a8b3a Mon Sep 17 00:00:00 2001 From: Shadi Date: Sun, 2 Jun 2024 20:43:59 -0400 Subject: [PATCH] New updates to LDMatrix + SumstatsTable data structures. --- CHANGELOG.md | 6 +++ magenpy/GWADataLoader.py | 78 ++++++++++++++++++++++++------ magenpy/LDMatrix.py | 58 +++++++++++++++++++++++ magenpy/SumstatsTable.py | 79 ++++++++++++++++++++++++------- magenpy/data/lrld_hg19_GRCh37.txt | 20 ++++++++ magenpy/stats/ld/c_utils.pyx | 64 +++++++++++++++++++++++++ magenpy/utils/data_utils.py | 16 +++++++ 7 files changed, 288 insertions(+), 33 deletions(-) create mode 100644 magenpy/data/lrld_hg19_GRCh37.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index e2ccc81..b93ae89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,12 @@ name (before it was returning true for inliers...). - Added `get_peak_memory_usage` to `system_utils` to inspect peak memory usage of a process. - Placeholder method to perform QC on `SumstatsTable` objects (needs to be implemented still). +- New attached dataset for long-range LD regions. +- New method in SumstatsTable to impute rsID (if missing). +- Preliminary support for matching with CHR+POS in SumstatsTable (still needs more work). +- LDMatrix updates: + - New method to filter long-range LD regions. + - New method to prune LD matrix. - New algorithm for symmetrizing upper triangular and block diagonal LD matrices. - Much faster and more memory efficient than using `scipy`. - New `LDMatrix` class has efficient data loading in `.load_data` method. diff --git a/magenpy/GWADataLoader.py b/magenpy/GWADataLoader.py index ca6bb6c..a6a81fb 100644 --- a/magenpy/GWADataLoader.py +++ b/magenpy/GWADataLoader.py @@ -548,7 +548,10 @@ def read_summary_statistics(self, desc="Reading summary statistics", disable=not self.verbose or len(sumstats_files) < 2): - ss_tab = SumstatsTable.from_file(f, sumstats_format=sumstats_format, parser=parser, **parse_kwargs) + ss_tab = SumstatsTable.from_file(f, + sumstats_format=sumstats_format, + parser=parser, + **parse_kwargs) if drop_duplicated: ss_tab.drop_duplicates() @@ -565,6 +568,19 @@ def read_summary_statistics(self, self.sumstats_table.update(ss_tab.split_by_chromosome(snps_per_chrom=ref_table)) + # If SNP information is not present in the sumstats tables, try to impute it + # using other reference tables: + + missing_snp = any('SNP' not in ss.table.columns for ss in self.sumstats_table.values()) + + if missing_snp and (self.genotype is not None or self.ld is not None): + + ref_table = self.to_snp_table(col_subset=['CHR', 'POS', 'SNP'], per_chromosome=True) + + for c, ss in self.sumstats_table.items(): + if 'SNP' not in ss.table.columns and c in ref_table: + ss.infer_snp_id(ref_table[c], allow_na=True) + def read_ld(self, ld_store_paths): """ Read the LD matrix files stored on-disk in Zarr array format. @@ -672,6 +688,10 @@ def harmonize_data(self): However, if you read or manipulate the data sources after initialization, you may need to call this method again to ensure that the data sources remain aligned. + !!! warning + Harmonization for now depends on having SNP rsID be present in all the resources. Hopefully + this requirement will be relaxed in the future. + """ data_sources = (self.genotype, self.sumstats_table, self.ld, self.annotation) @@ -705,8 +725,8 @@ def harmonize_data(self): else: - # Find the set of SNPs that are shared across all data sources: - common_snps = np.array(list(set.intersection(*[set(ds[c].snps) + # Find the set of SNPs that are shared across all data sources (exclude missing values): + common_snps = np.array(list(set.intersection(*[set(ds[c].snps[~pd.isnull(ds[c].snps)]) for ds in initialized_data_sources]))) # If necessary, filter the data sources to only have the common SNPs: @@ -717,10 +737,17 @@ def harmonize_data(self): # Harmonize the summary statistics data with either genotype or LD reference. # This procedure checks for flips in the effect allele between data sources. if self.sumstats_table is not None: + + id_cols = self.sumstats_table[c].identifier_cols + if self.genotype is not None: - self.sumstats_table[c].match(self.genotype[c].get_snp_table(col_subset=['SNP', 'A1', 'A2'])) + self.sumstats_table[c].match(self.genotype[c].get_snp_table( + col_subset=id_cols + ['A1', 'A2'] + )) elif self.ld is not None: - self.sumstats_table[c].match(self.ld[c].to_snp_table(col_subset=['SNP', 'A1', 'A2'])) + self.sumstats_table[c].match(self.ld[c].to_snp_table( + col_subset=id_cols + ['A1', 'A2'] + )) # If during the allele matching process we discover incompatibilities, # we filter those SNPs: @@ -763,7 +790,8 @@ def score(self, beta=None, standardize_genotype=False): try: beta = {c: s.marginal_beta or s.get_snp_pseudo_corr() for c, s in self.sumstats_table.items()} except Exception: - raise ValueError("To perform linear scoring, you must provide effect size estimates (BETA)!") + raise ValueError("To perform linear scoring, you must " + "provide effect size estimates (BETA)!") # Here, we have a very ugly way of accounting for # the fact that the chromosomes may be coded differently between the genotype @@ -771,7 +799,8 @@ def score(self, beta=None, standardize_genotype=False): common_chr_g, common_chr_b = match_chromosomes(self.genotype.keys(), beta.keys(), return_both=True) if len(common_chr_g) < 1: - raise ValueError("No common chromosomes found between the genotype and the effect size estimates!") + raise ValueError("No common chromosomes found between " + "the genotype and the effect size estimates!") if self.verbose and len(common_chr_g) < 2: print("> Generating polygenic scores...") @@ -831,7 +860,7 @@ def to_phenotype_table(self): return self.sample_table.get_phenotype_table() - def to_snp_table(self, col_subset=None, per_chromosome=False): + def to_snp_table(self, col_subset=None, per_chromosome=False, resource='auto'): """ Get a dataframe of SNP data for all variants across different chromosomes. @@ -840,21 +869,40 @@ def to_snp_table(self, col_subset=None, per_chromosome=False): :param per_chromosome: If True, returns a dictionary where the key is the chromosome number and the value is the SNP table per chromosome. + :param resource: The data source to extract the SNP table from. By default, the method + will try to extract the SNP table from the genotype matrix. If the genotype matrix is not + available, then it will try to extract the SNP information from the LD matrix or the summary + statistics table. Possible values: `auto`, `genotype`, `ld`, `sumstats`. :return: A dataframe (or dictionary of dataframes) of SNP data. """ + # Sanity checks: + assert resource in ('auto', 'genotype', 'ld', 'sumstats') + if resource != 'auto': + if resource == 'genotype' and self.genotype is None: + raise ValueError("Genotype matrix is not available!") + if resource == 'ld' and self.ld is None: + raise ValueError("LD matrix is not available!") + if resource == 'sumstats' and self.sumstats_table is None: + raise ValueError("Summary statistics table is not available!") + else: + if all(ds is None for ds in (self.genotype, self.ld, self.sumstats_table)): + raise ValueError("No data sources available to extract SNP data from!") + + # Extract the SNP data: + snp_tables = {} - for c in self.chromosomes: - if self.sumstats_table is not None: - snp_tables[c] = self.sumstats_table[c].to_table(col_subset=col_subset) - elif self.genotype is not None: + if resource in ('auto', 'genotype') and self.genotype is not None: + for c in self.chromosomes: snp_tables[c] = self.genotype[c].get_snp_table(col_subset=col_subset) - elif self.ld is not None: + elif resource in ('auto', 'ld') and self.ld is not None: + for c in self.chromosomes: snp_tables[c] = self.ld[c].to_snp_table(col_subset=col_subset) - else: - raise ValueError("GWADataLoader instance is not properly initialized!") + else: + return self.to_summary_statistics_table(col_subset=col_subset, + per_chromosome=per_chromosome) if per_chromosome: return snp_tables diff --git a/magenpy/LDMatrix.py b/magenpy/LDMatrix.py index 5c8dfa5..c36451a 100644 --- a/magenpy/LDMatrix.py +++ b/magenpy/LDMatrix.py @@ -796,6 +796,40 @@ def indptr(self): else: return self._zg['matrix/indptr'] + def filter_long_range_ld_regions(self): + """ + A utility method to exclude variants that are in long-range LD regions. The + boundaries of those regions are derived from here: + + https://genome.sph.umich.edu/wiki/Regions_of_high_linkage_disequilibrium_(LD) + + Which is based on the work of + + > Anderson, Carl A., et al. "Data quality control in genetic case-control association studies." Nature protocols 5.9 (2010): 1564-1573. + + .. note :: + This method is experimental and may not work as expected for all LD matrices. + """ + + from .parsers.annotation_parsers import parse_annotation_bed_file + from .utils.data_utils import lrld_path + + bed_df = parse_annotation_bed_file(lrld_path()) + + # Filter to only regions specific to the chromosome of this matrix: + bed_df = bed_df.loc[bed_df['CHR'] == self.chromosome] + + bp_pos = self.bp_position + snp_mask = np.ones(len(bp_pos), dtype=bool) + + # Loop over the LRLD region on this chromosome and exclude the SNPs in these regions: + for _, row in bed_df.iterrows(): + start, end = row['Start'], row['End'] + snp_mask &= ~((bp_pos >= start) & (bp_pos <= end)) + + # Filter the SNP to only those not in the LRLD regions: + self.filter_snps(self.snps[snp_mask]) + def filter_snps(self, extract_snps=None, extract_file=None): """ Filter the LDMatrix to keep a subset of variants. This mainly sets @@ -859,6 +893,30 @@ def reset_mask(self): return_symmetric=self.is_symmetric, dtype=self.dtype) + def prune(self, threshold): + """ + Perform LD pruning to remove variants that are in high LD with other variants. + If two variants are in high LD, this function keeps the variant that occurs + earlier in the matrix. This behavior will be updated in the future to allow + for arbitrary ordering of variants. + + !!! note + Experimental for now. Needs further testing & improvement. + + :param threshold: The absolute value of the Pearson correlation coefficient above which to prune variants. + :return: A boolean array indicating whether a variant is kept after pruning. A positive floating point number + between 0. and 1. + """ + + from .stats.ld.c_utils import prune_ld_ut + + assert 0. < threshold <= 1. + + if np.issubdtype(self.stored_dtype, np.integer): + threshold = quantize(np.array([threshold]), int_dtype=self.stored_dtype)[0] + + return prune_ld_ut(self.indptr[:], self.data[:], threshold) + def to_snp_table(self, col_subset=None): """ :param col_subset: The subset of columns to add to the table. If None, it returns diff --git a/magenpy/SumstatsTable.py b/magenpy/SumstatsTable.py index 1ca3a58..59ffc5c 100644 --- a/magenpy/SumstatsTable.py +++ b/magenpy/SumstatsTable.py @@ -34,12 +34,18 @@ def __init__(self, ss_table: pd.DataFrame): """ self.table: pd.DataFrame = ss_table - assert all([col in self.table.columns for col in ('SNP', 'A1')]) + # Check that the table contains some of the required columns (non exhaustive): + + # Either has SNP or CHR+POS: + assert 'SNP' in self.table.columns or all([col in self.table.columns for col in ('CHR', 'POS')]) + # Assert that the table has at least one of the alleles: + assert any([col in self.table.columns for col in ('A1', 'A2')]) + # TODO: Add other assertions? @property def shape(self): """ - :return: he shape of the summary statistics table. + :return: The shape of the summary statistics table. """ return self.table.shape @@ -49,7 +55,8 @@ def __len__(self): @property def chromosome(self): """ - A convenience method to return the chromosome number if there is only one chromosome in the summary statistics. + A convenience method to return the chromosome number if there is only + one chromosome in the summary statistics. If multiple chromosomes are present, it returns None. :return: The chromosome number if there is only one chromosome in the summary statistics. @@ -76,6 +83,13 @@ def m(self): """ return self.n_snps + @property + def identifier_cols(self): + if 'SNP' in self.table.columns: + return ['SNP'] + else: + return ['CHR', 'POS'] + @property def n_snps(self): """ @@ -341,21 +355,29 @@ def effect_sign(self): def infer_a2(self, reference_table, allow_na=False): """ Infer the reference allele A2 (if not present in the SumstatsTable) - from a reference table. Make sure that the reference table contains the SNP ID, - the reference allele A2 and the alternative (i.e. effect) allele A1. It is the - user's responsibility to make sure that the reference table matches the summary - statistics in terms of the specification of reference vs. alternative. They are - allowed to be flipped, but they have to be consistent across the two tables. + from a reference table. Make sure that the reference table contains the identifier information + for each SNP, in addition to the reference allele A2 and the alternative (i.e. effect) allele A1. + It is the user's responsibility to make sure that the reference table matches the summary + statistics in terms of the specification of reference vs. alternative. They have to be consistent + across the two tables. :param reference_table: A pandas table containing the following columns at least: - `SNP`, `A1`, `A2`. + SNP identifiers (`SNP` or `CHR` & `POS`) and allele information (`A1` & `A2`). :param allow_na: If True, allow the reference allele to be missing from the final result. """ - # Merge the summary statistics table with the reference table on `SNP` ID: - merged_table = self.table[['SNP', 'A1']].merge(reference_table[['SNP', 'A1', 'A2']], - how='left', - on='SNP') + # Get the identifier columns for this table: + id_cols = self.identifier_cols + + # Sanity checks: + assert all([col in reference_table.columns for col in id_cols + ['A1', 'A2']]) + + # Merge the summary statistics table with the reference table on unique ID: + merged_table = self.table[id_cols + ['A1']].merge( + reference_table[id_cols + ['A1', 'A2']], + how='left', + on=id_cols + ) # If `A1_x` agrees with `A1_y`, then `A2` is indeed the reference allele. # Otherwise, they are flipped and `A1_y` should be the reference allele: merged_table['A2'] = np.where(merged_table['A1_x'] == merged_table['A1_y'], @@ -368,6 +390,25 @@ def infer_a2(self, reference_table, allow_na=False): else: self.table['A2'] = merged_table['A2'] + def infer_snp_id(self, reference_table, allow_na=False): + """ + Infer the SNP ID (if not present in the SumstatsTable) from a reference table. + Make sure that the reference table contains the SNP ID, chromosome ID, and position. + + :param reference_table: A pandas table containing the following columns at least: + `SNP`, `CHR`, `POS`. + :param allow_na: If True, allow the SNP ID to be missing from the final result. + """ + + # Merge the summary statistics table with the reference table: + merged_table = self.table[['CHR', 'POS']].merge(reference_table[['SNP', 'CHR', 'POS']], how='left') + + # Check that the SNP ID could be inferred for all SNPs: + if not allow_na and merged_table['SNP'].isna().any(): + raise ValueError("The SNP ID could not be inferred for some SNPs!") + else: + self.table['SNP'] = merged_table['SNP'].values + def set_sample_size(self, n): """ Set the sample size for each variant in the summary table. @@ -394,14 +435,15 @@ def match(self, reference_table, correct_flips=True): correcting for potential flips in the effect alleles. :param reference_table: The SNP table to use as a reference. Must be a pandas - table with at least three columns: SNP, A1, A2. + table with the following columns: SNP identifier (either `SNP` or `CHR` & `POS`) and allele information + (`A1` & `A2`). :param correct_flips: If True, correct the direction of effect size estimates if the effect allele is reversed. """ from .utils.model_utils import merge_snp_tables - self.table = merge_snp_tables(ref_table=reference_table[['SNP', 'A1', 'A2']], + self.table = merge_snp_tables(ref_table=reference_table[self.identifier_cols + ['A1', 'A2']], alt_table=self.table, how='inner', correct_flips=correct_flips) @@ -457,14 +499,15 @@ def filter_snps(self, extract_snps=None, extract_file=None, extract_index=None): self.table = self.table.iloc[extract_index, ].reset_index(drop=True) else: raise Exception("To filter a summary statistics table, you must provide " - "the list of SNPs, a file containing the list of SNPs, or a list of indices to retain.") + "the list of SNPs, a file containing the list of SNPs, " + "or a list of indices to retain.") def drop_duplicates(self): """ Drop variants with duplicated rsIDs from the summary statistics table. """ - self.table = self.table.drop_duplicates(subset='SNP', keep=False) + self.table = self.table.drop_duplicates(subset=self.identifier_cols, keep=False) def get_col(self, col_name): """ @@ -583,7 +626,7 @@ def split_by_chromosome(self, snps_per_chrom=None): if 'CHR' in self.table.columns: chrom_tables = self.table.groupby('CHR') return { - c: SumstatsTable(chrom_tables.get_group(c)) + c: SumstatsTable(chrom_tables.get_group(c).copy()) for c in chrom_tables.groups } elif snps_per_chrom is not None: diff --git a/magenpy/data/lrld_hg19_GRCh37.txt b/magenpy/data/lrld_hg19_GRCh37.txt new file mode 100644 index 0000000..9d154c3 --- /dev/null +++ b/magenpy/data/lrld_hg19_GRCh37.txt @@ -0,0 +1,20 @@ +chr1 48227412 52227412 1 +chr2 86146488 101133568 2 +chr2 89563673 89564169 3 +chr2 98080299 98080324 4 +chr2 183291754 190291755 5 +chr3 47524995 50024996 6 +chr3 83417309 86917310 7 +chr5 44464242 50464243 8 +chr5 128972100 131972101 9 +chr6 25392020 33392022 10 +chr6 56892040 63942041 11 +chr6 139958306 142458307 12 +chr7 55032505 66362565 13 +chr8 7962589 11962591 14 +chr8 42880842 49837447 15 +chr8 111930823 114930824 16 +chr10 36959993 43679994 17 +chr11 87860351 90860352 18 +chr12 33108732 41713733 19 +chr20 32536338 35066586 20 \ No newline at end of file diff --git a/magenpy/stats/ld/c_utils.pyx b/magenpy/stats/ld/c_utils.pyx index 6c251be..0eb7dcd 100644 --- a/magenpy/stats/ld/c_utils.pyx +++ b/magenpy/stats/ld/c_utils.pyx @@ -26,6 +26,70 @@ ctypedef fused noncomplex_numeric: cnp.float64_t +cpdef find_tagging_variants(int[::1] variant_indices, + integral[::1] indptr, + noncomplex_numeric[::1] data, + noncomplex_numeric threshold): + """ + TODO: Implement function to find tagging variants. + """ + pass + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +@cython.exceptval(check=False) +cdef noncomplex_numeric numeric_abs(noncomplex_numeric x) noexcept nogil: + """ + Return the absolute value of a numeric type. + """ + if x < 0: + return -x + return x + +cpdef prune_ld_ut(integral[::1] indptr, + noncomplex_numeric[::1] data, + noncomplex_numeric r_threshold): + """ + Pass over the LD matrix once and prune it so that variants whose absolute correlation coefficient is above + or equal to a certain threshold are filtered away. If two variants are highly correlated, + this function keeps the one that occurs earlier in the matrix. + + This function works with LD matrices in any data type + (quantized to integers or floats), but it is the user's responsibility to set the appropriate + threshold for the data type used. + + !!! note + This function assumes that the LD matrix is in upper triangular form and doesn't include the + diagonal. We will try to generalize this implementation later. + + :param indptr: The index pointer array for the CSR matrix to be pruned. + :param data: The data array for the CSR matrix to be pruned. + :param r_threshold: The Pearson Correlation coefficient threshold above which to prune variants. + + :return: An boolean array of which variants are kept after pruning. + """ + + cdef: + int64_t i, curr_row, curr_row_size, curr_data_idx, curr_shape=indptr.shape[0]-1 + char[::1] keep = np.ones(curr_shape, dtype=np.int8) + + with nogil: + for curr_row in range(curr_shape): + + if keep[curr_row] == 1: + + curr_row_size = indptr[curr_row + 1] - indptr[curr_row] + + for i in range(curr_row_size): + curr_data_idx = indptr[curr_row] + i + + if numeric_abs(data[curr_data_idx]) >= r_threshold: + keep[curr_row + i + 1] = 0 + + return np.asarray(keep, dtype=bool) + @cython.boundscheck(False) @cython.wraparound(False) diff --git a/magenpy/utils/data_utils.py b/magenpy/utils/data_utils.py index 1643ecd..feaa97e 100644 --- a/magenpy/utils/data_utils.py +++ b/magenpy/utils/data_utils.py @@ -18,3 +18,19 @@ def ukb_height_sumstats_path(): https://yanglab.westlake.edu.cn/data/fastgwa_data/UKB/50.v1.1.fastGWA.gz """ return osp.join(osp.dirname(osp.dirname(__file__)), 'data/ukb_height_chr22.fastGWA.gz') + + +def lrld_path(): + """ + The boundaries of Long Range LD (LRLD) regions derived from here: + + https://genome.sph.umich.edu/wiki/Regions_of_high_linkage_disequilibrium_(LD) + + Which is based on the work of + + > Anderson, Carl A., et al. "Data quality control in genetic case-control association studies." Nature protocols 5.9 (2010): 1564-1573. + + :return: The path of the attached BED file containing long-range linkage disequilibrium + (LD) regions in the human genome. The coordinates are in hg19/GRCh37. + """ + return osp.join(osp.dirname(osp.dirname(__file__)), 'data/lrld_hg19_GRCh37.txt')