Skip to content

Commit

Permalink
Add support for GPTJ (#5)
Browse files Browse the repository at this point in the history
* Add support for GPTJ model and unit tests

Signed-off-by: mamtsing <[email protected]>
Signed-off-by: mamtsing <[email protected]>

* Update modeling file to support GPTJ model based on new interface

Signed-off-by: mamtsing <[email protected]>

* Update modeling_gptj.py

Signed-off-by: Mamta Singh <[email protected]>

* Update generate_inputs.py

Signed-off-by: Mamta Singh <[email protected]>

* Update run_utils.py

Signed-off-by: Mamta Singh <[email protected]>

* Update generate_inputs.py

Signed-off-by: Mamta Singh <[email protected]>

* Update config.json

Signed-off-by: Mamta Singh <[email protected]>

* Update class comments

Signed-off-by: Mamta Singh <[email protected]>

* Update _utils.py

Signed-off-by: Mamta Singh <[email protected]>

* Update _utils.py

Signed-off-by: Mamta Singh <[email protected]>

* Update modeling_gptj.py

Signed-off-by: Mamta Singh <[email protected]>

* Update config.json

Signed-off-by: Mamta Singh <[email protected]>

* Update modeling_gptj.py

Signed-off-by: Mamta Singh <[email protected]>

* Update generate_inputs.py

Signed-off-by: Mamta Singh <[email protected]>

---------

Signed-off-by: mamtsing <[email protected]>
Signed-off-by: mamtsing <[email protected]>
Signed-off-by: Mamta Singh <[email protected]>
Signed-off-by: Mamta Singh <[email protected]>
  • Loading branch information
quic-mamta authored Jun 19, 2024
1 parent ddf7ac1 commit ef6d68b
Show file tree
Hide file tree
Showing 16 changed files with 528 additions and 131 deletions.
9 changes: 8 additions & 1 deletion QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
FalconModel,
)
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJForCausalLM, GPTJModel
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
Expand Down Expand Up @@ -56,6 +57,7 @@
QEffFalconModel,
)
from .models.gpt2.modeling_gpt2 import QEffGPT2Attention, QEffGPT2Block, QEffGPT2LMHeadModel, QEffGPT2Model
from .models.gptj.modeling_gptj import QEffGPTJAttention, QEffGPTJForCausalLM, QEffGPTJModel
from .models.llama.modeling_llama import (
QEffLlamaAttention,
QEffLlamaForCausalLM,
Expand Down Expand Up @@ -85,9 +87,10 @@
# Required for the Automation tool
ModelArchitectures = namedtuple("ModelArchitectures", ["architectures"])
# Create an instance of the named tuple
my_architectures = ModelArchitectures(
qeff_supported_architectures = ModelArchitectures(
[
GPT2LMHeadModel.__name__,
GPTJForCausalLM.__name__,
MptForCausalLM.__name__,
CodeGenForCausalLM.__name__,
LlamaForCausalLM.__name__,
Expand All @@ -108,6 +111,10 @@
GPT2Block: QEffGPT2Block,
GPT2Attention: QEffGPT2Attention,
GPT2LMHeadModel: QEffGPT2LMHeadModel,
# GPTJ model layers
GPTJModel: QEffGPTJModel,
GPTJAttention: QEffGPTJAttention,
GPTJForCausalLM: QEffGPTJForCausalLM,
# Llama model layers
LlamaModel: QEffLlamaModel,
LlamaAttention: QEffLlamaAttention,
Expand Down
2 changes: 2 additions & 0 deletions QEfficient/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#
# -----------------------------------------------------------------------------

"""PyTorch Codegen model."""

from typing import Optional, Tuple, Union

import torch
Expand Down
6 changes: 4 additions & 2 deletions QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#
# -----------------------------------------------------------------------------

"""PyTorch Falcon model."""

import math
import warnings
from typing import Optional, Tuple, Union
Expand Down Expand Up @@ -34,7 +36,7 @@

class QEffFalconAttention(FalconAttention):
"""
Copied from FalconForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
Copied from FalconAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
The only differences are:
- add new args position idx for the cache_kwargs for kv retention
"""
Expand Down Expand Up @@ -214,7 +216,7 @@ def forward(

class QEffFalconModel(FalconModel):
"""
Copied from FalconForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
Copied from FalconModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
The only differences are:
- add new args position idx for the cache_kwargs for kv retention
- update causal attention mask
Expand Down
7 changes: 7 additions & 0 deletions QEfficient/transformers/models/gptj/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

Loading

0 comments on commit ef6d68b

Please sign in to comment.