-
Notifications
You must be signed in to change notification settings - Fork 136
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Qwen2 in LLM Microservice (#133)
* Support qwen2 in llm microservice Signed-off-by: letonghan <[email protected]>
- Loading branch information
Showing
5 changed files
with
419 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
|
||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# HABANA environment | ||
FROM vault.habana.ai/gaudi-docker/1.15.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest as hpu | ||
|
||
ENV LANG=en_US.UTF-8 | ||
ARG REPO=https://github.com/huggingface/optimum-habana.git | ||
ARG REPO_VER=v1.11.1 | ||
|
||
RUN apt-get update && \ | ||
apt-get install git-lfs && \ | ||
git-lfs install && \ | ||
apt-get install -y --no-install-recommends --fix-missing \ | ||
libgl1-mesa-glx \ | ||
libjemalloc-dev \ | ||
vim | ||
|
||
RUN useradd -m -s /bin/bash user && \ | ||
mkdir -p /home/user && \ | ||
chown -R user /home/user/ | ||
|
||
USER user | ||
|
||
COPY comps /home/user/comps | ||
COPY comps/llm/text-generation/qwen2/qwen2.patch /home/user/qwen2.patch | ||
|
||
SHELL ["/bin/bash", "--login", "-c"] | ||
RUN git clone --single-branch -b ${REPO_VER} ${REPO} /optimum-habana | ||
|
||
ENV PYTHONPATH=/root:/home/user | ||
|
||
RUN cd /optimum-habana && git apply /qwen2.patch && \ | ||
cd /optimum-habana/examples/text-generation && pip install -r requirements.txt && \ | ||
cd /optimum-habana && python setup.py install | ||
|
||
WORKDIR /home/user/comps/llms/text-generation/qwen2 | ||
|
||
ENTRYPOINT ["python", "llm.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
from datetime import datetime | ||
|
||
import torch | ||
from fastapi.responses import StreamingResponse | ||
from langsmith import traceable | ||
from utils import initialize_model | ||
|
||
from comps import GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice | ||
|
||
|
||
def warmup(): | ||
input_sentences = ["DeepSpeed is a machine learning framework", "He is working on", "He has a", "He got all"] | ||
input_tokens = tokenizer.batch_encode_plus(input_sentences, return_tensors="pt", padding=True) | ||
for t in input_tokens: | ||
if torch.is_tensor(input_tokens[t]): | ||
input_tokens[t] = input_tokens[t].to("hpu") | ||
for i in range(3): | ||
print(f"Current time: {datetime.now()}") | ||
print(f"Warming up {i+1}...") | ||
outputs = model.generate( | ||
**input_tokens, | ||
generation_config=generation_config, | ||
lazy_mode=True, | ||
hpu_graphs=True, | ||
profiling_steps=0, | ||
profiling_warmup_steps=0, | ||
).cpu() | ||
res = tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||
print(f"res: {res}") | ||
|
||
|
||
@register_microservice( | ||
name="opea_service@llm_qwen", | ||
service_type=ServiceType.LLM, | ||
endpoint="/v1/chat/completions", | ||
host="0.0.0.0", | ||
port=8000, | ||
) | ||
@traceable(run_type="llm") | ||
def llm_generate(input: LLMParamsDoc): | ||
input_query = input.query | ||
input_tokens = tokenizer.batch_encode_plus([input_query], return_tensors="pt", padding=True) | ||
for t in input_tokens: | ||
if torch.is_tensor(input_tokens[t]): | ||
input_tokens[t] = input_tokens[t].to("hpu") | ||
|
||
print(f"[llm - qwen] Current time: {datetime.now()}") | ||
output = model.generate( | ||
**input_tokens, | ||
generation_config=generation_config, | ||
lazy_mode=True, | ||
hpu_graphs=True, | ||
profiling_steps=0, | ||
profiling_warmup_steps=0, | ||
).cpu() | ||
res = tokenizer.batch_decode(output, skip_special_tokens=True)[0] | ||
print(f"[llm - qwen] res: {res}") | ||
return res | ||
|
||
|
||
if __name__ == "__main__": | ||
model, tokenizer, generation_config = initialize_model( | ||
model_name_or_path="Qwen/Qwen1.5-7B-Chat", max_new_tokens=128 | ||
) | ||
import habana_frameworks.torch.hpu as torch_hpu | ||
|
||
print("[llm - qwen] model and tokenizer initialized.") | ||
|
||
from optimum.habana.utils import HabanaProfile | ||
|
||
# compilation stage disable profiling | ||
HabanaProfile.disable() | ||
# Compilation | ||
print("Graph compilation...") | ||
warmup() | ||
print("[llm - qwen] model warm up finished.") | ||
|
||
torch_hpu.synchronize() | ||
HabanaProfile.enable() | ||
print("[llm - qwen] Ready to inference") | ||
|
||
opea_microservices["opea_service@llm_qwen"].start() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py | ||
index b086c80..e0e5a9f 100644 | ||
--- a/examples/text-generation/run_lm_eval.py | ||
+++ b/examples/text-generation/run_lm_eval.py | ||
@@ -75,13 +75,13 @@ class HabanaModelAdapter(lm_eval.base.BaseLM): | ||
self.options = options | ||
self._device = args.device | ||
self.model_inputs = {"use_cache": self.options.use_cache} | ||
- if self.model.config.model_type in ["llama", "falcon"]: | ||
+ if self.model.config.model_type in ["llama", "falcon", "qwen2"]: | ||
self.model_inputs.update( | ||
{ | ||
"reuse_cache": self.options.reuse_cache, | ||
} | ||
) | ||
- if self.model.config.model_type == "llama": | ||
+ if self.model.config.model_type in ["llama","mistral","qwen2"]: | ||
self.model_inputs.update( | ||
{ | ||
"attn_softmax_bf16": self.options.attn_softmax_bf16, | ||
diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py | ||
index 8bce0ae..c29f458 100644 | ||
--- a/examples/text-generation/utils.py | ||
+++ b/examples/text-generation/utils.py | ||
@@ -234,7 +234,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): | ||
|
||
model = deepspeed.init_inference(model, **ds_inference_kwargs) | ||
model = model.module | ||
- if model.config.model_type in ["llama", "falcon"]: | ||
+ if model.config.model_type in ["llama", "falcon","qwen2"]: | ||
patch_scoped_linear_all_reduce(model) | ||
|
||
if args.quant_config: | ||
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py | ||
index 0d50470..94cc7eb 100755 | ||
--- a/optimum/habana/transformers/generation/utils.py | ||
+++ b/optimum/habana/transformers/generation/utils.py | ||
@@ -740,7 +740,7 @@ class GaudiGenerationMixin(GenerationMixin): | ||
) | ||
model_kwargs["kv_cache_len"] = calculated_max_length | ||
|
||
- if self.config.model_type in ["llama", "falcon"]: | ||
+ if self.config.model_type in ["llama", "falcon","qwen2"]: | ||
if self.config.max_position_embeddings < calculated_max_length: | ||
unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) | ||
|
||
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py | ||
index 6dc40a7..b5044af 100644 | ||
--- a/optimum/habana/transformers/modeling_utils.py | ||
+++ b/optimum/habana/transformers/modeling_utils.py | ||
@@ -55,6 +55,9 @@ from .models import ( | ||
GaudiOPTForCausalLM, | ||
GaudiOPTLearnedPositionalEmbedding, | ||
GaudiPhiForCausalLM, | ||
+ GaudiQwen2Model, | ||
+ GaudiQwen2Attention, | ||
+ GaudiQwen2MLP, | ||
_gaudi_wav2vec2_compute_mask_indices, | ||
_gaudi_wav2vec2_mask_hidden_states, | ||
gaudi_albert_forward, | ||
@@ -118,6 +121,7 @@ from .models import ( | ||
gaudi_phi_attention_forward, | ||
gaudi_phi_decoder_layer_forward, | ||
gaudi_phi_model_forward, | ||
+ gaudi_qwen2_rmsnorm_forward, | ||
gaudi_rot_matmul, | ||
gaudi_rot_vec_mul, | ||
gaudi_SpeechT5Attention_forward, | ||
@@ -367,3 +371,11 @@ def adapt_transformers_to_gaudi(): | ||
transformers.models.speecht5.modeling_speecht5.SpeechT5SpeechDecoderPrenet.forward = ( | ||
gaudi_SpeechT5SpeechDecoderPrenet_forward | ||
) | ||
+ | ||
+ # Optimization for qwen2 on Gaudi | ||
+ transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM = GaudiQwen2ForCausalLM | ||
+ transformers.models.qwen2.modeling_qwen2.Qwen2Model = GaudiQwen2Model | ||
+ transformers.models.qwen2.modeling_qwen2.Qwen2Attention = GaudiQwen2Attention | ||
+ transformers.models.qwen2.modeling_qwen2.Qwen2MLP = GaudiQwen2MLP | ||
+ transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer = GaudiQwen2DecoderLayer | ||
+ transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm.forward = gaudi_qwen2_rmsnorm_forward | ||
diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py | ||
index 1582d3f..41fdfdc 100644 | ||
--- a/optimum/habana/transformers/models/__init__.py | ||
+++ b/optimum/habana/transformers/models/__init__.py | ||
@@ -122,6 +122,14 @@ from .phi import ( | ||
gaudi_phi_decoder_layer_forward, | ||
gaudi_phi_model_forward, | ||
) | ||
+from .qwen2 import ( | ||
+ GaudiQwen2Attention, | ||
+ GaudiQwen2DecoderLayer, | ||
+ GaudiQwen2ForCausalLM, | ||
+ GaudiQwen2MLP, | ||
+ GaudiQwen2Model, | ||
+ gaudi_qwen2_rmsnorm_forward, | ||
+) | ||
from .speecht5 import ( | ||
gaudi_generate_speech, | ||
gaudi_SpeechT5Attention_forward, | ||
diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py | ||
index dc6e136..7dfebaa 100644 | ||
--- a/optimum/habana/transformers/trainer.py | ||
+++ b/optimum/habana/transformers/trainer.py | ||
@@ -916,9 +916,9 @@ class GaudiTrainer(Trainer): | ||
if step % args.gradient_accumulation_steps == 0: | ||
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) | ||
|
||
- # attn_softmax_bf16 and use_flash_attention is enabled only for llama | ||
+ # attn_softmax_bf16 and use_flash_attention is enabled only for llama and qwen2 | ||
if hasattr(self.model, "generation_config") and self.model.generation_config is not None: | ||
- if self.model.config.model_type == "llama": | ||
+ if self.model.config.model_type in ["llama", "qwen2"]: | ||
if self.model.generation_config.attn_softmax_bf16: | ||
inputs["attn_softmax_bf16"] = True | ||
if self.model.generation_config.use_flash_attention: | ||
@@ -1799,9 +1799,9 @@ class GaudiTrainer(Trainer): | ||
if batch_size is None: | ||
batch_size = observed_batch_size | ||
|
||
- # attn_softmax_bf16 and use_flash_attention are enabled only for llama | ||
+ # attn_softmax_bf16 and use_flash_attention are enabled only for llama and qwen2 | ||
if hasattr(self.model, "generation_config") and self.model.generation_config is not None: | ||
- if self.model.config.model_type == "llama": | ||
+ if self.model.config.model_type in ["llama", "qwen2"]: | ||
if self.model.generation_config.attn_softmax_bf16: | ||
inputs["attn_softmax_bf16"] = True | ||
if self.model.generation_config.use_flash_attention: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
docarray[full] | ||
fastapi | ||
langsmith | ||
opentelemetry-api | ||
opentelemetry-exporter-otlp | ||
opentelemetry-sdk | ||
shortuuid | ||
transformers |
Oops, something went wrong.