Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python slice is not represented properly in thunder #1182

Open
jjsjann123 opened this issue Sep 20, 2024 · 0 comments
Open

python slice is not represented properly in thunder #1182

jjsjann123 opened this issue Sep 20, 2024 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@jjsjann123
Copy link
Collaborator

🚀 Feature

slice is currently showing up in trace differently from list/tuple, where NumberProxy is printed explicitly, resulting in an invalid python program.

e.g. in the script below:

import thunder
import torch
dtype = torch.float32
 
def foo(a):
    return a[..., : (a.shape[-1] // 2)]

jfoo = thunder.jit(foo, cache="symbolic values")

a = torch.randn(2, 2, device="cuda")
out = jfoo(a)

On PR #1027, commit 4d260aed2b0939cebdeeeb4f04cf47358d3d9c8b.

We have a trace like this:

def computation(a, i1):
  # a: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
  # i1: "int 2"

  # /volume/thunder_dynamic/t.py:7:         return a[..., : (a.shape[-1] // 2)]
  b2 = prims.signbit(i1)  # b2: "bool False"
  _ = prims.signbit(2)
  b3 = prims.ne(b2, False)  # b3: "bool False"
  f4 = prims.fmod(i1, 2)  # f4: "float 0.0"
  i5 = prims.convert_element_type(f4, int)  # i5: "int 0"
  b6 = prims.ne(i5, 0)  # b6: "bool False"
  b7 = prims.bitwise_and(b3, b6)  # b7: "bool False"
  i8 = prims.div(i1, 2)  # i8: "int 1"
  i9 = prims.convert_element_type(b7, int)  # i9: "int 0"
  i10 = prims.sub(i8, i9)  # i10: "int 1"
  t37 = ltorch.getitem(a, (..., slice(None, [IntegerProxy name=i10, value=1, static=CONSTRAINT.CONSTRAINABLE], None)))  # t37: "cuda:0 f32[[IntegerProxy name=i29, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i36, value=1, static=CONSTRAINT.CONSTRAINABLE]]"

This line is not a valid python program.

t37 = ltorch.getitem(a, (..., slice(None, [IntegerProxy name=i10, value=1, static=CONSTRAINT.CONSTRAINABLE], None)))

We should instead print out something like slice(None, i10, None).

We plan on just having slice print out being handled inline, similar to how list/tuple. Thanks to suggestion by @t-vi & @mruberry .

Alternative

I initial thought is that we needed SliceProxy, like ListProxy / TupleProxy. I thought that's needed, but as @t-vi pointed out that

I don't think that TupleProxy, DictProxy, and ListProxy are currently not used at all, they are relicts of the functional JIT which tried to proxy everything. I think CollectionProxy is used exclusively for autograd.
Having a slice proxy is certainly possible and may be needed if we run into trouble putting numberproxies into slices.
(but I wondered why you have lists in your slices above until... ).

I don't think we need SliceProxy for now, so I'll proceed with the less invasive approach first to unblock myself.

@jjsjann123 jjsjann123 added the enhancement New feature or request label Sep 20, 2024
@jjsjann123 jjsjann123 self-assigned this Sep 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant