Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented binary_cross_entropy_with_logits in paddle.nn.functional.loss #17033

Merged
merged 3 commits into from
Jun 19, 2023

Conversation

KevinUli
Copy link
Contributor

@KevinUli KevinUli commented Jun 17, 2023

Close #16741

@KevinUli
Copy link
Contributor Author

Result from paddle ground truth and other backend differs, so I set test_values=False

@xoiga123
Copy link
Contributor

@KevinUli How do they differ? If you explicitly set test_values=False then you have to implement your own testing from the returned values.

@KevinUli
Copy link
Contributor Author

KevinUli commented Jun 19, 2023

@xoiga123 When running the function for torch and paddle,

>>> arr = torch.as_tensor([0.5, 0.5,], dtype=torch.float32)
>>> pos_weight = torch.as_tensor([1., 2.,], dtype=torch.float32)
>>> torch.nn.functional.binary_cross_entropy_with_logits(arr, arr, weight=arr, reduction="none", pos_weight=pos_weight)
Tensor([0.3620, 0.4806])
>>> arr = paddle.to_tensor([0.5, 0.5,], dtype="float32")
>>> pos_weight = paddle.to_tensor([1., 2.,], dtype="float32")
>>> paddle.nn.functional.binary_cross_entropy_with_logits(arr, arr, weight=arr, reduction="none", pos_weight=pos_weight)
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, [0.36203849, 0.54305774])

Getting rid of the pos_weight kwarg returns the same value for all frameworks
The pos_weight kwarg seems to work differently for paddle as opposed to other frameworks.
Should I simply remove pos_weight kwarg from the testing and adding a comment for it?

@xoiga123
Copy link
Contributor

@KevinUli Yeah seems weird, I'll look into source and see what I can find. Thank you

@xoiga123
Copy link
Contributor

@KevinUli Yeah clearly the pytorch's implementation of pos_weight is correct, while the paddle's implementation is incorrect. This seems easily fixable. Please add a comment below the test definition

# TODO: paddle's implementation of pos_weight is wrong
# https://github.com/PaddlePaddle/Paddle/blob/f0422a28d75f9345fa3b801c01cd0284b3b44be3/python/paddle/nn/functional/loss.py#L831

And then proceed to open an issue on their side, or open a PR to fix it if you want to. I'll merge this PR but please keep track of when it's fixed and add pos_weight back. Thank you 🤗

@KevinUli
Copy link
Contributor Author

I've opened an issue on their side.

Copy link
Contributor

@xoiga123 xoiga123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well done 👍

@xoiga123 xoiga123 merged commit 41054c7 into ivy-llc:master Jun 19, 2023
@KevinUli KevinUli deleted the binary_cross_entropy_with_logits branch June 19, 2023 04:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

binary_cross_entropy_with_logits
2 participants