Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'chr' annotation in fasta and vcf can differ #64

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions mmsplice/exon_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kipoiseq.dataclasses import Interval, Variant
from kipoi.data import Dataset
from kipoiseq.extractors import VariantSeqExtractor
from mmsplice.utils import encodeDNA, region_annotate
from mmsplice.utils import encodeDNA, region_annotate, normalise_chrom

logger = logging.getLogger('mmsplice')

Expand Down Expand Up @@ -226,16 +226,31 @@ def __init__(self, fasta_file, split_seq=True, encode=True,
self.spliter = seq_spliter or SeqSpliter()
self.vseq_extractor = ExonVariantSeqExtrator(fasta_file)
self.fasta = self.vseq_extractor.fasta
self.fasta_contains_chr = self._fasta_contains_chr()
self.tissue_specific = tissue_specific
self.tissue_overhang = tissue_overhang

def _fasta_contains_chr(self):
fasta_chroms = set(self.fasta.fasta.keys())
return any(chrom.startswith('chr')
for chrom in fasta_chroms)

def _normalise_chrom_to_fasta(self, chrom):
if self.fasta_contains_chr:
chrom = normalise_chrom(chrom, 'chr1') #add 'chr' to chrom
else:
chrom = normalise_chrom(chrom, '1') #remove 'chr' from chrom
return chrom

def _next(self, exon, variant, overhang=None, mask_module=None):
overhang = overhang or self.overhang

exon._chrom = self._normalise_chrom_to_fasta(exon.chrom)

inputs = {
'seq': self.fasta.extract(Interval(
exon.chrom, exon.start - overhang[0],
exon.end + overhang[1], strand=exon.strand)).upper(),
exon.end + overhang[1], strand=exon.strand), use_strand=True).upper(), #with older kipoiseq version delete use_strand=True
'mut_seq': self.vseq_extractor.extract(
exon, [variant], overhang=overhang).upper()
}
Expand Down Expand Up @@ -271,6 +286,9 @@ def _next(self, exon, variant, overhang=None, mask_module=None):
if self.encode:
inputs = {k: self._encode_seq(v) for k, v in inputs.items()}

# normalise chrom annotation for output
exon._chrom = normalise_chrom(exon.chrom, variant.chrom)

