Skip to content

Commit

Permalink
Replaced IntervalsSkipList with OverlapDetector
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbenjamin committed Jan 14, 2018
1 parent b4d1ddd commit e78fa7c
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 510 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package org.broadinstitute.hellbender.engine;

import htsjdk.samtools.util.Locatable;
import htsjdk.samtools.util.OverlapDetector;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.collections.IntervalsSkipListOneContig;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.variant.GATKVariant;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

/**
* Immutable storage class.
Expand All @@ -21,7 +24,7 @@ public class ContextShard implements Serializable {
// the interval covered by this shard
public final SimpleInterval interval;
// variants that overlap with the shard.
public final IntervalsSkipListOneContig<GATKVariant> variants;
public final OverlapDetector<GATKVariant> variants;
// reads that start in the shard
public final List<GATKRead> reads;
// variants and reference for the particular read at the same index as this element.
Expand All @@ -38,7 +41,7 @@ public ContextShard(SimpleInterval interval) {
* Careful: this ctor takes ownership of the passed reads and ReadContextData array.
* Do not modify them after this call (ideally don't even keep a reference to them).
*/
private ContextShard(SimpleInterval interval, IntervalsSkipListOneContig<GATKVariant> variants, final List<GATKRead> reads, final List<ReadContextData> readContext) {
private ContextShard(SimpleInterval interval, OverlapDetector<GATKVariant> variants, final List<GATKRead> reads, final List<ReadContextData> readContext) {
this.interval = interval;
this.variants = variants;
this.reads = reads;
Expand All @@ -50,12 +53,8 @@ private ContextShard(SimpleInterval interval, IntervalsSkipListOneContig<GATKVar
* with the new interval. Reads, readContext, and the variants in readContext are unchanged.
*/
public ContextShard split(SimpleInterval newInterval) {
final IntervalsSkipListOneContig<GATKVariant> newVariants;
if (null==variants) {
newVariants = null;
} else {
newVariants = new IntervalsSkipListOneContig<>( variants.getOverlapping(newInterval) );
}
final OverlapDetector<GATKVariant> newVariants = variants == null ? null :
OverlapDetector.create( new ArrayList<>(variants.getOverlaps(newInterval)) );
return new ContextShard(newInterval, newVariants, reads, readContext);
}

Expand All @@ -64,7 +63,7 @@ public ContextShard split(SimpleInterval newInterval) {
* Note that readContext is unchanged (including the variants it may refer to).
*/
public ContextShard withVariants(List<GATKVariant> newVariants) {
return new ContextShard(this.interval, new IntervalsSkipListOneContig<>(newVariants), reads, readContext);
return new ContextShard(this.interval, OverlapDetector.create(newVariants), reads, readContext);
}

/**
Expand All @@ -86,8 +85,8 @@ public ContextShard withReadContext(List<ReadContextData> newReadContext) {
/**
* Returns the variants that overlap the query interval, in start-position order.
*/
public List<GATKVariant> variantsOverlapping(SimpleInterval interval) {
return variants.getOverlapping(interval);
public List<GATKVariant> variantsOverlapping(Locatable interval) {
return variants.getOverlaps(interval).stream().sorted(Comparator.comparingInt(GATKVariant::getStart)).collect(Collectors.toList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.google.common.base.Function;
import com.google.common.collect.Iterators;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.util.OverlapDetector;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
Expand All @@ -16,16 +17,13 @@
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.collections.IntervalsSkipList;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.reference.ReferenceBases;
import org.broadinstitute.hellbender.utils.variant.GATKVariant;
import scala.Tuple2;

import javax.annotation.Nullable;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.*;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -106,7 +104,7 @@ private static JavaPairRDD<GATKRead, ReadContextData> addUsingOverlapsPartitioni
.collect(Collectors.toList());

final Broadcast<ReferenceMultiSource> bReferenceSource = ctx.broadcast(referenceSource);
final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = variantsPaths == null ? ctx.broadcast(new IntervalsSkipList<>(variants.collect())) : null;
final Broadcast<OverlapDetector<GATKVariant>> variantsBroadcast = variantsPaths == null ? ctx.broadcast(OverlapDetector.create(variants.collect())) : null;

int maxLocatableSize = Math.min(shardSize, shardPadding);
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, mappedReads, GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize);
Expand All @@ -119,20 +117,15 @@ public Iterator<Tuple2<GATKRead, ReadContextData>> call(Shard<GATKRead> shard) t
// get reference bases for this shard (padded)
SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(shardPadding, sequenceDictionary);
ReferenceBases referenceBases = bReferenceSource.getValue().getReferenceBases(paddedInterval);
final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsPaths == null ? variantsBroadcast.getValue() :
final OverlapDetector<GATKVariant> overlapDetector = variantsPaths == null ? variantsBroadcast.getValue() :
KnownSitesCache.getVariants(variantsPaths);
Iterator<Tuple2<GATKRead, ReadContextData>> transform = Iterators.transform(shard.iterator(), new Function<GATKRead, Tuple2<GATKRead, ReadContextData>>() {
@Nullable
@Override
public Tuple2<GATKRead, ReadContextData> apply(@Nullable GATKRead r) {
List<GATKVariant> overlappingVariants;
if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
overlappingVariants = intervalsSkipList.getOverlapping(new SimpleInterval(r));
} else {
//Sometimes we have reads that do not form valid intervals (reads that do not consume any ref bases, eg CIGAR 61S90I
//In those cases, we'll just say that nothing overlaps the read
overlappingVariants = Collections.emptyList();
}
final List<GATKVariant> overlappingVariants = SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())
? overlapDetector.getOverlaps(r).stream().sorted(Comparator.comparingInt(GATKVariant::getStart)).collect(Collectors.toList())
: Collections.emptyList();
return new Tuple2<>(r, new ReadContextData(referenceBases, overlappingVariants));
}
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.broadinstitute.hellbender.engine.spark;

import htsjdk.samtools.util.OverlapDetector;
import org.broadinstitute.hellbender.utils.SerializableFunction;

import htsjdk.samtools.SAMRecord;
Expand All @@ -20,7 +21,6 @@
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.collections.IntervalsSkipList;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
Expand All @@ -33,10 +33,8 @@
import java.io.IOException;
import java.io.Serializable;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.*;
import java.util.stream.Collectors;


public final class AddContextDataToReadSparkOptimized implements Serializable {
Expand Down Expand Up @@ -325,7 +323,7 @@ private void throwIfOutsideMargin(SAMRecordToGATKReadAdapter g, SAMRecord r) {
* This happens immediately, at the caller.
*/
public static ArrayList<ContextShard> fillVariants(List<SimpleInterval> shardedIntervals, List<GATKVariant> variants, int margin) {
IntervalsSkipList<GATKVariant> intervals = new IntervalsSkipList<>(variants);
OverlapDetector<GATKVariant> intervals = OverlapDetector.create(variants);
ArrayList<ContextShard> ret = new ArrayList<>();
for (SimpleInterval s : shardedIntervals) {
int start = Math.max(s.getStart() - margin, 1);
Expand All @@ -344,7 +342,7 @@ public static ArrayList<ContextShard> fillVariants(List<SimpleInterval> shardedI
//
// Since the read's length is less than margin, we know that by including all the variants that overlap
// with the expanded interval we are also including all the variants that overlap with all the reads in this shard.
ret.add(new ContextShard(s).withVariants(intervals.getOverlapping(expandedInterval)));
ret.add(new ContextShard(s).withVariants(intervals.getOverlaps(expandedInterval).stream().sorted(Comparator.comparingInt(GATKVariant::getStart)).collect(Collectors.toList())));
}
return ret;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package org.broadinstitute.hellbender.engine.spark;

import htsjdk.samtools.util.OverlapDetector;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.hellbender.utils.collections.IntervalsSkipList;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.variant.GATKVariant;
Expand Down Expand Up @@ -32,7 +32,7 @@ private BroadcastJoinReadsWithVariants(){}
*/
public static JavaPairRDD<GATKRead, Iterable<GATKVariant>> join(final JavaRDD<GATKRead> reads, final JavaRDD<GATKVariant> variants) {
final JavaSparkContext ctx = new JavaSparkContext(reads.context());
final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(new IntervalsSkipList<>(variants.collect()));
final Broadcast<OverlapDetector<GATKVariant>> variantsBroadcast = ctx.broadcast(OverlapDetector.create(variants.collect()));
return reads.mapToPair(r -> getOverlapping(r, variantsBroadcast.getValue()));
}

Expand All @@ -48,9 +48,9 @@ public static JavaPairRDD<GATKRead, Iterable<GATKVariant>> join(final JavaRDD<GA
return reads.mapToPair(r -> getOverlapping(r, KnownSitesCache.getVariants(variantsPaths)));
}

private static Tuple2<GATKRead, Iterable<GATKVariant>> getOverlapping(final GATKRead read, final IntervalsSkipList<GATKVariant> intervalsSkipList) {
private static Tuple2<GATKRead, Iterable<GATKVariant>> getOverlapping(final GATKRead read, final OverlapDetector<GATKVariant> overlapDetector) {
if (SimpleInterval.isValid(read.getContig(), read.getStart(), read.getEnd())) {
return new Tuple2<>(read, intervalsSkipList.getOverlapping(new SimpleInterval(read)));
return new Tuple2<>(read, overlapDetector.getOverlaps(read));
} else {
//Sometimes we have reads that do not form valid intervals (reads that do not consume any ref bases, eg CIGAR 61S90I
//In those cases, we'll just say that nothing overlaps the read
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package org.broadinstitute.hellbender.engine.spark;

import htsjdk.samtools.util.OverlapDetector;
import htsjdk.variant.variantcontext.VariantContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.engine.FeatureDataSource;
import org.broadinstitute.hellbender.utils.collections.IntervalsSkipList;
import org.broadinstitute.hellbender.utils.variant.GATKVariant;
import org.broadinstitute.hellbender.utils.variant.VariantContextVariantAdapter;

Expand All @@ -19,19 +19,19 @@ class KnownSitesCache {

private static final Logger log = LogManager.getLogger(KnownSitesCache.class);

private static final Map<List<String>, IntervalsSkipList<GATKVariant>> PATHS_TO_VARIANTS = new HashMap<>();
private static final Map<List<String>, OverlapDetector<GATKVariant>> PATHS_TO_VARIANTS = new HashMap<>();

public static synchronized IntervalsSkipList<GATKVariant> getVariants(List<String> paths) {
public static synchronized OverlapDetector<GATKVariant> getVariants(List<String> paths) {
if (PATHS_TO_VARIANTS.containsKey(paths)) {
return PATHS_TO_VARIANTS.get(paths);
}
IntervalsSkipList<GATKVariant> variants = retrieveVariants(paths);
OverlapDetector<GATKVariant> variants = retrieveVariants(paths);
PATHS_TO_VARIANTS.put(paths, variants);
return variants;
}

private static IntervalsSkipList<GATKVariant> retrieveVariants(List<String> paths) {
return new IntervalsSkipList<>(paths
private static OverlapDetector<GATKVariant> retrieveVariants(List<String> paths) {
return OverlapDetector.create(paths
.stream()
.map(KnownSitesCache::loadFromFeatureDataSource)
.flatMap(Collection::stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ private void setHadoopBAMConfigurationProperties(final String inputName, final S
* Tests if a given SAMRecord overlaps any interval in a collection. This is only used as a fallback option for
* formats that don't support query-by-interval natively at the Hadoop-BAM layer.
*/
//TODO: use IntervalsSkipList, see https://github.com/broadinstitute/gatk/issues/1531
//TODO: use OverlapDetector, see https://github.com/broadinstitute/gatk/issues/1531
private static boolean samRecordOverlaps(final SAMRecord record, final TraversalParameters traversalParameters ) {
if (traversalParameters == null) {
return true;
Expand Down

This file was deleted.

Loading

0 comments on commit e78fa7c

Please sign in to comment.