Skip to content

Commit

Permalink
MarkDuplicatesSpark improvements checkpoint (broadinstitute#4656)
Browse files Browse the repository at this point in the history
Co-authored-by: Louis Bergelson <[email protected]>

First part of a major rewrite of MarkDuplicatesSpark to improve performance. Tool still has a number of known issues, but is much faster than the previous version.
  • Loading branch information
jamesemery authored and cwhelan committed May 25, 2018
1 parent dd2955f commit 5e47ec2
Show file tree
Hide file tree
Showing 28 changed files with 1,119 additions and 432 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.broadinstitute.hellbender.cmdline.argumentcollections;

import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.hellbender.tools.spark.transforms.markduplicates.MarkDuplicatesSpark;
import org.broadinstitute.hellbender.utils.read.markduplicates.MarkDuplicatesScoringStrategy;
import org.broadinstitute.hellbender.utils.read.markduplicates.OpticalDuplicateFinder;

import java.io.Serializable;


/**
* An argument collection for use with tools that mark optical
* duplicates.
*/
public final class MarkDuplicatesSparkArgumentCollection implements Serializable {
private static final long serialVersionUID = 1L;

@Argument(shortName = "DS", fullName = "DUPLICATE_SCORING_STRATEGY", doc = "The scoring strategy for choosing the non-duplicate among candidates.")
public MarkDuplicatesScoringStrategy duplicatesScoringStrategy = MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES;

@Argument(fullName = MarkDuplicatesSpark.DO_NOT_MARK_UNMAPPED_MATES, doc = "Enabling this option will mean unmapped mates of duplicate marked reads will not be marked as duplicates.")
public boolean dontMarkUnmappedMates = false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import htsjdk.samtools.*;
import org.apache.spark.serializer.KryoRegistrator;
import org.bdgenomics.adam.serialization.ADAMKryoRegistrator;
import org.broadinstitute.hellbender.tools.spark.transforms.markduplicates.MarkDuplicatesSparkUtils;
import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter;
import org.broadinstitute.hellbender.utils.read.markduplicates.PairedEnds;
import org.broadinstitute.hellbender.utils.read.markduplicates.sparkrecords.*;

import java.util.Collections;

Expand Down Expand Up @@ -77,8 +78,10 @@ private void registerGATKClasses(Kryo kryo) {
kryo.register(SAMFileHeader.SortOrder.class);
kryo.register(SAMProgramRecord.class);
kryo.register(SAMReadGroupRecord.class);

//register to avoid writing the full name of this class over and over
kryo.register(PairedEnds.class, new FieldSerializer<>(kryo, PairedEnds.class));
kryo.register(EmptyFragment.class, new FieldSerializer(kryo, EmptyFragment.class));
kryo.register(Fragment.class, new FieldSerializer(kryo, Fragment.class));
kryo.register(Pair.class, new Pair.Serializer());
kryo.register(Passthrough.class, new FieldSerializer(kryo, Passthrough.class));
kryo.register(MarkDuplicatesSparkUtils.IndexPair.class, new FieldSerializer(kryo, MarkDuplicatesSparkUtils.IndexPair.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.broadcast.Broadcast;
import org.bdgenomics.formats.avro.AlignmentRecord;
Expand All @@ -27,10 +26,7 @@
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.BDGAlignmentRecordToGATKReadAdapter;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadConstants;
import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter;
import org.broadinstitute.hellbender.utils.read.*;
import org.broadinstitute.hellbender.utils.spark.SparkUtils;
import org.seqdoop.hadoop_bam.AnySAMInputFormat;
import org.seqdoop.hadoop_bam.BAMInputFormat;
Expand Down Expand Up @@ -211,10 +207,10 @@ public boolean accept(Path path) {

/**
* Ensure reads in a pair fall in the same partition (input split), if the reads are queryname-sorted,
* so they are processed together. No shuffle is needed.
* or querygroup sorted, so they are processed together. No shuffle is needed.
*/
JavaRDD<GATKRead> putPairsInSamePartition(final SAMFileHeader header, final JavaRDD<GATKRead> reads) {
if (!header.getSortOrder().equals(SAMFileHeader.SortOrder.queryname)) {
if (!ReadUtils.isReadNameGroupedBam(header)) {
return reads;
}
int numPartitions = reads.getNumPartitions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.argumentcollections.MarkDuplicatesSparkArgumentCollection;
import picard.cmdline.programgroups.ReadDataManipulationProgramGroup;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
Expand Down Expand Up @@ -47,20 +48,20 @@ public final class BwaAndMarkDuplicatesPipelineSpark extends GATKSparkTool {
@ArgumentCollection
public final BwaArgumentCollection bwaArgs = new BwaArgumentCollection();

@Argument(shortName = "DS", fullName ="duplicates_scoring_strategy", doc = "The scoring strategy for choosing the non-duplicate among candidates.")
public MarkDuplicatesScoringStrategy duplicatesScoringStrategy = MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES;

@Argument(doc = "the output bam", shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME,
fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME)
protected String output;

@ArgumentCollection
protected MarkDuplicatesSparkArgumentCollection markDuplicatesSparkArgumentCollection = new MarkDuplicatesSparkArgumentCollection();


@Override
protected void runTool(final JavaSparkContext ctx) {
try (final BwaSparkEngine bwaEngine = new BwaSparkEngine(ctx, referenceArguments.getReferenceFileName(), bwaArgs.indexImageFile, getHeaderForReads(), getReferenceSequenceDictionary())) {
final ReadFilter filter = makeReadFilter(bwaEngine.getHeader());
final JavaRDD<GATKRead> alignedReads = bwaEngine.alignPaired(getUnfilteredReads()).filter(filter::test);
final JavaRDD<GATKRead> markedReadsWithOD = MarkDuplicatesSpark.mark(alignedReads, bwaEngine.getHeader(), duplicatesScoringStrategy, new OpticalDuplicateFinder(), getRecommendedNumReducers());
final JavaRDD<GATKRead> markedReads = MarkDuplicatesSpark.cleanupTemporaryAttributes(markedReadsWithOD);
final JavaRDD<GATKRead> markedReads = MarkDuplicatesSpark.mark(alignedReads, bwaEngine.getHeader(), markDuplicatesSparkArgumentCollection.duplicatesScoringStrategy, new OpticalDuplicateFinder(), getRecommendedNumReducers(), markDuplicatesSparkArgumentCollection.dontMarkUnmappedMates);
try {
ReadsSparkSink.writeReads(ctx, output,
referenceArguments.getReferencePath().toAbsolutePath().toUri().toString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.argumentcollections.MarkDuplicatesSparkArgumentCollection;
import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup;
import org.broadinstitute.hellbender.engine.ReadContextData;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
Expand Down Expand Up @@ -113,10 +114,10 @@ public class ReadsPipelineSpark extends GATKSparkTool {
private JoinStrategy joinStrategy = JoinStrategy.BROADCAST;

@ArgumentCollection
public final BwaArgumentCollection bwaArgs = new BwaArgumentCollection();
protected MarkDuplicatesSparkArgumentCollection markDuplicatesSparkArgumentCollection = new MarkDuplicatesSparkArgumentCollection();

@Argument(shortName = "DS", fullName ="duplicates-scoring-strategy", doc = "The scoring strategy for choosing the non-duplicate among candidates.")
public MarkDuplicatesScoringStrategy duplicatesScoringStrategy = MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES;
@ArgumentCollection
public final BwaArgumentCollection bwaArgs = new BwaArgumentCollection();

/**
* all the command line arguments for BQSR and its covariates
Expand Down Expand Up @@ -166,8 +167,7 @@ protected void runTool(final JavaSparkContext ctx) {
header = getHeaderForReads();
}

final JavaRDD<GATKRead> markedReadsWithOD = MarkDuplicatesSpark.mark(alignedReads, header, duplicatesScoringStrategy, new OpticalDuplicateFinder(), getRecommendedNumReducers());
final JavaRDD<GATKRead> markedReads = MarkDuplicatesSpark.cleanupTemporaryAttributes(markedReadsWithOD);
final JavaRDD<GATKRead> markedReads = MarkDuplicatesSpark.mark(alignedReads, header, markDuplicatesSparkArgumentCollection.duplicatesScoringStrategy, new OpticalDuplicateFinder(), getRecommendedNumReducers(), markDuplicatesSparkArgumentCollection.dontMarkUnmappedMates);

// The markedReads have already had the WellformedReadFilter applied to them, which
// is all the filtering that MarkDupes and ApplyBQSR want. BQSR itself wants additional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.metrics.MetricsFile;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
Expand All @@ -11,19 +12,25 @@
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.argumentcollections.MarkDuplicatesSparkArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.OpticalDuplicatesArgumentCollection;
import picard.cmdline.programgroups.ReadDataManipulationProgramGroup;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadUtils;
import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter;
import org.broadinstitute.hellbender.utils.read.markduplicates.DuplicationMetrics;
import org.broadinstitute.hellbender.utils.read.markduplicates.MarkDuplicatesScoringStrategy;
import org.broadinstitute.hellbender.utils.read.markduplicates.OpticalDuplicateFinder;
import picard.cmdline.programgroups.ReadDataManipulationProgramGroup;
import scala.Tuple2;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@DocumentedFeature
@CommandLineProgramProperties(
Expand All @@ -33,6 +40,7 @@
@BetaFeature
public final class MarkDuplicatesSpark extends GATKSparkTool {
private static final long serialVersionUID = 1L;
public static final String DO_NOT_MARK_UNMAPPED_MATES = "do-not-mark-unmapped-mates";

@Override
public boolean requiresReads() { return true; }
Expand All @@ -45,8 +53,8 @@ public final class MarkDuplicatesSpark extends GATKSparkTool {
shortName = "M", fullName = "METRICS_FILE")
protected String metricsFile;

@Argument(shortName = "DS", fullName = "DUPLICATE_SCORING_STRATEGY", doc = "The scoring strategy for choosing the non-duplicate among candidates.")
public MarkDuplicatesScoringStrategy duplicatesScoringStrategy = MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES;
@ArgumentCollection
protected MarkDuplicatesSparkArgumentCollection markDuplicatesSparkArgumentCollection = new MarkDuplicatesSparkArgumentCollection();

@ArgumentCollection
protected OpticalDuplicatesArgumentCollection opticalDuplicatesArgumentCollection = new OpticalDuplicatesArgumentCollection();
Expand All @@ -58,13 +66,69 @@ public List<ReadFilter> getDefaultReadFilters() {

public static JavaRDD<GATKRead> mark(final JavaRDD<GATKRead> reads, final SAMFileHeader header,
final MarkDuplicatesScoringStrategy scoringStrategy,
final OpticalDuplicateFinder opticalDuplicateFinder, final int numReducers) {
final OpticalDuplicateFinder opticalDuplicateFinder,
final int numReducers, final boolean dontMarkUnmappedMates) {

JavaPairRDD<MarkDuplicatesSparkUtils.IndexPair<String>, Integer> namesOfNonDuplicates = MarkDuplicatesSparkUtils.transformToDuplicateNames(header, scoringStrategy, opticalDuplicateFinder, reads, numReducers);

// Here we explicitly repartition the read names of the unmarked reads to match the partitioning of the original bam
final JavaRDD<Tuple2<String,Integer>> repartitionedReadNames = namesOfNonDuplicates
.mapToPair(pair -> new Tuple2<>(pair._1.getIndex(), new Tuple2<>(pair._1.getValue(),pair._2)))
.partitionBy(new KnownIndexPartitioner(reads.getNumPartitions()))
.values();

// Here we combine the original bam with the repartitioned unmarked readnames to produce our marked reads
return reads.zipPartitions(repartitionedReadNames, (readsIter, readNamesIter) -> {
final Map<String,Integer> namesOfNonDuplicateReadsAndOpticalCounts = Utils.stream(readNamesIter).collect(Collectors.toMap(Tuple2::_1,Tuple2::_2));
return Utils.stream(readsIter).peek(read -> {
// Handle reads that have been marked as non-duplicates (which also get tagged with optical duplicate summary statistics)
if( namesOfNonDuplicateReadsAndOpticalCounts.containsKey(read.getName())) {
read.setIsDuplicate(false);
if (!(dontMarkUnmappedMates && read.isUnmapped())) {
int dupCount = namesOfNonDuplicateReadsAndOpticalCounts.replace(read.getName(), -1);
if (dupCount > -1) {
((SAMRecordToGATKReadAdapter) read).setTransientAttribute(MarkDuplicatesSparkUtils.OPTICAL_DUPLICATE_TOTAL_ATTRIBUTE_NAME, dupCount);
}
}
// Mark unmapped read pairs as non-duplicates
} else if (ReadUtils.readAndMateAreUnmapped(read)) {
read.setIsDuplicate(false);
// Everything else is a duplicate
} else{
if (!(dontMarkUnmappedMates && read.isUnmapped())) {
read.setIsDuplicate(true);
}
}
}).iterator();
});
}

/**
* A custom partitioner designed to cut down on spark shuffle costs.
* This is designed such that getPartition(key) is called on a key which corresponds to the already known target partition
*
* By storing the original partitioning for each read and passing it through the duplicates marking process it
* allows us to get away with just shuffling the small read name objects to the correct partition in the original bam
* while avoiding any shuffle of the larger read objects.
*/
private static class KnownIndexPartitioner extends Partitioner {
private static final long serialVersionUID = 1L;
private final int numPartitions;

KnownIndexPartitioner(int numPartitions) {
this.numPartitions = numPartitions;
}

JavaRDD<GATKRead> primaryReads = reads.filter(v1 -> !ReadUtils.isNonPrimary(v1));
JavaRDD<GATKRead> nonPrimaryReads = reads.filter(v1 -> ReadUtils.isNonPrimary(v1));
JavaRDD<GATKRead> primaryReadsTransformed = MarkDuplicatesSparkUtils.transformReads(header, scoringStrategy, opticalDuplicateFinder, primaryReads, numReducers);
@Override
public int numPartitions() {
return numPartitions;
}

return primaryReadsTransformed.union(nonPrimaryReads);
@Override
@SuppressWarnings("unchecked")
public int getPartition(Object key) {
return (Integer) key;
}
}

@Override
Expand All @@ -73,27 +137,17 @@ protected void runTool(final JavaSparkContext ctx) {
final OpticalDuplicateFinder finder = opticalDuplicatesArgumentCollection.READ_NAME_REGEX != null ?
new OpticalDuplicateFinder(opticalDuplicatesArgumentCollection.READ_NAME_REGEX, opticalDuplicatesArgumentCollection.OPTICAL_DUPLICATE_PIXEL_DISTANCE, null) : null;

final JavaRDD<GATKRead> finalReadsForMetrics = mark(reads, getHeaderForReads(), duplicatesScoringStrategy, finder, getRecommendedNumReducers());
final SAMFileHeader header = getHeaderForReads();
final JavaRDD<GATKRead> finalReadsForMetrics = mark(reads, header, markDuplicatesSparkArgumentCollection.duplicatesScoringStrategy, finder, getRecommendedNumReducers(), markDuplicatesSparkArgumentCollection.dontMarkUnmappedMates);

if (metricsFile != null) {
final JavaPairRDD<String, DuplicationMetrics> metricsByLibrary = MarkDuplicatesSparkUtils.generateMetrics(getHeaderForReads(), finalReadsForMetrics);
final JavaPairRDD<String, DuplicationMetrics> metricsByLibrary = MarkDuplicatesSparkUtils.generateMetrics(
header, finalReadsForMetrics);
final MetricsFile<DuplicationMetrics, Double> resultMetrics = getMetricsFile();
MarkDuplicatesSparkUtils.saveMetricsRDD(resultMetrics, getHeaderForReads(), metricsByLibrary, metricsFile);
MarkDuplicatesSparkUtils.saveMetricsRDD(resultMetrics, header, metricsByLibrary, metricsFile);
}

final JavaRDD<GATKRead> finalReads = cleanupTemporaryAttributes(finalReadsForMetrics);
writeReads(ctx, output, finalReads);
header.setSortOrder(SAMFileHeader.SortOrder.coordinate);
writeReads(ctx, output, finalReadsForMetrics, header);
}


/**
* The OD attribute was added to each read for optical dups.
* Now we have to clear it to avoid polluting the output.
*/
public static JavaRDD<GATKRead> cleanupTemporaryAttributes(final JavaRDD<GATKRead> reads) {
return reads.map(read -> {
read.clearAttribute(MarkDuplicatesSparkUtils.OPTICAL_DUPLICATE_TOTAL_ATTRIBUTE_NAME);
return read;
});
}
}
Loading

0 comments on commit 5e47ec2

Please sign in to comment.