From 5919fff9099b50a492edc1a9ce2d94a5868bc779 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sun, 1 Oct 2023 05:47:32 -0400 Subject: [PATCH] Fix model configs (Falcon in C++, LLAMA in Python) (#1162) * cleanup * fix * fix * fix * fix loading of weights * import configs in models init (python) * remove unnecessary warning --- inference/incr_decoding/incr_decoding.cc | 2 +- inference/models/falcon.h | 8 +++- inference/spec_infer/spec_infer.cc | 2 +- python/flexflow/core/__init__.py | 3 +- python/flexflow/serve/models/__init__.py | 10 ++--- python/flexflow/serve/models/base.py | 3 -- python/flexflow/serve/models/falcon.py | 25 ++--------- python/flexflow/serve/models/llama.py | 27 ++---------- python/flexflow/serve/models/mpt.py | 28 ++---------- python/flexflow/serve/models/opt.py | 28 ++---------- python/flexflow/serve/models/starcoder.py | 28 ++---------- python/flexflow/serve/serve.py | 52 +++++++++++++---------- 12 files changed, 60 insertions(+), 156 deletions(-) diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index 3f913e4573..f3fd32878f 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -168,7 +168,7 @@ void FlexFlow::top_level_task(Task const *task, } else if (str == "OPTForCausalLM") { model_type = ModelType::OPT; break; - } else if (str == "RWForCausalLM") { + } else if (str == "RWForCausalLM" || str == "FalconForCausalLM") { model_type = ModelType::FALCON; break; } else if (str == "GPTBigCodeForCausalLM") { diff --git a/inference/models/falcon.h b/inference/models/falcon.h index a822f9be34..6c9124fe4c 100644 --- a/inference/models/falcon.h +++ b/inference/models/falcon.h @@ -37,13 +37,17 @@ class FALCON { hidden_size = model_config["hidden_size"]; layer_norm_epsilon = model_config["layer_norm_epsilon"]; multi_query = model_config["multi_query"]; - n_head = model_config["n_head"]; + n_head = (model_config.find("n_head") != model_config.end()) + ? model_config["n_head"] + : model_config["num_attention_heads"]; if (model_config.contains("n_head_kv")) { n_head_kv = model_config["n_head_kv"]; } else { n_head_kv = 1; } - n_layer = model_config["n_layer"]; + n_layer = (model_config.find("n_layer") != model_config.end()) + ? model_config["n_layer"] + : model_config["num_hidden_layers"]; parallel_attn = model_config["parallel_attn"]; vocab_size = model_config["vocab_size"]; } catch (json::exception const &e) { diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 2b1fb6e817..a95b26c930 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -163,7 +163,7 @@ void get_model_meta(FilePaths &file_paths, } else if (str == "OPTForCausalLM") { model_metadata.llm_model_type = ModelType::OPT; break; - } else if (str == "RWForCausalLM") { + } else if (str == "RWForCausalLM" || str == "FalconForCausalLM") { model_metadata.llm_model_type = ModelType::FALCON; break; } else if (str == "MPTForCausalLM") { diff --git a/python/flexflow/core/__init__.py b/python/flexflow/core/__init__.py index 5e8e4ece81..ace6030a1b 100644 --- a/python/flexflow/core/__init__.py +++ b/python/flexflow/core/__init__.py @@ -124,7 +124,8 @@ def init_flexflow_runtime(configs_dict: Optional[dict] = None, **kwargs): # Pass parameters to the FlexFlow C++ runtime via command line arguments for arg in ff_args: if arg not in ff_arg_to_sysarg: - warnings.warn(f"Ignoring parameter {arg}: not recognized.") + # warnings.warn(f"Ignoring parameter {arg}: not recognized.") + continue else: sys_arg = [ff_arg_to_sysarg[arg]] if type(ff_args[arg]) == bool: diff --git a/python/flexflow/serve/models/__init__.py b/python/flexflow/serve/models/__init__.py index a1ca9152ce..7b0e632f53 100644 --- a/python/flexflow/serve/models/__init__.py +++ b/python/flexflow/serve/models/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .llama import FlexFlowLLAMA -from .opt import FlexFlowOPT -from .falcon import FlexFlowFalcon -from .starcoder import FlexFlowSTARCODER -from .mpt import FlexFlowMPT +from .llama import FlexFlowLLAMA, LLAMAConfig +from .opt import FlexFlowOPT, OPTConfig +from .falcon import FlexFlowFalcon, FalconConfig +from .starcoder import FlexFlowSTARCODER, STARCODERConfig +from .mpt import FlexFlowMPT, MPTConfig diff --git a/python/flexflow/serve/models/base.py b/python/flexflow/serve/models/base.py index b7f4e54fc1..19affd9b47 100644 --- a/python/flexflow/serve/models/base.py +++ b/python/flexflow/serve/models/base.py @@ -34,6 +34,3 @@ def build_model(self): def convert_hf_model(model, dst_folder): assert False, "Not implemented yet" - - def get_layers_with_weights(self): - assert False, "Not implemented yet" diff --git a/python/flexflow/serve/models/falcon.py b/python/flexflow/serve/models/falcon.py index 2fd2f4953f..96268f5347 100644 --- a/python/flexflow/serve/models/falcon.py +++ b/python/flexflow/serve/models/falcon.py @@ -40,6 +40,9 @@ def __init__(self, hf_config): ) self.parallel_attn = hf_config.parallel_attn self.vocab_size = hf_config.vocab_size + # Standardized FlexFlow num heads fields below + self.num_attention_heads = self.n_head + self.num_key_value_heads = self.n_head_kv class FlexFlowFalcon(FlexFlowModel): @@ -277,25 +280,3 @@ def convert_hf_model(model, dst_folder): model.lm_head.weight.detach().cpu().numpy().tofile( os.path.join(dst_folder, "lm_head_weight") ) - - def get_layers_with_weights(self): - layer_names = [ - "word_embeddings_weight", - "ln_f_weight", - "lm_head_weight", - ] + [ - expr - for i in range(self.falcon_config.n_layer) - for expr in ( - f"layers_{i}_input_layernorm_weight", - f"layers_{i}_attention_weight", - f"layers_{i}_mlp_dense_h_to_4h_weight", - f"layers_{i}_mlp_dense_4h_to_h_weight", - ) - ] - layers_with_weights = { - layer_name: self.ffmodel.get_layer_by_name(layer_name) - for layer_name in layer_names - } - - return layers_with_weights diff --git a/python/flexflow/serve/models/llama.py b/python/flexflow/serve/models/llama.py index b8ea85b287..ba2f6e0826 100644 --- a/python/flexflow/serve/models/llama.py +++ b/python/flexflow/serve/models/llama.py @@ -25,15 +25,16 @@ def __init__(self, hf_config): self.max_beam_depth = 8 self.num_hidden_layers = hf_config.num_hidden_layers self.vocab_size = hf_config.vocab_size + self.hidden_size = hf_config.hidden_size + self.rms_norm_eps = hf_config.rms_norm_eps + self.intermediate_size = hf_config.intermediate_size + # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.num_attention_heads self.num_key_value_heads = ( hf_config.num_attention_heads if hf_config.num_key_value_heads is None else hf_config.num_key_value_heads ) - self.hidden_size = hf_config.hidden_size - self.rms_norm_eps = hf_config.rms_norm_eps - self.intermediate_size = hf_config.intermediate_size class FlexFlowLLAMA(FlexFlowModel): @@ -262,23 +263,3 @@ def convert_hf_model(model, dst_folder): .replace("model_", "") ) params.detach().cpu().numpy().tofile(f"{dst_folder}/{name}") - - def get_layers_with_weights(self): - layer_names = ["tok_embeddings_weight", "norm_weight", "output_weight"] + [ - expr - for i in range(self.llama_config.num_hidden_layers) - for expr in ( - f"layers_{i}_attention_norm_weight", - f"layers_{i}_attention_weight", - f"layers_{i}_ffn_norm_weight", - f"layers_{i}_feed_forward_w1_weight", - f"layers_{i}_feed_forward_w3_weight", - f"layers_{i}_feed_forward_w2_weight", - ) - ] - layers_with_weights = { - layer_name: self.ffmodel.get_layer_by_name(layer_name) - for layer_name in layer_names - } - - return layers_with_weights diff --git a/python/flexflow/serve/models/mpt.py b/python/flexflow/serve/models/mpt.py index 6e1ca9fdfa..43a2514394 100644 --- a/python/flexflow/serve/models/mpt.py +++ b/python/flexflow/serve/models/mpt.py @@ -27,8 +27,9 @@ def __init__(self, hf_config): self.n_heads = hf_config.n_heads self.n_layers = hf_config.n_layers self.vocab_size = hf_config.vocab_size - hf_config.num_attention_heads = hf_config.n_heads - hf_config.hidden_size = hf_config.d_model + # Standardized FlexFlow num heads fields below + self.num_attention_heads = hf_config.n_heads + self.num_key_value_heads = hf_config.n_heads class FlexFlowMPT(FlexFlowModel): @@ -274,26 +275,3 @@ def convert_hf_model(model, dst_folder): os.path.join(dst_folder, "transformer_wte_weight"), os.path.join(dst_folder, "lm_head_weight"), ) - - def get_layers_with_weights(self): - layer_names = [ - "transformer_wte_weight", - "transformer_norm_f_weight", - "lm_head_weight", - ] + [ - expr - for i in range(self.mpt_config.n_layers) - for expr in ( - f"layers_{i}_norm_1_weight", - f"layers_{i}_attention_weight", - f"layers_{i}_norm_2_weight", - f"layers_{i}_ffn_up_proj_weight", - f"layers_{i}_ffn_down_proj_weight", - ) - ] - layers_with_weights = { - layer_name: self.ffmodel.get_layer_by_name(layer_name) - for layer_name in layer_names - } - - return layers_with_weights diff --git a/python/flexflow/serve/models/opt.py b/python/flexflow/serve/models/opt.py index 639be2d5c4..d51287a181 100644 --- a/python/flexflow/serve/models/opt.py +++ b/python/flexflow/serve/models/opt.py @@ -30,10 +30,12 @@ def __init__(self, hf_config): self.hidden_size = hf_config.hidden_size self.layer_norm_elementwise_affine = hf_config.layer_norm_elementwise_affine self.max_position_embeddings = hf_config.max_position_embeddings - self.num_attention_heads = hf_config.num_attention_heads self.num_hidden_layers = hf_config.num_hidden_layers self.vocab_size = hf_config.vocab_size self.word_embed_proj_dim = hf_config.word_embed_proj_dim + # Standardized FlexFlow num heads fields below + self.num_attention_heads = hf_config.num_attention_heads + self.num_key_value_heads = hf_config.num_attention_heads class FlexFlowOPT(FlexFlowModel): @@ -297,27 +299,3 @@ def convert_hf_model(model, dst_folder): os.path.join(dst_folder, "embed_tokens_weight"), os.path.join(dst_folder, "embed_tokens_weight_lm_head"), ) - - def get_layers_with_weights(self): - layer_names = [ - "embed_tokens_weight", - "embed_positions_weight", - "final_layer_norm_weight", - "embed_tokens_weight_lm_head", - ] + [ - expr - for i in range(self.opt_config.num_hidden_layers) - for expr in ( - f"layers_{i}_attention_layer_norm_weight", - f"layers_{i}_attention_weight", - f"layers_{i}_final_layer_norm_weight", - f"layers_{i}_fc1_weight", - f"layers_{i}_fc2_weight", - ) - ] - layers_with_weights = { - layer_name: self.ffmodel.get_layer_by_name(layer_name) - for layer_name in layer_names - } - - return layers_with_weights diff --git a/python/flexflow/serve/models/starcoder.py b/python/flexflow/serve/models/starcoder.py index feb5be7d75..4eee3182d1 100644 --- a/python/flexflow/serve/models/starcoder.py +++ b/python/flexflow/serve/models/starcoder.py @@ -27,11 +27,13 @@ def __init__(self, hf_config): self.hidden_size = hf_config.n_embd self.layer_norm_epsilon = hf_config.layer_norm_epsilon self.max_position_embeddings = hf_config.n_positions - self.num_attention_heads = hf_config.n_head self.num_hidden_layers = hf_config.n_layer self.vocab_size = hf_config.vocab_size self.intermediate_size = hf_config.n_inner self.n_head_kv = 1 if hf_config.multi_query else hf_config.n_head + # Standardized FlexFlow num heads fields below + self.num_attention_heads = hf_config.n_head + self.num_key_value_heads = self.n_head_kv class FlexFlowSTARCODER(FlexFlowModel): @@ -266,27 +268,3 @@ def convert_hf_model(model, dst_folder): model.lm_head.weight.detach().cpu().numpy().tofile( os.path.join(dst_folder, "lm_head_weight") ) - - def get_layers_with_weights(self): - layer_names = [ - "transformer_wte_weight", - "transformer_wpe_weight", - "transformer_ln_f_weight", - "lm_head_weight", - ] + [ - expr - for i in range(self.starcoder_config.num_hidden_layers) - for expr in ( - f"layers_{i}_ln_1_weight", - f"layers_{i}_attention_weight", - f"layers_{i}_ln_2_weight", - f"layers_{i}_mlp_c_fc_weight", - f"layers_{i}_mlp_c_proj_weight", - ) - ] - layers_with_weights = { - layer_name: self.ffmodel.get_layer_by_name(layer_name) - for layer_name in layer_names - } - - return layers_with_weights diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index 7e340a04e2..eace15f691 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -19,6 +19,13 @@ FlexFlowSTARCODER, FlexFlowMPT, ) +from flexflow.serve.models import ( + LLAMAConfig, + OPTConfig, + FalconConfig, + STARCODERConfig, + MPTConfig, +) from flexflow.core import * from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer from huggingface_hub import HfApi @@ -86,17 +93,25 @@ def __init__( :type output_file: str, optional """ self.supported_models = { - "LlamaForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA), - "LLaMAForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA), - "OPTForCausalLM": (ModelType.OPT, FlexFlowOPT), - "RWForCausalLM": (ModelType.FALCON, FlexFlowFalcon), - "FalconForCausalLM": (ModelType.FALCON, FlexFlowFalcon), - "GPTBigCodeForCausalLM": (ModelType.STARCODER, FlexFlowSTARCODER), - "MPTForCausalLM": (ModelType.MPT, FlexFlowMPT), + "LlamaForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA, LLAMAConfig), + "LLaMAForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA, LLAMAConfig), + "OPTForCausalLM": (ModelType.OPT, FlexFlowOPT, OPTConfig), + "RWForCausalLM": (ModelType.FALCON, FlexFlowFalcon, FalconConfig), + "FalconForCausalLM": (ModelType.FALCON, FlexFlowFalcon, FalconConfig), + "GPTBigCodeForCausalLM": ( + ModelType.STARCODER, + FlexFlowSTARCODER, + STARCODERConfig, + ), + "MPTForCausalLM": (ModelType.MPT, FlexFlowMPT, MPTConfig), } self.hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) self.model_name = self.hf_config._name_or_path - self.model_type, self.model_class = self.__get_ff_model_type() + ( + self.model_type, + self.model_class, + self.config_class, + ) = self.__get_ff_model_type() self.data_type = data_type assert self.data_type == DataType.DT_HALF or self.data_type == DataType.DT_FLOAT self.cache_path = cache_path if len(cache_path) > 0 else "~/.cache/flexflow" @@ -274,23 +289,14 @@ def __load_hf_weights(self): self.download_hf_weights_if_needed() # Create file data loader, load weights into tensors - if ( - self.model_type == ModelType.FALCON - or self.model_type == ModelType.STARCODER - ): - n_q_heads = self.hf_config.num_attention_heads - if "n_head_kv" in self.hf_config.__dict__: - n_kv_heads = self.hf_config.n_head_kv - else: - n_kv_heads = 1 - else: - n_q_heads = n_kv_heads = self.hf_config.num_attention_heads + model_configs = self.config_class(self.hf_config) + self.fileloader = FileDataLoader( self.weights_path, - n_q_heads, - n_kv_heads, - self.hf_config.hidden_size, - self.hf_config.hidden_size // n_q_heads, + model_configs.num_attention_heads, + model_configs.num_key_value_heads, + model_configs.hidden_size, + model_configs.hidden_size // model_configs.num_attention_heads, self.ffconfig.tensor_parallelism_degree, )