Skip to content

Commit

Permalink
[SYSTEMDS-3764] Fix parser-level constant propagation w/ multi-returns
Browse files Browse the repository at this point in the history
This patch fixes an old bug of parser-level constant propagation where
updates to scalars through assignments from multi-return function
were not updated the set of constant vars, leading to incorrectly
propagated scalars, constant folding, and utimately wrong results.

This bug was discovered while improving incremental slice line.
  • Loading branch information
mboehm7 committed Sep 8, 2024
1 parent 472e69f commit b3517cb
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 6 deletions.
3 changes: 1 addition & 2 deletions scripts/builtin/incSliceLine.dml
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ m_incSliceLine = function(
minsc = -Inf
if( nrow(prevTK) > 0 ) {
prevTK2 = oneHotEncodeUsingOffsets(prevTK, foffb, foffe);
[minsc2, prevTKC2] = computeLowestPrevTK(prevTK2, X2, totalE, eAvg, alpha)
minsc = minsc2; #FIXME otherwise -Inf incorrectly propagated
[minsc, prevTKC2] = computeLowestPrevTK(prevTK2, X2, totalE, eAvg, alpha)
}

# create and score basic slices (conjunctions of 1 feature)
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/org/apache/sysds/parser/IfStatementBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,20 @@ public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap<String,

/////////////////////////////////////////////////////////////////////////////////
// check data type and value type are same for updated variables in both
// if statement and else statement
// if statement and else statement
// (reject conditional data type change)
/////////////////////////////////////////////////////////////////////////////////
for (String updatedVar : this._updated.getVariableNames()){
DataIdentifier origVersion = idsOrigCopy.getVariable(updatedVar);
DataIdentifier ifVersion = idsIfCopy.getVariable(updatedVar);
DataIdentifier ifVersion = idsIfCopy.getVariable(updatedVar);
DataIdentifier elseVersion = idsElseCopy.getVariable(updatedVar);

//data type handling: reject conditional data type change
if( ifVersion != null && elseVersion != null ) //both branches exist
{
if (!ifVersion.getOutput().getDataType().equals(elseVersion.getOutput().getDataType())){
raiseValidateError("IfStatementBlock has unsupported conditional data type change of variable '"+updatedVar+"' in if/else branch.", conditional);
}
}
}
else if( origVersion !=null ) //only if branch exists
{
Expand All @@ -99,7 +99,7 @@ else if( origVersion !=null ) //only if branch exists
}
}

//value type handling
//value type handling
if (ifVersion != null && elseVersion != null && !ifVersion.getOutput().getValueType().equals(elseVersion.getOutput().getValueType())){
LOG.warn(elseVersion.printWarningLocation() + "Variable " + elseVersion.getName() + " defined with different value type in if and else clause.");
}
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/apache/sysds/parser/StatementBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,10 @@ else if ( source instanceof BuiltinFunctionExpression || source instanceof Param
ids.addVariable(targetList.get(j).getName(), (DataIdentifier)outputs[j]);
}
}

// remove updated constant vars (for correctness)
for(DataIdentifier target : targetList)
currConstVars.remove(target.getName());
}

public void setStatementFormatType(OutputStatement s, boolean conditionalValidate)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.misc;

import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;

public class IPAScalarPropagationMulitReturnTest extends AutomatedTestBase
{
private final static String TEST_NAME1 = "ScalarPropagationMultiReturn";
private final static String TEST_DIR = "functions/misc/";
private final static String TEST_CLASS_DIR = TEST_DIR + IPAScalarPropagationMulitReturnTest.class.getSimpleName() + "/";

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}) );
}

@Test
public void testScalarPropagationNoIPA() {
runIPAScalarVariablePropagationTest( TEST_NAME1, false );
}

@Test
public void testScalarPropagationIPA() {
runIPAScalarVariablePropagationTest( TEST_NAME1, true );
}

private void runIPAScalarVariablePropagationTest( String testname, boolean IPA )
{
boolean oldFlagIPA = OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS;

try {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{"-explain","-args", output("R") };
OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA;

//run test, incl expected MR jobs (in case if IPA 0 due to scalar propagation)
runTest(true, false, null, IPA ? 0 : 35);

double[][] ret = TestUtils.convertHashMapToDoubleArray(
readDMLMatrixFromOutputDir("R"), 1, 1);
Assert.assertEquals(Double.valueOf(2), Double.valueOf(ret[0][0]));
}
finally {
OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = oldFlagIPA;
}
}
}
38 changes: 38 additions & 0 deletions src/test/scripts/functions/misc/ScalarPropagationMultiReturn.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------


foo = function(matrix[double] A) return (Double minsc, matrix[double] B) {
while(FALSE) {} # prevent inlining
minsc = min(A);
B = A + 7;
}

X = seq(1,7);
minsc = -Inf;
if( avg(X) <= 5 ) {
[minsc, X] = foo(X)
}

# bug: minsc was incorrectly propagated as -Inf, leading to -Inf instead of 2
R = as.matrix(minsc + 1);
write(R, $1);

0 comments on commit b3517cb

Please sign in to comment.