Skip to content

Commit

Permalink
Deep weighted logical rules.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed May 17, 2024
1 parent b2b6a64 commit 552a2f5
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public interface Formula extends Serializable {
/**
* Collapses nested formulas of the same type and remove duplicates at the top level.
* Does not change the context object.
* Order is not guarenteed.
* Order is not guaranteed.
* Ex: (A ^ B) ^ !!C ^ (D v E) becomes A ^ B ^ C ^ (D v E).
*
* Note that most formulas will return an object of the same type (eg a Conjunction will
Expand Down
18 changes: 16 additions & 2 deletions psl-core/src/main/java/org/linqs/psl/model/rule/Weight.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.util.HashCode;

/**
* A weight for a rule.
Expand Down Expand Up @@ -89,12 +90,25 @@ public String toString() {
}
}

public boolean equals(Object o) {
if (this == o) {
return true;
}

if (o == null || this.getClass() != o.getClass()) {
return false;
}

Weight other = (Weight)o;
return constantValue == other.constantValue && atom == other.atom;
}

public int hashCode() {
// Use the hash of the atom if it exists. Else, use the object's hash.
// Use the hash of the atom if it exists. Else, use a default hash shared by all constant weights.
if (atom != null) {
return atom.hashCode();
} else {
return super.hashCode();
return HashCode.DEFAULT_INITIAL_NUMBER;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ private void groundSingleNonSummationRule(

resources.weightGroundAtom = weightGroundAtom;


if (weightGroundAtom instanceof UnmanagedRandomVariableAtom) {
resources.accessExceptionAtoms.add(weightGroundAtom);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2024 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.model.rule.arithmetic;

import org.linqs.psl.model.atom.GroundAtom;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ public boolean equals(Object other) {
return false;
}

if (!this.weight.equals(otherRule.weight)) {
return false;
}

return super.equals(other);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.linqs.psl.model.predicate.GroundingOnlyPredicate;
import org.linqs.psl.model.rule.AbstractRule;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Weight;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.reasoner.term.TermStore;
Expand All @@ -60,13 +61,15 @@ public abstract class AbstractLogicalRule extends AbstractRule {
/**
* A key to store per-rule threading grounding resource under.
*/
private final String groundingResourcesKey;
protected final String groundingResourcesKey;

protected Formula formula;
protected final DNFClause negatedDNF;

protected AbstractLogicalRule(Formula formula, String name) {
protected AbstractLogicalRule(Formula formula, String name, int hashcode) {
this.name = name;
this.active = true;
this.hashcode = hashcode;

this.formula = formula;
groundingResourcesKey = AbstractLogicalRule.class.getName() + ";" + formula + ";GroundingResources";
Expand Down Expand Up @@ -101,21 +104,10 @@ protected AbstractLogicalRule(Formula formula, String name) {
throw new IllegalArgumentException("Formula is not a valid rule for unknown reason.");
}

// Build up the hash code from positive and negative literals.
int hash = HashCode.DEFAULT_INITIAL_NUMBER;

for (Atom atom : negatedDNF.getPosLiterals()) {
hash = HashCode.build(hash, atom);
}

for (Atom atom : negatedDNF.getNegLiterals()) {
hash = HashCode.build(hash, atom);
}

this.hashcode = hash;
this.parentHashCode = hash;

ensureRegistration();

this.parentHashCode = hashcode;
this.childHashCodes = new HashSet<Integer>();
}

public Formula getFormula() {
Expand Down Expand Up @@ -163,15 +155,7 @@ public void ground(Constant[] constants, Map<Variable, Integer> variableMap, Dat
results.add(ground(constants, variableMap, database));
}

private GroundRule ground(Constant[] constants, Map<Variable, Integer> variableMap, Database database) {
// Get the grounding resources for this thread,
if (!Parallel.hasThreadObject(groundingResourcesKey)) {
Parallel.putThreadObject(groundingResourcesKey, new GroundingResources(negatedDNF));
}
GroundingResources resources = (GroundingResources)Parallel.getThreadObject(groundingResourcesKey);

return groundInternal(constants, variableMap, database, resources);
}
protected abstract GroundRule ground(Constant[] constants, Map<Variable, Integer> variableMap, Database database);

public long groundAll(QueryResultIterable groundVariables, TermStore termStore, Database database, Grounding.GroundRuleCallback groundRuleCallback) {
long initialCount = termStore.size();
Expand Down Expand Up @@ -238,9 +222,9 @@ public boolean equals(Object other) {
(new HashSet<Atom>(thisNegLiterals)).equals(new HashSet<Atom>(otherNegLiterals));
}

protected abstract AbstractGroundLogicalRule groundFormulaInstance(List<GroundAtom> positiveAtoms, List<GroundAtom> negativeAtoms);
protected abstract AbstractGroundLogicalRule groundFormulaInstance(List<GroundAtom> positiveAtoms, List<GroundAtom> negativeAtoms, Weight groundedWeight);

private GroundRule groundInternal(Constant[] row, Map<Variable, Integer> variableMap,
protected GroundRule groundInternal(Constant[] row, Map<Variable, Integer> variableMap,
Database database, GroundingResources resources) {
resources.positiveAtoms.clear();
resources.negativeAtoms.clear();
Expand Down Expand Up @@ -318,7 +302,26 @@ private GroundRule groundInternal(Constant[] row, Map<Variable, Integer> variabl
return null;
}

return groundFormulaInstance(resources.positiveAtoms, resources.negativeAtoms);
// Ground the deep weight if it exists.
if (resources.weightQueryAtom != null) {
GroundAtom weightGroundAtom = resources.weightQueryAtom.ground(database, row, variableMap, resources.weightArgumentsBuffer, -1.0f);
if (weightGroundAtom == null) {
return null;
}

resources.weightGroundAtom = weightGroundAtom;

if (weightGroundAtom instanceof UnmanagedRandomVariableAtom) {
resources.accessExceptionAtoms.add(weightGroundAtom);
}
}

Weight groundedWeight = null;
if (resources.weightGroundAtom != null) {
groundedWeight = new Weight(1.0f, resources.weightGroundAtom);
}

return groundFormulaInstance(resources.positiveAtoms, resources.negativeAtoms, groundedWeight);
}

private short createAtoms(Database database, Map<Variable, Integer> variableMap,
Expand Down Expand Up @@ -359,37 +362,4 @@ private short createAtoms(Database database, Map<Variable, Integer> variableMap,

return rvaCount;
}

/**
* Allocated resources needed for grounding.
* This will be stashed in the thread objects so each thread will have one.
*/
private static class GroundingResources {
// Remember that these are positive/negative in the CNF.
public List<GroundAtom> positiveAtoms;
public List<GroundAtom> negativeAtoms;

// Atoms that cause trouble for the atom manager.
public Set<GroundAtom> accessExceptionAtoms;

// Allocate up-front some buffers for grounding QueryAtoms into.
public Constant[][] positiveAtomArgs;
public Constant[][] negativeAtomArgs;

public GroundingResources(DNFClause negatedDNF) {
positiveAtoms = new ArrayList<GroundAtom>(4);
negativeAtoms = new ArrayList<GroundAtom>(4);
accessExceptionAtoms = new HashSet<GroundAtom>(4);

positiveAtomArgs = new Constant[negatedDNF.getPosLiterals().size()][];
for (int i = 0; i < negatedDNF.getPosLiterals().size(); i++) {
positiveAtomArgs[i] = new Constant[negatedDNF.getPosLiterals().get(i).getArity()];
}

negativeAtomArgs = new Constant[negatedDNF.getNegLiterals().size()][];
for (int i = 0; i < negatedDNF.getNegLiterals().size(); i++) {
negativeAtomArgs[i] = new Constant[negatedDNF.getNegLiterals().get(i).getArity()];
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2024 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.model.rule.logical;

import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.QueryAtom;
import org.linqs.psl.model.formula.FormulaAnalysis;
import org.linqs.psl.model.rule.Weight;
import org.linqs.psl.model.term.Constant;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* Allocated resources needed for grounding.
* This will be stashed in the thread objects so each thread will have one.
*/
public class GroundingResources {
// Remember that these are positive/negative in the CNF.
public List<GroundAtom> positiveAtoms;
public List<GroundAtom> negativeAtoms;

public QueryAtom weightQueryAtom;
public GroundAtom weightGroundAtom;
public Constant[] weightArgumentsBuffer;


// Atoms that cause trouble for the atom manager.
public Set<GroundAtom> accessExceptionAtoms;

// Allocate up-front some buffers for grounding QueryAtoms into.
public Constant[][] positiveAtomArgs;
public Constant[][] negativeAtomArgs;

public GroundingResources() {
positiveAtoms = null;
negativeAtoms = null;
accessExceptionAtoms = null;

positiveAtomArgs = null;
negativeAtomArgs = null;

weightQueryAtom = null;
weightGroundAtom = null;
weightArgumentsBuffer = null;
}

public void parseNegatedDNF(FormulaAnalysis.DNFClause negatedDNF, Weight weight) {
positiveAtoms = new ArrayList<GroundAtom>(4);
negativeAtoms = new ArrayList<GroundAtom>(4);
accessExceptionAtoms = new HashSet<GroundAtom>(4);

positiveAtomArgs = new Constant[negatedDNF.getPosLiterals().size()][];
for (int i = 0; i < negatedDNF.getPosLiterals().size(); i++) {
positiveAtomArgs[i] = new Constant[negatedDNF.getPosLiterals().get(i).getArity()];
}

negativeAtomArgs = new Constant[negatedDNF.getNegLiterals().size()][];
for (int i = 0; i < negatedDNF.getNegLiterals().size(); i++) {
negativeAtomArgs[i] = new Constant[negatedDNF.getNegLiterals().get(i).getArity()];
}

if ((weight != null) && (weight.isDeep())) {
assert (weight.getAtom() instanceof QueryAtom);

weightQueryAtom = (QueryAtom)weight.getAtom();
weightArgumentsBuffer = new Constant[weightQueryAtom.getArity()];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,39 @@
package org.linqs.psl.model.rule.logical;

import java.util.List;
import java.util.Map;

import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.formula.Formula;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.UnweightedRule;
import org.linqs.psl.model.rule.Weight;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.util.Parallel;

public class UnweightedLogicalRule extends AbstractLogicalRule implements UnweightedRule {
public UnweightedLogicalRule(Formula formula) {
this(formula, formula.toString());
}

public UnweightedLogicalRule(Formula formula, String name) {
super(formula, name);
super(formula, name, formula.getDNF().hashCode());
}

@Override
protected GroundRule ground(Constant[] constants, Map<Variable, Integer> variableMap, Database database) {
// Get the grounding resources for this thread,
if (!Parallel.hasThreadObject(groundingResourcesKey)) {
GroundingResources groundingResources = new GroundingResources();
groundingResources.parseNegatedDNF(negatedDNF, null);
Parallel.putThreadObject(groundingResourcesKey, groundingResources);
}
GroundingResources groundingResources = (GroundingResources)Parallel.getThreadObject(groundingResourcesKey);

return groundInternal(constants, variableMap, database, groundingResources);
}

@Override
Expand All @@ -41,7 +60,8 @@ public WeightedRule relax(Weight weight, boolean squared) {
}

@Override
protected AbstractGroundLogicalRule groundFormulaInstance(List<GroundAtom> posLiterals, List<GroundAtom> negLiterals) {
protected AbstractGroundLogicalRule groundFormulaInstance(List<GroundAtom> posLiterals, List<GroundAtom> negLiterals, Weight groundedWeight) {
assert groundedWeight == null;
return new UnweightedGroundLogicalRule(this, posLiterals, negLiterals);
}

Expand Down
Loading

0 comments on commit 552a2f5

Please sign in to comment.