From 21b361193e7b142432a4a773987110394ca5970b Mon Sep 17 00:00:00 2001 From: Clifton Date: Sat, 16 Mar 2024 20:34:57 +0000 Subject: [PATCH] Fix fp16/ bf16 for Prefix Tuning --- src/adapters/methods/prefix_tuning.py | 6 ++++-- src/adapters/models/clip/modeling_clip.py | 2 +- tests/methods/base.py | 12 ++++++++---- tests/methods/test_adapter_common.py | 9 ++++++--- tests/methods/test_lora.py | 6 +++++- tests/methods/test_prefix_tuning.py | 4 +++- tests/methods/test_prompt_tuning.py | 4 +++- tests/test_adapter.py | 6 +++--- tests/test_clip.py | 4 ++-- 9 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/adapters/methods/prefix_tuning.py b/src/adapters/methods/prefix_tuning.py index 2a716e2acc..95ccdf50d8 100644 --- a/src/adapters/methods/prefix_tuning.py +++ b/src/adapters/methods/prefix_tuning.py @@ -519,10 +519,12 @@ def compose_single(self, adapter_setup: str, state: PrefixTuningState, lvl: int value_states = torch.cat([prefix_values, state.value_states], dim=2) if state.attention_mask is not None: if state.attention_mask.dim() == 2: # e.g. for DistilBERT, attention_mask has shape (batch_size, seq_len) - prefix_mask = torch.ones(batch_size, prefix_keys.size(2)).to(state.attention_mask.device) + prefix_mask = torch.ones(batch_size, prefix_keys.size(2)).to( + device=state.attention_mask.device, dtype=state.attention_mask.dtype + ) else: prefix_mask = torch.ones(batch_size, 1, state.attention_mask.size(2), prefix_keys.size(2)).to( - state.attention_mask.device + device=state.attention_mask.device, dtype=state.attention_mask.dtype ) if state.invert_mask: prefix_mask = 1.0 - prefix_mask diff --git a/src/adapters/models/clip/modeling_clip.py b/src/adapters/models/clip/modeling_clip.py index a67491dc58..b74a0308ef 100644 --- a/src/adapters/models/clip/modeling_clip.py +++ b/src/adapters/models/clip/modeling_clip.py @@ -66,7 +66,7 @@ def forward( if causal_attention_mask is not None: prefix_mask = torch.ones( bsz, 1, causal_attention_mask.size(2), src_len - causal_attention_mask.size(-1) - ).to(causal_attention_mask.device) + ).to(device=causal_attention_mask.device, dtype=causal_attention_mask.dtype) causal_attention_mask = torch.cat([prefix_mask, causal_attention_mask], dim=-1) if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( diff --git a/tests/methods/base.py b/tests/methods/base.py index 67d8e26909..3954aece46 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -35,6 +35,9 @@ def create_twin_models(model_class, config_creator=None): class AdapterMethodBaseTestMixin: """Provides base test running methods for testing an adapter method implementation.""" + # Model weight dtypes to test in forward pass + dtypes_to_test = [torch.float32, torch.half] if torch_device == "cuda" else [torch.float32] + def filter_parameters(self, model, filter_keys): return {k: v for (k, v) in model.named_parameters() if any([filter_key in k for filter_key in filter_keys])} @@ -151,14 +154,15 @@ def run_get_test(self, model, adapter_config, num_expected_modules): model.delete_adapter("first") - def run_forward_test(self, model, adapter_config): + def run_forward_test(self, model, adapter_config, dtype=torch.float32): model.eval() name = adapter_config.__class__.__name__ - model.add_adapter(name, config=adapter_config) - model.to(torch_device) + if name not in model.adapters_config: + model.add_adapter(name, config=adapter_config) + model.to(torch_device).to(dtype) - input_data = self.get_input_samples(config=model.config) + input_data = self.get_input_samples(config=model.config, dtype=dtype) # pass 1: set adapter via property model.set_active_adapters([name]) diff --git a/tests/methods/test_adapter_common.py b/tests/methods/test_adapter_common.py index 8d78c88e4a..b1d67757b0 100644 --- a/tests/methods/test_adapter_common.py +++ b/tests/methods/test_adapter_common.py @@ -210,13 +210,16 @@ def test_reduction_factor_no_default(self): with self.assertRaises(KeyError): model.add_adapter(name, config=adapter_config) - def test_adapter_forward(self): + def test_forward_bottleneck(self): model = self.get_model() model.eval() for adapter_config, _ in self.adapter_configs_to_test: - with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__): - self.run_forward_test(model, adapter_config) + for dtype in self.dtypes_to_test: + with self.subTest( + model_class=model.__class__.__name__, config=adapter_config.__class__.__name__, dtype=dtype + ): + self.run_forward_test(model, adapter_config, dtype=dtype) def test_invertible_adapter_forward(self): model = self.get_model() diff --git a/tests/methods/test_lora.py b/tests/methods/test_lora.py index 90f6d26ae8..0ade2bdbb0 100644 --- a/tests/methods/test_lora.py +++ b/tests/methods/test_lora.py @@ -29,7 +29,11 @@ def test_get_lora(self): def test_forward_lora(self): model = self.get_model() - self.run_forward_test(model, LoRAConfig(init_weights="bert", intermediate_lora=True, output_lora=True)) + for dtype in self.dtypes_to_test: + with self.subTest(model_class=model.__class__.__name__, dtype=dtype): + self.run_forward_test( + model, LoRAConfig(init_weights="bert", intermediate_lora=True, output_lora=True), dtype=dtype + ) def test_load_lora(self): self.run_load_test(LoRAConfig()) diff --git a/tests/methods/test_prefix_tuning.py b/tests/methods/test_prefix_tuning.py index 3c98b2854d..a1c41268ba 100644 --- a/tests/methods/test_prefix_tuning.py +++ b/tests/methods/test_prefix_tuning.py @@ -40,7 +40,9 @@ def test_get_prefix_tuning(self): def test_forward_prefix_tuning(self): model = self.get_model() - self.run_forward_test(model, PrefixTuningConfig(flat=True)) + for dtype in self.dtypes_to_test: + with self.subTest(model_class=model.__class__.__name__, dtype=dtype): + self.run_forward_test(model, PrefixTuningConfig(flat=True), dtype=dtype) def test_load_prefix_tuning(self): self.run_load_test(PrefixTuningConfig()) diff --git a/tests/methods/test_prompt_tuning.py b/tests/methods/test_prompt_tuning.py index d0b12d259c..a5150e1aa3 100644 --- a/tests/methods/test_prompt_tuning.py +++ b/tests/methods/test_prompt_tuning.py @@ -24,7 +24,9 @@ def test_get_prompt_tuning(self): def test_forward_prompt_tuning(self): model = self.get_model() - self.run_forward_test(model, PromptTuningConfig(prompt_length=10)) + for dtype in self.dtypes_to_test: + with self.subTest(model_class=model.__class__.__name__, dtype=dtype): + self.run_forward_test(model, PromptTuningConfig(prompt_length=10), dtype=dtype) def test_load_prompt_tuning(self): self.run_load_test(PromptTuningConfig(prompt_length=10)) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index c01f0295a2..ce6f2c1b34 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -49,7 +49,7 @@ def get_model(self): model.to(torch_device) return model - def get_input_samples(self, shape=None, vocab_size=5000, config=None): + def get_input_samples(self, shape=None, vocab_size=5000, config=None, **kwargs): shape = shape or self.default_input_samples_shape total_dims = 1 for dim in shape: @@ -96,7 +96,7 @@ def assert_adapter_unavailable(self, model, adapter_name): class VisionAdapterTestBase(AdapterTestBase): default_input_samples_shape = (3, 3, 224, 224) - def get_input_samples(self, shape=None, config=None): + def get_input_samples(self, shape=None, config=None, dtype=torch.float, **kwargs): shape = shape or self.default_input_samples_shape total_dims = 1 for dim in shape: @@ -104,7 +104,7 @@ def get_input_samples(self, shape=None, config=None): values = [] for _ in range(total_dims): values.append(random.random()) - pixel_values = torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous() + pixel_values = torch.tensor(data=values, dtype=dtype, device=torch_device).view(shape).contiguous() in_data = {"pixel_values": pixel_values} return in_data diff --git a/tests/test_clip.py b/tests/test_clip.py index 2ed57268e4..57c264f6b5 100644 --- a/tests/test_clip.py +++ b/tests/test_clip.py @@ -170,7 +170,7 @@ class CLIPAdapterTestBase(AdapterTestBase): default_vision_input_samples_shape = (3, 3, 224, 224) do_run_train_tests = False - def get_input_samples(self, vocab_size=5000, config=None): + def get_input_samples(self, vocab_size=5000, config=None, dtype=torch.float, **kwargs): # text inputs shape = self.default_text_input_samples_shape total_dims = 1 @@ -194,7 +194,7 @@ def get_input_samples(self, vocab_size=5000, config=None): values = [] for _ in range(total_dims): values.append(random.random()) - pixel_values = torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous() + pixel_values = torch.tensor(data=values, dtype=dtype, device=torch_device).view(shape).contiguous() in_data["pixel_values"] = pixel_values return in_data