Skip to content

Commit

Permalink
[SYSTEMDS-2603] New hybrid approach for lineage deduplication
Browse files Browse the repository at this point in the history
This patch makes a major refactoring of the lineage deduplication
framework. This changes the design of tracing all the
distinct paths in a loop-body before the first iteration, to trace
during execution. The number of distinct paths grows exponentially
with the number of control flow statements. Tracing all the paths
in advance can be a huge waste and overhead.

We now trace an iteration during execution. We count the number of
distinct paths before the iterations start, and we stop tracing
once all the paths are traced. Tracing during execution fits
very well with our multi-level reuse infrastructure.

Refer to JIRA for detailed discussions.
  • Loading branch information
phaniarnab committed Aug 8, 2020
1 parent a4f992e commit 1101533
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.lineage.Lineage;
import org.apache.sysds.runtime.lineage.LineageDedupUtils;

public class ForProgramBlock extends ProgramBlock
{
Expand Down Expand Up @@ -115,9 +116,9 @@ public void execute(ExecutionContext ec) {
// prepare update in-place variables
UpdateType[] flags = prepareUpdateInPlaceVariables(ec, _tid);

// compute lineage patches for all distinct paths, and store globally
// compute and store the number of distinct paths
if (DMLScript.LINEAGE_DEDUP)
ec.getLineage().computeDedupBlock(this, ec);
ec.getLineage().initializeDedupBlock(this, ec);

// run for loop body for each instance of predicate sequence
SequenceIterator seqIter = new SequenceIterator(from, to, incr);
Expand All @@ -131,17 +132,22 @@ public void execute(ExecutionContext ec) {
Lineage li = ec.getLineage();
li.set(_iterPredVar, li.getOrCreate(new CPOperand(iterVar)));
}
if (DMLScript.LINEAGE_DEDUP)
// create a new dedup map, if needed, to trace this iteration
ec.getLineage().createDedupPatch(this, ec);

//execute all child blocks
for (int i = 0; i < _childBlocks.size(); i++) {
for (int i = 0; i < _childBlocks.size(); i++)
_childBlocks.get(i).execute(ec);
}

if( DMLScript.LINEAGE_DEDUP )
if (DMLScript.LINEAGE_DEDUP) {
LineageDedupUtils.replaceLineage(ec);
// hook the dedup map to the main lineage trace
ec.getLineage().traceCurrentDedupPath();
}
}

// clear current LineageDedupBlock
// clear the current LineageDedupBlock
if (DMLScript.LINEAGE_DEDUP)
ec.getLineage().clearDedupBlock();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.lineage.LineageDedupUtils;


public class WhileProgramBlock extends ProgramBlock
Expand Down Expand Up @@ -98,22 +99,29 @@ public void execute(ExecutionContext ec)
// prepare update in-place variables
UpdateType[] flags = prepareUpdateInPlaceVariables(ec, _tid);

// compute lineage patches for all distinct paths, and store globally
// compute and store the number of distinct paths
if (DMLScript.LINEAGE_DEDUP)
ec.getLineage().computeDedupBlock(this, ec);
ec.getLineage().initializeDedupBlock(this, ec);

//run loop body until predicate becomes false
while( executePredicate(ec).getBooleanValue() ) {
if (DMLScript.LINEAGE_DEDUP)
ec.getLineage().resetDedupPath();

if (DMLScript.LINEAGE_DEDUP)
// create a new dedup map, if needed, to trace this iteration
ec.getLineage().createDedupPatch(this, ec);

//execute all child blocks
for (int i=0 ; i < _childBlocks.size() ; i++) {
_childBlocks.get(i).execute(ec);
}

if( DMLScript.LINEAGE_DEDUP )
if (DMLScript.LINEAGE_DEDUP) {
LineageDedupUtils.replaceLineage(ec);
// hook the dedup map to the main lineage trace
ec.getLineage().traceCurrentDedupPath();
}
}

// clear current LineageDedupBlock
Expand Down
32 changes: 31 additions & 1 deletion src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ public Lineage(Lineage that) {
}

public void trace(Instruction inst, ExecutionContext ec) {
if (_activeDedupBlock == null)
if (_activeDedupBlock == null || !_activeDedupBlock.isAllPathsTaken())
_map.trace(inst, ec);
}

public void traceCurrentDedupPath() {
if( _activeDedupBlock != null ) {
long lpath = _activeDedupBlock.getPath();
LineageDedupUtils.setDedupMap(_activeDedupBlock, lpath);
LineageMap lm = _activeDedupBlock.getMap(lpath);
if (lm != null)
_map.processDedupItem(lm, lpath);
Expand All @@ -82,6 +83,14 @@ public LineageItem get(String varName) {
return _map.get(varName);
}

public void setDedupBlock(LineageDedupBlock ldb) {
_activeDedupBlock = ldb;
}

public LineageMap getLineageMap() {
return _map;
}

public void set(String varName, LineageItem li) {
_map.set(varName, li);
}
Expand Down Expand Up @@ -120,11 +129,32 @@ public void computeDedupBlock(ProgramBlock pb, ExecutionContext ec) {
}
_activeDedupBlock = _dedupBlocks.get(pb); //null if invalid
}

public void initializeDedupBlock(ProgramBlock pb, ExecutionContext ec) {
if( !(pb instanceof ForProgramBlock || pb instanceof WhileProgramBlock) )
throw new DMLRuntimeException("Invalid deduplication block: "+ pb.getClass().getSimpleName());
if (!_dedupBlocks.containsKey(pb)) {
// valid only if doesn't contain a nested loop
boolean valid = LineageDedupUtils.isValidDedupBlock(pb, false);
// count distinct paths and store in the dedupblock
_dedupBlocks.put(pb, valid? LineageDedupUtils.initializeDedupBlock(pb, ec) : null);
}
_activeDedupBlock = _dedupBlocks.get(pb); //null if invalid
}

public void createDedupPatch(ProgramBlock pb, ExecutionContext ec) {
if (_activeDedupBlock != null)
LineageDedupUtils.setNewDedupPatch(_activeDedupBlock, pb, ec);
}

public void clearDedupBlock() {
_activeDedupBlock = null;
}

public void clearLineageMap() {
_map.resetLineageMaps();
}

public Map<String,String> serialize() {
Map<String,String> ret = new HashMap<>();
for (Map.Entry<String,LineageItem> e : _map.getTraces().entrySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class LineageDedupBlock {
private int _numPaths = 0;

private long _activePath = -1;
private ArrayList<Long> _numDistinctPaths = new ArrayList<>();

public LineageMap getActiveMap() {
if (_activePath < 0 || !_distinctPaths.containsKey(_activePath))
Expand All @@ -47,9 +48,11 @@ public LineageMap getActiveMap() {
}

public LineageMap getMap(Long path) {
if (!_distinctPaths.containsKey(path))
throw new DMLRuntimeException("Given path in LineageDedupBlock could not be found.");
return _distinctPaths.get(path);
return _distinctPaths.containsKey(path) ? _distinctPaths.get(path) : null;
}

public void setMap(Long takenPath, LineageMap tracedMap) {
_distinctPaths.put(takenPath, new LineageMap(tracedMap));
}

public boolean pathExists(Long path) {
Expand All @@ -69,6 +72,10 @@ public long getPath() {
_path.toLongArray()[0];
}

public boolean isAllPathsTaken() {
return _distinctPaths.size() == _numDistinctPaths.size();
}

public void traceProgramBlocks(ArrayList<ProgramBlock> pbs, ExecutionContext ec) {
if (_distinctPaths.size() == 0) //main path
_distinctPaths.put(0L, new LineageMap());
Expand Down Expand Up @@ -117,4 +124,36 @@ public void traceBasicProgramBlock(BasicProgramBlock bpb, ExecutionContext ec, C
entry.getValue().trace(inst, ec);
}
}
// compute and save the number of distinct paths
public void setNumPathsInPBs (ArrayList<ProgramBlock> pbs, ExecutionContext ec) {
if (_numDistinctPaths.size() == 0)
_numDistinctPaths.add(0L);
for (ProgramBlock pb : pbs)
numPathsInPB(pb, ec, _numDistinctPaths);
}

private void numPathsInPB(ProgramBlock pb, ExecutionContext ec, ArrayList<Long> paths) {
if (pb instanceof IfProgramBlock)
numPathsInIfPB((IfProgramBlock)pb, ec, paths);
else if (pb instanceof BasicProgramBlock)
return;
else
throw new DMLRuntimeException("Only BasicProgramBlocks or "
+ "IfProgramBlocks are allowed inside a LineageDedupBlock.");
}

private void numPathsInIfPB(IfProgramBlock ipb, ExecutionContext ec, ArrayList<Long> paths) {
ipb.setLineageDedupPathPos(_numPaths++);
ArrayList<Long> rep = new ArrayList<>();
int pathKey = 1 << (_numPaths-1);
for (long p : paths) {
long pathIndex = p | pathKey;
rep.add(pathIndex);
}
_numDistinctPaths.addAll(rep);
for (ProgramBlock pb : ipb.getChildBlocksIfBody())
numPathsInPB(pb, ec, rep);
for (ProgramBlock pb : ipb.getChildBlocksElseBody())
numPathsInPB(pb, ec, paths);
}
}
109 changes: 109 additions & 0 deletions src/main/java/org/apache/sysds/runtime/lineage/LineageDedupUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

package org.apache.sysds.runtime.lineage;

import java.util.ArrayList;

import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
Expand All @@ -27,8 +30,14 @@
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;

public class LineageDedupUtils {
private static Lineage _tmpLineage = null;
private static Lineage _mainLineage = null;
private static ArrayList<Long> _numDistinctPaths = new ArrayList<>();
private static long _maxNumPaths = 0;
private static int _numPaths = 0;

public static boolean isValidDedupBlock(ProgramBlock pb, boolean inLoop) {
// Only the last level loop-body in nested loop structure is valid for deduplication
boolean ret = true; //basic program block
if (pb instanceof FunctionProgramBlock) {
FunctionProgramBlock fsb = (FunctionProgramBlock)pb;
Expand Down Expand Up @@ -64,4 +73,104 @@ public static LineageDedupBlock computeDedupBlock(ProgramBlock fpb, ExecutionCon
ec.getLineage().setInitDedupBlock(null);
return ldb;
}

public static LineageDedupBlock initializeDedupBlock(ProgramBlock fpb, ExecutionContext ec) {
LineageDedupBlock ldb = new LineageDedupBlock();
ec.getLineage().setInitDedupBlock(ldb);
// create/reuse a lineage object to trace the loop iterations
initLocalLineage(ec);
// save the original lineage object
_mainLineage = ec.getLineage();
// count and save the number of distinct paths
ldb.setNumPathsInPBs(fpb.getChildBlocks(), ec);
ec.getLineage().setInitDedupBlock(null);
return ldb;
}

public static void setNewDedupPatch(LineageDedupBlock ldb, ProgramBlock fpb, ExecutionContext ec) {
// no need to trace anymore if all the paths are taken,
// instead reuse the stored maps for this and future interations
// NOTE: this optimization saves redundant tracing, but that
// kills reuse opportunities
if (ldb.isAllPathsTaken())
return;

// copy the input LineageItems of the loop-body
initLocalLineage(ec);
ArrayList<String> inputnames = fpb.getStatementBlock().getInputstoSB();
LineageItem[] liinputs = LineageItemUtils.getLineageItemInputstoSB(inputnames, ec);
// TODO: find the inputs from the ProgramBlock instead of StatementBlock
for (int i=0; i<liinputs.length; i++)
_tmpLineage.set(inputnames.get(i), liinputs[i]);
// also copy the dedupblock to trace the taken path (bitset)
_tmpLineage.setDedupBlock(ldb);
// attach the lineage object to the execution context
ec.setLineage(_tmpLineage);
}

public static void replaceLineage(ExecutionContext ec) {
// replace the local lineage with the original one
ec.setLineage(_mainLineage);
}

public static void setDedupMap(LineageDedupBlock ldb, long takenPath) {
// if this iteration took a new path, store the corresponding map
if (ldb.getMap(takenPath) == null)
ldb.setMap(takenPath, _tmpLineage.getLineageMap());
}

private static void initLocalLineage(ExecutionContext ec) {
_tmpLineage = _tmpLineage == null ? new Lineage() : _tmpLineage;
_tmpLineage.clearLineageMap();
_tmpLineage.clearDedupBlock();
}

/* The below static functions help to compute the number of distinct paths
* in any program block, and are used for diagnostic purposes. These will
* be removed in future.
*/

public static long computeNumPaths(ProgramBlock fpb, ExecutionContext ec) {
if (fpb == null || fpb.getChildBlocks() == null)
return 0;
_numDistinctPaths.clear();
long n = numPathsInPBs(fpb.getChildBlocks(), ec);
if (n > _maxNumPaths) {
_maxNumPaths = n;
System.out.println("\nmax no of paths : " + _maxNumPaths + "\n");
}
return n;
}

public static long numPathsInPBs (ArrayList<ProgramBlock> pbs, ExecutionContext ec) {
if (_numDistinctPaths.size() == 0)
_numDistinctPaths.add(0L);
for (ProgramBlock pb : pbs)
numPathsInPB(pb, ec, _numDistinctPaths);
return _numDistinctPaths.size();
}

private static void numPathsInPB(ProgramBlock pb, ExecutionContext ec, ArrayList<Long> paths) {
if (pb instanceof IfProgramBlock)
numPathsInIfPB((IfProgramBlock)pb, ec, paths);
else if (pb instanceof BasicProgramBlock)
return;
else
return;
}

private static void numPathsInIfPB(IfProgramBlock ipb, ExecutionContext ec, ArrayList<Long> paths) {
ipb.setLineageDedupPathPos(_numPaths++);
ArrayList<Long> rep = new ArrayList<>();
int pathKey = 1 << (_numPaths-1);
for (long p : paths) {
long pathIndex = p | pathKey;
rep.add(pathIndex);
}
_numDistinctPaths.addAll(rep);
for (ProgramBlock pb : ipb.getChildBlocksIfBody())
numPathsInPB(pb, ec, rep);
for (ProgramBlock pb : ipb.getChildBlocksElseBody())
numPathsInPB(pb, ec, paths);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.lineage.LineageItem.LineageItemType;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
Expand Down Expand Up @@ -640,6 +641,9 @@ private static void rSetDedupInputOntoOutput(String name, LineageItem item, Line
if( !tmp.isLiteral() && tmp.getName().equals(name) )
item.getInputs()[i] = dedupInput;
}
if (li.getType() == LineageItemType.Creation) {
item.getInputs()[i] = dedupInput;
}

rSetDedupInputOntoOutput(name, li, dedupInput);
}
Expand Down Expand Up @@ -817,7 +821,7 @@ else if (ins instanceof RandSPInstruction)
}

public static LineageItem[] getLineageItemInputstoSB(ArrayList<String> inputs, ExecutionContext ec) {
if (ReuseCacheType.isNone())
if (ReuseCacheType.isNone() && !DMLScript.LINEAGE_DEDUP)
return null;

ArrayList<CPOperand> CPOpInputs = inputs.size() > 0 ? new ArrayList<>() : null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,10 @@ private void processWriteLI(CPOperand input1, CPOperand input2, ExecutionContext

if (DMLScript.LINEAGE_DEDUP) {
LineageItemUtils.writeTraceToHDFS(Explain.explain(li), fName + ".lineage.dedup");
li = LineageItemUtils.rDecompress(li);
//li = LineageItemUtils.rDecompress(li);
// TODO:gracefully serialize the dedup maps without decompressing
}
LineageItemUtils.writeTraceToHDFS(Explain.explain(li), fName + ".lineage");
else
LineageItemUtils.writeTraceToHDFS(Explain.explain(li), fName + ".lineage");
}
}
Loading

0 comments on commit 1101533

Please sign in to comment.