diff --git a/src/juicebox/data/Block.java b/src/juicebox/data/Block.java index 9b00c985..82a2ed79 100644 --- a/src/juicebox/data/Block.java +++ b/src/juicebox/data/Block.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Random; /** @@ -65,6 +66,22 @@ public List getContactRecords() { return records; } + public List getContactRecords(double subsampleFraction, Random randomSubsampleGenerator) { + List newRecords = new ArrayList<>(); + for (ContactRecord i : records) { + int newBinX = i.getBinX(); + int newBinY = i.getBinY(); + int newCounts = 0; + for (int j = 0; j < (int) i.getCounts(); j++) { + if ( subsampleFraction <= 1 && subsampleFraction > 0 && randomSubsampleGenerator.nextDouble() <= subsampleFraction) { + newCounts += 1; + } + } + newRecords.add(new ContactRecord(newBinX, newBinY, (float) newCounts)); + } + return newRecords; + } + public void clear() { records.clear(); } diff --git a/src/juicebox/data/MatrixZoomData.java b/src/juicebox/data/MatrixZoomData.java index 55b8affe..6ee969ab 100644 --- a/src/juicebox/data/MatrixZoomData.java +++ b/src/juicebox/data/MatrixZoomData.java @@ -1250,7 +1250,8 @@ public IteratorContainer getFromFileIteratorContainer() { // Merge and write out blocks one at a time. public Pair,ExpectedValueCalculation> mergeAndWriteBlocks(LittleEndianOutputStream los, Deflater compressor, - boolean calculateExpecteds, Map fragmentCountMap, ChromosomeHandler chromosomeHandler) throws IOException { + boolean calculateExpecteds, Map fragmentCountMap, ChromosomeHandler chromosomeHandler, + double subsampleFraction, Random randomSubsampleGenerator) throws IOException { List sortedBlockNumbers = reader.getBlockNumbers(this); List indexEntries = new ArrayList<>(); @@ -1268,7 +1269,7 @@ public Pair,ExpectedValueCalculation> mergeAndWriteBlocks(Littl // Output block long position = los.getWrittenCount(); - writeBlock(currentBlock, los, compressor, zoomCalc); + writeBlock(currentBlock, los, compressor, zoomCalc, subsampleFraction, randomSubsampleGenerator); long size = los.getWrittenCount() - position; indexEntries.add(new IndexEntry(num, position, (int) size)); @@ -1280,7 +1281,8 @@ public Pair,ExpectedValueCalculation> mergeAndWriteBlocks(Littl // Merge and write out blocks multithreaded. public Pair,ExpectedValueCalculation> mergeAndWriteBlocks(LittleEndianOutputStream[] losArray, Deflater compressor, int whichZoom, int numResolutions, - boolean calculateExpecteds, Map fragmentCountMap, ChromosomeHandler chromosomeHandler) { + boolean calculateExpecteds, Map fragmentCountMap, ChromosomeHandler chromosomeHandler, + double subsampleFraction, Random randomSubsampleGenerator) { List sortedBlockNumbers = reader.getBlockNumbers(this); int numCPUThreads = (losArray.length - 1) / numResolutions; List sortedBlockSizes = new ArrayList<>(); @@ -1348,7 +1350,7 @@ public Pair,ExpectedValueCalculation> mergeAndWriteBlocks(Littl @Override public void run() { try { - writeBlockChunk(startBlock, endBlock, sortedBlockNumbers, losArray, whichLos, indexEntries, calc); + writeBlockChunk(startBlock, endBlock, sortedBlockNumbers, losArray, whichLos, indexEntries, calc, subsampleFraction, randomSubsampleGenerator); } catch (Exception e) { e.printStackTrace(); } @@ -1407,7 +1409,7 @@ public void run() { } private void writeBlockChunk(int startBlock, int endBlock, List sortedBlockNumbers, LittleEndianOutputStream[] losArray, - int threadNum, List indexEntries, ExpectedValueCalculation calc ) throws IOException{ + int threadNum, List indexEntries, ExpectedValueCalculation calc, double subsampleFraction, Random randomSubsampleGenerator) throws IOException{ Deflater compressor = new Deflater(); compressor.setLevel(Deflater.DEFAULT_COMPRESSION); //System.err.println(threadBlocks.length); @@ -1418,7 +1420,7 @@ private void writeBlockChunk(int startBlock, int endBlock, List sortedB if (currentBlock != null) { long position = losArray[threadNum + 1].getWrittenCount(); - writeBlock(currentBlock, losArray[threadNum + 1], compressor, calc); + writeBlock(currentBlock, losArray[threadNum + 1], compressor, calc, subsampleFraction, randomSubsampleGenerator); long size = losArray[threadNum + 1].getWrittenCount() - position; indexEntries.add(new IndexEntry(num, position, (int) size)); } @@ -1434,10 +1436,14 @@ private void writeBlockChunk(int startBlock, int endBlock, List sortedB * @param block Block to write * @throws IOException */ - protected void writeBlock(Block block, LittleEndianOutputStream los, Deflater compressor, ExpectedValueCalculation calc) throws IOException { - - final List records = block.getContactRecords();// getContactRecords(); + protected void writeBlock(Block block, LittleEndianOutputStream los, Deflater compressor, ExpectedValueCalculation calc, double subsampleFraction, Random randomSubsampleGenerator) throws IOException { + final List records; + if (subsampleFraction < 1) { + records = block.getContactRecords(subsampleFraction, randomSubsampleGenerator);// getContactRecords(); + } else { + records = block.getContactRecords(); + } // System.out.println("Write contact records : records count = " + records.size()); // Count records first @@ -1492,7 +1498,7 @@ public int compare(ContactRecord o1, ContactRecord o2) { final int px = record.getBinX() - binXOffset; final int py = record.getBinY() - binYOffset; if (calc != null && chr1.getIndex() == chr2.getIndex()) { - calc.addDistance(chr1.getIndex(), px, py, counts); + calc.addDistance(chr1.getIndex(), record.getBinX(), record.getBinY(), counts); } List row = rows.get(py); if (row == null) { @@ -1636,4 +1642,185 @@ protected byte[] compress(byte[] data, Deflater compressor) { return bos.toByteArray(); } + + // Merge and write out blocks multithreaded. + public ExpectedValueCalculation computeExpected(boolean calculateExpecteds, Map fragmentCountMap, ChromosomeHandler chromosomeHandler, int numCPUThreads) { + List sortedBlockNumbers = reader.getBlockNumbers(this); + List sortedBlockSizes = new ArrayList<>(); + long totalBlockSize = 0; + for (int i = 0; i < sortedBlockNumbers.size(); i++) { + int blockSize = reader.getBlockSize(this, sortedBlockNumbers.get(i)); + sortedBlockSizes.add(blockSize); + totalBlockSize += blockSize; + } + ExecutorService executor = Executors.newFixedThreadPool(numCPUThreads); + Map blockChunkSizes = new ConcurrentHashMap<>(numCPUThreads); + ExpectedValueCalculation zoomCalc; Map localExpectedValueCalculations; + if (calculateExpecteds) { + zoomCalc = new ExpectedValueCalculation(chromosomeHandler, zoom.getBinSize(), fragmentCountMap, NormalizationHandler.NONE); + localExpectedValueCalculations = new ConcurrentHashMap<>(numCPUThreads); + } else { + zoomCalc = null; + localExpectedValueCalculations = null; + } + + int placeholder = 0; + for (int l = 0; l < numCPUThreads; l++) { + final int threadNum = l; + final long blockSizePerThread = (long) Math.floor(1.5 * (long) Math.floor(totalBlockSize / numCPUThreads)); + final int startBlock; + if (l == 0) { + startBlock = 0; + } else { + startBlock = placeholder; + } + int tmpEnd = startBlock + ((int) Math.floor((double) sortedBlockNumbers.size() / numCPUThreads)) - 1; + long tmpBlockSize = 0; + int tmpEnd2 = 0; + for (int b = startBlock; b <= tmpEnd; b++) { + tmpBlockSize += sortedBlockSizes.get(b); + if (tmpBlockSize > blockSizePerThread) { + tmpEnd2 = b-1; + break; + } + } + if (tmpEnd2 > 0) { + tmpEnd = tmpEnd2; + } + if (l + 1 == numCPUThreads && tmpEnd < sortedBlockNumbers.size()-1) { + tmpEnd = sortedBlockNumbers.size()-1; + } + final int endBlock = tmpEnd; + placeholder = endBlock + 1; + //System.err.println(binSize + " " + blockNumbers.size() + " " + sortedBlockNumbers.length + " " + startBlock + " " + endBlock); + if (startBlock > endBlock) { + blockChunkSizes.put(threadNum,(long) 0); + continue; + } + List indexEntries = new ArrayList<>(); + final ExpectedValueCalculation calc; + if (calculateExpecteds) { + calc = new ExpectedValueCalculation(chromosomeHandler, zoom.getBinSize(), fragmentCountMap, NormalizationHandler.NONE); + } else { + calc = null; + } + + Runnable worker = new Runnable() { + @Override + public void run() { + try { + calculateExpectedBlockChunk(startBlock, endBlock, sortedBlockNumbers, calc); + } catch (Exception e) { + e.printStackTrace(); + } + if (calculateExpecteds) { + localExpectedValueCalculations.put(threadNum, calc); + } + } + }; + executor.execute(worker); + } + executor.shutdown(); + + // Wait until all threads finish + while (!executor.isTerminated()) { + try { + Thread.sleep(50); + } catch (InterruptedException e) { + System.err.println(e.getLocalizedMessage()); + } + } + + + for (int l = 0; l < numCPUThreads; l++) { + if (calculateExpecteds) { + ExpectedValueCalculation tmpCalc = localExpectedValueCalculations.get(l); + if (tmpCalc != null) { + zoomCalc.merge(tmpCalc); + } + tmpCalc = null; + + } + } + + + return zoomCalc; + } + + private void calculateExpectedBlockChunk(int startBlock, int endBlock, List sortedBlockNumbers, ExpectedValueCalculation calc ) throws IOException{ + //System.err.println(threadBlocks.length); + for (int i = startBlock; i <= endBlock; i++) { + + Block currentBlock = reader.readNormalizedBlock(sortedBlockNumbers.get(i), this, NormalizationHandler.NONE); + int num = sortedBlockNumbers.get(i); + + if (currentBlock != null) { + calcExpectedBlock(currentBlock, calc); + } + currentBlock.clear(); + //System.err.println("Used Memory after writing block " + i); + //System.err.println(Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()); + } + } + + protected void calcExpectedBlock(Block block, ExpectedValueCalculation calc) throws IOException { + + final List records = block.getContactRecords();// getContactRecords(); + + // System.out.println("Write contact records : records count = " + records.size()); + + // Count records first + int nRecords = records.size(); + + + // Find extents of occupied cells + int binXOffset = Integer.MAX_VALUE; + int binYOffset = Integer.MAX_VALUE; + int binXMax = 0; + int binYMax = 0; + for (ContactRecord entry : records) { + binXOffset = Math.min(binXOffset, entry.getBinX()); + binYOffset = Math.min(binYOffset, entry.getBinY()); + binXMax = Math.max(binXMax, entry.getBinX()); + binYMax = Math.max(binYMax, entry.getBinY()); + } + + + // Sort keys in row-major order + records.sort(new Comparator() { + @Override + public int compare(ContactRecord o1, ContactRecord o2) { + if (o1.getBinY() != o2.getBinY()) { + return o1.getBinY() - o2.getBinY(); + } else { + return o1.getBinX() - o2.getBinX(); + } + } + }); + ContactRecord lastRecord = records.get(records.size() - 1); + final short w = (short) (binXMax - binXOffset + 1); + final int w1 = binXMax - binXOffset + 1; + final int w2 = binYMax - binYOffset + 1; + + boolean isInteger = true; + float maxCounts = 0; + + for (ContactRecord record : records) { + float counts = record.getCounts(); + + if (counts >= 0) { + + isInteger = isInteger && (Math.floor(counts) == counts); + maxCounts = Math.max(counts, maxCounts); + + final int px = record.getBinX() - binXOffset; + final int py = record.getBinY() - binYOffset; + if (calc != null && chr1.getIndex() == chr2.getIndex()) { + calc.addDistance(chr1.getIndex(), record.getBinX(), record.getBinY(), counts); + } + + } + } + } + } diff --git a/src/juicebox/gui/SuperAdapter.java b/src/juicebox/gui/SuperAdapter.java index 7e7c71cd..38ae0b50 100644 --- a/src/juicebox/gui/SuperAdapter.java +++ b/src/juicebox/gui/SuperAdapter.java @@ -1078,7 +1078,7 @@ public void run() { LoadDialog.LAST_LOADED_HIC_FILE_PATH = files[0]; - CustomNormVectorFileHandler.unsafeHandleUpdatingOfNormalizations(SuperAdapter.this, files, isControl); + CustomNormVectorFileHandler.unsafeHandleUpdatingOfNormalizations(SuperAdapter.this, files, isControl, 1); boolean versionStatus = hic.getDataset().getVersion() >= HiCGlobals.minVersion; if (isControl) { diff --git a/src/juicebox/tools/clt/old/AddNorm.java b/src/juicebox/tools/clt/old/AddNorm.java index 24cb2e18..8c4ed317 100644 --- a/src/juicebox/tools/clt/old/AddNorm.java +++ b/src/juicebox/tools/clt/old/AddNorm.java @@ -132,7 +132,7 @@ public void run() { HiCGlobals.allowDynamicBlockIndex = false; try { if (inputVectorFile != null) { - CustomNormVectorFileHandler.updateHicFile(file, inputVectorFile); + CustomNormVectorFileHandler.updateHicFile(file, inputVectorFile, numCPUThreads); } else { launch(file, normalizationTypes, genomeWideResolution, noFragNorm, numCPUThreads, resolutionsToBuildTo); diff --git a/src/juicebox/tools/utils/norm/CustomNormVectorFileHandler.java b/src/juicebox/tools/utils/norm/CustomNormVectorFileHandler.java index 5296a73d..58f0d2e2 100644 --- a/src/juicebox/tools/utils/norm/CustomNormVectorFileHandler.java +++ b/src/juicebox/tools/utils/norm/CustomNormVectorFileHandler.java @@ -42,22 +42,25 @@ import java.util.concurrent.ExecutorService; import java.util.zip.GZIPInputStream; + + public class CustomNormVectorFileHandler extends NormVectorUpdater { - public static void updateHicFile(String path, String vectorPath) throws IOException { + + public static void updateHicFile(String path, String vectorPath, int numCPUThreads) throws IOException { DatasetReaderV2 reader = new DatasetReaderV2(path); Dataset ds = reader.read(); HiCGlobals.verifySupportedHiCFileVersion(reader.getVersion()); String[] vectorPaths = vectorPath.split(","); - NormVectorInfo normVectorInfo = completeCalculationsNecessaryForUpdatingCustomNormalizations(ds, vectorPaths, true); + NormVectorInfo normVectorInfo = completeCalculationsNecessaryForUpdatingCustomNormalizations(ds, vectorPaths, true, numCPUThreads); writeNormsToUpdateFile(reader, path, false, null, normVectorInfo.getExpectedValueFunctionMap(), normVectorInfo.getNormVectorIndices(), normVectorInfo.getNormVectorBuffers(), "Finished adding another normalization."); System.out.println("all custom norms added"); } - public static void unsafeHandleUpdatingOfNormalizations(SuperAdapter superAdapter, File[] files, boolean isControl) { + public static void unsafeHandleUpdatingOfNormalizations(SuperAdapter superAdapter, File[] files, boolean isControl, int numCPUThreads) { Dataset ds = superAdapter.getHiC().getDataset(); if (isControl) { @@ -70,7 +73,7 @@ public static void unsafeHandleUpdatingOfNormalizations(SuperAdapter superAdapte } try { - NormVectorInfo normVectorInfo = completeCalculationsNecessaryForUpdatingCustomNormalizations(ds, filePaths, false); + NormVectorInfo normVectorInfo = completeCalculationsNecessaryForUpdatingCustomNormalizations(ds, filePaths, false, numCPUThreads); for (NormalizationType customNormType : normVectorInfo.getNormalizationVectorsMap().keySet()) { ds.addNormalizationType(customNormType); @@ -89,7 +92,7 @@ public static void unsafeHandleUpdatingOfNormalizations(SuperAdapter superAdapte } private static NormVectorInfo completeCalculationsNecessaryForUpdatingCustomNormalizations( - final Dataset ds, String[] filePaths, boolean overwriteHicFileFooter) throws IOException { + final Dataset ds, String[] filePaths, boolean overwriteHicFileFooter, int numCPUThreads) throws IOException { Map> normalizationVectorMap = readVectorFile(filePaths, ds.getChromosomeHandler(), ds.getNormalizationHandler()); @@ -135,6 +138,7 @@ private static NormVectorInfo completeCalculationsNecessaryForUpdatingCustomNorm } } } + System.out.println("loaded existing norms"); ExecutorService executor = HiCGlobals.newFixedThreadPool(); for (NormalizationType customNormType : normalizationVectorMap.keySet()) { @@ -191,7 +195,7 @@ public void run() { if (zd == null) continue; handleLoadedVector(customNormType, chr.getIndex(), zoom, normalizationVectorMap.get(customNormType), - normVectorBuffers, normVectorIndices, zd, evLoaded); + normVectorBuffers, normVectorIndices, zd, evLoaded, fragCountMap, chromosomeHandler, numCPUThreads); } expectedValueFunctionMap.put(key, evLoaded.getExpectedValueFunction()); } @@ -203,7 +207,7 @@ public void run() { private static void handleLoadedVector(NormalizationType customNormType, final int chrIndx, HiCZoom zoom, Map normVectors, List normVectorBuffers, List normVectorIndex, - MatrixZoomData zd, ExpectedValueCalculation evLoaded) throws IOException { + MatrixZoomData zd, ExpectedValueCalculation evLoaded, Map fragmentCountMap, ChromosomeHandler chromosomeHandler, int numCPUThreads) throws IOException { String key = NormalizationVector.getKey(customNormType, chrIndx, zoom.getUnit().toString(), zoom.getBinSize()); if (normVectors.containsKey(key)) { @@ -226,7 +230,9 @@ private static void handleLoadedVector(NormalizationType customNormType, final i normVectorIndex.add(new NormalizationVectorIndexEntry( customNormType.toString(), chrIndx, zoom.getUnit().toString(), zoom.getBinSize(), position, sizeInBytes)); - evLoaded.addDistancesFromIterator(chrIndx, zd.getIteratorContainer(), vector.getData().convertToFloats()); + evLoaded.addDistancesFromZD(zd, fragmentCountMap, chromosomeHandler, numCPUThreads); + System.out.println("done with "+key); + } } diff --git a/src/juicebox/tools/utils/norm/NormVectorUpdater.java b/src/juicebox/tools/utils/norm/NormVectorUpdater.java index 430c9dae..7afcb91e 100644 --- a/src/juicebox/tools/utils/norm/NormVectorUpdater.java +++ b/src/juicebox/tools/utils/norm/NormVectorUpdater.java @@ -211,6 +211,10 @@ private static void appendExpectedValuesToBuffer(List expect } private static BufferedByteWriter getBufferWithEnoughSpace(List expectedBuffers, int bytesNeeded) { + if (expectedBuffers.size()==0) { + expectedBuffers.add(new BufferedByteWriter()); + } + BufferedByteWriter buffer = expectedBuffers.get(expectedBuffers.size() - 1); int freeBytes = Integer.MAX_VALUE - 10 - buffer.bytesWritten(); diff --git a/src/juicebox/tools/utils/original/ExpectedValueCalculation.java b/src/juicebox/tools/utils/original/ExpectedValueCalculation.java index 6b1f9b76..14d4560b 100644 --- a/src/juicebox/tools/utils/original/ExpectedValueCalculation.java +++ b/src/juicebox/tools/utils/original/ExpectedValueCalculation.java @@ -29,6 +29,7 @@ import juicebox.data.ChromosomeHandler; import juicebox.data.ContactRecord; import juicebox.data.ExpectedValueFunctionImpl; +import juicebox.data.MatrixZoomData; import juicebox.data.basics.Chromosome; import juicebox.data.basics.ListOfDoubleArrays; import juicebox.data.basics.ListOfFloatArrays; @@ -352,6 +353,10 @@ public void addDistancesFromIterator(int chrIndx, IteratorContainer ic, ListOfFl } } } + + public void addDistancesFromZD(MatrixZoomData zd, Map fragmentCountMap, ChromosomeHandler chromosomeHandler, int numCPUThreads) { + this.merge(zd.computeExpected(true, fragmentCountMap, chromosomeHandler, numCPUThreads)); + } } diff --git a/src/juicebox/tools/utils/original/MultithreadedPreprocessorHic.java b/src/juicebox/tools/utils/original/MultithreadedPreprocessorHic.java index ff65124e..ac89b412 100644 --- a/src/juicebox/tools/utils/original/MultithreadedPreprocessorHic.java +++ b/src/juicebox/tools/utils/original/MultithreadedPreprocessorHic.java @@ -228,9 +228,9 @@ void writeIndividualMatrix(Integer chromosomePair, int numOfNeededThreads, boole if (zd != null) { Pair, ExpectedValueCalculation> zdOutput; if (localLos.length > 1) { - zdOutput = zd.mergeAndWriteBlocks(localLos, compressor, i, numResolutions, calculateExpecteds, fragmentCountMap, chromosomeHandler); + zdOutput = zd.mergeAndWriteBlocks(localLos, compressor, i, numResolutions, calculateExpecteds, fragmentCountMap, chromosomeHandler, subsampleFraction, randomSubsampleGenerator); } else { - zdOutput = zd.mergeAndWriteBlocks(localLos[0], compressor, calculateExpecteds, fragmentCountMap, chromosomeHandler); + zdOutput = zd.mergeAndWriteBlocks(localLos[0], compressor, calculateExpecteds, fragmentCountMap, chromosomeHandler, subsampleFraction, randomSubsampleGenerator); } localBlockIndexes.put(zd.blockIndexPosition, zdOutput.getFirst()); if (calculateExpecteds) {