diff --git a/examples/hello_world/README.md b/examples/hello_world/README.md
index 12b93f4e..4d7e9f88 100644
--- a/examples/hello_world/README.md
+++ b/examples/hello_world/README.md
@@ -9,9 +9,9 @@ File organization:
* `my_functions.py` houses the logic that we want to compute. Note how the functions are named, and what input
parameters they require. That is how we create a DAG modeling the dataflow we want to happen.
* `my_script.py` houses how to get Hamilton to create the DAG and exercise it with some inputs.
-* `my_notebook_script.py` houses how one might iterate in a notebook environment and provide a way to inline define Hamilton
+* `my_notebook.ipynb` houses how one might iterate in a notebook environment and provide a way to inline define Hamilton
functions and add them to the DAG constructed. To be clear, it is not used by `my_script.py`, but showing an alternate path
-to running things.
+to running/developing things.
To run things:
```bash
diff --git a/examples/hello_world/my_notebook.ipynb b/examples/hello_world/my_notebook.ipynb
new file mode 100644
index 00000000..45ad6318
--- /dev/null
+++ b/examples/hello_world/my_notebook.ipynb
@@ -0,0 +1,557 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Cell 1 - import the things you need\n",
+ "import logging\n",
+ "import sys\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "from hamilton import ad_hoc_utils, driver\n",
+ "\n",
+ "logging.basicConfig(stream=sys.stdout)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Cell 2 - import modules to create part of the DAG from\n",
+ "# We use the autoreload extension that comes with ipython to automatically reload modules when\n",
+ "# the code in them changes.\n",
+ "\n",
+ "# import the jupyter extension\n",
+ "%load_ext autoreload\n",
+ "# set it to only reload the modules imported\n",
+ "%autoreload 1\n",
+ "# import the function modules you want to reload when they change.\n",
+ "# i.e. these should be your modules you write your functions in. As you change them,\n",
+ "# they will be reimported without you having to do anything.\n",
+ "%aimport my_functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Cell 3 - Define your new Hamilton functions & curate them into a TemporaryFunctionModule object.\n",
+ "# This enables you to add functions to your DAG without creating a proper module.\n",
+ "# This is ONLY INTENDED FOR QUICK DEVELOPMENT. For moving to production move these to an actual module.\n",
+ "\n",
+ "# Look at `my_functions` to see how these functions connect.\n",
+ "def signups() -> pd.Series:\n",
+ " \"\"\"Returns sign up values\"\"\"\n",
+ " return pd.Series([1, 10, 50, 100, 200, 400])\n",
+ "\n",
+ "\n",
+ "def spend() -> pd.Series:\n",
+ " \"\"\"Returns the spend values\"\"\"\n",
+ " return pd.Series([10, 10, 20, 40, 40, 50])\n",
+ "\n",
+ "\n",
+ "def log_spend_per_signup(spend_per_signup: pd.Series) -> pd.Series:\n",
+ " \"\"\"Simple function taking the logarithm of spend over signups.\"\"\"\n",
+ " return np.log(spend_per_signup)\n",
+ "\n",
+ "\n",
+ "# Place the functions into a temporary module -- the idea is that this should house a curated set of functions.\n",
+ "# Don't be afraid to make multiple of them -- however we'd advise you to not use this method for production.\n",
+ "# Also note, that using a temporary function module does not work for scaling onto Ray, Dask, or Pandas on Spark.\n",
+ "temp_module = ad_hoc_utils.create_temporary_module(\n",
+ " spend, signups, log_spend_per_signup, module_name=\"function_example\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Cell 4 - Instantiate the Hamilton driver and pass it the right things in.\n",
+ "\n",
+ "initial_config = {}\n",
+ "# we need to tell hamilton where to load function definitions from\n",
+ "dr = driver.Driver(initial_config, my_functions, temp_module) # can pass in multiple modules\n",
+ "# we need to specify what we want in the final dataframe.\n",
+ "output_columns = [\n",
+ " \"spend\",\n",
+ " \"signups\",\n",
+ " \"avg_3wk_spend\",\n",
+ " \"spend_per_signup\",\n",
+ " \"spend_zero_mean_unit_variance\",\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Cell 5 - visualize execution\n",
+ "# To visualize do `pip install sf-hamilton[visualization]` if you want these to work\n",
+ "\n",
+ "# visualize all possible functions\n",
+ "dr.display_all_functions(None) # we pass None to not save the image to file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# visualize just the execution path\n",
+ "dr.visualize_execution(output_columns, None, {}) # we pass None to not save the image to file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
spend
\n",
+ "
signups
\n",
+ "
avg_3wk_spend
\n",
+ "
spend_per_signup
\n",
+ "
spend_zero_mean_unit_variance
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
0
\n",
+ "
10
\n",
+ "
1
\n",
+ "
NaN
\n",
+ "
10.000
\n",
+ "
-1.064405
\n",
+ "
\n",
+ "
\n",
+ "
1
\n",
+ "
10
\n",
+ "
10
\n",
+ "
NaN
\n",
+ "
1.000
\n",
+ "
-1.064405
\n",
+ "
\n",
+ "
\n",
+ "
2
\n",
+ "
20
\n",
+ "
50
\n",
+ "
13.333333
\n",
+ "
0.400
\n",
+ "
-0.483821
\n",
+ "
\n",
+ "
\n",
+ "
3
\n",
+ "
40
\n",
+ "
100
\n",
+ "
23.333333
\n",
+ "
0.400
\n",
+ "
0.677349
\n",
+ "
\n",
+ "
\n",
+ "
4
\n",
+ "
40
\n",
+ "
200
\n",
+ "
33.333333
\n",
+ "
0.200
\n",
+ "
0.677349
\n",
+ "
\n",
+ "
\n",
+ "
5
\n",
+ "
50
\n",
+ "
400
\n",
+ "
43.333333
\n",
+ "
0.125
\n",
+ "
1.257934
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " spend signups avg_3wk_spend spend_per_signup \\\n",
+ "0 10 1 NaN 10.000 \n",
+ "1 10 10 NaN 1.000 \n",
+ "2 20 50 13.333333 0.400 \n",
+ "3 40 100 23.333333 0.400 \n",
+ "4 40 200 33.333333 0.200 \n",
+ "5 50 400 43.333333 0.125 \n",
+ "\n",
+ " spend_zero_mean_unit_variance \n",
+ "0 -1.064405 \n",
+ "1 -1.064405 \n",
+ "2 -0.483821 \n",
+ "3 0.677349 \n",
+ "4 0.677349 \n",
+ "5 1.257934 "
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# let's create the dataframe!\n",
+ "dr.execute(output_columns)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/examples/hello_world/my_notebook_script.py b/examples/hello_world/my_notebook_script.py
deleted file mode 100644
index e6fa525f..00000000
--- a/examples/hello_world/my_notebook_script.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Cell 1 - import the things you need
-import logging
-import sys
-
-import numpy as np
-import pandas as pd
-
-from hamilton import ad_hoc_utils, driver
-
-logging.basicConfig(stream=sys.stdout)
-
-# Cell 2 - import modules to create part of the DAG from
-import my_functions
-
-
-# Cell 3 - Define your new Hamilton functions & curate them into a TemporaryFunctionModule object.
-# Look at `my_functions` to see how these functions connect.
-def signups() -> pd.Series:
- """Returns sign up values"""
- return pd.Series([1, 10, 50, 100, 200, 400])
-
-
-def spend() -> pd.Series:
- """Returns the spend values"""
- return pd.Series([10, 10, 20, 40, 40, 50])
-
-
-def log_spend_per_signup(spend_per_signup: pd.Series) -> pd.Series:
- """Simple function taking the logarithm of spend over signups."""
- return np.log(spend_per_signup)
-
-
-# Place the functions into a temporary module -- the idea is that this should house a curated set of functions.
-# Don't be afraid to make multiple of them -- however we'd advise you to not use this method for production.
-# Also note, that using a temporary function module does not work for scaling onto Ray, Dask, or Pandas on Spark.
-temp_module = ad_hoc_utils.create_temporary_module(
- spend, signups, log_spend_per_signup, module_name="function_example"
-)
-
-# Cell 4 - Instantiate the Hamilton driver and pass it the right things in.
-initial_config = {}
-# we need to tell hamilton where to load function definitions from
-dr = driver.Driver(initial_config, my_functions, temp_module) # can pass in multiple modules
-# we need to specify what we want in the final dataframe.
-output_columns = [
- "spend",
- "signups",
- "avg_3wk_spend",
- "spend_per_signup",
- "spend_zero_mean_unit_variance",
- "log_spend_per_signup",
-]
-# let's create the dataframe!
-df = dr.execute(output_columns)
-print(df.to_string())
-
-# To visualize do `pip install sf-hamilton[visualization]` if you want these to work
-# dr.visualize_execution(output_columns, './my_dag.dot', {})
-# dr.display_all_functions('./my_full_dag.dot')
diff --git a/hamilton/driver.py b/hamilton/driver.py
index 0c30c756..b012fd27 100644
--- a/hamilton/driver.py
+++ b/hamilton/driver.py
@@ -359,7 +359,7 @@ def list_available_variables(self) -> List[Variable]:
@capture_function_usage
def display_all_functions(
self, output_file_path: str, render_kwargs: dict = None, graphviz_kwargs: dict = None
- ):
+ ) -> Optional["graphviz.Digraph"]: # noqa F821
"""Displays the graph of all functions loaded!
:param output_file_path: the full URI of path + file name to save the dot file to.
@@ -370,9 +370,11 @@ def display_all_functions(
:param graphviz_kwargs: Optional. Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
See https://graphviz.org/doc/info/attrs.html for options.
+ :return: the graphviz object if you want to do more with it.
+ If returned as the result in a Jupyter Notebook cell, it will render.
"""
try:
- self.graph.display_all(output_file_path, render_kwargs, graphviz_kwargs)
+ return self.graph.display_all(output_file_path, render_kwargs, graphviz_kwargs)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
@@ -384,7 +386,7 @@ def visualize_execution(
render_kwargs: dict,
inputs: Dict[str, Any] = None,
graphviz_kwargs: dict = None,
- ):
+ ) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes Execution.
Note: overrides are not handled at this time.
@@ -399,12 +401,14 @@ def visualize_execution(
:param graphviz_kwargs: Optional. Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
See https://graphviz.org/doc/info/attrs.html for options.
+ :return: the graphviz object if you want to do more with it.
+ If returned as the result in a Jupyter Notebook cell, it will render.
"""
_final_vars = self._create_final_vars(final_vars)
nodes, user_nodes = self.graph.get_upstream_nodes(_final_vars, inputs)
self.validate_inputs(user_nodes, inputs, nodes)
try:
- self.graph.display(
+ return self.graph.display(
nodes,
user_nodes,
output_file_path,
@@ -440,7 +444,7 @@ def what_is_downstream_of(self, *node_names: str) -> List[Variable]:
@capture_function_usage
def display_downstream_of(
self, *node_names: str, output_file_path: str, render_kwargs: dict, graphviz_kwargs: dict
- ):
+ ) -> Optional["graphviz.Digraph"]: # noqa F821
"""Creates a visualization of the DAG starting from the passed in function name(s).
Note: for any "node" visualized, we will also add its parents to the visualization as well, so
@@ -448,15 +452,17 @@ def display_downstream_of(
:param node_names: names of function(s) that are starting points for traversing the graph.
:param output_file_path: the full URI of path + file name to save the dot file to.
- E.g. 'some/path/graph.dot'
+ E.g. 'some/path/graph.dot'. Pass in None to skip saving any file.
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
:param graphviz_kwargs: Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
+ :return: the graphviz object if you want to do more with it.
+ If returned as the result in a Jupyter Notebook cell, it will render.
"""
downstream_nodes = self.graph.get_impacted_nodes(list(node_names))
try:
- self.graph.display(
+ return self.graph.display(
downstream_nodes,
set(),
output_file_path,
diff --git a/hamilton/graph.py b/hamilton/graph.py
index b8883502..f9d5780a 100644
--- a/hamilton/graph.py
+++ b/hamilton/graph.py
@@ -7,7 +7,7 @@
"""
import logging
from types import ModuleType
-from typing import Any, Callable, Collection, Dict, List, Set, Tuple, Type
+from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Type
from hamilton import base, node
from hamilton.function_modifiers import base as fm_base
@@ -199,7 +199,7 @@ def display_all(
render_kwargs = {}
if graphviz_kwargs is None:
graphviz_kwargs = {}
- self.display(
+ return self.display(
defined_nodes,
user_nodes,
output_file_path=output_file_path,
@@ -243,12 +243,12 @@ def display(
output_file_path: str = "test-output/graph.gv",
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
- ):
+ ) -> Optional["graphviz.Digraph"]: # noqa F821
"""Function to display the graph represented by the passed in nodes.
:param nodes: the set of nodes that need to be computed.
:param user_nodes: the set of inputs that the user provided.
- :param output_file_path: the path where we want to store the a `dot` file + pdf picture.
+ :param output_file_path: the path where we want to store the `dot` file + pdf picture. Pass in None to not save.
:param render_kwargs: kwargs to be passed to the render function to visualize.
:param graphviz_kwargs: kwargs to be passed to the graphviz graph object to configure it.
e.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
@@ -268,7 +268,9 @@ def display(
kwargs = {"view": True}
if render_kwargs and isinstance(render_kwargs, dict):
kwargs.update(render_kwargs)
- dot.render(output_file_path, **kwargs)
+ if output_file_path:
+ dot.render(output_file_path, **kwargs)
+ return dot
def get_impacted_nodes(self, var_changes: List[str]) -> Set[node.Node]:
"""Given our function graph, and a list of nodes that are changed,
diff --git a/tests/test_graph.py b/tests/test_graph.py
index b55518c3..1d66c616 100644
--- a/tests/test_graph.py
+++ b/tests/test_graph.py
@@ -530,6 +530,23 @@ def test_function_graph_display():
assert actual == expected
+def test_function_graph_display_without_saving():
+ """Tests that display works when None is passed in for path"""
+ fg = graph.FunctionGraph(tests.resources.dummy_functions, config={"b": 1, "c": 2})
+ defined_nodes = set()
+ user_nodes = set()
+ for n in fg.get_nodes():
+ if n.user_defined:
+ user_nodes.add(n)
+ else:
+ defined_nodes.add(n)
+ digraph = fg.display(defined_nodes, user_nodes, None)
+ assert digraph is not None
+ import graphviz
+
+ assert isinstance(digraph, graphviz.Digraph)
+
+
def test_create_graphviz_graph():
"""Tests that we create a graphviz graph"""
fg = graph.FunctionGraph(tests.resources.dummy_functions, config={})