Skip to content

Commit

Permalink
[shardformer] fix gpt2headdouble
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Sep 13, 2023
1 parent b0c4f28 commit 3a7e209
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion applications/Chat/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers>=4.20.1
transformers>=4.33.1
tqdm
datasets
loralib
Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def gpt2_model_forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

logger = logging.get_logger(__name__)
print("attention_mask_input" + str(attention_mask.shape))

# Preprocess passed in arguments
# TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
Expand Down Expand Up @@ -94,9 +95,9 @@ def gpt2_model_forward(
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size = input_shape[0]
device = hidden_states.device
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
batch_size = hidden_states.shape[0]

# GPT2Attention mask.
if attention_mask is not None:
Expand Down Expand Up @@ -176,6 +177,7 @@ def gpt2_model_forward(
block = self.h[i]
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
print("attention_mask_pp" + str(attention_mask.shape))
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
Expand Down
15 changes: 10 additions & 5 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,17 @@ def _criterion(outputs, inputs):
data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,))

sharded_model.train()
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
if booster.plugin.stage_manager is not None:
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
# for k, v in data.items():
# if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
# new_shape = [1] * v.dim()
# new_shape[0] = 4
# data[k] = v.to('cuda').repeat(*new_shape)

data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def run_gpt2_test(test_config):
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
Expand Down

0 comments on commit 3a7e209

Please sign in to comment.