Skip to content

Commit

Permalink
RUL-107: Bugfix in parsing expert survival rules when survival status…
Browse files Browse the repository at this point in the history
… attribute is nominal.
  • Loading branch information
agudys committed Oct 15, 2024
1 parent c2a3b1d commit acd799e
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 49 deletions.
2 changes: 1 addition & 1 deletion adaa.analytics.rules/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {
id 'java'
}

version = '2.1.20'
version = '2.1.21'
java {
sourceCompatibility = JavaVersion.VERSION_1_8
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import adaa.analytics.rules.logic.representation.condition.ConditionBase;
import adaa.analytics.rules.logic.representation.condition.ElementaryCondition;
import adaa.analytics.rules.logic.representation.rule.Rule;
import adaa.analytics.rules.logic.representation.rule.SurvivalRule;
import adaa.analytics.rules.logic.representation.valueset.IValueSet;
import adaa.analytics.rules.logic.representation.valueset.SingletonSet;
import adaa.analytics.rules.logic.representation.valueset.UndefinedSet;
import adaa.analytics.rules.logic.representation.valueset.Universum;

import java.io.Serializable;
Expand Down Expand Up @@ -60,10 +62,7 @@ public class Knowledge implements Serializable {

/** Maximum number of preferred attributes per rule. */
protected int preferredAttributesPerRule;

/** Auxiliary files indicating whether the knowledge concerns regression problem. */
protected boolean isRegression;


/** Auxiliary files indicating number classes (classification problems only). */
protected int numClasses;

Expand Down Expand Up @@ -166,7 +165,7 @@ public void setPreferredAttributesPerRule(int v) {
*/
public Knowledge(IExampleSet dataset, MultiSet<Rule> rules, MultiSet<Rule> preferredConditions, MultiSet<Rule> forbiddenConditions) {

this.isRegression = dataset.getAttributes().getLabel().isNumerical();
boolean isSurvival = dataset.getAttributes().getColumnByRole(SurvivalRule.SURVIVAL_TIME_ROLE) != null;

this.extendUsingPreferred = false;
this.extendUsingAutomatic = false;
Expand All @@ -176,7 +175,7 @@ public Knowledge(IExampleSet dataset, MultiSet<Rule> rules, MultiSet<Rule> prefe
this.preferredConditionsPerRule = Integer.MAX_VALUE;
this.preferredAttributesPerRule = Integer.MAX_VALUE;

int numClasses = (dataset.getAttributes().getLabel().isNominal())
this.numClasses = (dataset.getAttributes().getLabel().isNominal() && !isSurvival)
? dataset.getAttributes().getLabel().getMapping().size() : 1;

for (int i = 0; i < numClasses; ++i) {
Expand All @@ -188,14 +187,14 @@ public Knowledge(IExampleSet dataset, MultiSet<Rule> rules, MultiSet<Rule> prefe
}

for (Rule r : rules) {
SingletonSet set = (SingletonSet)r.getConsequence().getValueSet();
int c = (int)set.getValue();
SingletonSet set = (SingletonSet)r.getConsequence().getValueSet();
int c = (set instanceof UndefinedSet) ? 0 : (int)set.getValue();
this.rules.get(c).add(r);
}

for (Rule r : preferredConditions) {
SingletonSet set = (SingletonSet)r.getConsequence().getValueSet();
int c = (int)set.getValue();
int c = (set instanceof UndefinedSet) ? 0 : (int)set.getValue();

ElementaryCondition ec = (ElementaryCondition) r.getPremise().getSubconditions().get(0);
if (ec.getValueSet() instanceof Universum) {
Expand All @@ -207,7 +206,7 @@ public Knowledge(IExampleSet dataset, MultiSet<Rule> rules, MultiSet<Rule> prefe

for (Rule r : forbiddenConditions) {
SingletonSet set = (SingletonSet)r.getConsequence().getValueSet();
int c = (int)set.getValue();
int c = (set instanceof UndefinedSet) ? 0 : (int)set.getValue();

ElementaryCondition ec = (ElementaryCondition) r.getPremise().getSubconditions().get(0);
if (ec.getValueSet() instanceof Universum) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@
import adaa.analytics.rules.logic.representation.rule.RegressionRule;
import adaa.analytics.rules.logic.representation.rule.Rule;
import adaa.analytics.rules.logic.representation.rule.SurvivalRule;
import adaa.analytics.rules.logic.representation.valueset.IValueSet;
import adaa.analytics.rules.logic.representation.valueset.Interval;
import adaa.analytics.rules.logic.representation.valueset.SingletonSet;
import adaa.analytics.rules.logic.representation.valueset.Universum;
import adaa.analytics.rules.logic.representation.valueset.*;
import adaa.analytics.rules.utils.Logger;
import org.apache.commons.lang3.math.NumberUtils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.regex.Matcher;
Expand All @@ -57,10 +55,7 @@ public static Rule parseRule(String s, IAttributes meta) {
Pattern pattern = Pattern.compile("IF\\s+(?<premise>.+)\\s+THEN(?<consequence>\\s+.*|\\s*)");
Matcher matcher = pattern.matcher(s);

boolean isSurvival = false;
if (meta.getColumnByRoleUnsafe(SurvivalRule.SURVIVAL_TIME_ROLE) != null) {
isSurvival = true;
}
boolean isSurvival = (meta.getColumnByRoleUnsafe(SurvivalRule.SURVIVAL_TIME_ROLE) != null);

if (matcher.find()) {
String pre = matcher.group("premise");
Expand All @@ -70,24 +65,28 @@ public static Rule parseRule(String s, IAttributes meta) {
CompoundCondition premise = parseCompoundCondition(pre, meta);

if (con == null || con.trim().length() == 0) {
if (!meta.getLabelUnsafe().isNumerical()) {
Logger.log("Empty conclusion for nominal label"+ "\n", Level.WARNING);
} else {
consequence = new ElementaryCondition(meta.getLabelUnsafe().getName(), new SingletonSet(NaN, null));
if (isSurvival) {
consequence = new ElementaryCondition(meta.getLabelUnsafe().getName(), new UndefinedSet());
} else if (meta.getLabelUnsafe().isNumerical()) {
consequence = new ElementaryCondition(meta.getLabelUnsafe().getName(), new SingletonSet(NaN, null));
consequence.setAdjustable(false);
consequence.setDisabled( false);
consequence.setDisabled(false);
} else{
Logger.log("Empty conclusion for nominal label"+ "\n", Level.WARNING);
}
} else {
consequence = parseElementaryCondition(con, meta);
}

if (premise != null && consequence != null) {

rule = meta.getLabelUnsafe().isNominal()
? new ClassificationRule(premise, consequence)
: (isSurvival
? new SurvivalRule(premise, consequence)
: new RegressionRule(premise, consequence));
if (isSurvival) {
rule = new SurvivalRule(premise, consequence);
} else {
rule = meta.getLabelUnsafe().isNominal()
? new ClassificationRule(premise, consequence)
: new RegressionRule(premise, consequence);
}
}
}

Expand Down Expand Up @@ -158,9 +157,11 @@ public static ElementaryCondition parseElementaryCondition(String s, IAttributes
}

IValueSet valueSet = null;
IAttribute attributeMeta = meta.get(attribute);

boolean isSurvival = (meta.getColumnByRole(SurvivalRule.SURVIVAL_TIME_ROLE) != null) && (meta.getLabel() == attributeMeta);

IAttribute attributeMeta = meta.get(attribute);
if (attributeMeta == null) {
if (attributeMeta == null) {
Logger.log("Attribute <" + attribute + "> not found"+ "\n", Level.WARNING);
return null;
}
Expand All @@ -176,13 +177,19 @@ public static ElementaryCondition parseElementaryCondition(String s, IAttributes
matcher = regex.matcher(valueString);
if (matcher.find()) {
String value = matcher.group("discrete");
List<String> mapping = new ArrayList<String>(attributeMeta.getMapping().getValues());
double v = mapping.indexOf(value);
if (v == -1) {
Logger.log("Invalid value <" + value + "> of the nominal attribute <" + attribute + ">"+ "\n", Level.WARNING);
return null;
}
valueSet = new SingletonSet(v, mapping);

if (value.equals("NaN") && isSurvival) {
valueSet = new UndefinedSet();
} else {

List<String> mapping = new ArrayList<String>(attributeMeta.getMapping().getValues());
double v = mapping.indexOf(value);
if (v == -1) {
Logger.log("Invalid value <" + value + "> of the nominal attribute <" + attribute + ">" + "\n", Level.WARNING);
return null;
}
valueSet = new SingletonSet(v, mapping);
}

}
} else if (attributeMeta.isNumerical()) {
Expand All @@ -191,8 +198,11 @@ public static ElementaryCondition parseElementaryCondition(String s, IAttributes
//
if (matcher.find()) {
String value = matcher.group("discrete");
double v = value.equals("NaN") ? Double.NaN : Double.parseDouble(value);
valueSet = new SingletonSet(v, null);
if (value.equals("NaN")) {
valueSet = new UndefinedSet();
} else {
valueSet = new SingletonSet(Double.parseDouble(value), null);
}
} else {
boolean leftClosed = Pattern.compile("\\<.+").matcher(valueString).find();
boolean rightClosed = Pattern.compile(".+\\>").matcher(valueString).find();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import adaa.analytics.rules.logic.representation.condition.ElementaryCondition;
import adaa.analytics.rules.logic.representation.IntegerBitSet;
import adaa.analytics.rules.logic.representation.valueset.SingletonSet;
import adaa.analytics.rules.logic.representation.valueset.UndefinedSet;
import org.jetbrains.annotations.NotNull;

import java.io.Serializable;
Expand Down Expand Up @@ -285,14 +286,7 @@ public void covers(IExampleSet set, ContingencyTable ct, Set<Integer> positives,
* @return Text representation.
*/
public String toString() {
String consequenceString;
if (consequence.getValueSet() instanceof SingletonSet &&
Double.isNaN(((SingletonSet) consequence.getValueSet()).getValue()) && ((SingletonSet) consequence.getValueSet()).getMapping() == null) {
consequenceString = "";
} else {
consequenceString = consequence.toString();
}
String s = "IF " + premise.toString() + " THEN " + consequenceString;
String s = "IF " + premise.toString() + " THEN " + consequence.toString();
return s;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import adaa.analytics.rules.logic.representation.condition.CompoundCondition;
import adaa.analytics.rules.logic.representation.condition.ElementaryCondition;
import adaa.analytics.rules.logic.representation.KaplanMeierEstimator;
import adaa.analytics.rules.logic.representation.valueset.UndefinedSet;
import org.jetbrains.annotations.NotNull;
import tech.tablesaw.api.DoubleColumn;

Expand Down Expand Up @@ -157,5 +158,11 @@ public Covering covers(IExampleSet set) {
}
return covered;
}

@Override
public String toString() {
String s = "IF " + premise.toString() + " THEN ";
return s;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package adaa.analytics.rules.logic.representation.valueset;

import adaa.analytics.rules.utils.DoubleFormatter;

import java.util.List;

public class UndefinedSet extends SingletonSet{

/** Gets {@link #value} */
public double getValue() { throw new RuntimeException("Illegal call for UndefinedSet: getValue"); }
/** Sets {@link #value} */
public void setValue(double v) { throw new RuntimeException("Illegal call for UndefinedSet: setValue"); }

/** Gets {@link #value} as string */
public String getValueAsString() { throw new RuntimeException("Illegal call for UndefinedSet: getValueAsString"); }

/** Gets {@link #mapping} */
public List<String> getMapping() { throw new RuntimeException("Illegal call for UndefinedSet: getMapping" ); }
/** Sets {@link #mapping} */
public void setMapping(List<String> v) { throw new RuntimeException("Illegal call for UndefinedSet: setMapping"); }

public UndefinedSet() {
super(Double.NaN, null);
}

@Override
public String toString() {
return "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,15 @@ public RuleSetBase learnWithExpert(IExampleSet exampleSet) {
*/
private void fixMappings(Iterable<Rule> rules, IExampleSet set) {

boolean isSurvival = (set.getAttributes().getColumnByRole(SurvivalRule.SURVIVAL_TIME_ROLE) != null);

for (Rule r : rules) {
List<ConditionBase> toCheck = new ArrayList<ConditionBase>(); // list of elementary conditions to check
toCheck.addAll(r.getPremise().getSubconditions());
toCheck.add(r.getConsequence());

if (!isSurvival) {
toCheck.add(r.getConsequence());
}

for (ConditionBase c : toCheck) {
ElementaryCondition ec = (c instanceof ElementaryCondition) ? (ElementaryCondition) c : null;
Expand Down

0 comments on commit acd799e

Please sign in to comment.