Skip to content

Commit

Permalink
Add gemma and update recent changes to multiple host (#74)
Browse files Browse the repository at this point in the history
add gemma and update recent changes to  multiple host
  • Loading branch information
FanhaiLu1 authored May 9, 2024
1 parent dab2d7a commit 9c0d2ac
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 12 deletions.
7 changes: 7 additions & 0 deletions jetstream_pt/ray_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,14 @@ def create_pytorch_ray_engine(
quantize_weights=False,
quantize_kv=False,
max_cache_length=1024,
sharding_config=None,
) -> PyTorchRayEngine:

supported_models = ["llama-2", "llama-3", "gemma"]
if model_name not in supported_models:
raise NotImplementedError(
f"Model name should be one of{','.join(supported_models)}"
)
ray.init(ignore_reinit_error=True)
pod_name = tpu.get_current_pod_name()
num_hosts = tpu.get_current_pod_worker_count()
Expand Down Expand Up @@ -183,6 +189,7 @@ def create_pytorch_ray_engine(
quantize_weights=quantize_weights,
quantize_kv=quantize_kv,
max_cache_length=max_cache_length,
sharding_config=sharding_config,
)
engine_workers.append(engine_worker)
engine_master = PyTorchRayEngine(
Expand Down
50 changes: 38 additions & 12 deletions jetstream_pt/ray_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, List, Optional, Tuple, Union
import threading
import functools
import os
import humanize


Expand All @@ -39,6 +40,7 @@
from jetstream_pt import cache_manager
from jetstream_pt import quantize
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model


Mesh = jax.sharding.Mesh
Expand Down Expand Up @@ -103,6 +105,7 @@ def __init__(
quantize_weights=False,
quantize_kv=False,
max_cache_length=1024,
sharding_config=None,
):

jax.config.update("jax_default_prng_impl", "unsafe_rbg")
Expand Down Expand Up @@ -144,38 +147,61 @@ def __init__(
checkpoint_format = "safetensors"
checkpoint_path = paths[0]

if not sharding_config:
sharding_config = os.path.join("default_shardings", model_name + ".yaml")

env_data = JetEngineEnvironmentData(
tokenizer_path=tokenizer_path,
checkpoint_path=checkpoint_path,
checkpoint_format=checkpoint_format,
model_type="llama-2-" + param_size,
batch_size=batch_size,
max_decode_length=max_decode_length,
max_input_sequence_length=context_length,
enable_weight_quantization=quantize_weights,
enable_kv_quantization=quantize_kv,
cache_sequence_length=max_cache_length,
bf16_enable=bf16_enable,
sharding_config_path=sharding_config,
)
env = JetEngineEnvironment(env_data)

pt_model = None
if "llama" in model_name:
if model_name.startswith("llama"):

args = model_args.get_model_args(
model_name + "-" + param_size,
context_length,
batch_size,
bf16_enable,
model_name + "-" + param_size, context_length, batch_size, bf16_enable
)
args.device = "meta"
args.quantize = quantize_weights
env_data.cache_shape = (
batch_size,
args.n_kv_heads,
max_cache_length,
args.dim // args.n_heads,
)
env_data.model_type = "llama-2-" + param_size
env_data.num_layers = args.n_layers
env = JetEngineEnvironment(env_data)
pt_model = model_exportable.Transformer(args, env)
elif model_name == "gemma":
args = gemma_config.get_model_config(param_size)
env_data.cache_shape = (
batch_size,
args.num_key_value_heads,
max_cache_length,
args.head_dim,
)
env_data.model_type = "gemma-" + param_size
env_data.num_layers = args.num_hidden_layers
env = JetEngineEnvironment(env_data)
pt_model = gemma_model.GemmaModel(args, env)
else:
raise RuntimeError(f"Model with name {model_name} not found")

num_params_size = 0
num_params = 0
for _, v in pt_model.state_dict().items():
num_params += 1
num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2)
num_params_size = 0
num_params = 0
for _, v in pt_model.state_dict().items():
num_params += 1
num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2)
print("Number of param Gbytes:", num_params_size / (1 << 30))
print("Number of param: ", num_params)

Expand Down
10 changes: 10 additions & 0 deletions run_interactive_multiple_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@
"max_cache_length", 1024, "kv_cache_quantize"
)

_MODEL_NAME = flags.DEFINE_string(
"model_name", None, "model type", required=False
)

_SHARDING_CONFIG = flags.DEFINE_string(
"sharding_config", "", "config file for sharding"
)


def create_engine():
"""create a pytorch engine"""
Expand All @@ -73,6 +81,7 @@ def create_engine():

start = time.perf_counter()
engine = ray_engine.create_pytorch_ray_engine(
model_name=_MODEL_NAME.value,
tokenizer_path=_TOKENIZER_PATH.value,
ckpt_path=_CKPT_PATH.value,
bf16_enable=True,
Expand All @@ -82,6 +91,7 @@ def create_engine():
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
sharding_config=_SHARDING_CONFIG.value,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down

0 comments on commit 9c0d2ac

Please sign in to comment.