Skip to content

Commit

Permalink
Fix an import issue with pygraphviz
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Bloebaum <[email protected]>
  • Loading branch information
bloebp committed Jul 21, 2023
1 parent 88d79c3 commit 04abdfa
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions dowhy/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import networkx as nx
import numpy as np
import pandas as pd
import pygraphviz
from matplotlib import image

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -172,7 +171,7 @@ def _calc_arrow_width(strength: float, max_strength: float):
return 0.1 + 4.0 * float(abs(strength)) / float(max_strength)


def _plot_as_pyplot_figure(pygraphviz_graph: pygraphviz.AGraph, figure_size: Optional[Tuple[int, int]] = None) -> None:
def _plot_as_pyplot_figure(pygraphviz_graph: Any, figure_size: Optional[Tuple[int, int]] = None) -> None:
with tempfile.TemporaryDirectory() as tmp_dir_name:
pygraphviz_graph.draw(tmp_dir_name + os.sep + "Graph.png")
img = image.imread(tmp_dir_name + os.sep + "Graph.png")
Expand Down

0 comments on commit 04abdfa

Please sign in to comment.