Skip to content

Commit

Permalink
New updates to LDMatrix + SumstatsTable data structures.
Browse files Browse the repository at this point in the history
  • Loading branch information
shz9 committed Jun 3, 2024
1 parent e6d4481 commit ea3dc4b
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 33 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ name (before it was returning true for inliers...).

- Added `get_peak_memory_usage` to `system_utils` to inspect peak memory usage of a process.
- Placeholder method to perform QC on `SumstatsTable` objects (needs to be implemented still).
- New attached dataset for long-range LD regions.
- New method in SumstatsTable to impute rsID (if missing).
- Preliminary support for matching with CHR+POS in SumstatsTable (still needs more work).
- LDMatrix updates:
- New method to filter long-range LD regions.
- New method to prune LD matrix.
- New algorithm for symmetrizing upper triangular and block diagonal LD matrices.
- Much faster and more memory efficient than using `scipy`.
- New `LDMatrix` class has efficient data loading in `.load_data` method.
Expand Down
78 changes: 63 additions & 15 deletions magenpy/GWADataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,10 @@ def read_summary_statistics(self,
desc="Reading summary statistics",
disable=not self.verbose or len(sumstats_files) < 2):

ss_tab = SumstatsTable.from_file(f, sumstats_format=sumstats_format, parser=parser, **parse_kwargs)
ss_tab = SumstatsTable.from_file(f,
sumstats_format=sumstats_format,
parser=parser,
**parse_kwargs)

