Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix inspect grad mean/std from None to 0 #60

Merged
merged 1 commit into from
Jan 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions bmtrain/inspect/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def inspect_pipeline_transformer_block_list(pipe_model: PipelineTransformerBlock
"shape": tuple(shape),
"std": p.std().cpu().item(),
"mean": p.mean().cpu().item(),
"grad_std": None,
"grad_mean": None,
"grad_std": 0.,
"grad_mean": 0.,
"max": p.max().cpu().item(),
"min": p.min().cpu().item(),
}
Expand Down Expand Up @@ -180,8 +180,8 @@ def inspect_checkpoint_block(model : CheckpointBlock, param_name : str, prefix :
"shape": tuple(shape),
"std": p.std().cpu().item(),
"mean": p.mean().cpu().item(),
"grad_std": None,
"grad_mean": None,
"grad_std": 0.,
"grad_mean": 0.,
"max": p.max().cpu().item(),
"min": p.min().cpu().item(),
})
Expand Down Expand Up @@ -236,8 +236,8 @@ def inspect_model(model : torch.nn.Module, param_name : str, prefix : str = ''):
stats["grad_std"] = g.std().cpu().item()
stats["grad_mean"] = g.mean().cpu().item()
else:
stats["grad_std"] = None
stats["grad_mean"] = None
stats["grad_std"] = 0.
stats["grad_mean"] = 0.
ret.append(stats)
for name, module in model._modules.items():
ret.extend(inspect_model(module, param_name, prefix + name + '.'))
Expand Down
1 change: 1 addition & 0 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
("init_parameters_multi_gpu", 4),

("requires_grad", 1),
("requires_grad_multi_gpu", 2),
("has_inf_nan", 1),
("dropout", 1),
("loss_func", 1),
Expand Down
3 changes: 3 additions & 0 deletions tests/test_init_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ def test_main():
manual_seed(33)
m[2] = Linear_BMTInitializer(*shape)
bmt.init_parameters(m[2])
bmt.synchronize()
ret[2] = (m[2].weight.data, m[2].bias.data)

manual_seed(33)
m[3] = Linear_ManualInitBefore(*shape)
bmt.synchronize()
ret[3] = (m[3].weight.data, m[3].bias.data)

# manual_seed(33)
Expand Down Expand Up @@ -211,6 +213,7 @@ def test_main():
print(ret[i])
for i in range(10):
for j in range(10):
print(i, j)
assert_all_eq(ret[i][0], ret[j][0])
assert_all_eq(ret[i][1], ret[j][1])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len):
inspector.get_summary()
) + "\n"

return ret.replace("None ", "0.0000") + "\n" # replace for matching None grad with zero_grad
return ret + "\n" # replace for matching None grad with zero_grad

def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256):
ret = ""
Expand All @@ -215,8 +215,7 @@ def test_main():
assert len(words) == len(words2)
for w, w2 in zip(words, words2):
try:
if isinstance(eval(w), float):
is_float = True
is_float = isinstance(eval(w), float)
except:
is_float = False
if is_float:
Expand Down
34 changes: 24 additions & 10 deletions tests/test_middle_hidden.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid
ret += bmt.inspect.format_summary(
bmt.inspect.inspect_model(m, '*')
)
return ret.replace("None ", "0.0000") + "\n" # replace for matching None grad with zero_grad
return ret + "\n" # replace for matching None grad with zero_grad

def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256):
ret = ""
Expand All @@ -181,16 +181,30 @@ def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256):
return ret

def test_main():
ret = []
ret.append( run("normal", Model_NORMAL) )
ret.append( run("block", Model_BLOCK) )
ret.append( run("zero", Model_ZERO) )
ret.append( run("pipe", Model_PIPE) )
for r in ret:
ret = {}
ret["normal"] = run("normal", Model_NORMAL)
ret["block"] = run("block", Model_BLOCK)
ret["zero"] = run("zero", Model_ZERO)
ret["pipe"] = run("pipe", Model_PIPE)
for k, r in ret.items():
bmt.print_rank(f"============={k}============")
bmt.print_rank(r)
for r in ret:
for r2 in ret:
assert_eq(r, r2)
for r in ret.values():
for r2 in ret.values():
lines, lines2 = r.split('\n'), r2.split('\n')
assert len(lines) == len(lines2)
for line, line2 in zip(lines, lines2):
words, words2 = line.split(), line2.split()
assert len(words) == len(words2)
for w, w2 in zip(words, words2):
try:
is_float = isinstance(eval(w), float)
except:
is_float = False
if is_float:
assert_lt(abs(float(w)-float(w2)), 2.)
else:
assert_eq(w, w2)

if __name__ == "__main__":
bmt.init_distributed(pipe_size=4)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_other_hidden.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_pre=False, only_post
ret += bmt.inspect.format_summary(
bmt.inspect.inspect_model(m, '*')
)
return ret.replace("None ", "0.0000") + "\n" # replace for matching None grad with zero_grad
return ret + "\n" # replace for matching None grad with zero_grad

