Skip to content

Commit

Permalink
Fix model configs (Falcon in C++, LLAMA in Python) (#1162)
Browse files Browse the repository at this point in the history
* cleanup

* fix

* fix

* fix

* fix loading of weights

* import configs in models init (python)

* remove unnecessary warning
  • Loading branch information
goliaro authored Oct 1, 2023
1 parent 65cb570 commit 5919fff
Show file tree
Hide file tree
Showing 12 changed files with 60 additions and 156 deletions.
2 changes: 1 addition & 1 deletion inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ void FlexFlow::top_level_task(Task const *task,
} else if (str == "OPTForCausalLM") {
model_type = ModelType::OPT;
break;
} else if (str == "RWForCausalLM") {
} else if (str == "RWForCausalLM" || str == "FalconForCausalLM") {
model_type = ModelType::FALCON;
break;
} else if (str == "GPTBigCodeForCausalLM") {
Expand Down
8 changes: 6 additions & 2 deletions inference/models/falcon.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ class FALCON {
hidden_size = model_config["hidden_size"];
layer_norm_epsilon = model_config["layer_norm_epsilon"];
multi_query = model_config["multi_query"];
n_head = model_config["n_head"];
n_head = (model_config.find("n_head") != model_config.end())
? model_config["n_head"]
: model_config["num_attention_heads"];
if (model_config.contains("n_head_kv")) {
n_head_kv = model_config["n_head_kv"];
} else {
n_head_kv = 1;
}
n_layer = model_config["n_layer"];
n_layer = (model_config.find("n_layer") != model_config.end())
? model_config["n_layer"]
: model_config["num_hidden_layers"];
parallel_attn = model_config["parallel_attn"];
vocab_size = model_config["vocab_size"];
} catch (json::exception const &e) {
Expand Down
2 changes: 1 addition & 1 deletion inference/spec_infer/spec_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void get_model_meta(FilePaths &file_paths,
} else if (str == "OPTForCausalLM") {
model_metadata.llm_model_type = ModelType::OPT;
break;
} else if (str == "RWForCausalLM") {
} else if (str == "RWForCausalLM" || str == "FalconForCausalLM") {
model_metadata.llm_model_type = ModelType::FALCON;
break;
} else if (str == "MPTForCausalLM") {
Expand Down
3 changes: 2 additions & 1 deletion python/flexflow/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def init_flexflow_runtime(configs_dict: Optional[dict] = None, **kwargs):
# Pass parameters to the FlexFlow C++ runtime via command line arguments
for arg in ff_args:
if arg not in ff_arg_to_sysarg:
warnings.warn(f"Ignoring parameter {arg}: not recognized.")
# warnings.warn(f"Ignoring parameter {arg}: not recognized.")
continue
else:
sys_arg = [ff_arg_to_sysarg[arg]]
if type(ff_args[arg]) == bool:
Expand Down
10 changes: 5 additions & 5 deletions python/flexflow/serve/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .llama import FlexFlowLLAMA
from .opt import FlexFlowOPT
from .falcon import FlexFlowFalcon
from .starcoder import FlexFlowSTARCODER
from .mpt import FlexFlowMPT
from .llama import FlexFlowLLAMA, LLAMAConfig
from .opt import FlexFlowOPT, OPTConfig
from .falcon import FlexFlowFalcon, FalconConfig
from .starcoder import FlexFlowSTARCODER, STARCODERConfig
from .mpt import FlexFlowMPT, MPTConfig
3 changes: 0 additions & 3 deletions python/flexflow/serve/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,3 @@ def build_model(self):

def convert_hf_model(model, dst_folder):
assert False, "Not implemented yet"

def get_layers_with_weights(self):
assert False, "Not implemented yet"
25 changes: 3 additions & 22 deletions python/flexflow/serve/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def __init__(self, hf_config):
)
self.parallel_attn = hf_config.parallel_attn
self.vocab_size = hf_config.vocab_size
# Standardized FlexFlow num heads fields below
self.num_attention_heads = self.n_head
self.num_key_value_heads = self.n_head_kv


class FlexFlowFalcon(FlexFlowModel):
Expand Down Expand Up @@ -277,25 +280,3 @@ def convert_hf_model(model, dst_folder):
model.lm_head.weight.detach().cpu().numpy().tofile(
os.path.join(dst_folder, "lm_head_weight")
)

def get_layers_with_weights(self):
layer_names = [
"word_embeddings_weight",
"ln_f_weight",
"lm_head_weight",
] + [
expr
for i in range(self.falcon_config.n_layer)
for expr in (
f"layers_{i}_input_layernorm_weight",
f"layers_{i}_attention_weight",
f"layers_{i}_mlp_dense_h_to_4h_weight",
f"layers_{i}_mlp_dense_4h_to_h_weight",
)
]
layers_with_weights = {
layer_name: self.ffmodel.get_layer_by_name(layer_name)
for layer_name in layer_names
}

return layers_with_weights
27 changes: 4 additions & 23 deletions python/flexflow/serve/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ def __init__(self, hf_config):
self.max_beam_depth = 8
self.num_hidden_layers = hf_config.num_hidden_layers
self.vocab_size = hf_config.vocab_size
self.hidden_size = hf_config.hidden_size
self.rms_norm_eps = hf_config.rms_norm_eps
self.intermediate_size = hf_config.intermediate_size
# Standardized FlexFlow num heads fields below
self.num_attention_heads = hf_config.num_attention_heads
self.num_key_value_heads = (
hf_config.num_attention_heads
if hf_config.num_key_value_heads is None
else hf_config.num_key_value_heads
)
self.hidden_size = hf_config.hidden_size
self.rms_norm_eps = hf_config.rms_norm_eps
self.intermediate_size = hf_config.intermediate_size


class FlexFlowLLAMA(FlexFlowModel):
Expand Down Expand Up @@ -262,23 +263,3 @@ def convert_hf_model(model, dst_folder):
.replace("model_", "")
)
params.detach().cpu().numpy().tofile(f"{dst_folder}/{name}")

def get_layers_with_weights(self):
layer_names = ["tok_embeddings_weight", "norm_weight", "output_weight"] + [
expr
for i in range(self.llama_config.num_hidden_layers)
for expr in (
f"layers_{i}_attention_norm_weight",
f"layers_{i}_attention_weight",
f"layers_{i}_ffn_norm_weight",
f"layers_{i}_feed_forward_w1_weight",
f"layers_{i}_feed_forward_w3_weight",
f"layers_{i}_feed_forward_w2_weight",
)
]
layers_with_weights = {
layer_name: self.ffmodel.get_layer_by_name(layer_name)
for layer_name in layer_names
}

return layers_with_weights
28 changes: 3 additions & 25 deletions python/flexflow/serve/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def __init__(self, hf_config):
self.n_heads = hf_config.n_heads
self.n_layers = hf_config.n_layers
self.vocab_size = hf_config.vocab_size
hf_config.num_attention_heads = hf_config.n_heads
hf_config.hidden_size = hf_config.d_model
# Standardized FlexFlow num heads fields below
self.num_attention_heads = hf_config.n_heads
self.num_key_value_heads = hf_config.n_heads


class FlexFlowMPT(FlexFlowModel):
Expand Down Expand Up @@ -274,26 +275,3 @@ def convert_hf_model(model, dst_folder):
os.path.join(dst_folder, "transformer_wte_weight"),
os.path.join(dst_folder, "lm_head_weight"),
)

def get_layers_with_weights(self):
layer_names = [
"transformer_wte_weight",
"transformer_norm_f_weight",
"lm_head_weight",
] + [
expr
for i in range(self.mpt_config.n_layers)
for expr in (
f"layers_{i}_norm_1_weight",
f"layers_{i}_attention_weight",
f"layers_{i}_norm_2_weight",
f"layers_{i}_ffn_up_proj_weight",
f"layers_{i}_ffn_down_proj_weight",
)
]
layers_with_weights = {
layer_name: self.ffmodel.get_layer_by_name(layer_name)
for layer_name in layer_names
}

return layers_with_weights
28 changes: 3 additions & 25 deletions python/flexflow/serve/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ def __init__(self, hf_config):
self.hidden_size = hf_config.hidden_size
self.layer_norm_elementwise_affine = hf_config.layer_norm_elementwise_affine
self.max_position_embeddings = hf_config.max_position_embeddings
self.num_attention_heads = hf_config.num_attention_heads
self.num_hidden_layers = hf_config.num_hidden_layers
self.vocab_size = hf_config.vocab_size
self.word_embed_proj_dim = hf_config.word_embed_proj_dim
# Standardized FlexFlow num heads fields below
self.num_attention_heads = hf_config.num_attention_heads
self.num_key_value_heads = hf_config.num_attention_heads


class FlexFlowOPT(FlexFlowModel):
Expand Down Expand Up @@ -297,27 +299,3 @@ def convert_hf_model(model, dst_folder):
os.path.join(dst_folder, "embed_tokens_weight"),
os.path.join(dst_folder, "embed_tokens_weight_lm_head"),
)

