Skip to content

Commit

Permalink
fixup! Make ConvertBatchNormalization versions explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Apr 9, 2020
1 parent 896df01 commit b6c5c0e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
11 changes: 11 additions & 0 deletions nengo_dl/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def _conform_to_reference_input(self, tensor, ref_input):

network.Network._conform_to_reference_input = _conform_to_reference_input

if version.parse(tf.__version__) < version.parse("2.1.0rc0"):
from tensorflow.python.keras.layers import (
BatchNormalization as BatchNormalizationV1,
)
from tensorflow.python.keras.layers import BatchNormalizationV2
else:
from tensorflow.python.keras.layers import (
BatchNormalizationV1,
BatchNormalizationV2,
)

# Nengo compatibility

# monkeypatch fix for https://github.com/nengo/nengo/pull/1587
Expand Down
5 changes: 2 additions & 3 deletions nengo_dl/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import nengo
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.layers import BatchNormalizationV1, BatchNormalizationV2
from tensorflow.python.util import nest

from nengo_dl import compat
Expand Down Expand Up @@ -1170,8 +1169,8 @@ def convert(self, node_id):
return super().convert(node_id, dimensions=3)


@Converter.register(BatchNormalizationV1)
@Converter.register(BatchNormalizationV2)
@Converter.register(compat.BatchNormalizationV1)
@Converter.register(compat.BatchNormalizationV2)
class ConvertBatchNormalization(LayerConverter):
"""Convert ``tf.keras.layers.BatchNormalization`` to Nengo objects."""

Expand Down

0 comments on commit b6c5c0e

Please sign in to comment.