Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expected calculation and subsampling #989

Merged
merged 1 commit into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/juicebox/data/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Random;


/**
Expand Down Expand Up @@ -65,6 +66,22 @@ public List<ContactRecord> getContactRecords() {
return records;
}

public List<ContactRecord> getContactRecords(double subsampleFraction, Random randomSubsampleGenerator) {
List<ContactRecord> newRecords = new ArrayList<>();
for (ContactRecord i : records) {
int newBinX = i.getBinX();
int newBinY = i.getBinY();
int newCounts = 0;
for (int j = 0; j < (int) i.getCounts(); j++) {
if ( subsampleFraction <= 1 && subsampleFraction > 0 && randomSubsampleGenerator.nextDouble() <= subsampleFraction) {
newCounts += 1;
}
}
newRecords.add(new ContactRecord(newBinX, newBinY, (float) newCounts));
}
return newRecords;
}

public void clear() {
records.clear();
}
Expand Down
207 changes: 197 additions & 10 deletions src/juicebox/data/MatrixZoomData.java
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,8 @@ public IteratorContainer getFromFileIteratorContainer() {

// Merge and write out blocks one at a time.
public Pair<List<IndexEntry>,ExpectedValueCalculation> mergeAndWriteBlocks(LittleEndianOutputStream los, Deflater compressor,
boolean calculateExpecteds, Map<String, Integer> fragmentCountMap, ChromosomeHandler chromosomeHandler) throws IOException {
boolean calculateExpecteds, Map<String, Integer> fragmentCountMap, ChromosomeHandler chromosomeHandler,
double subsampleFraction, Random randomSubsampleGenerator) throws IOException {

List<Integer> sortedBlockNumbers = reader.getBlockNumbers(this);
List<IndexEntry> indexEntries = new ArrayList<>();
Expand All @@ -1268,7 +1269,7 @@ public Pair<List<IndexEntry>,ExpectedValueCalculation> mergeAndWriteBlocks(Littl

// Output block
long position = los.getWrittenCount();
writeBlock(currentBlock, los, compressor, zoomCalc);
writeBlock(currentBlock, los, compressor, zoomCalc, subsampleFraction, randomSubsampleGenerator);
long size = los.getWrittenCount() - position;

indexEntries.add(new IndexEntry(num, position, (int) size));
Expand All @@ -1280,7 +1281,8 @@ public Pair<List<IndexEntry>,ExpectedValueCalculation> mergeAndWriteBlocks(Littl

// Merge and write out blocks multithreaded.
public Pair<List<IndexEntry>,ExpectedValueCalculation> mergeAndWriteBlocks(LittleEndianOutputStream[] losArray, Deflater compressor, int whichZoom, int numResolutions,
boolean calculateExpecteds, Map<String, Integer> fragmentCountMap, ChromosomeHandler chromosomeHandler) {
boolean calculateExpecteds, Map<String, Integer> fragmentCountMap, ChromosomeHandler chromosomeHandler,
double subsampleFraction, Random randomSubsampleGenerator) {
List<Integer> sortedBlockNumbers = reader.getBlockNumbers(this);
int numCPUThreads = (losArray.length - 1) / numResolutions;
List<Integer> sortedBlockSizes = new ArrayList<>();
Expand Down Expand Up @@ -1348,7 +1350,7 @@ public Pair<List<IndexEntry>,ExpectedValueCalculation> mergeAndWriteBlocks(Littl
@Override
public void run() {
try {
writeBlockChunk(startBlock, endBlock, sortedBlockNumbers, losArray, whichLos, indexEntries, calc);
writeBlockChunk(startBlock, endBlock, sortedBlockNumbers, losArray, whichLos, indexEntries, calc, subsampleFraction, randomSubsampleGenerator);
} catch (Exception e) {
e.printStackTrace();
}
Expand Down Expand Up @@ -1407,7 +1409,7 @@ public void run() {
}

private void writeBlockChunk(int startBlock, int endBlock, List<Integer> sortedBlockNumbers, LittleEndianOutputStream[] losArray,
int threadNum, List<IndexEntry> indexEntries, ExpectedValueCalculation calc ) throws IOException{
int threadNum, List<IndexEntry> indexEntries, ExpectedValueCalculation calc, double subsampleFraction, Random randomSubsampleGenerator) throws IOException{
Deflater compressor = new Deflater();
compressor.setLevel(Deflater.DEFAULT_COMPRESSION);
//System.err.println(threadBlocks.length);
Expand All @@ -1418,7 +1420,7 @@ private void writeBlockChunk(int startBlock, int endBlock, List<Integer> sortedB

if (currentBlock != null) {
long position = losArray[threadNum + 1].getWrittenCount();
writeBlock(currentBlock, losArray[threadNum + 1], compressor, calc);
writeBlock(currentBlock, losArray[threadNum + 1], compressor, calc, subsampleFraction, randomSubsampleGenerator);
long size = losArray[threadNum + 1].getWrittenCount() - position;
indexEntries.add(new IndexEntry(num, position, (int) size));
}
Expand All @@ -1434,10 +1436,14 @@ private void writeBlockChunk(int startBlock, int endBlock, List<Integer> sortedB
* @param block Block to write
* @throws IOException
*/
protected void writeBlock(Block block, LittleEndianOutputStream los, Deflater compressor, ExpectedValueCalculation calc) throws IOException {

final List<ContactRecord> records = block.getContactRecords();// getContactRecords();
protected void writeBlock(Block block, LittleEndianOutputStream los, Deflater compressor, ExpectedValueCalculation calc, double subsampleFraction, Random randomSubsampleGenerator) throws IOException {

final List<ContactRecord> records;
if (subsampleFraction < 1) {
records = block.getContactRecords(subsampleFraction, randomSubsampleGenerator);// getContactRecords();
} else {
records = block.getContactRecords();
}
// System.out.println("Write contact records : records count = " + records.size());

// Count records first
Expand Down Expand Up @@ -1492,7 +1498,7 @@ public int compare(ContactRecord o1, ContactRecord o2) {
final int px = record.getBinX() - binXOffset;
final int py = record.getBinY() - binYOffset;
if (calc != null && chr1.getIndex() == chr2.getIndex()) {
calc.addDistance(chr1.getIndex(), px, py, counts);
calc.addDistance(chr1.getIndex(), record.getBinX(), record.getBinY(), counts);
}
List<ContactRecord> row = rows.get(py);
if (row == null) {
Expand Down Expand Up @@ -1636,4 +1642,185 @@ protected byte[] compress(byte[] data, Deflater compressor) {

return bos.toByteArray();
}

// Merge and write out blocks multithreaded.
public ExpectedValueCalculation computeExpected(boolean calculateExpecteds, Map<String, Integer> fragmentCountMap, ChromosomeHandler chromosomeHandler, int numCPUThreads) {
List<Integer> sortedBlockNumbers = reader.getBlockNumbers(this);
List<Integer> sortedBlockSizes = new ArrayList<>();
long totalBlockSize = 0;
for (int i = 0; i < sortedBlockNumbers.size(); i++) {
int blockSize = reader.getBlockSize(this, sortedBlockNumbers.get(i));
sortedBlockSizes.add(blockSize);
totalBlockSize += blockSize;
}
ExecutorService executor = Executors.newFixedThreadPool(numCPUThreads);
Map<Integer, Long> blockChunkSizes = new ConcurrentHashMap<>(numCPUThreads);
ExpectedValueCalculation zoomCalc; Map<Integer, ExpectedValueCalculation> localExpectedValueCalculations;
if (calculateExpecteds) {
zoomCalc = new ExpectedValueCalculation(chromosomeHandler, zoom.getBinSize(), fragmentCountMap, NormalizationHandler.NONE);
localExpectedValueCalculations = new ConcurrentHashMap<>(numCPUThreads);
} else {
zoomCalc = null;
localExpectedValueCalculations = null;
}

int placeholder = 0;
for (int l = 0; l < numCPUThreads; l++) {
final int threadNum = l;
final long blockSizePerThread = (long) Math.floor(1.5 * (long) Math.floor(totalBlockSize / numCPUThreads));
final int startBlock;
if (l == 0) {
startBlock = 0;
} else {
startBlock = placeholder;
}
int tmpEnd = startBlock + ((int) Math.floor((double) sortedBlockNumbers.size() / numCPUThreads)) - 1;
long tmpBlockSize = 0;
int tmpEnd2 = 0;
for (int b = startBlock; b <= tmpEnd; b++) {
tmpBlockSize += sortedBlockSizes.get(b);
if (tmpBlockSize > blockSizePerThread) {
tmpEnd2 = b-1;
break;
}
}
if (tmpEnd2 > 0) {
tmpEnd = tmpEnd2;
}
if (l + 1 == numCPUThreads && tmpEnd < sortedBlockNumbers.size()-1) {
tmpEnd = sortedBlockNumbers.size()-1;
}
final int endBlock = tmpEnd;
placeholder = endBlock + 1;
//System.err.println(binSize + " " + blockNumbers.size() + " " + sortedBlockNumbers.length + " " + startBlock + " " + endBlock);
if (startBlock > endBlock) {
blockChunkSizes.put(threadNum,(long) 0);
continue;
}
List<IndexEntry> indexEntries = new ArrayList<>();
final ExpectedValueCalculation calc;
if (calculateExpecteds) {
calc = new ExpectedValueCalculation(chromosomeHandler, zoom.getBinSize(), fragmentCountMap, NormalizationHandler.NONE);
} else {
calc = null;
}

Runnable worker = new Runnable() {
@Override
public void run() {
try {
calculateExpectedBlockChunk(startBlock, endBlock, sortedBlockNumbers, calc);
} catch (Exception e) {
e.printStackTrace();
}
if (calculateExpecteds) {
localExpectedValueCalculations.put(threadNum, calc);
}
}
};
executor.execute(worker);
}
executor.shutdown();

// Wait until all threads finish
while (!executor.isTerminated()) {
try {
Thread.sleep(50);
} catch (InterruptedException e) {
System.err.println(e.getLocalizedMessage());
}
}


for (int l = 0; l < numCPUThreads; l++) {
if (calculateExpecteds) {
ExpectedValueCalculation tmpCalc = localExpectedValueCalculations.get(l);
if (tmpCalc != null) {
zoomCalc.merge(tmpCalc);
}
tmpCalc = null;

}
}


return zoomCalc;
}

private void calculateExpectedBlockChunk(int startBlock, int endBlock, List<Integer> sortedBlockNumbers, ExpectedValueCalculation calc ) throws IOException{
//System.err.println(threadBlocks.length);
for (int i = startBlock; i <= endBlock; i++) {

Block currentBlock = reader.readNormalizedBlock(sortedBlockNumbers.get(i), this, NormalizationHandler.NONE);
int num = sortedBlockNumbers.get(i);

if (currentBlock != null) {
calcExpectedBlock(currentBlock, calc);
}
currentBlock.clear();
//System.err.println("Used Memory after writing block " + i);
//System.err.println(Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory());
}
}

protected void calcExpectedBlock(Block block, ExpectedValueCalculation calc) throws IOException {

final List<ContactRecord> records = block.getContactRecords();// getContactRecords();

// System.out.println("Write contact records : records count = " + records.size());

// Count records first
int nRecords = records.size();


// Find extents of occupied cells
int binXOffset = Integer.MAX_VALUE;
int binYOffset = Integer.MAX_VALUE;
int binXMax = 0;
int binYMax = 0;
for (ContactRecord entry : records) {
binXOffset = Math.min(binXOffset, entry.getBinX());
binYOffset = Math.min(binYOffset, entry.getBinY());
binXMax = Math.max(binXMax, entry.getBinX());
binYMax = Math.max(binYMax, entry.getBinY());
}


// Sort keys in row-major order
records.sort(new Comparator<ContactRecord>() {
@Override
public int compare(ContactRecord o1, ContactRecord o2) {
if (o1.getBinY() != o2.getBinY()) {
return o1.getBinY() - o2.getBinY();
} else {
return o1.getBinX() - o2.getBinX();
}
}
});
ContactRecord lastRecord = records.get(records.size() - 1);
final short w = (short) (binXMax - binXOffset + 1);
final int w1 = binXMax - binXOffset + 1;
final int w2 = binYMax - binYOffset + 1;

boolean isInteger = true;
float maxCounts = 0;

for (ContactRecord record : records) {
float counts = record.getCounts();

if (counts >= 0) {

isInteger = isInteger && (Math.floor(counts) == counts);
maxCounts = Math.max(counts, maxCounts);

final int px = record.getBinX() - binXOffset;
final int py = record.getBinY() - binYOffset;
if (calc != null && chr1.getIndex() == chr2.getIndex()) {
calc.addDistance(chr1.getIndex(), record.getBinX(), record.getBinY(), counts);
}

}
}
}

}
2 changes: 1 addition & 1 deletion src/juicebox/gui/SuperAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ public void run() {
LoadDialog.LAST_LOADED_HIC_FILE_PATH = files[0];


CustomNormVectorFileHandler.unsafeHandleUpdatingOfNormalizations(SuperAdapter.this, files, isControl);
CustomNormVectorFileHandler.unsafeHandleUpdatingOfNormalizations(SuperAdapter.this, files, isControl, 1);

boolean versionStatus = hic.getDataset().getVersion() >= HiCGlobals.minVersion;
if (isControl) {
Expand Down
2 changes: 1 addition & 1 deletion src/juicebox/tools/clt/old/AddNorm.java
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public void run() {
HiCGlobals.allowDynamicBlockIndex = false;
try {
if (inputVectorFile != null) {
CustomNormVectorFileHandler.updateHicFile(file, inputVectorFile);
CustomNormVectorFileHandler.updateHicFile(file, inputVectorFile, numCPUThreads);
} else {
launch(file, normalizationTypes, genomeWideResolution, noFragNorm,
numCPUThreads, resolutionsToBuildTo);
Expand Down
Loading