Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2][NF4Tensor][2/n] implement torch.chunk and other ops #150

Merged
merged 47 commits into from
May 1, 2024

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Apr 19, 2024

why FSDP needs those ops

  • torch.chunk / aten.split.Tensor: dim0 sharding on parameters torch.chunk(tensor, world_size, dim=0)
  • tensor.new_zeros / aten.new_zeros.default: allocate storage for padded params.
  • tensor[:end_idx] / aten.slice.Tensor and tensor.copy_: copy sharded params into padded params
  • tensor.view(-1) / aten.view.default: flatten ND tensors into 1D
  • torch.as_strided(tensor, orig_size) / aten.as_strided.default: restore 1D tensors to ND
  • tensor.pin_memory: move cpu tensor to pin memory for nonblocking D2H copy
  • tensor.cpu(): move gpu tensor to cpu

unit test: pytest test/dtypes/test_nf4.py

run fsdp in TorchTune

  • git clone https://github.com/weifengpy/torchtune.git
  • cd torchtune && pip install -e ".[dev]"
  • tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>
  • tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config recipes/configs/llama2/7B_qlora_single_device.yaml max_steps_per_epoch=1

user flow and gaps

weifengpy and others added 26 commits April 3, 2024 18:18
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 19, 2024
@weifengpy weifengpy marked this pull request as draft April 19, 2024 20:38
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@msaroufim
Copy link
Member

Didn't forget to review, will give it a thorough read this afternoon

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Made a first pass and can do a second one tomorrow morning

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_to_cpu(self):
nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda'))
nf4_tensor.cpu()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is just testing against crashes or do also expect the nf4_tensor.device to be cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. this is testing against crashes but i will add assertion on nf4_tensor.device.type == 'cpu'

torch.as_strided(nf4_tensor, nf4_tensor.size(), stride, nf4_tensor.storage_offset())

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_pin_memory(self):
Copy link
Member

@msaroufim msaroufim Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mentioned this briefly last week but could you remind me how you figured out these would be the functions that needed to be tested. (I'm thinking ahead with a tutorial for someone who wants to upstream some new exotic dtpye and get it working with fsdp). That's probably a good candidate for what I mean by we should add another smoke test so we know for sure FSDP will work

So I ran the tests locally and they all worked and fast! So this gives me confidence the nf4 tensor now supports many new ops but it doesnt give me confidence that fsdp won't break in some way

I was hoping we could have a smoke test of the sort fsdp(torch.nn.Sequential(LinearNF4(64,64))) that would ensure nothing breaks and that fsdp doesn't silently drop the dtype since that functionality wasn't tested for fsdp 1 and we had to rely on twitter to get that signal

Copy link
Contributor Author

@weifengpy weifengpy Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree that we need a smoke test on fsdp(model). Not sure how to setup a multi-gpu test in torchao though. Is there some .ci files to change? Is there some example in torchAO? I am happy to fill in the actual logic into the template. As a reference, FSDP tests in pytorch are done like this pytorch/test/distributed/_composable/fsdp/test_fully_shard_training.py

Copy link
Member

@msaroufim msaroufim Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something identical should work the machines we have in CI, every commit is already running on 4 A10Gs linux.g5.12xlarge. No existing example since this is our first distributed test

Let's just do this, first thing we meet tomorrow

def noop_detach(func, *args, **kwargs):
return args[0][0]


@implements(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more of a n00b q to @drisspg : what's up with all the args[0] I feel like there's some sort of contract I can't quite parse

EDIT: It's the NF4 tensor, could we add some comment somewhere to make this clearer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated PR with nf4tensor = args[0] at the begining to make it clearer

self.scaler_block_size,
self.scaler_mean,
self.nf4,
mesh.get_group().size(),
Copy link
Member

@msaroufim msaroufim Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: what is this doing?

Also more generally I don't follow what the 2 fsdp tests are trying to do. I think in fsdp_post_all_gather you are testing to make sure nf4 tensors are preserved and not silently casted to some other type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow what the 2 fsdp tests are trying to do

This is core logic in nf4tensor.py. unit tests happens in another file test_nf4.py

fsdp_pre_all_gather returns a tuple of two things

  • tuple[0] are quantized_scalers, quantization_factor and quantized_data. They are input for all-gather
  • tuple[1] are SubclassTensorArgs, block_size etc are metadata to reconstruct NF4Tensor. mesh.get_group().size() is the group size for all-gather (how many gpus). it's helpful to restore NF4Tensor.size. Eg for 2 gpus, all-gathering tensor(512) will return tensor(512 x 2)

scaler_mean = aten_op(args[0].scaler_mean, *args[1:], **kwargs)
nf4 = aten_op(args[0].nf4, *args[1:], **kwargs)
tensor_meta = SubclassTensorArgs(
args[0].size(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 This also confused me. I think what driss means is just give a human readable name to args[0] so its easier to read the code

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Made a first pass and can do a second one tomorrow morning

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
aten.detach.default,
]
)
def nf4_detach(aten_op, args, kwargs=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make that assumption that requires_grad=False and detach is a no-op, can we add an assertion that checks for args[0].requires_grad?

Also, I am not sure that we need to detach all inner tensors. cc: @bdhirsh

raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}")
ratio = nf4_tensor.numel() // math.prod(new_size)

assert nf4_tensor.quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: These assertion messages preferably should include the values (i.e. both nf4_tensor.quantized_scalers.size(0) and ratio) so that they can be more actionable.

quantization_factor = aten_op(nf4_tensor.quantization_factor, *(args[1:]), **kwargs)
quantized_data = aten_op(nf4_tensor.quantized_data, *(args[1:]), **kwargs)
return NF4Tensor(
SubclassTensorArgs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I seem to see this pattern a lot where we construct SubclassTensorArgs directly from an existing nf4_tensor. Perhaps, consider making this into a helper to avoid the duplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha. not nit at all. added util function to keep the code dry: NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))

Comment on lines +891 to +898
assert (
quantized_scalers.untyped_storage().data_ptr()
== out.quantized_scalers.untyped_storage().data_ptr() and
quantization_factor.untyped_storage().data_ptr()
== out.quantization_factor.untyped_storage().data_ptr() and
quantized_data.untyped_storage().data_ptr()
== out.quantized_data.untyped_storage().data_ptr()
), f"Expects out's data to be the all-gather output"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may consider removing these asserts (in the future) especially if tracing through this becomes an issue. In theory, NF4Tensor should not need to make this kind of assert, but for now, it might be helpful for debugging as the FSDP extension is still in its early stages.

)
)
) and len(args) == 2:
# Tensor.to(device, non_blocking)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that if we tried to use __torch_dispatch__, we would not be able to tell that it is simply .to(device, non_blocking=True) without a dtype argument/dtype change?

