Skip to content

Commit

Permalink
[misc] use tempfile
Browse files Browse the repository at this point in the history
  • Loading branch information
Hz188 committed Jun 27, 2024
1 parent 3a25166 commit 502e514
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions tests/test_moe/test_moe_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import shutil
import tempfile
from copy import deepcopy

import pytest
Expand All @@ -19,17 +21,17 @@
hidden_size = 8
top_k = 2

# Fixed temporary directory for all ranks
TEMP_DIR_BASE = "/tmp"
TEMP_DIR_NAME = "mixtral_test"


def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
if not torch.equal(p1.half(), p2.half()):
# exit distributed
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
raise AssertionError(f"Model parameter {name} is not equal")
# dist.destroy_process_group()
# exit(1)
# print(f"Passed: {name}")


def get_optimizer_snapshot(optim):
Expand All @@ -49,7 +51,6 @@ def get_optimizer_snapshot(optim):


def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_group=None):
# check param_groups
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
assert set(group1.keys()) == set(group2.keys())
Expand All @@ -75,13 +76,14 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_grou
assert state1[k] == state2[k]
if bug:
passed = False
# print(f"rank {dist.get_rank()} optim mismatch: {param2name[pid]}")

if not passed:
raise AssertionError(f"A total of {count} optim states are not equal")


def check_mixtral_moe_layer():
if dist.get_rank() == 0:
tmpdirname = tempfile.mkdtemp(dir=TEMP_DIR_BASE, prefix=TEMP_DIR_NAME)
torch.cuda.set_device(dist.get_rank())
config = MixtralConfig(
hidden_size=hidden_size,
Expand Down Expand Up @@ -117,28 +119,32 @@ def check_mixtral_moe_layer():
optimizer,
)

# check save model
booster.save_model(model, "mixtral_model", shard=True)
tmpdirname = os.path.join(TEMP_DIR_BASE, TEMP_DIR_NAME)
model_dir = os.path.join(tmpdirname, "mixtral_model")
hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model")
optim_dir = os.path.join(tmpdirname, "mixtral_optim")

booster.save_model(model, model_dir, shard=True)
dist.barrier()
if dist.get_rank() == 0:
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda()
check_model_equal(orig_model, saved_model)
# check_model_equal(model, saved_model)
saved_model.save_pretrained("mixtral_hf_model")
saved_model.save_pretrained(hf_model_dir)
dist.barrier()
# check load model
new_model = MixtralForCausalLM(config).cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, "mixtral_hf_model")
booster.load_model(new_model, hf_model_dir)
check_model_equal(model, new_model)

# check save optimizer
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
snapshot = get_optimizer_snapshot(optimizer.unwrap())
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
booster.save_optimizer(optimizer, optim_dir, shard=True)
dist.barrier()

# working2master = optimizer.get_working_to_master_map()
Expand All @@ -148,16 +154,12 @@ def check_mixtral_moe_layer():
for v in state.values():
if isinstance(v, torch.Tensor):
v.zero_()
booster.load_optimizer(optimizer, "mixtral_optim")
booster.load_optimizer(optimizer, optim_dir)
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
check_optimizer_snapshot_equal(snapshot, loaded_snapshot, None, model)

# Clean up
dist.barrier()
if dist.get_rank() == 0:
shutil.rmtree("mixtral_model")
shutil.rmtree("mixtral_hf_model")
shutil.rmtree("mixtral_optim")
shutil.rmtree(tmpdirname)


def run_dist(rank: int, world_size: int, port: int):
Expand Down

0 comments on commit 502e514

Please sign in to comment.