Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2D whole model compile fails at embedding layer #534

Open
tianyu-l opened this issue Aug 20, 2024 · 2 comments
Open

2D whole model compile fails at embedding layer #534

tianyu-l opened this issue Aug 20, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@tianyu-l
Copy link
Contributor

tianyu-l commented Aug 20, 2024

Specifically it failed at dealing with DTensor MaskPartial placement of sharded embedding.

This only happens when we do whole model compile.
TransformerBlock-level compilation (default) + separately compiling the embedding layer doesn't have this issue.

error log ./run_llama_train.sh + NGPU=8 + LOG_RANK=0 + CONFIG_FILE=./train_configs/llama3_8b.toml + overrides= + '[' 0 -ne 0 ']' + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 train.py --job.config_file ./train_configs/llama3_8b.toml W0819 17:09:53.189000 1633288 torch/distributed/run.py:793] W0819 17:09:53.189000 1633288 torch/distributed/run.py:793] ***************************************** W0819 17:09:53.189000 1633288 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0819 17:09:53.189000 1633288 torch/distributed/run.py:793] ***************************************** [rank0]:2024-08-19 17:09:55,187 - root - INFO - Starting job: Llama 3 8B training [rank0]:2024-08-19 17:09:58,867 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:2024-08-19 17:09:58,879 - root - INFO - GPU capacity: NVIDIA H100 (0) with 95.04GiB memory [rank0]:2024-08-19 17:09:58,879 - root - INFO - Building 2-D device mesh with ['dp', 'tp'], [4, 2] [rank0]:2024-08-19 17:09:58,906 - root - INFO - Building tiktoken tokenizer locally from ./torchtitan/datasets/tokenizer/original/tokenizer.model [rank0]:2024-08-19 17:09:59,059 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank0]:2024-08-19 17:09:59,059 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:2024-08-19 17:10:06,989 - root - INFO - Building llama3 8B with ModelArgs(dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_batch_size=32, max_seq_len=8192, depth_init=True, norm_type='rmsnorm') [rank0]:2024-08-19 17:10:07,101 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:2024-08-19 17:10:07,155 - root - INFO - Applied Tensor Parallelism to the model [rank0]:2024-08-19 17:10:07,425 - root - WARNING - detected that the pytorch is built from source. Please make sure the PR (https://github.com/pytorch/pytorch/pull/130760) is included in pytorch for correct 2D/3D DCP usage. [rank0]:2024-08-19 17:10:07,475 - root - INFO - Applied FSDP to the model [rank0]:2024-08-19 17:10:07,835 - root - INFO - GPU memory usage for model: 3.78GiB(3.98%) [rank0]:2024-08-19 17:10:07,836 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 4, sequence length 8192, total steps 10 (warmup 2) [rank0]:NCCL version 2.21.5+cuda12.0 [rank0]:[rank0]: Traceback (most recent call last): [rank0]:[rank0]: File "/data/users/lty/torchtitan/train.py", line 424, in [rank0]:[rank0]: main(config) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper [rank0]:[rank0]: return f(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/torchtitan/train.py", line 299, in main [rank0]:[rank0]: pred = model(input_ids) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1788, in _call_impl [rank0]:[rank0]: result = forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/eval_frame.py", line 509, in _fn [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/torchtitan/torchtitan/models/llama/model.py", line 436, in forward [rank0]:[rank0]: h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1801, in _call_impl [rank0]:[rank0]: hook_result = hook(self, args, result) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 1238, in __call__ [rank0]:[rank0]: return self._torchdynamo_orig_callable( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 1039, in __call__ [rank0]:[rank0]: result = self._inner_convert( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 514, in __call__ [rank0]:[rank0]: return _compile( [rank0]:[rank0]: ^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 902, in _compile [rank0]:[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 653, in compile_inner [rank0]:[rank0]: return _compile_inner(code, one_graph, hooks, transform) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_utils_internal.py", line 87, in wrapper_function [rank0]:[rank0]: return function(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 686, in _compile_inner [rank0]:[rank0]: out_code = transform_code_object(code, transform) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object [rank0]:[rank0]: transformations(instructions, code_options) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 208, in _fn [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 622, in transform [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2731, in run [rank0]:[rank0]: super().run() [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 958, in run [rank0]:[rank0]: while self.step(): [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 870, in step [rank0]:[rank0]: self.dispatch_table[inst.opcode](self, inst) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 558, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2242, in CALL [rank0]:[rank0]: self._call(inst) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2236, in _call [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 805, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward [rank0]:[rank0]: return getattr(self.realize(), name)(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/functions.py", line 906, in call_function [rank0]:[rank0]: return self.func.call_function(tx, merged_args, merged_kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/functions.py", line 322, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/functions.py", line 106, in call_function [rank0]:[rank0]: return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 811, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2946, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 3062, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 958, in run [rank0]:[rank0]: while self.step(): [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 870, in step [rank0]:[rank0]: self.dispatch_table[inst.opcode](self, inst) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 558, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2242, in CALL [rank0]:[rank0]: self._call(inst) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2236, in _call [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 805, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/misc.py", line 970, in call_function [rank0]:[rank0]: return self.obj.call_method(tx, self.name, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/tensor.py", line 527, in call_method [rank0]:[rank0]: result = handler_method(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/tensor.py", line 905, in method_redistribute [rank0]:[rank0]: return wrap_fx_proxy( [rank0]:[rank0]: ^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/builder.py", line 1916, in wrap_fx_proxy [rank0]:[rank0]: return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/builder.py", line 2003, in wrap_fx_proxy_cls [rank0]:[rank0]: example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 2051, in get_fake_value [rank0]:[rank0]: raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 1983, in get_fake_value [rank0]:[rank0]: ret_val = wrap_fake_exception( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 1468, in wrap_fake_exception [rank0]:[rank0]: return fn() [rank0]:[rank0]: ^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 1984, in [rank0]:[rank0]: lambda: run_node(tx.output, node, args, kwargs, nnmodule) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 2119, in run_node [rank0]:[rank0]: raise RuntimeError(make_error_message(e)).with_traceback( [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 2101, in run_node [rank0]:[rank0]: return node.target(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/tensor.py", line 898, in redistribute_fn_with_prim_types [rank0]:[rank0]: return x.redistribute(*args_as_value, **kwargs_as_value) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/api.py", line 541, in redistribute [rank0]:[rank0]: return Redistribute.apply(self, device_mesh, placements, async_op) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/autograd/function.py", line 575, in apply [rank0]:[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/_redistribute.py", line 295, in forward [rank0]:[rank0]: output = redistribute_local_tensor( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/_redistribute.py", line 214, in redistribute_local_tensor [rank0]:[rank0]: new_local_tensor = partial_spec._reduce_shard_value( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/ops/_embedding_ops.py", line 143, in _reduce_shard_value [rank0]:[rank0]: self.mask_buffer.apply_mask(tensor) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/ops/_embedding_ops.py", line 67, in apply_mask [rank0]:[rank0]: tensor[self.data, :] = 0.0 [rank0]:[rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/utils/_stats.py", line 21, in wrapper [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 1251, in __torch_dispatch__ [rank0]:[rank0]: return self.dispatch(func, types, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 1705, in dispatch [rank0]:[rank0]: return self._cached_dispatch_impl(func, types, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 1361, in _cached_dispatch_impl [rank0]:[rank0]: output = self._dispatch_impl(func, types, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 1800, in _dispatch_impl [rank0]:[rank0]: (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 2104, in validate_and_convert_non_fake_tensors [rank0]:[rank0]: validated_args = [validate(a) for a in flat_args] [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 2104, in [rank0]:[rank0]: validated_args = [validate(a) for a in flat_args] [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 2092, in validate [rank0]:[rank0]: raise AssertionError( [rank0]:[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function .redistribute_fn_with_prim_types at 0x7fd6b53b8a40>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(1, 8192, 4096), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0, 1], mesh_dim_names=('tp',)), placements=(_MaskPartial(offset_shape=(128256, 4096), offset_dim=0),)),), **{}): [rank0]:[rank0]: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.index_put_.default(FakeTensor(..., device='cuda:0', size=(1, 8192, 4096), dtype=torch.bfloat16), [tensor([...], device='cuda:0', size=(1, 8192))], FakeTensor(..., size=(), dtype=torch.bfloat16)) [rank0]: [rank0]:[rank0]: from user code: [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/api.py", line 895, in [rank0]:[rank0]: lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/tensor/parallel/style.py", line 251, in _prepare_output_fn [rank0]:[rank0]: outputs = outputs.redistribute(placements=output_layouts, async_op=True) [rank0]: [rank0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
@tianyu-l tianyu-l added the bug Something isn't working label Aug 20, 2024
@bdhirsh
Copy link

bdhirsh commented Aug 20, 2024

hmmm I wonder if this is the same as what wanchao and I saw with this: pytorch/pytorch#130028 (comment)

@tianyu-l
Copy link
Contributor Author

tianyu-l commented Aug 20, 2024

It looks Wanchao and Brian have been aware of this. Given how hard it is to tackle, let's stick with TransformerBlock-level compilation for now.

Also as of 08/19:

  • it seems whole model compile doesn't work well with SAC, as the performance dropped quite a bit (5700 -> 5000 tok/s on 8B model) compared with block-level compile.
  • whole model compile provides ~1.6% throughput gain, but has recompilation warnings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants