Skip to content

Commit

Permalink
Support both eager and graph mode
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Jul 22, 2020
1 parent 4698cb6 commit 0790159
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 41 deletions.
4 changes: 2 additions & 2 deletions .nengobones.yml
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ travis_yml:
env:
TF_VERSION: tensorflow-gpu
GPU_NUM: 0
- test_args: --unroll-simulation 5 --simulator-only
- test_args: --dtype float64 --simulator-only
- test_args: --unroll-simulation 5 --dtype float64 --simulator-only
- test_args: --inference-only --simulator-only
- test_args: --graph-mode --simulator-only
pypi_user: drasmuss
deploy_dists:
- sdist
Expand Down
6 changes: 3 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ jobs:
SCRIPT="remote-examples"
-
env:
TEST_ARGS="--unroll-simulation 5 --simulator-only"
TEST_ARGS="--unroll-simulation 5 --dtype float64 --simulator-only"
-
env:
TEST_ARGS="--dtype float64 --simulator-only"
TEST_ARGS="--inference-only --simulator-only"
-
env:
TEST_ARGS="--inference-only --simulator-only"
TEST_ARGS="--graph-mode --simulator-only"
- stage: deploy
if: branch =~ ^release-candidate-* OR tag =~ ^v[0-9]*
env: SCRIPT="deploy"
Expand Down
10 changes: 10 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def pytest_runtest_setup(item):
):
pytest.skip("Skipping performance test")

if item.config.getvalue("--graph-mode"):
tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_control_flow_v2()

tf.keras.backend.clear_session()


Expand Down Expand Up @@ -67,6 +71,12 @@ def pytest_addoption(parser):
default=False,
help="Run performance tests",
)
parser.addoption(
"--graph-mode",
action="store_true",
default=False,
help="Run tests in graph (not eager) mode",
)


@pytest.fixture(scope="session")
Expand Down
11 changes: 11 additions & 0 deletions docs/tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ search depth by setting

See :ref:`the documentation <config-planner>` for more details.

TensorFlow reworked a lot of their internal implementation details in TensorFlow 2.0.
But for some models, the pre-2.0 implementation will be faster. This behaviour can be
restored by calling

.. testcode::

tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_control_flow_v2()

at the top of your script.

Training a spiking deep network
-------------------------------

Expand Down
27 changes: 21 additions & 6 deletions nengo_dl/neuron_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from collections import OrderedDict
import contextlib
import logging
import warnings

Expand Down Expand Up @@ -50,6 +51,8 @@ def build_pre(self, signals, config):
for key in state_keys
]

self.prev_result = []

def neuron_step(dt, J, *states):
output = None
J_offset = 0
Expand Down Expand Up @@ -103,15 +106,27 @@ def build_step(self, signals):
states = [signals.gather(x) for x in self.state_data]
states_dtype = [x.dtype for x in self.state_data]

ret = tf.numpy_function(
self.neuron_step,
[signals.dt, J] + states,
[self.output_data.dtype] + states_dtype,
name=self.neuron_step.__name__,
)
if tf.executing_eagerly():
# noop
control_deps = contextlib.suppress()
else:
# we need to make sure that the previous call to this function
# has completed before the next starts, since we don't know that the
# functions are thread saf
control_deps = tf.control_dependencies(self.prev_result)

with control_deps:
ret = tf.numpy_function(
self.neuron_step,
[signals.dt, J] + states,
[self.output_data.dtype] + states_dtype,
name=self.neuron_step.__name__,
)

neuron_out, state_out = ret[0], ret[1:]

self.prev_result = [neuron_out]

neuron_out.set_shape((signals.minibatch_size,) + self.output_data.shape)
signals.scatter(self.output_data, neuron_out)

Expand Down
27 changes: 21 additions & 6 deletions nengo_dl/process_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from collections import OrderedDict
import contextlib
import logging

from nengo.builder.processes import SimProcess
Expand Down Expand Up @@ -48,6 +49,8 @@ def build_pre(self, signals, config):
]
self.mode = "inc" if self.ops[0].mode == "inc" else "update"

self.prev_result = []

# `merged_func` calls the step function for each process and
# combines the result
def merged_func(time, *input_state):
Expand Down Expand Up @@ -92,12 +95,22 @@ def build_step(self, signals):
input = [] if self.input_data is None else [signals.gather(self.input_data)]
state = [signals.gather(s) for s in self.state_data]

result = tf.numpy_function(
self.merged_func,
time + input + state,
[self.output_data.dtype] + [s.dtype for s in self.state_data],
name=self.merged_func.__name__,
)
if tf.executing_eagerly():
# noop
control_deps = contextlib.suppress()
else:
# we need to make sure that the previous call to this function
# has completed before the next starts, since we don't know that the
# functions are thread saf
control_deps = tf.control_dependencies(self.prev_result)

with control_deps:
result = tf.numpy_function(
self.merged_func,
time + input + state,
[self.output_data.dtype] + [s.dtype for s in self.state_data],
name=self.merged_func.__name__,
)

# TensorFlow will automatically squeeze length-1 outputs (if there is
# no state), which we don't want
Expand All @@ -106,6 +119,8 @@ def build_step(self, signals):
output = result[0]
state = result[1:]

self.prev_result = [output]

