Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reduction ops #28

Merged
merged 3 commits into from
Jun 20, 2023
Merged

Add reduction ops #28

merged 3 commits into from
Jun 20, 2023

Conversation

sbrunk
Copy link
Owner

@sbrunk sbrunk commented Jun 17, 2023

@sbrunk sbrunk force-pushed the reduction-ops branch 6 times, most recently from 1fa94d4 to 9901dd0 Compare June 19, 2023 20:40
@sbrunk sbrunk marked this pull request as ready for review June 20, 2023 07:07
@sbrunk
Copy link
Owner Author

sbrunk commented Jun 20, 2023

Hey @davoclavo could you have a look when you find the time? I've added most reduction ops and also tests for most of them (made much easier thanks to your helpers).

Copy link
Contributor

@davoclavo davoclavo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Amazing work! And awesome docs, I will try to follow that style from now on.

I just added a quick comment regarding type promotion.

And also another quick question: I have been changing snake_case to camelCase to match scala style, however there is some benefit of sticking with snake_case to match pytorch nomenclature. Should pick
one over the other to have consistency?

* @param p
* the norm to be computed
*/
// TODO dtype promotion floatNN/complexNN => highest floatNN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you could use FloatPromoted[D]

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but trying that I think we might need to add cases for float16/bfloat16 to FloatPromoted:

I.e. sin in PyTorch:

torch.sin(torch.tensor(1., dtype=torch.bfloat16)).dtype # torch.bfloat16
scala> torch.sin(torch.Tensor(1f).to(dtype=torch.bfloat16)).dtype
java.lang.ClassCastException: class torch.DType$bfloat16$ cannot be cast to class torch.Float32 (torch.DType$bfloat16$ and torch.Float32 are in unnamed module of loader sbt.internal.inc.classpath.ClasspathUtil$$anon$2 @586fb5d6)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I got it now, although the type is getting somewhat verbose 😆

Tensor[Promoted[FloatPromoted[ComplexToReal[D]], FloatPromoted[ComplexToReal[D2]]]]

@sbrunk
Copy link
Owner Author

sbrunk commented Jun 20, 2023

LGTM! Amazing work! And awesome docs, I will try to follow that style from now on.

Thanks! I wanted to document how I convert the docs, but didn't get to it yet.

I'm using the original PyTorch rst doc sources, i.e. for ops in torch:
https://raw.githubusercontent.com/pytorch/pytorch/master/torch/_torch_docs.py
or for torch.tensor methods:
https://github.com/pytorch/pytorch/blob/main/torch/_tensor_docs.py
Sometimes they are also regular docstrings:
https://github.com/pytorch/pytorch/blob/main/torch/functional.py

  1. Search for the function you're about to port. I usually have then generated doc open in the browser as well to help with search and to verify I get everything I want from the source (sometimes things are factored out into variables and reused etc.).
  2. Copy and convert the rst to markdown. I'm just using the online interface of pandoc: https://pandoc.org/try/
  3. Copy the markdown into Scaladoc and manually clean it up/fix things
  4. Run scalafmt

The conversion is not perfect but it often gets you 80% there reducing a lot of tedious manual work.

Also please don't feel pressured to add all the docs immediately if you don't feel like it when porting things over. They can always be added/improved later on as well. :)

And also another quick question: I have been changing snake_case to camelCase to match scala style, however there is some benefit of sticking with snake_case to match pytorch nomenclature. Should pick one over the other to have consistency?

Yeah that's something I'm not satisfied with either way. I started with camelCase for idiomatic style but realized it makes porting and searching for docs a bit more inconvenient. Since most names are now camelCase already, I'm currently inclined to stick with it, but I'm absolutely not religious about it. If you feel it makes more sense to move to snake_case, now or later, let's discuss it.

@sbrunk sbrunk merged commit 89c4d5f into main Jun 20, 2023
7 checks passed
@sbrunk sbrunk deleted the reduction-ops branch June 20, 2023 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants