Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3729] Add roll reorg operations in FED #2126

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,38 @@ public Future<FederatedResponse>[] executeMultipleSlices(long tid, boolean wait,
return ret.toArray(new Future[0]);
}

public Future<FederatedResponse>[] executeRoll(long tid, boolean wait, FederatedRequest frEnd,
FederatedRequest frStart, long rlen) {
// executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
setThreadID(tid, new FederatedRequest[]{frStart, frEnd});
List<Future<FederatedResponse>> ret = new ArrayList<>();

for(Pair<FederatedRange, FederatedData> e : _fedMap) {
if (e.getKey().getEndDims()[0] == rlen) {
ret.add(e.getValue().executeFederatedOperation(frEnd));
} else if (e.getKey().getBeginDims()[0] == 0){
ret.add(e.getValue().executeFederatedOperation(frStart));
}
}

// prepare results (future federated responses), with optional wait to ensure the
// order of requests without data dependencies (e.g., cleanup RPCs)
if(wait)
FederationUtils.waitFor(ret);
return ret.toArray(new Future[0]);
}

public List<Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData() {
if(!isInitialized())
throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData");

List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = new ArrayList<>();
FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, _ID);
for(Pair<FederatedRange, FederatedData> e : _fedMap)

for(Pair<FederatedRange, FederatedData> e : _fedMap){
FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, e.getValue().getVarID());
readResponses.add(Pair.of(e.getKey(), e.getValue().executeFederatedOperation(request)));
}

return readResponses;
}

Expand Down Expand Up @@ -692,6 +716,7 @@ public void reverseFedMap() {
}
}


private static class MappingTask implements Callable<Void> {
private final FederatedRange _range;
private final FederatedData _data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "r'" , FEDType.Reorg );
String2FEDInstructionType.put( "rdiag" , FEDType.Reorg );
String2FEDInstructionType.put( "rev" , FEDType.Reorg );
String2FEDInstructionType.put( "roll" , FEDType.Reorg );
//String2FEDInstructionType.put( "rshape" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser!
//String2FEDInstructionType.put( "rsort" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser!

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand c
* @param istr ?
*/
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
super(CPType.Reorg, op, in, out, opcode, istr);
super(CPType.Reorg, op, in, shift, out, opcode, istr);
_col = null;
_desc = null;
_ixret = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
Expand All @@ -57,6 +59,8 @@
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class ReorgFEDInstruction extends UnaryFEDInstruction {
// roll-specific attributes
private CPOperand _shift = null;

public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
Expand All @@ -66,14 +70,29 @@ public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opc
super(FEDType.Reorg, op, in1, out, opcode, istr);
}

private ReorgFEDInstruction(Operator op, CPOperand in, CPOperand shift, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
super(FEDType.Reorg, op, in, shift, out, opcode, istr, fedOut);
_shift = shift;
}

public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction rinst) {
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
if (rinst.input2 != null) {
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
} else{
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
}
}

public static ReorgFEDInstruction parseInstruction(ReorgSPInstruction rinst) {
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
if (rinst.input2 != null) {
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
} else{
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
}
}

public static ReorgFEDInstruction parseInstruction(String str) {
Expand Down Expand Up @@ -105,6 +124,15 @@ else if(opcode.equalsIgnoreCase("rev")) {
return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str,
fedOut);
}
else if (opcode.equalsIgnoreCase("roll")) {
InstructionUtils.checkNumFields(str, 3);
in.split(parts[1]);
out.split(parts[3]);
CPOperand shift = new CPOperand(parts[2]);
fedOut = parseFedOutFlag(str, 3);
return new ReorgFEDInstruction(new ReorgOperator(new RollIndex(0)),
in, out, shift, opcode, str, fedOut);
}
else {
throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: " + opcode);
}
Expand Down Expand Up @@ -167,6 +195,39 @@ else if(instOpcode.equalsIgnoreCase("rev")) {
.setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));

optionalForceLocal(out);
} else if (instOpcode.equalsIgnoreCase("roll")) {
long rlen = mo1.getNumRows();
long shift = ec.getScalarInput(_shift).getLongValue();
shift %= (rlen != 0 ? rlen : 1); // roll matrix with axis=none

long inID = mo1.getFedMapping().getID();
long outEndID = FederationUtils.getNextFedDataID();
long outStartID = FederationUtils.getNextFedDataID();

List<Pair<FederatedRange, FederatedData>> inMap = mo1.getFedMapping().getMap();
Pair<FederationMap, Long> rollResult = rollFedMap(inMap, inID, outEndID, outStartID, shift,
rlen, mo1.getFedMapping().getType());
long length = rollResult.getValue();
FederationMap outFedMap = rollResult.getKey();

FederatedRequest frEnd = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outEndID,
new ReorgFEDInstruction.SliceMatrix(inID, outEndID, length, true));
// FederatedRequest frCopy = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outEndID,
// new ReorgFEDInstruction.SliceMatrix(inID, outEndID, 0, true));
FederatedRequest frStart = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outStartID,
new ReorgFEDInstruction.SliceMatrix(inID, outStartID, length, false));
Future<FederatedResponse>[] ffr = outFedMap.executeRoll(getTID(), true, frEnd, frStart, rlen);
// Future<FederatedResponse>[] ffr = outFedMap.executeRoll(getTID(), true, frEnd, frStart, frCopy, rlen);

//derive output federated mapping
MatrixObject out = ec.getMatrixObject(output);
long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr);
out.getDataCharacteristics()
.setDimension(mo1.getNumRows(), mo1.getNumColumns())
.setBlocksize(mo1.getBlocksize())
.setNonZeros(nnz);
out.setFedMapping(outFedMap);
optionalForceLocal(out);
}
else if (instOpcode.equals("rdiag")) {
Expand All @@ -189,6 +250,40 @@ else if (instOpcode.equals("rdiag")) {
}
}


public Pair<FederationMap, Long> rollFedMap(List<Pair<FederatedRange, FederatedData>> oldMap, long inID,
long outEndID, long outStartID, long shift, long rlen, FType type) {
List<Pair<FederatedRange, FederatedData>> map = new ArrayList<>();
long length = 0;

for(Map.Entry<FederatedRange, FederatedData> e : oldMap) {
if(e.getKey().getSize() == 0) continue;
FederatedRange fedRange = new FederatedRange(e.getKey());
long beginRow = fedRange.getBeginDims()[0] + shift;
long endRow = fedRange.getEndDims()[0] + shift;

beginRow = beginRow > rlen ? beginRow - rlen : beginRow;
endRow = endRow > rlen ? endRow - rlen : endRow;

if (beginRow < endRow) {
fedRange.setBeginDim(0, beginRow);
fedRange.setEndDim(0, endRow);
map.add(Pair.of(fedRange, e.getValue().copyWithNewID(inID)));
} else {
length = rlen - beginRow;
fedRange.setBeginDim(0, beginRow);
fedRange.setEndDim(0, rlen);
map.add(Pair.of(fedRange, e.getValue().copyWithNewID(outEndID)));

FederatedRange startRange = new FederatedRange(fedRange);
startRange.setBeginDim(0, 0);
startRange.setEndDim(0, endRow);
map.add(Pair.of(startRange, e.getValue().copyWithNewID(outStartID)));
}
}
return Pair.of(new FederationMap(outEndID, map, type), length);
}

/**
* Update the federated ranges of result and return the updated federation map.
* @param result RdiagResult for which the fedmap is updated
Expand Down Expand Up @@ -307,6 +402,51 @@ private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) {
return new RdiagResult(diagFedMap, dcs);
}

public static class SliceMatrix extends FederatedUDF {
private static final long serialVersionUID = -3466926635958851402L;
private final long _outputID;
private final int _sliceRow;
private final boolean _isRight;

private SliceMatrix(long input, long outputID, long sliceRow, boolean isRight) {
super(new long[] {input});
_outputID = outputID;
_sliceRow = (int) sliceRow;
_isRight = isRight;
}

@Override
public FederatedResponse execute(ExecutionContext ec, Data... data) {
MatrixBlock oriBlock = ((MatrixObject) data[0]).acquireReadAndRelease();
MatrixBlock resBlock;

if (_sliceRow != 0){
if (_isRight){
resBlock = oriBlock.slice(0, _sliceRow-1, 0,
oriBlock.getNumColumns()-1, new MatrixBlock());
} else{
resBlock = oriBlock.slice(_sliceRow, oriBlock.getNumRows()-1,
0, oriBlock.getNumColumns()-1, new MatrixBlock());
}
} else{
resBlock = oriBlock;
}
ec.setMatrixOutput(String.valueOf(_outputID), resBlock);
return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, resBlock);
}

@Override
public List<Long> getOutputIds() {
return new ArrayList<>(Arrays.asList(_outputID));
}

@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
return Pair.of(String.valueOf(_outputID),
new LineageItem());
}
}

public static class Rdiag extends FederatedUDF {

private static final long serialVersionUID = -3466926635958851402L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ public static UnaryFEDInstruction parseInstruction(UnaryCPInstruction inst, Exec
}
}
else if(inst instanceof ReorgCPInstruction &&
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
|| inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) {
ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
CacheableData<?> mo = ec.getCacheableData(rinst.input1);

Expand Down Expand Up @@ -157,7 +158,8 @@ else if(inst instanceof AggregateUnarySPInstruction) {
return AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
}
else if(inst instanceof ReorgSPInstruction &&
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
|| inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) {
ReorgSPInstruction rinst = (ReorgSPInstruction) inst;
CacheableData<?> mo = ec.getCacheableData(rinst.input1);
if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederated() &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand d
}

private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
this(op, in, out, opcode, istr);
super(SPType.Reorg, op, in, shift, null, out, opcode, istr);
_shift = shift;
}

Expand Down
Loading
Loading