Skip to content

Commit

Permalink
Merge pull request #1776 from cmu-phil/vbc-2024-05-23-2
Browse files Browse the repository at this point in the history
Introduce plot data collection for different local graph confusion statistics
  • Loading branch information
jdramsey authored May 23, 2024
2 parents 9bc1256 + 58f01ed commit ca0379b
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 27 deletions.
152 changes: 141 additions & 11 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,17 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind
return accepts_rejects;
}

/**
* Get accepts and rejects nodes for all nodes from Anderson-Darling test and generate the plot data for confusion statistics.
*
* Confusion statistics were calculated using Adjacency (AdjacencyPrecision, AdjacencyRecall) and Arrowhead (ArrowheadPrecision, ArrowheadRecall)
* @param independenceTest
* @param estimatedCpdag
* @param trueGraph
* @param threshold
* @param shuffleThreshold
* @return
*/
public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) {
// When calling, default reject null as <=0.05
List<List<Node>> accepts_rejects = new ArrayList<>();
Expand Down Expand Up @@ -381,35 +392,35 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
List<List<Double>> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5
for (List<Double> localPValues: shuffledlocalPValues) {
// P value obtained from AD test
Double ADTest = checkAgainstAndersonDarlingTest(localPValues);
Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues);
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
if (ADTest <= threshold) {
if (ADTestPValue <= threshold) {
rejects.add(x);
if (!Double.isNaN(ap)) {
rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTest));
rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ar)) {
rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTest));
rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahp)) {
rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTest));
rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahr)) {
rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTest));
rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
} else {
accepts.add(x);
if (!Double.isNaN(ap)) {
accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTest));
accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ar)) {
accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTest));
accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahp)) {
accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTest));
accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahr)) {
accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTest));
accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
}
}
Expand All @@ -421,7 +432,7 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
try (BufferedWriter writer = new BufferedWriter(new FileWriter(entry.getKey()))) {
writer.write(entry.getValue());
switch (entry.getKey()) {
case "acceptsAdjP_ADTestP_data.csv":
case "accepts_AdjP_ADTestP_data.csv":
for (List<Double> AdjP_ADTestP_pair : accepts_AdjP_ADTestP) {
writer.write(nf.format(AdjP_ADTestP_pair.get(0)) + "," + nf.format(AdjP_ADTestP_pair.get(1)) + "\n");
}
Expand Down Expand Up @@ -479,6 +490,112 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
return accepts_rejects;
}

/**
* Get accepts and rejects nodes for all nodes from Anderson-Darling test and generate the plot data for confusion statistics.
*
* Confusion statistics were calculated using Local Graph Precision and Recall (LocalGraphPrecision, LocalGraphRecall).
* @param independenceTest
* @param estimatedCpdag
* @param trueGraph
* @param threshold
* @param shuffleThreshold
* @return
*/
public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) {
// When calling, default reject null as <=0.05
List<List<Node>> accepts_rejects = new ArrayList<>();
List<Node> accepts = new ArrayList<>();
List<Node> rejects = new ArrayList<>();
List<Node> allNodes = graph.getNodes();

// Confusion stats lists for data processing.
Map<String, String> fileContentMap = new HashMap<>();

// Using Local Graph Precision and Recall to calculate Confusion statistics.
List<List<Double>> accepts_LGP_ADTestP = new ArrayList<>();
List<List<Double>> accepts_LGR_ADTestP = new ArrayList<>();
fileContentMap.put("accepts_LGP_ADTestP_data.csv", "");
fileContentMap.put("accepts_LGR_ADTestP_data.csv", "");

List<List<Double>> rejects_LGP_ADTestP = new ArrayList<>();
List<List<Double>> rejects_LGR_ADTestP = new ArrayList<>();
fileContentMap.put("rejects_LGP_ADTestP_data.csv", "");
fileContentMap.put("rejects_LGR_ADTestP_data.csv", "");

NumberFormat nf = new DecimalFormat("0.00");
// Classify nodes into accepts and rejects base on ADTest result, and update confusion stats lists accordingly.
for (Node x : allNodes) {
List<IndependenceFact> localIndependenceFacts = getLocalIndependenceFacts(x);
List<Double> lgp_lgr = getPrecisionAndRecallOnMarkovBlanketGraphPlotData2(x, estimatedCpdag, trueGraph);
Double lgp = lgp_lgr.get(0);
Double lgr = lgp_lgr.get(1);
// All local nodes' p-values for node x.
List<List<Double>> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5
for (List<Double> localPValues: shuffledlocalPValues) {
// P value obtained from AD test
Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues);
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
if (ADTestPValue <= threshold) {
rejects.add(x);
if (!Double.isNaN(lgp)) {
rejects_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue));
}
if (!Double.isNaN(lgr)) {
rejects_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue));
}
} else {
accepts.add(x);
if (!Double.isNaN(lgp)) {
accepts_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue));
}
if (!Double.isNaN(lgr)) {
accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue));
}
}
}
}
accepts_rejects.add(accepts);
accepts_rejects.add(rejects);
// Write into data files.
for (Map.Entry<String, String> entry : fileContentMap.entrySet()) {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(entry.getKey()))) {
writer.write(entry.getValue());
switch (entry.getKey()) {
case "accepts_LGP_ADTestP_data.csv":
for (List<Double> LGP_ADTestP_pair : accepts_LGP_ADTestP) {
writer.write(nf.format(LGP_ADTestP_pair.get(0)) + "," + nf.format(LGP_ADTestP_pair.get(1)) + "\n");
}
break;

case "accepts_LGR_ADTestP_data.csv":
for (List<Double> LGR_ADTestP_pair : accepts_LGR_ADTestP) {
writer.write(nf.format(LGR_ADTestP_pair.get(0)) + "," + nf.format(LGR_ADTestP_pair.get(1)) + "\n");
}
break;

case "rejects_LGP_ADTestP_data.csv":
for (List<Double> LGP_ADTestP_pair : rejects_LGP_ADTestP) {
writer.write(nf.format(LGP_ADTestP_pair.get(0)) + "," + nf.format(LGP_ADTestP_pair.get(1)) + "\n");
}
break;

case "rejects_LGR_ADTestP_data.csv":
for (List<Double> LGR_ADTestP_pair : rejects_LGR_ADTestP) {
writer.write(nf.format(LGR_ADTestP_pair.get(0)) + "," + nf.format(LGR_ADTestP_pair.get(1)) + "\n");
}
break;

default:
break;
}
System.out.println("Successfully written to " + entry.getKey());
} catch (IOException e) {
e.printStackTrace();
}
}
return accepts_rejects;
}