What is the story for dequantization? Namely, what is the outer NF4Tensor's dtype, and what happens when you call .to(dtype) with that same dtype? (e.g. if NF4Tensor.dtype == torch.bfloat16, what if you call NF4Tensor.to(torch.bfloat16)?)

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

NF4_OPS_TABLE: Dict[Any, Any] = {}

INNER_TENSOR_NAMES_FOR_FSDP = ["quantized_scalers", "quantization_factor", "quantized_data"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I exclude two tiny tensors: nf4 (numel=16) and scaler_mean (numel=1)
when GPU > numel, we need to implement padding for inner tensors. it's not worth the time in my opinion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it'd apply to more than just FSDP. Is that correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it'd apply to more than just FSDP. Is that correct?

it applies general distributed case when we shard a single tensor to N GPUs. I can change the name to INNER_TENSOR_NAMES_FOR_SHARDING if that's clearer

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
assert nf4tensor.quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}"
quantized_scalers = aten_op(nf4tensor.quantized_scalers, [nf4tensor.quantized_scalers.size(0) // ratio], **kwargs)

assert nf4tensor.quantization_factor.size(0) % ratio == 0, f"quantization_factor.size(0) must be divisible by {ratio}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe these asserts could be unified?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion. I removed duplicative asserts with for loop over inner tensors

@weifengpy weifengpy requested a review from cpuhrsch May 1, 2024 18:14

NF4_OPS_TABLE: Dict[Any, Any] = {}

INNER_TENSOR_NAMES_FOR_SHARDING = ["quantized_scalers", "quantization_factor", "quantized_data"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is something FSDP2 requires any Tensor subclass to have defined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for any Tensor subclass, we prefer reusing __tensor_flatten__ to lookup inner tensors. For NF4, we define INNER_TENSOR_NAMES_FOR_SHARDING as a subset of __tensor_flatten__ because scaler_mean and nf4 are too tiny to shard

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, isn't that something you could filter with a numel based heuristic within FSDP itself instead of requiring some tensor subclasses to communicate it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the inner tensors that are sharded needs to match the torch.chunk implementation in the subclass, so FSDP cannot necessarily determine the tensors to shard itself. (E.g., if FSDP filtered by numel but the subclass implemented torch.chunk to still shard some tensor smaller than the numel threshold, then there would be a correctness issue.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to private const with underscore _INNER_TENSOR_NAMES_FOR_SHARDING after discussion with @cpuhrsch

weifengpy and others added 2 commits May 1, 2024 14:20
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy requested a review from cpuhrsch May 1, 2024 22:06
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the heroic work. Let's open up an issue with known gaps

@weifengpy
Copy link
Contributor Author

Thank you for the heroic work. Let's open up an issue with known gaps

yes, will open issue for the renaming work

@cpuhrsch cpuhrsch merged commit ac53d7f into pytorch:main May 1, 2024
13 checks passed
@weifengpy
Copy link
Contributor Author

Thank you for the heroic work. Let's open up an issue with known gaps

opened issues and linked them here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants