Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantizing manually #118

Draft
wants to merge 9 commits into
base: llama_fp8
Choose a base branch
from

Conversation

rohan-tan-bhowmik
Copy link
Contributor

@rohan-tan-bhowmik rohan-tan-bhowmik commented Jul 25, 2024

The output when I print dataset.root_theta._tree for llama-3b

{'token_embd': {'weight': PrimitiveTensor(token_embd.weight, [32000, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'blk': {'0': {'attn_q': {'weight': PrimitiveTensor(blk.0.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.0.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.0.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.0.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.0.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.0.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.0.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.0.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.0.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '1': {'attn_q': {'weight': PrimitiveTensor(blk.1.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.1.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.1.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.1.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.1.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.1.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.1.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.1.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.1.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '2': {'attn_q': {'weight': PrimitiveTensor(blk.2.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.2.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.2.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.2.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.2.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.2.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.2.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.2.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.2.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '3': {'attn_q': {'weight': PrimitiveTensor(blk.3.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.3.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.3.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.3.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.3.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.3.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.3.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.3.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.3.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '4': {'attn_q': {'weight': PrimitiveTensor(blk.4.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.4.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.4.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.4.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.4.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.4.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.4.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.4.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.4.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '5': {'attn_q': {'weight': PrimitiveTensor(blk.5.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.5.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.5.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.5.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.5.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.5.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.5.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.5.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.5.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '6': {'attn_q': {'weight': PrimitiveTensor(blk.6.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.6.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.6.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.6.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.6.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.6.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.6.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.6.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.6.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '7': {'attn_q': {'weight': PrimitiveTensor(blk.7.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.7.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.7.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.7.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.7.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.7.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.7.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.7.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.7.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '8': {'attn_q': {'weight': PrimitiveTensor(blk.8.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.8.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.8.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.8.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.8.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.8.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.8.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.8.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.8.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '9': {'attn_q': {'weight': PrimitiveTensor(blk.9.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.9.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.9.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.9.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.9.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.9.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.9.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.9.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.9.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '10': {'attn_q': {'weight': PrimitiveTensor(blk.10.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.10.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.10.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.10.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.10.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.10.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.10.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.10.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.10.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '11': {'attn_q': {'weight': PrimitiveTensor(blk.11.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.11.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.11.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.11.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.11.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.11.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.11.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.11.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.11.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '12': {'attn_q': {'weight': PrimitiveTensor(blk.12.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.12.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.12.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.12.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.12.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.12.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.12.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.12.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.12.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '13': {'attn_q': {'weight': PrimitiveTensor(blk.13.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.13.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.13.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.13.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.13.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.13.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.13.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.13.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.13.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '14': {'attn_q': {'weight': PrimitiveTensor(blk.14.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.14.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.14.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.14.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.14.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.14.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.14.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.14.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.14.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '15': {'attn_q': {'weight': PrimitiveTensor(blk.15.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.15.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.15.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.15.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.15.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.15.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.15.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.15.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.15.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '16': {'attn_q': {'weight': PrimitiveTensor(blk.16.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.16.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.16.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.16.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.16.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.16.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.16.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.16.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.16.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '17': {'attn_q': {'weight': PrimitiveTensor(blk.17.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.17.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.17.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.17.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.17.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.17.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.17.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.17.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.17.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '18': {'attn_q': {'weight': PrimitiveTensor(blk.18.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.18.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.18.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.18.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.18.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.18.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.18.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.18.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.18.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '19': {'attn_q': {'weight': PrimitiveTensor(blk.19.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.19.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.19.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.19.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.19.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.19.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.19.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.19.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.19.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '20': {'attn_q': {'weight': PrimitiveTensor(blk.20.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.20.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.20.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.20.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.20.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.20.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.20.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.20.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.20.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '21': {'attn_q': {'weight': PrimitiveTensor(blk.21.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.21.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.21.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.21.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.21.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.21.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.21.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.21.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.21.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '22': {'attn_q': {'weight': PrimitiveTensor(blk.22.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.22.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.22.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.22.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.22.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.22.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.22.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.22.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.22.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '23': {'attn_q': {'weight': PrimitiveTensor(blk.23.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.23.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.23.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.23.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.23.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.23.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.23.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.23.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.23.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '24': {'attn_q': {'weight': PrimitiveTensor(blk.24.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.24.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.24.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.24.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.24.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.24.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.24.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.24.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.24.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, '25': {'attn_q': {'weight': PrimitiveTensor(blk.25.attn_q.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_k': {'weight': PrimitiveTensor(blk.25.attn_k.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_v': {'weight': PrimitiveTensor(blk.25.attn_v.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_output': {'weight': PrimitiveTensor(blk.25.attn_output.weight, [3200, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_gate': {'weight': PrimitiveTensor(blk.25.ffn_gate.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_down': {'weight': PrimitiveTensor(blk.25.ffn_down.weight, [3200, 8640], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_up': {'weight': PrimitiveTensor(blk.25.ffn_up.weight, [8640, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'attn_norm': {'weight': PrimitiveTensor(blk.25.attn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'ffn_norm': {'weight': PrimitiveTensor(blk.25.ffn_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'output_norm': {'weight': PrimitiveTensor(output_norm.weight, [3200], torch.float32), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'output': {'weight': PrimitiveTensor(output.weight, [32000, 3200], torch.float16), 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}, 'q_input': DynamicScaledQuantizer(q_input) -> dtype=torch.float8_e4m3fn)}

Copy link
Member

@dan-garvey dan-garvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After seeing this and thinking about it, how I would approach your task generally is:

1.Look at how llama is generated from weights under models.
2. Make a small case where you only have attn weights and generate just the attention part.
3. play with the quantization like you're doing here.
4. finish a small test case
5. actually try incorporating into a model.

@rsuderman do you think that makes sense?

parser.add_argument(
"--use-fp8-quantization",
help="DType to use for activations in the model",
default="false",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check out optionalBooleanAction or something like that, you don't have to do all the postprocess parsing like line 241

@@ -201,6 +201,20 @@ def pad_block_ids(self) -> torch.Tensor:
return torch.tensor(rows, device=self.parent.model.device)


def quantize_theta(theta):
if isinstance(theta, Theta) or isinstance(theta, dict):
if "q_input" not in (theta._tree if isinstance(theta, Theta) else theta):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can probably do all this if isinstance stuff once at the beginning by assigning the result to variable, much more readable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants