Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss augmented inference using SparseNetworks #445

Merged
merged 38 commits into from
Dec 5, 2016
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
be2f7b4
-added the Badge Example
kordjamshidi Nov 4, 2016
b6acc34
-added the Badge Example Reader
kordjamshidi Nov 4, 2016
4d4d979
-added Badge example with loss augmented inference
kordjamshidi Nov 4, 2016
75a071e
-format
kordjamshidi Nov 4, 2016
e1a106e
-fixed the test
kordjamshidi Nov 4, 2016
9f40b24
-fixed the tests due to the fix in initialization
kordjamshidi Nov 4, 2016
4a578c4
-test size of weights
kordjamshidi Nov 4, 2016
cd57748
learning configuration
kordjamshidi Nov 5, 2016
55a808c
Merge remote-tracking branch 'upstream/master' into loss-augmented
kordjamshidi Nov 5, 2016
09f4aaa
-added documentation
kordjamshidi Nov 5, 2016
bca4262
-modified and documented the badge example
kordjamshidi Nov 5, 2016
ea0198b
-assert the type
kordjamshidi Nov 6, 2016
ed706a3
-minor
kordjamshidi Nov 6, 2016
da9b768
-added pipeline example with Badge
kordjamshidi Nov 10, 2016
b15189e
-improved documentation
kordjamshidi Nov 11, 2016
0776234
-improved documentation
kordjamshidi Nov 11, 2016
053f965
-relative path works for java folder?!
kordjamshidi Nov 14, 2016
77b0a13
-relative path works for java folder?!
kordjamshidi Nov 14, 2016
73c67a0
-improved documentation
kordjamshidi Nov 14, 2016
4ad5840
format
kordjamshidi Nov 14, 2016
c104fff
SRL join-training
kordjamshidi Nov 16, 2016
e668bdd
Fixed the RunningApps name in the documentation
kordjamshidi Nov 16, 2016
dce32e9
temporarily removed joinnodes populate for SRL experiments
kordjamshidi Nov 16, 2016
8059746
-train mode
kordjamshidi Nov 16, 2016
f91e18d
-use Gurobi
kordjamshidi Nov 16, 2016
e006bb9
-remove logger messages
kordjamshidi Nov 16, 2016
ff5c9f3
-jointTrain setting
kordjamshidi Nov 18, 2016
754d694
format
kordjamshidi Nov 18, 2016
92de796
-fixed the test units path for SRL
kordjamshidi Nov 18, 2016
9cf1f25
-format
kordjamshidi Nov 18, 2016
5186251
-replaced configuration parameters
kordjamshidi Nov 18, 2016
e1ea9a4
-replaced configuration parameters
kordjamshidi Nov 18, 2016
aba475b
-added results of join training (IBT) with SRL ArgTypeClassifier
kordjamshidi Dec 3, 2016
a65cd4e
-brought the logger messages back
kordjamshidi Dec 3, 2016
bcdba7e
-changed back the solver for tests
kordjamshidi Dec 3, 2016
dab55f1
-changed back the commented out join node population
kordjamshidi Dec 3, 2016
4d8f361
-fixed typos in blocking
kordjamshidi Dec 5, 2016
b0baf94
-fixed typos in blocking
kordjamshidi Dec 5, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ object JointTrainSparseNetwork {

val logger: Logger = LoggerFactory.getLogger(this.getClass)
var difference = 0
def apply[HEAD <: AnyRef](node: Node[HEAD], cls: List[ConstrainedClassifier[_, HEAD]], init: Boolean)(implicit headTag: ClassTag[HEAD]) = {
train[HEAD](node, cls, 1, init)
def apply[HEAD <: AnyRef](node: Node[HEAD], cls: List[ConstrainedClassifier[_, HEAD]], init: Boolean, lossAugmented: Boolean)(implicit headTag: ClassTag[HEAD]) = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a doc to this function and explain what it does as well as the parameters?

Copy link
Member

@danyaljj danyaljj Nov 14, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually here what I meant was documentation for the function.
Like:

/** 
* This function does blah blah ... 
* @param node .... 
* @param cls ... 
* .... 
* @param lossAugmented ....
*/

train[HEAD](node, cls, 1, init, lossAugmented)
}

def apply[HEAD <: AnyRef](node: Node[HEAD], cls: List[ConstrainedClassifier[_, HEAD]], it: Int, init: Boolean)(implicit headTag: ClassTag[HEAD]) = {
train[HEAD](node, cls, it, init)
def apply[HEAD <: AnyRef](node: Node[HEAD], cls: List[ConstrainedClassifier[_, HEAD]], it: Int, init: Boolean, lossAugmented: Boolean = false)(implicit headTag: ClassTag[HEAD]) = {
train[HEAD](node, cls, it, init, lossAugmented)
}

@scala.annotation.tailrec
def train[HEAD <: AnyRef](node: Node[HEAD], cls: List[ConstrainedClassifier[_, HEAD]], it: Int, init: Boolean)(implicit headTag: ClassTag[HEAD]): Unit = {
def train[HEAD <: AnyRef](node: Node[HEAD], cls: List[ConstrainedClassifier[_, HEAD]], it: Int, init: Boolean, lossAugmented: Boolean = false)(implicit headTag: ClassTag[HEAD]): Unit = {
// forall members in collection of the head (dm.t) do
logger.info("Training iteration: " + it)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add an assertion here to check that the base classifiers are of the type SparseNetworkLearner. Also you can add that to the function documentation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I already had a line about it in the new documentation.

if (init) ClassifierUtils.InitializeClassifiers(node, cls: _*)
Expand All @@ -43,19 +43,24 @@ object JointTrainSparseNetwork {
if (idx % 5000 == 0)
logger.info(s"Training: $idx examples inferred.")

cls.foreach {
case classifier: ConstrainedClassifier[_, HEAD] =>
val typedClassifier = classifier.asInstanceOf[ConstrainedClassifier[_, HEAD]]
val oracle = typedClassifier.onClassifier.getLabeler
if (lossAugmented)
cls.foreach { cls_i =>
cls_i.onClassifier.classifier.setLossFlag()
cls_i.onClassifier.classifier.setCandidates(cls_i.getCandidates(h).size * cls.size)
}

typedClassifier.getCandidates(h) foreach {
cls.foreach {
currentClassifier: ConstrainedClassifier[_, HEAD] =>
val oracle = currentClassifier.onClassifier.getLabeler
val baseClassifier = currentClassifier.onClassifier.classifier.asInstanceOf[SparseNetworkLearner]
currentClassifier.getCandidates(h) foreach {
candidate =>
{
def trainOnce() = {
val result = typedClassifier.classifier.discreteValue(candidate)

val result = currentClassifier.classifier.discreteValue(candidate)
val trueLabel = oracle.discreteValue(candidate)
val ilearner = typedClassifier.onClassifier.classifier.asInstanceOf[SparseNetworkLearner]
val lLexicon = typedClassifier.onClassifier.getLabelLexicon
val lLexicon = currentClassifier.onClassifier.getLabelLexicon
var LTU_actual: Int = 0
var LTU_predicted: Int = 0
for (i <- 0 until lLexicon.size()) {
Expand All @@ -69,26 +74,26 @@ object JointTrainSparseNetwork {
// and the LTU of the predicted class should be demoted.
if (!result.equals(trueLabel)) //equals("true") && trueLabel.equals("false") )
{
val a = typedClassifier.onClassifier.getExampleArray(candidate)
val a = currentClassifier.onClassifier.getExampleArray(candidate)
val a0 = a(0).asInstanceOf[Array[Int]] //exampleFeatures
val a1 = a(1).asInstanceOf[Array[Double]] // exampleValues
val exampleLabels = a(2).asInstanceOf[Array[Int]]
val label = exampleLabels(0)
var N = ilearner.getNetwork.size
var N = baseClassifier.getNetwork.size

if (label >= N || ilearner.getNetwork.get(label) == null) {
val conjugateLabels = ilearner.isUsingConjunctiveLabels | ilearner.getLabelLexicon.lookupKey(label).isConjunctive
ilearner.setConjunctiveLabels(conjugateLabels)
if (label >= N || baseClassifier.getNetwork.get(label) == null) {
val conjugateLabels = baseClassifier.isUsingConjunctiveLabels | baseClassifier.getLabelLexicon.lookupKey(label).isConjunctive
baseClassifier.setConjunctiveLabels(conjugateLabels)

val ltu: LinearThresholdUnit = ilearner.getBaseLTU
ltu.initialize(ilearner.getNumExamples, ilearner.getNumFeatures)
ilearner.getNetwork.set(label, ltu)
val ltu: LinearThresholdUnit = baseClassifier.getBaseLTU.clone().asInstanceOf[LinearThresholdUnit]
ltu.initialize(baseClassifier.getNumExamples, baseClassifier.getNumFeatures)
baseClassifier.getNetwork.set(label, ltu)
N = label + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is not required. Also N can be made a val.

}

// test push
val ltu_actual = ilearner.getLTU(LTU_actual).asInstanceOf[LinearThresholdUnit]
val ltu_predicted = ilearner.getLTU(LTU_predicted).asInstanceOf[LinearThresholdUnit]
val ltu_actual = baseClassifier.getLTU(LTU_actual).asInstanceOf[LinearThresholdUnit]
val ltu_predicted = baseClassifier.getLTU(LTU_predicted).asInstanceOf[LinearThresholdUnit]

if (ltu_actual != null)
ltu_actual.promote(a0, a1, 0.1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are promoting/demoting by a fixed update of 0.1, shouldn't we take into account the learning rate parameter. The update rule inside LinearThresholdUnit's learn function is according to the learning rate and margin thickness.

Copy link
Member Author

@kordjamshidi kordjamshidi Nov 6, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this has remained here from my very first trial version. How should I pass the parameters, do you think that I just add it to the list of input parameters? Since we have two apply versions it can not have the default value for both cases as well, I guess. Isn't it a separate issue to have a consistent way for parameter setting in Saul?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The baseLTU already has all parameters to use. We can directly call the learn function to use those parameters.

val labelValues = a(3).asInstanceOf[Array[Double]]

if (ltu_actual != null) {
    # Learn as Positive Example
    ltu_actual.learn(a0, a1, Array(1), labelValues)
}

if (ltu_predicted != null) {
    # Learn as a negative example
    ltu_predicted.learn(a0, a1, Array(0), labelValues)
}

Also it might be better to rename all the variables a, a0, a1 etc for better readability.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call learn?! and what we are doing here then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

learn does not use internal prediction result?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/IllinoisCogComp/lbjava/blob/master/lbjava/src/main/java/edu/illinois/cs/cogcomp/lbjava/learn/LinearThresholdUnit.java#L462

Learn promotes or demotes the LTU's weight vector. The third argument controls if promote should be called or demote should be called.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about the score, s?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine if we cannot use learn. My only concern was using that having a fixed learning rate might affect performance. We can fix that separately.

Expand All @@ -100,8 +105,13 @@ object JointTrainSparseNetwork {
trainOnce()
}
}

}
}
if (lossAugmented)
cls.foreach { cls_i =>
cls_i.onClassifier.classifier.unsetLossFlag()
}
}
train(node, cls, it - 1, false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ object InitSparseNetwork {
if (label >= N || iLearner.getNetwork.get(label) == null) {
val isConjunctiveLabels = iLearner.isUsingConjunctiveLabels | iLearner.getLabelLexicon.lookupKey(label).isConjunctive
iLearner.setConjunctiveLabels(isConjunctiveLabels)
val ltu: LinearThresholdUnit = iLearner.getBaseLTU
val ltu: LinearThresholdUnit = iLearner.getBaseLTU.clone().asInstanceOf[LinearThresholdUnit]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the necessity for clone()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this bug was caught by @bhargav, it needs to create a new instance of linear threshold here each time a new label is met. This was the main bug for the SparseNetwork initialization.

ltu.initialize(iLearner.getNumExamples, iLearner.getNumFeatures)
iLearner.getNetwork.set(label, ltu)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
/** This software is released under the University of Illinois/Research and Academic Use License. See
* the LICENSE file in the root folder for details. Copyright (c) 2016
*
* Developed by: The Cognitive Computations Group, University of Illinois at Urbana-Champaign
* http://cogcomp.cs.illinois.edu/
*/
package edu.illinois.cs.cogcomp.saul.classifier.JoinTrainingTests

import edu.illinois.cs.cogcomp.infer.ilp.OJalgoHook
Expand Down Expand Up @@ -71,7 +77,7 @@ class InitializeSparseNetwork extends FlatSpec with Matchers {
val wv1After = clNet1.getNetwork.get(0).asInstanceOf[LinearThresholdUnit].getWeightVector
val wv2After = clNet2.getNetwork.get(0).asInstanceOf[LinearThresholdUnit].getWeightVector

wv1After.size() should be(5)
wv1After.size() should be(6)
wv2After.size() should be(12)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/** This software is released under the University of Illinois/Research and Academic Use License. See
* the LICENSE file in the root folder for details. Copyright (c) 2016
*
* Developed by: The Cognitive Computations Group, University of Illinois at Urbana-Champaign
* http://cogcomp.cs.illinois.edu/
*/
package edu.illinois.cs.cogcomp.saulexamples.Badge;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;

public class BadgeReader {
public List<String> badges;
// int currentBadge;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop this?


public BadgeReader(String dataFile) {
badges = new ArrayList<String>();

try {
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(dataFile)));

String str;
while ((str = br.readLine()) != null) {
badges.add(str);
}

br.close();
}catch (Exception e) {}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you apply the autoformatter on this file?

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/** This software is released under the University of Illinois/Research and Academic Use License. See
* the LICENSE file in the root folder for details. Copyright (c) 2016
*
* Developed by: The Cognitive Computations Group, University of Illinois at Urbana-Champaign
* http://cogcomp.cs.illinois.edu/
*/
package edu.illinois.cs.cogcomp.saulexamples.Badge

import edu.illinois.cs.cogcomp.lbjava.learn.{ SparseNetworkLearner, SparsePerceptron }

/** Created by Parisa on 9/13/16.
*/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop the comment?


object BadgeClassifiers {
import BadgeDataModel._
import edu.illinois.cs.cogcomp.saul.classifier.Learnable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this import to the top?

object BadgeClassifier extends Learnable[String](badge) {
def label = BadgeLabel
override lazy val classifier = new SparsePerceptron()
override def feature = using(BadgeFeature1)
}
object BadgeOppositClassifier extends Learnable[String](badge) {
def label = BadgeOppositLabel
override lazy val classifier = new SparsePerceptron()
override def feature = using(BadgeFeature1)
}

object BadgeClassifierMulti extends Learnable[String](badge) {
def label = BadgeLabel
override lazy val classifier = new SparseNetworkLearner()
override def feature = using(BadgeFeature1)
}

object BadgeOppositClassifierMulti extends Learnable[String](badge) {
def label = BadgeOppositLabel
override lazy val classifier = new SparseNetworkLearner()
override def feature = using(BadgeFeature1)
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a little comment to each of these classifiers?

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/** This software is released under the University of Illinois/Research and Academic Use License. See
* the LICENSE file in the root folder for details. Copyright (c) 2016
*
* Developed by: The Cognitive Computations Group, University of Illinois at Urbana-Champaign
* http://cogcomp.cs.illinois.edu/
*/
package edu.illinois.cs.cogcomp.saulexamples.Badge

import edu.illinois.cs.cogcomp.infer.ilp.OJalgoHook
import edu.illinois.cs.cogcomp.saul.classifier.ConstrainedClassifier
import edu.illinois.cs.cogcomp.saul.constraint.ConstraintTypeConversion._
import edu.illinois.cs.cogcomp.saulexamples.Badge.BadgeClassifiers.{ BadgeOppositClassifierMulti, BadgeClassifierMulti, BadgeClassifier, BadgeOppositClassifier }

/** Created by Parisa on 11/1/16.
*/
object BadgeConstraintClassifiers {

val binaryConstraint = ConstrainedClassifier.constraint[String] {
x: String =>
(BadgeClassifier on x is "negative") ==> (BadgeOppositClassifier on x is "positive")
}

object badgeConstrainedClassifier extends ConstrainedClassifier[String, String](BadgeClassifier) {
def subjectTo = binaryConstraint
override val solver = new OJalgoHook
}

object oppositBadgeConstrainedClassifier extends ConstrainedClassifier[String, String](BadgeOppositClassifier) {
def subjectTo = binaryConstraint
override val solver = new OJalgoHook
}

object badgeConstrainedClassifierMulti extends ConstrainedClassifier[String, String](BadgeClassifierMulti) {
def subjectTo = binaryConstraint
override val solver = new OJalgoHook
}

object oppositBadgeConstrainedClassifierMulti extends ConstrainedClassifier[String, String](BadgeOppositClassifierMulti) {
def subjectTo = binaryConstraint
override val solver = new OJalgoHook
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/** This software is released under the University of Illinois/Research and Academic Use License. See
* the LICENSE file in the root folder for details. Copyright (c) 2016
*
* Developed by: The Cognitive Computations Group, University of Illinois at Urbana-Champaign
* http://cogcomp.cs.illinois.edu/
*/
package edu.illinois.cs.cogcomp.saulexamples.Badge

import edu.illinois.cs.cogcomp.saul.datamodel.DataModel

/** Created by Parisa on 9/13/16.
*/
object BadgeDataModel extends DataModel {

val badge = node[String]

val BadgeFeature1 = property(badge) {
x: String =>
{
val tokens = x.split(" ")
tokens(1).charAt(1).toString
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop these paranthesis?

}

val BadgeLabel = property(badge)("true", "false") {
x: String =>
{
val tokens = x.split(" ")
if (tokens(0).equals("+"))
"true"
else
"false"
}
}

val BadgeOppositLabel = property(badge)("true", "false") {
x: String =>
{
val tokens = x.split(" ")
if (tokens(0).equals("+"))
"false"
else
"true"
}
}
Copy link
Member

@danyaljj danyaljj Nov 5, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of doing this?

Why not re-use BadgeLabel here and say:
if(BadgeOppositLabel(x) == "true") "false" else "true"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no specific reason, I guess the overhead is the same.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but you don't repeat the code; instead reuse it.

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/** This software is released under the University of Illinois/Research and Academic Use License. See
* the LICENSE file in the root folder for details. Copyright (c) 2016
*
* Developed by: The Cognitive Computations Group, University of Illinois at Urbana-Champaign
* http://cogcomp.cs.illinois.edu/
*/
package edu.illinois.cs.cogcomp.saulexamples.Badge

/** Created by Parisa on 9/13/16.
*/

import edu.illinois.cs.cogcomp.saul.classifier.{ JointTrain, JointTrainSparseNetwork }
import edu.illinois.cs.cogcomp.saulexamples.Badge.BadgeClassifiers.{ BadgeOppositClassifier, BadgeClassifier }
import edu.illinois.cs.cogcomp.saulexamples.Badge.BadgeConstraintClassifiers.{ badgeConstrainedClassifier, badgeConstrainedClassifierMulti, oppositBadgeConstrainedClassifier, oppositBadgeConstrainedClassifierMulti }
import edu.illinois.cs.cogcomp.saulexamples.Badge.BadgeDataModel._

import scala.collection.JavaConversions._
object BadgesApp {

val allNamesTrain = new BadgeReader("data/badges/badges.train").badges
val allNamesTest = new BadgeReader("data/badges/badges.test").badges

badge.populate(allNamesTrain)
badge.populate(allNamesTest, false)

val cls = List(badgeConstrainedClassifierMulti, oppositBadgeConstrainedClassifierMulti)

object BadgeExperimentType extends Enumeration {
val JoinTrainSparsePerceptron, JointTrainSparseNetwork, JointTrainSparseNetworkLossAugmented = Value
}

def main(args: Array[String]): Unit = {

/** Choose the experiment you're interested in by changing the following line */
val testType = BadgeExperimentType.JointTrainSparseNetworkLossAugmented

testType match {
case BadgeExperimentType.JoinTrainSparsePerceptron => JoinTrainSparsePerceptron()
case BadgeExperimentType.JointTrainSparseNetwork => JoinTrainSparseNetwork()
case BadgeExperimentType.JointTrainSparseNetworkLossAugmented => LossAugmentedJoinTrainSparseNetwork()
}
}

/*Test the join training with SparsePerceptron*/
def JoinTrainSparsePerceptron(): Unit = {
BadgeClassifier.test()
BadgeOppositClassifier.test()
JointTrain.train(BadgeDataModel.badge, List(badgeConstrainedClassifier, oppositBadgeConstrainedClassifier), 5)
oppositBadgeConstrainedClassifier.test()
badgeConstrainedClassifier.test()
BadgeClassifier.test()
}

/*Test the joinTraining with SparseNetwork*/
def JoinTrainSparseNetwork(): Unit = {

JointTrainSparseNetwork.train(badge, cls, 5, init = true)

badgeConstrainedClassifierMulti.test()
oppositBadgeConstrainedClassifierMulti.test()
}

/*Test the joinTraining with SparseNetwork and doing loss augmented inference*/
def LossAugmentedJoinTrainSparseNetwork(): Unit = {

JointTrainSparseNetwork.train(badge, cls, 5, init = true, lossAugmented = true)

badgeConstrainedClassifierMulti.test()
oppositBadgeConstrainedClassifierMulti.test()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,12 @@ class EntityRelationTests extends FlatSpec with Matchers {

ClassifierUtils.TrainClassifiers(1, cls_base)

PerConstrainedClassifier.onClassifier.classifier.asInstanceOf[SparseNetworkLearner].getNetwork.get(0).asInstanceOf[LinearThresholdUnit].getWeightVector.size() should be(1660)
PerConstrainedClassifier.onClassifier.classifier.asInstanceOf[SparseNetworkLearner].getNetwork.get(0).asInstanceOf[LinearThresholdUnit].getWeightVector.size() should be(1654)

val jointTrainIteration = 1
JointTrainSparseNetwork.train[ConllRelation](
pairs, cls, jointTrainIteration, init = true
)
val jointTrainIteration = 2
JointTrainSparseNetwork.train[ConllRelation](pairs, cls, jointTrainIteration, init = true)

PerConstrainedClassifier.onClassifier.classifier.asInstanceOf[SparseNetworkLearner].getNetwork.get(0).asInstanceOf[LinearThresholdUnit].getWeightVector.size() should be(50)
PerConstrainedClassifier.onClassifier.classifier.asInstanceOf[SparseNetworkLearner].getNetwork.get(0).asInstanceOf[LinearThresholdUnit].getWeightVector.size() should be(81)

}
}