Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prologue trace orders arguments in a way that breaks aliasing relation #1172

Open
shino16 opened this issue Sep 19, 2024 · 2 comments · May be fixed by #1184
Open

Prologue trace orders arguments in a way that breaks aliasing relation #1172

shino16 opened this issue Sep 19, 2024 · 2 comments · May be fixed by #1184
Assignees

Comments

@shino16
Copy link
Contributor

shino16 commented Sep 19, 2024

@thunder.jit
def fn(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10):
    # do not skip functionalization process
    a9 += 1
    # do not drop arguments by dce
    return a0 + a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10

args = [torch.zeros((), device='cuda') for _ in range(11)]
args[0] = args[1] = torch.zeros((2,), device='cuda')
fn(*args)
Traceback (most recent call last):
  File "/opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py", line 16, in <module>
    fn(*args)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 717, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 219, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 530, in get_computation_and_inputs
    functionalize_inplace_ops(
  File "/opt/pytorch/lightning-thunder/thunder/core/functionalization.py", line 884, in functionalize_inplace_ops
    no_implicit_alias_trace, swap_map_for_aliases = replace_args_with_alias_map(computation_trace, alias_tensor_indices)
  File "/opt/pytorch/lightning-thunder/thunder/core/functionalization.py", line 172, in replace_args_with_alias_map
    reshaped_arg = prims.reshape.meta(arg, arg_to_replace.shape)
  File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/prims.py", line 3167, in reshape_meta
    utils.check(
  File "/opt/pytorch/lightning-thunder/thunder/core/baseutils.py", line 107, in check
    raise exception_type(s())
RuntimeError: Attempting to reshape a.shape=(2,) to shape=(), but a.numel=2 is different from the number of elements in shape, 1

_alias_tensor_of_args_kwargs_dict recognizes that the 0-th and the first arguments are aliases. The prologue trace reorders the arguments into (a0, a10, a1, a2, a3, a4, a5, a6, a7, a8, a9), so the functionalizer tries to resolve the alias between a0 and a10 and reshapes a0 into a10.shape.

prologue trace:

import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  prims.check_len(args, 11)
  # kwargs: "Any"
  prims.check_len(kwargs, 0)
  a9: "cuda:0 f32[]" = args[9]
  a0: "cuda:0 f32[2]" = args[0]
  a1: "cuda:0 f32[2]" = args[1]
  a2: "cuda:0 f32[]" = args[2]
  a3: "cuda:0 f32[]" = args[3]
  a4: "cuda:0 f32[]" = args[4]
  a5: "cuda:0 f32[]" = args[5]
  a6: "cuda:0 f32[]" = args[6]
  a7: "cuda:0 f32[]" = args[7]
  a8: "cuda:0 f32[]" = args[8]
  a10: "cuda:0 f32[]" = args[10]
  prims.check_tensor_shape_and_metadata(a0, (2,), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(a1, (2,), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(a2, (), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(a3, (), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(a4, (), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(a5, (), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(a6, (), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(a7, (), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(a8, (), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(a9, (), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(a10, (), 'cuda:0', torch.float32, False)
  cache_info: "Any" = thunder._get_cache_info()
  cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
  prims.check_literal_like(cache_info_default_dtype, torch.float32)
  cache_info_default_device: "<class 'torch.device'>" = cache_info['default_device']
  prims.check_literal_like(cache_info_default_device, torch.device("cpu"))
  cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
  prims.check_number_type_and_value(cache_info_is_autocast_enabled, False)
  cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
  prims.check_number_type_and_value(cache_info_no_grad_sync, False)
  cache_info_alias_tensor_indices: "str" = cache_info['alias_tensor_indices']
  prims.check_string_value(cache_info_alias_tensor_indices, '0,1')
  return ((a0, a10, a1, a2, a3, a4, a5, a6, a7, a8, a9), ())
@shino16
Copy link
Contributor Author

shino16 commented Sep 19, 2024

The prologue returns the arguments in the order of f"[{idx}]", which is why a10 (=args[10]) is ordered before a1 (=args[1]). Variable names do not matter.

param_ordering[id(output)] = (output, param_ordering[id(obj)][1] + [math.inf, "[" + str(idx) + "]"])

@t-vi
Copy link
Collaborator

t-vi commented Sep 19, 2024

Thank you for catching this! We should change it to not make a string here.

@t-vi t-vi closed this as completed Sep 19, 2024
@t-vi t-vi reopened this Sep 19, 2024
@t-vi t-vi self-assigned this Sep 19, 2024
@t-vi t-vi linked a pull request Sep 20, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants