diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 86f2d2909475..3522067b545b 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -1,4 +1,6 @@ +import os import shutil +import tempfile from copy import deepcopy import pytest @@ -19,17 +21,17 @@ hidden_size = 8 top_k = 2 +# Fixed temporary directory for all ranks +TEMP_DIR_BASE = "/tmp" +TEMP_DIR_NAME = "mixtral_test" + def check_model_equal(model1, model2): assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())): if not torch.equal(p1.half(), p2.half()): - # exit distributed print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}") raise AssertionError(f"Model parameter {name} is not equal") - # dist.destroy_process_group() - # exit(1) - # print(f"Passed: {name}") def get_optimizer_snapshot(optim): @@ -49,7 +51,6 @@ def get_optimizer_snapshot(optim): def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_group=None): - # check param_groups assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"]) for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]): assert set(group1.keys()) == set(group2.keys()) @@ -75,13 +76,14 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_grou assert state1[k] == state2[k] if bug: passed = False - # print(f"rank {dist.get_rank()} optim mismatch: {param2name[pid]}") if not passed: raise AssertionError(f"A total of {count} optim states are not equal") def check_mixtral_moe_layer(): + if dist.get_rank() == 0: + tmpdirname = tempfile.mkdtemp(dir=TEMP_DIR_BASE, prefix=TEMP_DIR_NAME) torch.cuda.set_device(dist.get_rank()) config = MixtralConfig( hidden_size=hidden_size, @@ -117,20 +119,24 @@ def check_mixtral_moe_layer(): optimizer, ) - # check save model - booster.save_model(model, "mixtral_model", shard=True) + tmpdirname = os.path.join(TEMP_DIR_BASE, TEMP_DIR_NAME) + model_dir = os.path.join(tmpdirname, "mixtral_model") + hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model") + optim_dir = os.path.join(tmpdirname, "mixtral_optim") + + booster.save_model(model, model_dir, shard=True) dist.barrier() if dist.get_rank() == 0: - saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda() + saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda() check_model_equal(orig_model, saved_model) # check_model_equal(model, saved_model) - saved_model.save_pretrained("mixtral_hf_model") + saved_model.save_pretrained(hf_model_dir) dist.barrier() # check load model new_model = MixtralForCausalLM(config).cuda() new_optimizer = Adam(new_model.parameters(), lr=1e-3) new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) - booster.load_model(new_model, "mixtral_hf_model") + booster.load_model(new_model, hf_model_dir) check_model_equal(model, new_model) # check save optimizer @@ -138,7 +144,7 @@ def check_mixtral_moe_layer(): for group in optimizer.param_groups: group["lr"] = 0.1 snapshot = get_optimizer_snapshot(optimizer.unwrap()) - booster.save_optimizer(optimizer, "mixtral_optim", shard=True) + booster.save_optimizer(optimizer, optim_dir, shard=True) dist.barrier() # working2master = optimizer.get_working_to_master_map() @@ -148,16 +154,12 @@ def check_mixtral_moe_layer(): for v in state.values(): if isinstance(v, torch.Tensor): v.zero_() - booster.load_optimizer(optimizer, "mixtral_optim") + booster.load_optimizer(optimizer, optim_dir) loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap()) check_optimizer_snapshot_equal(snapshot, loaded_snapshot, None, model) - - # Clean up dist.barrier() if dist.get_rank() == 0: - shutil.rmtree("mixtral_model") - shutil.rmtree("mixtral_hf_model") - shutil.rmtree("mixtral_optim") + shutil.rmtree(tmpdirname) def run_dist(rank: int, world_size: int, port: int):