Skip to content

Commit

Permalink
Invariant nodes removal (#1013)
Browse files Browse the repository at this point in the history
Add parameter to define invariant nodes in distribution_change function

---------

Signed-off-by: priyadutt <[email protected]>
  • Loading branch information
bhatt-priyadutt authored Aug 24, 2023
1 parent 734b616 commit b4f80dc
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 14 deletions.
32 changes: 20 additions & 12 deletions dowhy/gcm/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,28 @@ def assign_causal_mechanisms(
if not override_models and CAUSAL_MECHANISM in causal_model.graph.nodes[node]:
validate_causal_model_assignment(causal_model.graph, node)
continue
assign_causal_mechanism_node(causal_model, node, based_on, quality)

if is_root_node(causal_model.graph, node):
causal_model.set_causal_mechanism(node, EmpiricalDistribution())
else:
prediction_model = select_model(
based_on[get_ordered_predecessors(causal_model.graph, node)].to_numpy(),
based_on[node].to_numpy(),
quality,
)

if isinstance(prediction_model, ClassificationModel):
causal_model.set_causal_mechanism(node, ClassifierFCM(prediction_model))
else:
causal_model.set_causal_mechanism(node, AdditiveNoiseModel(prediction_model))
def assign_causal_mechanism_node(
causal_model: ProbabilisticCausalModel,
node: str,
based_on: pd.DataFrame,
quality: AssignmentQuality = AssignmentQuality.GOOD,
) -> None:
if is_root_node(causal_model.graph, node):
causal_model.set_causal_mechanism(node, EmpiricalDistribution())
else:
prediction_model = select_model(
based_on[get_ordered_predecessors(causal_model.graph, node)].to_numpy(),
based_on[node].to_numpy(),
quality,
)

if isinstance(prediction_model, ClassificationModel):
causal_model.set_causal_mechanism(node, ClassifierFCM(prediction_model))
else:
causal_model.set_causal_mechanism(node, AdditiveNoiseModel(prediction_model))


def select_model(
Expand Down
45 changes: 43 additions & 2 deletions dowhy/gcm/distribution_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import networkx as nx
import numpy as np
Expand All @@ -13,7 +13,7 @@
from statsmodels.stats.multitest import multipletests
from tqdm import tqdm

from dowhy.gcm.auto import AssignmentQuality, assign_causal_mechanisms
from dowhy.gcm.auto import AssignmentQuality, assign_causal_mechanism_node, assign_causal_mechanisms
from dowhy.gcm.causal_mechanisms import ConditionalStochasticModel
from dowhy.gcm.causal_models import (
PARENTS_DURING_FIT,
Expand Down Expand Up @@ -93,6 +93,7 @@ def distribution_change(
old_data: pd.DataFrame,
new_data: pd.DataFrame,
target_node: Any,
invariant_nodes: List[Any] = None,
num_samples: int = 2000,
difference_estimation_func: Callable[[np.ndarray, np.ndarray], float] = auto_estimate_kl_divergence,
independence_test: Callable[[np.ndarray, np.ndarray], float] = kernel_based,
Expand All @@ -119,6 +120,8 @@ def distribution_change(
:param old_data: Joint samples from the 'old' distribution.
:param new_data: Joint samples from the 'new' distribution.
:param target_node: Target node of interest for attributing the marginal distribution change.
:param invariant_nodes: List of nodes where the mechanism is kept constant regardless of changes in the
datasets being analyzed.
:param num_samples: Number of samples used for estimating Shapley values. This can have a significant influence
on runtime and accuracy.
:param difference_estimation_func: Function for quantifying the distribution change. This function should expect
Expand Down Expand Up @@ -149,13 +152,17 @@ def distribution_change(
returned: a dictionary indicating whether each node's mechanism changed, the causal DAG whose causal models
learned from old data, and the causal DAG whose causal models are learned from new data.
"""
if invariant_nodes is None:
invariant_nodes = []
causal_graph_old = graph_factory(node_connected_subgraph_view(causal_model.graph, target_node))
causal_model_old = ProbabilisticCausalModel(causal_graph_old)

if auto_assignment_quality is None:
clone_causal_models(causal_model.graph, causal_model_old.graph)
else:
assign_causal_mechanisms(causal_model_old, old_data, override_models=True, quality=auto_assignment_quality)
invariant_nodes = list(set(invariant_nodes).intersection(set(causal_graph_old.nodes)))
_remove_invariant_nodes(invariant_nodes, causal_model_old, old_data, auto_assignment_quality)

causal_graph_new = graph_factory(causal_graph_old)
causal_model_new = ProbabilisticCausalModel(causal_graph_new)
Expand All @@ -173,6 +180,7 @@ def distribution_change(
conditional_independence_test,
mechanism_change_test_significance_level,
mechanism_change_test_fdr_control_method,
invariant_nodes,
)

attributions = distribution_change_of_graphs(
Expand All @@ -184,6 +192,9 @@ def distribution_change(
shapley_config,
graph_factory,
)
# set attributions to zero for left out invariant nodes
for node in invariant_nodes:
attributions[node] = 0
if return_additional_info:
return attributions, mechanism_changes, causal_model_old, causal_model_new
else:
Expand Down Expand Up @@ -238,6 +249,33 @@ def distribution_change_of_graphs(
)


def _remove_invariant_nodes(
invariant_nodes: List[Any],
causal_model: ProbabilisticCausalModel,
old_data: pd.DataFrame,
auto_assignment_quality: Optional[AssignmentQuality],
) -> None:
if auto_assignment_quality is None:
auto_assignment_quality = AssignmentQuality.GOOD
for invar_node in invariant_nodes:
# Get parent and child nodes
parents = get_ordered_predecessors(causal_model.graph, invar_node)
children = list(causal_model.graph.successors(invar_node))
# Don't remove node if node has more than 1 children nodes as it can introduce
# hidden confounders.
if len(children) > 1:
continue
# Remove the middle node
causal_model.graph.remove_node(invar_node)
# Connect parent and child nodes
for parent in parents:
for child in children:
causal_model.graph.add_edge(parent, child)
# Update the causal mechanism for the child nodes
for child in children:
assign_causal_mechanism_node(causal_model, child, old_data, quality=auto_assignment_quality)


def _fit_accounting_for_mechanism_change(
causal_model_old: ProbabilisticCausalModel,
causal_model_new: ProbabilisticCausalModel,
Expand All @@ -247,6 +285,7 @@ def _fit_accounting_for_mechanism_change(
conditional_independence_test: Callable[[np.ndarray, np.ndarray, np.ndarray], float],
significance_level: float,
fdr_control_method: Optional[str],
invariant_nodes: List[Any],
) -> Dict[Any, bool]:
mechanism_changed_for_node = _check_significant_mechanism_change(
causal_model_old.graph,
Expand All @@ -261,6 +300,8 @@ def _fit_accounting_for_mechanism_change(
joint_data = pd.concat([old_data, new_data], ignore_index=True, sort=True)

for node in causal_model_new.graph.nodes:
if node in invariant_nodes:
mechanism_changed_for_node[node] = False
if mechanism_changed_for_node[node]:
fit_causal_model_of_target(causal_model_old, node, old_data)
fit_causal_model_of_target(causal_model_new, node, new_data)
Expand Down
20 changes: 20 additions & 0 deletions tests/gcm/test_distribution_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ def test_given_two_graphs_fitted_on_data_sets_with_different_mechanisms_when_eva
assert results["X0"] == approx(0, abs=0.1)


@flaky(max_runs=5)
def test_given_list_of_invariant_nodes_to_remove_return_expected_results():
original_observations, outlier_observations = _generate_data()

causal_model = ProbabilisticCausalModel(nx.DiGraph([("X0", "X1"), ("X0", "X2"), ("X2", "X3")]))
_assign_causal_mechanisms(causal_model)

results = distribution_change(
causal_model,
original_observations,
outlier_observations,
"X3",
shapley_config=ShapleyConfig(n_jobs=1),
invariant_nodes=["X0", "X1"],
)

assert results["X3"] > results["X2"]
assert results["X0"] == approx(0)


@flaky(max_runs=5)
def test_when_using_distribution_change_without_fdrc_then_returns_valid_results():
original_observations, outlier_observations = _generate_data()
Expand Down

0 comments on commit b4f80dc

Please sign in to comment.