def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256):
ret = ""
Expand Down
34 changes: 32 additions & 2 deletions tests/test_requires_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from bmtrain import config
from bmtrain.block_layer import CheckpointBlockContext, CheckpointBlock, TransformerBlockList
from bmtrain.pipe_layer import PipelineTransformerBlockList
from typing import List
import torch.nn.functional as F

Expand Down Expand Up @@ -67,7 +68,36 @@ def test_main():
assert_neq(sm2.split('\n')[1], sm3.split('\n')[1])
assert_eq(sm2.split('\n')[2], sm3.split('\n')[2])

def test_main_pipe():
a = Linear(256, 256)
b = Linear(256, 256)
m = PipelineTransformerBlockList([CheckpointBlock(a), CheckpointBlock(b)])
bmt.init_parameters(m)

a.bias.requires_grad_(False)
awg, abg, sm1 = run(m, a, b)
print(awg, abg, sm1)
assert_eq((awg, abg), (False, True))
assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"])

a.weight.requires_grad_(False)
a.bias.requires_grad_(True)
awg, abg, sm2 = run(m, a, b)
print(awg, abg, sm2)
assert_eq((awg, abg), (False, False))
assert_eq(sm1.split('\n')[1], sm2.split('\n')[1])
assert_neq(sm1.split('\n')[2], sm2.split('\n')[2])

a.weight.requires_grad_(True)
a.bias.requires_grad_(False)
awg, abg, sm3 = run(m, a, b)
print(awg, abg, sm3)
assert_eq((awg, abg), (False, False))
assert_neq(sm2.split('\n')[1], sm3.split('\n')[1])
assert_eq(sm2.split('\n')[2], sm3.split('\n')[2])

if __name__ == "__main__":
bmt.init_distributed()
bmt.init_distributed(pipe_size=1)

test_main()
test_main()
test_main_pipe()
96 changes: 96 additions & 0 deletions tests/test_requires_grad_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from utils import *

import bmtrain as bmt
import torch
from bmtrain import config
from bmtrain.block_layer import CheckpointBlockContext, CheckpointBlock, TransformerBlockList
from bmtrain.pipe_layer import PipelineTransformerBlockList
from typing import List
import torch.nn.functional as F

class Linear(bmt.DistributedModule):
def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None:
super().__init__()

self.in_features = in_features
self.out_features = out_features
self.out = {}
if init_weight:
self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features))
else:
self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_)

if init_bias:
self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,))
else:
self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_)

def forward(self, input):
ret = F.linear(input, self.weight, self.bias)
return ret

def run(m, a, b):
inp = torch.rand((1, 10, 256)).cuda()*100
logits = m(inp)
loss = logits.sum()
loss.backward()

sm = bmt.inspect.format_summary(
bmt.inspect.inspect_model(m, '*')
)
return sm

def test_main():
a = Linear(256, 256)
b = Linear(256, 256)
m = TransformerBlockList([CheckpointBlock(a), CheckpointBlock(b)])
bmt.init_parameters(m)

a.bias.requires_grad_(False)
sm1 = run(m, a, b)
print(sm1)
assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"])

a.weight.requires_grad_(False)
a.bias.requires_grad_(True)
sm2 = run(m, a, b)
print(sm2)
assert_eq(sm1.split('\n')[1], sm2.split('\n')[1])
assert_neq(sm1.split('\n')[2], sm2.split('\n')[2])

a.weight.requires_grad_(True)
a.bias.requires_grad_(False)
sm3 = run(m, a, b)
assert_neq(sm2.split('\n')[1], sm3.split('\n')[1])
assert_eq(sm2.split('\n')[2], sm3.split('\n')[2])

def test_main_pipe():
a = Linear(256, 256)
b = Linear(256, 256)
m = PipelineTransformerBlockList([CheckpointBlock(a), CheckpointBlock(b)])
bmt.init_parameters(m)

a.bias.requires_grad_(False)
sm1 = run(m, a, b)
print(sm1)
assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"])

a.weight.requires_grad_(False)
a.bias.requires_grad_(True)
sm2 = run(m, a, b)
print(sm2)
assert_eq(sm1.split('\n')[1], sm2.split('\n')[1])
assert_neq(sm1.split('\n')[2], sm2.split('\n')[2])

a.weight.requires_grad_(True)
a.bias.requires_grad_(False)
sm3 = run(m, a, b)
print(sm3)
assert_neq(sm2.split('\n')[1], sm3.split('\n')[1])
assert_eq(sm2.split('\n')[2], sm3.split('\n')[2])

if __name__ == "__main__":
bmt.init_distributed(pipe_size=2)

test_main()
test_main_pipe()