Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add l2_loss function to tensorflow frontend #23073

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions ivy/data_classes/array/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,49 @@ def l1_loss(
"""
return ivy.l1_loss(self._data, target, reduction=reduction, out=out)

def l2_loss(
self: Union[ivy.Array, ivy.NativeArray],
/,
*,
reduction: Optional[str] = "sum",
out: Optional[ivy.Array] = None,
) -> ivy.Array:

"""
ivy.Array instance method variant of l2_loss. This method simply wraps the
function, and so the docstring for l2_loss also applies to this method with
minimal changes.

Parameters
----------
input: Union[ivy.Array, ivy.NativeArray]
Input array containing input values.
reduction : str, optional
Reduction method for the output loss. Options:
"none" (no reduction),
"mean" (mean of losses),
"sum" (sum of losses).
Default: "sum".
out : Optional[ivy.Array], optional
Optional output array for writing the result to.
It must have a shape that the inputs broadcast to.

Returns
-------
ivy.Array
The L2 loss of the given input.

Examples
--------
>>>x=ivy.array([0.5,1.5,9])
>>>y=ivy.array([0.1,2,1e-2])
>>>print(ivy.l2_loss(x))
ivy.array(41.75)
>>>print(ivy.l2_loss(x-y))
ivy.array(40.61505127)
"""
return ivy.l2_loss(self._data)

def huber_loss(
self: ivy.Array,
pred: Union[ivy.Array, ivy.NativeArray],
Expand Down
50 changes: 50 additions & 0 deletions ivy/data_classes/container/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,56 @@ def l1_loss(
out=out,
)

@staticmethod
def _static_l2_loss(
input: Union[ivy.Container, ivy.Array, ivy.NativeArray],
/,
*,
reduction: Optional[Union[str, ivy.Container]] = "mean",
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Dummy_documentation
"""
return ContainerBase.cont_multi_map_in_function(
"l2_loss",
input,
reduction=reduction,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def l2_loss(
self: ivy.Container,
/,
*,
reduction: Optional[Union[str, ivy.Container]] = "mean",
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Dummy_documentation
"""
return self._static_l2_loss(
self,
reduction=reduction,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

@staticmethod
def _static_smooth_l1_loss(
input: Union[ivy.Container, ivy.Array, ivy.NativeArray],
Expand Down
5 changes: 5 additions & 0 deletions ivy/functional/frontends/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,11 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
return ivy.log_poisson_loss(targets, log_input, compute_full_loss=compute_full_loss)


@to_ivy_arrays_and_back
def l2_loss(t,name=None):
return ivy.l2_loss(t)


@to_ivy_arrays_and_back
def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
return ivy.max_pool1d(input, ksize, strides, padding, data_format=data_format)
Expand Down
54 changes: 54 additions & 0 deletions ivy/functional/ivy/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,60 @@ def l1_loss(
return ivy.inplace_update(out, loss) if out is not None else loss


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
@inputs_to_ivy_arrays
@handle_array_function
def l2_loss(
input: Union[ivy.Array, ivy.NativeArray],
/,
*,
reduction: Optional[str] = "sum",
out: Optional[ivy.Array] = None,
) -> ivy.Array:

"""
Computes half the L2 norm of an array without the sqrt .

Parameters
----------
input: Union[ivy.Array, ivy.NativeArray]
Input array containing input values.
reduction : str, optional
Reduction method for the output loss. Options:
"none" (no reduction),
"mean" (mean of losses),
"sum" (sum of losses).
Default: "sum".
out : Optional[ivy.Array], optional
Optional output array for writing the result to.
It must have a shape that the inputs broadcast to.

Returns
-------
ivy.Array
The L2 loss of the given input.

Examples
--------
>>>x=ivy.array([0.5,1.5,9])
>>>y=ivy.array([0.1,2,1e-2])
>>>print(ivy.l2_loss(x))
ivy.array(41.75)
>>>print(ivy.l2_loss(x-y))
ivy.array(40.61505127)
"""

loss= ivy.pow(input,2)
if reduction == "sum":
return ivy.divide(ivy.sum(loss, out=out),2)
elif reduction == "mean":
return ivy.divide(ivy.mean(loss, out=out),2)
else:
return ivy.inplace_update(out, loss) if out is not None else loss


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
Expand Down
36 changes: 36 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,7 @@ def test_tensorflow_local_response_normalization(
compute_full_loss=st.booleans(),
test_with_out=st.just(False),
)

def test_tensorflow_log_poisson_loss(
*,
dtype_target_log_inputs,
Expand All @@ -1398,6 +1399,41 @@ def test_tensorflow_log_poisson_loss(
)


@handle_frontend_test(
fn_tree="tensorflow.nn.l2_loss",
dtype_target_inputs=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
num_arrays=1,
min_value=0,
max_value=1,
min_num_dims=1,
max_num_dims=5,
shared_dtype=True,
),
compute_full_loss=st.booleans(),
test_with_out=st.just(False),
)
def test_tensorflow_l2_loss(
*,
dtype_target_inputs,
compute_full_loss,
test_flags,
frontend,
fn_tree,
on_device,
backend_fw,
):
input_dtype, input_values = dtype_target_inputs
helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
test_flags=test_flags,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
t = input_values,
)

# max_pool1d
@handle_frontend_test(
fn_tree="tensorflow.nn.max_pool1d",
Expand Down
Loading