/**
* Calculates the precision and recall on the Markov Blanket graph for a given node. Prints the statistics to the
* console.
Expand Down Expand Up @@ -547,6 +664,19 @@ public void getPrecisionAndRecallOnMarkovBlanketGraph2(Node x, Graph estimatedGr
" LocalGraphPrecision = " + nf.format(lgp) + " LocalGraphRecall = " + nf.format(lgr) + " \n");
}

public List<Double> getPrecisionAndRecallOnMarkovBlanketGraphPlotData2(Node x, Graph estimatedGraph, Graph trueGraph) {
// Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes.
Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes());
Graph xMBLookupGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(lookupGraph, x);
System.out.println("xMBLookupGraph:" + xMBLookupGraph);
Graph xMBEstimatedGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(estimatedGraph, x);
System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph);

double lgp = new LocalGraphPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null);
double lgr = new LocalGraphRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null);
return Arrays.asList(lgp, lgr);
}

/**
* Returns the variables of the independence test.
*
Expand Down
20 changes: 4 additions & 16 deletions tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() {
}

@Test
public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() {
public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket2() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());
Expand All @@ -461,25 +461,13 @@ public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() {
IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
// ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5);
// List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.3);

List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

List<Double> acceptsPrecision = new ArrayList<>();
List<Double> acceptsRecall = new ArrayList<>();
for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph);
System.out.println("=====================");
}
}

}

0 comments on commit ca0379b

Please sign in to comment.