Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port of nvscorevariants into GATK, with a basic tool frontend #8004

Merged
merged 17 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ plugins {
id "application" // provides installDist
id 'maven-publish'
id 'signing'
id "jacoco"
// id "jacoco"
id "de.undercouch.download" version "5.4.0" //used for downloading GSA lib
id "com.github.johnrengelman.shadow" version "8.1.1" //used to build the shadow and sparkJars
id "com.github.ben-manes.versions" version "0.12.0" //used for identifying dependencies that need updating
Expand Down Expand Up @@ -625,17 +625,22 @@ task bundle(type: Zip) {
}
}

jacocoTestReport {
//jacocoTestReport {
// dependsOn test
//
// group = "Reporting"
// description = "Generate Jacoco coverage reports after running tests."
// getAdditionalSourceDirs().from(sourceSets.main.allJava.srcDirs)
//
// reports {
// xml.required = true
// html.required = true
// }
//}
//}

task jacocoTestReport {
dependsOn test

group = "Reporting"
description = "Generate Jacoco coverage reports after running tests."
getAdditionalSourceDirs().from(sourceSets.main.allJava.srcDirs)

reports {
xml.required = true
html.required = true
}
}

task condaStandardEnvironmentDefinition(type: Copy) {
Expand Down Expand Up @@ -687,6 +692,13 @@ task localDevCondaEnv(type: Exec) {
commandLine "conda", "env", "create", "--yes", "-f", gatkCondaYML
}

task localDevCondaUpdate(type: Exec) {
dependsOn 'condaEnvironmentDefinition'
inputs.file("$buildDir/$pythonPackageArchiveName")
workingDir "$buildDir"
commandLine "conda", "env", "update", "-f", gatkCondaYML
}

task javadocJar(type: Jar, dependsOn: javadoc) {
archiveClassifier = 'javadoc'
from "$docBuildDir/javadoc"
Expand Down
45 changes: 24 additions & 21 deletions scripts/docker/dockertest.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ buildscript {

plugins {
id "java" // set up default java compile and test tasks
id "jacoco"
// id "jacoco"
}

repositories {
Expand Down Expand Up @@ -113,9 +113,9 @@ def getJVMArgs(runtimeAddOpens, testAddOpens) {

test {
jvmArgs = getJVMArgs(runtimeAddOpens, testAddOpens)
jacoco {
jvmArgs = getJVMArgs(runtimeAddOpens, testAddOpens)
}
// jacoco {
// jvmArgs = getJVMArgs(runtimeAddOpens, testAddOpens)
// }
}

task testOnPackagedReleaseJar(type: Test){
Expand Down Expand Up @@ -153,22 +153,25 @@ task testOnPackagedReleaseJar(type: Test){

// Task intended to collect coverage data from testOnPackagedReleaseJar executed inside the docker image
// the classpath for these tests is set at execution time for testOnPackagedReleaseJar
task jacocoTestReportOnPackagedReleaseJar(type: JacocoReport) {
String sourceFiles = "$System.env.SOURCE_DIR"
String testClassesUnpacked = "$System.env.CP_DIR"

//task jacocoTestReportOnPackagedReleaseJar(type: JacocoReport) {
// String sourceFiles = "$System.env.SOURCE_DIR"
// String testClassesUnpacked = "$System.env.CP_DIR"
//
// dependsOn testOnPackagedReleaseJar
// executionData testOnPackagedReleaseJar
// additionalSourceDirs.setFrom(sourceSets.main.allJava.srcDirs)
//
// sourceDirectories.setFrom(sourceFiles)
// classDirectories.setFrom(testClassesUnpacked)
//
// group = "Reporting"
// description = "Generate Jacoco coverage reports after running tests inside the docker image."
//
// reports {
// xml.required = true
// html.required = true
// }
//}
task jacocoTestReportOnPackagedReleaseJar {
dependsOn testOnPackagedReleaseJar
executionData testOnPackagedReleaseJar
additionalSourceDirs.setFrom(sourceSets.main.allJava.srcDirs)

sourceDirectories.setFrom(sourceFiles)
classDirectories.setFrom(testClassesUnpacked)

group = "Reporting"
description = "Generate Jacoco coverage reports after running tests inside the docker image."

reports {
xml.required = true
html.required = true
}
}
2 changes: 2 additions & 0 deletions scripts/gatkcondaenv.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ dependencies:
- conda-forge::scipy=1.11.4
- conda-forge::h5py=3.10.0
- conda-forge::pytorch=2.1.0=*mkl*100
- conda-forge::pytorch-lightning=2.4.0 # supports Pytorch >= 2.1 and <= 2.4, used by NVScoreVariants
- conda-forge::scikit-learn=1.3.2
- conda-forge::matplotlib=3.8.2
- conda-forge::pandas=2.1.3
- conda-forge::tqdm=4.66.1
- conda-forge::dill=0.3.7 # used for pickling lambdas in TrainVariantAnnotationsModel
- conda-forge::biopython=1.84 # used by NVScoreVariants

# core R dependencies; these should only be used for plotting and do not take precedence over core python dependencies!
- r-base=4.3.1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.argparser.ExperimentalFeature;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.io.Resource;
import org.broadinstitute.hellbender.utils.python.PythonExecutorBase;
import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
import org.broadinstitute.hellbender.utils.runtime.ProcessOutput;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* Annotate a VCF with scores from a PyTorch-based Convolutional Neural Network (CNN).
*
* It contains both a 1D model that uses only the reference sequence and variant annotations,
* and a 2D model that uses reads in addition to the reference sequence and variant annotations.
*
* The scores for each variant record will be placed in an INFO field annotation named CNN_1D
* (if using the 1D model) or CNN_2D (if using the 2D model). These scores represent the
* log odds of being a true variant versus being false under the trained convolutional neural
* network.
*
* The provided models were trained on short-read human sequencing data, and will likely not perform
* well for other kinds of sequencing data, or for non-human data. A companion training tool for
* NVScoreVariants will be released in the future to support users who need to train their own models.
*
* Example command for running with the 1D model:
*
* <pre>
* gatk NVScoreVariants \
* -V src/test/resources/large/VQSR/recalibrated_chr20_start.vcf \
* -R src/test/resources/large/human_g1k_v37.20.21.fasta \
* -O output.vcf
* </pre>
*
* Example command for running with the 2D model:
*
* <pre>
* gatk NVScoreVariants \
* -V src/test/resources/large/VQSR/recalibrated_chr20_start.vcf \
* -R src/test/resources/large/human_g1k_v37.20.21.fasta \
* --tensor-type read_tensor \
* -I src/test/resources/large/VQSR/g94982_contig_20_start_bamout.bam \
* -O output.vcf
* </pre>
*
* <b><i>The PyTorch Python code that this tool relies upon was contributed by engineers at
* <a href="https://github.com/NVIDIA-Genomics-Research">NVIDIA Genomics Research</a>.
* We would like to give particular thanks to Babak Zamirai of NVIDIA, who authored
* the tool, as well as to Ankit Sethia, Mehrzad Samadi, and George Vacek (also of NVIDIA),
* without whom this project would not have been possible.</i></b>
*/
@CommandLineProgramProperties(
summary = "Annotate a VCF with scores from a PyTorch-based Convolutional Neural Network (CNN)",
oneLineSummary = "Annotate a VCF with scores from a PyTorch-based Convolutional Neural Network (CNN)",
programGroup = VariantFilteringProgramGroup.class
)
@ExperimentalFeature
public class NVScoreVariants extends CommandLineProgram {

public static final String NV_SCORE_VARIANTS_PACKAGE = "scorevariants";
public static final String NV_SCORE_VARIANTS_SCRIPT = "nvscorevariants.py";
public static final String NV_SCORE_VARIANTS_1D_MODEL_FILENAME = "1d_cnn_mix_train_full_bn.pt";
public static final String NV_SCORE_VARIANTS_2D_MODEL_FILENAME = "small_2d.pt";
public static final String NV_SCORE_VARIANTS_1D_MODEL = Resource.LARGE_RUNTIME_RESOURCES_PATH + "/nvscorevariants/" + NV_SCORE_VARIANTS_1D_MODEL_FILENAME;
public static final String NV_SCORE_VARIANTS_2D_MODEL = Resource.LARGE_RUNTIME_RESOURCES_PATH + "/nvscorevariants/" + NV_SCORE_VARIANTS_2D_MODEL_FILENAME;

public enum TensorType {
reference,
read_tensor
}

@Argument(fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME, shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, doc = "Output VCF file")
private File outputVCF;

@Argument(fullName = StandardArgumentDefinitions.VARIANT_LONG_NAME, shortName = StandardArgumentDefinitions.VARIANT_SHORT_NAME, doc = "Input VCF file containing variants to score")
private File inputVCF;

@Argument(fullName = StandardArgumentDefinitions.REFERENCE_LONG_NAME, shortName = StandardArgumentDefinitions.REFERENCE_SHORT_NAME, doc = "Reference sequence file")
private File reference;

@Argument(fullName = StandardArgumentDefinitions.INPUT_LONG_NAME, shortName = StandardArgumentDefinitions.INPUT_SHORT_NAME, doc = "BAM file containing reads, if using the 2D model", optional = true)
private File bam;

@Argument(fullName = "tensor-type", doc = "Name of the tensors to generate: reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true)
private TensorType tensorType = TensorType.reference;

@Argument(fullName = "batch-size", doc = "Batch size", optional = true)
private int batchSize = 64;

@Argument(fullName = "random-seed", doc = "Seed to initialize the random number generator", optional = true)
private int randomSeed = 724;

@Argument(fullName = "tmp-file", doc = "The temporary VCF-like file where variants scores will be written", optional = true)
private File tmpFile;

@Argument(fullName = "accelerator", doc = "Type of hardware accelerator to use (auto, cpu, cuda, mps, tpu, etc)", optional = true)
private String accelerator = "auto";

@Override
protected void onStartup() {
PythonScriptExecutor.checkPythonEnvironmentForPackage(NV_SCORE_VARIANTS_PACKAGE);
}

@Override
protected Object doWork() {
final PythonScriptExecutor pythonExecutor = new PythonScriptExecutor(PythonExecutorBase.PythonExecutableName.PYTHON3, true);
final Resource pythonScriptResource = new Resource(NV_SCORE_VARIANTS_SCRIPT, NVScoreVariants.class);
final File extractedModelDirectory = extractModelFilesToTempDirectory();

if ( tmpFile == null ) {
tmpFile = IOUtils.createTempFile("NVScoreVariants_tmp", ".txt");
}

final List<String> arguments = new ArrayList<>(Arrays.asList(
"--output-file", outputVCF.getAbsolutePath(),
"--vcf-file", inputVCF.getAbsolutePath(),
"--ref-file", reference.getAbsolutePath(),
"--tensor-type", tensorType.name(),
"--batch-size", Integer.toString(batchSize),
"--seed", Integer.toString(randomSeed),
"--tmp-file", tmpFile.getAbsolutePath(),
"--model-directory", extractedModelDirectory.getAbsolutePath()
));

if (accelerator != null) {
arguments.addAll(List.of("--accelerator",accelerator));
}

if ( tensorType == TensorType.reference && bam != null ) {
throw new UserException.BadInput("--" + StandardArgumentDefinitions.INPUT_LONG_NAME +
" should only be specified when running with --tensor-type " + TensorType.read_tensor.name());
}
else if ( tensorType == TensorType.read_tensor && bam == null ) {
throw new UserException.BadInput("Need to specify a BAM file via --" + StandardArgumentDefinitions.INPUT_LONG_NAME +
" when running with --tensor-type " + TensorType.read_tensor.name());
}

if ( bam != null ) {
arguments.addAll(Arrays.asList("--input-file", bam.getAbsolutePath()));
}

logger.info("Running Python NVScoreVariants module with arguments: " + arguments);
final ProcessOutput pythonOutput = pythonExecutor.executeScriptAndGetOutput(
pythonScriptResource,
null,
arguments
);

if ( pythonOutput.getExitValue() != 0 ) {
logger.error("Error running NVScoreVariants Python command:\n" + pythonOutput.getStatusSummary(true));
}

return pythonOutput.getExitValue();
}

private File extractModelFilesToTempDirectory() {
final File extracted1DModel = IOUtils.writeTempResourceFromPath(NV_SCORE_VARIANTS_1D_MODEL, null);
final File extracted2DModel = IOUtils.writeTempResourceFromPath(NV_SCORE_VARIANTS_2D_MODEL, null);
final File modelDirectory = IOUtils.createTempDir("NVScoreVariants_models");

if ( ! extracted1DModel.renameTo(new File(modelDirectory, NV_SCORE_VARIANTS_1D_MODEL_FILENAME)) ) {
throw new UserException("Error moving " + extracted1DModel.getAbsolutePath() + " to " + modelDirectory.getAbsolutePath());
}
if ( ! extracted2DModel.renameTo(new File(modelDirectory, NV_SCORE_VARIANTS_2D_MODEL_FILENAME)) ) {
throw new UserException("Error moving " + extracted2DModel.getAbsolutePath() + " to " + modelDirectory.getAbsolutePath());
}

logger.info("Extracted models to: " + modelDirectory.getAbsolutePath());
return modelDirectory;
}

@Override
protected void onShutdown() {
super.onShutdown();
}
}
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/python3

from pysam import VariantFile
import re
import argparse
import sys

CONTIG_INDEX = 0;
POS_INDEX = 1;
REF_INDEX = 2;
ALT_INDEX = 3;
KEY_INDEX = 4;

def create_output_vcf(vcf_in, scores_file, vcf_out, label):
variant_file = VariantFile(vcf_in)
variant_file.reset()

variant_file.header.info.add(id=label, number=1, type='Float', description='Log odds of being a true variant versus \
being false under the trained Convolutional Neural Network')
header = variant_file.header.copy()
vcfWriter = VariantFile(vcf_out, 'w', header=header)

with open(scores_file) as scoredVariants:
sv = next(scoredVariants)
for variant in variant_file:
scoredVariant = sv.split('\t')
if variant.contig == scoredVariant[CONTIG_INDEX] and \
variant.pos == int(scoredVariant[POS_INDEX]) and \
variant.ref == scoredVariant[REF_INDEX] and \
', '.join(variant.alts or []) == re.sub('[\[\]]', '', scoredVariant[ALT_INDEX]):

if len(scoredVariant) > KEY_INDEX:
variant.info.update({label: float(scoredVariant[KEY_INDEX])})

vcfWriter.write(variant)

sv = next(scoredVariants, None)
else:
sys.exit("Score file out of sync with original VCF. Score file has: " + sv + "\nBut VCF has: " + str(variant))
Loading
Loading