output.set_shape(self.output_data.full_shape)
signals.scatter(self.output_data, output, mode=self.mode)
for i, s in enumerate(state):
Expand Down
42 changes: 31 additions & 11 deletions nengo_dl/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import collections
import contextlib
import copy
from functools import partial
import logging
Expand Down Expand Up @@ -602,17 +603,29 @@ def reset(
"""

if self.stateful:
for key, var in self.tensor_graph.saved_state.items():
var.assign(
self.tensor_graph.initial_values[key](var.shape, dtype=var.dtype)
if tf.executing_eagerly():
for key, var in self.tensor_graph.saved_state.items():
var.assign(
self.tensor_graph.initial_values[key](
var.shape, dtype=var.dtype
)
)
else:
tf.keras.backend.batch_get_value(
[var.initializer for var in self.tensor_graph.saved_state.values()]
)

if include_trainable:
for key, var in self.tensor_graph.base_params.items():
var.assign(
self.tensor_graph.initial_values[key](var.shape, dtype=var.dtype)
if tf.executing_eagerly():
for key, var in self.tensor_graph.base_params.items():
var.assign(
self.tensor_graph.initial_values[key](
var.shape, dtype=var.dtype
)
)
else:
tf.keras.backend.batch_get_value(
[var.initializer for var in self.tensor_graph.base_params.values()]
)

if include_probes:
for p in self.model.probes:
self.model.params[p] = []
Expand Down Expand Up @@ -1605,11 +1618,18 @@ def arg_func(*args, output=None):
include_probes=False, include_trainable=False, include_processes=False
)

if tf.executing_eagerly():
# noop
ctx = contextlib.suppress()
else:
ctx = tf.compat.v1.keras.backend.get_session().as_default()

grads = dict()
for output in outputs:
analytic, numeric = tf.test.compute_gradient(
partial(arg_func, output=output), inputs
)
with ctx:
analytic, numeric = tf.test.compute_gradient(
partial(arg_func, output=output), inputs
)
grads[output] = dict()
grads[output]["analytic"] = analytic
grads[output]["numeric"] = numeric
Expand Down
21 changes: 15 additions & 6 deletions nengo_dl/tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,13 @@ def unbuild(layer):

tf.keras.backend.batch_set_value(zip(weight_sets, weight_vals))

if not tf.executing_eagerly():
# initialize state variables (need to do this manually because we're not
# adding them to self.weights)
tf.keras.backend.batch_get_value(
[var.initializer for var in self.saved_state.values()]
)

@tf.autograph.experimental.do_not_convert
def call(self, inputs, training=None, progress=None, stateful=False):
"""
Expand Down Expand Up @@ -495,12 +502,11 @@ def call(self, inputs, training=None, progress=None, stateful=False):
# number of steps, even if there are no output probes
outputs = list(probe_arrays.values()) + [steps_run]

n_state = 0
updates = []
if stateful:
# update saved state
for var, val in zip(self.saved_state.values(), final_internal_state):
var.assign(val)
n_state += 1
updates.append(var.assign(val))

# if any of the base params have changed (due to online learning rules) then we
# also need to assign those back to the original variable (so that their
Expand All @@ -513,10 +519,13 @@ def call(self, inputs, training=None, progress=None, stateful=False):
minibatched = self.base_arrays_init["trainable"][key][-1]

if minibatched:
var.assign(val)
n_state += 1
updates.append(var.assign(val))

logger.info("Number of state updates: %d", len(updates))

logger.info("Number of state updates: %d", n_state)
if not tf.executing_eagerly() and len(updates) > 0:
with tf.control_dependencies(updates):
outputs = [tf.identity(x) for x in outputs]

return outputs

Expand Down
34 changes: 27 additions & 7 deletions nengo_dl/tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,26 @@ def test_lmu(Simulator, native_nengo, pytestconfig):

@pytest.mark.performance
@pytest.mark.parametrize(
"net, train, minibatch_size, min, max",
"net, train, minibatch_size, eager, min, max",
[
(benchmarks.cconv(128, 64, nengo.RectifiedLinear()), False, 64, 1.0, 1.15),
(benchmarks.cconv(128, 64, nengo.LIF()), False, 64, 2.25, 2.55),
(benchmarks.integrator(128, 32, nengo.RectifiedLinear()), True, 64, 0.6, 0.9),
(benchmarks.integrator(128, 32, nengo.LIF()), True, 64, 0.95, 1.15),
(
benchmarks.cconv(128, 64, nengo.RectifiedLinear()),
False,
64,
True,
1.0,
1.15,
),
(benchmarks.cconv(128, 64, nengo.LIF()), False, 64, True, 2.25, 2.55),
(
benchmarks.integrator(128, 32, nengo.RectifiedLinear()),
True,
64,
True,
0.6,
0.9,
),
(benchmarks.integrator(128, 32, nengo.LIF()), True, 64, True, 0.95, 1.15),
(
benchmarks.random_network(
64,
Expand All @@ -212,13 +226,15 @@ def test_lmu(Simulator, native_nengo, pytestconfig):
),
False,
None,
True,
0.5,
0.7,
),
(benchmarks.lmu(1000, 1, native_nengo=True), True, 100, 1.3, 1.5),
(benchmarks.lmu(1000, 1, native_nengo=True), True, 100, True, 1.3, 1.5),
(benchmarks.lmu(1000, 1, native_nengo=True), True, 100, False, 1.05, 1.25),
],
)
def test_performance(net, train, minibatch_size, min, max):
def test_performance(net, train, minibatch_size, eager, min, max):
# performance is based on Azure NC6 VM
# CPU: Intel Xeon E5-2690 v3 @ 2.60Ghz
# GPU: Nvidia Tesla K80
Expand All @@ -227,6 +243,10 @@ def test_performance(net, train, minibatch_size, min, max):
# Nengo version: 3.1.0
# NengoDL version: 3.3.0

if not eager:
tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_control_flow_v2()

time = benchmarks.run_profile(
net,
minibatch_size=minibatch_size,
Expand Down

0 comments on commit 0790159

Please sign in to comment.