if drop_duplicated:
ss_tab.drop_duplicates()
Expand All @@ -565,6 +568,19 @@ def read_summary_statistics(self,

self.sumstats_table.update(ss_tab.split_by_chromosome(snps_per_chrom=ref_table))

# If SNP information is not present in the sumstats tables, try to impute it
# using other reference tables:

missing_snp = any('SNP' not in ss.table.columns for ss in self.sumstats_table.values())

if missing_snp and (self.genotype is not None or self.ld is not None):

ref_table = self.to_snp_table(col_subset=['CHR', 'POS', 'SNP'], per_chromosome=True)

for c, ss in self.sumstats_table.items():
if 'SNP' not in ss.table.columns and c in ref_table:
ss.infer_snp_id(ref_table[c], allow_na=True)

def read_ld(self, ld_store_paths):
"""
Read the LD matrix files stored on-disk in Zarr array format.
Expand Down Expand Up @@ -672,6 +688,10 @@ def harmonize_data(self):
However, if you read or manipulate the data sources after initialization,
you may need to call this method again to ensure that the data sources remain aligned.
!!! warning
Harmonization for now depends on having SNP rsID be present in all the resources. Hopefully
this requirement will be relaxed in the future.
"""

data_sources = (self.genotype, self.sumstats_table, self.ld, self.annotation)
Expand Down Expand Up @@ -705,8 +725,8 @@ def harmonize_data(self):

else:

# Find the set of SNPs that are shared across all data sources:
common_snps = np.array(list(set.intersection(*[set(ds[c].snps)
# Find the set of SNPs that are shared across all data sources (exclude missing values):
common_snps = np.array(list(set.intersection(*[set(ds[c].snps[~pd.isnull(ds[c].snps)])
for ds in initialized_data_sources])))

# If necessary, filter the data sources to only have the common SNPs:
Expand All @@ -717,10 +737,17 @@ def harmonize_data(self):
# Harmonize the summary statistics data with either genotype or LD reference.
# This procedure checks for flips in the effect allele between data sources.
if self.sumstats_table is not None:

id_cols = self.sumstats_table[c].identifier_cols

if self.genotype is not None:
self.sumstats_table[c].match(self.genotype[c].get_snp_table(col_subset=['SNP', 'A1', 'A2']))
self.sumstats_table[c].match(self.genotype[c].get_snp_table(
col_subset=id_cols + ['A1', 'A2']
))
elif self.ld is not None:
self.sumstats_table[c].match(self.ld[c].to_snp_table(col_subset=['SNP', 'A1', 'A2']))
self.sumstats_table[c].match(self.ld[c].to_snp_table(
col_subset=id_cols + ['A1', 'A2']
))

# If during the allele matching process we discover incompatibilities,
# we filter those SNPs:
Expand Down Expand Up @@ -763,15 +790,17 @@ def score(self, beta=None, standardize_genotype=False):
try:
beta = {c: s.marginal_beta or s.get_snp_pseudo_corr() for c, s in self.sumstats_table.items()}
except Exception:
raise ValueError("To perform linear scoring, you must provide effect size estimates (BETA)!")
raise ValueError("To perform linear scoring, you must "
"provide effect size estimates (BETA)!")

# Here, we have a very ugly way of accounting for
# the fact that the chromosomes may be coded differently between the genotype
# and the beta dictionary. Maybe we can find a better solution in the future.
common_chr_g, common_chr_b = match_chromosomes(self.genotype.keys(), beta.keys(), return_both=True)

if len(common_chr_g) < 1:
raise ValueError("No common chromosomes found between the genotype and the effect size estimates!")
raise ValueError("No common chromosomes found between "
"the genotype and the effect size estimates!")

if self.verbose and len(common_chr_g) < 2:
print("> Generating polygenic scores...")
Expand Down Expand Up @@ -831,7 +860,7 @@ def to_phenotype_table(self):

return self.sample_table.get_phenotype_table()

def to_snp_table(self, col_subset=None, per_chromosome=False):
def to_snp_table(self, col_subset=None, per_chromosome=False, resource='auto'):
"""
Get a dataframe of SNP data for all variants
across different chromosomes.
Expand All @@ -840,21 +869,40 @@ def to_snp_table(self, col_subset=None, per_chromosome=False):
:param per_chromosome: If True, returns a dictionary where the key
is the chromosome number and the value is the SNP table per
chromosome.
:param resource: The data source to extract the SNP table from. By default, the method
will try to extract the SNP table from the genotype matrix. If the genotype matrix is not
available, then it will try to extract the SNP information from the LD matrix or the summary
statistics table. Possible values: `auto`, `genotype`, `ld`, `sumstats`.
:return: A dataframe (or dictionary of dataframes) of SNP data.
"""

# Sanity checks:
assert resource in ('auto', 'genotype', 'ld', 'sumstats')
if resource != 'auto':
if resource == 'genotype' and self.genotype is None:
raise ValueError("Genotype matrix is not available!")
if resource == 'ld' and self.ld is None:
raise ValueError("LD matrix is not available!")
if resource == 'sumstats' and self.sumstats_table is None:
raise ValueError("Summary statistics table is not available!")
else:
if all(ds is None for ds in (self.genotype, self.ld, self.sumstats_table)):
raise ValueError("No data sources available to extract SNP data from!")

# Extract the SNP data:

snp_tables = {}

for c in self.chromosomes:
if self.sumstats_table is not None:
snp_tables[c] = self.sumstats_table[c].to_table(col_subset=col_subset)
elif self.genotype is not None:
if resource in ('auto', 'genotype') and self.genotype is not None:
for c in self.chromosomes:
snp_tables[c] = self.genotype[c].get_snp_table(col_subset=col_subset)
elif self.ld is not None:
elif resource in ('auto', 'ld') and self.ld is not None:
for c in self.chromosomes:
snp_tables[c] = self.ld[c].to_snp_table(col_subset=col_subset)
else:
raise ValueError("GWADataLoader instance is not properly initialized!")
else:
return self.to_summary_statistics_table(col_subset=col_subset,
per_chromosome=per_chromosome)

if per_chromosome:
return snp_tables
Expand Down
58 changes: 58 additions & 0 deletions magenpy/LDMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,40 @@ def indptr(self):
else:
return self._zg['matrix/indptr']

def filter_long_range_ld_regions(self):
"""
A utility method to exclude variants that are in long-range LD regions. The
boundaries of those regions are derived from here:
https://genome.sph.umich.edu/wiki/Regions_of_high_linkage_disequilibrium_(LD)
Which is based on the work of
> Anderson, Carl A., et al. "Data quality control in genetic case-control association studies." Nature protocols 5.9 (2010): 1564-1573.
.. note ::
This method is experimental and may not work as expected for all LD matrices.
"""

from .parsers.annotation_parsers import parse_annotation_bed_file
from .utils.data_utils import lrld_path

bed_df = parse_annotation_bed_file(lrld_path())

# Filter to only regions specific to the chromosome of this matrix:
bed_df = bed_df.loc[bed_df['CHR'] == self.chromosome]

bp_pos = self.bp_position
snp_mask = np.ones(len(bp_pos), dtype=bool)

# Loop over the LRLD region on this chromosome and exclude the SNPs in these regions:
for _, row in bed_df.iterrows():
start, end = row['Start'], row['End']
snp_mask &= ~((bp_pos >= start) & (bp_pos <= end))

# Filter the SNP to only those not in the LRLD regions:
self.filter_snps(self.snps[snp_mask])

def filter_snps(self, extract_snps=None, extract_file=None):
"""
Filter the LDMatrix to keep a subset of variants. This mainly sets
Expand Down Expand Up @@ -859,6 +893,30 @@ def reset_mask(self):
return_symmetric=self.is_symmetric,
dtype=self.dtype)

def prune(self, threshold):
"""
Perform LD pruning to remove variants that are in high LD with other variants.
If two variants are in high LD, this function keeps the variant that occurs
earlier in the matrix. This behavior will be updated in the future to allow
for arbitrary ordering of variants.
!!! note
Experimental for now. Needs further testing & improvement.
:param threshold: The absolute value of the Pearson correlation coefficient above which to prune variants.
:return: A boolean array indicating whether a variant is kept after pruning. A positive floating point number
between 0. and 1.
"""

from .stats.ld.c_utils import prune_ld_ut

assert 0. < threshold <= 1.

if np.issubdtype(self.stored_dtype, np.integer):
threshold = quantize(np.array([threshold]), int_dtype=self.stored_dtype)[0]

return prune_ld_ut(self.indptr[:], self.data[:], threshold)

def to_snp_table(self, col_subset=None):
"""
:param col_subset: The subset of columns to add to the table. If None, it returns
Expand Down
79 changes: 61 additions & 18 deletions magenpy/SumstatsTable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,18 @@ def __init__(self, ss_table: pd.DataFrame):
"""
self.table: pd.DataFrame = ss_table

assert all([col in self.table.columns for col in ('SNP', 'A1')])
# Check that the table contains some of the required columns (non exhaustive):

# Either has SNP or CHR+POS:
assert 'SNP' in self.table.columns or all([col in self.table.columns for col in ('CHR', 'POS')])
# Assert that the table has at least one of the alleles:
assert any([col in self.table.columns for col in ('A1', 'A2')])
# TODO: Add other assertions?

@property
def shape(self):
"""
:return: he shape of the summary statistics table.
:return: The shape of the summary statistics table.
"""
return self.table.shape

Expand All @@ -49,7 +55,8 @@ def __len__(self):
@property
def chromosome(self):
"""
A convenience method to return the chromosome number if there is only one chromosome in the summary statistics.
A convenience method to return the chromosome number if there is only
one chromosome in the summary statistics.
If multiple chromosomes are present, it returns None.
:return: The chromosome number if there is only one chromosome in the summary statistics.
Expand All @@ -76,6 +83,13 @@ def m(self):
"""
return self.n_snps

@property
def identifier_cols(self):
if 'SNP' in self.table.columns:
return ['SNP']
else:
return ['CHR', 'POS']

@property
def n_snps(self):
"""
Expand Down Expand Up @@ -341,21 +355,29 @@ def effect_sign(self):
def infer_a2(self, reference_table, allow_na=False):
"""
Infer the reference allele A2 (if not present in the SumstatsTable)
from a reference table. Make sure that the reference table contains the SNP ID,
the reference allele A2 and the alternative (i.e. effect) allele A1. It is the
user's responsibility to make sure that the reference table matches the summary
statistics in terms of the specification of reference vs. alternative. They are
allowed to be flipped, but they have to be consistent across the two tables.
from a reference table. Make sure that the reference table contains the identifier information
for each SNP, in addition to the reference allele A2 and the alternative (i.e. effect) allele A1.
It is the user's responsibility to make sure that the reference table matches the summary
statistics in terms of the specification of reference vs. alternative. They have to be consistent
across the two tables.
:param reference_table: A pandas table containing the following columns at least:
`SNP`, `A1`, `A2`.
SNP identifiers (`SNP` or `CHR` & `POS`) and allele information (`A1` & `A2`).
:param allow_na: If True, allow the reference allele to be missing from the final result.
"""

# Merge the summary statistics table with the reference table on `SNP` ID:
merged_table = self.table[['SNP', 'A1']].merge(reference_table[['SNP', 'A1', 'A2']],
how='left',
on='SNP')
# Get the identifier columns for this table:
id_cols = self.identifier_cols

# Sanity checks:
assert all([col in reference_table.columns for col in id_cols + ['A1', 'A2']])

# Merge the summary statistics table with the reference table on unique ID:
merged_table = self.table[id_cols + ['A1']].merge(
reference_table[id_cols + ['A1', 'A2']],
how='left',
on=id_cols
)
# If `A1_x` agrees with `A1_y`, then `A2` is indeed the reference allele.
# Otherwise, they are flipped and `A1_y` should be the reference allele:
merged_table['A2'] = np.where(merged_table['A1_x'] == merged_table['A1_y'],
Expand All @@ -368,6 +390,25 @@ def infer_a2(self, reference_table, allow_na=False):
else:
self.table['A2'] = merged_table['A2']

def infer_snp_id(self, reference_table, allow_na=False):
"""
Infer the SNP ID (if not present in the SumstatsTable) from a reference table.
Make sure that the reference table contains the SNP ID, chromosome ID, and position.
:param reference_table: A pandas table containing the following columns at least:
`SNP`, `CHR`, `POS`.
:param allow_na: If True, allow the SNP ID to be missing from the final result.
"""

# Merge the summary statistics table with the reference table:
merged_table = self.table[['CHR', 'POS']].merge(reference_table[['SNP', 'CHR', 'POS']], how='left')

# Check that the SNP ID could be inferred for all SNPs:
if not allow_na and merged_table['SNP'].isna().any():
raise ValueError("The SNP ID could not be inferred for some SNPs!")
else:
self.table['SNP'] = merged_table['SNP'].values

def set_sample_size(self, n):
"""
Set the sample size for each variant in the summary table.
Expand All @@ -394,14 +435,15 @@ def match(self, reference_table, correct_flips=True):
correcting for potential flips in the effect alleles.
:param reference_table: The SNP table to use as a reference. Must be a pandas
table with at least three columns: SNP, A1, A2.
table with the following columns: SNP identifier (either `SNP` or `CHR` & `POS`) and allele information
(`A1` & `A2`).
:param correct_flips: If True, correct the direction of effect size
estimates if the effect allele is reversed.
"""

from .utils.model_utils import merge_snp_tables

self.table = merge_snp_tables(ref_table=reference_table[['SNP', 'A1', 'A2']],
self.table = merge_snp_tables(ref_table=reference_table[self.identifier_cols + ['A1', 'A2']],
alt_table=self.table,
how='inner',
correct_flips=correct_flips)
Expand Down Expand Up @@ -457,14 +499,15 @@ def filter_snps(self, extract_snps=None, extract_file=None, extract_index=None):
self.table = self.table.iloc[extract_index, ].reset_index(drop=True)
else:
raise Exception("To filter a summary statistics table, you must provide "
"the list of SNPs, a file containing the list of SNPs, or a list of indices to retain.")
"the list of SNPs, a file containing the list of SNPs, "
"or a list of indices to retain.")

def drop_duplicates(self):
"""
Drop variants with duplicated rsIDs from the summary statistics table.
"""

self.table = self.table.drop_duplicates(subset='SNP', keep=False)
self.table = self.table.drop_duplicates(subset=self.identifier_cols, keep=False)

def get_col(self, col_name):
"""
Expand Down Expand Up @@ -583,7 +626,7 @@ def split_by_chromosome(self, snps_per_chrom=None):
if 'CHR' in self.table.columns:
chrom_tables = self.table.groupby('CHR')
return {
c: SumstatsTable(chrom_tables.get_group(c))
c: SumstatsTable(chrom_tables.get_group(c).copy())
for c in chrom_tables.groups
}
elif snps_per_chrom is not None:
Expand Down
Loading

0 comments on commit ea3dc4b

Please sign in to comment.