def get_layers_with_weights(self):
layer_names = [
"embed_tokens_weight",
"embed_positions_weight",
"final_layer_norm_weight",
"embed_tokens_weight_lm_head",
] + [
expr
for i in range(self.opt_config.num_hidden_layers)
for expr in (
f"layers_{i}_attention_layer_norm_weight",
f"layers_{i}_attention_weight",
f"layers_{i}_final_layer_norm_weight",
f"layers_{i}_fc1_weight",
f"layers_{i}_fc2_weight",
)
]
layers_with_weights = {
layer_name: self.ffmodel.get_layer_by_name(layer_name)
for layer_name in layer_names
}

return layers_with_weights
28 changes: 3 additions & 25 deletions python/flexflow/serve/models/starcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ def __init__(self, hf_config):
self.hidden_size = hf_config.n_embd
self.layer_norm_epsilon = hf_config.layer_norm_epsilon
self.max_position_embeddings = hf_config.n_positions
self.num_attention_heads = hf_config.n_head
self.num_hidden_layers = hf_config.n_layer
self.vocab_size = hf_config.vocab_size
self.intermediate_size = hf_config.n_inner
self.n_head_kv = 1 if hf_config.multi_query else hf_config.n_head
# Standardized FlexFlow num heads fields below
self.num_attention_heads = hf_config.n_head
self.num_key_value_heads = self.n_head_kv


