Skip to content

Commit

Permalink
implement tf.nn.bias_add (#8835)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedo42 authored Dec 18, 2022
1 parent 90107cd commit 5a4c565
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
13 changes: 13 additions & 0 deletions ivy/functional/frontends/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,16 @@ def moments(x, axes, shift=None, keepdims=False, name=None):
return ivy.mean(x, axis=axes, keepdims=keepdims), ivy.var(
x, axis=axes, keepdims=keepdims
)


@to_ivy_arrays_and_back
def bias_add(value, bias, data_format=None, name=None):
if data_format is None:
data_format = "N...C"

if data_format == "N...C":
return ivy.add(value, bias)
else:
value = ivy.swapaxes(value, 1, -1)
res = ivy.add(value, bias)
return ivy.swapaxes(res, 1, -1)
46 changes: 46 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,3 +958,49 @@ def test_tensorflow_moments(
axes=axis,
keepdims=keepdims,
)


@st.composite
def _generate_bias_data(draw):
data_format = draw(st.sampled_from(["NC...", "N...C", None]))
channel_dim = 1 if data_format == "NC..." else -1
dtype, value, shape = draw(
helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
min_num_dims=3,
ret_shape=True,
)
)
channel_size = shape[channel_dim]
bias = draw(helpers.array_values(dtype=dtype[0], shape=(channel_size,)))
return data_format, dtype, value, bias


@handle_frontend_test(
fn_tree="tensorflow.nn.bias_add",
data=_generate_bias_data(),
)
def test_tensorflow_bias_add(
*,
data,
as_variable,
num_positional_args,
native_array,
frontend,
fn_tree,
on_device,
):
data_format, dtype, value, bias = data
helpers.test_frontend_function(
input_dtypes=dtype * 2,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
value=value[0],
bias=bias,
data_format=data_format,
)

0 comments on commit 5a4c565

Please sign in to comment.