Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fp16/ bf16 for Prefix Tuning #659

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,12 @@ def compose_single(self, adapter_setup: str, state: PrefixTuningState, lvl: int
value_states = torch.cat([prefix_values, state.value_states], dim=2)
if state.attention_mask is not None:
if state.attention_mask.dim() == 2: # e.g. for DistilBERT, attention_mask has shape (batch_size, seq_len)
prefix_mask = torch.ones(batch_size, prefix_keys.size(2)).to(state.attention_mask.device)
prefix_mask = torch.ones(batch_size, prefix_keys.size(2)).to(
device=state.attention_mask.device, dtype=state.attention_mask.dtype
)
else:
prefix_mask = torch.ones(batch_size, 1, state.attention_mask.size(2), prefix_keys.size(2)).to(
state.attention_mask.device
device=state.attention_mask.device, dtype=state.attention_mask.dtype
)
if state.invert_mask:
prefix_mask = 1.0 - prefix_mask
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(
if causal_attention_mask is not None:
prefix_mask = torch.ones(
bsz, 1, causal_attention_mask.size(2), src_len - causal_attention_mask.size(-1)
).to(causal_attention_mask.device)
).to(device=causal_attention_mask.device, dtype=causal_attention_mask.dtype)
causal_attention_mask = torch.cat([prefix_mask, causal_attention_mask], dim=-1)
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
Expand Down
12 changes: 8 additions & 4 deletions tests/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def create_twin_models(model_class, config_creator=None):
class AdapterMethodBaseTestMixin:
"""Provides base test running methods for testing an adapter method implementation."""

# Model weight dtypes to test in forward pass
dtypes_to_test = [torch.float32, torch.half] if torch_device == "cuda" else [torch.float32]

def filter_parameters(self, model, filter_keys):
return {k: v for (k, v) in model.named_parameters() if any([filter_key in k for filter_key in filter_keys])}

Expand Down Expand Up @@ -151,14 +154,15 @@ def run_get_test(self, model, adapter_config, num_expected_modules):

model.delete_adapter("first")

def run_forward_test(self, model, adapter_config):
def run_forward_test(self, model, adapter_config, dtype=torch.float32):
model.eval()

name = adapter_config.__class__.__name__
model.add_adapter(name, config=adapter_config)
model.to(torch_device)
if name not in model.adapters_config:
model.add_adapter(name, config=adapter_config)
model.to(torch_device).to(dtype)

input_data = self.get_input_samples(config=model.config)
input_data = self.get_input_samples(config=model.config, dtype=dtype)

# pass 1: set adapter via property
model.set_active_adapters([name])
Expand Down
9 changes: 6 additions & 3 deletions tests/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,16 @@ def test_reduction_factor_no_default(self):
with self.assertRaises(KeyError):
model.add_adapter(name, config=adapter_config)

def test_adapter_forward(self):
def test_forward_bottleneck(self):
model = self.get_model()
model.eval()

for adapter_config, _ in self.adapter_configs_to_test:
with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__):
self.run_forward_test(model, adapter_config)
for dtype in self.dtypes_to_test:
with self.subTest(
model_class=model.__class__.__name__, config=adapter_config.__class__.__name__, dtype=dtype
):
self.run_forward_test(model, adapter_config, dtype=dtype)

def test_invertible_adapter_forward(self):
model = self.get_model()
Expand Down
6 changes: 5 additions & 1 deletion tests/methods/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def test_get_lora(self):

def test_forward_lora(self):
model = self.get_model()
self.run_forward_test(model, LoRAConfig(init_weights="bert", intermediate_lora=True, output_lora=True))
for dtype in self.dtypes_to_test:
with self.subTest(model_class=model.__class__.__name__, dtype=dtype):
self.run_forward_test(
model, LoRAConfig(init_weights="bert", intermediate_lora=True, output_lora=True), dtype=dtype
)

def test_load_lora(self):
self.run_load_test(LoRAConfig())
Expand Down
4 changes: 3 additions & 1 deletion tests/methods/test_prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def test_get_prefix_tuning(self):

def test_forward_prefix_tuning(self):
model = self.get_model()
self.run_forward_test(model, PrefixTuningConfig(flat=True))
for dtype in self.dtypes_to_test:
with self.subTest(model_class=model.__class__.__name__, dtype=dtype):
self.run_forward_test(model, PrefixTuningConfig(flat=True), dtype=dtype)

def test_load_prefix_tuning(self):
self.run_load_test(PrefixTuningConfig())
Expand Down
4 changes: 3 additions & 1 deletion tests/methods/test_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def test_get_prompt_tuning(self):

def test_forward_prompt_tuning(self):
model = self.get_model()
self.run_forward_test(model, PromptTuningConfig(prompt_length=10))
for dtype in self.dtypes_to_test:
with self.subTest(model_class=model.__class__.__name__, dtype=dtype):
self.run_forward_test(model, PromptTuningConfig(prompt_length=10), dtype=dtype)

def test_load_prompt_tuning(self):
self.run_load_test(PromptTuningConfig(prompt_length=10))
Expand Down
6 changes: 3 additions & 3 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_model(self):
model.to(torch_device)
return model

def get_input_samples(self, shape=None, vocab_size=5000, config=None):
def get_input_samples(self, shape=None, vocab_size=5000, config=None, **kwargs):
shape = shape or self.default_input_samples_shape
total_dims = 1
for dim in shape:
Expand Down Expand Up @@ -96,15 +96,15 @@ def assert_adapter_unavailable(self, model, adapter_name):
class VisionAdapterTestBase(AdapterTestBase):
default_input_samples_shape = (3, 3, 224, 224)

def get_input_samples(self, shape=None, config=None):
def get_input_samples(self, shape=None, config=None, dtype=torch.float, **kwargs):
shape = shape or self.default_input_samples_shape
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(random.random())
pixel_values = torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
pixel_values = torch.tensor(data=values, dtype=dtype, device=torch_device).view(shape).contiguous()
in_data = {"pixel_values": pixel_values}

return in_data
Expand Down
4 changes: 2 additions & 2 deletions tests/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class CLIPAdapterTestBase(AdapterTestBase):
default_vision_input_samples_shape = (3, 3, 224, 224)
do_run_train_tests = False

def get_input_samples(self, vocab_size=5000, config=None):
def get_input_samples(self, vocab_size=5000, config=None, dtype=torch.float, **kwargs):
# text inputs
shape = self.default_text_input_samples_shape
total_dims = 1
Expand All @@ -194,7 +194,7 @@ def get_input_samples(self, vocab_size=5000, config=None):
values = []
for _ in range(total_dims):
values.append(random.random())
pixel_values = torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
pixel_values = torch.tensor(data=values, dtype=dtype, device=torch_device).view(shape).contiguous()
in_data["pixel_values"] = pixel_values

return in_data
Expand Down
Loading