class FlexFlowSTARCODER(FlexFlowModel):
Expand Down Expand Up @@ -266,27 +268,3 @@ def convert_hf_model(model, dst_folder):
model.lm_head.weight.detach().cpu().numpy().tofile(
os.path.join(dst_folder, "lm_head_weight")
)

def get_layers_with_weights(self):
layer_names = [
"transformer_wte_weight",
"transformer_wpe_weight",
"transformer_ln_f_weight",
"lm_head_weight",
] + [
expr
for i in range(self.starcoder_config.num_hidden_layers)
for expr in (
f"layers_{i}_ln_1_weight",
f"layers_{i}_attention_weight",
f"layers_{i}_ln_2_weight",
f"layers_{i}_mlp_c_fc_weight",
f"layers_{i}_mlp_c_proj_weight",
)
]
layers_with_weights = {
layer_name: self.ffmodel.get_layer_by_name(layer_name)
for layer_name in layer_names
}

return layers_with_weights
52 changes: 29 additions & 23 deletions python/flexflow/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
FlexFlowSTARCODER,
FlexFlowMPT,
)
from flexflow.serve.models import (
LLAMAConfig,
OPTConfig,
FalconConfig,
STARCODERConfig,
MPTConfig,
)
from flexflow.core import *
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from huggingface_hub import HfApi
Expand Down Expand Up @@ -86,17 +93,25 @@ def __init__(
:type output_file: str, optional
"""
self.supported_models = {
"LlamaForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA),
"LLaMAForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA),
"OPTForCausalLM": (ModelType.OPT, FlexFlowOPT),
"RWForCausalLM": (ModelType.FALCON, FlexFlowFalcon),
"FalconForCausalLM": (ModelType.FALCON, FlexFlowFalcon),
"GPTBigCodeForCausalLM": (ModelType.STARCODER, FlexFlowSTARCODER),
"MPTForCausalLM": (ModelType.MPT, FlexFlowMPT),
"LlamaForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA, LLAMAConfig),
"LLaMAForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA, LLAMAConfig),
"OPTForCausalLM": (ModelType.OPT, FlexFlowOPT, OPTConfig),
"RWForCausalLM": (ModelType.FALCON, FlexFlowFalcon, FalconConfig),
"FalconForCausalLM": (ModelType.FALCON, FlexFlowFalcon, FalconConfig),
"GPTBigCodeForCausalLM": (
ModelType.STARCODER,
FlexFlowSTARCODER,
STARCODERConfig,
),
"MPTForCausalLM": (ModelType.MPT, FlexFlowMPT, MPTConfig),
}
self.hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
self.model_name = self.hf_config._name_or_path
self.model_type, self.model_class = self.__get_ff_model_type()
(
self.model_type,
self.model_class,
self.config_class,
) = self.__get_ff_model_type()
self.data_type = data_type
assert self.data_type == DataType.DT_HALF or self.data_type == DataType.DT_FLOAT
self.cache_path = cache_path if len(cache_path) > 0 else "~/.cache/flexflow"
Expand Down Expand Up @@ -274,23 +289,14 @@ def __load_hf_weights(self):
self.download_hf_weights_if_needed()

# Create file data loader, load weights into tensors
if (
self.model_type == ModelType.FALCON
or self.model_type == ModelType.STARCODER
):
n_q_heads = self.hf_config.num_attention_heads
if "n_head_kv" in self.hf_config.__dict__:
n_kv_heads = self.hf_config.n_head_kv
else:
n_kv_heads = 1
else:
n_q_heads = n_kv_heads = self.hf_config.num_attention_heads
model_configs = self.config_class(self.hf_config)

self.fileloader = FileDataLoader(
self.weights_path,
n_q_heads,
n_kv_heads,
self.hf_config.hidden_size,
self.hf_config.hidden_size // n_q_heads,
model_configs.num_attention_heads,
model_configs.num_key_value_heads,
model_configs.hidden_size,
model_configs.hidden_size // model_configs.num_attention_heads,
self.ffconfig.tensor_parallelism_degree,
)

Expand Down

0 comments on commit 5919fff

Please sign in to comment.