Skip to content

Commit

Permalink
Merge pull request #989 from aidenlab/hicSummingPre
Browse files Browse the repository at this point in the history
expected calculation and subsampling
  • Loading branch information
suhas-rao authored Jun 23, 2022
2 parents e68b2f6 + 8cf9e6b commit 3431835
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 22 deletions.
17 changes: 17 additions & 0 deletions src/juicebox/data/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Random;


/**
Expand Down Expand Up @@ -65,6 +66,22 @@ public List<ContactRecord> getContactRecords() {
return records;
}

public List<ContactRecord> getContactRecords(double subsampleFraction, Random randomSubsampleGenerator) {
List<ContactRecord> newRecords = new ArrayList<>();
for (ContactRecord i : records) {
int newBinX = i.getBinX();
int newBinY = i.getBinY();
int newCounts = 0;
for (int j = 0; j < (int) i.getCounts(); j++) {
if ( subsampleFraction <= 1 && subsampleFraction > 0 && randomSubsampleGenerator.nextDouble() <= subsampleFraction) {
newCounts += 1;
}
}
newRecords.add(new ContactRecord(newBinX, newBinY, (float) newCounts));
}
return newRecords;
}

public void clear() {
records.clear();
}
Expand Down
207 changes: 197 additions & 10 deletions src/juicebox/data/MatrixZoomData.java
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,8 @@ public IteratorContainer getFromFileIteratorContainer() {

// Merge and write out blocks one at a time.
public Pair<List<IndexEntry>,ExpectedValueCalculation> mergeAndWriteBlocks(LittleEndianOutputStream los, Deflater compressor,
boolean calculateExpecteds, Map<String, Integer> fragmentCountMap, ChromosomeHandler chromosomeHandler) throws IOException {
boolean calculateExpecteds, Map<String, Integer> fragmentCountMap, ChromosomeHandler chromosomeHandler,
double subsampleFraction, Random randomSubsampleGenerator) throws IOException {

List<Integer> sortedBlockNumbers = reader.getBlockNumbers(this);
List<IndexEntry> indexEntries = new ArrayList<>();
Expand All @@ -1268,7 +1269,7 @@ public Pair<List<IndexEntry>,ExpectedValueCalculation> mergeAndWriteBlocks(Littl

// Output block
long position = los.getWrittenCount();
writeBlock(currentBlock, los, compressor, zoomCalc);
writeBlock(currentBlock, los, compressor, zoomCalc, subsampleFraction, randomSubsampleGenerator);
long size = los.getWrittenCount() - position;

indexEntries.add(new IndexEntry(num, position, (int) size));
Expand All @@ -1280,7 +1281,8 @@ public Pair<List<IndexEntry>,ExpectedValueCalculation> mergeAndWriteBlocks(Littl

// Merge and write out blocks multithreaded.
public Pair<List<IndexEntry>,ExpectedValueCalculation> mergeAndWriteBlocks(LittleEndianOutputStream[] losArray, Deflater compressor, int whichZoom, int numResolutions,
boolean calculateExpecteds, Map<String, Integer> fragmentCountMap, ChromosomeHandler chromosomeHandler) {
boolean calculateExpecteds, Map<String, Integer> fragmentCountMap, ChromosomeHandler chromosomeHandler,
double subsampleFraction, Random randomSubsampleGenerator) {
List<Integer> sortedBlockNumbers = reader.getBlockNumbers(this);
int numCPUThreads = (losArray.length - 1) / numResolutions;
List<Integer> sortedBlockSizes = new ArrayList<>();
Expand Down Expand Up @@ -1348,7 +1350,7 @@ public Pair<List<IndexEntry>,ExpectedValueCalculation> mergeAndWriteBlocks(Littl
@Override
public void run() {
try {
writeBlockChunk(startBlock, endBlock, sortedBlockNumbers, losArray, whichLos, indexEntries, calc);
writeBlockChunk(startBlock, endBlock, sortedBlockNumbers, losArray, whichLos, indexEntries, calc, subsampleFraction, randomSubsampleGenerator);
} catch (Exception e) {
e.printStackTrace();
}
Expand Down Expand Up @@ -1407,7 +1409,7 @@ public void run() {
}

private void writeBlockChunk(int startBlock, int endBlock, List<Integer> sortedBlockNumbers, LittleEndianOutputStream[] losArray,
int threadNum, List<IndexEntry> indexEntries, ExpectedValueCalculation calc ) throws IOException{
int threadNum, List<IndexEntry> indexEntries, ExpectedValueCalculation calc, double subsampleFraction, Random randomSubsampleGenerator) throws IOException{
Deflater compressor = new Deflater();
compressor.setLevel(Deflater.DEFAULT_COMPRESSION);
//System.err.println(threadBlocks.length);
Expand All @@ -1418,7 +1420,7 @@ private void writeBlockChunk(int startBlock, int endBlock, List<Integer> sortedB

if (currentBlock != null) {
long position = losArray[threadNum + 1].getWrittenCount();
writeBlock(currentBlock, losArray[threadNum + 1], compressor, calc);
writeBlock(currentBlock, losArray[threadNum + 1], compressor, calc, subsampleFraction, randomSubsampleGenerator);
long size = losArray[threadNum + 1].getWrittenCount() - position;
indexEntries.add(new IndexEntry(num, position, (int) size));
}
Expand All @@ -1434,10 +1436,14 @@ private void writeBlockChunk(int startBlock, int endBlock, List<Integer> sortedB
* @param block Block to write
* @throws IOException
*/
protected void writeBlock(Block block, LittleEndianOutputStream los, Deflater compressor, ExpectedValueCalculation calc) throws IOException {

final List<ContactRecord> records = block.getContactRecords();// getContactRecords();
protected void writeBlock(Block block, LittleEndianOutputStream los, Deflater compressor, ExpectedValueCalculation calc, double subsampleFraction, Random randomSubsampleGenerator) throws IOException {

final List<ContactRecord> records;
if (subsampleFraction < 1) {
records = block.getContactRecords(subsampleFraction, randomSubsampleGenerator);// getContactRecords();
} else {
records = block.getContactRecords();
}
// System.out.println("Write contact records : records count = " + records.size());

// Count records first
Expand Down Expand Up @@ -1492,7 +1498,7 @@ public int compare(ContactRecord o1, ContactRecord o2) {
final int px = record.getBinX() - binXOffset;
final int py = record.getBinY() - binYOffset;
if (calc != null && chr1.getIndex() == chr2.getIndex()) {
calc.addDistance(chr1.getIndex(), px, py, counts);
calc.addDistance(chr1.getIndex(), record.getBinX(), record.getBinY(), counts);
}
List<ContactRecord> row = rows.get(py);
if (row == null) {
Expand Down Expand Up @@ -1636,4 +1642,185 @@ protected byte[] compress(byte[] data, Deflater compressor) {

return bos.toByteArray();
}

// Merge and write out blocks multithreaded.
public ExpectedValueCalculation computeExpected(boolean calculateExpecteds, Map<String, Integer> fragmentCountMap, ChromosomeHandler chromosomeHandler, int numCPUThreads) {
List<Integer> sortedBlockNumbers = reader.getBlockNumbers(this);
List<Integer> sortedBlockSizes = new ArrayList<>();
long totalBlockSize = 0;
for (int i = 0; i < sortedBlockNumbers.size(); i++) {
int blockSize = reader.getBlockSize(this, sortedBlockNumbers.get(i));
sortedBlockSizes.add(blockSize);
totalBlockSize += blockSize;
}
ExecutorService executor = Executors.newFixedThreadPool(numCPUThreads);
Map<Integer, Long> blockChunkSizes = new ConcurrentHashMap<>(numCPUThreads);
ExpectedValueCalculation zoomCalc; Map<Integer, ExpectedValueCalculation> localExpectedValueCalculations;
if (calculateExpecteds) {
zoomCalc = new ExpectedValueCalculation(chromosomeHandler, zoom.getBinSize(), fragmentCountMap, NormalizationHandler.NONE);
localExpectedValueCalculations = new ConcurrentHashMap<>(numCPUThreads);
} else {
zoomCalc = null;
localExpectedValueCalculations = null;
}

int placeholder = 0;
for (int l = 0; l < numCPUThreads; l++) {
final int threadNum = l;
final long blockSizePerThread = (long) Math.floor(1.5 * (long) Math.floor(totalBlockSize / numCPUThreads));
final int startBlock;
if (l == 0) {
startBlock = 0;
} else {
startBlock = placeholder;
}
int tmpEnd = startBlock + ((int) Math.floor((double) sortedBlockNumbers.size() / numCPUThreads)) - 1;
long tmpBlockSize = 0;
int tmpEnd2 = 0;
for (int b = startBlock; b <= tmpEnd; b++) {
tmpBlockSize += sortedBlockSizes.get(b);
if (tmpBlockSize > blockSizePerThread) {
tmpEnd2 = b-1;
break;
}
}
if (tmpEnd2 > 0) {
tmpEnd = tmpEnd2;
}
if (l + 1 == numCPUThreads && tmpEnd < sortedBlockNumbers.size()-1) {
tmpEnd = sortedBlockNumbers.size()-1;
}
final int endBlock = tmpEnd;
placeholder = endBlock + 1;
//System.err.println(binSize + " " + blockNumbers.size() + " " + sortedBlockNumbers.length + " " + startBlock + " " + endBlock);
if (startBlock > endBlock) {
blockChunkSizes.put(threadNum,(long) 0);
continue;
}
List<IndexEntry> indexEntries = new ArrayList<>();
final ExpectedValueCalculation calc;
if (calculateExpecteds) {
calc = new ExpectedValueCalculation(chromosomeHandler, zoom.getBinSize(), fragmentCountMap, NormalizationHandler.NONE);
} else {
calc = null;
}

Runnable worker = new Runnable() {
@Override
public void run() {
try {
calculateExpectedBlockChunk(startBlock, endBlock, sortedBlockNumbers, calc);
} catch (Exception e) {
e.printStackTrace();
}
if (calculateExpecteds) {
localExpectedValueCalculations.put(threadNum, calc);
}
}
};
executor.execute(worker);
}
executor.shutdown();

// Wait until all threads finish
while (!executor.isTerminated()) {
try {
Thread.sleep(50);
} catch (InterruptedException e) {
System.err.println(e.getLocalizedMessage());
}
}


for (int l = 0; l < numCPUThreads; l++) {
if (calculateExpecteds) {
ExpectedValueCalculation tmpCalc = localExpectedValueCalculations.get(l);
if (tmpCalc != null) {
zoomCalc.merge(tmpCalc);
}
tmpCalc = null;

}
}


return zoomCalc;
}

private void calculateExpectedBlockChunk(int startBlock, int endBlock, List<Integer> sortedBlockNumbers, ExpectedValueCalculation calc ) throws IOException{
//System.err.println(threadBlocks.length);
for (int i = startBlock; i <= endBlock; i++) {

Block currentBlock = reader.readNormalizedBlock(sortedBlockNumbers.get(i), this, NormalizationHandler.NONE);
int num = sortedBlockNumbers.get(i);

if (currentBlock != null) {
calcExpectedBlock(currentBlock, calc);
}
currentBlock.clear();
//System.err.println("Used Memory after writing block " + i);
//System.err.println(Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory());
}
}

protected void calcExpectedBlock(Block block, ExpectedValueCalculation calc) throws IOException {

final List<ContactRecord> records = block.getContactRecords();// getContactRecords();

// System.out.println("Write contact records : records count = " + records.size());

// Count records first
int nRecords = records.size();


// Find extents of occupied cells
int binXOffset = Integer.MAX_VALUE;
int binYOffset = Integer.MAX_VALUE;
int binXMax = 0;
int binYMax = 0;
for (ContactRecord entry : records) {
binXOffset = Math.min(binXOffset, entry.getBinX());
binYOffset = Math.min(binYOffset, entry.getBinY());
binXMax = Math.max(binXMax, entry.getBinX());
binYMax = Math.max(binYMax, entry.getBinY());
}


// Sort keys in row-major order
records.sort(new Comparator<ContactRecord>() {
@Override
public int compare(ContactRecord o1, ContactRecord o2) {
if (o1.getBinY() != o2.getBinY()) {
return o1.getBinY() - o2.getBinY();
} else {
return o1.getBinX() - o2.getBinX();
}
}
});
ContactRecord lastRecord = records.get(records.size() - 1);
final short w = (short) (binXMax - binXOffset + 1);
final int w1 = binXMax - binXOffset + 1;
final int w2 = binYMax - binYOffset + 1;

boolean isInteger = true;
float maxCounts = 0;

for (ContactRecord record : records) {
float counts = record.getCounts();

if (counts >= 0) {

isInteger = isInteger && (Math.floor(counts) == counts);
maxCounts = Math.max(counts, maxCounts);

final int px = record.getBinX() - binXOffset;
final int py = record.getBinY() - binYOffset;
if (calc != null && chr1.getIndex() == chr2.getIndex()) {
calc.addDistance(chr1.getIndex(), record.getBinX(), record.getBinY(), counts);
}

}
}
}

}
2 changes: 1 addition & 1 deletion src/juicebox/gui/SuperAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ public void run() {
LoadDialog.LAST_LOADED_HIC_FILE_PATH = files[0];


CustomNormVectorFileHandler.unsafeHandleUpdatingOfNormalizations(SuperAdapter.this, files, isControl);
CustomNormVectorFileHandler.unsafeHandleUpdatingOfNormalizations(SuperAdapter.this, files, isControl, 1);

boolean versionStatus = hic.getDataset().getVersion() >= HiCGlobals.minVersion;
if (isControl) {
Expand Down
2 changes: 1 addition & 1 deletion src/juicebox/tools/clt/old/AddNorm.java
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public void run() {
HiCGlobals.allowDynamicBlockIndex = false;
try {
if (inputVectorFile != null) {
CustomNormVectorFileHandler.updateHicFile(file, inputVectorFile);
CustomNormVectorFileHandler.updateHicFile(file, inputVectorFile, numCPUThreads);
} else {
launch(file, normalizationTypes, genomeWideResolution, noFragNorm,
numCPUThreads, resolutionsToBuildTo);
Expand Down
Loading

0 comments on commit 3431835

Please sign in to comment.