Skip to content

Commit

Permalink
refactor the code
Browse files Browse the repository at this point in the history
  • Loading branch information
FoolPlayer committed Oct 20, 2023
1 parent bd00085 commit 64c1f4f
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 277 deletions.
2 changes: 1 addition & 1 deletion colossalai/inference/pipeline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInfer

input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
output = inferengine.inference([data.to('cuda').data])
output = inferengine.inference(data.to('cuda'))


```
Expand Down
120 changes: 0 additions & 120 deletions colossalai/inference/pipeline/batch_infer_state.py

This file was deleted.

20 changes: 17 additions & 3 deletions colossalai/inference/pipeline/engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding

from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.generate import GenerateSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy

from .kvcache_manager import MemoryManager
from ..tensor_parallel.kvcache_manager import MemoryManager
from .microbatch_manager import MicroBatchManager


Expand Down Expand Up @@ -38,7 +39,7 @@ class PPInferEngine:
colossalai.launch_from_torch(config={})
model = LlamaForCausalLM.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
Expand Down Expand Up @@ -103,7 +104,20 @@ def __init__(
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)

def inference(self, input_list):
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
"""
Args:
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
Returns:
out (list): a list of output data, each element is a list of token.
timestamp (float): the time cost of the inference, only return when verbose is `True`.
"""
assert isinstance(
input_list, (BatchEncoding, dict)
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
if isinstance(input_list, BatchEncoding):
input_list = input_list.data
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
if self.verbose:
return out, timestamp
else:
Expand Down
104 changes: 0 additions & 104 deletions colossalai/inference/pipeline/kvcache_manager.py

This file was deleted.

4 changes: 2 additions & 2 deletions colossalai/inference/pipeline/microbatch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch

from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager
from ..tensor_parallel.batch_infer_state import BatchInferState
from ..tensor_parallel.kvcache_manager import MemoryManager

__all__ = "MicroBatchManager"

Expand Down
14 changes: 11 additions & 3 deletions colossalai/inference/pipeline/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from transformers.utils import logging

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.pipeline.stage_manager import PipelineStageManager

from ._utils import copy_kv_to_mem_cache
Expand All @@ -31,6 +31,14 @@
)
HAS_VLLM_KERNERL = False

try:
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd

HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand Down Expand Up @@ -363,8 +371,8 @@ def llama_flash_attn_kvcache_forward(

cos, sin = infer_state.position_cos, infer_state.position_sin

rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)

query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
Expand Down
35 changes: 0 additions & 35 deletions colossalai/inference/pipeline/utils.py

This file was deleted.

Loading

0 comments on commit 64c1f4f

Please sign in to comment.