Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IndexError when flat inputs are concatenated during trace #176

Open
austinleedavis opened this issue Jul 20, 2024 · 0 comments
Open

IndexError when flat inputs are concatenated during trace #176

austinleedavis opened this issue Jul 20, 2024 · 0 comments

Comments

@austinleedavis
Copy link

Description

The NNSight object raises an IndexError when using unbatched token IDs as input while tracing in a loop. This bug is an oppressive landmine and the error message is not very helpful. It be nice if the trace invocation checks if the inputs are correctly batched using len(input_ids.shape) before applying the tensor concatenation. A simple solution is to use tracer.invoke(input_ids.unsqueeze(0)).

Root Cause

When an NNSight object batches inputs it uses the torch.concatenate() method to stack the input_id tensors along dimension zero (0). However, if a single dimension tensor is used as input, e.g., len(input_id.shape)==1, then concatenation appends the subsequent inputs to the original input_id tensor rather than stacking it. The downstream forward pass can potentially raise an IndexError after multiple concatenations because the resulting length of the "batched" inputs exceed the model context window. So, the forward pass will raise an IndexError when the inputs are being embedded because the input id indices exceed the number of columns in the embedding matrix.

Working Example

input_ids = torch.tensor([[1]]) # <-- shape = torch.Size([1, 1])
model = NNSight(model)
with model.trace() as tracer:
    for i in range(model.transformer.wpe.weight.shape[0]):
        with tracer.invoke(input_ids):
            pass 

Failing Example

input_ids = torch.tensor([1]) # <-- shape = torch.Size([1])
model = NNSight(model)
with model.trace() as tracer:
    for i in range(model.transformer.wpe.weight.shape[0]):
        with tracer.invoke(input_ids):
            pass 

Info

  • nnsight==0.2.16
  • torch==2.3.1
  • transformers==4.40.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant