Skip to content

Commit

Permalink
Deep rule weights in learning applications.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed May 16, 2024
1 parent 73ac8d4 commit 0f9b070
Show file tree
Hide file tree
Showing 16 changed files with 415 additions and 192 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ public WeightLearningApplication(List<Rule> rules, Database trainTargetDatabase,

if (rule instanceof WeightedRule) {
if (((WeightedRule) rule).getWeight().isDeep()) {
mutableRules.add((WeightedRule) rule);
} else {
deepRules.add((WeightedRule) rule);
} else {
mutableRules.add((WeightedRule) rule);
}
}
}
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ public void generateBatches() {
SimpleTermStore<? extends ReasonerTerm> batchTermStore = batchTermStores.get(i);
batchDeepModelPredicates.add(new ArrayList<DeepModelPredicate>());

// Copy all deep model predicates.
for (DeepPredicate deepPredicate : deepPredicates) {
DeepModelPredicate batchDeepModelPredicate = deepPredicate.getDeepModel().copy();
batchDeepModelPredicate.setAtomStore(batchTermStore.getAtomStore(), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public abstract class Minimizer extends GradientDescent {
private static final Logger log = Logger.getLogger(Minimizer.class);

protected float latentInferenceEnergy;
protected float[] latentInferenceIncompatibility;
protected float[] symbolicWeightRuleLatentInferenceIncompatibility;
protected float[] deepWeightRuleLatentInferenceIncompatibility;
protected TermState[] latentInferenceTermState;
protected float[] latentInferenceAtomValueState;
protected List<TermState[]> batchLatentInferenceTermStates;
Expand All @@ -66,12 +67,16 @@ public abstract class Minimizer extends GradientDescent {
protected float energyLossCoefficient;

protected float mapEnergy;
protected float[] mapIncompatibility;
protected float[] mapSquaredIncompatibility;
protected float[] symbolicWeightRuleMAPIncompatibility;
protected float[] deepWeightRuleMAPIncompatibility;
protected float[] symbolicWeightRuleMAPSquaredIncompatibility;
protected float[] deepWeightRuleMAPSquaredIncompatibility;

protected float augmentedInferenceEnergy;
protected float[] augmentedInferenceIncompatibility;
protected float[] augmentedInferenceSquaredIncompatibility;
protected float[] symbolicWeightRuleAugmentedInferenceIncompatibility;
protected float[] deepWeightRuleAugmentedInferenceIncompatibility;
protected float[] symbolicWeightRuleAugmentedInferenceSquaredIncompatibility;
protected float[] deepWeightRuleAugmentedInferenceSquaredIncompatibility;

protected TermState[] augmentedInferenceTermState;
protected float[] augmentedInferenceAtomValueState;
Expand Down Expand Up @@ -122,7 +127,8 @@ public Minimizer(List<Rule> rules, Database trainTargetDatabase, Database trainT
super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);

latentInferenceEnergy = Float.POSITIVE_INFINITY;
latentInferenceIncompatibility = new float[mutableRules.size()];
symbolicWeightRuleLatentInferenceIncompatibility = new float[mutableRules.size()];
deepWeightRuleLatentInferenceIncompatibility = null;
latentInferenceTermState = null;
latentInferenceAtomValueState = null;
batchLatentInferenceTermStates = new ArrayList<TermState[]>();
Expand All @@ -132,12 +138,16 @@ public Minimizer(List<Rule> rules, Database trainTargetDatabase, Database trainT
energyLossCoefficient = Options.MINIMIZER_ENERGY_LOSS_COEFFICIENT.getFloat();

mapEnergy = Float.POSITIVE_INFINITY;
mapIncompatibility = new float[mutableRules.size()];
mapSquaredIncompatibility = new float[mutableRules.size()];
symbolicWeightRuleMAPIncompatibility = new float[mutableRules.size()];
deepWeightRuleMAPIncompatibility = null;
symbolicWeightRuleMAPSquaredIncompatibility = new float[mutableRules.size()];
deepWeightRuleMAPSquaredIncompatibility = null;

augmentedInferenceEnergy = Float.POSITIVE_INFINITY;
augmentedInferenceIncompatibility = new float[mutableRules.size()];
augmentedInferenceSquaredIncompatibility = new float[mutableRules.size()];
symbolicWeightRuleAugmentedInferenceIncompatibility = new float[mutableRules.size()];
deepWeightRuleAugmentedInferenceIncompatibility = null;
symbolicWeightRuleAugmentedInferenceSquaredIncompatibility = new float[mutableRules.size()];
deepWeightRuleAugmentedInferenceSquaredIncompatibility = null;
augmentedInferenceTermState = null;
augmentedInferenceAtomValueState = null;
batchAugmentedInferenceTermStates = new ArrayList<TermState[]>();
Expand Down Expand Up @@ -180,6 +190,19 @@ public Minimizer(List<Rule> rules, Database trainTargetDatabase, Database trainT
finalConstraintTolerance = Options.MINIMIZER_OBJECTIVE_DIFFERENCE_TOLERANCE.getFloat();
}

@Override
protected void postInitGroundModel() {
super.postInitGroundModel();

deepWeightRuleLatentInferenceIncompatibility = new float[groundedDeepWeightedRules.size()];

deepWeightRuleMAPIncompatibility = new float[groundedDeepWeightedRules.size()];
deepWeightRuleMAPSquaredIncompatibility = new float[groundedDeepWeightedRules.size()];

deepWeightRuleAugmentedInferenceIncompatibility = new float[groundedDeepWeightedRules.size()];
deepWeightRuleAugmentedInferenceSquaredIncompatibility = new float[groundedDeepWeightedRules.size()];
}

@Override
protected void initForLearning() {
super.initForLearning();
Expand Down Expand Up @@ -544,9 +567,9 @@ private void computeMAPInferenceStatistics() {
inTrainingMAPState = true;

mapEnergy = trainInferenceApplication.getReasoner().parallelComputeObjective(trainInferenceApplication.getTermStore()).objective;
computeCurrentIncompatibility(mapIncompatibility);
computeCurrentSquaredIncompatibility(mapSquaredIncompatibility);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVEnergyGradient, MAPDeepEnergyGradient);
computeCurrentIncompatibility(symbolicWeightRuleMAPIncompatibility, deepWeightRuleMAPIncompatibility);
computeCurrentSquaredIncompatibility(symbolicWeightRuleMAPSquaredIncompatibility, deepWeightRuleMAPSquaredIncompatibility);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), expressionRVAtomMAPEnergyGradient, expressionDeepAtomMAPEnergyGradient);
}

/**
Expand All @@ -560,8 +583,8 @@ protected void computeAugmentedInferenceStatistics() {
inTrainingMAPState = true;

augmentedInferenceEnergy = trainInferenceApplication.getReasoner().parallelComputeObjective(trainInferenceApplication.getTermStore()).objective;
computeCurrentIncompatibility(augmentedInferenceIncompatibility);
computeCurrentSquaredIncompatibility(augmentedInferenceSquaredIncompatibility);
computeCurrentIncompatibility(symbolicWeightRuleAugmentedInferenceIncompatibility, deepWeightRuleAugmentedInferenceIncompatibility);
computeCurrentSquaredIncompatibility(symbolicWeightRuleAugmentedInferenceSquaredIncompatibility, deepWeightRuleAugmentedInferenceSquaredIncompatibility);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), augmentedRVAtomEnergyGradient, augmentedDeepAtomEnergyGradient);

deactivateAugmentedInferenceProxTerms();
Expand All @@ -579,7 +602,7 @@ protected void computeLatentInferenceStatistics() {
inTrainingMAPState = true;

latentInferenceEnergy = trainInferenceApplication.getReasoner().parallelComputeObjective(trainInferenceApplication.getTermStore()).objective;
computeCurrentIncompatibility(latentInferenceIncompatibility);
computeCurrentIncompatibility(symbolicWeightRuleLatentInferenceIncompatibility, deepWeightRuleLatentInferenceIncompatibility);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), rvLatentAtomGradient, deepLatentAtomGradient);

unfixLabeledRandomVariables();
Expand Down Expand Up @@ -698,10 +721,19 @@ protected void addAugmentedLagrangianProxRuleConstantsGradient() {
protected abstract void addSupervisedProxRuleObservedAtomValueGradient();

@Override
protected void addLearningLossWeightGradient() {
protected void addLearningLossSymbolicWeightGradient() {
addRuleWeightGradient(symbolicWeightGradient, symbolicWeightRuleMAPIncompatibility, symbolicWeightRuleLatentInferenceIncompatibility, symbolicWeightRuleAugmentedInferenceIncompatibility);
}

@Override
protected void addTotalDeepRuleWeightGradient() {
addRuleWeightGradient(deepWeightGradient, deepWeightRuleMAPIncompatibility, deepWeightRuleLatentInferenceIncompatibility, deepWeightRuleAugmentedInferenceIncompatibility);
}

private void addRuleWeightGradient(float[] ruleWeightGradient, float[] MAPIncompatibility, float[] latentInferenceIncompatibility, float[] augmentedInferenceIncompatibility) {
// Energy loss gradient.
for (int i = 0; i < mutableRules.size(); i++) {
weightGradient[i] += energyLossCoefficient * latentInferenceIncompatibility[i];
for (int i = 0; i < ruleWeightGradient.length; i++) {
ruleWeightGradient[i] += energyLossCoefficient * latentInferenceIncompatibility[i];
}

// Energy difference constraint gradient.
Expand All @@ -710,15 +742,15 @@ protected void addLearningLossWeightGradient() {
return;
}

for (int i = 0; i < mutableRules.size(); i++) {
weightGradient[i] += linearPenaltyCoefficient * (augmentedInferenceIncompatibility[i] - mapIncompatibility[i]);
weightGradient[i] += squaredPenaltyCoefficient * (augmentedInferenceEnergy - mapEnergy - constraintRelaxationConstant)
* (augmentedInferenceIncompatibility[i] - mapIncompatibility[i]);
for (int i = 0; i < ruleWeightGradient.length; i++) {
ruleWeightGradient[i] += linearPenaltyCoefficient * (augmentedInferenceIncompatibility[i] - MAPIncompatibility[i]);
ruleWeightGradient[i] += squaredPenaltyCoefficient * (augmentedInferenceEnergy - mapEnergy - constraintRelaxationConstant)
* (augmentedInferenceIncompatibility[i] - MAPIncompatibility[i]);
}
}

@Override
protected void addTotalAtomGradient() {
protected void addTotalExpressionAtomGradient() {
// Energy Loss Gradient.
for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) {
GroundAtom atom = trainInferenceApplication.getTermStore().getAtomStore().getAtom(i);
Expand All @@ -727,7 +759,7 @@ protected void addTotalAtomGradient() {
continue;
}

deepGradient[i] += energyLossCoefficient * deepLatentAtomGradient[i];
expressionDeepAtomGradient[i] += energyLossCoefficient * deepLatentAtomGradient[i];
}

// Energy difference constraint gradient.
Expand All @@ -745,12 +777,12 @@ protected void addTotalAtomGradient() {
continue;
}

float rvEnergyGradientDifference = augmentedRVAtomEnergyGradient[i] - MAPRVEnergyGradient[i];
float deepAtomEnergyGradientDifference = augmentedDeepAtomEnergyGradient[i] - MAPDeepEnergyGradient[i];
float rvEnergyGradientDifference = augmentedRVAtomEnergyGradient[i] - expressionRVAtomMAPEnergyGradient[i];
float deepAtomEnergyGradientDifference = augmentedDeepAtomEnergyGradient[i] - expressionDeepAtomMAPEnergyGradient[i];

rvGradient[i] += squaredPenaltyCoefficient * constraintViolation * rvEnergyGradientDifference
expressionRVAtomGradient[i] += squaredPenaltyCoefficient * constraintViolation * rvEnergyGradientDifference
+ linearPenaltyCoefficient * rvEnergyGradientDifference;
deepGradient[i] += squaredPenaltyCoefficient * constraintViolation * deepAtomEnergyGradientDifference
expressionDeepAtomGradient[i] += squaredPenaltyCoefficient * constraintViolation * deepAtomEnergyGradientDifference
+ linearPenaltyCoefficient * deepAtomEnergyGradientDifference;
}
}
Expand Down Expand Up @@ -796,9 +828,10 @@ protected float computeGradientNorm() {
/**
* A method for computing the squared incompatibilities of the rules with atoms values in their current state.
*/
protected void computeCurrentSquaredIncompatibility(float[] incompatibilityArray) {
protected void computeCurrentSquaredIncompatibility(float[] symbolicWeightRuleIncompatibility, float[] deepWeightRuleIncompatibility) {
// Zero out the incompatibility first.
Arrays.fill(incompatibilityArray, 0.0f);
Arrays.fill(symbolicWeightRuleIncompatibility, 0.0f);
Arrays.fill(deepWeightRuleIncompatibility, 0.0f);

float[] atomValues = trainInferenceApplication.getTermStore().getAtomStore().getAtomValues();

Expand All @@ -811,13 +844,27 @@ protected void computeCurrentSquaredIncompatibility(float[] incompatibilityArray
continue;
}

Integer index = ruleIndexMap.get((WeightedRule)term.getRule());
Weight weight = ((WeightedRule)term.getRule()).getWeight();

Integer index = null;
if (weight.isDeep()) {
index = groundedDeepWeightedRuleIndexMap.get((WeightedRule) term.getRule());
} else {
index = symbolicWeightedRuleIndexMap.get((WeightedRule) term.getRule());
}

if (index == null) {
// Relaxed constraints are weighted rules that are not part of the optimization.
continue;
}

incompatibilityArray[index] += term.evaluateSquaredHingeLoss(atomValues);
float squaredIncompatibility = term.evaluateSquaredHingeLoss(atomValues);

if (weight.isDeep()) {
deepWeightRuleIncompatibility[index] += squaredIncompatibility;
} else {
symbolicWeightRuleIncompatibility[index] += squaredIncompatibility;
}
}
}

Expand Down
Loading

0 comments on commit 0f9b070

Please sign in to comment.