From 4ff19e1ea1a465a512670bb168d53b08ee82e4c0 Mon Sep 17 00:00:00 2001 From: min-guk Date: Sat, 28 Sep 2024 21:07:36 +0900 Subject: [PATCH 1/2] [SYSTEMDS-3729] Add roll reorg operations in FED --- .../federated/FederationMap.java | 61 ++++++ .../instructions/FEDInstructionParser.java | 1 + .../instructions/cp/ReorgCPInstruction.java | 2 +- .../instructions/fed/ReorgFEDInstruction.java | 115 ++++++++++- .../instructions/fed/UnaryFEDInstruction.java | 6 +- .../primitives/part2/FederatedRollTest.java | 185 ++++++++++++++++++ .../functions/federated/FederatedRollTest.dml | 32 +++ .../federated/FederatedRollTestReference.dml | 26 +++ 8 files changed, 421 insertions(+), 7 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java create mode 100644 src/test/scripts/functions/federated/FederatedRollTest.dml create mode 100644 src/test/scripts/functions/federated/FederatedRollTestReference.dml diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 985fdb056e5..579f2c17b6e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -406,6 +406,29 @@ public Future[] executeMultipleSlices(long tid, boolean wait, return ret.toArray(new Future[0]); } + public Future[] executeRoll(long tid, boolean wait, FederatedRequest fr, FederatedRequest frEnd, + FederatedRequest frStart, FederatedRequest frCopy) { + // executes step1[] - step 2 - ... step4 (only first step federated-data-specific) + setThreadID(tid, new FederatedRequest[]{fr, frCopy, frStart, frEnd}); + List> ret = new ArrayList<>(); + + for(Pair e : _fedMap) { + if (e.getKey().getEndDims()[0] == 100) { + ret.add(e.getValue().executeFederatedOperation(fr, frEnd)); + } else if (e.getKey().getBeginDims()[0] == 0){ + ret.add(e.getValue().executeFederatedOperation(fr, frStart)); + } else{ + ret.add(e.getValue().executeFederatedOperation(fr, frCopy)); + } + } + + // prepare results (future federated responses), with optional wait to ensure the + // order of requests without data dependencies (e.g., cleanup RPCs) + if(wait) + FederationUtils.waitFor(ret); + return ret.toArray(new Future[0]); + } + public List>> requestFederatedData() { if(!isInitialized()) throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData"); @@ -692,6 +715,44 @@ public void reverseFedMap() { } } + public long rollFedMap(long shift, long rlen) { + long length = 0; + + int size = _fedMap.size(); + + for (int i = 0; i < size; i++) { + Pair entry = _fedMap.get(i); + FederatedRange fedRange = entry.getKey(); + + long beginRow = fedRange.getBeginDims()[0] + shift; + long endRow = fedRange.getEndDims()[0] + shift; + + beginRow = beginRow > rlen ? beginRow - rlen : beginRow; + endRow = endRow > rlen ? endRow - rlen : endRow; + + if (beginRow < endRow) { + fedRange.setBeginDim(0, beginRow); + fedRange.setEndDim(0, endRow); + } else { + FederatedData fedData = entry.getValue(); + + // End block + fedRange.setBeginDim(0, beginRow); + fedRange.setEndDim(0, rlen); + length = rlen - beginRow; + + // Start block + FederatedRange startRange = new FederatedRange(fedRange); + startRange.setBeginDim(0, 0); + startRange.setEndDim(0, endRow); + + _fedMap.add(Pair.of(startRange, fedData)); + } + } + + return length; + } + private static class MappingTask implements Callable { private final FederatedRange _range; private final FederatedData _data; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java index f61e86e800b..820d07031d6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java @@ -86,6 +86,7 @@ public class FEDInstructionParser extends InstructionParser String2FEDInstructionType.put( "r'" , FEDType.Reorg ); String2FEDInstructionType.put( "rdiag" , FEDType.Reorg ); String2FEDInstructionType.put( "rev" , FEDType.Reorg ); + String2FEDInstructionType.put( "roll" , FEDType.Reorg ); //String2FEDInstructionType.put( "rshape" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser! //String2FEDInstructionType.put( "rsort" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser! diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java index e7b3000d52e..ab105a95855 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java @@ -86,7 +86,7 @@ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand c * @param istr ? */ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) { - super(CPType.Reorg, op, in, out, opcode, istr); + super(CPType.Reorg, op, in, shift, out, opcode, istr); _col = null; _desc = null; _ixret = null; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java index c10ca272593..93b1a105846 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java @@ -43,6 +43,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.functionobjects.DiagIndex; import org.apache.sysds.runtime.functionobjects.RevIndex; +import org.apache.sysds.runtime.functionobjects.RollIndex; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; @@ -57,6 +58,8 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics; public class ReorgFEDInstruction extends UnaryFEDInstruction { + // roll-specific attributes + private CPOperand _shift = null; public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) { super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut); @@ -66,14 +69,29 @@ public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opc super(FEDType.Reorg, op, in1, out, opcode, istr); } + private ReorgFEDInstruction(Operator op, CPOperand in, CPOperand shift, CPOperand out, String opcode, String istr) { + super(FEDType.Reorg, op, in, shift, out, opcode, istr); + _shift = shift; + } + public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction rinst) { - return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), - rinst.getInstructionString(), FederatedOutput.NONE); + if (rinst.input2 != null) { + return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(), + rinst.getInstructionString()); + } else{ + return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), + rinst.getInstructionString(), FederatedOutput.NONE); + } } public static ReorgFEDInstruction parseInstruction(ReorgSPInstruction rinst) { - return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), - rinst.getInstructionString(), FederatedOutput.NONE); + if (rinst.input2 != null) { + return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(), + rinst.getInstructionString()); + } else{ + return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), + rinst.getInstructionString(), FederatedOutput.NONE); + } } public static ReorgFEDInstruction parseInstruction(String str) { @@ -105,6 +123,14 @@ else if(opcode.equalsIgnoreCase("rev")) { return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str, fedOut); } + else if (opcode.equalsIgnoreCase("roll")) { + InstructionUtils.checkNumFields(str, 3); + in.split(parts[1]); + out.split(parts[3]); + CPOperand shift = new CPOperand(parts[2]); + return new ReorgFEDInstruction(new ReorgOperator(new RollIndex(0)), + in, out, shift, opcode, str); + } else { throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: " + opcode); } @@ -169,6 +195,41 @@ else if(instOpcode.equalsIgnoreCase("rev")) { optionalForceLocal(out); } + else if (instOpcode.equalsIgnoreCase("roll")) { + long inID = mo1.getFedMapping().getID(); + long outID = FederationUtils.getNextFedDataID(); + + long rlen = mo1.getNumRows(); + long shift = ec.getScalarInput(_shift).getLongValue(); + shift %= (rlen != 0 ? rlen : 1); // roll matrix with axis=none + + FederationMap outFedMap = mo1.getFedMapping().copyWithNewID(outID); + long length = outFedMap.rollFedMap(shift, rlen); + + FederatedRequest fr = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, outID, + new MatrixCharacteristics(-1, -1), mo1.getDataType()); + + FederatedRequest frCopy = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, + new ReorgFEDInstruction.SplitRow(mo1.getFedMapping().getID(), outID, 0, false)); + + FederatedRequest frEnd = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, + new ReorgFEDInstruction.SplitRow(mo1.getFedMapping().getID(), outID, length, true)); + + FederatedRequest frStart = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, + new ReorgFEDInstruction.SplitRow(inID, outID, length, false)); + + Future[] ffr = outFedMap.executeRoll(getTID(), true, fr, frEnd, frStart, frCopy); + + //derive output federated mapping + MatrixObject out = ec.getMatrixObject(output); + long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr); + out.getDataCharacteristics() + .setDimension(mo1.getNumRows(), mo1.getNumColumns()) + .setBlocksize(mo1.getBlocksize()) + .setNonZeros(nnz); + out.setFedMapping(outFedMap); + optionalForceLocal(out); + } else if (instOpcode.equals("rdiag")) { RdiagResult result; // diag(diag(X)) @@ -307,6 +368,52 @@ private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) { return new RdiagResult(diagFedMap, dcs); } + public static class SplitRow extends FederatedUDF { + private static final long serialVersionUID = -3466926635958851402L; + private final long _outputID; + private final int _sliceRow; + private final boolean _isRight; + + private SplitRow(long input, long outputID, long sliceRow, boolean isRight) { + super(new long[] {input}); + _outputID = outputID; + _sliceRow = (int) sliceRow; + _isRight = isRight; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixBlock oriBlock = ((MatrixObject) data[0]).acquireReadAndRelease(); + MatrixBlock resBlock; + + if (_sliceRow == 0){ + ec.setMatrixOutput(String.valueOf(_outputID), oriBlock); + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new int[]{oriBlock.getNumRows(), oriBlock.getNumColumns()}); + } + + if (_isRight){ + resBlock = oriBlock.slice(0, _sliceRow-1, 0, oriBlock.getNumColumns()-1, new MatrixBlock()); + ec.setMatrixOutput(String.valueOf(_outputID), resBlock); + } else { + resBlock = oriBlock.slice(_sliceRow, oriBlock.getNumRows()-1, 0, oriBlock.getNumColumns()-1, new MatrixBlock()); + ec.setMatrixOutput(String.valueOf(_outputID), resBlock); + } + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new int[]{resBlock.getNumRows(), resBlock.getNumColumns()}); + } + + @Override + public List getOutputIds() { + return new ArrayList<>(Arrays.asList(_outputID)); + } + + @Override + public Pair getLineageItem(ExecutionContext ec) { + return Pair.of(String.valueOf(_outputID), + new LineageItem()); + } + } + public static class Rdiag extends FederatedUDF { private static final long serialVersionUID = -3466926635958851402L; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java index f025983e741..2311a1afe26 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java @@ -88,7 +88,8 @@ public static UnaryFEDInstruction parseInstruction(UnaryCPInstruction inst, Exec } } else if(inst instanceof ReorgCPInstruction && - (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) { + (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") + || inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) { ReorgCPInstruction rinst = (ReorgCPInstruction) inst; CacheableData mo = ec.getCacheableData(rinst.input1); @@ -157,7 +158,8 @@ else if(inst instanceof AggregateUnarySPInstruction) { return AggregateUnaryFEDInstruction.parseInstruction(auinstruction); } else if(inst instanceof ReorgSPInstruction && - (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) { + (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") + || inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) { ReorgSPInstruction rinst = (ReorgSPInstruction) inst; CacheableData mo = ec.getCacheableData(rinst.input1); if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederated() && diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java new file mode 100644 index 00000000000..2a20c696799 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.primitives.part2; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FederatedRollTest extends AutomatedTestBase { + // private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName()); + + private final static String TEST_NAME = "FederatedRollTest"; + + private final static String TEST_DIR = "functions/federated/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRollTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Parameterized.Parameter(2) + public boolean rowPartitioned; + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList(new Object[][] {{100, 12, true}, {100, 12, false}}); + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"})); + } + + @Test + public void testRevCP() { + runRevTest(ExecMode.SINGLE_NODE); + } + + @Test + public void testRevSP() { + runRevTest(ExecMode.SPARK); + } + + @Test + public void federatedCompilationRevCP() { + runRevTest(ExecMode.SINGLE_NODE, true); + } + + @Test + public void federatedCompilationRevSP() { + runRevTest(ExecMode.SPARK, true); + } + + private void runRevTest(ExecMode execMode) { + runRevTest(execMode, false); + } + + private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + ExecMode platformOld = rtplatform; + + if(rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int r = rows; + int c = cols / 4; + if(rowPartitioned) { + r = rows / 4; + c = cols; + } + + double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + for(int k : new int[] {1, 2, 3}) { + Arrays.fill(X3[k], 0); + } + + MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + writeInputMatrixWithMTD("X3", X3, false, mc); + writeInputMatrixWithMTD("X4", X4, false, mc); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + int port3 = getRandomAvailablePort(); + int port4 = getRandomAvailablePort(); + Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); + Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); + Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); + Process t4 = startLocalFedWorker(port4); + + + try { + if(!isAlive(t1, t2, t3, t4)) + throw new RuntimeException("Failed starting federated worker"); + rtplatform = execMode; + if(rtplatform == ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), + Boolean.toString(rowPartitioned).toUpperCase(), expected("S")}; + + runTest(null); + + OptimizerUtils.FEDERATED_COMPILATION = activateFedCompilation; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, + "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")}; + + runTest(null); + + // compare via files + compareResults(0.01, "Stat-DML1", "Stat-DML2"); + + Assert.assertTrue(heavyHittersContainsString("fed_roll")); + + // check that federated input files are still existing + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); + + } + finally { + TestUtils.shutdownThreads(t1, t2, t3, t4); + + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + OptimizerUtils.FEDERATED_COMPILATION = false; + } + } +} diff --git a/src/test/scripts/functions/federated/FederatedRollTest.dml b/src/test/scripts/functions/federated/FederatedRollTest.dml new file mode 100644 index 00000000000..cb464256ed8 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRollTest.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- + # + # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file + # distributed with this work for additional information + # regarding copyright ownership. The ASF licenses this file + # to you under the Apache License, Version 2.0 (the + # "License"); you may not use this file except in compliance + # with the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, + # software distributed under the License is distributed on an + # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + # KIND, either express or implied. See the License for the + # specific language governing permissions and limitations + # under the License. + # + #------------------------------------------------------------- +if ($rP) { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); +} else { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2), + list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols))); +} + +s = roll(A, 1); +write(s, $out_S); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedRollTestReference.dml b/src/test/scripts/functions/federated/FederatedRollTestReference.dml new file mode 100644 index 00000000000..694bd5f1d4a --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRollTestReference.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- + # + # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file + # distributed with this work for additional information + # regarding copyright ownership. The ASF licenses this file + # to you under the Apache License, Version 2.0 (the + # "License"); you may not use this file except in compliance + # with the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, + # software distributed under the License is distributed on an + # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + # KIND, either express or implied. See the License for the + # specific language governing permissions and limitations + # under the License. + # + #------------------------------------------------------------- + + if($5) { A = rbind(read($1), read($2), read($3), read($4)); } + else { A = cbind(read($1), read($2), read($3), read($4)); } + + s = roll(A, 1); + write(s, $6); \ No newline at end of file From 977bb9e708346c235e9f645a27081b1d4204aa6c Mon Sep 17 00:00:00 2001 From: min-guk Date: Mon, 14 Oct 2024 05:34:31 +0900 Subject: [PATCH 2/2] [SYSTEMDS-3729] Fix FED roll function (CP works, SP fails) --- .../federated/FederationMap.java | 58 +--- .../instructions/fed/ReorgFEDInstruction.java | 115 ++++--- .../spark/ReorgSPInstruction.java | 2 +- .../primitives/part2/FederatedRollTest.java | 287 +++++++++--------- 4 files changed, 229 insertions(+), 233 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 579f2c17b6e..6820ff37eeb 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -406,19 +406,17 @@ public Future[] executeMultipleSlices(long tid, boolean wait, return ret.toArray(new Future[0]); } - public Future[] executeRoll(long tid, boolean wait, FederatedRequest fr, FederatedRequest frEnd, - FederatedRequest frStart, FederatedRequest frCopy) { + public Future[] executeRoll(long tid, boolean wait, FederatedRequest frEnd, + FederatedRequest frStart, long rlen) { // executes step1[] - step 2 - ... step4 (only first step federated-data-specific) - setThreadID(tid, new FederatedRequest[]{fr, frCopy, frStart, frEnd}); + setThreadID(tid, new FederatedRequest[]{frStart, frEnd}); List> ret = new ArrayList<>(); for(Pair e : _fedMap) { - if (e.getKey().getEndDims()[0] == 100) { - ret.add(e.getValue().executeFederatedOperation(fr, frEnd)); + if (e.getKey().getEndDims()[0] == rlen) { + ret.add(e.getValue().executeFederatedOperation(frEnd)); } else if (e.getKey().getBeginDims()[0] == 0){ - ret.add(e.getValue().executeFederatedOperation(fr, frStart)); - } else{ - ret.add(e.getValue().executeFederatedOperation(fr, frCopy)); + ret.add(e.getValue().executeFederatedOperation(frStart)); } } @@ -434,9 +432,12 @@ public List>> requestFederatedDat throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData"); List>> readResponses = new ArrayList<>(); - FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, _ID); - for(Pair e : _fedMap) + + for(Pair e : _fedMap){ + FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, e.getValue().getVarID()); readResponses.add(Pair.of(e.getKey(), e.getValue().executeFederatedOperation(request))); + } + return readResponses; } @@ -715,43 +716,6 @@ public void reverseFedMap() { } } - public long rollFedMap(long shift, long rlen) { - long length = 0; - - int size = _fedMap.size(); - - for (int i = 0; i < size; i++) { - Pair entry = _fedMap.get(i); - FederatedRange fedRange = entry.getKey(); - - long beginRow = fedRange.getBeginDims()[0] + shift; - long endRow = fedRange.getEndDims()[0] + shift; - - beginRow = beginRow > rlen ? beginRow - rlen : beginRow; - endRow = endRow > rlen ? endRow - rlen : endRow; - - if (beginRow < endRow) { - fedRange.setBeginDim(0, beginRow); - fedRange.setEndDim(0, endRow); - } else { - FederatedData fedData = entry.getValue(); - - // End block - fedRange.setBeginDim(0, beginRow); - fedRange.setEndDim(0, rlen); - length = rlen - beginRow; - - // Start block - FederatedRange startRange = new FederatedRange(fedRange); - startRange.setBeginDim(0, 0); - startRange.setEndDim(0, endRow); - - _fedMap.add(Pair.of(startRange, fedData)); - } - } - - return length; - } private static class MappingTask implements Callable { private final FederatedRange _range; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java index 93b1a105846..85af01b6625 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java @@ -36,6 +36,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; @@ -69,15 +70,15 @@ public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opc super(FEDType.Reorg, op, in1, out, opcode, istr); } - private ReorgFEDInstruction(Operator op, CPOperand in, CPOperand shift, CPOperand out, String opcode, String istr) { - super(FEDType.Reorg, op, in, shift, out, opcode, istr); + private ReorgFEDInstruction(Operator op, CPOperand in, CPOperand shift, CPOperand out, String opcode, String istr, FederatedOutput fedOut) { + super(FEDType.Reorg, op, in, shift, out, opcode, istr, fedOut); _shift = shift; } public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction rinst) { if (rinst.input2 != null) { return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(), - rinst.getInstructionString()); + rinst.getInstructionString(), FederatedOutput.NONE); } else{ return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), rinst.getInstructionString(), FederatedOutput.NONE); @@ -87,7 +88,7 @@ public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction rinst) { public static ReorgFEDInstruction parseInstruction(ReorgSPInstruction rinst) { if (rinst.input2 != null) { return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(), - rinst.getInstructionString()); + rinst.getInstructionString(), FederatedOutput.NONE); } else{ return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), rinst.getInstructionString(), FederatedOutput.NONE); @@ -128,8 +129,9 @@ else if (opcode.equalsIgnoreCase("roll")) { in.split(parts[1]); out.split(parts[3]); CPOperand shift = new CPOperand(parts[2]); + fedOut = parseFedOutFlag(str, 3); return new ReorgFEDInstruction(new ReorgOperator(new RollIndex(0)), - in, out, shift, opcode, str); + in, out, shift, opcode, str, fedOut); } else { throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: " + opcode); @@ -194,31 +196,29 @@ else if(instOpcode.equalsIgnoreCase("rev")) { out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID())); optionalForceLocal(out); - } - else if (instOpcode.equalsIgnoreCase("roll")) { - long inID = mo1.getFedMapping().getID(); - long outID = FederationUtils.getNextFedDataID(); - + } else if (instOpcode.equalsIgnoreCase("roll")) { long rlen = mo1.getNumRows(); long shift = ec.getScalarInput(_shift).getLongValue(); shift %= (rlen != 0 ? rlen : 1); // roll matrix with axis=none - FederationMap outFedMap = mo1.getFedMapping().copyWithNewID(outID); - long length = outFedMap.rollFedMap(shift, rlen); - - FederatedRequest fr = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, outID, - new MatrixCharacteristics(-1, -1), mo1.getDataType()); - - FederatedRequest frCopy = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, - new ReorgFEDInstruction.SplitRow(mo1.getFedMapping().getID(), outID, 0, false)); - - FederatedRequest frEnd = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, - new ReorgFEDInstruction.SplitRow(mo1.getFedMapping().getID(), outID, length, true)); - - FederatedRequest frStart = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, - new ReorgFEDInstruction.SplitRow(inID, outID, length, false)); - - Future[] ffr = outFedMap.executeRoll(getTID(), true, fr, frEnd, frStart, frCopy); + long inID = mo1.getFedMapping().getID(); + long outEndID = FederationUtils.getNextFedDataID(); + long outStartID = FederationUtils.getNextFedDataID(); + + List> inMap = mo1.getFedMapping().getMap(); + Pair rollResult = rollFedMap(inMap, inID, outEndID, outStartID, shift, + rlen, mo1.getFedMapping().getType()); + long length = rollResult.getValue(); + FederationMap outFedMap = rollResult.getKey(); + + FederatedRequest frEnd = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outEndID, + new ReorgFEDInstruction.SliceMatrix(inID, outEndID, length, true)); +// FederatedRequest frCopy = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outEndID, +// new ReorgFEDInstruction.SliceMatrix(inID, outEndID, 0, true)); + FederatedRequest frStart = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outStartID, + new ReorgFEDInstruction.SliceMatrix(inID, outStartID, length, false)); + Future[] ffr = outFedMap.executeRoll(getTID(), true, frEnd, frStart, rlen); +// Future[] ffr = outFedMap.executeRoll(getTID(), true, frEnd, frStart, frCopy, rlen); //derive output federated mapping MatrixObject out = ec.getMatrixObject(output); @@ -250,6 +250,40 @@ else if (instOpcode.equals("rdiag")) { } } + + public Pair rollFedMap(List> oldMap, long inID, + long outEndID, long outStartID, long shift, long rlen, FType type) { + List> map = new ArrayList<>(); + long length = 0; + + for(Map.Entry e : oldMap) { + if(e.getKey().getSize() == 0) continue; + FederatedRange fedRange = new FederatedRange(e.getKey()); + long beginRow = fedRange.getBeginDims()[0] + shift; + long endRow = fedRange.getEndDims()[0] + shift; + + beginRow = beginRow > rlen ? beginRow - rlen : beginRow; + endRow = endRow > rlen ? endRow - rlen : endRow; + + if (beginRow < endRow) { + fedRange.setBeginDim(0, beginRow); + fedRange.setEndDim(0, endRow); + map.add(Pair.of(fedRange, e.getValue().copyWithNewID(inID))); + } else { + length = rlen - beginRow; + fedRange.setBeginDim(0, beginRow); + fedRange.setEndDim(0, rlen); + map.add(Pair.of(fedRange, e.getValue().copyWithNewID(outEndID))); + + FederatedRange startRange = new FederatedRange(fedRange); + startRange.setBeginDim(0, 0); + startRange.setEndDim(0, endRow); + map.add(Pair.of(startRange, e.getValue().copyWithNewID(outStartID))); + } + } + return Pair.of(new FederationMap(outEndID, map, type), length); + } + /** * Update the federated ranges of result and return the updated federation map. * @param result RdiagResult for which the fedmap is updated @@ -368,13 +402,13 @@ private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) { return new RdiagResult(diagFedMap, dcs); } - public static class SplitRow extends FederatedUDF { + public static class SliceMatrix extends FederatedUDF { private static final long serialVersionUID = -3466926635958851402L; private final long _outputID; private final int _sliceRow; private final boolean _isRight; - private SplitRow(long input, long outputID, long sliceRow, boolean isRight) { + private SliceMatrix(long input, long outputID, long sliceRow, boolean isRight) { super(new long[] {input}); _outputID = outputID; _sliceRow = (int) sliceRow; @@ -386,20 +420,19 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { MatrixBlock oriBlock = ((MatrixObject) data[0]).acquireReadAndRelease(); MatrixBlock resBlock; - if (_sliceRow == 0){ - ec.setMatrixOutput(String.valueOf(_outputID), oriBlock); - return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new int[]{oriBlock.getNumRows(), oriBlock.getNumColumns()}); - } - - if (_isRight){ - resBlock = oriBlock.slice(0, _sliceRow-1, 0, oriBlock.getNumColumns()-1, new MatrixBlock()); - ec.setMatrixOutput(String.valueOf(_outputID), resBlock); - } else { - resBlock = oriBlock.slice(_sliceRow, oriBlock.getNumRows()-1, 0, oriBlock.getNumColumns()-1, new MatrixBlock()); - ec.setMatrixOutput(String.valueOf(_outputID), resBlock); + if (_sliceRow != 0){ + if (_isRight){ + resBlock = oriBlock.slice(0, _sliceRow-1, 0, + oriBlock.getNumColumns()-1, new MatrixBlock()); + } else{ + resBlock = oriBlock.slice(_sliceRow, oriBlock.getNumRows()-1, + 0, oriBlock.getNumColumns()-1, new MatrixBlock()); + } + } else{ + resBlock = oriBlock; } - - return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new int[]{resBlock.getNumRows(), resBlock.getNumColumns()}); + ec.setMatrixOutput(String.valueOf(_outputID), resBlock); + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, resBlock); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java index b096405959b..1a4f8fef0da 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java @@ -85,7 +85,7 @@ private ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand d } private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) { - this(op, in, out, opcode, istr); + super(SPType.Reorg, op, in, shift, null, out, opcode, istr); _shift = shift; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java index 2a20c696799..bb693719ced 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java @@ -38,148 +38,147 @@ @RunWith(value = Parameterized.class) @net.jcip.annotations.NotThreadSafe public class FederatedRollTest extends AutomatedTestBase { - // private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName()); - - private final static String TEST_NAME = "FederatedRollTest"; - - private final static String TEST_DIR = "functions/federated/"; - private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRollTest.class.getSimpleName() + "/"; - - private final static int blocksize = 1024; - @Parameterized.Parameter() - public int rows; - @Parameterized.Parameter(1) - public int cols; - - @Parameterized.Parameter(2) - public boolean rowPartitioned; - - @Parameterized.Parameters - public static Collection data() { - return Arrays.asList(new Object[][] {{100, 12, true}, {100, 12, false}}); - } - - @Override - public void setUp() { - TestUtils.clearAssertionInformation(); - addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"})); - } - - @Test - public void testRevCP() { - runRevTest(ExecMode.SINGLE_NODE); - } - - @Test - public void testRevSP() { - runRevTest(ExecMode.SPARK); - } - - @Test - public void federatedCompilationRevCP() { - runRevTest(ExecMode.SINGLE_NODE, true); - } - - @Test - public void federatedCompilationRevSP() { - runRevTest(ExecMode.SPARK, true); - } - - private void runRevTest(ExecMode execMode) { - runRevTest(execMode, false); - } - - private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - ExecMode platformOld = rtplatform; - - if(rtplatform == ExecMode.SPARK) - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - - getAndLoadTestConfiguration(TEST_NAME); - String HOME = SCRIPT_DIR + TEST_DIR; - - // write input matrices - int r = rows; - int c = cols / 4; - if(rowPartitioned) { - r = rows / 4; - c = cols; - } - - double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); - double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); - double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); - double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); - - for(int k : new int[] {1, 2, 3}) { - Arrays.fill(X3[k], 0); - } - - MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); - writeInputMatrixWithMTD("X1", X1, false, mc); - writeInputMatrixWithMTD("X2", X2, false, mc); - writeInputMatrixWithMTD("X3", X3, false, mc); - writeInputMatrixWithMTD("X4", X4, false, mc); - - // empty script name because we don't execute any script, just start the worker - fullDMLScriptName = ""; - int port1 = getRandomAvailablePort(); - int port2 = getRandomAvailablePort(); - int port3 = getRandomAvailablePort(); - int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); - - - try { - if(!isAlive(t1, t2, t3, t4)) - throw new RuntimeException("Failed starting federated worker"); - rtplatform = execMode; - if(rtplatform == ExecMode.SPARK) { - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - } - TestConfiguration config = availableTestConfigurations.get(TEST_NAME); - loadTestConfiguration(config); - - // Run reference dml script with normal matrix - fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; - programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), - Boolean.toString(rowPartitioned).toUpperCase(), expected("S")}; - - runTest(null); - - OptimizerUtils.FEDERATED_COMPILATION = activateFedCompilation; - fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[] {"-stats", "100", "-nvargs", - "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), - "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), - "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), - "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, - "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")}; - - runTest(null); - - // compare via files - compareResults(0.01, "Stat-DML1", "Stat-DML2"); - - Assert.assertTrue(heavyHittersContainsString("fed_roll")); - - // check that federated input files are still existing - Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); - Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); - Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); - Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); - - } - finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); - - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; - OptimizerUtils.FEDERATED_COMPILATION = false; - } - } + // private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName()); + + private final static String TEST_NAME = "FederatedRollTest"; + + private final static String TEST_DIR = "functions/federated/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRollTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Parameterized.Parameter(2) + public boolean rowPartitioned; + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList(new Object[][]{{100, 12, true}, {100, 12, false}}); + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"S"})); + } + + @Test + public void testRollCP() { + runRollTest(ExecMode.SINGLE_NODE); + } + + @Test + public void testRollSP() { + runRollTest(ExecMode.SPARK); + } + + @Test + public void federatedCompilationRollCP() { + runRollTest(ExecMode.SINGLE_NODE, true); + } + + @Test + public void federatedCompilationRollSP() { + runRollTest(ExecMode.SPARK, true); + } + + private void runRollTest(ExecMode execMode) { + runRollTest(execMode, false); + } + + private void runRollTest(ExecMode execMode, boolean activateFedCompilation) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + ExecMode platformOld = rtplatform; + + if (rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int r = rows; + int c = cols / 4; + if (rowPartitioned) { + r = rows / 4; + c = cols; + } + + double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + for (int k : new int[]{1, 2, 3}) { + Arrays.fill(X3[k], 0); + } + + MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + writeInputMatrixWithMTD("X3", X3, false, mc); + writeInputMatrixWithMTD("X4", X4, false, mc); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + int port3 = getRandomAvailablePort(); + int port4 = getRandomAvailablePort(); + Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); + Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); + Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); + Process t4 = startLocalFedWorker(port4); + + + try { + if (!isAlive(t1, t2, t3, t4)) + throw new RuntimeException("Failed starting federated worker"); + rtplatform = execMode; + if (rtplatform == ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[]{"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), + Boolean.toString(rowPartitioned).toUpperCase(), expected("S")}; + + runTest(null); + + OptimizerUtils.FEDERATED_COMPILATION = activateFedCompilation; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, + "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")}; + + runTest(null); + + // compare via files + compareResults(0.01, "Stat-DML1", "Stat-DML2"); + + Assert.assertTrue(heavyHittersContainsString("fed_roll")); + + // check that federated input files are still existing + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); + + } finally { + TestUtils.shutdownThreads(t1, t2, t3, t4); + + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + OptimizerUtils.FEDERATED_COMPILATION = false; + } + } }