Skip to content

Commit

Permalink
Merge pull request #1799 from cmu-phil/vbc-2024-07-16
Browse files Browse the repository at this point in the history
Include specify file for test check nodewise markov
  • Loading branch information
jdramsey authored Jul 17, 2024
2 parents f046f67 + 829ed76 commit 1c5a45e
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 1 deletion.
49 changes: 49 additions & 0 deletions testTrueGraphForCheckNodewiseMarkov.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Graph Nodes:
X1;X2;X3;X4;X5;X6;X7;X8;X9;X10

Graph Edges:
1. X1 --> X2
2. X1 --> X3
3. X1 --> X5
4. X1 --> X6
5. X1 --> X7
6. X1 --> X8
7. X1 --> X9
8. X1 --> X10
9. X2 --> X3
10. X2 --> X4
11. X2 --> X5
12. X2 --> X8
13. X2 --> X9
14. X2 --> X10
15. X3 --> X4
16. X3 --> X5
17. X3 --> X6
18. X3 --> X7
19. X3 --> X8
20. X3 --> X9
21. X3 --> X10
22. X4 --> X5
23. X4 --> X6
24. X4 --> X9
25. X4 --> X10
26. X5 --> X6
27. X5 --> X7
28. X5 --> X8
29. X5 --> X9
30. X5 --> X10
31. X6 --> X7
32. X6 --> X8
33. X6 --> X9
34. X6 --> X10
35. X7 --> X8
36. X7 --> X9
37. X7 --> X10
38. X8 --> X9
39. X8 --> X10
40. X9 --> X10


Test True Graph size: 10
Test Estimated CPDAG Graph: Graph Nodes:
X1;X2;X3;X4;X5;X6;X7;X8;X9;X10
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,56 @@
import edu.cmu.tetrad.util.Params;
import org.junit.Test;

import java.io.File;
import java.util.List;



public class TestCheckNodewiseMarkov {

public static void main(String... args) {
testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket(10, 40, 40, 0.5, 1.0, 0.8);
// testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket(10, 40, 40, 0.5, 1.0, 0.8);
String filePath = "testTrueGraphForCheckNodewiseMarkov.txt";
File file = new File(filePath);
if (file.exists()) {
System.out.println("Loading true graph file: " + filePath);
testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket(file, 0.5, 1.0, 0.8);
} else {
System.out.println("File does not exist at the specified path.");
}
}

public static void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket(File txtFile, double threshold, double shuffleThreshold, double lowRecallBound) {
Graph trueGraph = GraphSaveLoadUtils.loadGraphTxt(txtFile);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
// Parameters without additional setting default tobe Gaussian
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(10000, false);
SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
// TODO VBC: Next check different search algo to generate estimated graph. e.g. PC
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");
testGaussianDAGPrecisionRecallForLocalOnMarkovBlanketUsingAdjAHConfusionMatrix(data, trueGraph, estimatedCpdag, threshold, shuffleThreshold, lowRecallBound);
testGaussianDAGPrecisionRecallForLocalOnMarkovBlanketUsingLGConfusionMatrix(data, trueGraph, estimatedCpdag, threshold, shuffleThreshold, lowRecallBound);
System.out.println("~~~~~~~~~~~~~Full Graph~~~~~~~~~~~~~~~");
estimatedCpdag = GraphUtils.replaceNodes(estimatedCpdag, trueGraph.getNodes());
double whole_ap = new AdjacencyPrecision().getValue(trueGraph, estimatedCpdag, null);
double whole_ar = new AdjacencyRecall().getValue(trueGraph, estimatedCpdag, null);
double whole_ahp = new ArrowheadPrecision().getValue(trueGraph, estimatedCpdag, null);
double whole_ahr = new ArrowheadRecall().getValue(trueGraph, estimatedCpdag, null);
double whole_lgp = new LocalGraphPrecision().getValue(trueGraph, estimatedCpdag, null);
double whole_lgr = new LocalGraphRecall().getValue(trueGraph, estimatedCpdag, null);
System.out.println("whole_ap: " + whole_ap);
System.out.println("whole_ar: " + whole_ar );
System.out.println("whole_ahp: " + whole_ahp);
System.out.println("whole_ahr: " + whole_ahr);
System.out.println("whole_lgp: " + whole_lgp);
System.out.println("whole_lgr: " + whole_lgr);
}

public static void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket(int numNodes, int maxNumEdges, int maxDegree, double threshold, double shuffleThreshold, double lowRecallBound) {
Expand Down

0 comments on commit 1c5a45e

Please sign in to comment.