Skip to content

Commit

Permalink
Multiple speed improvements
Browse files Browse the repository at this point in the history
* Classification:
   * Significantly faster growing (two orders of magnitude for sets with >100k instances), faster pruning,
   * Added approximate mode (`approximate_induction` parameter).
Regression:
   * Mean-based growing set as default (few times faster then median, non-significant impact on accuracy).
Survival:
   * Faster growing and pruning (few fold improvement).
  • Loading branch information
agudys committed Dec 19, 2023
1 parent 9720962 commit eba946c
Show file tree
Hide file tree
Showing 27 changed files with 744 additions and 417 deletions.
2 changes: 1 addition & 1 deletion adaa.analytics.rules/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ codeQuality {
}

sourceCompatibility = 1.8
version = '1.6.2'
version = '1.7.0'


jar {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ private void parse(String[] args) {
RapidMiner.setExecutionMode(RapidMiner.ExecutionMode.COMMAND_LINE);

RapidMiner.init();
//System.in.read();
// System.in.read();
execute(argList.get(0));

} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ public void close() {
/**
* Can be implemented by subclasses to perform some initial processing prior growing.
* @param trainSet Training set.
* @return Preprocessed training set.
*/
public void preprocess(ExampleSet trainSet) {}
public ExampleSet preprocess(ExampleSet trainSet) { return trainSet; }

/**
* Adds elementary conditions to the rule premise until termination conditions are fulfilled.
Expand All @@ -104,11 +105,14 @@ public int grow(
int initialConditionsCount = rule.getPremise().getSubconditions().size();

// get current covering
Covering covering = new Covering();
rule.covers(dataset, covering, covering.positives, covering.negatives);
Set<Integer> covered = new HashSet<Integer>();
covered.addAll(covering.positives);
covered.addAll(covering.negatives);
ContingencyTable contingencyTable = new Covering();
IntegerBitSet positives = new IntegerBitSet(dataset.size());
IntegerBitSet negatives = new IntegerBitSet(dataset.size());
rule.covers(dataset, contingencyTable, positives, negatives);
//Set<Integer> covered = new HashSet<Integer>();
IntegerBitSet covered = new IntegerBitSet(dataset.size());
covered.addAll(positives);
covered.addAll(negatives);
Set<Attribute> allowedAttributes = new TreeSet<Attribute>(new AttributeComparator());
for (Attribute a: dataset.getAttributes()) {
allowedAttributes.add(a);
Expand All @@ -126,18 +130,23 @@ public int grow(

notifyConditionAdded(condition);

covering = new Covering();
rule.covers(dataset, covering, covering.positives, covering.negatives);
covered.clear();
covered.addAll(covering.positives);
covered.addAll(covering.negatives);
//recalculate covering only when needed
if (condition.getCovering() != null) {
positives.retainAll(condition.getCovering());
negatives.retainAll(condition.getCovering());
covered.retainAll(condition.getCovering());
} else {
contingencyTable.clear();
positives.clear();
negatives.clear();

rule.covers(dataset, contingencyTable, positives, negatives);
covered.clear();
covered.addAll(positives);
covered.addAll(negatives);
}

rule.setCoveringInformation(covering);
rule.getCoveredPositives().setAll(covering.positives);
rule.getCoveredNegatives().setAll(covering.negatives);

rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure());

Logger.log("Condition " + rule.getPremise().getSubconditions().size() + " added: "
+ rule.toString() + ", weight=" + rule.getWeight() + "\n", Level.FINER);

Expand All @@ -152,12 +161,25 @@ public int grow(
carryOn = false;
}

} while (carryOn);

} while (carryOn);

// ugly
Covering covering = new Covering();
covering.positives = positives;
covering.negatives = negatives;

rule.setCoveringInformation(covering);
rule.getCoveredPositives().setAll(positives);
rule.getCoveredNegatives().setAll(negatives);

// if rule has been successfully grown
int addedConditionsCount = rule.getPremise().getSubconditions().size() - initialConditionsCount;
rule.setInducedContitionsCount(addedConditionsCount);

if (addedConditionsCount > 0) {
rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure());
}

rule.setInducedContitionsCount(addedConditionsCount);
notifyGrowingFinished(rule);

return addedConditionsCount;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public ActionFinder(ActionInductionParameters params) {
classificationFinder = new ClassificationFinder(params);
}

public void preprocess(ExampleSet trainSet) {
classificationFinder.preprocess(trainSet);
public ExampleSet preprocess(ExampleSet trainSet) {
return classificationFinder.preprocess(trainSet);
}

private void log(String msg, Level level) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ public ConditionCandidate(String attribute, IValueSet valueSet) {
}
}

protected static final int MAX_BINS = 100;

// Example description:
// [0-31] - example id (32 bits)
// [32-47] - block id (16 bits)
Expand Down Expand Up @@ -73,27 +71,36 @@ public ApproximateClassificationFinder(InductionParameters params) {
}

@Override
public void preprocess(ExampleSet dataset) {
public ExampleSet preprocess(ExampleSet dataset) {
int n_examples = dataset.size();
int n_attributes = dataset.getAttributes().size();

trainSet = dataset;
descriptions = new long[n_attributes][n_examples];
mappings = new int[n_attributes][n_examples];

bins_positives = new int[n_attributes][MAX_BINS];
bins_negatives = new int[n_attributes][MAX_BINS];
bins_newPositives = new int[n_attributes][MAX_BINS];
bins_begins = new int[n_attributes][MAX_BINS];
bins_positives = new int[n_attributes][];
bins_negatives = new int[n_attributes][];
bins_newPositives = new int[n_attributes][];
bins_begins = new int[n_attributes][];

ruleRanges = new int[n_attributes][2];

for (Attribute attr: dataset.getAttributes()) {
int ia = attr.getTableIndex();
int n_vals = attr.isNominal() ? attr.getMapping().size() : params.getApproximateBinsCount();

bins_positives[ia] = new int [n_vals];
bins_negatives[ia] = new int[n_vals];
bins_newPositives[ia] = new int[n_vals];
bins_begins[ia] = new int[n_vals];

determineBins(dataset, attr, descriptions[ia], mappings[ia], bins_begins[ia], ruleRanges[ia]);

arrayCopies.put("ruleRanges", (Object)Arrays.stream(ruleRanges).map(int[]::clone).toArray(int[][]::new));
}

return dataset;
}

/**
Expand Down Expand Up @@ -293,13 +300,14 @@ protected ElementaryCondition induceCondition(
int covered_n = 0;
int covered_new_p = 0;

// use first attribute to establish number of covered elements
// use first attribute to establish number of covered elements
for (int bid = ruleRanges[0][0]; bid < ruleRanges[0][1]; ++bid) {
covered_p += bins_positives[0][bid];
covered_n += bins_negatives[0][bid];
covered_new_p += bins_newPositives[0][bid];
}


// iterate over all allowed decision attributes
for (Attribute attr : dataset.getAttributes()) {

Expand Down Expand Up @@ -462,7 +470,10 @@ class Stats {

if (current != null && current.getAttribute() != null) {
Logger.log("\tAttribute best: " + current + ", quality=" + current.quality, Level.FINEST);
updateMidpoint(dataset, current);
Attribute attr = dataset.getAttributes().get(current.getAttribute());
if (attr.isNumerical()) {
updateMidpoint(dataset, current);
}
Logger.log(", adjusted: " + current + "\n", Level.FINEST);
}

Expand All @@ -482,13 +493,13 @@ class Stats {
return null; // empty condition - discard
}

updateMidpoint(dataset, best);

Logger.log("\tFinal best: " + best + ", quality=" + best.quality + "\n", Level.FINEST);

if (bestAttr.isNominal()) {
if (bestAttr.isNumerical()) {
updateMidpoint(dataset, best);
} else {
allowedAttributes.remove(bestAttr);
}

Logger.log("\tFinal best: " + best + ", quality=" + best.quality + "\n", Level.FINEST);
}

return best;
Expand All @@ -508,7 +519,7 @@ protected void notifyConditionAdded(ConditionBase cnd) {
ruleRanges[aid][0] = blockId + 1;
ruleRanges[aid][1] = blockId;
} else {
excludeExamplesFromArrays(trainSet, attr, ruleRanges[aid][0], candidate.blockId + 1);
excludeExamplesFromArrays(trainSet, attr, ruleRanges[aid][0], candidate.blockId);
excludeExamplesFromArrays(trainSet, attr, candidate.blockId + 1, ruleRanges[aid][1]);
ruleRanges[aid][0] = blockId;
ruleRanges[aid][1] = blockId + 1;
Expand Down Expand Up @@ -546,6 +557,7 @@ protected void determineBins(ExampleSet dataset, Attribute attr,
vals[i] = dataset.getExample(i).getValue(attr);
}


/*
class ValuesComparator implements IntComparator {
double [] vals;
Expand Down Expand Up @@ -597,12 +609,12 @@ public int compare(Bin p, Bin q) {
}
}

PriorityQueue<Bin> bins = new PriorityQueue<Bin>(100, new SizeBinComparator());
PriorityQueue<Bin> finalBins = new PriorityQueue<Bin>(100, new IndexBinComparator());
PriorityQueue<Bin> bins = new PriorityQueue<Bin>(binsBegins.length, new SizeBinComparator());
PriorityQueue<Bin> finalBins = new PriorityQueue<Bin>(binsBegins.length, new IndexBinComparator());

bins.add(new Bin(0, mappings.length));

while (bins.size() > 0 && (bins.size() + finalBins.size()) < MAX_BINS) {
while (bins.size() > 0 && (bins.size() + finalBins.size()) < binsBegins.length) {
Bin b = bins.poll();

int id = (b.end + b.begin) / 2;
Expand All @@ -611,9 +623,13 @@ public int compare(Bin p, Bin q) {
// decide direction
if (vals[b.begin] == midval) {
// go up
while (vals[id] == midval) { ++id; }
while (vals[id] == midval) {
++id;
}
} else {
while (vals[id - 1] == midval) { --id; }
while (vals[id - 1] == midval) {
--id;
}
}

Bin leftBin = new Bin(b.begin, id);
Expand Down Expand Up @@ -646,17 +662,16 @@ public int compare(Bin p, Bin q) {
descriptions[i] |= bid << OFFSET_BIN;
}

binsBegins[(int)bid] = b.begin;
binsBegins[(int) bid] = b.begin;
++bid;
}

ruleRanges[0] = 0;
ruleRanges[1] = (int)bid;

// print bins
for (int i = 0; i < bid; ++i) {
ruleRanges[1] = (int) bid;
// print bins
for (int i = 0; i < ruleRanges[1]; ++i) {
int lo = binsBegins[i];
int hi = (i == bid - 1) ? trainSet.size() : binsBegins[i+1] - 1;
int hi = (i == ruleRanges[1] - 1) ? trainSet.size() : binsBegins[i+1] - 1;
Logger.log("[" + lo + ", " + hi + "]:" + vals[lo] + "\n", Level.FINER);
}
}
Expand All @@ -665,6 +680,10 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int

Logger.log("Excluding examples: " + attr.getName() + " from [" + binLo + "," + binHi + "]\n", Level.FINER);

if (binLo == binHi) {
return;
}

int n_examples = dataset.size();
int src_row = attr.getTableIndex();
long[] src_descriptions = descriptions[src_row];
Expand Down Expand Up @@ -695,9 +714,11 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int
int dst_row = other.getTableIndex();

// if nominal attribute was already used
/*
if (other.isNominal() && Math.abs(ruleRanges[dst_row][1] - ruleRanges[dst_row][0]) == 1) {
continue;
}
*/

Future<Object> future = pool.submit(() -> {

Expand All @@ -717,8 +738,14 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int

int bid = (int) ((desc & MASK_BIN) >> OFFSET_BIN);

boolean opposite = dst_ranges[0] > dst_ranges[1]; // this indicate nominal opposite condition
int dst_bin_lo = Math.min(dst_ranges[0], dst_ranges[1]);
int dst_bin_hi = Math.max(dst_ranges[0], dst_ranges[1]);

// update stats only in bins covered by the rule
if (bid >= dst_ranges[0] && bid < dst_ranges[1] && ((desc & FLAG_COVERED) != 0)) {
boolean in_range = (bid >= dst_bin_lo && bid < dst_bin_hi) || (opposite && (bid < dst_bin_lo || bid >= dst_bin_hi));

if (in_range && ((desc & FLAG_COVERED) != 0)) {

if ((desc & FLAG_POSITIVE) != 0) {
--dst_positives[bid];
Expand Down Expand Up @@ -755,12 +782,16 @@ protected void resetArrays(ExampleSet dataset, int targetLabel) {

int n_examples = dataset.size();

int[][] copy_ranges = (int[][])arrayCopies.get("ruleRanges");

for (Attribute attr: dataset.getAttributes()) {
int attribute_id = attr.getTableIndex();

Arrays.fill(bins_positives[attribute_id], 0);
Arrays.fill(bins_negatives[attribute_id], 0);
Arrays.fill(bins_newPositives[attribute_id], 0);
ruleRanges[attribute_id][0] = 0;
ruleRanges[attribute_id][1] = copy_ranges[attribute_id][1];

long[] descriptions_row = descriptions[attribute_id];
int[] mappings_row = mappings[attribute_id];
Expand Down Expand Up @@ -792,6 +823,9 @@ protected void resetArrays(ExampleSet dataset, int targetLabel) {
}
}

// reset rule ranges


Logger.log("Reset arrays for class " + targetLabel + "\n", Level.FINER);
printArrays();

Expand All @@ -816,9 +850,13 @@ protected void printArrays() {

int bin_p = 0, bin_n = 0, bin_new_p = 0, bin_outside = 0;

for (int i = 0; i < MAX_BINS; ++i) {
boolean opposite = ruleRanges[attribute_id][0] > ruleRanges[attribute_id][1]; // this indicate nominal opposite condition
int lo = Math.min(ruleRanges[attribute_id][0], ruleRanges[attribute_id][1]);
int hi = Math.max(ruleRanges[attribute_id][0], ruleRanges[attribute_id][1]);

for (int i = 0; i < bins_positives[attribute_id].length; ++i) {

if (i >= ruleRanges[attribute_id][0] && i < ruleRanges[attribute_id][1]) {
if ((i >= lo && i < hi) || (opposite && (i < lo || i >= hi)) ) {
bin_p += bins_positives[attribute_id][i];
bin_n += bins_negatives[attribute_id][i];
bin_new_p += bins_newPositives[attribute_id][i];
Expand Down
Loading

0 comments on commit eba946c

Please sign in to comment.