return {
'inputs': inputs,
'metadata': {
Expand Down Expand Up @@ -388,6 +406,16 @@ def _check_chrom_annotation(self):
fasta_chroms = set(self.fasta.fasta.keys())
exon_chroms = set(self.exons['Chromosome'])

if not fasta_chroms.intersection(exon_chroms):
logger.warning(
'Mismatch of chromosome names in Fasta and VCF file.')
chr_annotaion = any(chrom.startswith('chr')
for chrom in exon_chroms)
if not chr_annotaion:
fasta_chroms = set([normalise_chrom(x, '1') for x in sorted(fasta_chroms)])
else:
fasta_chroms = set([normalise_chrom(x, 'chr1') for x in sorted(fasta_chroms)])

if not fasta_chroms.intersection(exon_chroms):
raise ValueError(
'Fasta chrom names do not match with vcf chrom names')
Expand Down
13 changes: 12 additions & 1 deletion mmsplice/mmsplice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
predict_pathogenicity, predict_splicing_efficiency, encodeDNA, \
read_ref_psi_annotation, delta_logit_PSI_to_delta_PSI, \
mmsplice_ref_modules, mmsplice_alt_modules, \
df_batch_writer, df_batch_writer_parquet
df_batch_writer, df_batch_writer_parquet, normalise_chrom
from mmsplice.exon_dataloader import SeqSpliter
from mmsplice.mtsplice import MTSplice, tissue_names
from mmsplice.layers import GlobalAveragePooling1D_Mask0, ConvDNA
Expand Down Expand Up @@ -154,6 +154,17 @@ def _predict_batch_mtsplice(self, batch, df, mtsplice,
df = pd.concat([df, tissue_pred], axis=1)

if natural_scale:

chr_ref = any(chrom.startswith('chr') for chrom in df_ref.index)
chr_preds = any(chrom.startswith('chr') for chrom in df['exons'].values)
if chr_ref != chr_preds:
df_ref = df_ref.reset_index()
if chr_preds:
df_ref['exons'] = df_ref['exons'].apply(lambda x: normalise_chrom(x, 'chr1'))
else:
df_ref['exons'] = df_ref['exons'].apply(lambda x: normalise_chrom(x, '1'))
df_ref = df_ref.set_index('exons')

df_ref = df_ref[df_ref.columns[6:]]
df = df.join(df_ref, on='exons', rsuffix='_ref')

Expand Down
13 changes: 13 additions & 0 deletions mmsplice/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ def pyrange_add_chr_from_chrom_annotation(pr):
return pyranges.PyRanges(df)


def normalise_chrom(source, target):

def has_prefix(x):
return x.startswith('chr')

if has_prefix(source) and not has_prefix(target):
return source.strip('chr')
elif not has_prefix(source) and has_prefix(target):
return 'chr'+source

return source


def max_varEff(df):
""" Summarize largest absolute effect per variant across all affected exons.
Args:
Expand Down
12 changes: 11 additions & 1 deletion mmsplice/vcf_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from kipoi.data import SampleIterator
from kipoiseq.extractors import MultiSampleVCF, SingleVariantMatcher
from mmsplice.utils import pyrange_remove_chr_from_chrom_annotation,\
pyrange_add_chr_from_chrom_annotation
pyrange_add_chr_from_chrom_annotation, normalise_chrom
from mmsplice.exon_dataloader import ExonSplicingMixin

logger = logging.getLogger('mmsplice')
Expand Down Expand Up @@ -81,6 +81,16 @@ def _check_chrom_annotation(self):
fasta_chroms = set(self.fasta.fasta.keys())
vcf_chroms = set(self.vcf.seqnames)

if not fasta_chroms.intersection(vcf_chroms):
logger.warning(
'Mismatch of chromosome names in Fasta and VCF file.')
chr_annotaion = any(chrom.startswith('chr')
for chrom in vcf_chroms)
if not chr_annotaion:
fasta_chroms = set([normalise_chrom(x, '1') for x in sorted(fasta_chroms)])
else:
fasta_chroms = set([normalise_chrom(x, 'chr1') for x in sorted(fasta_chroms)])

if not fasta_chroms.intersection(vcf_chroms):
raise ValueError(
'Fasta chrom names do not match with vcf chrom names')
Expand Down
12 changes: 9 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@


snps = [
"17:41276033:C:['G']"
"17:41276033:C:['G']",
"17:41203228:T:['A']", # delta_logit_psi over 10, potential minus strand error due to kipoiseq version
]

deletions = [
Expand Down Expand Up @@ -51,13 +52,18 @@ def parse_vcf_id(vcf_id):

@pytest.fixture
def vcf_path():

chr_annotation = 'chr'
# chr_annotation = ''

with tempfile.NamedTemporaryFile('w') as temp_vcf:
temp_vcf.write('##fileformat=VCFv4.0\n')
temp_vcf.write('##contig=<ID=13,length=115169878>\n')
temp_vcf.write('##contig=<ID=17,length=81195210>\n')
temp_vcf.write(f'##contig=<ID={chr_annotation}13,length=115169878>\n')
temp_vcf.write(f'##contig=<ID={chr_annotation}17,length=81195210>\n')
temp_vcf.write('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\n')

for v in variants:
v = chr_annotation + v
temp_vcf.write('%s\t%s\t1\t%s\t%s\t.\t.\t.\n'
% tuple(parse_vcf_id(v)))

Expand Down
24 changes: 23 additions & 1 deletion tests/test_exon_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
from conftest import fasta_file, exon_file
from mmsplice.exon_dataloader import ExonDataset
from mmsplice.exon_dataloader import ExonDataset, ExonSplicingMixin
from kipoiseq.dataclasses import Interval, Variant


def test_ExonDataset():
Expand Down Expand Up @@ -62,3 +63,24 @@ def test_ExonDataset__len__():
dl = ExonDataset(exon_file, fasta_file)
df = pd.read_csv(exon_file)
assert len(dl) == df.shape[0]


def test_ExonSplicingMixin_extract_seq_in_mmsplice(vcf_path):
import numpy as np
from conftest import gtf_file, fasta_file
from mmsplice.vcf_dataloader import SplicingVCFDataloader

dl = SplicingVCFDataloader(gtf_file, fasta_file, vcf_path)

rows = list(dl)
row = rows[0]
assert 41203228 == row['metadata']['variant']['pos']

mutated_seq = False
for module in ['acceptor_intron', 'acceptor', 'exon', 'donor', 'donor_intron']:
# for snvs there should be only one mutated nucleotide -> sum over different entries in one-hot encoded vector == 2
assert (pd.DataFrame(row['inputs'])['seq'][module] != pd.DataFrame(row['inputs'])['mut_seq'][module]).sum() <= 2
if (pd.DataFrame(row['inputs'])['seq'][module] != pd.DataFrame(row['inputs'])['mut_seq'][module]).sum() == 2:
mutated_seq = True

assert mutated_seq == True
4 changes: 3 additions & 1 deletion tests/test_mmsplice.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def test_predict_all_table(vcf_path):
dl = SplicingVCFDataloader(gtf_file, fasta_file, vcf_path)
df = predict_all_table(model, dl, pathogenicity=True,
splicing_efficiency=True)


assert df[df['ID'] == 'chr17:41203228:T>A'].iloc[0]['delta_logit_psi'] < 10

assert len(df['delta_logit_psi']) == len(variants) - 1
assert df.shape[1] == 8 + 10 + 2

Expand Down
16 changes: 14 additions & 2 deletions tests/test_vcf_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ def test_splicing_vcf_dataloader_prebuild_grch38(vcf_path):


def test_splicing_vcf_loads_all(vcf_path):
def _format_variants(variant):
chrom = variant.split(':')[0]
pos = variant.split(':')[1]
ref = variant.split(':')[2]
alt = variant.split(':')[3].split("'")[1]
return chrom + ':' + pos + ':' + ref + '>' + alt

dl = SplicingVCFDataloader(gtf_file, fasta_file, vcf_path)
variants_dl = [i['metadata']['variant']['annotation'] for i in dl]
variants_dl = [x.strip('chr') for x in variants_dl]
variants_conftest = [_format_variants(x) for x in variants]
assert len(set(variants_conftest).difference(set(variants_dl))) == 1

dl = SplicingVCFDataloader(gtf_file, fasta_file, vcf_path)
assert sum(1 for i in dl) == len(variants) - 1

Expand Down Expand Up @@ -385,10 +398,9 @@ def test_splicing_vcf_loads_snps(vcf_path):
}

rows = list(dl)
# d = rows[11]

for i in rows:
if '17:41276033:C>G' == i['metadata']['variant']['annotation']:
if '17:41276033:C>G' in i['metadata']['variant']['annotation']: # use 'in' here because could be with or without 'chr' annotation
d = i

assert d['inputs']['seq'] == expected_snps_seq['seq']
Expand Down