From ef6d68b7ed260db4d1be0e71231867a36ec66835 Mon Sep 17 00:00:00 2001 From: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> Date: Wed, 19 Jun 2024 22:04:53 +0530 Subject: [PATCH] Add support for GPTJ (#5) * Add support for GPTJ model and unit tests Signed-off-by: mamtsing Signed-off-by: mamtsing * Update modeling file to support GPTJ model based on new interface Signed-off-by: mamtsing * Update modeling_gptj.py Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> * Update generate_inputs.py Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> * Update run_utils.py Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> * Update generate_inputs.py Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> * Update config.json Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> * Update class comments Signed-off-by: Mamta Singh * Update _utils.py Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> * Update _utils.py Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> * Update modeling_gptj.py Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> * Update config.json Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> * Update modeling_gptj.py Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> * Update generate_inputs.py Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> --------- Signed-off-by: mamtsing Signed-off-by: mamtsing Signed-off-by: Mamta Singh <168400541+quic-mamta@users.noreply.github.com> Signed-off-by: Mamta Singh --- QEfficient/transformers/modeling_utils.py | 9 +- .../models/codegen/modeling_codegen.py | 2 + .../models/falcon/modeling_falcon.py | 6 +- .../transformers/models/gptj/__init__.py | 7 + .../transformers/models/gptj/modeling_gptj.py | 426 ++++++++++++++++++ .../models/llama/modeling_llama.py | 6 +- .../models/mistral/modeling_mistral.py | 4 +- .../models/mixtral_moe/modeling_mixtral.py | 5 +- .../transformers/models/mpt/modeling_mpt.py | 8 +- .../transformers/models/phi3/modeling_phi3.py | 4 +- .../models/qwen2/modeling_qwen2.py | 4 +- .../models/starcoder2/modeling_starcoder2.py | 6 +- QEfficient/utils/_utils.py | 4 +- QEfficient/utils/generate_inputs.py | 147 ++---- QEfficient/utils/run_utils.py | 15 +- tests/config.json | 6 +- 16 files changed, 528 insertions(+), 131 deletions(-) create mode 100644 QEfficient/transformers/models/gptj/__init__.py create mode 100644 QEfficient/transformers/models/gptj/modeling_gptj.py diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 598e3d4a..29cd8883 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -20,6 +20,7 @@ FalconModel, ) from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model +from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJForCausalLM, GPTJModel from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaModel, LlamaRMSNorm from transformers.models.mistral.modeling_mistral import ( MistralAttention, @@ -56,6 +57,7 @@ QEffFalconModel, ) from .models.gpt2.modeling_gpt2 import QEffGPT2Attention, QEffGPT2Block, QEffGPT2LMHeadModel, QEffGPT2Model +from .models.gptj.modeling_gptj import QEffGPTJAttention, QEffGPTJForCausalLM, QEffGPTJModel from .models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaForCausalLM, @@ -85,9 +87,10 @@ # Required for the Automation tool ModelArchitectures = namedtuple("ModelArchitectures", ["architectures"]) # Create an instance of the named tuple -my_architectures = ModelArchitectures( +qeff_supported_architectures = ModelArchitectures( [ GPT2LMHeadModel.__name__, + GPTJForCausalLM.__name__, MptForCausalLM.__name__, CodeGenForCausalLM.__name__, LlamaForCausalLM.__name__, @@ -108,6 +111,10 @@ GPT2Block: QEffGPT2Block, GPT2Attention: QEffGPT2Attention, GPT2LMHeadModel: QEffGPT2LMHeadModel, + # GPTJ model layers + GPTJModel: QEffGPTJModel, + GPTJAttention: QEffGPTJAttention, + GPTJForCausalLM: QEffGPTJForCausalLM, # Llama model layers LlamaModel: QEffLlamaModel, LlamaAttention: QEffLlamaAttention, diff --git a/QEfficient/transformers/models/codegen/modeling_codegen.py b/QEfficient/transformers/models/codegen/modeling_codegen.py index d82d29b4..f234e739 100644 --- a/QEfficient/transformers/models/codegen/modeling_codegen.py +++ b/QEfficient/transformers/models/codegen/modeling_codegen.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +"""PyTorch Codegen model.""" + from typing import Optional, Tuple, Union import torch diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 81c5b579..61be5a1c 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +"""PyTorch Falcon model.""" + import math import warnings from typing import Optional, Tuple, Union @@ -34,7 +36,7 @@ class QEffFalconAttention(FalconAttention): """ - Copied from FalconForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py + Copied from FalconAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py The only differences are: - add new args position idx for the cache_kwargs for kv retention """ @@ -214,7 +216,7 @@ def forward( class QEffFalconModel(FalconModel): """ - Copied from FalconForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py + Copied from FalconModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py The only differences are: - add new args position idx for the cache_kwargs for kv retention - update causal attention mask diff --git a/QEfficient/transformers/models/gptj/__init__.py b/QEfficient/transformers/models/gptj/__init__.py new file mode 100644 index 00000000..91fee0a4 --- /dev/null +++ b/QEfficient/transformers/models/gptj/__init__.py @@ -0,0 +1,7 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + diff --git a/QEfficient/transformers/models/gptj/modeling_gptj.py b/QEfficient/transformers/models/gptj/modeling_gptj.py new file mode 100644 index 00000000..5bd061bb --- /dev/null +++ b/QEfficient/transformers/models/gptj/modeling_gptj.py @@ -0,0 +1,426 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""PyTorch GPT-J model.""" + +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import DynamicCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.gptj.modeling_gptj import ( + GPTJAttention, + GPTJForCausalLM, + GPTJModel, + get_embed_positions, + logger, + rotate_every_two, +) +from transformers.utils.import_utils import is_torch_fx_proxy + +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + + +def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: + *sin_shape, sin_last_shape = sin.shape + sin = sin.reshape(-1, 1).repeat(1, 2).reshape(*sin_shape, 1, 2 * sin_last_shape) + *cos_shape, cos_last_shape = cos.shape + cos = cos.reshape(-1, 1).repeat(1, 2).reshape(*cos_shape, 1, 2 * cos_last_shape) + return (tensor * cos) + (rotate_every_two(tensor) * sin) + + +class QEffGPTJAttention(GPTJAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + """ + Copied from GPTJAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + """ + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + # compute causal mask from causal mask buffer + query_length, key_length = query.size(-2), key.size(-2) + + # Keep the attention weights computation in fp32 to avoid overflow issues + query = query.to(torch.float32) + key = key.to(torch.float32) + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + attn_weights = attn_weights / self.scale_attn + + if attention_mask is not None: + # Apply the attention mask + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.to(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states: torch.FloatTensor, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + + if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): + # The logic to conditionally copy to GPU could not be traced, so we do this + # every time in the torch.fx case + embed_positions = get_embed_positions(self.embed_positions, position_ids) + else: + embed_positions = self._get_embed_positions(position_ids) + + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + repeated_position_ids = torch.where(repeated_position_ids==-1, 0, repeated_position_ids) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + k_rot = apply_rotary_pos_emb(k_rot, sin, cos) + q_rot = apply_rotary_pos_emb(q_rot, sin, cos) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + key = apply_rotary_pos_emb(key, sin, cos) + query = apply_rotary_pos_emb(query, sin, cos) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + # Added for optimized GPTJ Attention for AI 100 KV Retention + # Update the cache_kwargs with position_ids for Cloud AI 100 + cache_kwargs = {"position_ids": position_ids} + pkv = DynamicCache() + pkv.key_cache.append(layer_past[0]) + pkv.value_cache.append(layer_past[1]) + key, value = pkv.update(key, value, 0, cache_kwargs) + + if use_cache is True: + # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation. + # Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128 + present = (key.to(hidden_states.dtype), value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class QEffGPTJModel(GPTJModel): + """ + Copied from GPTJModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update causal attention mask + """ + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + if not self._use_flash_attention_2: + # Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + else: + # update attention mask for Cloud AI 100 + attention_mask = _create_causal_mask(position_ids, past_length, None) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + position_ids, + head_mask[i], + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states=hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class QEffGPTJForCausalLM(GPTJForCausalLM): + """ + Copied from GPTJForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update the hidden_states, and fix for onnx model + """ + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = transformer_outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + lm_logits = self.lm_head(hidden_states) + lm_logits = lm_logits.float() + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + # make sure sampling in fp16 works correctly and + # compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 793edcb0..f3e068b3 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +"""PyTorch Llama model.""" + import math from typing import List, Optional, Tuple, Union @@ -32,7 +34,7 @@ class QEffLlamaAttention(LlamaAttention): """ - Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + Copied from LlamaAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: - add new args position idx for the cache_kwargs for kv retention """ @@ -226,7 +228,7 @@ def forward( class QEffLlamaModel(LlamaModel): """ - Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + Copied from LlamaModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: - add new args position idx for the cache_kwargs for kv retention - update causal attention mask diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index a183275e..d703ea3f 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -37,7 +37,7 @@ class QEffMistralAttention(MistralAttention): """ - Copied from MistralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py + Copied from MistralAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py The only differences are: - add new args position idx for the cache_kwargs for kv retention """ @@ -121,7 +121,7 @@ def forward( class QEffMistralModel(MistralModel): """ - Copied from MistralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py + Copied from MistralModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py The only differences are: - add new args position idx for the cache_kwargs for kv retention - update causal attention mask diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 03f2a32e..da8e5cf8 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +"""PyTorch Mixtral model.""" import math from typing import List, Optional, Tuple, Union @@ -36,7 +37,7 @@ class QEffMixtralAttention(MixtralAttention): """ - Copied from MixtralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py + Copied from MixtralAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py The only differences are: - add new args position idx for the cache_kwargs for kv retention """ @@ -173,7 +174,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral class QEffMixtralModel(MixtralModel): """ - Copied from MixtralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py + Copied from MixtralModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py The only differences are: - add new args position idx for the cache_kwargs for kv retention - update causal attention mask diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index dcac9d5a..48e43b0d 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +"""PyTorch MPT model.""" + from typing import Optional, Tuple, Union import torch @@ -24,7 +26,7 @@ class QEffMptAttention(MptAttention): """ - Copied from MptForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py + Copied from MptAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py The only differences are: - add new args position idx for the cache_kwargs for kv retention """ @@ -92,7 +94,7 @@ def forward( class QEffMptBlock(MptBlock): """ - Copied from MptForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py + Copied from MptBlock: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py The only differences are: - add new args cache idx for the kv retention """ @@ -145,7 +147,7 @@ def forward( class QEFfMptModel(MptModel): """ - Copied from MptForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py + Copied from MptModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py The only differences are: - add new args cache idx for the kv retention """ diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 769392d9..aebc92bf 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -33,7 +33,7 @@ class QEffPhi3Attention(Phi3Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" """ - Copied from Phi3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py + Copied from Phi3Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py The only differences are: - add new args position idx for the cache_kwargs for kv retention """ @@ -123,7 +123,7 @@ def forward( class QEffPhi3Model(Phi3Model): """ - Copied from Phi3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py + Copied from Phi3Model: https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py The only differences are: - add new args position idx for the cache_kwargs for kv retention - update causal attention mask diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 56fda43b..0d5cd19f 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -37,7 +37,7 @@ class QEffQwen2Attention(Qwen2Attention): """ - Copied from Qwen2ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py + Copied from Qwen2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py The only differences are: - add new args position idx for the cache_kwargs for kv retention """ @@ -117,7 +117,7 @@ def forward( class QEffQwen2Model(Qwen2Model): """ - Copied from Qwen2ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py + Copied from Qwen2Model: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py The only differences are: - add new args position idx for the cache_kwargs for kv retention - update causal attention mask diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index 701d9166..f63b1932 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -36,7 +36,7 @@ class QEffStarcoder2Attention(Starcoder2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" """ - Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py + Copied from Starcoder2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py The only differences are: - add new args position idx for the cache_kwargs for kv retention """ @@ -122,7 +122,7 @@ def forward( class QEffStarcoder2Model(Starcoder2Model): """ - Copied from Starcoder2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py + Copied from Starcoder2Model: https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py The only differences are: - add new args position idx for the cache_kwargs for kv retention - update causal attention mask @@ -281,7 +281,7 @@ def forward( class QEffStarcoder2ForCausalLM(Starcoder2ForCausalLM): """ - Copied from Starcoder2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py + Copied from Starcoder2ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py The only differences are: - add new args position idx for the cache_kwargs for kv retention - update the hidden_states, and fix for onnx model diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index ed2c898e..3bbe2e7a 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -119,13 +119,13 @@ def onnx_exists(model_name: str) -> Tuple[bool, str, str]: return onnx_exists_bool, onnx_dir_path, onnx_model_path -def load_hf_tokenizer(model_name: str, cache_dir: Optional[str] = None, hf_token: Optional[str] = None, padding_side:str = "left", **kwargs) -> Union[PreTrainedTokenizerFast, PreTrainedTokenizer]: +def load_hf_tokenizer(model_name: str, cache_dir: Optional[str] = None, hf_token: Optional[str] = None, padding_side:str = "right", **kwargs) -> Union[PreTrainedTokenizerFast, PreTrainedTokenizer]: logger.info(f"Loading Tokenizer for {model_name}") if hf_token is not None: login(hf_token) # Download tokenizer along with model if it doesn't exist - model_hf_path = hf_download(repo_id=model_name, cache_dir=cache_dir, allow_patterns=["*.json", "*.py", "*token*"]) + model_hf_path = hf_download(repo_id=model_name, cache_dir=cache_dir, allow_patterns=["*.json", "*.py", "*token*", "*.txt"]) tokenizer = AutoTokenizer.from_pretrained(model_hf_path, padding_side=padding_side, trust_remote_code=True, **kwargs) padding_check_and_fix(tokenizer) # Check and fix tokenizer viability diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 1da64e2c..d09e1be5 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -1,6 +1,6 @@ # ----------------------------------------------------------------------------- # -# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- @@ -44,8 +44,6 @@ def prepare_pytorch_inputs(self, n_layer, padding_shape): batch_size, input_len = input_ids.shape inputs.pop("attention_mask") position_ids = torch.arange(input_len).view(1, -1) - print(batch_size, input_len, position_ids) - inputs["input_ids"] = torch.concat( [ input_ids, @@ -54,7 +52,6 @@ def prepare_pytorch_inputs(self, n_layer, padding_shape): ], 1, ) - inputs["position_ids"] = torch.concat( [ position_ids, @@ -69,7 +66,6 @@ def prepare_pytorch_inputs(self, n_layer, padding_shape): past_value = torch.zeros((padding_shape), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) - inputs["past_key_values"] = tuple(past_key_values) return inputs @@ -89,7 +85,6 @@ def update_pytorch_inputs(self, iteration, inputs, pt_outputs): updated_inputs["past_key_values"] = tuple( [(key.detach(), value.detach()) for key, value in pt_outputs["past_key_values"]] ) - return updated_inputs def prepare_ort_inputs(self, n_layer, padding_shape): @@ -100,43 +95,29 @@ def prepare_ort_inputs(self, n_layer, padding_shape): :return inputs: Dict - input_ids, position_ids, past_key_values """ - model_inputs = self.tokenizer( + inputs = self.tokenizer( self.input_str, return_tensors="np", padding=True, ) - input_ids = model_inputs["input_ids"] - - inputs = {} - inputs["input_ids"] = input_ids - - batch_size, input_len = inputs["input_ids"].shape + input_ids = inputs["input_ids"] + batch_size, input_len = input_ids.shape + inputs.pop("attention_mask") position_ids = np.arange(input_len).reshape(1, -1) - print(batch_size, input_len, position_ids) - - if len(input_ids.shape) == 1: - inputs["input_ids"] = input_ids.astype(np.int64)[:, np.newaxis] - else: - input_ids = np.concatenate( - [ - input_ids, - np.ones((batch_size, self.prompt_len - input_len)) * (self.tokenizer.pad_token_id), - ], - axis=1, - ).astype(np.int64) - inputs["input_ids"] = input_ids.astype(np.int64) - - if len(position_ids.shape) == 1: - inputs["position_ids"] = position_ids.astype(np.int64)[:, np.newaxis] - else: - position_ids = np.concatenate( - [ - position_ids, - np.ones((batch_size, self.prompt_len - input_len)) * (-1), - ], - axis=1, - ).astype(np.int64) - inputs["position_ids"] = position_ids.astype(np.int64) + inputs["input_ids"] = np.concatenate( + [ + input_ids, + np.full((batch_size, self.prompt_len - input_len), self.tokenizer.pad_token_id) + ], + axis=1, + ).astype(np.int64) + inputs["position_ids"] = np.concatenate( + [ + position_ids, + np.full((batch_size, self.prompt_len - input_len), -1) + ], + axis=1, + ).astype(np.int64) for i in range(n_layer): inputs["past_key." + str(i)] = np.zeros((padding_shape), dtype=np.float32) @@ -154,26 +135,14 @@ def update_ort_inputs(self, iteration, inputs, ort_outputs, n_layer): :return inputs: Dict - input_ids, position_ids, past_key_values """ - past_key_values = ort_outputs["past_key_values"] - - input_ids = ort_outputs["logits"].argmax(-1) - position_ids = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1 - - if len(input_ids.shape) == 1: - inputs["input_ids"] = input_ids.astype(np.int64)[:, np.newaxis] - else: - inputs["input_ids"] = input_ids.astype(np.int64) - - if len(position_ids.shape) == 1: - inputs["position_ids"] = position_ids.astype(np.int64)[:, np.newaxis] - else: - inputs["position_ids"] = position_ids.astype(np.int64) - + updated_inputs = {} + updated_inputs["input_ids"] = ort_outputs["logits"].argmax(-1) + updated_inputs["position_ids"] = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1 for i in range(n_layer): - inputs["past_key." + str(i)] = past_key_values[i * 2] - inputs["past_value." + str(i)] = past_key_values[i * 2 + 1] + updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2] + updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1] - return inputs + return updated_inputs def prepare_cloud_ai_100_inputs(self, n_layer, padding_shape): """ @@ -183,48 +152,34 @@ def prepare_cloud_ai_100_inputs(self, n_layer, padding_shape): :return inputs: Dict - input_ids, position_ids, past_key_values """ - model_inputs = self.tokenizer( + inputs = self.tokenizer( self.input_str, return_tensors="np", padding=True, ) - input_ids = model_inputs["input_ids"] - - inputs = {} - inputs["input_ids"] = input_ids - + input_ids = inputs["input_ids"] batch_size, input_len = inputs["input_ids"].shape + inputs.pop("attention_mask") position_ids = np.arange(input_len).reshape(1, -1) - print(batch_size, input_len, position_ids) - - if len(input_ids.shape) == 1: - inputs["input_ids"] = input_ids.astype(np.int64)[:, np.newaxis] - else: - input_ids = np.concatenate( - [ - input_ids, - np.ones((batch_size, self.prompt_len - input_len)) * (self.tokenizer.pad_token_id), - ], - axis=1, - ).astype(np.int64) - inputs["input_ids"] = input_ids.astype(np.int64) - - if len(position_ids.shape) == 1: - inputs["position_ids"] = position_ids.astype(np.int64)[:, np.newaxis] - else: - position_ids = np.concatenate( - [ - position_ids, - np.ones((batch_size, self.prompt_len - input_len)) * (-1), - ], - axis=1, - ).astype(np.int64) - inputs["position_ids"] = position_ids.astype(np.int64) + inputs["input_ids"] = np.concatenate( + [ + input_ids, + np.full((batch_size, self.prompt_len - input_len), self.tokenizer.pad_token_id) + ], + axis=1, + ).astype(np.int64) + inputs["position_ids"] = np.concatenate( + [ + position_ids, + np.full((batch_size, self.prompt_len - input_len), -1) + ], + axis=1, + ).astype(np.int64) for i in range(n_layer): inputs["past_key." + str(i)] = np.zeros((padding_shape), dtype=np.float16) inputs["past_value." + str(i)] = np.zeros((padding_shape), dtype=np.float16) - + return inputs def update_cloud_ai_100_inputs(self, iteration, inputs, outputs): @@ -238,17 +193,7 @@ def update_cloud_ai_100_inputs(self, iteration, inputs, outputs): """ updated_inputs = {} + updated_inputs["input_ids"] = outputs["logits"].argmax(-1) + updated_inputs["position_ids"] = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1 - input_ids = outputs["logits"].argmax(-1) - position_ids = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1 - - if len(input_ids.shape) == 1: - updated_inputs["input_ids"] = input_ids.astype(np.int64)[:, np.newaxis] - else: - updated_inputs["input_ids"] = input_ids.astype(np.int64) - if len(position_ids.shape) == 1: - updated_inputs["position_ids"] = position_ids.astype(np.int64)[:, np.newaxis] - else: - updated_inputs["position_ids"] = position_ids.astype(np.int64) - - return updated_inputs \ No newline at end of file + return updated_inputs diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index e95be3c6..573685bf 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -40,6 +40,7 @@ def __init__(self, tokenizer, prompt, prompt_len, ctx_len): self.input_handler = InputHandler(self.tokenizer, self.prompt, self.prompt_len, self.ctx_len) + @torch.no_grad() def run_hf_model_on_pytorch(self, model_hf): """ Function responsible for running Huggingface PyTorch model and return the output tokens @@ -50,12 +51,11 @@ def run_hf_model_on_pytorch(self, model_hf): input_ids_len = len(input_ids[0]) - with torch.no_grad(): - for _ in range(self.gen_len): - outputs = model_hf(input_ids) - logits = outputs.logits[:, -1, :] - predicted_token_id = torch.argmax(logits, dim=-1) - input_ids = torch.cat([input_ids, predicted_token_id.unsqueeze(1)], dim=-1) + for _ in range(self.gen_len): + outputs = model_hf(input_ids) + logits = outputs.logits[:, -1, :] + predicted_token_id = torch.argmax(logits, dim=-1) + input_ids = torch.cat([input_ids, predicted_token_id.unsqueeze(1)], dim=-1) generated_ids = input_ids[0][input_ids_len:].detach().numpy() generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) @@ -165,8 +165,7 @@ def run_kv_model_on_ort(self, model_path, n_layer, padding_shape): def run_kv_model_on_cloud_ai_100(self, session, n_layer, padding_shape): """ Function responsible for running ONNX model on Cloud AI 100 and return the output tokens - :param qpc_path: str - :param device_id: List[int] + :param session: QAICInferenceSession :param n_layer : int :param padding_shape : List[int] :return generated_ids: numpy.ndarray - output tokens diff --git a/tests/config.json b/tests/config.json index 302ed48a..b71d9c22 100644 --- a/tests/config.json +++ b/tests/config.json @@ -36,9 +36,13 @@ "model_name": "wtang06/mpt-125m-c4", "model_class": "MptForCausalLM" }, + { + "model_name":"hakurei/gpt-j-random-tinier", + "model_class":"GPTJForCausalLM" + }, { "model_name":"mistralai/Mixtral-8x7B-Instruct-v0.1", "model_class":"MixtralForCausalLM" } ] -} \ No newline at end of file +}