Skip to content

Commit

Permalink
Fix llama3 generation
Browse files Browse the repository at this point in the history
  • Loading branch information
satyaog committed Sep 24, 2024
1 parent 134dd47 commit 8cc4926
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
12 changes: 7 additions & 5 deletions benchmarks/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def generate_model(

params = json.loads(params_path.read_text())
model = llama.model.Transformer(ModelArgs(**params))
model.to(torch.bfloat16)
torch.save(model.state_dict(), params_path.with_name(f"consolidated.{rank:02}.pth"))

except Exception as e:
Expand Down Expand Up @@ -117,11 +118,12 @@ def generate_weights(args, config):
else:
# Note that at the time of writing torchtune doesn't support multi-*.pth
# files loading
ctx = multiprocessing.get_context("spawn")
params_path = next(args.output_dir.glob("**/params.json"))
model_parallel_size = len(config["checkpointer"]["checkpoint_files"])
pipes = [multiprocessing.Pipe() for _ in range(model_parallel_size)]
pipes = [ctx.Pipe() for _ in range(model_parallel_size)]
processes = [
multiprocessing.Process(
ctx.Process(
target=generate_model,
args=[conn, params_path, rank, model_parallel_size]
)
Expand Down Expand Up @@ -162,9 +164,9 @@ def main():

#
huggingface_format = config.get("safetensors", False)
pretrained = not config.get("no_pretrained", False)
untrained = config.get("untrained", False)

if not pretrained:
if untrained:
# if we will generate the weights do not download anyweights
ignore_patterns = ["*.safetensors", "*consolidated.*.pth"]

Expand Down Expand Up @@ -203,7 +205,7 @@ def main():
args = parser.parse_args(download_args)
parser.run(args)

if not pretrained:
if untrained:
generate_weights(args, config)

if "qlora" in config.get("model", {}).get("_component_", ""):
Expand Down
12 changes: 6 additions & 6 deletions config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ llm-lora-single:
repo_id="meta-llama/Meta-Llama-3.1-8B": true
batch_size=8: true
gradient_accumulation_steps=8: true
no_pretrained=True: true
untrained=True: true


llm-lora-ddp-gpus:
Expand All @@ -571,7 +571,7 @@ llm-lora-ddp-gpus:
repo_id="meta-llama/Meta-Llama-3.1-8B": true
batch_size=8: true
gradient_accumulation_steps=8: true
no_pretrained=True: true
untrained=True: true


llm-lora-ddp-nodes:
Expand All @@ -595,7 +595,7 @@ llm-lora-ddp-nodes:
repo_id="meta-llama/Meta-Llama-3.1-8B": true
batch_size=8: true
gradient_accumulation_steps=8: true
no_pretrained=True: true
untrained=True: true

num_machines: 2
requires_capabilities:
Expand All @@ -621,7 +621,7 @@ llm-lora-mp-gpus:
repo_id="meta-llama/Meta-Llama-3.1-70B": true
batch_size=8: true
gradient_accumulation_steps=1: true
no_pretrained=True: true
untrained=True: true

llm-full-mp-gpus:
inherits: _llm
Expand All @@ -642,7 +642,7 @@ llm-full-mp-gpus:
safetensors=true: true
batch_size=2: true
gradient_accumulation_steps=1: true
no_pretrained=True: true
untrained=True: true

llm-full-mp-nodes:
tags:
Expand All @@ -666,7 +666,7 @@ llm-full-mp-nodes:
safetensors=true: true
batch_size=2: true
gradient_accumulation_steps=1: true
no_pretrained=True: true
untrained=True: true

num_machines: 2
requires_capabilities:
Expand Down

0 comments on commit 8cc4926

Please sign in to comment.