From 64e5c2b63fc202480be261efa1eb62eb52587bc5 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 17 Sep 2024 11:25:33 -0700 Subject: [PATCH] partial nnx impl --- .vscode/launch.json | 70 - .vscode/settings.json | 6 - MaxText/layers/models.py | 9 +- MaxText/nnx_layers/__init__.py | 47 + MaxText/nnx_layers/attentions.py | 566 +-- MaxText/nnx_layers/embeddings.py | 43 +- MaxText/nnx_layers/linears.py | 442 +- MaxText/nnx_layers/mistral.py | 133 +- MaxText/nnx_layers/models.py | 410 +- MaxText/nnx_layers/normalizations.py | 34 +- MaxText/nnx_layers/quantizations.py | 2 +- MaxText/train.py | 12 +- testing_nnx_layers.ipynb | 6477 ++++++++++++++++++++++++++ 13 files changed, 7391 insertions(+), 860 deletions(-) delete mode 100644 .vscode/launch.json delete mode 100644 .vscode/settings.json create mode 100644 MaxText/nnx_layers/__init__.py create mode 100644 testing_nnx_layers.ipynb diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index ddd8eb0f6..000000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,70 +0,0 @@ -{ - "version": "0.2.0", - "configurations": [ - { - "name": "Debug MaxText Decode", - "type": "python", - "request": "launch", - "console": "integratedTerminal", - "justMyCode": false, - "python": "python3", - "program": "${workspaceFolder}/MaxText/decode.py", - "args": ["MaxText/configs/base.yml", - "run_name=runner_$(date +%Y-%m-%d-%H-%M)", - "base_output_directory=gs://test-maxtext-output", - "dataset_path=gs://test-maxtext-dataset", - "steps=2", - "attention=dot_product", - "enable_checkpointing=false"] - }, - { - "name": "Debug MaxText Train", - "type": "python", - "request": "launch", - "console": "integratedTerminal", - "justMyCode": false, - "python": "python3", - "program": "${workspaceFolder}/MaxText/train.py", - "args": ["MaxText/configs/base.yml", - "run_name=runner_$(date +%Y-%m-%d-%H-%M)", - "base_output_directory=gs://test-maxtext-output", - "dataset_path=gs://test-maxtext-dataset", - "steps=2", - "enable_checkpointing=false"] - }, - { - "name": "Debug MaxText Inference Microbenchmark", - "type": "python", - "request": "launch", - "console": "integratedTerminal", - "justMyCode": false, - "python": "python3", - "program": "${workspaceFolder}/MaxText/inference_microbenchmark.py", - "args": [ - "MaxText/configs/base.yml", - "model_name=llama2-7b", - "tokenizer_path=assets/tokenizer.llama2", - "weight_dtype=bfloat16", - "scan_layers=false", - "attention=dot_product", - "max_prefill_predict_length=1024", - "max_target_length=2048", - "ici_fsdp_parallelism=1", - "ici_tensor_parallelism=-1", - "ici_autoregressive_parallelism=1", - "inference_microbenchmark_prefill_lengths=32,64,128,256,512,1024", - "inference_microbenchmark_stages=generate", - "inference_microbenchmark_loop_iters=1", - "run_name=runner_$(date +%Y-%m-%d-%H-%M)", - "base_output_directory=gs://test-maxtext-output", - "prefill_cache_axis_order=0,2,1,3", - "ar_cache_axis_order=0,2,1,3", - "compute_axis_order=0,2,1,3", - "reshape_q=true", - "per_device_batch_size=24", - "quantization=int8", - "quantize_kvcache=True", - ] - }, - ] -} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 661647513..000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "python.testing.pytestArgs": [], - "python.testing.cwd": "${workspaceFolder}/MaxText", - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 82cc47294..cc01d9180 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -144,7 +144,7 @@ def __call__( jnp.sum(layer_output == 0) / jnp.size(layer_output), ) - return layer_output, None if cfg.scan_layers else layer_output + return (layer_output, None) if cfg.scan_layers else layer_output class SequentialBlockDecoderLayers(nn.Module): """Sequential unscanned series of decoder layers.""" @@ -264,6 +264,8 @@ def __call__( y = self.shared_embedding(decoder_input_tokens.astype("int32")) y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) y = y.astype(cfg.dtype) + + self.sow("intermediates", "rdyro_shared_embedding", y) if cfg.use_untrainable_positional_embedding: y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) @@ -337,6 +339,7 @@ def __call__( policy=policy, static_argnums=(-1, -2, -3, -4, -5), ) + self.sow("intermediates", "rdyro_before_scan", y) if cfg.using_pipeline_parallelism: if cfg.num_layers_per_pipeline_stage == 1: stage_module = BlockLayer(config=cfg, mesh=mesh, quant=self.quant) @@ -370,6 +373,8 @@ def __call__( deterministic, model_mode, ) + self.sow("intermediates", f"rdyro_after_layer_{lyr}", y) + self.sow("intermediates", "rdyro_after_scan", y) y = self.get_norm_layer()( dtype=cfg.dtype, @@ -379,6 +384,7 @@ def __call__( kernel_axes=("norm",), )(y) y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + self.sow("intermediates", "rdyro_after_norm", y) # [batch, length, emb_dim] -> [batch, length, vocab_size] if cfg.logits_via_embedding: @@ -401,6 +407,7 @@ def __call__( )( y ) # We do not quantize the logits matmul. + self.sow("intermediates", "rdyro_after_logits", logits) logits = nn.with_logical_constraint(logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab")) logits = logits.astype(jnp.float32) return logits diff --git a/MaxText/nnx_layers/__init__.py b/MaxText/nnx_layers/__init__.py new file mode 100644 index 000000000..dd0a0cb62 --- /dev/null +++ b/MaxText/nnx_layers/__init__.py @@ -0,0 +1,47 @@ +import jax +from flax import nnx +from flax import linen as nn +from flax.core import meta +from flax.linen.spmd import LogicallyPartitioned + +def _maybe_unbox_value(x): + return x + if isinstance(x, meta.Partitioned): + return x.unbox() + else: + return x + #return x.unbox() if isinstance(x, meta.Partitioned) else x + +class LinenToNNX(nnx.Module): + def __init__(self, module: nn.Module, rngs=None): + assert rngs is not None, "You must provide `rngs=`" + self.linen_module = module + self.initialized, self.rngs, self.linen_state = False, rngs, None + self.rngs = None + # flax's flags to keep track of train / eval + self.use_running_average, self.deterministic = False, False + + #@nnx.jit + def __call__(self, *args, **kw): + if not self.initialized: + rngs = nnx.Rngs(0) + #self.linen_state = self.linen_module.init(self.rngs(), *args, **kw) + self.linen_state = self.linen_module.init(rngs(), *args, **kw) + self.linen_state["params"] = jax.tree.map(lambda x: nnx.Param( + _maybe_unbox_value(x)), self.linen_state["params"], + is_leaf=lambda x: isinstance(x, meta.Partitioned)) + for key in (set(self.linen_state.keys()) - set(["params"])): + self.linen_state[key] = jax.tree.map( + lambda x: nnx.Variable(_maybe_unbox_value(x)), self.linen_state[key], + is_leaf=lambda x: isinstance(x, meta.Partitioned)) + self.initialized = True + #del self.rngs + mutable_keys = [k for k in self.linen_state.keys() if k != "params"] + linen_state = jax.tree.map(lambda x: x.value, self.linen_state) + ret, updates = self.linen_module.apply(linen_state, *args, **kw, + mutable=mutable_keys) + #print(f"Update keys: {mutable_keys}") + if not self.use_running_average or not self.deterministic: + updates = jax.tree.map(lambda x: nnx.Variable(x), updates) + nnx.update(self, nnx.state({"linen_state": updates})) + return ret \ No newline at end of file diff --git a/MaxText/nnx_layers/attentions.py b/MaxText/nnx_layers/attentions.py index b54f668dd..04c9a52b1 100644 --- a/MaxText/nnx_layers/attentions.py +++ b/MaxText/nnx_layers/attentions.py @@ -18,8 +18,10 @@ import functools import math from typing import Any, Optional +import dataclasses from flax import linen as nn +from flax import nnx import jax from jax import lax from jax.ad_checkpoint import checkpoint_name @@ -27,17 +29,12 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask import jax.numpy as jnp -import common_types -from kernels.ragged_attention import ragged_gqa -from kernels.ragged_attention import ragged_mha -from layers import embeddings -from layers import initializers -from layers import linears -from layers import quantizations - -# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes -# pytype: disable=attribute-error +import common_types +from nnx_layers import embeddings +from nnx_layers import initializers +from nnx_layers import linears +from nnx_layers import quantizations class AttentionType(enum.Enum): @@ -75,7 +72,7 @@ class AttentionType(enum.Enum): CACHE_SCALE_SEQUENCE = common_types.CACHE_SCALE_SEQUENCE CACHE_SCALE_HEADS = common_types.CACHE_SCALE_HEADS CACHE_SCALE_KV = common_types.CACHE_SCALE_KV -DEFAULT_MASK_VALUE = common_types.DEFAULT_MASK_VALUE +DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) nd_dense_init = initializers.nd_dense_init @@ -83,6 +80,9 @@ class AttentionType(enum.Enum): dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) +# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes +# pytype: disable=attribute-error + def validate_compute_axis_order(s: AxisIdxes) -> None: valid_compute_axis_order = ((0,1,2,3), (0,2,1,3)) @@ -115,7 +115,8 @@ def apply_mask_to_logits(logits: Array, mask: Array): return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE) -class AttentionOp(nn.Module): +@dataclasses.dataclass +class AttentionOp(nnx.Module): mesh: Mesh attention_kernel: str max_target_length: int @@ -127,8 +128,6 @@ class AttentionOp(nn.Module): flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) cache_logical_axis_names: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV) cache_scale_logical_axis_names: AxisNames = (CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV) - ragged_qkv_axis_names: AxisNames = (CACHE_BATCH, CACHE_HEADS, CACHE_SEQUENCE, CACHE_KV) - ragged_lengths_names: AxisNames = (CACHE_BATCH,) prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3) ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3) compute_axis_order: AxisIdxes = (0, 1, 2, 3) @@ -140,8 +139,19 @@ class AttentionOp(nn.Module): attention_type: AttentionType = AttentionType.GLOBAL # Default to global attention attn_logits_soft_cap: float | None = None sliding_window_size: int | None = None - use_ragged_attention: bool = False - ragged_block_size: int = 256 + name: str = "" + rngs: nnx.Rngs = None + + def __post_init__(self): + # prefill cache + self.cached_key_var, self.cached_key_scale_var = None, None + self.cached_value_var, self.cached_value_scale_var = None, None + self.cached_segment_id_var = None + + # ar cache + self.cached_key_var, self.cached_key_scale_var = None, None + self.cached_value_var, self.cached_value_scale_var = None, None + self.cached_segment_id_var, self.cache_index_var = None, None def check_attention_inputs(self, query: Array, key: Array| KVTensor, value: Array| KVTensor) -> None: """Check attention inputs.""" @@ -196,15 +206,10 @@ def generate_attention_mask(self, query, key, decoder_segment_ids: Array | None, return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None - def apply_attention(self, query: Array, key: Array | KVTensor, value: Array | KVTensor, decoder_segment_ids: Array | None, lengths: Array | None, model_mode: str, use_ragged_attention: bool = False): + def apply_attention(self, query: Array, key: Array| KVTensor, value: Array| KVTensor, decoder_segment_ids: Array | None, model_mode: str): self.check_attention_inputs(query, key, value) length = query.shape[-3] - if use_ragged_attention and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - if lengths is None: - lengths = jnp.sum(decoder_segment_ids, axis=-1) - - return self.ragged_attention(query, key, value, lengths, self.ragged_block_size) - elif ( + if ( self.attention_kernel == "dot_product" or (self.attention_kernel == "autoselected" and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE) or (self.attention_kernel == "autoselected" and length < 128) @@ -236,34 +241,6 @@ def apply_attention(self, query: Array, key: Array | KVTensor, value: Array | KV else: raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") - - def ragged_attention(self, query: Array, key: Array | KVTensor, value: Array | KVTensor, lengths: Array, block_size: int) -> tuple[Array, Array, Array]: - """Ragged Attention.""" - if isinstance(query, KVTensor) or isinstance(query, KVTensor): - raise TypeError("Ragged attention does not currently support quantized tensors.") - b = nn.logical_to_mesh_axes(self.ragged_lengths_names) - bsnd = nn.logical_to_mesh_axes(self.cache_logical_axis_names) - @functools.partial( - shard_map, - mesh=self.mesh, - in_specs=( - bsnd, - bsnd, - bsnd, - b, - None, - ), - out_specs=bsnd, - check_rep=False, - ) - def wrap_ragged_attention(query, key, value, lengths, block_size): - if query.shape[-2] == key.shape[-2]: - return ragged_mha(query, key, value, lengths, block_size=block_size) - else: - return ragged_gqa(query, key, value, lengths, block_size=block_size) - - return wrap_ragged_attention(query, key, value, lengths, block_size) - def tpu_flash_attention(self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None, attn_logits_soft_cap: float | None = None) -> Array: """TPU Flash Attention.""" # Transpose to ('batch', 'heads', 'length', 'kv') @@ -530,140 +507,112 @@ def _get_cache_scale_logical_shape(self, batch, heads): def _get_prefill_cache_vars(self, batch, heads, kv_head_size): + if (self.cached_key_var is None or self.cached_key_scale_var is None): + dtype = self._get_cached_kv_dtype(self.dtype) + cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size) - dtype = self._get_cached_kv_dtype(self.dtype) - cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size) + cache_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.prefill_cache_axis_order) + cache_shape = self.transpose_tuple(cache_logical_shape, self.prefill_cache_axis_order) - cache_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.prefill_cache_axis_order) - cache_shape = self.transpose_tuple(cache_logical_shape, self.prefill_cache_axis_order) - - cached_key_var = self.variable( - "cache", - "cached_prefill_key", - nn.with_logical_partitioning(jnp.zeros, cache_axis_names), - cache_shape, - dtype, - ) - cached_value_var = self.variable( - "cache", - "cached_prefill_value", - nn.with_logical_partitioning(jnp.zeros, cache_axis_names), - cache_shape, - dtype, - ) - cached_segment_id_var = self.variable( - "cache", - "cache_prefill_segment_id", - nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH, CACHE_SEQUENCE)), - (cache_logical_shape[0], self.max_prefill_predict_length), - jnp.int32, - ) - - if self.kv_quant: - cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads) - cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.prefill_cache_axis_order) - cache_scale_shape = self.transpose_tuple(cache_scale_logical_shape, self.prefill_cache_axis_order) - - cached_key_scale_var = self.variable( - "cache", - "cached_prefill_key_scale", - nn.with_logical_partitioning(jnp.zeros, cache_scale_axis_names), - cache_scale_shape, - jnp.bfloat16, + self.cached_key_var = nnx.Cache( + "cached_prefill_key", + nnx.with_partitioning(jnp.zeros, cache_axis_names)(cache_shape, dtype) ) - cached_value_scale_var = self.variable( - "cache", - "cached_prefill_value_scale", - nn.with_logical_partitioning(jnp.zeros, cache_scale_axis_names), - cache_scale_shape, - jnp.bfloat16, + self.cached_value_var = nnx.Cache( + "cached_prefill_value", + nn.with_logical_partitioning(jnp.zeros, cache_axis_names), + cache_shape, + dtype, + ) + self.cached_segment_id_var = nnx.Cache( + "cache_prefill_segment_id", + nnx.with_partitioning(jnp.zeros, (CACHE_BATCH, CACHE_SEQUENCE))( + (cache_logical_shape[0], self.max_prefill_predict_length), + jnp.int32, + ) ) - else: - cached_key_scale_var = None - cached_value_scale_var = None - key_vars = (cached_key_var, cached_key_scale_var) - value_vars = (cached_value_var, cached_value_scale_var) - return key_vars, value_vars, cached_segment_id_var + if self.kv_quant: + cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads) + cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.prefill_cache_axis_order) + cache_scale_shape = self.transpose_tuple(cache_scale_logical_shape, self.prefill_cache_axis_order) - def _get_ar_cache_vars(self, batch, heads, kv_head_size): + self.cached_key_scale_var = nnx.Cache( + "cached_prefill_key_scale", + nnx.with_partitioning(jnp.zeros, cache_scale_axis_names)(cache_scale_shape, jnp.bfloat16) + ) + self.cached_value_scale_var = nnx.Cache( + "cached_prefill_value_scale", + nnx.with_partitioning(jnp.zeros, cache_scale_axis_names)(cache_scale_shape, jnp.bfloat16) + ) + else: + self.cached_key_scale_var = None + self.cached_value_scale_var = None - dtype = self._get_cached_kv_dtype(self.dtype) - cache_length = self.max_target_length - self.max_prefill_predict_length - cache_logical_shape = (batch, cache_length, heads, kv_head_size) + key_vars = (self.cached_key_var, self.cached_key_scale_var) + value_vars = (self.cached_value_var, self.cached_value_scale_var) + return key_vars, value_vars, self.cached_segment_id_var - cache_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.ar_cache_axis_order) - cache_shape = self.transpose_tuple(cache_logical_shape, self.ar_cache_axis_order) + def _get_ar_cache_vars(self, batch, heads, kv_head_size): - # TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding - cached_key_var = self.variable( - "cache", - "cached_ar_key", - nn.with_logical_partitioning(jnp.zeros, cache_axis_names), - cache_shape, - dtype, - ) - cached_key_var.value = nn.with_logical_constraint( - cached_key_var.value, - cache_axis_names, - ) + if self.cached_key_var is None or self.cached_value_var is None: + dtype = self._get_cached_kv_dtype(self.dtype) + cache_length = self.max_target_length - self.max_prefill_predict_length + cache_logical_shape = (batch, cache_length, heads, kv_head_size) - cached_value_var = self.variable( - "cache", - "cached_ar_value", - nn.with_logical_partitioning(jnp.zeros, cache_axis_names), - cache_shape, - dtype, - ) - cached_value_var.value = nn.with_logical_constraint( - cached_value_var.value, - cache_axis_names, - ) + cache_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.ar_cache_axis_order) + cache_shape = self.transpose_tuple(cache_logical_shape, self.ar_cache_axis_order) - cached_segment_id_var = self.variable( - "cache", - "cache_ar_segment_id", - nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH, CACHE_SEQUENCE)), - (cache_logical_shape[0], cache_length), - jnp.int32, - ) + self.cached_key_var = nnx.Cache( + "cached_ar_key", + nnx.with_partitioning(jnp.zeros, cache_axis_names)(cache_shape, dtype) + ) + # # TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding + #cached_key_var.value = nn.with_logical_constraint( + # cached_key_var.value, + # cache_axis_names, + #) + + self.cached_value_var = nnx.Cache( + "cached_ar_value", + nnx.with_partitioning(jnp.zeros, cache_axis_names)(cache_shape, dtype) + ) + #cached_value_var.value = nn.with_logical_constraint( + # cached_value_var.value, + # cache_axis_names, + #) + + self.cached_segment_id_var = nnx.Cache( + "cache_ar_segment_id", + nnx.with_partitioning(jnp.zeros, (CACHE_BATCH, CACHE_SEQUENCE))( + (cache_logical_shape[0], cache_length), jnp.int32) + ) - cached_lengths_var = self.variable( - "cache", - "cached_ar_lengths", - nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH, )), - (cache_logical_shape[0], ), - jnp.int32, - ) + if self.kv_quant: + cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads) + cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order) + cache_scale_shape = self.transpose_tuple(cache_scale_logical_shape, self.ar_cache_axis_order) - if self.kv_quant: - cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads) - cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order) - cache_scale_shape = self.transpose_tuple(cache_scale_logical_shape, self.ar_cache_axis_order) - - cached_key_scale_var = self.variable( - "cache", - "cached_ar_key_scale", - nn.with_logical_partitioning(jnp.zeros, cache_scale_axis_names), - cache_scale_shape, - jnp.bfloat16, - ) - cached_value_scale_var = self.variable( - "cache", - "cached_ar_value_scale", - nn.with_logical_partitioning(jnp.zeros, cache_scale_axis_names), - cache_scale_shape, - jnp.bfloat16, + self.cached_key_scale_var = nnx.Cache( + "cached_ar_key_scale", + nnx.with_partitioning(jnp.zeros, cache_scale_axis_names)(cache_scale_shape, jnp.bfloat16) + ) + self.cached_value_scale_var = nnx.Cache( + "cached_ar_value_scale", + nnx.with_partitioning(jnp.zeros, cache_scale_axis_names)(cache_scale_shape, jnp.bfloat16) + ) + else: + self.cached_key_scale_var = None + self.cached_value_scale_var = None + + self.cache_index_var = nnx.Cache( + "cache_ar_index", + nnx.with_partitioning(jnp.zeros, ())((1,), jnp.int32) ) - else: - cached_key_scale_var = None - cached_value_scale_var = None - cache_index_var = self.variable( - "cache", "cache_ar_index", nn.with_logical_partitioning(jnp.zeros, ()), (1,), jnp.int32) - key_vars = (cached_key_var, cached_key_scale_var) - value_vars = (cached_value_var, cached_value_scale_var) - return key_vars, value_vars, cached_segment_id_var, cache_index_var, cached_lengths_var + key_vars = (self.cached_key_var, self.cached_key_scale_var) + value_vars = (self.cached_value_var, self.cached_value_scale_var) + return key_vars, value_vars, self.cached_segment_id_var, self.cache_index_var def kv_cache_prefill( self, @@ -717,8 +666,6 @@ def update_ar_key_value( cached_key_vars: tuple[nn.Variable, nn.Variable | None], cached_value_vars: tuple[nn.Variable, nn.Variable | None], one_hot_indices: Array, - lengths: Array, - use_ragged_attention: bool, ) -> None: """Adds a single token's results to the ar kv cache @@ -748,42 +695,17 @@ def update_ar_key_value( one_token_value_shaped_for_cache, one_token_value_scale_shaped_for_cache = self.kv_quant.quantize( one_token_value_shaped_for_cache, ar_cache_axis_names) - + one_hot_indices = one_hot_indices.astype(int) ar_cache_update_idx = jnp.squeeze(one_hot_indices) - ar_cache_sequence_axis = ar_cache_update_axis = ar_cache_axis_names.index(CACHE_SEQUENCE) - ar_cache_batch_axis = ar_cache_axis_names.index(CACHE_BATCH) - - if use_ragged_attention: - cache_locations = [slice(None)] * 4 - new_token_locations = [slice(None)] * 4 - new_token_locations[ar_cache_sequence_axis] = 0 - - def key_body(i, val): - cache_locations[ar_cache_batch_axis] = i - cache_locations[ar_cache_sequence_axis] = lengths[i] - new_token_locations[ar_cache_batch_axis] = i - return val.at[tuple(cache_locations)].set(one_token_key_shaped_for_cache[tuple(new_token_locations)]) - - def value_body(i, val): - cache_locations[ar_cache_batch_axis] = i - cache_locations[ar_cache_sequence_axis] = lengths[i] - new_token_locations[ar_cache_batch_axis] = i - return val.at[tuple(cache_locations)].set(one_token_value_shaped_for_cache[tuple(new_token_locations)]) - - cached_key_var.value = jax.lax.fori_loop(0, one_token_key_shaped_for_cache.shape[0], key_body, cached_key_var.value, unroll=8) - cached_value_var.value = jax.lax.fori_loop(0, one_token_value_shaped_for_cache.shape[0], value_body, cached_value_var.value, unroll=8) - - else: - one_hot_indices = one_hot_indices.astype(int) - cached_key_var.value = jax.lax.dynamic_update_index_in_dim( - cached_key_var.value, one_token_key_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis) - cached_value_var.value = jax.lax.dynamic_update_index_in_dim( - cached_value_var.value, one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis) + ar_cache_update_axis = ar_cache_axis_names.index(CACHE_SEQUENCE) + cached_key_var.value = jax.lax.dynamic_update_index_in_dim( + cached_key_var.value, one_token_key_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis) cached_key_var.value = nn.with_logical_constraint(cached_key_var.value, ar_cache_axis_names) + cached_value_var.value = jax.lax.dynamic_update_index_in_dim( + cached_value_var.value, one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis) cached_value_var.value = nn.with_logical_constraint(cached_value_var.value, ar_cache_axis_names) - if self.kv_quant: ar_cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order) ar_cache_scale_update_axis = ar_cache_scale_axis_names.index(CACHE_SCALE_SEQUENCE) @@ -818,7 +740,6 @@ def kv_cache_autoregressive( self, key: Array, value: Array, - use_ragged_attention: bool = False, ): """In autoregressive mode, we update the cache for this entry and then return the full cache. @@ -840,15 +761,14 @@ def kv_cache_autoregressive( if not is_initialized: raise ValueError("Error, we can't do autoregression if we haven't seeded the KV Cache.") - cached_ar_key_vars, cached_ar_value_vars, cached_ar_segment_id_var, cache_ar_index_var, cache_ar_lengths_var = self._get_ar_cache_vars(batch, heads, kv_head_size) + cached_ar_key_vars, cached_ar_value_vars, cached_ar_segment_id_var, cache_ar_index_var = self._get_ar_cache_vars(batch, heads, kv_head_size) - self.update_ar_key_value(key, value, cached_ar_key_vars, cached_ar_value_vars, cache_ar_index_var.value, cache_ar_lengths_var.value, use_ragged_attention) + self.update_ar_key_value(key, value, cached_ar_key_vars, cached_ar_value_vars, cache_ar_index_var.value) active_indicator = jnp.zeros((batch, 1), dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR cached_ar_segment_id_var.value = jax.lax.dynamic_update_index_in_dim( cached_ar_segment_id_var.value, active_indicator, jnp.squeeze(cache_ar_index_var.value), 1 ) cache_ar_index_var.value = jnp.mod(cache_ar_index_var.value + 1, self.max_target_length - self.max_prefill_predict_length) - cache_ar_lengths_var.value = cache_ar_lengths_var.value.at[:].add(1) # The below retrieves the existing prefill cache variables, not creating new ones cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars(batch, heads, kv_head_size) @@ -863,11 +783,10 @@ def kv_cache_autoregressive( self.get_cached_values(cached_ar_key_vars, key.dtype, self.ar_cache_axis_order), self.get_cached_values(cached_ar_value_vars, value.dtype, self.ar_cache_axis_order), cached_ar_segment_id_var.value, - cache_ar_lengths_var.value ) return cached_prefill, cached_ar - def kv_cache(self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str, use_ragged_attention: bool = False) -> tuple: + def kv_cache(self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str) -> tuple: """KV cache takes the current state and updates the state accordingly. The key and value have dimension [b, s, n_kv, d], @@ -893,7 +812,7 @@ def kv_cache(self, key: Array, value: Array, decoder_segment_ids: Array, model_m elif model_mode == common_types.MODEL_MODE_PREFILL: return self.kv_cache_prefill(key, value, decoder_segment_ids), None elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - return self.kv_cache_autoregressive(key, value, use_ragged_attention) + return self.kv_cache_autoregressive(key, value) else: raise ValueError(f"Model Mode isn't supported! {model_mode=}") @@ -920,18 +839,16 @@ def normalize_attention(self, local_outs, local_maxes, local_sums): attn_out += local_normalizer * local_out return attn_out - @nn.compact + #@nn.compact def __call__(self, query, key, value, decoder_segment_ids, model_mode): - prefill_kv_cache, ar_kv_cache = self.kv_cache(key, value, decoder_segment_ids, model_mode, use_ragged_attention=self.use_ragged_attention) + prefill_kv_cache, ar_kv_cache = self.kv_cache(key, value, decoder_segment_ids, model_mode) prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention( query=query, key=prefill_kv_cache[0], value=prefill_kv_cache[1], decoder_segment_ids=prefill_kv_cache[2], - lengths=None, model_mode=model_mode, - use_ragged_attention=self.use_ragged_attention, ) # Return the "prefill" cache if it actually the combined prefill+ar kv cache @@ -945,21 +862,18 @@ def __call__(self, query, key, value, decoder_segment_ids, model_mode): key=ar_kv_cache[0], value=ar_kv_cache[1], decoder_segment_ids=ar_kv_cache[2], - lengths=ar_kv_cache[3], model_mode=model_mode, - use_ragged_attention=self.use_ragged_attention, ) - if ar_unnormalized_output is not None: - unnormalized_outputs = [prefill_unnormalized_output, ar_unnormalized_output] - exponentials_maxes = [prefill_exponentials_max, ar_exponentials_max] - exponentials_sums = [prefill_exponentials_sum, ar_exponentials_sum] - return self.normalize_attention(unnormalized_outputs, exponentials_maxes, exponentials_sums) - else: - return prefill_unnormalized_output / prefill_exponentials_sum + unnormalized_outputs = [prefill_unnormalized_output, ar_unnormalized_output] + exponentials_maxes = [prefill_exponentials_max, ar_exponentials_max] + exponentials_sums = [prefill_exponentials_sum, ar_exponentials_sum] + return self.normalize_attention(unnormalized_outputs, exponentials_maxes, exponentials_sums) +#################################################################################################### -class Attention(nn.Module): +@dataclasses.dataclass +class Attention(nnx.Module): """Generic Attention. Attributes: @@ -1003,8 +917,6 @@ class Attention(nn.Module): attention_type: AttentionType = AttentionType.GLOBAL # Default to global attention attn_logits_soft_cap: float | None = None sliding_window_size: int | None = None - use_ragged_attention: bool = False - ragged_block_size: int = 256 # Shard the query activation as the same as the key and value. # TODO: Find a better sharding axis name. @@ -1017,22 +929,38 @@ class Attention(nn.Module): ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3) compute_axis_order: AxisIdxes = (0, 1, 2, 3) reshape_q: bool = False + + name: str = "" + rngs: nnx.Rngs = None - - def query_projection(self, inputs_q: Array) -> Array: - """Query projection.""" - - # NOTE: T5 does not explicitly rescale the attention logits by - # 1/sqrt(depth_kq)! This is folded into the initializers of the - # linear transformations, which is equivalent under Adafactor. - depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) - - def query_init(*args): - # pylint: disable=no-value-for-parameter - return self.kernel_init(*args) / depth_scaling - - query_proj = DenseGeneral( - features=(self.num_query_heads, self.head_dim), + def __post_init__(self): + # apply projection. + if self.config.fused_qkv: + self.qkv_proj = DenseGeneral( + in_features=self.config.emb_dim, + out_features=(3, self.num_query_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "qkv", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="qkv_proj", + quant=self.quant, + rngs=self.rngs, + ) + else: + # NOTE: T5 does not explicitly rescale the attention logits by + # 1/sqrt(depth_kq)! This is folded into the initializers of the + # linear transformations, which is equivalent under Adafactor. + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + + def query_init(*args): + # pylint: disable=no-value-for-parameter + return self.kernel_init(*args) / depth_scaling + + self.q_proj = DenseGeneral( + in_features=self.config.emb_dim, + out_features=(self.num_query_heads, self.head_dim), axis=-1, kernel_init=query_init, kernel_axes=("embed", "heads", "kv"), @@ -1040,63 +968,78 @@ def query_init(*args): weight_dtype=self.weight_dtype, name="query", quant=self.quant, - matmul_precision=self.config.matmul_precision, - )(inputs_q) - return query_proj - - def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: - """Projection for Key and Value. - - Args: - inputs_kv: inputs_kv: key/values of shape `[batch, kv_length, - num_kv_heads, kv_dim]`. - proj_name: name of projection, `key` or `value`. - - Returns: - Projection of key or value, in shape of `[batch, kv_length, head_dim]`. - """ - if self.num_kv_heads == -1: - raise ValueError("num_kv_heads is not defined.") + rngs=self.rngs, + ) - if self.num_query_heads % self.num_kv_heads != 0: - raise ValueError("Invalid num_kv_heads for GQA.") + if self.num_kv_heads == -1: + raise ValueError("num_kv_heads is not defined.") - kernel_axes = ("embed", "kv_heads", "kv_head_dim") + if self.num_query_heads % self.num_kv_heads != 0: + raise ValueError("Invalid num_kv_heads for GQA.") - kv_proj = DenseGeneral( - features=(self.num_kv_heads, self.head_dim), + kernel_axes = ("embed", "kv_heads", "kv_head_dim") + self.k_proj = DenseGeneral( + in_features=self.config.emb_dim, + out_features=(self.num_kv_heads, self.head_dim), axis=-1, kernel_init=self.kernel_init, kernel_axes=kernel_axes, dtype=self.dtype, weight_dtype=self.weight_dtype, - name=proj_name, + name="key", quant=self.quant, - matmul_precision=self.config.matmul_precision, - )(inputs_kv) - return kv_proj - - def qkv_projection(self, inputs: Array, proj_name: str): - """Fused QKV projection""" - - qkv_proj = DenseGeneral( - features=(3, self.num_query_heads, self.head_dim), + rngs=self.rngs, + ) + self.v_proj = DenseGeneral( + in_features=self.config.emb_dim, + out_features=(self.num_kv_heads, self.head_dim), axis=-1, kernel_init=self.kernel_init, - kernel_axes=("embed", "qkv", "heads", "kv"), + kernel_axes=kernel_axes, dtype=self.dtype, weight_dtype=self.weight_dtype, - name=proj_name, + name="value", quant=self.quant, - matmul_precision=self.config.matmul_precision, - )(inputs) - qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") - query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...] - return query, key, value + rngs=self.rngs, + ) + + # apply ROPE + self.query_rotary = RotaryEmbedding(min_timescale=self.config.rope_min_timescale, + max_timescale = self.config.rope_max_timescale, + embedding_dims=self.head_dim, name="query_rotary", + rngs=self.rngs) + self.key_rotary = RotaryEmbedding(min_timescale=self.config.rope_min_timescale, + max_timescale = self.config.rope_max_timescale, + embedding_dims=self.head_dim, name="key_rotary", + rngs=self.rngs) - def out_projection(self, output_dim: int, out: Array) -> Array: - out_proj = DenseGeneral( - features=output_dim, + assert not self.config.quantize_kvcache or self.kv_quant + self.attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + kv_quant=self.kv_quant, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + dropout_rate=self.dropout_rate, + dtype=self.dtype, + prefill_cache_axis_order=self.prefill_cache_axis_order, + ar_cache_axis_order=self.ar_cache_axis_order, + compute_axis_order=self.compute_axis_order, + reshape_q=self.reshape_q, + attention_type=self.attention_type, + attn_logits_soft_cap=self.attn_logits_soft_cap, + sliding_window_size=self.sliding_window_size, + rngs=self.rngs, + ) + + self.out_proj = DenseGeneral( + in_features=(self.num_query_heads, self.head_dim), + out_features=self.config.emb_dim, axis=(-2, -1), kernel_init=self.kernel_init, kernel_axes=("heads", "kv", "embed"), @@ -1104,17 +1047,16 @@ def out_projection(self, output_dim: int, out: Array) -> Array: weight_dtype=self.weight_dtype, name="out", quant=self.quant, - matmul_precision=self.config.matmul_precision, - )(out) - return out_proj + rngs=self.rngs, + ) - def key_rotary(self, key: Array, inputs_positions: Array): - """Apply Rotary Embedding to key.""" - key = RotaryEmbedding(min_timescale=self.config.rope_min_timescale, max_timescale = self.config.rope_max_timescale, - embedding_dims=self.head_dim, name="key_rotary")(inputs=key, position=inputs_positions) - return key + def qkv_projection(self, inputs: Array): + """Fused QKV projection""" + qkv_proj = self.qkv_proj(inputs) + qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") + query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...] + return query, key, value - @nn.compact def __call__( self, inputs_q: Array, @@ -1149,15 +1091,14 @@ def __call__( """ # apply projection. if self.config.fused_qkv: - query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") + query, key, value = self.qkv_projection(inputs_q) else: - query = self.query_projection(inputs_q) - key = self.kv_projection(inputs_kv, proj_name="key") - value = self.kv_projection(inputs_kv, proj_name="value") + query = self.q_proj(inputs_q) + key = self.k_proj(inputs_kv) + value = self.v_proj(inputs_kv) # apply ROPE - query = RotaryEmbedding(min_timescale=self.config.rope_min_timescale, max_timescale = self.config.rope_max_timescale, - embedding_dims=self.head_dim, name="query_rotary")(inputs=query, position=inputs_positions) + query = self.query_rotary(inputs=query, position=inputs_positions) key = self.key_rotary(key, inputs_positions) # annotate with sharding constraint. @@ -1168,36 +1109,11 @@ def __call__( value = nn.with_logical_constraint(value, self.value_axis_names) value = checkpoint_name(value, "value_proj") - assert not self.config.quantize_kvcache or self.kv_quant - attention_op = AttentionOp( - mesh=self.mesh, - attention_kernel=self.attention_kernel, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - float32_qk_product=self.float32_qk_product, - float32_logits=self.float32_logits, - quant=self.quant, - kv_quant=self.kv_quant, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - dropout_rate=self.dropout_rate, - dtype=self.dtype, - prefill_cache_axis_order=self.prefill_cache_axis_order, - ar_cache_axis_order=self.ar_cache_axis_order, - compute_axis_order=self.compute_axis_order, - reshape_q=self.reshape_q, - attention_type=self.attention_type, - attn_logits_soft_cap=self.attn_logits_soft_cap, - sliding_window_size=self.sliding_window_size, - use_ragged_attention=self.use_ragged_attention, - ragged_block_size=self.ragged_block_size, - ) - - out = attention_op(query, key, value, decoder_segment_ids, model_mode) + out = self.attention_op(query, key, value, decoder_segment_ids, model_mode) out = nn.with_logical_constraint(out, self.out_axis_names) # apply output projection, output dim is set to the input dim. - out = self.out_projection(inputs_q.shape[-1], out) + out = self.out_proj(out) out = checkpoint_name(out, "out_proj") return out diff --git a/MaxText/nnx_layers/embeddings.py b/MaxText/nnx_layers/embeddings.py index 1be1ce7a4..9a08eb649 100644 --- a/MaxText/nnx_layers/embeddings.py +++ b/MaxText/nnx_layers/embeddings.py @@ -15,12 +15,15 @@ """Embedding Layers.""" from typing import Any, Optional +import dataclasses from flax import linen as nn import jax from jax import lax import jax.numpy as jnp from layers import initializers +from flax import nnx +from flax.core.meta import Partitioned Config = Any Array = jnp.ndarray @@ -33,7 +36,8 @@ _MAX_WAVELENGTH = 10_000 -class Embed(nn.Module): +@dataclasses.dataclass +class Embed(nnx.Module): """A parameterized function from integers [0, n) to d-dimensional vectors. Attributes: @@ -51,14 +55,13 @@ class Embed(nn.Module): dtype: DType = jnp.float32 attend_dtype: Optional[DType] = None embedding_init: Initializer = default_embed_init - - def setup(self): - self.embedding = self.param( - "embedding", - with_logical_partitioning(self.embedding_init, ("vocab", "embed")), - (self.num_embeddings, self.features), - self.config.weight_dtype, - ) + rngs: nnx.Rngs = None + + def __post_init__(self): + value = nnx.with_partitioning(self.embedding_init, ("vocab", "embed"))( + self.rngs(), (self.num_embeddings, self.features), + self.config.weight_dtype) + self.embedding = nnx.Param(value) def __call__(self, inputs: Array) -> Array: """Embeds the inputs along the last dimension. @@ -75,13 +78,15 @@ def __call__(self, inputs: Array) -> Array: inputs = inputs.astype(self.cast_input_dtype) if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError("Input type must be an integer or unsigned integer.") + + embedding = self.embedding.value if cfg.use_iota_embed: iota = lax.iota(jnp.int32, self.num_embeddings) one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) - output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) + output = jnp.dot(one_hot, jnp.asarray(embedding, self.dtype)) else: - output = jnp.asarray(self.embedding, self.dtype)[inputs] + output = jnp.asarray(embedding, self.dtype)[inputs] output = nn.with_logical_constraint(output, ("activation_embed_and_logits_batch", "activation_length", "activation_embed")) return output @@ -102,7 +107,8 @@ def attend(self, query: Array) -> Array: return jnp.dot(query, jnp.asarray(self.embedding, jnp.bfloat16).T) -class RotaryEmbedding(nn.Module): +@dataclasses.dataclass +class RotaryEmbedding(nnx.Module): """RoPE Attributes: @@ -118,8 +124,10 @@ class RotaryEmbedding(nn.Module): embedding_dims: int = 0 cast_as_fprop_dtype: bool = True fprop_dtype: DType = jnp.bfloat16 - - def setup(self) -> None: + name: str = "" + rngs: nnx.Rngs = None + + def __post_init__(self) -> None: if self.embedding_dims % 2: raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") @@ -166,9 +174,14 @@ def __call__( return x_out -class PositionalEmbedding(nn.Module): +@dataclasses.dataclass +class PositionalEmbedding(nnx.Module): embedding_dims: int max_wavelength: int = _MAX_WAVELENGTH + rngs: nnx.Rngs | None = None + + def __post_init__(self): + pass def __call__( self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks diff --git a/MaxText/nnx_layers/linears.py b/MaxText/nnx_layers/linears.py index 496436a3e..efb15f1b6 100644 --- a/MaxText/nnx_layers/linears.py +++ b/MaxText/nnx_layers/linears.py @@ -17,14 +17,16 @@ import functools import operator from typing import Any, Callable, Iterable, Sequence, Tuple, Union, Optional +import dataclasses import flax.linen as nn +from flax import nnx import jax from jax import lax import jax.numpy as jnp import common_types -from layers import initializers -from layers import normalizations +from nnx_layers import initializers +from nnx_layers import normalizations from layers import quantizations import numpy as np from jax.ad_checkpoint import checkpoint_name @@ -76,10 +78,12 @@ def _canonicalize_tuple(x): return (x,) -class DenseGeneral(nn.Module): +@dataclasses.dataclass +class DenseGeneral(nnx.Module): """A linear transformation with flexible axes. Attributes: + in_features: tuple with input features sizes. features: tuple with numbers of output features. axis: tuple with axes to apply the transformation on. weight_dtype: the dtype of the weights (default: float32). @@ -89,7 +93,8 @@ class DenseGeneral(nn.Module): quant: quantization config, defaults to None implying no quantization. """ - features: Union[Iterable[int], int] + in_features: Union[Iterable[int], int] + out_features: Union[Iterable[int], int] axis: Union[Iterable[int], int] = -1 weight_dtype: DType = jnp.float32 dtype: DType = jnp.float32 @@ -97,9 +102,40 @@ class DenseGeneral(nn.Module): kernel_axes: Tuple[str, ...] = () quant: Optional[Quant] = None use_bias: bool = False - matmul_precision: str = "default" + name: str = "" + rngs: nnx.Rngs = None + + def __post_init__(self): + in_features = _canonicalize_tuple(self.in_features) + out_features = _canonicalize_tuple(self.out_features) + axis = _canonicalize_tuple(self.axis) + + kernel_shape = in_features + out_features + kernel_in_axis = np.arange(len(axis)) + kernel_out_axis = np.arange(len(axis), len(axis) + len(out_features)) + + if quantizations.in_serve_mode(self.quant): + # During aqt convert state we delete kernel weight from params to save memory. + # Instead they are retrieved from the tensors stored in the 'aqt' collection. + # kernel = jnp.zeros(kernel_shape) + self.kernel = None + else: + self.kernel = nnx.Param( + nnx.with_partitioning(self.kernel_init, self.kernel_axes)( + self.rngs(), + kernel_shape, + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ) + ) + + if self.use_bias: + bias_axes, bias_shape = self.kernel_axes[-len(out_features) :], kernel_shape[-len(out_features) :] + self.bias = nnx.Param(nnx.with_partitioning(bias_init, bias_axes)( + bias_shape, self.weight_dtype)) + - @nn.compact def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along multiple dimensions. @@ -113,54 +149,26 @@ def __call__(self, inputs: Array) -> Array: def compute_dot_general(inputs, kernel, axis, contract_ind): """Computes a dot_general operation that may be quantized.""" dot_general = lax.dot_general - matmul_precision = lax.Precision(self.matmul_precision) if self.quant: dot_general_cls = self.quant.dot_general_cls(mesh_axes=self.kernel_axes) dot_general = dot_general_cls() - return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) - return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) - - features = _canonicalize_tuple(self.features) - axis = _canonicalize_tuple(self.axis) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) inputs = jnp.asarray(inputs, self.dtype) - axis = _normalize_axes(axis, inputs.ndim) - - kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features - kernel_in_axis = np.arange(len(axis)) - kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) - if quantizations.in_serve_mode(self.quant): - # During aqt convert state we delete kernel weight from params to save memory. - # Instead they are retrieved from the tensors stored in the 'aqt' collection. - kernel = jnp.zeros(kernel_shape) - else: - kernel = self.param( - "kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), - kernel_shape, - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, - ) - kernel = jnp.asarray(kernel, self.dtype) + axis = _normalize_axes(_canonicalize_tuple(self.axis), inputs.ndim) + kernel = jnp.asarray(self.kernel, self.dtype) contract_ind = tuple(range(0, len(axis))) output = compute_dot_general(inputs, kernel, axis, contract_ind) if self.use_bias: - bias_axes, bias_shape = self.kernel_axes[-len(features) :], kernel_shape[-len(features) :] - bias = self.param( - "bias", - nn.with_logical_partitioning(bias_init, bias_axes), - bias_shape, - self.weight_dtype, - ) - bias = jnp.asarray(bias, self.dtype) + bias = jnp.asarray(self.bias, self.dtype) output += bias return output -class MlpBlock(nn.Module): +@dataclasses.dataclass +class MlpBlock(nnx.Module): """Transformer MLP / feed-forward block. Attributes: @@ -178,6 +186,7 @@ class MlpBlock(nn.Module): """ config: Config + input_dim: int = 2048 intermediate_dim: int = 2048 activations: Sequence[Union[str, Callable[..., Any]]] = ("relu",) kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") @@ -187,6 +196,64 @@ class MlpBlock(nn.Module): use_bias: bool = False use_pre_norm: bool = False quant: Optional[Quant] = None + name: str = "" + rngs: nnx.Rngs = None + + def __post_init__(self): + cfg = self.config + self.mlp_layer_norm = self.get_norm_layer()( + features=self.input_dim, + name="mlp_layer_norm", + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + kernel_axes=("norm",), + epsilon=cfg.normalization_layer_epsilon, + rngs=self.rngs, + ) + if self.config.fused_mlp: + self.wi = DenseGeneral( + self.input_dim, + (len(self.activations), self.intermediate_dim), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "num_activations", "mlp"), + name="wi", + quant=self.quant, + use_bias=self.use_bias, + rngs=self.rngs, + ) + else: + self.wi = dict() + for idx, _ in enumerate(self.activations): + dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" + self.wi[dense_name] = DenseGeneral( + self.input_dim, + self.intermediate_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "mlp"), + name=dense_name, + quant=self.quant, + use_bias=self.use_bias, + rngs=self.rngs, + ) + self.dropout = nnx.Dropout(rate=self.intermediate_dropout_rate, + broadcast_dims=(-2,), rngs=self.rngs) + self.wo = DenseGeneral( + self.intermediate_dim, + self.input_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("mlp", "embed"), + name="wo", + quant=self.quant, + use_bias=self.use_bias, + rngs=self.rngs, + ) + def get_norm_layer(self): if self.config.decoder_block in ("default", "llama2", "mistral", "gemma"): @@ -198,80 +265,45 @@ def get_norm_layer(self): else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") - @nn.compact + #@nn.compact def __call__(self, inputs, decode: bool = False, deterministic: bool = False): """Applies Transformer MlpBlock module.""" cfg = self.config if self.use_pre_norm: - inputs = self.get_norm_layer()( - name="mlp_layer_norm", - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - kernel_axes=("norm",), - epsilon=cfg.normalization_layer_epsilon, - )(inputs) + inputs = self.mlp_layer_norm(inputs) # Iterate over specified MLP input activation functions. # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. activations = [] if cfg.fused_mlp: - x = DenseGeneral( - (len(self.activations), self.intermediate_dim), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - kernel_init=self.kernel_init, - kernel_axes=("embed", "num_activations", "mlp"), - name="wi", - quant=self.quant, - use_bias=self.use_bias, - matmul_precision=self.config.matmul_precision, - )(inputs) + x = (inputs) for idx, act_fn in enumerate(self.activations): y = _convert_to_activation_function(act_fn)(x[:, :, idx, ...]) activations.append(y) else: for idx, act_fn in enumerate(self.activations): dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" - x = DenseGeneral( - self.intermediate_dim, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - kernel_init=self.kernel_init, - kernel_axes=("embed", "mlp"), - name=dense_name, - quant=self.quant, - use_bias=self.use_bias, - matmul_precision=self.config.matmul_precision, - )(inputs) - x = _convert_to_activation_function(act_fn)(x.astype(jnp.float32)) + x = self.wi[dense_name](inputs) + x = _convert_to_activation_function(act_fn)(x) activations.append(x) # Take elementwise product of above intermediate activations. - x = functools.reduce(operator.mul, activations).astype(self.dtype) + x = functools.reduce(operator.mul, activations) x = checkpoint_name(x, "mlpwi") # Apply dropout and final dense output projection. - x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( + x = self.dropout( x, deterministic=deterministic ) # Broadcast along length. x = nn.with_logical_constraint(x, ("activation_batch", "activation_length", "activation_mlp")) - output = DenseGeneral( - inputs.shape[-1], - dtype=self.dtype, - weight_dtype=self.weight_dtype, - kernel_init=self.kernel_init, - kernel_axes=("mlp", "embed"), - name="wo", - quant=self.quant, - use_bias=self.use_bias, - matmul_precision=self.config.matmul_precision, - )(x) + output = self.wo(x) output = checkpoint_name(output, "mlpwo") return output -class MoeBlock(nn.Module): +@dataclasses.dataclass +class MoeBlock(nnx.Module): """Mixture of Experts (MoE) block. Attributes: @@ -294,44 +326,64 @@ class MoeBlock(nn.Module): weight_dtype: DType = jnp.float32 dtype: DType = jnp.float32 quant: Optional[Quant] = None - - # The first axes is expert - wi_kernel_axes = ('exp', 'embed_no_exp', 'mlp') - wo_kernel_axes = ('exp', 'mlp', 'embed_no_exp') + name: str = "" + rngs: nnx.Rngs = None + + def __post_init__(self): + cfg = self.config + self.gate = DenseGeneral( + in_features=cfg.emb_dim, + out_features=self.num_experts, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + kernel_init=self.kernel_init, + kernel_axes=self.kernel_axes, + name="gate", + rngs=self.rngs, + ) + self.w0_kernel, self.w1_kernel, self.wo_kernel = self.generate_kernels( + cfg.num_experts, cfg.emb_dim, cfg.mlp_dim) def generate_kernels(self, num_experts, emb_dim, mlp_dim): - kernel_in_axis = np.arange(1) kernel_out_axis = np.arange(1, 2) kernel_init = nd_dense_init(1.0, 'fan_in', 'truncated_normal') - w0_kernel = self.param( - 'wi_0', - nn.with_logical_partitioning(kernel_init, self.wi_kernel_axes), - (num_experts, emb_dim, mlp_dim), - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, + # The first axes is expert + kernel_axes = (None, 'embed', 'mlp') + wo_kernel_axes = (None, 'mlp', 'embed') + + w0_kernel = nnx.Param( + nnx.with_partitioning(kernel_init, kernel_axes)( + self.rngs(), + (num_experts, emb_dim, mlp_dim), + self.weight_dtype, + kernel_in_axis, + kernel_out_axis + ) ) - w0_kernel = jnp.asarray(w0_kernel, self.dtype) - w1_kernel = self.param( - 'wi_1', - nn.with_logical_partitioning(kernel_init, self.wi_kernel_axes), - (num_experts, emb_dim, mlp_dim), - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, + w0_kernel.value = jnp.asarray(w0_kernel.value, self.dtype) + w1_kernel = nnx.Param( + nnx.with_partitioning(kernel_init, kernel_axes)( + self.rngs(), + (num_experts, emb_dim, mlp_dim), + self.weight_dtype, + kernel_in_axis, + kernel_out_axis + ) ) - w1_kernel = jnp.asarray(w1_kernel, self.dtype) - wo_kernel = self.param( - 'wo', - nn.with_logical_partitioning(kernel_init, self.wo_kernel_axes), - (num_experts, mlp_dim, emb_dim), - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, + w1_kernel.value = jnp.asarray(w1_kernel.value, self.dtype) + wo_kernel = nnx.Param( + nnx.with_partitioning(kernel_init, wo_kernel_axes)( + self.rngs(), + (num_experts, mlp_dim, emb_dim), + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ) ) - wo_kernel = jnp.asarray(wo_kernel, self.dtype) + wo_kernel.value = jnp.asarray(wo_kernel.value, self.dtype) return w0_kernel, w1_kernel, wo_kernel def permute(self, inputs, gate_logits): @@ -358,8 +410,7 @@ def unpermute(self, intermediate, sorted_selected_experts, weights): tensor_parallelism = self.config.ici_tensor_parallelism * self.config.dcn_tensor_parallelism reshaped_intermediate = jnp.reshape(unsort_intermediate, (-1, self.num_experts_per_tok, self.config.emb_dim // tensor_parallelism)) with jax.named_scope("weight_sum"): - matmul_precision = lax.Precision(self.config.matmul_precision) - output = jnp.einsum("BKE,BK -> BE", reshaped_intermediate.astype(jnp.float32), reshaped_weights.astype(jnp.float32), precision=matmul_precision) + output = jnp.einsum("BKE,BK -> BE", reshaped_intermediate, reshaped_weights) return output.reshape(-1, self.config.max_target_length, self.config.emb_dim // tensor_parallelism).astype(self.dtype) def megablox(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): @@ -413,23 +464,15 @@ def wrapper(x, logits, w0, w1, wo): output = self.unpermute(intermediate_output, sorted_selected_experts, weights) - return output, None + return output return wrapper(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel) - def reshape_and_update_weights(self, weights, indices): - # input of weights & indices: (batch_size, seq_len, num_experts_per_tok) - # output of updated weights: (batch_size, seq_len, num_experts) - update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype) - index_update = (jnp.arange(weights.shape[0])[:, None, None], jnp.arange(weights.shape[1])[:, None], indices) - update_weights = update_weights.at[index_update].set(weights) - return update_weights - - def generate_masks(self, top_k_indices, softmax_probs): + def generate_masks(self, top_k_indices): # calculate expert_capacity = (tokens_per_batch / num_experts) * capacity_factor batch_size, seq_len, _ = top_k_indices.shape - tokens_per_batch = seq_len * self.num_experts_per_tok - expert_capacity_per_batch = int((tokens_per_batch / self.num_experts) * self.config.capacity_factor) - max_logging.log(f"Applying potential token dropping with a batch expert_capacity of {expert_capacity_per_batch}") + tokens_per_batch = batch_size * seq_len * self.num_experts_per_tok + expert_capacity = int((tokens_per_batch / self.num_experts) * self.config.capacity_factor) + max_logging.log(f"Applying potential token dropping with an expert_capacity of {expert_capacity}") # calculate expert mask and drop tokens if needed # shape of output expert mask: (batch, sequence, num_experts_per_tok) @@ -442,124 +485,53 @@ def generate_masks(self, top_k_indices, softmax_probs): # trunc_expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]], # so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of updated_expert_mask is [[[1, 1],[0, 1]]]. expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32) - expert_mask_fused = jnp.reshape(expert_mask, (batch_size, seq_len * self.num_experts_per_tok, self.num_experts)) - expert_mask_fused = nn.with_logical_constraint(expert_mask_fused, ("activation_batch", None, None)) - expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1) - expert_token_count = jnp.reshape(expert_token_count_fused, ((batch_size, seq_len, self.num_experts_per_tok, self.num_experts))) - expert_token_count = nn.with_logical_constraint(expert_token_count, ("activation_batch", "activation_length", None, None)) - trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) - combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) - - # reshape & update weights - softmax_probs *= combined_expert_mask - - # calculate token position in expert capacity dimension - expert_token_position_fused = expert_mask_fused * expert_token_count_fused - expert_token_position = jnp.reshape(expert_token_position_fused, (batch_size, seq_len, self.num_experts_per_tok, self.num_experts)) - combined_expert_token_position = jnp.sum(expert_token_position, axis=2) * combined_expert_mask - expert_token_position_in_capacity = jax.nn.one_hot(combined_expert_token_position, num_classes=expert_capacity_per_batch+1, dtype=jnp.int32) - - # shape of combine_mask is (batch_size, seq_len, num_experts, expert_capacity_per_batch + 1), - # and cut 0-dimension which is always 0 - combine_mask = (softmax_probs[..., None] * expert_token_position_in_capacity) - combine_mask = combine_mask[..., 1:] - dispatch_mask = combine_mask.astype(bool) - return dispatch_mask, combine_mask - - # See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. - def load_balance_loss(self, top_k_indices, logits): - expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32) - summed_expert_mask = jnp.sum(expert_mask, axis=2) - # Get fraction of tokens dispatched to each expert - density = jnp.mean(summed_expert_mask, axis=1) - # get fraction of probability allocated to each expert - density_prob = jnp.mean(logits, axis=1) - loss = jnp.mean(density * density_prob) * (self.num_experts ** 2) * self.config.load_balance_loss_weight - return loss - - def get_einsum(self, rhs_mesh_axes: Tuple[Optional[str], ...] = ()): - if self.quant: - einsum_op = self.quant.einsum(rhs_mesh_axes) - else: - einsum_op = jnp.einsum - return einsum_op + reshaped_expert_mask = jnp.reshape(expert_mask, (batch_size, seq_len * self.num_experts_per_tok, self.num_experts)) + expert_token_count = jnp.cumsum(reshaped_expert_mask, axis=1) + expert_token_count = jnp.reshape(expert_token_count, ((batch_size, seq_len, self.num_experts_per_tok, self.num_experts))) + trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity) + updated_expert_mask = jnp.sum(trunc_expert_mask, axis=3) + return updated_expert_mask def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): - gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", "activation_embed")) - softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype) # shape of top_k_weights & top_k_indices: (batch, sequence, num_experts_per_tok) - top_k_weights, top_k_indices = jax.lax.top_k(softmax_probs, self.num_experts_per_tok) - matmul_precision = lax.Precision(self.config.matmul_precision) + top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok) + softmax_probs = jax.nn.softmax(top_k_weights.astype(jnp.float32), axis=-1).astype(self.dtype) + # token dropping if needed if self.config.capacity_factor > 0: - # token dropping if needed - dispatch_mask, combine_mask = self.generate_masks(top_k_indices, softmax_probs) - mask_axes = ("activation_batch", "activation_length", None, None) - dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes) - combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) - loss = self.load_balance_loss(top_k_indices, softmax_probs) - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) - with jax.named_scope("dispatch"): - dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask, precision=matmul_precision) - dispatch = nn.with_logical_constraint(dispatch, ("activation_exp", "activation_batch_no_exp", None, "activation_embed")) - with jax.named_scope("wi_0"): - w0_kernel_axes = ("exp", None, None) - w0_kernel = nn.with_logical_constraint(w0_kernel, w0_kernel_axes) - layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w0_kernel, precision=matmul_precision).astype(jnp.float32) - layer_w0 = nn.with_logical_constraint(layer_w0, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp")) - with jax.named_scope("wi_1"): - w1_kernel_axes = ("exp", None, None) - w1_kernel = nn.with_logical_constraint(w1_kernel, w1_kernel_axes) - layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w1_kernel, precision=matmul_precision).astype(jnp.float32) - layer_w1 = nn.with_logical_constraint(layer_w1, ("activation_exp", "activation_batch_no_exp",None, "activation_mlp")) - layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) - layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype) - with jax.named_scope("wo"): - wo_kernel_axes = ("exp", None, None) - wo_kernel = nn.with_logical_constraint(wo_kernel, wo_kernel_axes) - intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)("EBCH,EHM -> EBCM", layer_multiply, wo_kernel, precision=matmul_precision) - intermediate_layer = nn.with_logical_constraint(intermediate_layer, ("activation_exp", "activation_batch_no_exp", None, "activation_embed")) - with jax.named_scope("combine"): - # Matmul & element wise operation - output = self.get_einsum(rhs_mesh_axes=mask_axes)("EBCM,BSEC -> BSM", intermediate_layer, combine_mask, precision=matmul_precision) - return output, loss - else: - weights = self.reshape_and_update_weights(top_k_weights, top_k_indices) - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) - with jax.named_scope("wi_0"): - layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)("BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision).astype(jnp.float32) - with jax.named_scope("wi_1"): - layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)("BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision).astype(jnp.float32) - layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) - layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype) - with jax.named_scope("wo"): - intermediate_layer = self.get_einsum(rhs_mesh_axes=self.wo_kernel_axes)("BSEH,EHM -> BSEM", layer_multiply, wo_kernel, precision=matmul_precision) - with jax.named_scope("w_sum"): - weights_axis = ("activation_batch", "activation_length", "activation_exp") - output = self.get_einsum(rhs_mesh_axes=weights_axis)("BSEM,BSE -> BSM", intermediate_layer.astype(jnp.float32), weights.astype(jnp.float32)).astype(self.dtype) - return output, None + expert_mask = self.generate_masks(top_k_indices) + softmax_probs *= expert_mask + + weights = jnp.zeros_like(gate_logits) + index_update = (jnp.arange(gate_logits.shape[0])[:, None, None], jnp.arange(gate_logits.shape[1])[:, None], top_k_indices) + weights = weights.at[index_update].set(softmax_probs) + + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + with jax.named_scope("wi_0"): + layer_w0 = jnp.einsum("BLE,NEH -> BLNH", inputs, w0_kernel) + with jax.named_scope("wi_1"): + layer_w1 = jnp.einsum("BLE,NEH -> BLNH", inputs, w1_kernel) + layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) + layer_multiply = jnp.multiply(layer_w0_act, layer_w1) + with jax.named_scope("wo"): + intermediate_layer = jnp.einsum("BLNH,NHE -> BLNE", layer_multiply, wo_kernel) + with jax.named_scope("w_sum"): + output = jnp.einsum("BLNE,BLN -> BLE", intermediate_layer, weights) + return output @nn.compact def __call__(self, inputs): cfg = self.config inputs = inputs.astype(cfg.dtype) - gate_logits = DenseGeneral( - self.num_experts, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - quant=self.quant, - kernel_init=self.kernel_init, - kernel_axes=self.kernel_axes, - name="gate", - matmul_precision=self.config.matmul_precision)(inputs) - - w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, - cfg.emb_dim, - cfg.mlp_dim) + gate_logits = self.gate(inputs) if cfg.megablox: max_logging.log("Running MoE megablox implementation.") - return self.megablox(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel) + return self.megablox( + inputs, gate_logits, self.w0_kernel.value, self.w1_kernel.value, + self.wo_kernel.value) else: max_logging.log("Running MoE matmul implementation.") - return self.dense_matmul(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel) + return self.dense_matmul( + inputs, gate_logits, self.w0_kernel.value, self.w1_kernel.value, + self.wo_kernel.value) diff --git a/MaxText/nnx_layers/mistral.py b/MaxText/nnx_layers/mistral.py index 8ac08787b..cba686c2c 100644 --- a/MaxText/nnx_layers/mistral.py +++ b/MaxText/nnx_layers/mistral.py @@ -20,17 +20,21 @@ from typing import Optional -from layers import quantizations -from layers import linears -from layers import initializers +import dataclasses + import jax from jax.sharding import Mesh from flax import linen as nn +from flax import nnx import jax.numpy as jnp -from layers import attentions -from layers import embeddings -from layers import normalizations -from layers import models + +from nnx_layers import quantizations +from nnx_layers import linears +from nnx_layers import initializers +from nnx_layers import attentions +from nnx_layers import embeddings +from nnx_layers import normalizations +from nnx_layers import models import common_types import max_logging @@ -50,40 +54,29 @@ # ----------------------------------------- -class MistralDecoderLayer(nn.Module): +@dataclasses.dataclass +class MistralDecoderLayer(nnx.Module): """Transformer decoder layer that attends to the encoder.""" config: models.Config mesh: Mesh quant: Optional[Quant] = None - - @nn.compact - def __call__( - self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + name: str = "" + rngs: nnx.Rngs = None + + def __post_init__(self): cfg = self.config mesh = self.mesh - - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) - - lnx_rms = models.RMSNorm( + self.pre_self_attention_layer_norm = RMSNorm( + features=cfg.emb_dim, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_layer_norm", kernel_axes=("norm",), epsilon=cfg.normalization_layer_epsilon, + rngs=self.rngs, ) - lnx = lnx_rms(inputs) - - lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) - - # Self-attention block - attention_layer = Attention( + self.self_attention = Attention( config=cfg, num_query_heads=cfg.num_query_heads, num_kv_heads=cfg.num_kv_heads, @@ -98,33 +91,19 @@ def __call__( name="self_attention", quant=self.quant, kv_quant=quantizations.configure_kv_quant(cfg), + rngs=self.rngs, ) - - attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode, - ) - - attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) - intermediate_inputs = inputs + attention_lnx - - # Fully Connected - hidden_states = models.RMSNorm( + self.post_self_attention_layer_norm = RMSNorm( + features=cfg.emb_dim, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_layer_norm", kernel_axes=("norm",), epsilon=cfg.normalization_layer_epsilon, - )(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) - - load_balance_loss = None + rngs=self.rngs, + ) if cfg.num_experts > 1: - mlp_lnx, load_balance_loss = linears.MoeBlock( + self.mlp = linears.MoeBlock( config=cfg, num_experts=cfg.num_experts, num_experts_per_tok=cfg.num_experts_per_tok, @@ -134,12 +113,10 @@ def __call__( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, quant=self.quant, - )(hidden_states) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') + rngs=self.rngs, ) else: - mlp_lnx = linears.MlpBlock( + self.mlp = linears.MlpBlock( intermediate_dim=cfg.mlp_dim, activations=cfg.mlp_activations, intermediate_dropout_rate=cfg.dropout_rate, @@ -148,20 +125,61 @@ def __call__( name="mlp", config=cfg, quant=self.quant, - )(hidden_states, deterministic=deterministic) + rngs=self.rngs, + ) + self.post_dropout = nnx.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): + cfg = self.config + inputs_dtype = inputs.dtype + + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + + lnx = self.pre_self_attention_layer_norm(inputs) + + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) + + # Self-attention block + attention_lnx = self.self_attention( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) + + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) + intermediate_inputs = inputs + attention_lnx + + # Fully Connected + hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) + hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) + + if cfg.num_experts > 1: + mlp_lnx = self.mlp(hidden_states) + mlp_lnx = nn.with_logical_constraint( + mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') + ) + else: + mlp_lnx = self.mlp(hidden_states, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.post_dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, ("activation_batch", "activation_length", "activation_embed"), ) - if cfg.num_experts > 1 and load_balance_loss is not None: - self.sow("intermediates", "moe_lb_loss", load_balance_loss) - if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) @@ -170,6 +188,7 @@ def __call__( "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) + layer_output = layer_output.astype(inputs_dtype) if cfg.scan_layers: return layer_output, None diff --git a/MaxText/nnx_layers/models.py b/MaxText/nnx_layers/models.py index 82cc47294..5122fcd4d 100644 --- a/MaxText/nnx_layers/models.py +++ b/MaxText/nnx_layers/models.py @@ -17,19 +17,28 @@ # pylint: disable=no-name-in-module from typing import Any, Callable, Optional - +import dataclasses +from inspect import signature +import time from flax import linen as nn +from flax import nnx +from flax.nnx import bridge import functools import jax import jax.numpy as jnp import common_types -from layers import attentions -from layers import embeddings -from layers import linears -from layers import normalizations, quantizations +from nnx_layers import attentions +from nnx_layers import embeddings +#from layers import embeddings +#from layers import linears +from nnx_layers import linears +#from layers import normalizations, quantizations +from nnx_layers import normalizations +from layers import quantizations from layers import pipeline + Array = common_types.Array Config = common_types.Config DType = common_types.DType @@ -46,15 +55,71 @@ # The network: Decoder & Transformer Definitions # ------------------------------------------------------------------------------ +@jax.tree_util.register_static +class StaticStr(str): pass + -class DecoderLayer(nn.Module): +@dataclasses.dataclass +class DecoderLayer(nnx.Module): """Transformer decoder layer that attends to the encoder.""" config: Config mesh: Mesh quant: Optional[Quant] = None + name: str = "decoder_layer" + rngs: nnx.Rngs | None = None + + def __post_init__(self) -> None: + cfg = self.config + mesh = self.mesh - @nn.compact + self.input_norm = RMSNorm( + features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="pre_self_attention_norm", + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=self.rngs, + ) + self.attention_layer = Attention( + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]), + ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]), + compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]), + reshape_q=cfg.reshape_q, + rngs=self.rngs + ) + + self.mlp_block = linears.MlpBlock( + input_dim=cfg.emb_dim, + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="mlp", + config=cfg, + quant=self.quant, + rngs=self.rngs + ) + self.out_dropout = nnx.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,), + rngs=self.rngs) + + #@nn.compact def __call__( self, inputs, @@ -69,37 +134,10 @@ def __call__( inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] - lnx = RMSNorm( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="pre_self_attention_norm", - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=("norm",), - )(inputs) + lnx = self.input_norm(inputs) lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) - attention_layer = Attention( - config=self.config, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name="self_attention", - quant=self.quant, - kv_quant=quantizations.configure_kv_quant(cfg), - prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]), - ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]), - compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]), - reshape_q=cfg.reshape_q, - ) - - attention_lnx = attention_layer( + attention_lnx = self.attention_layer( lnx, lnx, decoder_positions, @@ -107,27 +145,16 @@ def __call__( deterministic=deterministic, model_mode=model_mode, ) - attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) # MLP block. - mlp_lnx = linears.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="mlp", - config=cfg, - quant=self.quant, - )(lnx, deterministic=deterministic) + mlp_lnx = self.mlp_block(lnx, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) next_layer_addition = mlp_lnx + attention_lnx - - next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - next_layer_addition, deterministic=deterministic - ) + + next_layer_addition_dropped_out = self.out_dropout( + next_layer_addition, deterministic=deterministic) layer_output = next_layer_addition_dropped_out + inputs layer_output = nn.with_logical_constraint( @@ -144,7 +171,7 @@ def __call__( jnp.sum(layer_output == 0) / jnp.size(layer_output), ) - return layer_output, None if cfg.scan_layers else layer_output + return (layer_output, None) if cfg.scan_layers else layer_output class SequentialBlockDecoderLayers(nn.Module): """Sequential unscanned series of decoder layers.""" @@ -165,14 +192,18 @@ def __call__(self, inputs: jnp.ndarray, decoder_segment_ids, decoder_positions, model_mode, ) return inputs + +#################################################################################################### -class Decoder(nn.Module): +@dataclasses.dataclass +class Decoder(nnx.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" config: Config shared_embedding: nn.Module mesh: Mesh quant: Optional[Quant] = None + rngs: nnx.Rngs | None = None def get_decoder_layer(self): if self.config.decoder_block == "default": @@ -183,7 +214,7 @@ def get_decoder_layer(self): return llama2.LlamaDecoderLayer elif self.config.decoder_block == "mistral": # TODO(ranran): update to Mistral with sliding window attention - from layers import mistral + from nnx_layers import mistral return mistral.MistralDecoderLayer elif self.config.decoder_block == "gemma": @@ -202,15 +233,11 @@ def get_decoder_layer(self): from layers import simple_layer return simple_layer.SimpleDecoderLayer - elif self.config.decoder_block == "simple_mlp": - from layers import simple_layer - - return simple_layer.SimpleMlpDecoderLayer else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") def get_norm_layer(self): - if self.config.decoder_block in ("default", "llama2", "mistral", "gemma", "gemma2", "simple", "simple_mlp"): + if self.config.decoder_block in ("default", "llama2", "mistral", "gemma", "gemma2", "simple"): return RMSNorm elif self.config.decoder_block == "gpt3": from layers import gpt3 @@ -219,10 +246,21 @@ def get_norm_layer(self): else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") - def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mesh): - initializing = self.is_mutable_collection("params") + def make_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mesh): + #initializing = self.is_mutable_collection("params") + initializing = True params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) cache_spec = 0 + + # convert keyword args to positional arguments, nnx.remat layer only supports positional + args = dict(config=cfg, mesh=mesh, quant=self.quant, rngs=self.rngs) + arg_kws = list(signature(decoder_layer.__init__).parameters.keys())[1:] # skip "self" + args = [args.get(kw, None) for kw in arg_kws] + # call example: + # decoder_layers(y, decoder_segment_ids, decoder_positions, deterministic, model_mode) + return nnx.Scan.constructor(decoder_layer, length=cfg.base_num_decoder_layers, + in_axes=(0, 0, None, None, None, None))(*args) + scan_fn = nn.scan( decoder_layer, variable_axes={ @@ -245,39 +283,66 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mes length=length, metadata_params={nn.PARTITION_NAME: metdata_axis_name}, ) + #return bridge.ToNNX(scan_fn(config=cfg, mesh=mesh, name="layers", quant=self.quant), rngs=self.rngs) return scan_fn(config=cfg, mesh=mesh, name="layers", quant=self.quant) - @nn.compact - def __call__( - self, - decoder_input_tokens, - decoder_positions, - decoder_segment_ids=None, - deterministic=False, - model_mode=common_types.MODEL_MODE_TRAIN, - ): + def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mesh): + #initializing = self.is_mutable_collection("params") + initializing = True + params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) + cache_spec = 0 + #scan_fn = nn.scan( + scan_fn = nnx.scan( + decoder_layer, + variable_axes={ + "params": params_spec, + "cache": cache_spec, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={ + "params": True, + "dropout": cfg.enable_dropout, + }, + in_axes=( + nn.broadcast, + nn.broadcast, + nn.broadcast, + nn.broadcast, + ), + length=length, + metadata_params={nn.PARTITION_NAME: metdata_axis_name}, + ) + #return bridge.ToNNX(scan_fn(config=cfg, mesh=mesh, name="layers", quant=self.quant), rngs=self.rngs) + return scan_fn(config=cfg, mesh=mesh, name="layers", quant=self.quant) + + def __post_init__(self): cfg = self.config mesh = self.mesh - assert decoder_input_tokens.ndim == 2 # [batch, len] - # [batch, length] -> [batch, length, emb_dim] - y = self.shared_embedding(decoder_input_tokens.astype("int32")) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) - y = y.astype(cfg.dtype) + # input dropout ################################################################################ + self.input_dropout = nnx.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,), + rngs=self.rngs) + # input dropout ################################################################################ + # position embedding ########################################################################### if cfg.use_untrainable_positional_embedding: - y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) - + self.positional_embedding = PositionalEmbedding(cfg.emb_dim, rngs=self.rngs) if cfg.trainable_position_size > 0: - y += Embed( - num_embeddings=cfg.trainable_position_size, - features=cfg.emb_dim, - dtype=cfg.dtype, - embedding_init=nn.initializers.normal(stddev=1.0), - name="position_embedder", - config=cfg, - )(decoder_positions) + self.trainable_position_embedder = Embed( + num_embeddings=cfg.trainable_position_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + name="position_embedder", + config=cfg, + rngs=self.rngs, + ) + # position embedding ########################################################################### + + # block decoder layer ########################################################################## BlockLayer = self.get_decoder_layer() if cfg.remat_policy != "none": @@ -331,55 +396,128 @@ def __call__( assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" policy = None - RemattedBlockLayer = nn.remat( # pylint: disable=invalid-name + RemattedBlockLayer = nnx.remat( # pylint: disable=invalid-name BlockLayer, prevent_cse=not cfg.scan_layers, policy=policy, static_argnums=(-1, -2, -3, -4, -5), ) - if cfg.using_pipeline_parallelism: - if cfg.num_layers_per_pipeline_stage == 1: - stage_module = BlockLayer(config=cfg, mesh=mesh, quant=self.quant) - elif cfg.scan_layers: - stage_module = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_layers_per_pipeline_stage, "layers_per_stage", mesh) - elif not cfg.scan_layers: - stage_module=SequentialBlockDecoderLayers(decoder_layer=RemattedBlockLayer, num_decoder_layers=cfg.num_layers_per_pipeline_stage, config=cfg, mesh=mesh,quant=self.quant) - - y = pipeline.Pipeline(config=cfg, mesh=mesh, layers=stage_module, remat_policy=policy)( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) - else: + + #if cfg.using_pipeline_parallelism: + # if cfg.num_layers_per_pipeline_stage == 1: + # stage_module = BlockLayer(config=cfg, mesh=mesh, quant=self.quant) + # elif cfg.scan_layers: + # stage_module = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_layers_per_pipeline_stage, "layers_per_stage", mesh) + # elif not cfg.scan_layers: + # stage_module=SequentialBlockDecoderLayers(decoder_layer=RemattedBlockLayer, num_decoder_layers=cfg.num_layers_per_pipeline_stage, config=cfg, mesh=mesh,quant=self.quant) + + # y = pipeline.Pipeline(config=cfg, mesh=mesh, layers=stage_module, remat_policy=policy)( + # y, + # decoder_segment_ids, + # decoder_positions, + # deterministic, + # model_mode, + # ) + #else: + if True: if cfg.scan_layers: - y, _ = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_decoder_layers, "layers", mesh)( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) + self.decoder_layers = self.make_decoder_layers(cfg, RemattedBlockLayer, cfg.num_decoder_layers, "layers", mesh) else: + self.decoder_layers = [] for lyr in range(cfg.num_decoder_layers): - y = RemattedBlockLayer(config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant)( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) - - y = self.get_norm_layer()( + arg_keys = list(signature(BlockLayer.__init__).parameters.keys())[1:] # skip "self" + args = dict(config=cfg, mesh=mesh, quant=self.quant, rngs=self.rngs, name=f"layers_{lyr}") + args = [args.get(kw, None) for kw in arg_keys] + self.decoder_layers.append(RemattedBlockLayer(*args)) + # block decoder layer ########################################################################## + + # norm layer ################################################################################### + #self.norm_layer = bridge.ToNNX(self.get_norm_layer()( + self.norm_layer = self.get_norm_layer()( + features=cfg.emb_dim, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="decoder_norm", epsilon=cfg.normalization_layer_epsilon, kernel_axes=("norm",), - )(y) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + rngs=self.rngs, + ) + #), rngs=self.rngs) + self.norm_layer_dropout = nnx.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,), + rngs=self.rngs) + # norm layer ################################################################################### + + # logits mapping ############################################################################### + if not cfg.logits_via_embedding: + self.logits_transpose = linears.DenseGeneral( + cfg.emb_dim, + cfg.vocab_size, + weight_dtype=cfg.weight_dtype, + dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability + kernel_axes=("embed", "vocab"), + name="logits_dense", + rngs=self.rngs) + # logits mapping ############################################################################### + + # output ####################################################################################### + # output ####################################################################################### + + #@nn.compact + def __call__( + self, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=common_types.MODEL_MODE_TRAIN, + ): + cfg = self.config + #mesh = self.mesh + assert decoder_input_tokens.ndim == 2 # [batch, len] + + # [batch, length] -> [batch, length, emb_dim] + y = self.shared_embedding(decoder_input_tokens.astype("int32")) + self.sow(nnx.Intermediate, "rdyro_shared_embedding", y) + # input dropout ################################################################################ + y = self.input_dropout(y, deterministic=deterministic) + y = y.astype(cfg.dtype) + self.sow(nnx.Intermediate, "rdyro_input_droput", y) + # input dropout ################################################################################ + + # position embedding ########################################################################### + if cfg.use_untrainable_positional_embedding: + y = self.positional_embedding(y, decoder_positions) + + if cfg.trainable_position_size > 0: + y += self.trainable_position_embedder(decoder_positions) + # position embedding ########################################################################### + + + # block decoder layer ########################################################################## + self.sow(nnx.Intermediate, "rdyro_before_scan", y) + if True: + if cfg.scan_layers: + #y, _ = nnx.scan(lambda y, layer: layer(y, decoder_segment_ids, decoder_positions, deterministic, model_mode))(y, self.decoder_layers) + y, _ = self.decoder_layers(y, decoder_segment_ids, decoder_positions, deterministic, StaticStr(model_mode)) + #y, _ = self.decoder_layers(y, decoder_segment_ids, decoder_positions, deterministic, model_mode) + else: + layer: nnx.Module + for lyr, layer in enumerate(self.decoder_layers): + #layer.lazy_init(y, decoder_segment_ids, decoder_positions, deterministic, model_mode) + y = layer(y, decoder_segment_ids, decoder_positions, deterministic, model_mode) + self.sow(nnx.Intermediate, f"rdyro_after_layer_{lyr}", y) + self.sow(nnx.Intermediate, "rdyro_after_scan", y) + # block decoder layer ########################################################################## + + # norm layer ################################################################################### + #self.norm_layer.lazy_init(y) + y = self.norm_layer(y) + y = self.norm_layer_dropout(y, deterministic=deterministic) + self.sow(nnx.Intermediate, "rdyro_after_norm", y) + # norm layer ################################################################################### + + # logits mapping ############################################################################### # [batch, length, emb_dim] -> [batch, length, vocab_size] if cfg.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. @@ -391,22 +529,20 @@ def __call__( logits = logits / cfg.final_logits_soft_cap logits = jnp.tanh(logits) * cfg.final_logits_soft_cap else: - logits = linears.DenseGeneral( - cfg.vocab_size, - weight_dtype=cfg.weight_dtype, - dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability - kernel_axes=("embed", "vocab"), - name="logits_dense", - matmul_precision=self.config.matmul_precision, - )( - y - ) # We do not quantize the logits matmul. + #self.logits_transpose.lazy_init(y) + logits = self.logits_transpose(y) # We do not quantize the logits matmul. + self.sow(nnx.Intermediate, "rdyro_after_logits", logits) + # logits mapping ############################################################################### + + # output ####################################################################################### logits = nn.with_logical_constraint(logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab")) logits = logits.astype(jnp.float32) return logits + # output ####################################################################################### -class Transformer(nn.Module): +@dataclasses.dataclass +class Transformer(nnx.Module): """An decoder-only Transformer model.""" # Make new attributes required, so that all Transformer dependencies (train, decode, compile, etc) will error instead of silently use defaults. @@ -414,10 +550,12 @@ class Transformer(nn.Module): config: Config mesh: Mesh quant: Quant + rngs: nnx.Rngs | None = None - def setup(self): + def __post_init__(self): """Initialize shared_embedding & decoder layers.""" - + if self.rngs is None: + self.rngs = nnx.Rngs(time.time_ns() % 2 ** 31) cfg = self.config mesh = self.mesh self.shared_embedding = Embed( @@ -426,11 +564,13 @@ def setup(self): dtype=cfg.dtype, attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability embedding_init=nn.initializers.normal(stddev=1.0), - name="token_embedder", + #name="token_embedder", config=cfg, + rngs=self.rngs, ) - self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant) + self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, + quant=self.quant, rngs=self.rngs) def __call__( self, diff --git a/MaxText/nnx_layers/normalizations.py b/MaxText/nnx_layers/normalizations.py index 862c586c9..884635a81 100644 --- a/MaxText/nnx_layers/normalizations.py +++ b/MaxText/nnx_layers/normalizations.py @@ -15,8 +15,10 @@ """Normalization Layers.""" from typing import Any, Tuple +import dataclasses from flax import linen as nn +from flax import nnx from jax import lax import jax.numpy as jnp from layers import initializers @@ -24,28 +26,38 @@ Initializer = initializers.Initializer -class RMSNorm(nn.Module): +@dataclasses.dataclass +class RMSNorm(nnx.Module): """RMS normalization.""" + features: int epsilon: float = 1e-6 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 kernel_axes: Tuple[str, ...] = () scale_init: Initializer = nn.initializers.ones + name: str = "rms_norm" + rngs: nnx.Rngs | None = None + + def __post_init__(self): + #value = self.scale_init(self.rngs(), (self.features,), self.weight_dtype) + #self.scale = nnx.Param(value, names=self.kernel_axes) + self.scale = nnx.Param( + nnx.with_partitioning(self.scale_init, self.kernel_axes)( + self.rngs(), (self.features,), self.weight_dtype)) - @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Applies layer normalization on the input.""" x = jnp.asarray(x, jnp.float32) - features = x.shape[-1] + assert self.features == x.shape[-1], f"{self.features} != {x.shape[-1]}" + #features = x.shape[-1] mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) - scale = self.param( - "scale", - nn.with_logical_partitioning(self.scale_init, self.kernel_axes), - (features,), - self.weight_dtype, - ) - - scale = jnp.asarray(scale, self.dtype) + #scale = self.param( + # "scale", + # nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + # (features,), + # self.weight_dtype, + #) + scale = jnp.asarray(self.scale.value, self.dtype) return y * scale diff --git a/MaxText/nnx_layers/quantizations.py b/MaxText/nnx_layers/quantizations.py index 5b5fcf050..26842db3a 100644 --- a/MaxText/nnx_layers/quantizations.py +++ b/MaxText/nnx_layers/quantizations.py @@ -143,7 +143,7 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): def einsum(self, mesh_axes: Tuple[str, ...] = ()): """Returns einsum configured with aqt params.""" - rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper( + rhs_axis_metadata_wrapper=self._get_rhs_axis_metadata_wrapper( mesh_axes) aqt_einsum = functools.partial( aqt_flax.AqtEinsum( diff --git a/MaxText/train.py b/MaxText/train.py index e0dc1953c..5628691d6 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -8,7 +8,7 @@ https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. @@ -48,6 +48,7 @@ # Placeholder: internal from input_pipeline.input_pipeline_interface import create_data_iterator +from nnx_layers import models as nnx_models from layers import models import jax.numpy as jnp @@ -65,7 +66,7 @@ from ml_goodput_measurement import goodput from ml_goodput_measurement import monitoring -Transformer = models.Transformer +#Transformer = models.Transformer EPS = 1e-8 _CHUNK_BYTE_SIZE = 2 * 1024 **3 @@ -435,7 +436,7 @@ def check_example_batch(config, example_batch): err, _ = jax.jit(jittable_f)(example_batch['inputs'][: config.global_batch_size_to_train_on, :]) err.throw() -def setup_mesh_and_model(config): +def setup_mesh_and_model(config, use_nnx: bool): """Set up the mesh and the model for training Args: @@ -461,7 +462,10 @@ def setup_mesh_and_model(config): # Model and Optimizer definition quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh, quant=quant) + if use_nnx: + model = nnx_models.Transformer(config, mesh, quant=quant) + else: + model = models.Transformer(config, mesh, quant=quant) learning_rate_schedule = max_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) logger = checkpointing.setup_checkpoint_logger(config) diff --git a/testing_nnx_layers.ipynb b/testing_nnx_layers.ipynb new file mode 100644 index 000000000..632431956 --- /dev/null +++ b/testing_nnx_layers.ipynb @@ -0,0 +1,6477 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-08-28 15:03:12.386928: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-08-28 15:03:12.398480: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-08-28 15:03:12.402136: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2024-08-28 15:03:13.345597: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "/home/rdyro/.pyenv/versions/devel/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import sys\n", + "import os\n", + "from pathlib import Path\n", + "import time\n", + "import dataclasses\n", + "import functools\n", + "import contextlib\n", + "from io import StringIO\n", + "from pprint import pprint\n", + "\n", + "os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=2\"\n", + "\n", + "import jax\n", + "from jax import numpy as jnp, random as jrandom\n", + "from flax import nnx, linen as nn\n", + "from jax import sharding\n", + "from jax.sharding import Mesh\n", + "from jax.experimental.mesh_utils import create_device_mesh\n", + "import optax\n", + "import equinox as eqx\n", + "#from flax.linen import partitioning as nn_partitioning\n", + "#from flax.core import meta\n", + "from flax.nnx import bridge\n", + "\n", + "paths = [Path(\"MaxText\").absolute()]\n", + "[sys.path.append(str(path)) for path in paths if str(path) not in sys.path]\n", + "\n", + "from MaxText.layers.normalizations import RMSNorm\n", + "from MaxText.nnx_layers.normalizations import RMSNorm as NNXRMSNorm\n", + "from MaxText.nnx_layers import LinenToNNX\n", + "from MaxText import pyconfig, train, max_utils\n", + "from MaxText.nnx_layers.models import Transformer\n", + "from MaxText.layers.models import Transformer as Transformer2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num_devices: 2, shape (1, 1, 2, 1, 1, 1, 1)\n" + ] + } + ], + "source": [ + "buf = StringIO()\n", + "with contextlib.redirect_stdout(buf):\n", + " pyconfig.initialize([\"python3\", \"MaxText/configs/base.yml\", \"hardware=other\", \n", + " \"enable_single_controller=True\", \"decoder_block=default\", \n", + " \"scan_layers=True\"])\n", + " config = pyconfig.config\n", + "pyconfig_output = (buf.seek(0), buf.read())[1]\n", + "\n", + "input_tokens, input_positions = jnp.array([[0]]), jnp.array([[0]])\n", + "devices_array = max_utils.create_device_mesh(config)\n", + "mesh = Mesh(devices_array, config.mesh_axes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Testing `nnx.Scan` and others" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "class MyMod(nnx.Module):\n", + " def __init__(self, in_dim, out_dim, rngs):\n", + " self.rngs = rngs\n", + " self.linear1 = nnx.Linear(in_dim, 2 * out_dim, rngs=self.rngs)\n", + " #self.linear2 = nnx.Linear(2 * out_dim, out_dim, rngs=self.rngs)\n", + " self.linear2 = bridge.ToNNX(nn.Dense(out_dim), rngs=self.rngs)\n", + "\n", + " def __call__(self, x):\n", + " return self.linear2(jax.nn.tanh(self.linear1(x))), None" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array([[-0.08562934, 0.16581826, 0.07718924, 0.07391047, -0.11791413,\n", + " -0.39341435, -0.3302102 , 0.02588149, -0.15963459, 0.05993025]], dtype=float32),\n", + " None)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def init_fn(x, mod: nnx.Scan):\n", + " return bridge.lazy_init(mod.scan_module, x)\n", + " \n", + "smod = nnx.Scan.constructor(MyMod, length=5)(in_dim=10, out_dim=10, rngs=nnx.Rngs(0))\n", + "#nnx.vmap(init_fn, in_axes=(None, 0))(jnp.ones((1, 10)), smod)\n", + "\n", + "bridge.lazy_init(smod, jnp.ones((1, 10)))\n", + "smod(jnp.ones((1, 10)))" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "model = Transformer(config, mesh, quant=None, \n", + " rngs=nnx.Rngs(default=0, params=0))\n", + "#with jax.profiler.trace(\"nnx_init\"):\n", + "# model(input_tokens, input_positions)\n", + "#model = Transformer(config, mesh, quant=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Transformer(config=, mesh=Mesh(device_ids=array([[[[[[[0]]]],\n", + "\n", + "\n", + "\n", + " [[[[1]]]]]]]), axis_names=('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive')), quant=None, rngs=Rngs(\n", + " default=RngStream(\n", + " count=RngCount(\n", + " tag='default',\n", + " value=Array(3, dtype=uint32)\n", + " ),\n", + " key=RngKey(\n", + " tag='default',\n", + " value=Array((), dtype=key) overlaying:\n", + " [0 0]\n", + " )\n", + " ),\n", + " params=RngStream(\n", + " count=RngCount(\n", + " tag='params',\n", + " value=Array(1, dtype=uint32)\n", + " ),\n", + " key=RngKey(\n", + " tag='params',\n", + " value=Array((), dtype=key) overlaying:\n", + " [0 0]\n", + " )\n", + " )\n", + "))" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bridge.lazy_init(model, input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[[ 0.41741037, -1.0495998 , -0.6531233 , ..., -0.22075425,\n", + " -1.859736 , 0.42536822]]], dtype=float32)" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "state = nnx.state(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flax.core.meta import Partitioned\n", + "jax.tree.map(lambda x: None if not isinstance(x, Partitioned) else x, state, is_leaf=lambda x: isinstance(x, Partitioned))" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [], + "source": [ + "class LazyMod(nnx.Module):\n", + " def __init__(self):\n", + " self.p1 = nnx.Param(nnx.with_partitioning(jnp.ones, (\"embed\", \"fsdp\"))(100))\n", + "\n", + " def __call__(self, x):\n", + " if not hasattr(self, \"p2\"):\n", + " self.p2 = nnx.Param(x, sharding=(\"embed\", \"embed\"))\n", + " return self.p1 + (self.p2 * x)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "@dataclasses.dataclass(unsafe_hash=True)\n", + "class MeshRules:\n", + " embed: str | None = None\n", + " mlp: str | None = None\n", + " kv: str | None = None\n", + " vocab: str | None = None\n", + "\n", + " def __call__(self, *keys: str) -> tuple[str, ...]:\n", + " return tuple(getattr(self, key) for key in keys)\n", + " \n", + "mesh_rules = MeshRules(embed='fsdp', mlp='tensor', kv='tensor', vocab='tensor')" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('fsdp', 'tensor')" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mesh_rules(\"embed\", \"mlp\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "p = nnx.Param(nnx.with_partitioning(jnp.zeros, (\"fsdp\", \"tensor\"))((100, 2)))" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Param(\n", + " value=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32),\n", + " sharding=(),\n", + " mesh=None\n", + ")" + ] + }, + "execution_count": 116, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "NNXRMSNorm(100, rngs=nnx.Rngs(time.time_ns())).scale" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test Flax DictModule" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "class DictMmodule(nn.Module):\n", + " def __init__(self, d: dict | None = None):\n", + " if d is None:\n", + " return\n", + "\n", + " assert isinstance(d, dict)\n", + " for key, item in d.items():\n", + " self[key] = item\n", + "\n", + " def __getitem__(self, key):\n", + " return getattr(self, key)\n", + "\n", + " def __setitem__(self, key, value):\n", + " setattr(self, key, value)\n", + "\n", + " def __repr__(self):\n", + " out = \"DictModule(\\n\"\n", + " for attr in dir(self):\n", + " if not attr.startswith(\"_\"):\n", + " line = str(getattr(self, attr))\n", + " line = \"\\n\".join(\" \" + row for row in line.strip().split(\"\\n\"))\n", + " out += line + \",\\n\"\n", + " out += \")\"\n", + " return out\n", + "\n", + "\n", + "class DictModuleV2(nn.Module):\n", + " #def __init__(self, d: dict | None = None):\n", + " # super().__init__()\n", + " # self._setup(d)\n", + " # self.d = d\n", + " \n", + " def __getitem__(self, key):\n", + " #return getattr(self, key)\n", + " return self.get(key)\n", + "\n", + " def __setitem__(self, key, value):\n", + " self._user_set.add(key)\n", + " #setattr(self, key, value)\n", + " self.set(key, value)\n", + " \n", + " def get(self, key):\n", + " return self.d[key]\n", + "\n", + " def set(self, key, value):\n", + " self.d[key] = value\n", + " \n", + " def setup(self, d: dict | None = None):\n", + " self._setup(d)\n", + "\n", + " def _setup(self, d: dict | None = None):\n", + " self.d = d\n", + " self._user_set = set()\n", + " #if d is None:\n", + " # return\n", + " #assert isinstance(d, dict)\n", + " #for key, item in d.items():\n", + " # #self[key] = item\n", + " # #self.set(key, item)\n", + " # self._user_set.add(key)\n", + " \n", + " def add_modules(self, d):\n", + " self.d = d\n", + "\n", + " def __call__(self, key, *args, **kw):\n", + " #return self[key](*args, **kw)\n", + " return self.get(key)(*args, **kw)\n", + "\n", + " def __repr__(self):\n", + " out = \"DictModuleV2(\\n\"\n", + " for key in self._user_set:\n", + " #line = str(getattr(self, key))\n", + " line = str(self.get(key))\n", + " line = \"\\n\".join(\" \" + f\"{key}={row}\" for row in line.strip().split(\"\\n\"))\n", + " out += line + \",\\n\"\n", + " out += \")\"\n", + " return out\n" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "\"DictModuleV2\" object has no attribute \"_user_set\". If \"_user_set\" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[60], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m mod \u001b[38;5;241m=\u001b[39m DictModuleV2()\n\u001b[1;32m 2\u001b[0m mod\u001b[38;5;241m.\u001b[39madd_modules({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ma\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhello\u001b[39m\u001b[38;5;124m\"\u001b[39m})\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;43mprint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmod\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[59], line 71\u001b[0m, in \u001b[0;36mDictModuleV2.__repr__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__repr__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 70\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDictModuleV2(\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 71\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_user_set\u001b[49m:\n\u001b[1;32m 72\u001b[0m \u001b[38;5;66;03m#line = str(getattr(self, key))\u001b[39;00m\n\u001b[1;32m 73\u001b[0m line \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget(key))\n\u001b[1;32m 74\u001b[0m line \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrow\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m line\u001b[38;5;241m.\u001b[39mstrip()\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n", + "File \u001b[0;32m~/.pyenv/versions/devel/lib/python3.11/site-packages/flax/linen/module.py:1304\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1299\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscope \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1300\u001b[0m msg \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1301\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m If \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m is defined in \u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m.setup()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m, remember these fields \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 1302\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mare only accessible from inside \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minit\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m or \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mapply\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1303\u001b[0m )\n\u001b[0;32m-> 1304\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(msg)\n", + "\u001b[0;31mAttributeError\u001b[0m: \"DictModuleV2\" object has no attribute \"_user_set\". If \"_user_set\" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'." + ] + } + ], + "source": [ + "mod = DictModuleV2()\n", + "mod.add_modules({\"a\": \"hello\"})\n", + "print(mod)" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [], + "source": [ + "class FlaxModel(nn.Module):\n", + " def setup(self): \n", + " self.mod_dict = DictModuleV2()\n", + " self.mod_dict.add_modules({\"a\": nn.Dense(10), \"b\": nn.Dense(20)})\n", + "\n", + " def __call__(self, x):\n", + " #return mod_dict[\"b\"](jax.nn.tanh(mod_dict[\"a\"](x)))\n", + " return self.mod_dict(\"b\", jax.nn.tanh(self.mod_dict(\"a\", x)))" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "d = DictMmodule({\"hi\": nn.Dense(2)})\n", + "d[\"hello\"] = nn.Dense(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Continue with Transformer testing" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "model = Transformer(config, mesh, quant=None, \n", + " rngs=nnx.Rngs(time.time_ns() % 2 ** 31))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 1, 32000)\n" + ] + } + ], + "source": [ + "bridge.lazy_init(model, input_tokens, input_positions)\n", + "out = model(input_tokens, input_positions)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def get_model(input_tokens, input_positions):\n", + " model = Transformer(config, mesh, quant=None, \n", + " rngs=nnx.Rngs(time.time_ns() % 2 ** 31))\n", + " bridge.lazy_init(model, input_tokens, input_positions)\n", + " return nnx.split(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "mod, params = jax.eval_shape(get_model, input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "all_leaves = jax.tree.leaves(params, is_leaf=lambda x: isinstance(x, nnx.VariableState))\n", + "all_leaves = [x for x in all_leaves if not issubclass(x.type, nnx.RngState)]" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all(hasattr(x, \"sharding\") for x in all_leaves)" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 1, 32000)\n" + ] + } + ], + "source": [ + "model2 = Transformer2(config, mesh, quant=None)\n", + "params2 = model2.init(nnx.Rngs(default=0, params=0)(), input_tokens, input_positions)\n", + "out = model2.apply(params2, input_tokens, input_positions)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def get_model_def():\n", + " model = Transformer(config, mesh, quant=None, \n", + " rngs=nnx.Rngs(default=0, params=0))\n", + " bridge.lazy_init(model, input_tokens, input_positions)\n", + " #model(input_tokens, input_positions)\n", + " return nnx.split(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "gdef, _ = jax.eval_shape(get_model_def)\n", + "_, params = get_model_def()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "#@functools.partial(jax.jit, static_argnums=(0,))\n", + "@jax.jit\n", + "def fwd_fn(gdef, state, *input):\n", + " return nnx.merge(gdef, state)(*input)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "with jax.log_compiles():\n", + " fwd_fn(gdef, params, input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[[ 0.90953326, 1.16376 , 1.1164488 , ..., 0.5105222 ,\n", + " -0.5272695 , 0.1440109 ]]], dtype=float32)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "@nnx.jit\n", + "def eval_param_shape():\n", + " input_tokens, input_positions = jnp.array([[0]]), jnp.array([[0]])\n", + " with nn_partitioning.axis_rules(config.logical_axis_rules):\n", + " model = Transformer(config, mesh, quant=None)\n", + " model(input_tokens, input_positions)\n", + " #return nnx.state(model)\n", + " return nnx.split(model)\n", + "\n", + "def get_model_def():\n", + " input_tokens, input_positions = jnp.array([[0]]), jnp.array([[0]])\n", + " with nn_partitioning.axis_rules(config.logical_axis_rules):\n", + " model = Transformer(config, mesh, quant=None)\n", + " model(input_tokens, input_positions)\n", + " return nnx.graphdef(model)\n", + " #return nnx.split(model)[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/core.py:691: FutureWarning: unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release.\n", + " warnings.warn(\n", + "/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/core.py:691: FutureWarning: unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release.\n", + " warnings.warn(\n", + "/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/core.py:691: FutureWarning: unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release.\n", + " warnings.warn(\n", + "/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/core.py:691: FutureWarning: unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "gdef, params = jax.eval_shape(eval_param_shape)\n", + "state_logical_annotations = nn.get_partition_spec(params)\n", + "state_mesh_shardings = nn.logical_to_mesh_sharding(\n", + " state_logical_annotations, mesh, config.logical_axis_rules)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GraphDef(\n", + " nodedef=NodeDef(\n", + " type=Transformer,\n", + " index=0,\n", + " attributes=('config', 'decoder', 'mesh', 'quant', 'shared_embedding'),\n", + " subgraphs={\n", + " 'decoder': NodeDef(\n", + " type=LinenToNNX,\n", + " index=1,\n", + " attributes=('deterministic', 'initialized', 'linen_module', 'linen_state', 'rngs', 'use_running_average'),\n", + " subgraphs={\n", + " 'linen_state': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('params',),\n", + " subgraphs={\n", + " 'params': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('decoder_norm', 'layers', 'logits_dense'),\n", + " subgraphs={\n", + " 'decoder_norm': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('scale',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'scale': 2\n", + " },\n", + " metadata=PyTreeDef({'scale': *})\n", + " ),\n", + " 'layers': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('mlp', 'post_self_attention_layer_norm', 'pre_self_attention_layer_norm', 'self_attention'),\n", + " subgraphs={\n", + " 'mlp': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('wi_0', 'wi_1', 'wo'),\n", + " subgraphs={\n", + " 'wi_0': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 3\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " ),\n", + " 'wi_1': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 4\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " ),\n", + " 'wo': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 5\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " )\n", + " },\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef({'wi_0': *, 'wi_1': *, 'wo': *})\n", + " ),\n", + " 'post_self_attention_layer_norm': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('scale',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'scale': 6\n", + " },\n", + " metadata=PyTreeDef({'scale': *})\n", + " ),\n", + " 'pre_self_attention_layer_norm': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('scale',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'scale': 7\n", + " },\n", + " metadata=PyTreeDef({'scale': *})\n", + " ),\n", + " 'self_attention': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('key', 'out', 'query', 'value'),\n", + " subgraphs={\n", + " 'key': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 8\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " ),\n", + " 'out': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 9\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " ),\n", + " 'query': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 10\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " ),\n", + " 'value': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 11\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " )\n", + " },\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef({'key': *, 'out': *, 'query': *, 'value': *})\n", + " )\n", + " },\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef({'mlp': *, 'post_self_attention_layer_norm': *, 'pre_self_attention_layer_norm': *, 'self_attention': *})\n", + " ),\n", + " 'logits_dense': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 12\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " )\n", + " },\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef({'decoder_norm': *, 'layers': *, 'logits_dense': *})\n", + " )\n", + " },\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef({'params': *})\n", + " ),\n", + " 'rngs': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=(),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef(None)\n", + " )\n", + " },\n", + " static_fields={\n", + " 'deterministic': False,\n", + " 'initialized': True,\n", + " 'linen_module': Decoder(\n", + " # attributes\n", + " config = \n", + " shared_embedding = Embed(config=, num_embeddings=32000, features=2048, cast_input_dtype=None, dtype=dtype(bfloat16), attend_dtype=, embedding_init=.init at 0x7f49c411f740>, rngs=None)\n", + " mesh = Mesh(device_ids=array([[[[[[[0]]]],\n", + " \n", + " \n", + " \n", + " [[[[1]]]]]]]), axis_names=('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'))\n", + " quant = None\n", + " ),\n", + " 'use_running_average': False\n", + " },\n", + " leaves={},\n", + " metadata=\n", + " ),\n", + " 'quant': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=(),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef(None)\n", + " ),\n", + " 'shared_embedding': NodeDef(\n", + " type=Embed,\n", + " index=13,\n", + " attributes=('attend_dtype', 'cast_input_dtype', 'config', 'dtype', 'embedding', 'embedding_init', 'features', 'num_embeddings'),\n", + " subgraphs={\n", + " 'cast_input_dtype': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=(),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef(None)\n", + " )\n", + " },\n", + " static_fields={\n", + " 'attend_dtype': ,\n", + " 'config': ,\n", + " 'dtype': dtype(bfloat16),\n", + " 'embedding_init': .init at 0x7f49c411f740>,\n", + " 'features': 2048,\n", + " 'num_embeddings': 32000\n", + " },\n", + " leaves={\n", + " 'embedding': 14\n", + " },\n", + " metadata=\n", + " )\n", + " },\n", + " static_fields={\n", + " 'config': ,\n", + " 'mesh': Mesh(device_ids=array([[[[[[[0]]]],\n", + " \n", + " \n", + " \n", + " [[[[1]]]]]]]), axis_names=('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'))\n", + " },\n", + " leaves={},\n", + " metadata=\n", + " ),\n", + " index_mapping=None\n", + ")" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gdef" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "params = jax.jit(lambda: eval_param_shape()[1], out_shardings=state_mesh_shardings)()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "State({\n", + " 'embedding': VariableState(\n", + " type=Param,\n", + " value=LogicallyPartitioned(value=Array([[ 0.0575558 , -0.31785473, 0.11529455, ..., 1.0821939 ,\n", + " 1.4235774 , -1.2933688 ],\n", + " [-2.0068665 , -0.06486757, 0.1310754 , ..., -1.5467689 ,\n", + " 0.37397835, 0.41232687],\n", + " [-0.57422966, 0.1731033 , 0.9584525 , ..., 0.07480869,\n", + " 0.15087242, 0.41225332],\n", + " ...,\n", + " [ 0.7054784 , -0.4994459 , 0.07542419, ..., -1.2780907 ,\n", + " -0.12462003, 0.4509493 ],\n", + " [ 1.3809816 , -1.2765152 , 0.77147233, ..., 1.7020334 ,\n", + " 0.6716798 , -0.24864346],\n", + " [ 1.4495107 , 0.41864708, 1.412156 , ..., -1.0488809 ,\n", + " 0.12066022, 1.5232936 ]], dtype=float32), names=('fsdp', 'embed'), mesh=None, rules=None)\n", + " )\n", + "})" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "params[\"shared_embedding\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "params2 = jax.tree.map(lambda x: x.unbox() if isinstance(x, meta.Partitioned) else x, params, is_leaf=lambda x: isinstance(x, meta.Partitioned))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "fwd_fn = jax.jit(lambda p, *inputs: nnx.merge(gdef, p)(*inputs))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GraphDef(\n", + " nodedef=NodeDef(\n", + " type=Transformer,\n", + " index=0,\n", + " attributes=('config', 'decoder', 'mesh', 'quant', 'shared_embedding'),\n", + " subgraphs={\n", + " 'decoder': NodeDef(\n", + " type=LinenToNNX,\n", + " index=1,\n", + " attributes=('deterministic', 'initialized', 'linen_module', 'linen_state', 'rngs', 'use_running_average'),\n", + " subgraphs={\n", + " 'linen_state': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('params',),\n", + " subgraphs={\n", + " 'params': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('decoder_norm', 'layers', 'logits_dense'),\n", + " subgraphs={\n", + " 'decoder_norm': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('scale',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'scale': 2\n", + " },\n", + " metadata=PyTreeDef({'scale': *})\n", + " ),\n", + " 'layers': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('mlp', 'post_self_attention_layer_norm', 'pre_self_attention_layer_norm', 'self_attention'),\n", + " subgraphs={\n", + " 'mlp': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('wi_0', 'wi_1', 'wo'),\n", + " subgraphs={\n", + " 'wi_0': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 3\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " ),\n", + " 'wi_1': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 4\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " ),\n", + " 'wo': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 5\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " )\n", + " },\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef({'wi_0': *, 'wi_1': *, 'wo': *})\n", + " ),\n", + " 'post_self_attention_layer_norm': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('scale',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'scale': 6\n", + " },\n", + " metadata=PyTreeDef({'scale': *})\n", + " ),\n", + " 'pre_self_attention_layer_norm': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('scale',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'scale': 7\n", + " },\n", + " metadata=PyTreeDef({'scale': *})\n", + " ),\n", + " 'self_attention': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('key', 'out', 'query', 'value'),\n", + " subgraphs={\n", + " 'key': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 8\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " ),\n", + " 'out': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 9\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " ),\n", + " 'query': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 10\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " ),\n", + " 'value': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 11\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " )\n", + " },\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef({'key': *, 'out': *, 'query': *, 'value': *})\n", + " )\n", + " },\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef({'mlp': *, 'post_self_attention_layer_norm': *, 'pre_self_attention_layer_norm': *, 'self_attention': *})\n", + " ),\n", + " 'logits_dense': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=('kernel',),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={\n", + " 'kernel': 12\n", + " },\n", + " metadata=PyTreeDef({'kernel': *})\n", + " )\n", + " },\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef({'decoder_norm': *, 'layers': *, 'logits_dense': *})\n", + " )\n", + " },\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef({'params': *})\n", + " ),\n", + " 'rngs': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=(),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef(None)\n", + " )\n", + " },\n", + " static_fields={\n", + " 'deterministic': False,\n", + " 'initialized': True,\n", + " 'linen_module': Decoder(\n", + " # attributes\n", + " config = \n", + " shared_embedding = Embed(config=, num_embeddings=32000, features=2048, cast_input_dtype=None, dtype=dtype(bfloat16), attend_dtype=, embedding_init=.init at 0x7f49c411f740>, rngs=None)\n", + " mesh = Mesh(device_ids=array([[[[[[[0]]]],\n", + " \n", + " \n", + " \n", + " [[[[1]]]]]]]), axis_names=('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'))\n", + " quant = None\n", + " ),\n", + " 'use_running_average': False\n", + " },\n", + " leaves={},\n", + " metadata=\n", + " ),\n", + " 'quant': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=(),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef(None)\n", + " ),\n", + " 'shared_embedding': NodeDef(\n", + " type=Embed,\n", + " index=13,\n", + " attributes=('attend_dtype', 'cast_input_dtype', 'config', 'dtype', 'embedding', 'embedding_init', 'features', 'num_embeddings'),\n", + " subgraphs={\n", + " 'cast_input_dtype': NodeDef(\n", + " type=PytreeType,\n", + " index=-1,\n", + " attributes=(),\n", + " subgraphs={},\n", + " static_fields={},\n", + " leaves={},\n", + " metadata=PyTreeDef(None)\n", + " )\n", + " },\n", + " static_fields={\n", + " 'attend_dtype': ,\n", + " 'config': ,\n", + " 'dtype': dtype(bfloat16),\n", + " 'embedding_init': .init at 0x7f49c411f740>,\n", + " 'features': 2048,\n", + " 'num_embeddings': 32000\n", + " },\n", + " leaves={\n", + " 'embedding': 14\n", + " },\n", + " metadata=\n", + " )\n", + " },\n", + " static_fields={\n", + " 'config': ,\n", + " 'mesh': Mesh(device_ids=array([[[[[[[0]]]],\n", + " \n", + " \n", + " \n", + " [[[[1]]]]]]]), axis_names=('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'))\n", + " },\n", + " leaves={},\n", + " metadata=\n", + " ),\n", + " index_mapping=None\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gdef" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "ename": "UnexpectedTracerError", + "evalue": "Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[32000,2048] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.\nJAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.\nThe function being traced when the value leaked was jit_fn at /home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/transforms/transforms.py:139 traced for jit.\n------------------------------\nThe leaked intermediate value was created on line /home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/linen/spmd.py:361:6 (with_logical_partitioning..wrapper). \n------------------------------\nWhen the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:\n------------------------------\n/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/object.py:88:2 (_graph_node_meta_call)\n/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/object.py:82:4 (ObjectMeta._object_meta_construct)\n:11:2 (__create_fn__..__init__)\n/home/rdyro/maxtext/MaxText/nnx_layers/embeddings.py:65:12 (Embed.__post_init__)\n/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/linen/spmd.py:361:6 (with_logical_partitioning..wrapper)\n------------------------------\n\nTo catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mUnexpectedTracerError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mfwd_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_tokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_positions\u001b[49m\u001b[43m)\u001b[49m\n", + " \u001b[0;31m[... skipping hidden 11 frame]\u001b[0m\n", + "Cell \u001b[0;32mIn[10], line 1\u001b[0m, in \u001b[0;36m\u001b[0;34m(p, *inputs)\u001b[0m\n\u001b[0;32m----> 1\u001b[0m fwd_fn \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mjit(\u001b[38;5;28;01mlambda\u001b[39;00m p, \u001b[38;5;241m*\u001b[39minputs: \u001b[43mnnx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmerge\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgdef\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m)\n", + "File \u001b[0;32m~/maxtext/MaxText/nnx_layers/models.py:483\u001b[0m, in \u001b[0;36mTransformer.__call__\u001b[0;34m(self, decoder_input_tokens, decoder_positions, decoder_segment_ids, enable_dropout, model_mode)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m decoder_segment_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m model_mode \u001b[38;5;241m==\u001b[39m common_types\u001b[38;5;241m.\u001b[39mMODEL_MODE_AUTOREGRESSIVE:\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 479\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDuring autoregressive decoding we assume the tokens are in the active sequence\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 480\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m which is always \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcommon_types\u001b[38;5;241m.\u001b[39mDECODING_ACTIVE_SEQUENCE_INDICATOR\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 481\u001b[0m )\n\u001b[0;32m--> 483\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 484\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder_input_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_input_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 485\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder_positions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_positions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 486\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder_segment_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_segment_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 487\u001b[0m \u001b[43m \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43menable_dropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m logits\n", + "File \u001b[0;32m~/maxtext/MaxText/nnx_layers/__init__.py:41\u001b[0m, in \u001b[0;36mLinenToNNX.__call__\u001b[0;34m(self, *args, **kw)\u001b[0m\n\u001b[1;32m 39\u001b[0m mutable_keys \u001b[38;5;241m=\u001b[39m [k \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlinen_state\u001b[38;5;241m.\u001b[39mkeys() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 40\u001b[0m linen_state \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mtree\u001b[38;5;241m.\u001b[39mmap(\u001b[38;5;28;01mlambda\u001b[39;00m x: x\u001b[38;5;241m.\u001b[39mvalue, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlinen_state)\n\u001b[0;32m---> 41\u001b[0m ret, updates \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinen_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlinen_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkw\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 42\u001b[0m \u001b[43m \u001b[49m\u001b[43mmutable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmutable_keys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;66;03m#print(f\"Update keys: {mutable_keys}\")\u001b[39;00m\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_running_average \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdeterministic:\n", + " \u001b[0;31m[... skipping hidden 6 frame]\u001b[0m\n", + "File \u001b[0;32m~/maxtext/MaxText/nnx_layers/models.py:261\u001b[0m, in \u001b[0;36mDecoder.__call__\u001b[0;34m(self, decoder_input_tokens, decoder_positions, decoder_segment_ids, deterministic, model_mode)\u001b[0m\n\u001b[1;32m 258\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m decoder_input_tokens\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;66;03m# [batch, len]\u001b[39;00m\n\u001b[1;32m 260\u001b[0m \u001b[38;5;66;03m# [batch, length] -> [batch, length, emb_dim]\u001b[39;00m\n\u001b[0;32m--> 261\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshared_embedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdecoder_input_tokens\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mint32\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 262\u001b[0m y \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mDropout(rate\u001b[38;5;241m=\u001b[39mcfg\u001b[38;5;241m.\u001b[39mdropout_rate, broadcast_dims\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m,))(y, deterministic\u001b[38;5;241m=\u001b[39mdeterministic)\n\u001b[1;32m 263\u001b[0m y \u001b[38;5;241m=\u001b[39m y\u001b[38;5;241m.\u001b[39mastype(cfg\u001b[38;5;241m.\u001b[39mdtype)\n", + "File \u001b[0;32m~/maxtext/MaxText/nnx_layers/embeddings.py:103\u001b[0m, in \u001b[0;36mEmbed.__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 101\u001b[0m output \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mdot(one_hot, jnp\u001b[38;5;241m.\u001b[39masarray(embedding, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype))\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 103\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43membedding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m[inputs]\n\u001b[1;32m 104\u001b[0m output \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mwith_logical_constraint(output, (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mactivation_embed_and_logits_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mactivation_length\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mactivation_embed\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3568\u001b[0m, in \u001b[0;36masarray\u001b[0;34m(a, dtype, order, copy, device)\u001b[0m\n\u001b[1;32m 3566\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3567\u001b[0m dtype \u001b[38;5;241m=\u001b[39m dtypes\u001b[38;5;241m.\u001b[39mcanonicalize_dtype(dtype, allow_extended_dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;66;03m# type: ignore[assignment]\u001b[39;00m\n\u001b[0;32m-> 3568\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3459\u001b[0m, in \u001b[0;36marray\u001b[0;34m(object, dtype, copy, order, ndmin, device)\u001b[0m\n\u001b[1;32m 3457\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3458\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected input type for array: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mobject\u001b[39m)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 3459\u001b[0m out_array: Array \u001b[38;5;241m=\u001b[39m \u001b[43mlax_internal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_convert_element_type\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3460\u001b[0m \u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweak_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweak_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msharding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msharding\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3461\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ndmin \u001b[38;5;241m>\u001b[39m ndim(out_array):\n\u001b[1;32m 3462\u001b[0m out_array \u001b[38;5;241m=\u001b[39m lax\u001b[38;5;241m.\u001b[39mexpand_dims(out_array, \u001b[38;5;28mrange\u001b[39m(ndmin \u001b[38;5;241m-\u001b[39m ndim(out_array)))\n", + " \u001b[0;31m[... skipping hidden 4 frame]\u001b[0m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1737\u001b[0m, in \u001b[0;36mDynamicJaxprTracer._assert_live\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_assert_live\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1736\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trace\u001b[38;5;241m.\u001b[39mmain\u001b[38;5;241m.\u001b[39mjaxpr_stack: \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[0;32m-> 1737\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m core\u001b[38;5;241m.\u001b[39mescaped_tracer_error(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n", + "\u001b[0;31mUnexpectedTracerError\u001b[0m: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[32000,2048] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.\nJAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.\nThe function being traced when the value leaked was jit_fn at /home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/transforms/transforms.py:139 traced for jit.\n------------------------------\nThe leaked intermediate value was created on line /home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/linen/spmd.py:361:6 (with_logical_partitioning..wrapper). \n------------------------------\nWhen the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:\n------------------------------\n/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/object.py:88:2 (_graph_node_meta_call)\n/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/object.py:82:4 (ObjectMeta._object_meta_construct)\n:11:2 (__create_fn__..__init__)\n/home/rdyro/maxtext/MaxText/nnx_layers/embeddings.py:65:12 (Embed.__post_init__)\n/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/linen/spmd.py:361:6 (with_logical_partitioning..wrapper)\n------------------------------\n\nTo catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError" + ] + } + ], + "source": [ + "fwd_fn(params2, input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#nnx.merge(get_model_def(), params2)(input_tokens, input_positions)\n", + "with jax.checking_leaks():\n", + " fwd_fn(params2, input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
┌───────┬───────┐\n",
+       "│       │       │\n",
+       "│       │       │\n",
+       "│       │       │\n",
+       "│       │       │\n",
+       "│ CPU 0 │ CPU 1 │\n",
+       "│       │       │\n",
+       "│       │       │\n",
+       "│       │       │\n",
+       "│       │       │\n",
+       "└───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┐\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ CPU \u001b[1;36m0\u001b[0m │ CPU \u001b[1;36m1\u001b[0m │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "└───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "jax.debug.visualize_array_sharding(params2[\"shared_embedding\"][\"embedding\"].value)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
┌──────────────────────────────────────────────────────────────────────────────┐\n",
+       "│                                                                              │\n",
+       "│                                    CPU 0                                     │\n",
+       "│                                                                              │\n",
+       "│                                                                              │\n",
+       "├──────────────────────────────────────────────────────────────────────────────┤\n",
+       "│                                                                              │\n",
+       "│                                    CPU 1                                     │\n",
+       "│                                                                              │\n",
+       "│                                                                              │\n",
+       "└──────────────────────────────────────────────────────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────────────────────────────────────────────────────────────────────────┐\n", + "│ │\n", + "│ CPU \u001b[1;36m0\u001b[0m │\n", + "│ │\n", + "│ │\n", + "├──────────────────────────────────────────────────────────────────────────────┤\n", + "│ │\n", + "│ CPU \u001b[1;36m1\u001b[0m │\n", + "│ │\n", + "│ │\n", + "└──────────────────────────────────────────────────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "jax.debug.visualize_array_sharding(params2[\"decoder\"][\"linen_state\"][\"params\"][\"layers\"][\"mlp\"][\"wi_0\"][\"kernel\"].value[:, 0, ...])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/core.py:691: FutureWarning: unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release.\n", + " warnings.warn(\n", + "/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/core.py:691: FutureWarning: unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release.\n", + " warnings.warn(\n", + "/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/core.py:691: FutureWarning: unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release.\n", + " warnings.warn(\n", + "/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/core.py:691: FutureWarning: unhashable type: . Attempting to hash a tracer will lead to an error in a future JAX release.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "params_shape = jax.eval_shape(eval_param_shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "State({\n", + " 'decoder': {\n", + " 'linen_state': {\n", + " 'params': {\n", + " 'decoder_norm': {\n", + " 'scale': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(2048,), dtype=float32)\n", + " )\n", + " },\n", + " 'layers': {\n", + " 'mlp': {\n", + " 'wi_0': {\n", + " 'kernel': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(2048, 16, 7168), dtype=float32)\n", + " )\n", + " },\n", + " 'wi_1': {\n", + " 'kernel': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(2048, 16, 7168), dtype=float32)\n", + " )\n", + " },\n", + " 'wo': {\n", + " 'kernel': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(7168, 16, 2048), dtype=float32)\n", + " )\n", + " }\n", + " },\n", + " 'post_self_attention_layer_norm': {\n", + " 'scale': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(2048, 16), dtype=float32)\n", + " )\n", + " },\n", + " 'pre_self_attention_layer_norm': {\n", + " 'scale': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(2048, 16), dtype=float32)\n", + " )\n", + " },\n", + " 'self_attention': {\n", + " 'key': {\n", + " 'kernel': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(2048, 16, 16, 128), dtype=float32)\n", + " )\n", + " },\n", + " 'out': {\n", + " 'kernel': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(16, 16, 128, 2048), dtype=float32)\n", + " )\n", + " },\n", + " 'query': {\n", + " 'kernel': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(2048, 16, 16, 128), dtype=float32)\n", + " )\n", + " },\n", + " 'value': {\n", + " 'kernel': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(2048, 16, 16, 128), dtype=float32)\n", + " )\n", + " }\n", + " }\n", + " },\n", + " 'logits_dense': {\n", + " 'kernel': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(2048, 32000), dtype=float32)\n", + " )\n", + " }\n", + " }\n", + " }\n", + " },\n", + " 'shared_embedding': {\n", + " 'embedding': VariableState(\n", + " type=Param,\n", + " value=ShapeDtypeStruct(shape=(32000, 2048), dtype=float32)\n", + " ),\n", + " 'rngs': {\n", + " 'default': {\n", + " 'count': VariableState(\n", + " type=RngCount,\n", + " value=ShapeDtypeStruct(shape=(), dtype=uint32),\n", + " tag='default'\n", + " ),\n", + " 'key': VariableState(\n", + " type=RngKey,\n", + " value=ShapeDtypeStruct(shape=(), dtype=key),\n", + " tag='default'\n", + " )\n", + " }\n", + " }\n", + " }\n", + "})" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "params_shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'params': {'decoder_norm': {'scale': LogicallyPartitioned(value=Param(\n", + " value=Array([1., 1., 1., ..., 1., 1., 1.], dtype=float32)\n", + " ), names=('norm',), mesh=None, rules=None)},\n", + " 'layers': {'mlp': {'wi_0': {'kernel': LogicallyPartitioned(value=Param(\n", + " value=Array([[[-1.46824829e-02, 1.97991189e-02, 4.92318161e-02, ...,\n", + " -2.27814689e-02, -1.77739235e-03, 7.30113918e-03],\n", + " [ 6.71353471e-03, 2.53547207e-02, 2.49213744e-02, ...,\n", + " 1.92197636e-02, 4.14784951e-03, -3.55075486e-02],\n", + " [ 2.83581223e-02, 3.04420362e-03, -8.07493739e-03, ...,\n", + " 6.35656202e-03, -3.63132101e-04, -8.82705022e-03],\n", + " ...,\n", + " [ 6.87895669e-03, 9.28692240e-03, 1.86208095e-02, ...,\n", + " 1.10543484e-03, 1.05740549e-02, 2.69666575e-02],\n", + " [ 3.79684679e-02, -4.19698209e-02, -4.50059175e-02, ...,\n", + " -2.73766350e-02, 4.00221422e-02, -1.40723391e-02],\n", + " [ 4.64158319e-03, -9.33785981e-04, -1.14092585e-02, ...,\n", + " -4.58411761e-02, 2.25836635e-02, 1.27054006e-02]],\n", + " \n", + " [[-4.04913798e-02, 3.38724777e-02, -3.74968746e-03, ...,\n", + " -2.00307509e-03, 2.26409640e-03, 1.51474886e-02],\n", + " [ 6.18777750e-03, -1.34507846e-02, 1.72182843e-02, ...,\n", + " -2.49324962e-02, 4.47796024e-02, 8.73657409e-03],\n", + " [-1.91211812e-02, -1.54849403e-02, -2.67096162e-02, ...,\n", + " 1.78834908e-02, 2.68951114e-02, -3.28964069e-02],\n", + " ...,\n", + " [-1.78560577e-02, 4.29300331e-02, -8.31132289e-03, ...,\n", + " -2.73101591e-03, 2.10166071e-02, 2.23623049e-02],\n", + " [-4.53571714e-02, -2.04789490e-02, -2.43269242e-02, ...,\n", + " 1.73852816e-02, 1.10745034e-03, -3.77959609e-02],\n", + " [-1.90838501e-02, 1.00691952e-02, 3.43200873e-06, ...,\n", + " 7.77074043e-03, 1.55691197e-02, -2.20722910e-02]],\n", + " \n", + " [[ 3.61189470e-02, -4.27912362e-02, -1.06900493e-02, ...,\n", + " 1.59358489e-03, 1.39744056e-03, -3.65842246e-02],\n", + " [ 1.11668864e-02, 3.53815816e-02, 4.75407615e-02, ...,\n", + " -3.76004949e-02, 1.23280250e-02, -1.76602218e-03],\n", + " [ 7.03665148e-03, -2.13977769e-02, 2.48558521e-02, ...,\n", + " -2.15282571e-02, 2.82202866e-02, -1.65731255e-02],\n", + " ...,\n", + " [ 9.53033101e-03, -4.23673429e-02, 7.94526562e-03, ...,\n", + " -2.16249805e-02, 1.12661840e-02, 1.83546072e-04],\n", + " [ 1.64078325e-02, -1.22003397e-02, 1.86574105e-02, ...,\n", + " -8.56488571e-03, 4.03752960e-02, -1.16453422e-02],\n", + " [ 1.11683747e-02, 3.37623693e-02, 2.34768149e-02, ...,\n", + " 3.37207504e-02, 2.58489735e-02, -3.99921685e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.06673529e-02, 1.62901934e-02, -1.58934742e-02, ...,\n", + " -3.04784980e-02, -7.57506862e-03, -2.69532092e-02],\n", + " [ 1.48534160e-02, 1.74835064e-02, 4.35699150e-02, ...,\n", + " 2.65684705e-02, -1.07743908e-02, 3.03987395e-02],\n", + " [-4.83171903e-02, 3.29687223e-02, -9.76319425e-03, ...,\n", + " -4.29098606e-02, -4.14303057e-02, 5.52123832e-03],\n", + " ...,\n", + " [ 2.07692813e-02, 2.06820983e-02, 1.08975815e-02, ...,\n", + " 1.06183039e-02, -1.57454293e-02, 2.00878945e-03],\n", + " [-3.22883832e-03, -3.03809531e-02, 3.59453596e-02, ...,\n", + " -1.33549236e-02, 5.82014350e-03, 1.84253324e-03],\n", + " [-6.24883513e-04, -5.99902635e-03, 3.36017609e-02, ...,\n", + " -1.02666970e-02, -2.40748618e-02, -1.17630488e-03]],\n", + " \n", + " [[ 1.13392184e-02, 1.39854755e-02, -1.60739142e-02, ...,\n", + " 3.97099070e-02, -3.44168767e-02, 3.52152064e-02],\n", + " [-2.29297262e-02, 1.79340411e-02, -2.24616248e-02, ...,\n", + " 6.26819488e-03, -1.30856168e-02, -1.70083076e-03],\n", + " [-2.59619020e-02, -4.64426093e-02, -4.98010451e-03, ...,\n", + " 2.34234314e-02, 3.21373940e-02, 8.47257953e-03],\n", + " ...,\n", + " [ 1.99932139e-02, -4.01608348e-02, -1.48963686e-02, ...,\n", + " -2.41215713e-02, 3.44673358e-03, 2.91437414e-02],\n", + " [-3.81128080e-02, -1.31169497e-03, -2.68604811e-02, ...,\n", + " -1.91383976e-02, 1.82469580e-02, 3.86135988e-02],\n", + " [ 4.87306993e-03, 1.54371411e-02, 3.27008238e-05, ...,\n", + " -3.02628987e-03, 3.03459093e-02, -3.01156454e-02]],\n", + " \n", + " [[-1.56527795e-02, 2.43246201e-02, 1.49646932e-02, ...,\n", + " 8.33549909e-03, 1.17203360e-02, 3.33694508e-03],\n", + " [-1.45665100e-02, 9.50825959e-03, 1.97047945e-02, ...,\n", + " 1.06094982e-02, -2.73795035e-02, 2.57862210e-02],\n", + " [ 1.42094120e-02, -3.70660163e-02, -3.45249125e-03, ...,\n", + " -1.38101866e-02, 1.34967249e-02, -3.43502574e-02],\n", + " ...,\n", + " [ 3.12731676e-02, 1.88873825e-03, -4.28854674e-03, ...,\n", + " -8.72620102e-03, 3.68478112e-02, 2.49224752e-02],\n", + " [ 2.09684130e-02, 3.38535458e-02, 3.13230045e-03, ...,\n", + " 1.26158334e-02, 1.17295515e-03, -1.31474240e-02],\n", + " [ 1.77969635e-02, 2.23454926e-02, 1.40393171e-02, ...,\n", + " 5.43067558e-03, -2.21616328e-02, -2.42203125e-03]]], dtype=float32)\n", + " ), names=('embed', 'layers', 'mlp'), mesh=None, rules=None)},\n", + " 'wi_1': {'kernel': LogicallyPartitioned(value=Param(\n", + " value=Array([[[ 0.03565774, 0.03441827, 0.02874585, ..., 0.03410576,\n", + " 0.03435576, -0.00528875],\n", + " [-0.01005271, 0.03558357, 0.00193986, ..., -0.01204118,\n", + " 0.04934866, -0.00306885],\n", + " [ 0.00393492, -0.00620722, -0.0219217 , ..., 0.00182931,\n", + " 0.03783138, 0.00232728],\n", + " ...,\n", + " [-0.00120604, -0.01271016, 0.03092119, ..., 0.04444938,\n", + " 0.02749676, 0.00381973],\n", + " [ 0.0113045 , 0.01060509, -0.00763659, ..., 0.00148459,\n", + " 0.03150247, 0.02341036],\n", + " [-0.00897911, -0.00409753, 0.01125375, ..., -0.00259383,\n", + " -0.02653666, 0.00266445]],\n", + " \n", + " [[-0.00561835, 0.03607391, -0.02593885, ..., 0.03354722,\n", + " 0.02193886, -0.03198051],\n", + " [-0.04530573, -0.03533448, 0.02430564, ..., 0.00361731,\n", + " -0.02755632, -0.01669513],\n", + " [ 0.02448408, 0.00108543, -0.0269188 , ..., -0.01477237,\n", + " 0.03186396, 0.02323902],\n", + " ...,\n", + " [-0.00723606, 0.02727294, 0.01028091, ..., 0.01844209,\n", + " -0.02139114, 0.0090496 ],\n", + " [-0.00218486, 0.01491658, -0.00867404, ..., -0.02923171,\n", + " 0.0253869 , 0.01560827],\n", + " [ 0.00709936, 0.00571416, -0.00153445, ..., -0.04361975,\n", + " 0.03077151, -0.02040462]],\n", + " \n", + " [[ 0.00902917, -0.0264896 , 0.02174903, ..., 0.00893519,\n", + " -0.03779956, 0.00516337],\n", + " [-0.01593422, -0.00213347, 0.00992963, ..., -0.03922895,\n", + " 0.02405727, -0.00849431],\n", + " [-0.02404962, 0.01765902, 0.01439368, ..., 0.0071812 ,\n", + " -0.03078856, -0.03064431],\n", + " ...,\n", + " [-0.04452487, -0.01695952, -0.00798001, ..., -0.03105808,\n", + " 0.03840261, -0.03804101],\n", + " [-0.03912866, -0.02157033, 0.0172698 , ..., 0.03589072,\n", + " 0.00449919, -0.01690734],\n", + " [-0.03414461, -0.0199658 , 0.00251831, ..., 0.0452245 ,\n", + " 0.01691104, 0.02390193]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.00479411, 0.00946209, 0.01421728, ..., 0.01489236,\n", + " -0.03535576, 0.00109393],\n", + " [ 0.01023187, 0.01845231, 0.03831874, ..., -0.0117135 ,\n", + " -0.03645476, 0.04585216],\n", + " [-0.03881903, 0.01121489, 0.03912411, ..., 0.00688312,\n", + " 0.00834181, -0.01929677],\n", + " ...,\n", + " [ 0.0208546 , -0.04601963, 0.02617503, ..., 0.0054558 ,\n", + " -0.00307089, -0.02127012],\n", + " [ 0.01081495, -0.04890989, 0.0449676 , ..., -0.00358266,\n", + " -0.00696406, 0.03550996],\n", + " [ 0.00361995, -0.01770126, -0.02872516, ..., -0.00467019,\n", + " 0.00382693, -0.00923806]],\n", + " \n", + " [[ 0.0005481 , 0.04066269, -0.02653156, ..., 0.03153158,\n", + " -0.03530968, -0.04896537],\n", + " [ 0.00294591, 0.0302179 , -0.00401548, ..., 0.00696581,\n", + " 0.00399454, -0.03124043],\n", + " [ 0.00358062, 0.00385296, 0.03248601, ..., -0.02260119,\n", + " 0.00422479, -0.00980416],\n", + " ...,\n", + " [ 0.04386496, -0.00206178, 0.00615798, ..., -0.01405665,\n", + " 0.00951219, 0.01838201],\n", + " [ 0.01730962, -0.03085295, 0.01302828, ..., -0.01612841,\n", + " -0.02780299, -0.00122608],\n", + " [-0.03326694, 0.00283201, 0.04142758, ..., -0.00470959,\n", + " 0.00938603, 0.00546736]],\n", + " \n", + " [[-0.00952438, 0.03114594, -0.00877742, ..., -0.02477983,\n", + " -0.04234194, -0.00398861],\n", + " [ 0.004522 , 0.00986265, 0.0449134 , ..., 0.01936236,\n", + " 0.00350893, 0.04222813],\n", + " [-0.0231675 , -0.00156449, 0.01716213, ..., -0.02140141,\n", + " -0.01919747, -0.01885796],\n", + " ...,\n", + " [-0.01526466, -0.02172572, -0.00303553, ..., -0.01049941,\n", + " 0.00976944, -0.01317649],\n", + " [-0.0292046 , 0.00984967, -0.03089336, ..., -0.03161765,\n", + " -0.02122704, 0.0194024 ],\n", + " [ 0.03095125, 0.02289206, 0.00686356, ..., -0.03189711,\n", + " 0.00271455, 0.04876042]]], dtype=float32)\n", + " ), names=('embed', 'layers', 'mlp'), mesh=None, rules=None)},\n", + " 'wo': {'kernel': LogicallyPartitioned(value=Param(\n", + " value=Array([[[-9.07508098e-03, -3.76353087e-03, -2.11386643e-02, ...,\n", + " -4.97729145e-03, -2.30344795e-02, -5.21374773e-03],\n", + " [ 1.49644222e-02, 2.48974524e-02, -6.44102739e-03, ...,\n", + " -9.40560270e-03, -1.29256596e-05, 2.05899850e-02],\n", + " [-2.18340438e-02, 1.11139053e-02, 1.20316977e-02, ...,\n", + " -4.91878437e-03, -1.63546521e-02, 1.88892409e-02],\n", + " ...,\n", + " [ 1.67979114e-02, -7.91628938e-03, -6.65052095e-03, ...,\n", + " -9.00100078e-03, -3.35256173e-03, -8.91312584e-03],\n", + " [-5.00768889e-03, -8.05124454e-03, 1.04554081e-02, ...,\n", + " 6.32952387e-03, 1.98556129e-02, 7.49903498e-03],\n", + " [ 2.24419567e-03, 3.23786447e-03, 6.57508802e-03, ...,\n", + " -1.58418622e-02, -3.87991662e-03, 5.16976137e-03]],\n", + " \n", + " [[-7.08179374e-04, -4.88601765e-03, -2.26451759e-03, ...,\n", + " -6.96770102e-03, -2.31834641e-03, 1.07336184e-02],\n", + " [-1.17810909e-02, 2.16609072e-02, 4.57067601e-03, ...,\n", + " 1.34057133e-03, 1.68046635e-02, -1.32207072e-03],\n", + " [-1.35228420e-02, -1.46499816e-02, 2.05482543e-02, ...,\n", + " 3.18284985e-03, -8.20552744e-03, 9.53213125e-03],\n", + " ...,\n", + " [-2.36591860e-03, 8.86272639e-03, 1.34189399e-02, ...,\n", + " 6.16838643e-03, -7.75564974e-03, -1.42136533e-02],\n", + " [-7.02896435e-03, 1.94715802e-02, 2.85791070e-03, ...,\n", + " 1.60093363e-02, -1.05189607e-02, -9.87690035e-03],\n", + " [ 1.73970070e-02, -5.90207893e-03, -1.87828224e-02, ...,\n", + " -9.87892412e-03, -3.18766735e-03, -2.61443220e-02]],\n", + " \n", + " [[-3.60694132e-04, -1.97788328e-02, -1.24834711e-02, ...,\n", + " 3.36641574e-06, 1.91073250e-02, -9.91937984e-03],\n", + " [ 2.45548645e-03, -1.09968300e-03, 1.09725315e-02, ...,\n", + " 4.31890112e-05, -2.03895988e-03, -1.16696733e-03],\n", + " [-1.51481507e-02, 1.27215749e-02, 1.93749033e-02, ...,\n", + " 6.69664470e-03, -9.62172076e-03, 1.22431461e-02],\n", + " ...,\n", + " [-2.05735583e-02, -4.28849412e-03, 9.92790097e-04, ...,\n", + " 1.33573245e-02, -1.37265297e-02, 8.86155874e-04],\n", + " [ 1.28661897e-02, -2.21134406e-02, -2.60440260e-03, ...,\n", + " 1.57516897e-02, -7.55556813e-03, 1.37633656e-03],\n", + " [-3.28322127e-03, 1.49798458e-02, -4.83577466e-03, ...,\n", + " -1.72403436e-02, -1.07629616e-02, 4.40329965e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.07283825e-02, 8.83054570e-04, 1.45358928e-02, ...,\n", + " -2.13773679e-02, -1.89648895e-03, -1.65178720e-02],\n", + " [-5.92785422e-03, 9.73744132e-03, 1.07453242e-02, ...,\n", + " -1.77890975e-02, -2.31479737e-03, -1.47928009e-02],\n", + " [ 3.46612209e-03, 2.09317752e-03, -2.23583076e-03, ...,\n", + " 2.02624295e-02, -1.98786631e-02, -1.45405186e-02],\n", + " ...,\n", + " [ 6.25825953e-03, 1.23642553e-02, -4.27940511e-04, ...,\n", + " 7.49228941e-03, -1.68918595e-02, -6.19752053e-03],\n", + " [ 7.74614932e-03, 8.15871917e-03, -1.56988800e-02, ...,\n", + " -7.48479553e-03, 1.19353151e-02, 1.09152747e-02],\n", + " [-9.78062768e-03, 2.34975982e-02, -1.64722558e-02, ...,\n", + " 2.55489200e-02, -2.14781682e-03, 1.35666728e-02]],\n", + " \n", + " [[-9.67365166e-04, 1.50311021e-02, -1.12769436e-02, ...,\n", + " 4.33268445e-03, 1.24321776e-02, 1.01084355e-02],\n", + " [-1.77122969e-02, -7.36261951e-04, -2.31412351e-02, ...,\n", + " -1.45015372e-02, 1.29036233e-02, -6.51812414e-03],\n", + " [ 5.06481528e-03, 9.68516152e-03, -5.95546560e-03, ...,\n", + " -1.38650984e-02, -6.36311481e-03, 5.51731652e-03],\n", + " ...,\n", + " [ 2.83123762e-03, -1.72765926e-03, 2.40682792e-02, ...,\n", + " -5.81357512e-04, -2.55863015e-02, -9.41873435e-03],\n", + " [ 2.02560928e-02, -6.03783410e-03, 8.39533005e-03, ...,\n", + " 4.66724765e-03, -2.89616943e-03, -1.92513764e-02],\n", + " [-2.21987185e-03, 3.65933403e-03, -5.08518051e-03, ...,\n", + " 5.65514667e-03, -4.49760118e-03, -2.43310118e-03]],\n", + " \n", + " [[ 9.94763151e-03, -5.15610911e-03, 6.12043543e-03, ...,\n", + " 9.91210621e-03, 1.29760453e-03, 1.23856477e-02],\n", + " [ 4.98563237e-03, 1.03172741e-03, -1.38760230e-03, ...,\n", + " 6.91412343e-03, 1.09909428e-02, 1.05478857e-02],\n", + " [-8.46618507e-03, -3.37270810e-03, -1.08084064e-02, ...,\n", + " 2.25753672e-02, 6.89936476e-03, 1.48271220e-02],\n", + " ...,\n", + " [ 6.95522828e-03, -4.86036250e-03, -2.18883296e-03, ...,\n", + " 3.51696392e-03, -4.05810308e-03, 1.28184073e-02],\n", + " [ 1.17182415e-02, -7.31491065e-03, -1.99927129e-02, ...,\n", + " 1.84665769e-02, 5.12807677e-03, 2.52193480e-04],\n", + " [-1.03745209e-02, -1.26228272e-03, 1.58133954e-02, ...,\n", + " 1.17033269e-04, 2.40123011e-02, 1.38586266e-02]]], dtype=float32)\n", + " ), names=('mlp', 'layers', 'embed'), mesh=None, rules=None)}},\n", + " 'post_self_attention_layer_norm': {'scale': LogicallyPartitioned(value=Param(\n", + " value=Array([[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)\n", + " ), names=('norm', 'layers'), mesh=None, rules=None)},\n", + " 'pre_self_attention_layer_norm': {'scale': LogicallyPartitioned(value=Param(\n", + " value=Array([[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)\n", + " ), names=('norm', 'layers'), mesh=None, rules=None)},\n", + " 'self_attention': {'key': {'kernel': LogicallyPartitioned(value=Param(\n", + " value=Array([[[[ 9.98104084e-03, 1.59688555e-02, 1.61589999e-02, ...,\n", + " -2.36384701e-02, -6.25879969e-03, 4.40714136e-03],\n", + " [-1.79614164e-02, -1.17118855e-03, 1.85510274e-02, ...,\n", + " 2.40465067e-03, -6.05480485e-02, 3.53242271e-03],\n", + " [-8.10797047e-03, 3.75609100e-02, -8.77115599e-05, ...,\n", + " -3.13788541e-02, 1.11738127e-02, 2.13317107e-02],\n", + " ...,\n", + " [-4.07692976e-02, -1.18069313e-02, 6.11988921e-03, ...,\n", + " -8.03878810e-03, 5.97768882e-03, -1.71378907e-02],\n", + " [-2.14502160e-02, -9.42392182e-03, -2.50575924e-03, ...,\n", + " -5.21345176e-02, 8.61073378e-03, -2.76700780e-02],\n", + " [ 1.03091053e-03, 4.56905551e-03, -1.83959701e-03, ...,\n", + " -1.25258733e-02, -2.33625714e-02, 1.80401914e-02]],\n", + " \n", + " [[ 2.35310127e-03, -1.63810514e-02, -1.00152260e-02, ...,\n", + " 2.62101018e-03, -3.30032632e-02, -1.32776704e-02],\n", + " [-1.88348461e-02, 1.65523700e-02, 5.31195663e-03, ...,\n", + " -1.73089504e-02, -7.16518797e-03, -8.08843295e-04],\n", + " [ 4.40382073e-03, -1.33496830e-02, 4.86714207e-03, ...,\n", + " 3.81087027e-02, -3.18874978e-02, -1.42767020e-02],\n", + " ...,\n", + " [ 2.35865526e-02, -4.83379252e-02, 3.93685279e-03, ...,\n", + " -3.33569571e-02, 2.13569030e-03, 1.60543825e-02],\n", + " [ 1.86941158e-02, -1.01079978e-03, -1.94314513e-02, ...,\n", + " 1.46976840e-02, -1.23430807e-02, 2.37119514e-02],\n", + " [ 1.08293416e-02, -4.18475978e-02, 6.89163432e-03, ...,\n", + " 2.82795019e-02, -7.48982374e-03, 1.18014133e-02]],\n", + " \n", + " [[ 1.68947000e-02, 2.12650257e-03, 2.72743274e-02, ...,\n", + " 2.12317165e-02, 4.94856574e-03, -6.24488033e-02],\n", + " [ 1.14068482e-02, -1.28335841e-02, 4.11398262e-02, ...,\n", + " -6.60371361e-03, 1.68395881e-02, 9.73575562e-03],\n", + " [ 2.07108222e-02, -2.10138615e-02, -1.44102620e-02, ...,\n", + " 2.78599616e-02, 2.31446791e-02, -8.04340933e-03],\n", + " ...,\n", + " [ 2.12568119e-02, 2.30695400e-02, 1.02877347e-02, ...,\n", + " 2.27846904e-03, 1.14677101e-02, -4.47166339e-03],\n", + " [ 9.41546541e-03, -3.09335068e-02, 9.11162421e-03, ...,\n", + " 6.88449480e-03, 3.55004370e-02, 7.00548012e-03],\n", + " [ 3.95961711e-03, 7.90314097e-03, 1.56400129e-02, ...,\n", + " -1.68990027e-02, -6.67196792e-03, -3.49220674e-04]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.62438620e-02, 5.84594421e-02, -7.62696052e-03, ...,\n", + " -8.99294205e-03, 1.24035673e-02, -1.74901467e-02],\n", + " [ 1.30181247e-02, -1.16716018e-02, 1.59225613e-02, ...,\n", + " -5.08111268e-02, -5.91615727e-03, -2.47620326e-02],\n", + " [ 1.22331828e-02, 8.44266918e-03, -2.86151655e-02, ...,\n", + " 6.06027693e-02, 3.10366172e-02, 8.19506310e-03],\n", + " ...,\n", + " [ 1.67075023e-02, 2.87209656e-02, -4.63903174e-02, ...,\n", + " -6.71990402e-03, -9.08801798e-03, -3.24341189e-03],\n", + " [-5.81940729e-03, -2.70631481e-02, -5.01795812e-03, ...,\n", + " -2.08367817e-02, -2.92552728e-02, 8.10562167e-03],\n", + " [-1.30825732e-02, -1.72582027e-02, -7.59805832e-03, ...,\n", + " 3.73275355e-02, 3.58866304e-02, 4.88963863e-03]],\n", + " \n", + " [[ 3.30702262e-03, -2.02163514e-02, -8.63661803e-03, ...,\n", + " -1.94315091e-02, -1.24991988e-04, 2.71679182e-02],\n", + " [-1.35238981e-02, -3.87510434e-02, -2.04325784e-02, ...,\n", + " 1.49307139e-02, 1.67791639e-02, -2.62580868e-02],\n", + " [-1.56468730e-02, -9.47724562e-03, -6.26823725e-03, ...,\n", + " 2.68701687e-02, 9.48239677e-03, 1.76364705e-02],\n", + " ...,\n", + " [-6.40862199e-05, -4.34645405e-03, -2.53209192e-02, ...,\n", + " -1.00937588e-02, -9.01887845e-03, 4.91252588e-03],\n", + " [ 3.85914221e-02, -2.96325958e-03, -7.70390732e-03, ...,\n", + " 1.52014531e-02, 1.70317236e-02, 3.79340607e-03],\n", + " [-1.97444018e-02, -1.29310479e-02, -8.31574854e-03, ...,\n", + " 6.80251000e-03, -2.13656537e-02, 3.32459062e-02]],\n", + " \n", + " [[ 2.49927025e-03, -9.49086342e-03, -1.24446908e-02, ...,\n", + " 1.30693531e-02, 1.70562696e-02, 1.50763628e-03],\n", + " [-2.78034160e-04, -4.78380779e-03, -1.28347408e-02, ...,\n", + " 1.22519732e-02, -5.81163494e-03, 1.55475317e-02],\n", + " [-9.01220553e-03, -1.50230518e-02, 2.44038180e-03, ...,\n", + " 3.19786407e-02, 3.93416844e-02, 1.35739045e-02],\n", + " ...,\n", + " [ 9.52635333e-03, 3.07885855e-02, 1.13216182e-02, ...,\n", + " -2.41896184e-03, 3.02634258e-02, -1.51911536e-02],\n", + " [-1.72974560e-02, -2.51180101e-02, -1.75352972e-02, ...,\n", + " 1.78796351e-02, -3.75340432e-02, 3.41667160e-02],\n", + " [-8.36785696e-03, 4.13313508e-03, 3.72673310e-02, ...,\n", + " -2.80103553e-02, 3.84394713e-02, 5.09639131e-03]]],\n", + " \n", + " \n", + " [[[-2.38960385e-02, -7.80875981e-03, 1.60261355e-02, ...,\n", + " 1.99319981e-02, -5.27845928e-03, 8.14353768e-03],\n", + " [ 2.13486366e-02, -4.02743556e-02, -3.26502509e-02, ...,\n", + " 3.18567245e-03, -1.87500026e-02, -4.06028843e-03],\n", + " [ 1.20903915e-02, 1.29403861e-03, 1.56650261e-04, ...,\n", + " 2.06638891e-02, 1.55147770e-02, -1.77861471e-02],\n", + " ...,\n", + " [ 2.31344420e-02, -3.40384506e-02, 2.52342187e-02, ...,\n", + " 2.74835583e-02, -9.58707090e-03, -2.28669792e-02],\n", + " [-4.05901903e-03, -3.69284227e-02, -1.31157006e-03, ...,\n", + " -2.83584483e-02, 9.23506916e-03, 1.16928748e-03],\n", + " [-9.15135350e-03, 1.49819916e-02, -8.65173619e-03, ...,\n", + " -5.42870956e-03, 1.53808892e-02, 1.12991724e-02]],\n", + " \n", + " [[-1.15478253e-02, 1.32119581e-02, -3.04323658e-02, ...,\n", + " -3.45253907e-02, -1.39119904e-02, -7.91519065e-04],\n", + " [ 4.14968319e-02, 2.01195618e-03, -4.37397100e-02, ...,\n", + " 1.76242813e-02, 1.33487517e-02, 2.74274778e-02],\n", + " [ 2.44710274e-04, -1.44015402e-02, -8.31946172e-03, ...,\n", + " -9.08750016e-03, 8.85715708e-03, 9.97282378e-03],\n", + " ...,\n", + " [-6.69140462e-03, -4.89844568e-03, -7.90875498e-03, ...,\n", + " -9.96531919e-03, -3.77064943e-02, -2.59322068e-03],\n", + " [-2.73570698e-03, 2.25353260e-02, 3.89909595e-02, ...,\n", + " -3.34200636e-02, 1.21990861e-02, -2.30164230e-02],\n", + " [-1.63493976e-02, 3.83873261e-03, -1.55478576e-02, ...,\n", + " 1.76279452e-02, -1.17726652e-02, 4.49196920e-02]],\n", + " \n", + " [[-1.82930622e-02, -7.02985842e-03, 2.45022792e-02, ...,\n", + " -3.69956382e-02, 1.58171188e-02, -3.13313454e-02],\n", + " [-6.81585516e-04, 3.59866791e-03, 2.28256476e-03, ...,\n", + " -3.96049358e-02, 1.98061969e-02, -9.03436821e-03],\n", + " [ 1.92559771e-02, 2.83573885e-02, 3.56518924e-02, ...,\n", + " 1.97382364e-02, 5.90320583e-03, 2.12071761e-02],\n", + " ...,\n", + " [-3.28127891e-02, -3.83468054e-04, -2.64698174e-03, ...,\n", + " -2.11906992e-03, -8.65013245e-03, -5.47139533e-03],\n", + " [-7.47651700e-03, 2.39912476e-02, -4.26247111e-03, ...,\n", + " -4.22692113e-02, 3.16754356e-02, 6.77536940e-03],\n", + " [ 1.39406258e-02, -3.30717280e-03, -1.35963224e-02, ...,\n", + " -2.01949272e-02, -2.67038029e-02, -2.60270154e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-9.05567780e-03, -8.98809358e-03, 3.19246799e-02, ...,\n", + " -1.32246604e-02, 1.62581913e-02, 1.91331096e-02],\n", + " [-2.19893921e-02, 4.17922810e-02, 3.64403948e-02, ...,\n", + " -8.83269589e-03, 2.90103666e-02, -1.37755787e-02],\n", + " [ 1.63827688e-02, -2.54808250e-03, 3.55842337e-02, ...,\n", + " -1.04324697e-02, 1.21145789e-02, 3.72201391e-03],\n", + " ...,\n", + " [-2.82503180e-02, -2.57107187e-02, 1.60968304e-02, ...,\n", + " -5.93443913e-03, -2.85195448e-02, -2.32952368e-02],\n", + " [-6.19092956e-03, -1.48009593e-02, 6.57100417e-03, ...,\n", + " -1.53454961e-02, -1.67688914e-02, -3.61259542e-02],\n", + " [ 3.90939787e-03, 4.79806997e-02, -3.56106684e-02, ...,\n", + " -3.02567286e-03, 3.89548205e-03, 2.37605837e-03]],\n", + " \n", + " [[-4.46378533e-03, -7.04277866e-03, 9.77296475e-03, ...,\n", + " -3.30955014e-02, 1.02817593e-02, -3.10437810e-02],\n", + " [-2.68313121e-02, 8.57926346e-03, 3.85397039e-02, ...,\n", + " 3.40819508e-02, -1.84267145e-02, 4.66169976e-03],\n", + " [-9.85915121e-03, 6.64951513e-03, 1.01416232e-02, ...,\n", + " 1.78388935e-02, -2.93764398e-02, 4.07898352e-02],\n", + " ...,\n", + " [-1.33392001e-02, -3.08248065e-02, 6.53423602e-03, ...,\n", + " 3.25974599e-02, -8.93197022e-03, 1.65896229e-02],\n", + " [ 3.49212214e-02, -9.02530178e-03, 1.83998663e-02, ...,\n", + " -5.96993975e-03, -1.65763553e-02, -9.88619775e-03],\n", + " [ 1.14258863e-02, -1.75203290e-02, 5.14224777e-03, ...,\n", + " -7.94262532e-03, -1.19587323e-02, -4.37967386e-03]],\n", + " \n", + " [[ 2.47726422e-02, 2.70449370e-03, -3.78459021e-02, ...,\n", + " -3.45514193e-02, 1.95742641e-02, 1.45505304e-02],\n", + " [-1.31507469e-02, -2.33946387e-02, 1.25415279e-02, ...,\n", + " -8.20873398e-03, -1.75710525e-02, 1.34092905e-02],\n", + " [ 5.92371635e-03, 9.14022140e-03, 6.67352648e-03, ...,\n", + " 3.66200954e-02, 3.42024602e-02, 2.24676356e-02],\n", + " ...,\n", + " [ 2.94511821e-02, -6.82496978e-03, 1.58534404e-02, ...,\n", + " 1.70412716e-02, 2.35159416e-02, 3.00012212e-02],\n", + " [ 1.69200599e-02, -2.65185945e-02, -1.00305434e-02, ...,\n", + " -9.29499883e-03, -3.68024744e-02, -3.60316806e-03],\n", + " [-2.34347265e-02, -2.72402670e-02, -1.39702670e-03, ...,\n", + " -1.82007160e-03, -2.83285789e-03, -6.98567694e-03]]],\n", + " \n", + " \n", + " [[[ 1.46922721e-02, -1.50566967e-02, 3.22788693e-02, ...,\n", + " 4.35602628e-02, -1.69371180e-02, 1.67538319e-02],\n", + " [ 8.79650004e-03, 4.61293152e-03, 4.53954795e-03, ...,\n", + " -3.65974777e-03, -9.83852800e-03, -2.19115466e-02],\n", + " [-2.43448224e-02, -1.15892179e-02, -5.07462944e-04, ...,\n", + " 1.16093419e-02, 4.17537428e-03, 1.17164403e-02],\n", + " ...,\n", + " [ 2.30648462e-02, 4.00895439e-02, 6.02875371e-04, ...,\n", + " -1.36918854e-02, 1.55656440e-02, -1.42481746e-02],\n", + " [-4.15097550e-02, -2.35571545e-02, 3.23780254e-02, ...,\n", + " 3.41399829e-03, 5.64015750e-03, -2.83962358e-02],\n", + " [ 1.73292495e-02, -1.90802515e-02, -1.76426582e-02, ...,\n", + " 2.22983304e-02, -2.79177204e-02, -1.06717022e-02]],\n", + " \n", + " [[ 4.33714315e-02, 2.80921236e-02, 1.00891674e-02, ...,\n", + " 2.19056066e-02, -7.15666311e-03, 5.16341068e-02],\n", + " [ 1.97851495e-03, -2.11731307e-02, 3.24648954e-02, ...,\n", + " -3.22418623e-02, -1.66047644e-02, -2.52671428e-02],\n", + " [-1.89462062e-02, 1.60795134e-02, -9.45313741e-03, ...,\n", + " 1.40624645e-04, -2.97423657e-02, -1.68410316e-02],\n", + " ...,\n", + " [ 1.95442345e-02, -3.14226523e-02, 1.86026618e-02, ...,\n", + " 5.04186796e-03, 1.71495089e-03, 2.99274530e-02],\n", + " [ 1.02104135e-02, 2.82178968e-02, 5.34341782e-02, ...,\n", + " -1.43127609e-02, -1.69200804e-02, -3.21779437e-02],\n", + " [ 5.92325069e-03, -9.95789072e-04, 2.26714890e-02, ...,\n", + " 3.19184735e-02, -3.16146165e-02, 1.17764540e-03]],\n", + " \n", + " [[-1.88334212e-02, -7.98336789e-02, -1.04318429e-02, ...,\n", + " 5.00106290e-02, 2.57417653e-02, 1.27624199e-02],\n", + " [-2.71921940e-02, -2.89224703e-02, -1.13601200e-02, ...,\n", + " 3.95220071e-02, 1.35518955e-02, -2.42836773e-02],\n", + " [ 2.88921851e-03, 1.57890320e-02, -1.07391747e-02, ...,\n", + " -1.63617209e-02, -6.86274422e-03, 2.56431202e-04],\n", + " ...,\n", + " [ 2.04988420e-02, 8.31293128e-03, -4.18259948e-02, ...,\n", + " 4.05490631e-03, 1.20172873e-02, 1.69694517e-02],\n", + " [ 3.98022607e-02, -2.43449258e-03, -3.20049468e-03, ...,\n", + " -3.87019105e-02, -4.89497744e-02, -2.09026802e-02],\n", + " [-1.98083930e-02, -4.68334369e-03, -2.24323962e-02, ...,\n", + " 7.14561762e-03, 1.47617457e-03, -1.06195947e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-3.21465768e-02, -6.01922162e-03, -3.54064559e-03, ...,\n", + " -2.13874988e-02, 2.84030405e-03, -1.40779372e-02],\n", + " [ 3.60625461e-02, -4.17255424e-03, -3.95840593e-03, ...,\n", + " 1.02100207e-03, -1.63890477e-02, -1.49069829e-02],\n", + " [-2.36533470e-02, -2.13329159e-02, -1.21561037e-02, ...,\n", + " -1.37716495e-02, -4.02703835e-03, 3.38089168e-02],\n", + " ...,\n", + " [-2.83686947e-02, -2.28986293e-02, 3.01956963e-02, ...,\n", + " 2.05005030e-03, -1.68484845e-03, 2.88131721e-02],\n", + " [-4.12416384e-02, -3.42474543e-02, -1.68071617e-03, ...,\n", + " 1.69363506e-02, 4.01699543e-02, 5.85046504e-03],\n", + " [-2.42107902e-02, -1.82455697e-03, -6.09175349e-03, ...,\n", + " -6.52236165e-03, 1.15882058e-03, -1.48505354e-02]],\n", + " \n", + " [[-7.34559866e-03, 2.50953739e-03, -1.58386119e-02, ...,\n", + " 1.93906091e-02, 1.82358120e-02, -1.40716527e-02],\n", + " [-7.84495194e-03, -1.93782864e-04, 9.94820846e-04, ...,\n", + " 2.17454024e-02, 7.09982589e-03, 1.01293288e-02],\n", + " [ 3.27572003e-02, 4.53847200e-02, -1.68661065e-02, ...,\n", + " -1.41758174e-02, -2.37568747e-02, -1.08635696e-02],\n", + " ...,\n", + " [-2.12254897e-02, 4.89619561e-02, 2.10761875e-02, ...,\n", + " 3.56286392e-02, -4.70774919e-02, -4.44780802e-03],\n", + " [-1.94987729e-02, -1.81865022e-02, 3.00472905e-03, ...,\n", + " -2.24764589e-02, -2.93147676e-02, -3.88287790e-02],\n", + " [ 1.71370115e-02, 8.24137405e-03, -1.20976511e-02, ...,\n", + " 2.62454106e-03, -2.73552630e-02, -3.60670756e-03]],\n", + " \n", + " [[ 3.54777731e-04, -1.71962846e-02, -3.85503583e-02, ...,\n", + " -3.18362787e-02, 3.20166014e-02, -3.98790985e-02],\n", + " [-4.47113113e-03, 1.09510077e-02, 2.47112829e-02, ...,\n", + " 3.39519531e-02, 3.69816124e-02, -9.98119917e-03],\n", + " [-1.77717593e-03, 8.19783565e-03, 2.79606823e-02, ...,\n", + " -7.11567476e-02, -2.22736411e-02, -2.88497307e-03],\n", + " ...,\n", + " [-3.09489779e-02, 2.58802474e-02, -8.70705582e-03, ...,\n", + " 1.00675607e-02, 7.64593063e-03, 1.55500351e-02],\n", + " [ 1.39459537e-03, 3.45916152e-02, -2.23256536e-02, ...,\n", + " 2.52868757e-02, 3.60817648e-03, -8.99366662e-03],\n", + " [-1.90600622e-02, -2.15801951e-02, -2.70377956e-02, ...,\n", + " -8.23643873e-04, -3.68903205e-02, -6.38305098e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 7.92311411e-03, 1.41378883e-02, -1.07794367e-02, ...,\n", + " -1.53082730e-02, -9.30196047e-03, 3.86641510e-02],\n", + " [-3.38271149e-02, 4.23511461e-04, 2.45871767e-02, ...,\n", + " -3.98764685e-02, 2.49621850e-02, -1.98978893e-02],\n", + " [ 1.92915034e-02, -1.17508974e-03, 1.36238327e-02, ...,\n", + " 2.75402162e-02, 3.27949300e-02, 9.56372637e-03],\n", + " ...,\n", + " [ 2.45963279e-02, -1.52535420e-02, -9.02652927e-03, ...,\n", + " 2.38725860e-02, -3.19998083e-03, -4.43489179e-02],\n", + " [ 5.45599265e-03, 2.10298542e-02, -3.45331151e-03, ...,\n", + " 1.42665161e-03, 1.57153571e-03, -1.28439572e-02],\n", + " [ 2.19439578e-04, -6.36765081e-03, -7.30145629e-03, ...,\n", + " -2.21679639e-02, -1.62544660e-02, 2.84295958e-02]],\n", + " \n", + " [[ 3.14959511e-03, -1.83053799e-02, -1.23594678e-03, ...,\n", + " -1.49676893e-02, 2.11138465e-02, -2.99243052e-02],\n", + " [ 1.14861783e-02, -2.24336633e-03, 1.99916866e-02, ...,\n", + " 1.02459278e-04, -2.54582297e-02, 9.63746756e-03],\n", + " [-2.29035504e-03, -1.62002183e-02, -5.89855621e-03, ...,\n", + " -3.99995483e-02, 2.59546447e-03, 2.22039632e-02],\n", + " ...,\n", + " [ 6.98803051e-05, -1.33106662e-02, 2.21071243e-02, ...,\n", + " 1.28839249e-02, 4.90500638e-03, 8.07818305e-03],\n", + " [ 4.10916237e-03, -2.62456890e-02, 3.39489244e-02, ...,\n", + " 8.10478814e-03, 9.56555456e-03, -2.68895905e-02],\n", + " [-1.00300824e-02, -4.15954143e-02, 5.55448746e-03, ...,\n", + " -2.06358433e-02, -1.28297070e-02, 2.30511604e-03]],\n", + " \n", + " [[ 2.21409984e-02, 1.41417291e-02, -2.32528504e-02, ...,\n", + " 8.53959844e-03, 2.94376239e-02, 1.33991390e-02],\n", + " [-1.93202514e-02, -4.98548942e-03, -1.00286407e-02, ...,\n", + " -2.53196210e-02, -2.60243062e-02, 5.58571192e-03],\n", + " [-2.47033425e-02, 2.92320922e-02, 3.58061679e-02, ...,\n", + " 5.80539135e-03, -1.79602224e-02, 7.43259117e-03],\n", + " ...,\n", + " [ 2.15874780e-02, -8.70363694e-03, -2.36786921e-02, ...,\n", + " 7.45585607e-03, 1.75383221e-02, 1.70961097e-02],\n", + " [ 3.04667726e-02, 2.65112892e-02, -3.00667528e-02, ...,\n", + " -1.09294485e-02, 5.83385713e-02, 3.92496325e-02],\n", + " [ 2.30226405e-02, -1.05086779e-02, 1.08611146e-02, ...,\n", + " 1.82129424e-02, 1.61664374e-02, -1.01210093e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-5.65261114e-03, 2.44086757e-02, -9.26491152e-03, ...,\n", + " -1.23758866e-02, 4.13009301e-02, 1.21109346e-02],\n", + " [ 3.77076562e-03, -1.57525262e-03, -3.84917296e-02, ...,\n", + " 3.98874283e-02, 5.62397093e-02, 1.40442597e-02],\n", + " [ 1.57943070e-02, 6.36206521e-03, -1.39866713e-02, ...,\n", + " -2.80065071e-02, 6.32997975e-03, -1.02270739e-02],\n", + " ...,\n", + " [ 1.82655454e-03, 2.25882921e-02, 2.73263250e-02, ...,\n", + " 1.00082229e-03, -1.39960705e-03, -2.70390082e-02],\n", + " [-1.40826525e-02, 9.90296714e-03, 3.24425474e-02, ...,\n", + " -9.57366545e-03, -2.49028504e-02, 3.21057215e-02],\n", + " [ 5.58992215e-02, -3.29012647e-02, 1.87043175e-02, ...,\n", + " 3.68904173e-02, 4.97059943e-03, -3.46108526e-03]],\n", + " \n", + " [[-3.09388917e-02, 2.18043830e-02, -2.02209074e-02, ...,\n", + " 2.19267681e-02, 2.93062292e-02, -8.29111040e-03],\n", + " [-3.93134467e-02, 3.19911204e-02, 2.29696929e-02, ...,\n", + " -9.20659024e-03, -1.04239965e-02, 3.39123271e-02],\n", + " [ 1.96585134e-02, -2.86515672e-02, 9.54271667e-03, ...,\n", + " -2.47983052e-03, -2.03029141e-02, 1.71760246e-02],\n", + " ...,\n", + " [ 1.80947240e-02, 9.29749105e-03, 3.58061902e-02, ...,\n", + " -1.27449995e-02, 2.23024283e-02, -1.51114687e-02],\n", + " [-2.77244058e-02, 1.00372816e-02, 2.05204077e-02, ...,\n", + " -1.93466321e-02, 2.37277746e-02, 1.62767526e-03],\n", + " [-7.78205786e-03, -2.59915721e-02, -8.86264257e-03, ...,\n", + " -1.90750174e-02, -6.42557964e-02, -2.43876223e-02]],\n", + " \n", + " [[-3.11316755e-02, -2.47430149e-02, -2.66333576e-03, ...,\n", + " -5.05396239e-02, 3.07888584e-03, 1.09418789e-02],\n", + " [-1.03661446e-02, 2.15766728e-02, 1.71110407e-02, ...,\n", + " 6.95373770e-03, 2.45610140e-02, -1.90155618e-02],\n", + " [-2.62158625e-02, 2.23902278e-02, -1.95221435e-02, ...,\n", + " 1.38828652e-02, 1.45518417e-02, 2.73007117e-02],\n", + " ...,\n", + " [-3.71055887e-03, 2.62371693e-02, -1.59934405e-02, ...,\n", + " 1.89527348e-02, -1.53741315e-02, -1.13933655e-02],\n", + " [ 6.15425641e-04, 1.14568332e-02, 6.79585058e-03, ...,\n", + " -6.96819415e-03, 3.80350053e-02, 3.86422612e-02],\n", + " [-1.02707576e-02, 5.52736707e-02, -4.42233449e-03, ...,\n", + " 1.96574740e-02, 9.59230494e-03, 6.27154065e-03]]],\n", + " \n", + " \n", + " [[[ 5.30388998e-03, -2.77228566e-04, 1.25684682e-02, ...,\n", + " -1.86173040e-02, -9.32976510e-03, 3.62677947e-02],\n", + " [ 3.17595825e-02, 7.18384981e-02, -4.02139872e-02, ...,\n", + " -1.78597979e-02, 1.66928116e-02, 3.12196254e-03],\n", + " [ 8.51421151e-03, 1.57978963e-02, 1.40492385e-02, ...,\n", + " -1.39013282e-03, -5.03457244e-03, 6.06294163e-03],\n", + " ...,\n", + " [-2.84298323e-02, -4.46975604e-03, -6.07138593e-03, ...,\n", + " -1.45471795e-02, -1.43388323e-02, 1.07126888e-02],\n", + " [-8.16246942e-02, 1.02488957e-02, 4.17387486e-03, ...,\n", + " 1.89219341e-02, 1.02887466e-03, -3.29028480e-02],\n", + " [-4.22075763e-02, -3.27504575e-02, -3.54678743e-02, ...,\n", + " -1.45768542e-02, 1.03760846e-02, 1.85341351e-02]],\n", + " \n", + " [[-1.09680789e-03, 2.61755893e-03, -1.25682317e-02, ...,\n", + " -1.28531288e-02, 1.43346302e-02, -1.89503457e-03],\n", + " [ 4.52561630e-03, -1.59883723e-02, -1.22267446e-02, ...,\n", + " -2.97064846e-03, 9.64227505e-03, 5.13029136e-02],\n", + " [-7.68044870e-03, -5.49495593e-03, 2.77458411e-02, ...,\n", + " -3.06336675e-02, -1.48461740e-02, 2.21258216e-02],\n", + " ...,\n", + " [-2.09953766e-02, -9.36275721e-03, 2.31342278e-02, ...,\n", + " 6.46259310e-03, -1.63279269e-02, -8.91698408e-04],\n", + " [-2.15722416e-02, 3.35210905e-04, 1.00424504e-02, ...,\n", + " 2.15924922e-02, 1.50765451e-02, 1.13496585e-02],\n", + " [-4.26839590e-02, -3.22726741e-03, 3.57248820e-02, ...,\n", + " 2.34062411e-03, 2.82860771e-02, 2.07038578e-02]],\n", + " \n", + " [[ 9.13192239e-03, -4.24849205e-02, 3.07685379e-02, ...,\n", + " -3.73274907e-02, 3.48763494e-03, 2.94765308e-02],\n", + " [ 1.08741112e-02, 9.86367743e-03, 8.06485768e-03, ...,\n", + " 2.02839337e-02, 2.60464661e-02, 3.31459269e-02],\n", + " [ 1.56486575e-02, 1.12361731e-02, 9.39249992e-03, ...,\n", + " -1.38258059e-02, 3.94318253e-02, -9.81542841e-03],\n", + " ...,\n", + " [ 2.43948400e-02, -2.14900821e-03, -2.52453052e-02, ...,\n", + " 1.87767874e-02, -1.50138093e-02, -2.49968637e-02],\n", + " [ 3.80803482e-03, -2.09002309e-02, -7.61321886e-03, ...,\n", + " -4.16172594e-02, 1.86036881e-02, 2.80170381e-04],\n", + " [-6.73812069e-03, -3.64310741e-02, 4.33162749e-02, ...,\n", + " -3.02704587e-03, 1.96599890e-03, 1.61957573e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-4.31040861e-02, -4.33356315e-03, -2.04837658e-02, ...,\n", + " 2.18657404e-02, -1.01621794e-02, -3.63434618e-03],\n", + " [-7.60025810e-03, -1.43353641e-03, -2.55990797e-03, ...,\n", + " 3.20926611e-03, 4.32263911e-02, -2.26990283e-02],\n", + " [-5.39905764e-03, -1.55137414e-02, 1.83037650e-02, ...,\n", + " 1.10770352e-02, -2.41475403e-02, 2.55352957e-03],\n", + " ...,\n", + " [ 1.21220239e-02, 8.16479139e-03, -1.75310541e-02, ...,\n", + " -7.06721970e-04, -7.67495949e-04, 2.67728306e-02],\n", + " [ 9.61747300e-03, 3.36289071e-02, -2.98715979e-02, ...,\n", + " -1.08439336e-02, 3.04514915e-02, 4.92281374e-03],\n", + " [ 1.68735776e-02, -7.65340822e-03, 2.27796379e-02, ...,\n", + " -2.84707751e-02, 2.95312144e-04, 2.58639995e-02]],\n", + " \n", + " [[-1.50748659e-02, -3.62316966e-02, 1.48225529e-02, ...,\n", + " 3.36956605e-02, -4.27803863e-03, -4.60707843e-02],\n", + " [-1.37211783e-02, 2.09336188e-02, -1.96862733e-03, ...,\n", + " -1.64590608e-02, 2.23768931e-02, 1.18644992e-02],\n", + " [-2.22140569e-02, 1.47926593e-02, 2.55332645e-02, ...,\n", + " -2.31636055e-02, 1.81392524e-02, 3.47924083e-02],\n", + " ...,\n", + " [ 4.34953757e-02, -1.84452310e-02, 1.14278933e-02, ...,\n", + " 6.06728904e-03, 1.41338550e-03, -8.31194222e-03],\n", + " [ 3.16520818e-02, -1.03927832e-02, 1.35171823e-02, ...,\n", + " -2.40571816e-02, 2.55631134e-02, 1.15686944e-02],\n", + " [-3.95310530e-03, 6.07472844e-03, -1.39684845e-02, ...,\n", + " 3.68243922e-03, -2.12407019e-02, -5.60068451e-02]],\n", + " \n", + " [[ 3.15125957e-02, -3.00869253e-02, -3.19464295e-03, ...,\n", + " -1.20992921e-02, 1.19971121e-02, -1.05607724e-02],\n", + " [ 3.96671612e-03, -1.02292206e-02, 4.78439266e-04, ...,\n", + " 3.56249586e-02, -2.46530562e-03, 4.38960306e-02],\n", + " [ 2.79264860e-02, -2.65449029e-03, 7.51507096e-03, ...,\n", + " -9.21469927e-03, 1.19787280e-03, 5.65008353e-03],\n", + " ...,\n", + " [ 7.34231574e-03, -3.18963453e-02, -1.05863740e-03, ...,\n", + " -7.20110722e-03, -4.16397601e-02, -2.29477454e-02],\n", + " [-2.61361944e-03, 4.94082235e-02, 1.84304323e-02, ...,\n", + " -2.61832704e-03, 9.75636579e-03, 4.59055007e-02],\n", + " [ 4.04830202e-02, 3.20039131e-03, 1.27492305e-02, ...,\n", + " 3.04953437e-02, -1.49954837e-02, 3.32608819e-02]]],\n", + " \n", + " \n", + " [[[ 4.32941951e-02, 3.10984440e-02, -1.40028745e-02, ...,\n", + " -2.63396502e-02, 2.71022245e-02, 2.98934691e-02],\n", + " [ 2.03460678e-02, 1.04432926e-03, -2.91057955e-02, ...,\n", + " -2.70425845e-02, -5.00517078e-02, 3.24078985e-02],\n", + " [ 1.39814401e-02, 1.09670172e-02, -2.25809179e-02, ...,\n", + " -2.70948224e-02, -2.72415113e-02, -5.99681691e-04],\n", + " ...,\n", + " [-4.34833989e-02, 4.24072845e-03, 4.41785174e-04, ...,\n", + " -4.01108200e-03, 2.76879547e-03, 4.36873501e-03],\n", + " [ 6.36186858e-04, -3.08388397e-02, -5.54279704e-03, ...,\n", + " 3.44318524e-02, 4.05389909e-03, -2.06897948e-02],\n", + " [ 1.52571676e-02, -2.65168045e-02, 1.62985362e-03, ...,\n", + " -2.93499380e-02, 1.51740424e-02, -3.43150757e-02]],\n", + " \n", + " [[-6.77207038e-02, -3.05177663e-02, -8.02449044e-03, ...,\n", + " -1.06523717e-02, 3.87160741e-02, -2.89201899e-03],\n", + " [ 3.51064950e-02, 2.23247195e-03, 2.91619264e-02, ...,\n", + " 3.99974780e-03, 4.15953854e-03, -1.27803711e-02],\n", + " [-5.97910350e-03, 1.42097278e-02, -1.42388244e-03, ...,\n", + " 3.10557708e-03, 2.31009927e-02, 9.86447278e-03],\n", + " ...,\n", + " [-6.45806221e-03, -1.81956850e-02, -1.48468940e-02, ...,\n", + " -3.61507274e-02, -2.87679080e-02, 3.95791791e-02],\n", + " [-5.98290469e-03, -6.86451141e-03, 1.28855128e-02, ...,\n", + " 1.03946012e-02, -3.36898171e-04, 4.79138223e-03],\n", + " [-3.01254299e-02, -5.35822893e-03, -1.63584892e-02, ...,\n", + " 4.00997847e-02, -4.70964937e-03, -3.87126915e-02]],\n", + " \n", + " [[ 2.80313864e-02, 8.46729334e-03, -1.70287956e-02, ...,\n", + " 8.14406760e-03, -5.37659042e-03, -2.26973947e-02],\n", + " [ 6.07402204e-03, 5.30915521e-02, 5.38000138e-03, ...,\n", + " 3.37931104e-02, 2.06457470e-02, -2.61472929e-02],\n", + " [ 6.35048281e-03, 4.40138252e-03, 2.00671144e-03, ...,\n", + " -7.43304892e-03, 5.84282679e-03, -1.30469748e-03],\n", + " ...,\n", + " [-6.11268217e-03, -1.08368285e-02, -2.08084360e-02, ...,\n", + " -2.30966397e-02, 1.73330773e-02, 1.72163546e-02],\n", + " [ 1.40693923e-02, 9.34054051e-03, 3.54509205e-02, ...,\n", + " 2.19709892e-02, 2.17053145e-02, 6.12467155e-03],\n", + " [ 3.62938717e-02, -2.18308773e-02, 1.65606849e-02, ...,\n", + " 1.76243614e-02, 1.70536023e-02, 3.82164028e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 2.02965848e-02, 4.34144912e-03, 3.68712731e-02, ...,\n", + " 3.61357704e-02, 2.92288121e-02, -6.28873799e-03],\n", + " [-2.01349407e-02, -2.85606887e-02, -1.47841463e-03, ...,\n", + " -3.38885598e-02, 1.42047508e-02, 4.26835530e-02],\n", + " [ 1.06356023e-02, -9.52455375e-05, 5.59027866e-02, ...,\n", + " 1.28179509e-02, -7.38982577e-03, 6.83565810e-02],\n", + " ...,\n", + " [-2.75830701e-02, -2.35095378e-02, -2.19627898e-02, ...,\n", + " 2.42620036e-02, 4.22321409e-02, -3.09796110e-02],\n", + " [ 2.88582537e-02, 1.80471484e-02, 6.21309597e-03, ...,\n", + " 1.03071216e-03, -1.96726974e-02, -3.78747261e-03],\n", + " [ 1.43212844e-02, 4.17260490e-02, -1.03303706e-02, ...,\n", + " -2.52035330e-03, 2.44680848e-02, -2.61700191e-02]],\n", + " \n", + " [[ 1.37732970e-02, 7.18049007e-03, 1.12178121e-02, ...,\n", + " 3.55166718e-02, -1.64650120e-02, -4.38697711e-02],\n", + " [ 8.58534407e-03, -3.80925760e-02, 4.47692396e-03, ...,\n", + " -6.86220359e-03, 4.65866830e-03, 2.21862495e-02],\n", + " [-1.63988564e-02, 1.64850596e-02, -1.06026260e-02, ...,\n", + " -6.18771557e-03, -4.70151752e-03, 1.34417061e-02],\n", + " ...,\n", + " [ 5.20932581e-03, 2.15749778e-02, 3.66376229e-02, ...,\n", + " 3.21099535e-02, -1.18134497e-03, -2.09525283e-02],\n", + " [-5.44207022e-02, 1.02304369e-02, -9.14097298e-03, ...,\n", + " 1.67771839e-02, 1.08779278e-02, -1.03849219e-02],\n", + " [ 4.42930358e-03, -5.48929675e-03, -9.38280579e-03, ...,\n", + " -2.02783234e-02, 3.44903320e-02, -1.50323212e-02]],\n", + " \n", + " [[-1.60864275e-02, -1.50802750e-02, 5.39442850e-03, ...,\n", + " 9.46784671e-03, -3.27216908e-02, 1.49815250e-03],\n", + " [ 1.30281094e-02, -2.21546646e-02, 2.04520021e-02, ...,\n", + " 5.15241846e-02, -2.25299131e-02, -1.06447479e-02],\n", + " [ 5.39766671e-03, 1.06891161e-02, -6.60262955e-03, ...,\n", + " -1.91490948e-02, 1.05633000e-02, -2.15542875e-02],\n", + " ...,\n", + " [-4.97350143e-03, -1.45917060e-02, 1.59165356e-02, ...,\n", + " -3.71967144e-02, 2.05513220e-02, 4.66412958e-03],\n", + " [ 3.05538550e-02, -9.18694772e-03, 4.03900854e-02, ...,\n", + " 1.90847169e-03, -2.31015328e-02, 1.96259166e-03],\n", + " [-2.61353683e-02, -3.24622802e-02, 1.48035409e-02, ...,\n", + " 2.05925456e-03, 1.05605833e-02, -6.19442873e-02]]]], dtype=float32)\n", + " ), names=('embed', 'layers', 'kv_heads', 'kv_head_dim'), mesh=None, rules=None)},\n", + " 'out': {'kernel': LogicallyPartitioned(value=Param(\n", + " value=Array([[[[-4.74479934e-03, -3.31639429e-03, -2.54307073e-02, ...,\n", + " -9.98777337e-03, -3.19514833e-02, 3.50650623e-02],\n", + " [-1.71285346e-02, -5.97233698e-03, -2.32156683e-02, ...,\n", + " 2.47568805e-02, -7.77165443e-02, 6.39596861e-03],\n", + " [-3.55790555e-02, 6.75198808e-03, -5.57309017e-02, ...,\n", + " 2.03684270e-02, -5.43648414e-02, 1.10359658e-02],\n", + " ...,\n", + " [-1.33349700e-02, -7.04783574e-03, -2.30718236e-02, ...,\n", + " 1.19028790e-02, 4.66167703e-02, -5.67456819e-02],\n", + " [ 9.79898032e-03, 2.37176698e-02, 2.53103226e-02, ...,\n", + " 2.77623460e-02, 3.63001563e-02, -1.09400484e-03],\n", + " [ 1.58415884e-02, -2.07305886e-02, -2.13162042e-03, ...,\n", + " -1.47786308e-02, 7.77488982e-04, -5.18164411e-02]],\n", + " \n", + " [[ 1.51699884e-02, -1.75642055e-02, 2.52563339e-02, ...,\n", + " -1.69500615e-02, -2.06067264e-02, -2.24556420e-02],\n", + " [ 7.90769607e-03, 1.74246300e-02, -6.54656515e-02, ...,\n", + " -4.13244888e-02, -2.79289614e-02, -9.48653370e-03],\n", + " [-7.60355778e-03, 1.10810483e-02, 4.73006181e-02, ...,\n", + " 2.24877652e-02, 9.22497187e-04, 1.12072229e-02],\n", + " ...,\n", + " [ 5.90484887e-02, 7.54298817e-04, 1.40212858e-02, ...,\n", + " 5.23132458e-03, -1.22377416e-02, 1.97242666e-02],\n", + " [ 1.65537167e-02, 3.18695349e-03, 4.57507633e-02, ...,\n", + " -4.41738330e-02, -3.34094348e-03, 4.47885953e-02],\n", + " [ 6.23717671e-03, -2.92454436e-02, -3.32461596e-02, ...,\n", + " 5.23271691e-03, -1.85233890e-03, -6.84564514e-03]],\n", + " \n", + " [[ 6.50280155e-03, 1.24506224e-02, 3.71247740e-03, ...,\n", + " 3.00733931e-02, -8.11321021e-04, 1.60306208e-02],\n", + " [-3.32447700e-02, -2.05644462e-02, -1.51062636e-02, ...,\n", + " -9.05556139e-04, -1.70840304e-02, 1.27298478e-02],\n", + " [ 1.26500176e-02, -3.19515988e-02, -3.84614314e-03, ...,\n", + " -2.15836745e-02, 6.27282588e-03, 7.64242082e-04],\n", + " ...,\n", + " [-1.42661929e-02, 4.25367180e-04, -2.75921612e-03, ...,\n", + " 2.11788006e-02, -2.08097734e-02, -9.80366115e-03],\n", + " [-1.19812582e-02, 1.48620205e-02, 1.53905656e-02, ...,\n", + " 9.94326733e-03, 5.49498620e-03, 2.59696748e-02],\n", + " [-3.65375467e-02, 3.37866717e-03, 3.55184712e-02, ...,\n", + " -4.97174561e-02, -5.65935904e-03, 2.17198152e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.20181674e-02, 5.38566010e-03, -4.82130283e-03, ...,\n", + " 1.66477989e-02, -1.87554378e-02, -9.29502025e-03],\n", + " [ 1.92568544e-02, -3.73350605e-02, 5.31212706e-03, ...,\n", + " 3.29655968e-03, 3.92573439e-02, -3.42247495e-03],\n", + " [-9.76938475e-03, -5.42233372e-03, -2.34038588e-02, ...,\n", + " -4.25107740e-02, -1.57238413e-02, -1.24199856e-02],\n", + " ...,\n", + " [ 1.39326099e-02, 1.00796921e-02, -3.12647149e-02, ...,\n", + " 1.56849083e-02, -8.55383277e-03, -1.48214251e-02],\n", + " [-3.17233168e-02, -1.74471345e-02, 1.86914727e-02, ...,\n", + " -6.10606978e-04, -2.10036431e-02, -2.04265770e-02],\n", + " [-2.12330539e-02, -1.64742334e-04, 3.70428301e-02, ...,\n", + " -4.91813337e-03, -3.06218807e-02, -2.18151659e-02]],\n", + " \n", + " [[ 7.97324069e-03, -3.54506224e-02, -5.17089805e-03, ...,\n", + " 3.35150212e-02, -2.42773499e-02, 1.37767848e-03],\n", + " [-4.90523363e-03, -1.22315083e-02, 2.44078692e-03, ...,\n", + " -5.04425680e-03, -8.54537264e-03, 4.26383875e-02],\n", + " [ 1.78324115e-02, 5.15177986e-03, -1.40919806e-02, ...,\n", + " 1.63182076e-02, -1.11235203e-02, 8.36064667e-03],\n", + " ...,\n", + " [-2.28520408e-02, -2.72207540e-02, 2.80817505e-04, ...,\n", + " 2.55335476e-02, 2.77418289e-02, -1.30551110e-03],\n", + " [-3.71287055e-02, 1.76149253e-02, 1.02488520e-02, ...,\n", + " -2.31687091e-02, -6.60025468e-03, 2.42728163e-02],\n", + " [-3.34804878e-02, 9.78386588e-03, -4.40058298e-02, ...,\n", + " -2.29893383e-02, -2.43940186e-02, 2.80517135e-02]],\n", + " \n", + " [[-5.24006225e-02, 2.05250327e-02, 1.76899284e-02, ...,\n", + " -2.71630753e-02, 2.72141341e-02, 2.25454904e-02],\n", + " [ 4.25495133e-02, -1.42155383e-02, 4.27075960e-02, ...,\n", + " -3.24201882e-02, -3.85760027e-03, -7.26933032e-03],\n", + " [-7.45799334e-04, -1.13437325e-02, -1.24565549e-02, ...,\n", + " 1.88092832e-02, -4.10070978e-02, -2.14590617e-02],\n", + " ...,\n", + " [ 8.68310872e-03, 1.00254826e-02, 2.27146298e-02, ...,\n", + " -6.06271904e-03, 6.11264200e-04, -4.57061231e-02],\n", + " [ 1.40004577e-02, -2.19353661e-02, -2.74791877e-04, ...,\n", + " 3.91743630e-02, -3.14766131e-02, -2.81259883e-02],\n", + " [ 5.51663898e-03, 1.50774885e-03, -8.74966476e-03, ...,\n", + " -3.57954577e-02, 8.22459999e-03, -3.36858560e-04]]],\n", + " \n", + " \n", + " [[[-2.83177234e-02, -1.06864925e-02, -1.74293928e-02, ...,\n", + " -8.37876578e-04, 1.45845278e-03, 8.86010565e-03],\n", + " [ 9.54936445e-03, -5.31283719e-03, 1.39202066e-02, ...,\n", + " 9.01347864e-03, 1.38771469e-02, -8.20392917e-04],\n", + " [ 3.96110415e-02, -5.64282807e-03, -1.79204438e-02, ...,\n", + " -2.21537817e-02, -5.17317653e-02, 3.37849348e-03],\n", + " ...,\n", + " [-7.41589768e-03, 1.61673184e-02, -9.91118420e-03, ...,\n", + " 1.63337886e-02, -7.53403595e-03, -2.22752672e-02],\n", + " [ 1.10447491e-02, 3.43038477e-02, 6.74798200e-03, ...,\n", + " 3.00120302e-02, 1.70114245e-02, -2.46060151e-03],\n", + " [-1.30208470e-02, -2.86877807e-03, 1.18944598e-02, ...,\n", + " -1.67585015e-02, -6.89755799e-03, 5.41743124e-03]],\n", + " \n", + " [[-2.93496829e-02, -1.64347850e-02, 6.41369894e-02, ...,\n", + " -1.91545635e-02, 5.24461716e-02, -2.18857303e-02],\n", + " [-1.08717894e-02, 2.83296369e-02, -4.09530587e-02, ...,\n", + " 3.26723866e-02, -3.35235670e-02, 6.82804827e-03],\n", + " [ 1.46635063e-02, 1.83561035e-02, -8.73495638e-03, ...,\n", + " -3.04467585e-02, 2.26073060e-02, 6.68699713e-03],\n", + " ...,\n", + " [ 3.13676633e-02, -2.03809254e-02, -6.91012107e-03, ...,\n", + " -2.20466889e-02, 1.31586567e-02, 4.81996732e-03],\n", + " [-2.65287310e-02, -1.98266953e-02, 3.98234129e-02, ...,\n", + " -1.37773026e-02, -1.44258169e-02, 3.34249623e-02],\n", + " [ 2.72602607e-02, -1.58936623e-03, -1.19184665e-02, ...,\n", + " 4.22138721e-03, -2.71547008e-02, 2.23858040e-02]],\n", + " \n", + " [[ 2.35281102e-02, -9.59722698e-03, -9.40936711e-03, ...,\n", + " -2.16489006e-02, -1.18872970e-02, -6.27314523e-02],\n", + " [-2.05160640e-02, -5.10578649e-03, -2.60898285e-02, ...,\n", + " 3.75537723e-02, -2.15892736e-02, 6.84330752e-03],\n", + " [ 4.38294327e-03, 5.23066567e-03, 3.79776186e-03, ...,\n", + " -8.40374082e-03, 3.67972767e-03, 5.34982979e-02],\n", + " ...,\n", + " [-5.14540588e-03, 9.10828169e-03, 9.63077229e-03, ...,\n", + " -6.23551011e-03, -1.33264232e-02, 2.75616776e-02],\n", + " [ 1.55161293e-02, -8.41555372e-03, 1.05763588e-03, ...,\n", + " -3.76014449e-02, 4.82998183e-03, 2.63492148e-02],\n", + " [-5.46485977e-03, 4.90796228e-04, 8.21621530e-03, ...,\n", + " 4.48206775e-02, -1.67524032e-02, 7.62323325e-05]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.74942035e-02, -1.58969816e-02, -1.68458093e-02, ...,\n", + " 4.08440223e-03, 2.43034139e-02, -2.23822799e-02],\n", + " [ 2.70088762e-02, -3.18728387e-02, 1.23735890e-02, ...,\n", + " 1.56726204e-02, -3.69952619e-02, -2.07206439e-02],\n", + " [-1.96328927e-02, 1.17566832e-03, 3.10069602e-03, ...,\n", + " -5.90704149e-03, -2.18282100e-02, 6.36378583e-03],\n", + " ...,\n", + " [-9.58876032e-03, -1.45538878e-02, -7.60570820e-03, ...,\n", + " 4.10533510e-02, 2.44914517e-02, -1.48092890e-02],\n", + " [ 2.91717090e-02, 2.07352173e-02, 2.25588074e-03, ...,\n", + " -2.53619570e-02, 1.67360678e-02, -1.34534314e-02],\n", + " [ 2.44944450e-02, -1.64350290e-02, -8.84152204e-03, ...,\n", + " 1.74066853e-02, -6.02006400e-03, -1.02592967e-02]],\n", + " \n", + " [[-8.39177519e-03, 1.73828402e-03, 4.29000147e-02, ...,\n", + " 2.91833598e-02, 4.83867899e-03, -4.04443368e-02],\n", + " [ 9.90391616e-03, -1.10775593e-03, -2.02611983e-02, ...,\n", + " 1.39314646e-03, 2.91320775e-02, 2.60988511e-02],\n", + " [ 5.52384602e-03, -3.90357338e-03, -3.34765166e-02, ...,\n", + " 2.66178101e-02, -1.70076564e-02, -7.50930980e-03],\n", + " ...,\n", + " [ 6.19337056e-03, 4.43901215e-03, 7.01038912e-03, ...,\n", + " 2.75035552e-03, 1.47229834e-02, -1.71172302e-02],\n", + " [ 2.91238315e-02, -5.56633575e-03, -1.33797070e-02, ...,\n", + " -2.43431646e-02, 9.38323513e-03, -1.70616750e-02],\n", + " [ 3.79210711e-02, -2.46435609e-02, -4.01678346e-02, ...,\n", + " 1.29145908e-03, -1.62827186e-02, 1.55427614e-02]],\n", + " \n", + " [[-3.06156911e-02, -1.72037762e-02, 4.97930916e-03, ...,\n", + " -2.33553108e-02, -1.30510209e-02, 1.93368434e-03],\n", + " [ 2.76000816e-02, -1.05546145e-02, 1.16228219e-02, ...,\n", + " -2.81293709e-02, 1.04423882e-02, -8.29986017e-03],\n", + " [-7.60432798e-03, 2.97159627e-02, -1.51876155e-02, ...,\n", + " -3.49945277e-02, 3.72992381e-02, -3.98128815e-02],\n", + " ...,\n", + " [ 1.48279453e-02, -5.43901995e-02, 1.90758463e-02, ...,\n", + " -6.33996585e-03, 2.36277953e-02, -1.13533614e-02],\n", + " [ 1.71122346e-02, -3.01914127e-03, 1.51281692e-02, ...,\n", + " 7.69672915e-04, -7.51318596e-03, -7.61045888e-03],\n", + " [-2.43114009e-02, 3.15875858e-02, 4.13025916e-03, ...,\n", + " 1.39208753e-02, 3.28444852e-03, 1.26043381e-02]]],\n", + " \n", + " \n", + " [[[ 3.28501314e-02, 1.40637727e-02, 1.06818657e-02, ...,\n", + " 2.76624355e-02, -1.34565867e-02, -2.16124840e-02],\n", + " [ 2.78543215e-02, 3.60914017e-03, 3.14002112e-02, ...,\n", + " 1.41889565e-02, 6.54782057e-02, -6.32778881e-03],\n", + " [-3.10241673e-02, 3.29673989e-03, -9.87109356e-03, ...,\n", + " 2.43205782e-02, 1.43774571e-02, 1.73833817e-02],\n", + " ...,\n", + " [ 4.68212506e-03, -1.00239562e-02, -1.45210624e-02, ...,\n", + " -4.73263115e-03, 3.76123935e-02, 7.60522112e-03],\n", + " [ 3.25159654e-02, -7.58462446e-03, 2.13728156e-02, ...,\n", + " 2.36012060e-02, -1.28713530e-02, 5.91887794e-02],\n", + " [ 1.98195428e-02, 2.29141172e-02, -3.93708237e-02, ...,\n", + " -1.57596190e-02, -3.21464837e-02, -1.64040308e-02]],\n", + " \n", + " [[-1.01205250e-02, -4.06561494e-02, 1.58451051e-02, ...,\n", + " 3.05374768e-02, -1.16339130e-02, 4.91988938e-03],\n", + " [-3.87121085e-03, 2.94669662e-02, -1.41026396e-02, ...,\n", + " 2.67188903e-02, 2.24877652e-02, 1.74150877e-02],\n", + " [-1.75633002e-02, -5.96264657e-03, -2.86255439e-04, ...,\n", + " -8.05678684e-03, 1.94369303e-03, -1.55476201e-02],\n", + " ...,\n", + " [-1.56619269e-02, 4.53769416e-02, -1.73271354e-02, ...,\n", + " -2.28548460e-02, -3.94833507e-03, 2.50194152e-03],\n", + " [ 1.71677209e-02, -3.26283537e-02, 3.01127341e-02, ...,\n", + " 7.07132556e-03, 2.74674799e-02, 9.70115606e-03],\n", + " [ 9.60034411e-03, -4.24372703e-02, -3.98676582e-02, ...,\n", + " -1.75861474e-02, 5.27254380e-02, 3.84769519e-03]],\n", + " \n", + " [[-1.84717011e-02, 2.53326688e-02, 2.44959202e-02, ...,\n", + " 8.35703686e-03, 7.32624810e-03, 1.91706009e-02],\n", + " [-8.93448573e-03, -6.83587161e-04, 9.68769286e-03, ...,\n", + " 3.39575484e-02, -5.94374910e-03, -1.46517949e-03],\n", + " [ 1.52185829e-02, -7.98597839e-03, -1.42547274e-02, ...,\n", + " 3.09771299e-02, -1.19755883e-03, -9.66020336e-04],\n", + " ...,\n", + " [-2.51756865e-03, 2.46529877e-02, 2.33551790e-03, ...,\n", + " -5.30952029e-03, -2.83103120e-02, 6.19755313e-03],\n", + " [ 2.06221901e-02, 3.17784399e-02, 4.35360968e-02, ...,\n", + " 5.26622776e-03, -7.56918453e-04, 2.67179590e-02],\n", + " [-7.47193291e-04, -2.50263587e-02, -1.74052306e-02, ...,\n", + " -7.38941412e-03, -6.68366964e-04, -2.87869317e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 3.94428000e-02, 2.23548170e-02, 5.83733257e-04, ...,\n", + " 2.67034527e-02, -1.15226572e-02, 3.26089524e-02],\n", + " [ 5.42930560e-03, -2.66327728e-02, -2.19931845e-02, ...,\n", + " -1.57176685e-02, -9.30693839e-03, -1.82678306e-03],\n", + " [-2.65372470e-02, 3.13082215e-04, 1.05436295e-02, ...,\n", + " 1.43382242e-02, -5.22248028e-03, -3.44935581e-02],\n", + " ...,\n", + " [-8.32139701e-03, 1.20895403e-02, 1.09900488e-02, ...,\n", + " 1.74154751e-02, 2.36615278e-02, -9.82946623e-03],\n", + " [ 1.66328661e-02, -3.28258201e-02, -7.89046101e-03, ...,\n", + " -3.29314247e-02, 1.56794209e-02, 6.49438202e-02],\n", + " [ 9.41897929e-03, -2.14849245e-02, -2.16542836e-03, ...,\n", + " 6.63580932e-03, -1.44503713e-02, -2.11114320e-03]],\n", + " \n", + " [[ 7.99161475e-03, 7.68584618e-03, 1.64266918e-02, ...,\n", + " -1.78568792e-02, 5.51055605e-03, -3.71390879e-02],\n", + " [ 1.57207847e-02, -2.50175465e-02, -2.21068356e-02, ...,\n", + " -3.37335654e-02, 2.56969053e-02, -2.12858450e-02],\n", + " [-2.12062895e-02, -3.61115150e-02, -3.77251692e-02, ...,\n", + " -5.02254535e-03, 2.83875894e-02, -1.54167861e-02],\n", + " ...,\n", + " [ 2.14903448e-02, -1.69867184e-02, 1.47554623e-02, ...,\n", + " 9.23852064e-03, -2.42061187e-02, -1.77666545e-02],\n", + " [-3.81695200e-03, -9.58001838e-05, 2.86416914e-02, ...,\n", + " 9.13325232e-03, -8.07523634e-03, -3.29999179e-02],\n", + " [ 2.00948734e-02, 4.03151475e-02, 3.58236544e-02, ...,\n", + " -2.78221481e-02, -3.85066168e-03, 3.07574868e-04]],\n", + " \n", + " [[ 8.59932322e-03, 4.34936606e-04, 3.50060910e-02, ...,\n", + " 3.14257829e-03, -3.98972407e-02, 2.06447765e-02],\n", + " [ 1.28529444e-02, -1.38166118e-02, -4.35717171e-03, ...,\n", + " -3.05310730e-02, 3.47338826e-03, 7.48051330e-03],\n", + " [ 2.55758949e-02, -3.59037854e-02, 2.57720500e-02, ...,\n", + " -2.56014783e-02, 2.66553182e-02, 1.83349755e-02],\n", + " ...,\n", + " [-2.49241441e-02, -3.31429057e-02, -3.08509804e-02, ...,\n", + " 1.65575370e-02, -2.33910599e-04, -1.72104109e-02],\n", + " [ 1.21277571e-02, 1.11402180e-02, -3.57330479e-02, ...,\n", + " 4.28419793e-03, -1.15025486e-03, -3.42245661e-02],\n", + " [-7.94969033e-03, -2.70549878e-02, -5.75279221e-02, ...,\n", + " 4.00079507e-03, 1.15590561e-02, 3.83644700e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-2.72335811e-03, -1.10900234e-02, -6.49100123e-03, ...,\n", + " 2.58722790e-02, 1.01897947e-03, -1.88811105e-02],\n", + " [-2.53289938e-02, 1.57807730e-02, -1.33049833e-02, ...,\n", + " -3.23417261e-02, -1.63922142e-02, -3.55102904e-02],\n", + " [-1.32589061e-02, 4.70475620e-03, 1.71098839e-02, ...,\n", + " 1.29518230e-02, -2.72585433e-02, 5.16929990e-03],\n", + " ...,\n", + " [-4.00732420e-02, 1.28089625e-03, 9.10441391e-03, ...,\n", + " -1.23770377e-02, 1.85080729e-02, -4.01422987e-03],\n", + " [ 2.28935201e-02, 1.73642058e-02, -1.45190712e-02, ...,\n", + " 1.93074197e-02, 2.70247157e-03, 1.75877847e-02],\n", + " [-3.85131873e-03, -2.26252805e-02, 2.85611656e-02, ...,\n", + " -6.99687423e-03, -1.41741792e-02, -6.79824874e-03]],\n", + " \n", + " [[ 1.65484995e-02, -9.10331775e-03, -3.86426374e-02, ...,\n", + " 2.84203980e-02, -1.04945740e-02, -6.55397680e-03],\n", + " [-7.09510818e-02, 3.71479131e-02, -2.23602336e-02, ...,\n", + " 3.27484868e-02, -4.32940992e-03, 5.57021052e-03],\n", + " [ 1.83354970e-02, -5.16767837e-02, -9.44902189e-03, ...,\n", + " -2.10109055e-02, -2.71564703e-02, -9.40832589e-03],\n", + " ...,\n", + " [ 1.88474245e-02, -1.36142725e-03, 9.54398187e-04, ...,\n", + " 5.42460987e-03, 1.67939849e-02, -2.02694722e-02],\n", + " [-6.27775118e-03, 8.71084223e-04, 1.07724201e-02, ...,\n", + " 3.18997586e-03, 8.65433458e-03, 9.33027454e-03],\n", + " [-1.07231494e-02, 6.59098616e-03, 9.92858410e-03, ...,\n", + " 1.63427636e-03, -1.15730604e-02, 7.22313859e-03]],\n", + " \n", + " [[ 1.03402454e-02, -4.80356403e-02, 3.88671942e-02, ...,\n", + " 2.53254846e-02, 4.64370437e-02, -6.14036852e-03],\n", + " [-3.74754965e-02, 7.07579451e-03, -1.73182748e-02, ...,\n", + " -2.01226119e-02, -2.42718589e-02, 4.76582423e-02],\n", + " [ 1.30117014e-02, 5.32194786e-03, -4.66591157e-02, ...,\n", + " -1.05918702e-02, -1.12826768e-02, 7.65951164e-03],\n", + " ...,\n", + " [ 2.73534097e-02, -2.08311081e-02, -2.15688925e-02, ...,\n", + " -2.47239154e-02, -4.61445600e-02, 2.78947763e-02],\n", + " [-1.29930824e-02, 2.63795517e-02, 3.45615670e-02, ...,\n", + " 2.17073946e-03, 2.31790133e-02, 3.95826809e-03],\n", + " [ 2.65401304e-02, -5.20734023e-03, 2.78182735e-04, ...,\n", + " 1.99910812e-02, -1.63113791e-02, -1.15462206e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.21304393e-02, -4.14055586e-02, 3.17521095e-02, ...,\n", + " -2.20903065e-02, 1.24863070e-02, 3.80968500e-04],\n", + " [-2.34106630e-02, 4.92765717e-02, 3.77997197e-02, ...,\n", + " -4.09190514e-04, 1.78235155e-02, -5.57436273e-02],\n", + " [-1.35911359e-02, 2.48493738e-02, 2.86827758e-02, ...,\n", + " 2.69374102e-02, -3.16027328e-02, 4.34140638e-02],\n", + " ...,\n", + " [-8.82776082e-03, 7.44447019e-03, -1.86879875e-03, ...,\n", + " -4.55543958e-02, -3.52773117e-03, -1.74863338e-02],\n", + " [ 2.92029083e-02, 4.31048824e-03, 1.94344595e-02, ...,\n", + " -2.45319377e-03, 3.01111285e-02, -2.06523109e-02],\n", + " [ 2.29896959e-02, -1.19661549e-02, -6.91948226e-03, ...,\n", + " 1.00500649e-02, 7.45975599e-03, 2.07517222e-02]],\n", + " \n", + " [[ 1.03465002e-02, -5.50190918e-03, -2.47487742e-02, ...,\n", + " 5.29889250e-04, -9.00464971e-03, -3.14186793e-03],\n", + " [-2.00759266e-02, -3.60511839e-02, -2.95261834e-02, ...,\n", + " 3.33643146e-02, 2.76790373e-03, -1.62162725e-03],\n", + " [-3.36428881e-02, 2.06500478e-02, -1.78107095e-03, ...,\n", + " -1.71851963e-02, 3.12182889e-03, 9.36799031e-03],\n", + " ...,\n", + " [ 2.27399990e-02, -2.82836724e-02, 6.94419071e-03, ...,\n", + " 2.39255708e-02, 1.16293994e-03, -6.11391012e-03],\n", + " [-1.19105410e-02, -1.71797629e-02, 5.12789898e-02, ...,\n", + " -4.84078228e-02, 4.02263962e-02, 2.54617874e-02],\n", + " [ 2.48029567e-02, 1.23962713e-02, 1.06537836e-02, ...,\n", + " -5.41236550e-02, -2.43732519e-02, -2.56188866e-03]],\n", + " \n", + " [[ 1.94444731e-02, -3.05626746e-02, -4.29862663e-02, ...,\n", + " -2.18314026e-02, 2.52272505e-02, 1.86543223e-02],\n", + " [-8.54164653e-04, 2.91635823e-02, 2.93131359e-02, ...,\n", + " 1.08039230e-02, 3.13602164e-02, -2.39496678e-02],\n", + " [-1.22776655e-02, -5.47520933e-04, 3.68749443e-03, ...,\n", + " 2.20335554e-02, 2.06334656e-03, -1.33078787e-02],\n", + " ...,\n", + " [-1.13730249e-03, 7.11594126e-04, -1.58041939e-02, ...,\n", + " -2.38339370e-03, 3.21297981e-02, 3.58687225e-03],\n", + " [ 2.44203731e-02, 9.26344469e-03, 1.22231785e-02, ...,\n", + " -3.28979082e-02, 3.15811345e-03, 2.65061781e-02],\n", + " [ 2.56357491e-02, 4.00351323e-02, 3.75372432e-02, ...,\n", + " -1.58260483e-02, -4.50757053e-03, -1.46084500e-03]]],\n", + " \n", + " \n", + " [[[-4.16907901e-03, -1.63985956e-02, 1.06381811e-02, ...,\n", + " 5.85727468e-02, -3.22487578e-02, -3.13513130e-02],\n", + " [-1.89134814e-02, 6.91604465e-02, 8.05565435e-03, ...,\n", + " 5.67025272e-03, -1.02553191e-02, 2.63325106e-02],\n", + " [-2.25603823e-02, -2.46176450e-03, -1.23682329e-02, ...,\n", + " -1.31175136e-02, -4.39877175e-02, 4.54103155e-03],\n", + " ...,\n", + " [-9.61553026e-03, 6.01437874e-02, 3.33577283e-02, ...,\n", + " -3.10457381e-03, -1.65850166e-02, -3.43850143e-02],\n", + " [ 2.08335649e-02, 1.94177348e-02, 2.40582377e-02, ...,\n", + " 1.45591339e-02, -4.10769433e-02, -1.45788165e-02],\n", + " [-7.29953870e-03, -1.21789973e-03, 3.42227332e-02, ...,\n", + " -2.08858289e-02, -1.64305959e-02, 3.15355025e-02]],\n", + " \n", + " [[-4.16124891e-03, -1.26387449e-02, -1.84693597e-02, ...,\n", + " -8.57015420e-03, -7.02943467e-03, -1.97605900e-02],\n", + " [-9.35217645e-03, -8.23301525e-05, -1.21519249e-02, ...,\n", + " -2.21263915e-02, 1.47277396e-02, -4.34198193e-02],\n", + " [-2.66537406e-02, 3.59394625e-02, -2.44168770e-02, ...,\n", + " -1.32152550e-02, 3.75929810e-02, 5.02283312e-03],\n", + " ...,\n", + " [-2.41201371e-03, 6.20710663e-03, 1.18033634e-02, ...,\n", + " 1.54988803e-02, -7.68166938e-05, 1.68110710e-02],\n", + " [ 1.06705418e-02, -4.34499010e-02, 4.58495393e-02, ...,\n", + " -4.65306751e-02, 2.03165580e-02, -3.10213980e-03],\n", + " [ 1.58765335e-02, 1.05357207e-02, 8.60831980e-03, ...,\n", + " 4.08648401e-02, -2.32436750e-02, 1.85787100e-02]],\n", + " \n", + " [[-1.64776091e-02, 1.06714172e-02, 1.85008347e-02, ...,\n", + " 1.06418524e-02, -1.61947515e-02, -6.68113260e-03],\n", + " [-5.92493778e-03, 1.25645725e-02, 3.22690437e-04, ...,\n", + " 1.51707744e-02, -3.47490935e-03, -3.30746807e-02],\n", + " [ 1.71640683e-02, -1.89680438e-02, 1.94635913e-02, ...,\n", + " -1.75008420e-02, -8.87238979e-03, 1.58845969e-02],\n", + " ...,\n", + " [ 1.98462401e-02, -9.17763263e-03, 3.31103578e-02, ...,\n", + " -2.43605915e-02, -4.84838448e-02, -3.10673565e-02],\n", + " [ 2.32738759e-02, -3.79366404e-03, 2.46869382e-02, ...,\n", + " 3.04560680e-02, 3.61048500e-03, 1.06272101e-02],\n", + " [ 2.28478201e-02, -2.91082030e-03, -1.51661728e-02, ...,\n", + " 4.59267478e-03, -5.68744726e-03, -2.95207202e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.79519095e-02, 3.52353463e-03, -3.14719900e-02, ...,\n", + " -7.30430149e-03, 2.12492608e-02, -3.19696218e-02],\n", + " [ 7.63541646e-03, 1.75173227e-02, -2.68753767e-02, ...,\n", + " 2.92374957e-02, 1.23241879e-02, 3.24454382e-02],\n", + " [-3.15307081e-02, 1.15962059e-03, 1.94866862e-02, ...,\n", + " -2.60506221e-03, -4.24133055e-03, -2.81141736e-02],\n", + " ...,\n", + " [ 2.36145370e-02, -2.46716440e-02, -5.15227430e-02, ...,\n", + " -3.66021320e-02, 1.38321714e-02, 4.08156979e-04],\n", + " [ 2.63132863e-02, -8.86155479e-03, 2.22663162e-03, ...,\n", + " 1.63082890e-02, -2.58268546e-02, -3.43668982e-02],\n", + " [-6.77102618e-03, -3.41618210e-02, -6.38010129e-02, ...,\n", + " 1.44619402e-02, 5.40126376e-02, -1.05246799e-02]],\n", + " \n", + " [[ 2.58265622e-02, 2.68756058e-02, -1.85905471e-02, ...,\n", + " 1.89487766e-02, 1.31017237e-04, 7.53056211e-03],\n", + " [ 3.27585675e-02, 3.91397662e-02, 8.38544872e-03, ...,\n", + " 2.05928367e-02, 1.86090805e-02, -2.91778725e-02],\n", + " [-3.14335502e-03, 2.35188026e-02, 1.10888733e-02, ...,\n", + " 3.13352968e-04, -8.69278517e-03, 2.46937051e-02],\n", + " ...,\n", + " [ 4.93837260e-02, 8.29393417e-03, 3.93152703e-04, ...,\n", + " -2.91522034e-02, 1.00783594e-02, -3.06068305e-02],\n", + " [ 9.33646876e-03, 2.71096686e-03, -1.75078716e-02, ...,\n", + " 2.49596732e-03, 4.82263247e-04, -1.95445418e-02],\n", + " [-1.47100464e-02, -3.73872608e-04, 3.36909927e-02, ...,\n", + " -1.35094011e-02, 2.63457447e-02, -1.62587259e-02]],\n", + " \n", + " [[ 8.78979638e-03, -8.15477129e-03, -2.48800982e-02, ...,\n", + " 1.83275510e-02, 1.31760128e-02, 1.87910385e-02],\n", + " [-7.60001270e-03, 2.57015694e-02, 2.61115208e-02, ...,\n", + " -1.67176337e-03, -2.05829032e-02, -1.01749524e-02],\n", + " [-3.93159036e-03, 4.71936679e-03, -1.28736692e-02, ...,\n", + " 2.02282798e-02, 3.68582644e-02, 1.25563820e-03],\n", + " ...,\n", + " [-1.78766227e-03, 7.45871477e-03, 7.68106896e-03, ...,\n", + " -2.73127183e-02, -1.39780007e-02, 1.95874162e-02],\n", + " [ 3.15577863e-03, -2.08507739e-02, -2.30499245e-02, ...,\n", + " -1.93229709e-02, -7.12386239e-03, 2.02831477e-02],\n", + " [ 8.76611099e-03, 1.64622813e-02, -3.95691924e-04, ...,\n", + " -2.52286578e-03, -1.98110770e-02, -5.14745433e-03]]],\n", + " \n", + " \n", + " [[[ 5.16243950e-02, -1.52431813e-03, 1.49962166e-03, ...,\n", + " -1.12565293e-03, 1.98276099e-02, -9.02298198e-04],\n", + " [-2.35079369e-03, -7.29617178e-02, -6.96790311e-03, ...,\n", + " -1.08093880e-02, -1.51018929e-02, -1.44857066e-02],\n", + " [ 1.56849846e-02, -3.27217504e-02, 2.89568808e-02, ...,\n", + " -1.07262693e-02, 8.06438457e-03, 1.42819462e-02],\n", + " ...,\n", + " [ 6.56878203e-03, -2.99751945e-02, -2.16652863e-02, ...,\n", + " -1.28091201e-02, -5.75369373e-02, -1.27787469e-02],\n", + " [ 5.94798336e-03, 2.83241421e-02, 1.48073409e-03, ...,\n", + " 2.30742879e-02, 1.59918852e-02, -1.20190820e-02],\n", + " [ 1.55482078e-02, -9.91898309e-03, 4.49288450e-03, ...,\n", + " 1.38532044e-02, 2.61466615e-02, -7.10423756e-03]],\n", + " \n", + " [[-1.91627573e-02, -2.46173851e-02, 2.46993545e-02, ...,\n", + " -2.59642415e-02, -1.50099918e-02, 1.40191214e-02],\n", + " [-1.70771573e-02, -5.77787543e-03, -1.44021530e-02, ...,\n", + " 1.45225823e-02, -4.76392359e-02, -3.87129793e-03],\n", + " [-1.20939082e-02, -2.08247472e-02, 2.94512324e-02, ...,\n", + " 2.20998600e-02, 7.76751386e-03, -1.38627915e-02],\n", + " ...,\n", + " [ 3.75998877e-02, -4.71507385e-03, -3.60399764e-03, ...,\n", + " -1.07630098e-03, -2.30168439e-02, 1.91546157e-02],\n", + " [-5.71460230e-03, 1.94935352e-02, 5.93578350e-03, ...,\n", + " 6.71439525e-03, 1.09529188e-02, 1.11684483e-02],\n", + " [ 9.49900690e-03, 7.62948319e-02, -1.60003584e-02, ...,\n", + " 1.00835469e-02, -5.67997880e-02, -2.02202448e-03]],\n", + " \n", + " [[ 1.08363032e-02, 1.45316310e-02, 1.55245103e-02, ...,\n", + " -1.58841279e-03, 5.44003793e-04, -1.74412248e-03],\n", + " [ 2.87559349e-03, -3.03537101e-02, 6.19730586e-03, ...,\n", + " -3.80939506e-02, 1.56850945e-02, -2.35957336e-02],\n", + " [ 1.89092942e-02, -4.07658052e-03, 7.91690312e-03, ...,\n", + " 1.86924823e-02, 1.85154919e-02, 2.36026775e-02],\n", + " ...,\n", + " [ 2.79152952e-02, -1.51411246e-03, -4.93254466e-03, ...,\n", + " 3.02837156e-02, -1.45493941e-02, -2.92634610e-02],\n", + " [-6.05904404e-03, -2.47722659e-02, -2.39776038e-02, ...,\n", + " 2.67081261e-02, -1.81289390e-02, -1.54672307e-03],\n", + " [-9.27512161e-03, -3.26469145e-03, -3.33408476e-03, ...,\n", + " -6.71523623e-03, -9.24872886e-03, -1.52468067e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.87003594e-02, 1.98093057e-02, 4.47006291e-03, ...,\n", + " 5.28374873e-03, -1.38477040e-02, 2.45643202e-02],\n", + " [-1.79056507e-02, -4.85386001e-03, -1.40788015e-02, ...,\n", + " 6.49160426e-03, -1.88346468e-02, -1.31046874e-02],\n", + " [-8.95902421e-03, 4.14338370e-04, 4.55709659e-02, ...,\n", + " -2.79974286e-02, -1.54177388e-02, -2.59000752e-02],\n", + " ...,\n", + " [-1.50168380e-02, 5.61868362e-02, 3.18701975e-02, ...,\n", + " 1.48454821e-02, 8.76298000e-04, 1.26774637e-02],\n", + " [-2.10217424e-02, 4.02421951e-02, -1.86079461e-02, ...,\n", + " 1.06168780e-02, -1.89959574e-02, 3.94978411e-02],\n", + " [ 2.64965016e-02, 2.73485878e-03, -1.31304236e-02, ...,\n", + " 1.98476743e-02, -7.81391189e-03, 4.58555017e-03]],\n", + " \n", + " [[ 1.62440706e-02, 1.80590320e-02, 1.38924541e-02, ...,\n", + " 2.79267821e-02, 2.01046392e-02, -1.62884705e-02],\n", + " [ 1.63081326e-02, -1.16606103e-02, -5.77286026e-03, ...,\n", + " 1.44915646e-02, 3.22017446e-02, -3.95494103e-02],\n", + " [ 2.87858918e-02, -1.91345364e-02, 3.10993139e-02, ...,\n", + " -7.28548877e-03, -3.13448645e-02, -1.65288020e-02],\n", + " ...,\n", + " [-1.32212015e-02, -3.02918558e-03, 6.71001151e-03, ...,\n", + " 1.06916251e-02, 8.53407849e-03, 4.93555330e-03],\n", + " [-2.22906228e-02, -5.93961403e-03, -1.45730795e-02, ...,\n", + " -1.47141255e-02, 8.70022271e-03, 1.99010745e-02],\n", + " [ 3.70372869e-02, -1.66909478e-03, -7.91354198e-03, ...,\n", + " -4.15867474e-03, -5.33658732e-03, -1.30036715e-02]],\n", + " \n", + " [[ 2.39764340e-03, -3.73546705e-02, -1.80299431e-02, ...,\n", + " 2.47415621e-02, 9.66696534e-03, -6.21364964e-03],\n", + " [ 2.04866230e-02, -1.53284962e-03, -3.77247706e-02, ...,\n", + " 4.66107810e-03, 9.06148367e-03, -2.35139094e-02],\n", + " [-1.68972425e-02, -1.37862060e-02, 1.51300896e-02, ...,\n", + " -3.18789519e-02, -1.05580855e-02, 5.00372499e-02],\n", + " ...,\n", + " [ 2.45341640e-02, -6.43476262e-04, 1.49203632e-02, ...,\n", + " 1.15576545e-02, 6.78215874e-03, -1.74383875e-02],\n", + " [ 1.04364834e-03, 7.88632687e-03, 1.37785180e-02, ...,\n", + " 1.25991716e-03, -1.24975974e-02, -2.53110006e-03],\n", + " [ 4.87493584e-03, 4.46882937e-03, -8.12995993e-03, ...,\n", + " 1.94751341e-02, 8.87307990e-03, -2.59496160e-02]]]], dtype=float32)\n", + " ), names=('heads', 'layers', 'kv', 'embed'), mesh=None, rules=None)},\n", + " 'query': {'kernel': LogicallyPartitioned(value=Param(\n", + " value=Array([[[[-2.01652804e-03, 2.42590345e-03, -2.09178473e-03, ...,\n", + " -2.16176081e-03, 2.78788572e-03, 1.08526088e-03],\n", + " [-4.52742819e-03, 1.75843190e-03, -4.34232829e-03, ...,\n", + " -2.19754130e-03, -1.49982108e-03, -2.58680154e-03],\n", + " [-4.15907707e-03, -1.46766382e-04, 1.96919823e-03, ...,\n", + " -9.36572265e-04, -1.11866451e-03, -1.78947626e-03],\n", + " ...,\n", + " [ 2.10440229e-03, -4.16784780e-03, 2.04684888e-03, ...,\n", + " 5.21469221e-04, 1.91902649e-03, -3.78059718e-04],\n", + " [-4.21499688e-04, 2.49431934e-03, -1.10886118e-03, ...,\n", + " -1.17066270e-03, 1.61204743e-03, 8.39977409e-04],\n", + " [ 2.13584583e-03, 2.50683678e-03, -2.76794843e-03, ...,\n", + " -3.01412292e-05, 2.01986125e-03, -5.12758736e-04]],\n", + " \n", + " [[ 1.14278635e-03, -3.20536928e-04, -1.05642481e-03, ...,\n", + " -1.80785405e-03, -1.71014934e-03, -2.77611706e-03],\n", + " [-1.99664547e-03, -9.18405713e-04, -1.29196071e-03, ...,\n", + " 6.57983939e-04, 3.55391600e-03, -1.53157650e-03],\n", + " [-4.13984992e-03, -1.53730449e-03, -2.24032695e-03, ...,\n", + " -1.60477415e-04, 5.37427259e-04, -1.73829705e-03],\n", + " ...,\n", + " [-1.98133755e-03, -2.85817590e-03, -1.65687583e-03, ...,\n", + " -1.25946838e-03, -3.25893867e-04, -1.92284759e-03],\n", + " [ 3.63468635e-03, -1.31156296e-03, -1.66612922e-03, ...,\n", + " 1.01074378e-03, 7.71577470e-04, 9.89528373e-04],\n", + " [ 4.93294501e-04, -4.58740426e-04, 4.12119273e-03, ...,\n", + " -1.39962160e-03, 1.41984003e-03, 3.80445272e-03]],\n", + " \n", + " [[-2.20012385e-03, -6.70255540e-05, -5.87148708e-04, ...,\n", + " -3.36029311e-03, 3.35888565e-03, -2.02512415e-03],\n", + " [-2.07296340e-03, 1.00104709e-03, 1.41796691e-03, ...,\n", + " 2.51044356e-03, -2.35976768e-03, -8.74977442e-04],\n", + " [ 1.65059511e-03, -3.31497780e-04, 1.10116387e-04, ...,\n", + " 1.70730776e-03, 2.73278705e-03, -2.58843671e-03],\n", + " ...,\n", + " [ 4.78620129e-03, 9.57353157e-04, 2.49215751e-03, ...,\n", + " 8.82904569e-04, -2.19647330e-03, -1.22317683e-03],\n", + " [-4.04625811e-04, -1.35506713e-03, -2.04782211e-03, ...,\n", + " -1.79551111e-03, -3.32200667e-03, -1.04658329e-03],\n", + " [ 5.05911186e-04, -7.03218917e-04, 6.09300216e-04, ...,\n", + " 2.84836022e-03, -1.30274124e-03, 1.60752985e-04]],\n", + " \n", + " ...,\n", + " \n", + " [[ 3.98090947e-03, -1.50793686e-03, -2.33821291e-03, ...,\n", + " 2.09687348e-03, -1.45075988e-04, -1.04740600e-03],\n", + " [-7.84931704e-04, 1.02867000e-03, -1.29621173e-03, ...,\n", + " -1.20145400e-04, 4.12858644e-04, -5.77150285e-03],\n", + " [ 1.69388150e-05, -3.67378088e-05, -3.49474954e-03, ...,\n", + " -9.42380400e-04, 1.91230606e-03, 1.36247894e-03],\n", + " ...,\n", + " [ 6.16155565e-04, -3.71627370e-03, -1.54793845e-03, ...,\n", + " -5.63956564e-04, -1.37411815e-03, -2.34674662e-03],\n", + " [-3.21534579e-03, 3.54262930e-03, -5.66654722e-04, ...,\n", + " -1.95907103e-03, 2.00136774e-03, -8.29145836e-04],\n", + " [-3.28598639e-06, 2.64123618e-03, -1.39526464e-03, ...,\n", + " -1.55252009e-03, 4.56914495e-05, -1.74677512e-03]],\n", + " \n", + " [[ 1.65245275e-03, 5.99300896e-04, 1.40955707e-03, ...,\n", + " -5.09631820e-03, -6.82767131e-04, 8.46532930e-04],\n", + " [-1.77776197e-03, 5.95380960e-04, -8.75072321e-04, ...,\n", + " 2.14704918e-03, 4.02043248e-03, -6.48023048e-03],\n", + " [-1.78766192e-03, 1.79693312e-03, 9.43141466e-04, ...,\n", + " -9.06297297e-04, -2.81051802e-03, -2.28050584e-03],\n", + " ...,\n", + " [ 2.34009419e-03, 4.95853101e-06, -2.41371267e-03, ...,\n", + " 1.48308463e-03, -1.39985385e-03, -1.17662514e-03],\n", + " [-2.75977212e-03, 5.76169463e-04, 1.88167603e-03, ...,\n", + " -1.16910622e-03, 8.59587381e-05, -1.05389894e-03],\n", + " [ 1.70711288e-03, -1.11140613e-03, 1.44163903e-03, ...,\n", + " 8.34593957e-05, 4.58572613e-05, 1.43440708e-03]],\n", + " \n", + " [[-8.41239817e-04, -2.50688917e-03, 1.15096732e-03, ...,\n", + " -1.06858206e-04, -2.56077503e-03, 2.65735947e-03],\n", + " [ 7.78683752e-04, -5.36518288e-04, 3.11376323e-04, ...,\n", + " -7.06324878e-04, 8.38482811e-04, 6.67221466e-05],\n", + " [ 6.29631279e-04, -2.37617365e-04, -2.73725088e-03, ...,\n", + " -6.85345731e-04, 1.37785356e-03, 5.46075695e-04],\n", + " ...,\n", + " [-2.02252762e-03, -2.22098711e-03, -2.97340448e-03, ...,\n", + " 2.39015743e-03, -1.21460052e-03, 1.81478285e-03],\n", + " [-3.24647152e-03, 2.54299585e-03, -1.27230916e-04, ...,\n", + " 1.56454940e-03, -8.21294452e-05, 1.98223349e-03],\n", + " [-2.16006549e-04, -2.99584633e-03, 1.27099147e-05, ...,\n", + " -6.08568895e-04, 9.31108836e-04, 4.92323097e-03]]],\n", + " \n", + " \n", + " [[[-2.42402847e-03, 1.58200518e-03, 8.60989850e-04, ...,\n", + " 5.24808245e-04, -1.27904024e-03, 1.46103883e-03],\n", + " [ 1.86115049e-03, 9.24920823e-06, 7.81070616e-04, ...,\n", + " 3.53186415e-03, -2.59508169e-03, -1.10025692e-03],\n", + " [-2.42484966e-03, 1.51288614e-03, -3.80000658e-03, ...,\n", + " 6.74261944e-04, 1.12718582e-04, 3.00724362e-03],\n", + " ...,\n", + " [-8.46167852e-04, -1.20655610e-03, -4.22487879e-04, ...,\n", + " -5.36989246e-04, 2.97842926e-04, -1.11366378e-03],\n", + " [-1.35714246e-03, -6.59807527e-04, -2.37879023e-04, ...,\n", + " -7.63746444e-04, -1.19689095e-03, 8.54462036e-04],\n", + " [-1.27247869e-04, 8.36492691e-05, -4.07455285e-04, ...,\n", + " 1.31545909e-04, -1.26759673e-03, -1.13556127e-03]],\n", + " \n", + " [[ 1.80878176e-03, -4.04930033e-04, -1.94863742e-03, ...,\n", + " -1.71593414e-03, -1.35128084e-03, -3.07939015e-04],\n", + " [ 2.83280911e-04, 1.40645413e-03, 6.06766625e-05, ...,\n", + " -4.62970289e-04, -2.45792256e-03, 3.88959679e-03],\n", + " [ 4.72257240e-03, -2.06172862e-03, 1.53980509e-03, ...,\n", + " 1.58444268e-03, 2.68892734e-03, -1.00472989e-03],\n", + " ...,\n", + " [-2.68899696e-03, -3.08185117e-03, 2.99380906e-03, ...,\n", + " -4.92826628e-04, 1.41661975e-03, 2.25331425e-03],\n", + " [ 7.32330023e-04, 2.11428618e-03, -8.02994240e-04, ...,\n", + " -2.99568812e-04, 7.67441816e-06, 1.11636810e-03],\n", + " [ 4.02385376e-05, -7.94252497e-04, -8.23841547e-05, ...,\n", + " 5.55715756e-04, 1.41116988e-03, 8.68803705e-04]],\n", + " \n", + " [[ 1.34165675e-04, -1.89419850e-04, -1.59559352e-03, ...,\n", + " 1.08416390e-03, -7.68493046e-04, -6.67045184e-04],\n", + " [-6.37715624e-04, -5.42039692e-04, -2.83074740e-04, ...,\n", + " -1.26758299e-04, -5.73198218e-03, -7.53256434e-04],\n", + " [-7.70782179e-04, 3.27616610e-04, -1.80330046e-03, ...,\n", + " -1.97392958e-03, 3.21969623e-03, -2.61678779e-03],\n", + " ...,\n", + " [ 2.44396203e-03, -1.03402813e-03, 2.60461791e-04, ...,\n", + " 5.74744947e-04, 5.00099733e-03, 9.69981018e-04],\n", + " [ 8.37068597e-04, -6.73775910e-04, 8.01739749e-04, ...,\n", + " 7.76454108e-04, -4.35047004e-05, -2.17741262e-03],\n", + " [-1.34304329e-03, 2.59777671e-03, 1.59425952e-03, ...,\n", + " 5.78927458e-04, 2.19768169e-03, -3.12926667e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.63430953e-03, 2.60417862e-03, 2.91643734e-03, ...,\n", + " -6.91898924e-04, 3.68897640e-03, 2.57216627e-03],\n", + " [ 1.92283292e-03, -1.76618469e-03, -5.45195362e-04, ...,\n", + " -2.08540459e-05, -5.02103183e-04, 1.72844157e-03],\n", + " [ 3.25435749e-03, 3.69224581e-03, 2.76503689e-03, ...,\n", + " 8.14877130e-05, -2.42474518e-04, 2.29716621e-04],\n", + " ...,\n", + " [ 6.98816904e-04, -1.82767492e-03, -8.29827448e-04, ...,\n", + " 1.76285312e-03, -2.30381591e-03, 6.00503816e-04],\n", + " [-9.35565389e-04, 2.34113308e-03, 1.43279880e-03, ...,\n", + " -2.28526909e-03, 1.04101759e-03, 5.83186455e-04],\n", + " [-8.81537737e-04, 7.20201235e-04, 6.89574226e-04, ...,\n", + " -2.91118398e-03, -2.15113105e-05, 3.23183159e-03]],\n", + " \n", + " [[ 1.84301240e-03, -5.75119455e-04, -1.14869361e-03, ...,\n", + " 9.25116241e-04, -1.51386624e-03, -2.97726481e-03],\n", + " [ 3.34335212e-03, 2.22770288e-03, -1.98119693e-03, ...,\n", + " 1.10006914e-03, 4.94988635e-04, 8.93766410e-04],\n", + " [ 1.04806549e-03, -1.21443474e-03, 8.47132469e-04, ...,\n", + " 6.68702065e-04, -7.00015109e-04, -2.95750634e-03],\n", + " ...,\n", + " [ 2.04890897e-03, -8.50587792e-04, -6.40269136e-04, ...,\n", + " 1.90546480e-03, -1.50805013e-03, 5.70783704e-05],\n", + " [ 1.50648935e-03, 2.67667603e-03, 2.53104442e-03, ...,\n", + " -6.88273867e-04, -2.78768782e-03, 1.82969205e-04],\n", + " [-5.64672882e-05, -2.45890720e-03, 1.32669182e-03, ...,\n", + " -3.03604244e-03, -5.25452988e-03, -2.74663395e-03]],\n", + " \n", + " [[ 2.08497955e-03, -8.14181520e-04, -4.74566303e-04, ...,\n", + " 7.37989438e-04, -2.64599454e-03, -1.32581021e-03],\n", + " [-4.15012095e-04, 2.16006325e-03, -1.56865572e-03, ...,\n", + " 2.98217195e-03, 2.70293653e-03, -2.86708609e-03],\n", + " [ 1.82538701e-03, 1.68169744e-03, -9.76276177e-04, ...,\n", + " 2.21827370e-03, 2.21834285e-03, -4.74569341e-03],\n", + " ...,\n", + " [-2.21753702e-03, 4.12888112e-05, -6.61199330e-04, ...,\n", + " 2.03105644e-03, -9.55777417e-04, 3.98159260e-04],\n", + " [-1.73387316e-03, -1.63896650e-03, 3.25647555e-03, ...,\n", + " -3.91141279e-03, 4.89048287e-03, -3.28900060e-03],\n", + " [-2.61775032e-03, -1.99604733e-03, 4.03047743e-04, ...,\n", + " 6.25679083e-03, 2.55644263e-04, -1.93146605e-03]]],\n", + " \n", + " \n", + " [[[-6.09099516e-05, 9.38913552e-04, 1.10663555e-03, ...,\n", + " -1.26173196e-03, -2.12964811e-03, -1.43552362e-03],\n", + " [-7.64798839e-04, -4.66175406e-04, 4.60857205e-04, ...,\n", + " 1.14330172e-03, 9.85053717e-04, -4.25687619e-03],\n", + " [-1.13502506e-03, 1.39013771e-03, -1.45604636e-03, ...,\n", + " -5.19513560e-05, 6.11972529e-04, -2.37623765e-03],\n", + " ...,\n", + " [ 8.22912203e-04, 5.23085147e-03, -1.54424435e-03, ...,\n", + " 1.85238465e-03, 2.10027117e-03, -2.81596673e-03],\n", + " [-3.32790660e-03, 4.28764522e-03, -2.89843697e-03, ...,\n", + " 5.34718041e-04, 2.53491011e-03, -2.39159609e-03],\n", + " [ 1.18417863e-03, -1.93740195e-03, -8.76987935e-04, ...,\n", + " 1.23451487e-03, 9.95422830e-04, 2.84749852e-03]],\n", + " \n", + " [[-1.90361077e-03, 1.52660452e-03, 1.16002187e-03, ...,\n", + " -1.43294875e-03, 1.26307376e-03, -2.09462456e-03],\n", + " [ 1.16570189e-03, 1.49338171e-04, -1.39527745e-03, ...,\n", + " 1.46432302e-03, -1.85642345e-03, -2.45312811e-04],\n", + " [ 1.63603970e-03, 1.82339014e-03, -3.25484289e-04, ...,\n", + " 9.77891730e-04, -1.58154627e-03, 2.99933203e-03],\n", + " ...,\n", + " [ 7.88934762e-04, 6.63882820e-04, 5.00127964e-04, ...,\n", + " -1.87795478e-04, 5.24009520e-04, -3.33858537e-03],\n", + " [ 7.84951728e-04, 1.50596199e-04, 3.16938688e-03, ...,\n", + " -8.10782425e-04, 3.27366730e-03, 6.82984537e-05],\n", + " [-1.94875093e-03, 5.32128755e-03, -4.89614112e-03, ...,\n", + " 7.95155283e-05, -1.11490488e-03, -1.21192541e-03]],\n", + " \n", + " [[ 8.34624516e-04, -6.40944054e-05, 3.95242590e-03, ...,\n", + " 1.35113602e-03, 2.08341703e-03, -1.19436020e-03],\n", + " [ 4.50570253e-04, 2.09432474e-04, 1.69749709e-03, ...,\n", + " 3.47505673e-04, -1.22899096e-03, -9.53997485e-04],\n", + " [ 1.96454301e-03, -2.51198909e-03, -1.53288987e-04, ...,\n", + " 4.43555153e-04, 1.68924380e-04, -3.89395701e-03],\n", + " ...,\n", + " [ 3.85239146e-05, -1.22487941e-03, -2.39187269e-03, ...,\n", + " 1.75011600e-03, 4.94157488e-04, -1.78930582e-03],\n", + " [ 2.56989826e-03, 4.15328005e-03, -2.37996830e-03, ...,\n", + " 1.03146362e-03, 9.35796765e-04, 1.01895176e-03],\n", + " [-3.30988783e-03, 1.98505330e-03, 3.23304767e-03, ...,\n", + " 1.34587963e-05, -1.36304880e-03, 1.47298560e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-3.11274547e-03, -1.48621359e-04, 4.53304406e-03, ...,\n", + " -8.09929392e-04, 7.22893921e-04, 2.93062581e-03],\n", + " [-3.83946754e-04, -3.50830785e-04, 3.56172142e-03, ...,\n", + " -1.10392179e-03, -2.39004527e-04, -3.04323528e-03],\n", + " [ 1.21767935e-03, 1.84642442e-04, -3.12053063e-03, ...,\n", + " 1.11318892e-03, 9.37248522e-04, -1.57997198e-03],\n", + " ...,\n", + " [-4.96087363e-03, -1.39242853e-03, -3.23794695e-04, ...,\n", + " -2.53925379e-03, 1.81211438e-03, -8.96861893e-04],\n", + " [ 3.46131856e-03, 9.94224101e-04, -1.00826588e-03, ...,\n", + " 8.25010589e-04, 4.16209176e-03, -2.50335573e-03],\n", + " [ 2.22531240e-03, -3.47394496e-04, 1.41804817e-03, ...,\n", + " -3.53726954e-03, -1.47630158e-03, 1.39356762e-05]],\n", + " \n", + " [[-8.18923931e-04, 5.61996480e-04, 1.14661176e-03, ...,\n", + " -3.46623623e-04, 8.44335649e-04, 1.03986077e-03],\n", + " [ 2.69834208e-03, -7.09684537e-05, -2.47714459e-04, ...,\n", + " 2.69302051e-03, -1.02054805e-03, 2.10447120e-03],\n", + " [-7.18752388e-04, 2.93785147e-03, 3.11710709e-03, ...,\n", + " -8.05539952e-04, 6.78786193e-04, -1.17247808e-03],\n", + " ...,\n", + " [-1.23673177e-03, 8.34377788e-05, -2.97808181e-03, ...,\n", + " -1.60584983e-03, 7.36871909e-04, 2.35592970e-03],\n", + " [ 2.64593516e-04, -1.72551267e-03, 1.09355350e-03, ...,\n", + " 2.30995356e-03, 7.25972583e-04, 1.55663944e-03],\n", + " [-2.34732078e-03, 1.58771442e-03, -2.76637054e-03, ...,\n", + " -9.93301044e-04, 4.30252432e-04, -1.27186067e-03]],\n", + " \n", + " [[ 2.19203008e-04, -2.32748012e-03, 5.65051683e-04, ...,\n", + " -1.34423317e-03, 3.09362658e-03, -2.08111084e-03],\n", + " [-4.75513982e-03, -1.72945834e-03, 2.05205730e-03, ...,\n", + " -4.28485405e-03, -1.40116422e-03, -2.90942914e-03],\n", + " [ 1.28270173e-03, -1.03239203e-03, 1.18634687e-03, ...,\n", + " -1.25182059e-03, -4.91283019e-04, 1.24233449e-03],\n", + " ...,\n", + " [ 3.28387995e-03, 2.05056553e-04, -2.31638877e-03, ...,\n", + " 3.05059191e-04, -2.83723610e-04, -2.79178564e-03],\n", + " [-2.36630673e-04, 2.89800373e-04, -2.67513422e-03, ...,\n", + " -1.63657218e-03, 5.44047565e-04, -8.24235554e-04],\n", + " [-5.78965992e-03, 3.11412849e-03, -2.67138216e-03, ...,\n", + " 8.22653237e-04, -2.80803721e-03, -1.19118486e-03]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 5.75492799e-04, -8.38074193e-04, -4.07448597e-03, ...,\n", + " 3.19885439e-03, 1.37585704e-03, 1.63260393e-03],\n", + " [-1.77425321e-03, 5.05774806e-04, 2.22660066e-03, ...,\n", + " 1.64651498e-03, 6.91065914e-04, 2.91744107e-03],\n", + " [-9.19053098e-04, 5.41089615e-03, -1.17358658e-03, ...,\n", + " 4.26690734e-04, 9.11797164e-04, 8.53107020e-04],\n", + " ...,\n", + " [ 2.22268180e-04, 2.15905742e-03, 6.81063029e-05, ...,\n", + " -2.65566370e-04, -3.98743578e-04, -9.66946653e-04],\n", + " [-9.39210702e-04, 1.44900451e-03, -2.40868816e-04, ...,\n", + " -1.64364942e-03, -1.81416050e-03, -3.08289705e-03],\n", + " [-2.93824269e-04, 4.24254173e-03, 8.94048368e-04, ...,\n", + " 1.90320599e-03, 5.80231826e-05, 1.35100691e-03]],\n", + " \n", + " [[-7.55565066e-04, 1.36540376e-03, 1.70124258e-04, ...,\n", + " 1.54401571e-03, -1.93784805e-03, -6.11010415e-04],\n", + " [-6.04787201e-04, -2.83261435e-03, -2.06062151e-03, ...,\n", + " -1.32616283e-03, 2.37897830e-03, -1.12545548e-03],\n", + " [ 3.66652757e-03, -3.69462953e-03, -4.61522024e-03, ...,\n", + " -1.94709550e-03, 6.11315627e-05, 8.75181693e-04],\n", + " ...,\n", + " [ 1.63245574e-03, -1.57542317e-03, -8.06179305e-04, ...,\n", + " -1.78271817e-04, -1.61887694e-03, 2.76966742e-03],\n", + " [-4.74011758e-03, 8.16200511e-04, 3.16539896e-04, ...,\n", + " -2.81311711e-03, -2.45223753e-03, -5.45330695e-04],\n", + " [-2.03077635e-03, -5.00337547e-03, -4.07255022e-03, ...,\n", + " -1.30678155e-03, 1.24201085e-03, -1.29970768e-03]],\n", + " \n", + " [[-4.52117296e-03, -3.16268852e-05, 3.85353895e-04, ...,\n", + " 1.59144984e-03, -2.38110148e-03, 2.48184055e-03],\n", + " [-1.28935569e-03, 2.57373811e-03, 2.94212904e-03, ...,\n", + " -5.61565510e-04, -7.35115464e-05, -3.42047913e-03],\n", + " [-2.21239243e-04, -1.15832384e-03, 1.08135364e-03, ...,\n", + " -4.36729228e-04, 3.60808475e-03, 8.21812602e-04],\n", + " ...,\n", + " [ 8.23041075e-04, -5.96798724e-04, 2.46903393e-04, ...,\n", + " -4.91942163e-04, 2.06866098e-04, -4.35019581e-04],\n", + " [-2.81880749e-03, -2.84888525e-03, -1.09620718e-03, ...,\n", + " -2.69397220e-04, 1.59656862e-03, -2.64589535e-03],\n", + " [-2.12765578e-03, 3.33865726e-04, 1.20675005e-03, ...,\n", + " 1.29106629e-03, -1.32574548e-03, -7.36644084e-04]],\n", + " \n", + " ...,\n", + " \n", + " [[ 3.23526119e-03, -3.99052026e-03, -2.08828118e-04, ...,\n", + " 2.66660145e-03, 2.10053753e-03, 3.37542081e-03],\n", + " [ 4.63207718e-04, 1.01864873e-03, 2.56253150e-03, ...,\n", + " 8.33682076e-04, -7.01415061e-04, -1.19667535e-03],\n", + " [ 3.29541188e-04, -7.66945013e-04, -3.33548873e-03, ...,\n", + " 1.15198782e-04, 6.42844476e-04, 1.88415416e-03],\n", + " ...,\n", + " [-2.37255497e-03, -3.17891943e-03, -4.51162038e-03, ...,\n", + " -3.05468752e-03, 1.62586628e-03, 4.91956249e-04],\n", + " [ 3.52150155e-03, -7.57674628e-04, -7.59741000e-04, ...,\n", + " -1.49976392e-03, 4.13274392e-03, -1.23380893e-03],\n", + " [-1.63820712e-03, -3.56243784e-03, 2.99573201e-03, ...,\n", + " 1.15981442e-03, 2.10368703e-03, -1.90565491e-03]],\n", + " \n", + " [[-1.68043072e-03, -7.33379973e-04, 1.13100593e-03, ...,\n", + " -1.47540949e-03, 1.16571796e-03, -1.62291632e-03],\n", + " [-7.54969427e-04, -7.51511368e-04, -1.31694577e-03, ...,\n", + " 1.92298542e-03, 2.01423699e-03, -2.87026423e-03],\n", + " [ 4.23348835e-03, 3.31194652e-03, 1.20164186e-03, ...,\n", + " -3.45917325e-03, -5.01860690e-04, -6.43040752e-04],\n", + " ...,\n", + " [-1.13918819e-03, 1.25175423e-03, -5.31003519e-04, ...,\n", + " 1.30030024e-03, 1.82856078e-04, 3.85270314e-03],\n", + " [ 1.88736513e-03, 1.40991912e-03, -1.96992129e-04, ...,\n", + " 1.47780683e-03, -3.49664269e-03, -5.96455648e-04],\n", + " [ 2.10178574e-03, -6.58707111e-04, -3.20032449e-03, ...,\n", + " 4.44282312e-03, -9.55423529e-05, -6.79109362e-05]],\n", + " \n", + " [[ 4.32580709e-03, -1.76192063e-03, -8.17409891e-04, ...,\n", + " 1.47083332e-03, -4.84955177e-04, 2.61401339e-03],\n", + " [ 9.38941230e-05, 2.97831109e-04, -2.87603028e-03, ...,\n", + " -1.46740745e-03, 3.66672815e-04, 1.13328511e-03],\n", + " [ 8.38277279e-04, -4.03696817e-04, -2.04606727e-03, ...,\n", + " -1.77424017e-03, 6.40474609e-04, -1.81213778e-04],\n", + " ...,\n", + " [-9.99074429e-04, 8.79953790e-04, 2.73118983e-03, ...,\n", + " 4.27598227e-03, -1.27154599e-05, 2.96790688e-03],\n", + " [-7.35054957e-04, -1.12943852e-03, 4.87833656e-03, ...,\n", + " -1.77495775e-03, 2.17769132e-03, -7.87812984e-04],\n", + " [-5.18430781e-04, -9.90051660e-04, -1.82940188e-04, ...,\n", + " 7.18105119e-04, -1.10850716e-03, -2.89660739e-03]]],\n", + " \n", + " \n", + " [[[ 2.00630305e-03, -4.95062117e-03, -4.56968788e-04, ...,\n", + " -1.83992810e-03, 1.72986125e-03, -2.18573492e-03],\n", + " [ 1.02666032e-03, -9.63373634e-04, 1.31931133e-03, ...,\n", + " -1.08042883e-03, 3.69013008e-03, 4.98798210e-04],\n", + " [-1.49450160e-03, 2.07022694e-03, -1.25916011e-03, ...,\n", + " -4.21528355e-04, 1.17012812e-03, 2.67903018e-03],\n", + " ...,\n", + " [ 2.19121412e-03, -1.96273625e-03, -3.54813470e-04, ...,\n", + " -8.27860931e-05, 2.05961126e-03, 6.40570011e-04],\n", + " [-3.85378022e-04, -3.60000110e-03, 1.03685306e-03, ...,\n", + " -7.80751638e-04, 1.68622661e-04, -2.18253554e-04],\n", + " [ 1.79904362e-03, 8.04174168e-04, 8.65711714e-04, ...,\n", + " -1.36796746e-03, -2.21515936e-03, 4.18083742e-03]],\n", + " \n", + " [[ 2.64886185e-03, 9.37252364e-04, 5.05205011e-04, ...,\n", + " -1.78583933e-03, -3.10236076e-03, -1.77450886e-03],\n", + " [-7.71016232e-04, -2.01772014e-03, 3.56957875e-03, ...,\n", + " 2.09710281e-03, 4.32718371e-04, 9.70929337e-04],\n", + " [-6.02441782e-04, -3.46877705e-03, -4.35600610e-04, ...,\n", + " 3.12189193e-04, 1.23401522e-03, 4.83738352e-03],\n", + " ...,\n", + " [-1.91007310e-03, 5.97948325e-04, 1.64363731e-03, ...,\n", + " -1.59621777e-04, -6.47419074e-04, -2.56336294e-04],\n", + " [ 1.03034393e-03, 3.03110550e-03, -8.79029743e-04, ...,\n", + " 1.95900793e-03, 1.07618619e-03, 2.39637378e-03],\n", + " [ 7.04997437e-05, 1.85755501e-03, 1.01020315e-03, ...,\n", + " 6.83933380e-04, 3.04347283e-04, -6.33398440e-06]],\n", + " \n", + " [[-8.46583338e-04, 2.73965008e-04, 1.94287591e-03, ...,\n", + " -9.67003987e-04, 8.02981202e-04, 1.09669403e-03],\n", + " [-1.29807019e-03, -1.72013033e-03, 3.48320464e-04, ...,\n", + " 1.01493066e-03, 6.41077582e-04, 3.57675436e-03],\n", + " [-2.66979099e-04, -3.02361068e-03, 7.02642021e-04, ...,\n", + " -4.68853768e-03, 1.61933684e-04, -2.48300523e-04],\n", + " ...,\n", + " [-1.44773466e-03, -1.23606145e-03, -2.20396928e-03, ...,\n", + " -2.25923583e-03, -1.31231471e-04, -2.53896200e-04],\n", + " [-2.28308560e-03, 1.76166324e-03, 2.32184120e-03, ...,\n", + " 4.53253509e-03, 2.42436567e-04, 2.58734100e-04],\n", + " [-1.79768505e-03, -2.70375423e-03, -6.67729473e-04, ...,\n", + " 3.52663337e-03, -4.09787375e-04, -4.61249510e-05]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.34084048e-03, 4.55678877e-04, -5.53742866e-04, ...,\n", + " -5.79219894e-04, -9.84570244e-04, -3.76733276e-03],\n", + " [ 3.57258366e-03, -7.67665741e-04, -1.34740560e-03, ...,\n", + " -2.59171799e-03, 2.15514703e-03, -1.73505931e-03],\n", + " [-1.97921018e-03, 3.77656252e-04, 2.20133443e-04, ...,\n", + " -7.59347575e-04, 2.07287259e-03, 2.89211725e-03],\n", + " ...,\n", + " [ 2.49226158e-03, 2.89102271e-03, -2.69971183e-03, ...,\n", + " 6.72032766e-04, -1.12714886e-03, 1.31482127e-04],\n", + " [ 1.31130824e-03, 2.97283143e-04, 1.21282545e-04, ...,\n", + " 9.43228079e-04, -1.52130856e-03, 6.03098248e-04],\n", + " [-3.54440243e-04, -7.89023587e-04, 5.88736264e-04, ...,\n", + " -1.42484438e-04, -2.84745032e-03, -2.25020526e-03]],\n", + " \n", + " [[ 1.06017338e-03, 1.55658473e-03, -3.22846742e-03, ...,\n", + " -7.13421847e-04, 4.72694199e-04, 2.86514871e-04],\n", + " [ 2.53080670e-03, -1.72682886e-03, 3.88275919e-04, ...,\n", + " -1.06701092e-03, -5.00054099e-04, -1.33809773e-03],\n", + " [-6.95102324e-04, -1.15148688e-03, 2.09788210e-03, ...,\n", + " -4.50784544e-04, 1.14079696e-04, 2.03500036e-03],\n", + " ...,\n", + " [-3.23885027e-03, -1.39316218e-03, -3.44876666e-03, ...,\n", + " -7.55057437e-04, -4.25444543e-03, 4.08475142e-04],\n", + " [ 3.31097376e-03, 8.34188366e-04, 5.23621238e-05, ...,\n", + " 1.54662342e-03, -1.17046654e-03, -1.48927595e-03],\n", + " [ 2.07010005e-03, 1.74302724e-03, -1.60723494e-03, ...,\n", + " 1.12819602e-03, -9.21999163e-04, -7.72282772e-04]],\n", + " \n", + " [[ 1.57463260e-03, 1.23621873e-03, -1.29727251e-03, ...,\n", + " 5.93696139e-04, 4.51685250e-04, -1.03837647e-03],\n", + " [ 1.09677261e-03, -8.32310470e-04, 1.56681077e-03, ...,\n", + " -5.88828814e-04, -3.10627976e-03, 1.55787449e-03],\n", + " [-3.39843496e-03, 4.74721855e-05, 2.13810778e-03, ...,\n", + " 1.40511198e-03, -1.27306802e-03, 3.43295396e-03],\n", + " ...,\n", + " [ 2.37556454e-03, -7.40309653e-04, -1.80129474e-03, ...,\n", + " 7.77595676e-04, -5.10393875e-04, 4.90164943e-03],\n", + " [ 9.77985095e-04, 1.08225248e-03, -8.85272166e-04, ...,\n", + " 5.64833535e-05, 5.32481761e-04, 2.27898522e-03],\n", + " [ 7.14328198e-04, 1.32288283e-03, 1.14245317e-03, ...,\n", + " 2.22827354e-03, 1.35761010e-03, 8.66894785e-04]]],\n", + " \n", + " \n", + " [[[-8.43666901e-04, -2.43223505e-03, -2.36211228e-03, ...,\n", + " 1.55125992e-04, -8.11930222e-04, 3.67028639e-04],\n", + " [ 1.37160486e-03, -8.72522825e-04, -1.13000767e-03, ...,\n", + " -2.95673893e-03, 8.78142077e-04, 8.05620977e-04],\n", + " [-1.97373400e-03, 1.15039863e-03, -3.29498551e-03, ...,\n", + " -4.07117652e-03, 3.91823007e-03, 1.59254312e-04],\n", + " ...,\n", + " [-8.09907040e-04, -4.63840814e-04, 2.20341398e-03, ...,\n", + " -2.36079958e-03, 5.61598536e-05, 1.41227804e-03],\n", + " [ 2.01672185e-04, 1.96335604e-03, 1.49982388e-03, ...,\n", + " 2.51943804e-03, -1.17054128e-03, -5.91556076e-04],\n", + " [ 1.78214640e-03, -6.45851193e-04, 7.47291720e-04, ...,\n", + " 8.55092134e-04, 1.45126032e-04, -1.66391581e-03]],\n", + " \n", + " [[ 1.45349966e-03, 1.23377086e-03, 1.75998185e-03, ...,\n", + " -1.41576596e-03, 1.99613604e-03, 3.75185371e-03],\n", + " [ 3.98416072e-03, 8.72770208e-04, -4.13282309e-03, ...,\n", + " -4.29047039e-03, -4.32789419e-03, 2.53033242e-03],\n", + " [ 7.08233056e-05, 2.74872407e-03, -1.89754076e-03, ...,\n", + " 1.64624432e-03, 4.81134630e-04, -1.18645874e-03],\n", + " ...,\n", + " [ 1.61946591e-04, -1.57826871e-04, -9.06688685e-04, ...,\n", + " -8.53267207e-04, 1.08717661e-03, -8.50711018e-04],\n", + " [ 2.14377884e-03, 3.63736972e-03, 6.65535976e-04, ...,\n", + " 5.46371564e-03, 1.00986031e-03, 1.85510307e-03],\n", + " [ 3.63659579e-04, -1.43811316e-03, 2.70165550e-03, ...,\n", + " 2.26890971e-03, 4.18603094e-03, 3.54178861e-04]],\n", + " \n", + " [[ 2.25326070e-03, -2.37576314e-03, 3.88632558e-04, ...,\n", + " -2.10003019e-03, 1.11576184e-04, 1.19418104e-03],\n", + " [-4.65946621e-04, 2.03397311e-03, 3.26732756e-03, ...,\n", + " 5.47510630e-04, -2.23908015e-03, 1.93968823e-03],\n", + " [-1.16974569e-03, 1.42465008e-03, -2.80775159e-04, ...,\n", + " 2.07854644e-03, -3.48520582e-03, -1.00561406e-03],\n", + " ...,\n", + " [-8.99634091e-04, -1.43154676e-03, -1.86001227e-04, ...,\n", + " -3.30833755e-05, -1.04856736e-03, -5.98404324e-04],\n", + " [ 3.85779096e-03, -2.46858317e-03, 8.08119716e-04, ...,\n", + " 1.95104536e-03, -2.86672451e-03, 2.45206640e-04],\n", + " [-1.27436535e-03, 7.71830091e-05, 1.67707109e-03, ...,\n", + " -1.34763587e-03, -2.44313938e-04, 9.55835043e-04]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.59653137e-03, -5.26345102e-05, -7.39213545e-04, ...,\n", + " 1.76239735e-03, -1.16860913e-03, -2.29923986e-03],\n", + " [-1.73654140e-03, -9.84819373e-04, -2.89764209e-03, ...,\n", + " -1.70617877e-03, 6.74121256e-04, 1.57688430e-03],\n", + " [-2.09680083e-03, 6.21461775e-04, -2.24610212e-05, ...,\n", + " -8.95891746e-04, -1.08481525e-03, 1.78939721e-04],\n", + " ...,\n", + " [-5.24749514e-03, 1.45776651e-03, -1.68309943e-03, ...,\n", + " 2.90547335e-03, 7.26866361e-04, 3.69962258e-03],\n", + " [ 4.19326313e-03, -9.62809427e-04, 1.69594641e-04, ...,\n", + " -9.21472310e-05, -4.84630145e-04, -5.59113594e-03],\n", + " [ 4.95695625e-04, -1.29670720e-04, 9.26245993e-04, ...,\n", + " 2.76659592e-03, -7.26665894e-04, -1.17162266e-03]],\n", + " \n", + " [[-3.46557819e-04, -1.78748823e-03, 9.99188400e-04, ...,\n", + " -5.31620753e-04, 1.07362331e-03, -4.52276086e-04],\n", + " [ 2.88659753e-03, 1.27288362e-03, -5.06431214e-04, ...,\n", + " 9.71116184e-04, 9.12718126e-04, 1.14133487e-04],\n", + " [ 2.28739838e-04, -3.31611256e-03, -1.97146367e-03, ...,\n", + " -1.96627295e-03, 2.65030703e-03, -1.35979161e-03],\n", + " ...,\n", + " [-3.28782108e-03, -1.96355046e-04, -1.51833403e-03, ...,\n", + " -6.41251565e-04, -4.32930421e-04, -9.70596913e-04],\n", + " [ 6.23346888e-04, -2.15629814e-03, -1.50742428e-03, ...,\n", + " -3.14926519e-03, 2.00443948e-03, -5.03565068e-04],\n", + " [-3.17326048e-03, -7.57485279e-04, 6.78380718e-04, ...,\n", + " -5.96601807e-04, -1.76243757e-05, 6.95572176e-04]],\n", + " \n", + " [[ 9.40771250e-04, 2.44649965e-03, -4.04537597e-04, ...,\n", + " 1.68599433e-03, -6.75203977e-04, -7.53352593e-04],\n", + " [ 9.48417932e-04, 4.11994843e-05, -2.99012219e-03, ...,\n", + " -1.31526857e-03, 2.84452783e-03, -2.04165070e-03],\n", + " [-4.16489202e-04, -2.52062059e-03, -7.13209040e-04, ...,\n", + " 4.82805015e-04, -4.03559860e-03, 1.59164844e-03],\n", + " ...,\n", + " [-1.64044905e-04, -8.18465021e-04, 3.18248081e-03, ...,\n", + " -2.26559397e-03, 4.88495221e-03, -1.20419846e-03],\n", + " [-1.37704040e-03, -6.33488351e-04, -3.65660846e-04, ...,\n", + " 1.93017040e-04, 2.77919346e-03, -4.99912363e-04],\n", + " [ 2.65176088e-04, -2.23132898e-03, 1.65310968e-03, ...,\n", + " 3.08551290e-03, -1.35207956e-03, -1.95106410e-03]]]], dtype=float32)\n", + " ), names=('embed', 'layers', 'heads', 'kv'), mesh=None, rules=None)},\n", + " 'value': {'kernel': LogicallyPartitioned(value=Param(\n", + " value=Array([[[[ 3.06652510e-03, -3.40593383e-02, 5.96511029e-02, ...,\n", + " -3.83924507e-03, -4.91481721e-02, -1.22140190e-02],\n", + " [-4.45110403e-04, -1.12190032e-02, -3.67582031e-02, ...,\n", + " 6.66676834e-03, -2.52818130e-02, 5.53275235e-02],\n", + " [ 1.29760029e-02, -2.44194288e-02, 6.40960224e-03, ...,\n", + " 5.75300539e-03, 4.83112633e-02, 7.17488676e-03],\n", + " ...,\n", + " [-2.69112978e-02, -2.17129402e-02, -2.08647326e-02, ...,\n", + " -2.31449939e-02, 1.11215003e-02, 4.64028604e-02],\n", + " [-3.14151794e-02, 2.22921241e-02, 5.00940834e-04, ...,\n", + " 1.59594193e-02, -1.89934240e-03, 2.27000304e-02],\n", + " [-4.81172185e-03, -5.67634357e-03, -2.20337119e-02, ...,\n", + " 4.43634056e-02, 4.48966213e-02, -2.77173476e-05]],\n", + " \n", + " [[ 9.17493738e-03, -2.12259926e-02, -4.23315056e-02, ...,\n", + " 1.16434591e-02, -3.23921628e-02, -1.40072890e-02],\n", + " [-1.03522446e-02, 3.49451462e-03, -1.70773100e-02, ...,\n", + " -6.15655258e-03, 8.96936655e-03, 3.01204044e-02],\n", + " [-2.45943945e-02, 9.93571430e-03, -3.00959572e-02, ...,\n", + " -2.10262053e-02, -3.60233746e-02, -2.20616385e-02],\n", + " ...,\n", + " [ 1.83614844e-03, -9.12897009e-03, -4.70034033e-02, ...,\n", + " -2.06966940e-02, -1.16917146e-02, -4.04882357e-02],\n", + " [-5.87343471e-03, 4.76479158e-03, 5.43761486e-03, ...,\n", + " 1.53478701e-02, 4.94849198e-02, 4.39479202e-02],\n", + " [ 5.09543419e-02, -1.03180734e-02, 6.41846610e-03, ...,\n", + " -2.85183433e-02, 1.70078110e-02, 9.58583318e-03]],\n", + " \n", + " [[ 9.45213623e-03, -9.23942495e-03, 7.32704997e-03, ...,\n", + " 8.17284454e-03, -1.38332192e-02, -4.58166562e-02],\n", + " [ 1.21177714e-02, 1.73714980e-02, 1.76457148e-02, ...,\n", + " -1.05940029e-02, 5.91923995e-03, 2.80251876e-02],\n", + " [-7.46240839e-03, 4.93337177e-02, -2.09294762e-02, ...,\n", + " 2.62671113e-02, -1.13453902e-02, 7.63379794e-05],\n", + " ...,\n", + " [-3.73549722e-02, -6.08540233e-03, -2.07422627e-03, ...,\n", + " -8.60962737e-03, -4.87790033e-02, -3.51211876e-02],\n", + " [-7.81677291e-03, -2.54890416e-02, -1.88444667e-02, ...,\n", + " 1.76166762e-02, -2.41615456e-02, 1.83594134e-02],\n", + " [ 1.85655616e-02, -8.21443554e-03, 5.84631637e-02, ...,\n", + " 3.68116703e-03, -2.57569142e-02, -3.73738632e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 2.21473258e-02, 3.05725690e-02, 4.36328612e-02, ...,\n", + " -2.53117550e-02, -3.86569928e-03, 1.95729937e-02],\n", + " [-1.65991783e-02, 2.87079569e-02, -1.36366794e-02, ...,\n", + " -9.19500366e-03, 1.80108510e-02, -1.86498184e-02],\n", + " [ 2.67185587e-02, 1.56763550e-02, -8.81534815e-03, ...,\n", + " 1.68805607e-02, -2.09701713e-02, -2.29248088e-02],\n", + " ...,\n", + " [ 2.80966740e-02, 6.25316706e-03, -1.70223098e-02, ...,\n", + " 1.57940518e-02, 4.27884236e-03, -3.87805775e-02],\n", + " [-1.74644089e-03, 4.62792069e-03, -3.55610400e-02, ...,\n", + " 1.15110597e-04, 1.99411754e-02, -2.24533193e-02],\n", + " [ 2.65536625e-02, 1.14388391e-02, 1.18201505e-02, ...,\n", + " 3.33781317e-02, -1.30241383e-02, -1.18359737e-02]],\n", + " \n", + " [[ 2.47075558e-02, 2.51646992e-02, -3.84029001e-02, ...,\n", + " -2.81387437e-02, -1.58542562e-02, -2.11098418e-03],\n", + " [-1.44320522e-02, 1.56766362e-02, -9.69647802e-03, ...,\n", + " 3.63230560e-04, -5.17170392e-02, -1.73999202e-02],\n", + " [-6.78725680e-03, -2.10417453e-02, 1.51290959e-02, ...,\n", + " -1.36650316e-02, -1.01177972e-02, 2.28998438e-02],\n", + " ...,\n", + " [-3.87843363e-02, -2.44068243e-02, 2.40774751e-02, ...,\n", + " -3.17261666e-02, -1.96458120e-02, -1.40469708e-02],\n", + " [ 5.94725609e-02, -5.66626433e-03, -1.69408191e-02, ...,\n", + " 8.01148635e-05, 1.65617783e-02, 4.69114725e-03],\n", + " [-1.94618925e-02, -9.08343308e-03, -4.63657640e-02, ...,\n", + " -2.39744410e-02, 6.69258973e-03, -6.99904340e-04]],\n", + " \n", + " [[ 4.87205535e-02, 1.38993695e-04, 2.23054737e-03, ...,\n", + " -1.27633754e-02, -3.68132629e-02, 2.22445950e-02],\n", + " [-2.74003018e-02, 4.13585483e-04, -4.29724762e-03, ...,\n", + " -2.86933817e-02, 2.78535150e-02, 2.72844005e-02],\n", + " [ 7.71898450e-03, -1.32342260e-02, 1.83440410e-02, ...,\n", + " 1.92444921e-02, 2.27058958e-02, -3.01651638e-02],\n", + " ...,\n", + " [ 6.22110383e-04, -1.51168248e-02, 6.38947543e-03, ...,\n", + " -2.40315683e-02, -2.66360212e-02, 2.32341439e-02],\n", + " [-1.25938491e-03, 4.07406548e-03, -1.42413229e-02, ...,\n", + " -1.83072891e-02, 4.22833394e-03, 1.94233898e-02],\n", + " [ 4.52653319e-02, 9.63478070e-03, 3.76412906e-02, ...,\n", + " -1.92716718e-02, -5.48598021e-02, -2.00155359e-02]]],\n", + " \n", + " \n", + " [[[-2.62019988e-02, -3.21775824e-02, 1.13541111e-02, ...,\n", + " 1.35714328e-02, 7.56171625e-03, 1.45942960e-02],\n", + " [ 1.80564802e-02, -1.25953695e-02, -2.00767666e-02, ...,\n", + " -1.34720020e-02, 1.15183508e-02, -1.16392365e-02],\n", + " [ 9.96008236e-03, -1.00148385e-02, -8.30273703e-03, ...,\n", + " -2.52189208e-02, 1.60079151e-02, 4.94243205e-03],\n", + " ...,\n", + " [-2.52499897e-03, -8.99344275e-04, -8.53826478e-03, ...,\n", + " 1.31221572e-02, 5.04696416e-03, 1.10431928e-02],\n", + " [-1.12230815e-02, 2.71696188e-02, -8.19974951e-03, ...,\n", + " 2.14992780e-02, -9.01464466e-03, -1.69684459e-02],\n", + " [-1.82418595e-03, -1.92433018e-02, -4.25211154e-02, ...,\n", + " -5.03852889e-02, -1.55010046e-02, 4.29662392e-02]],\n", + " \n", + " [[-2.04912350e-02, -1.92880873e-02, 1.90340169e-02, ...,\n", + " -1.48209212e-02, -3.58604975e-02, -1.75445173e-02],\n", + " [ 1.47496546e-02, -1.36547992e-02, -5.32851787e-03, ...,\n", + " 5.73085994e-02, -1.88640505e-02, 2.21359599e-02],\n", + " [ 4.06005755e-02, -2.76730815e-03, 1.81265399e-02, ...,\n", + " 3.41760404e-02, 8.10095342e-04, -1.31845530e-02],\n", + " ...,\n", + " [-3.81129272e-02, -1.10455826e-02, 8.38429946e-03, ...,\n", + " 3.88442054e-02, -9.68133658e-03, 2.79435311e-02],\n", + " [-5.28661255e-03, 2.29175631e-02, -1.32719362e-02, ...,\n", + " 2.81694066e-02, 1.11214090e-02, -3.70353125e-02],\n", + " [ 2.28716116e-02, 7.34687224e-03, -3.63622904e-02, ...,\n", + " 1.76609419e-02, 1.34537295e-02, -8.08722340e-03]],\n", + " \n", + " [[ 3.00175436e-02, -2.76538152e-02, 1.36637399e-02, ...,\n", + " 1.67092159e-02, -5.44693600e-03, -1.62662137e-02],\n", + " [ 1.07326992e-02, 7.60155031e-03, 1.78580098e-02, ...,\n", + " 1.86731610e-02, -1.19036157e-02, 1.50683969e-02],\n", + " [-3.77028957e-02, 6.09626994e-03, 2.46099476e-02, ...,\n", + " 1.92622524e-02, -3.44793778e-03, 6.06140634e-03],\n", + " ...,\n", + " [ 2.54833195e-02, -7.10646156e-04, 1.69997420e-02, ...,\n", + " 1.99005604e-02, -1.57875125e-04, -3.18801850e-02],\n", + " [-7.04139564e-03, -1.49664776e-02, -1.46921184e-02, ...,\n", + " -8.75255745e-03, 5.08271717e-02, 1.75600275e-02],\n", + " [-2.55666263e-02, 2.28687958e-03, -2.22181692e-03, ...,\n", + " -8.75537191e-03, -3.38227712e-02, -9.81499162e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 3.41032930e-02, -9.74745303e-03, -8.57152883e-03, ...,\n", + " -3.44515331e-02, 3.60203981e-02, 2.41048969e-02],\n", + " [ 7.50735449e-03, 6.85176626e-03, -4.04918455e-02, ...,\n", + " -5.57275675e-03, 2.69399509e-02, -1.15887402e-02],\n", + " [ 2.77798139e-02, -2.43773293e-02, 1.01266149e-02, ...,\n", + " 1.49775296e-02, 2.09767744e-02, 3.29823755e-02],\n", + " ...,\n", + " [ 5.14271576e-03, 2.36189216e-02, -1.48960091e-02, ...,\n", + " -6.18681088e-02, -3.10407858e-02, -4.38875845e-03],\n", + " [ 2.36156695e-02, 2.60603335e-02, 8.04961473e-03, ...,\n", + " 5.69023564e-03, -2.42884718e-02, 1.67629682e-02],\n", + " [ 1.58409458e-02, -3.55317555e-02, 1.37851909e-02, ...,\n", + " -1.87481027e-02, -2.03634333e-02, -2.92136576e-02]],\n", + " \n", + " [[ 2.18475722e-02, 5.55585697e-03, -2.35971119e-02, ...,\n", + " 7.93927629e-03, 8.65900423e-03, -1.62848141e-02],\n", + " [ 1.47624444e-02, 6.89860852e-03, -1.01652108e-02, ...,\n", + " 2.15822235e-02, -2.80216988e-03, -8.57752934e-03],\n", + " [ 1.22773759e-02, 1.80008030e-03, 5.52693196e-03, ...,\n", + " -1.47179002e-02, 1.65592954e-02, -1.44722769e-02],\n", + " ...,\n", + " [-7.16243265e-03, 3.83422598e-02, -1.71445720e-02, ...,\n", + " -1.57408230e-02, 2.93814316e-02, -9.90876090e-03],\n", + " [-2.20019612e-02, -1.03703085e-02, 2.44039875e-02, ...,\n", + " 1.62561685e-02, -4.09307703e-03, -4.61321836e-03],\n", + " [-1.82296541e-02, -3.38266678e-02, -1.25086252e-02, ...,\n", + " -7.27624539e-03, -3.06446850e-03, 1.82815529e-02]],\n", + " \n", + " [[-1.12073391e-03, -2.27560885e-02, 3.12993452e-02, ...,\n", + " 5.94655937e-03, 1.28314625e-02, 8.55658483e-03],\n", + " [ 9.73275118e-03, -1.59513168e-02, 1.02656409e-02, ...,\n", + " -1.20393978e-02, -1.33230891e-02, -2.73184329e-02],\n", + " [-2.67258789e-02, 8.50535743e-03, 2.84755696e-02, ...,\n", + " -1.95129458e-02, -9.14636301e-04, 1.85332540e-02],\n", + " ...,\n", + " [ 3.32031101e-02, -5.78648504e-03, 2.70566549e-02, ...,\n", + " 1.82050057e-02, -2.43596081e-03, -1.79888289e-02],\n", + " [-3.02373990e-03, 9.64739453e-03, -1.95111837e-02, ...,\n", + " -1.61980074e-02, 3.26526016e-02, -9.20403283e-03],\n", + " [ 9.02173852e-05, -1.90607030e-02, 3.33247297e-02, ...,\n", + " -1.55236712e-02, 7.43342238e-03, 1.34351226e-02]]],\n", + " \n", + " \n", + " [[[-7.59421987e-03, -1.74888372e-02, -4.63528857e-02, ...,\n", + " 9.10903607e-03, -9.87619627e-04, 3.95952305e-03],\n", + " [-1.62045117e-02, 3.34008560e-02, -1.97095070e-02, ...,\n", + " 9.35199298e-03, -8.93968716e-03, 1.00332703e-02],\n", + " [ 5.23450924e-03, 1.00945178e-02, 6.64633699e-03, ...,\n", + " -2.58395411e-02, 3.04494295e-02, 1.99235901e-02],\n", + " ...,\n", + " [-1.48946503e-02, -2.62932479e-02, -1.04750283e-02, ...,\n", + " -5.83845377e-03, 2.95074395e-04, 3.06571857e-03],\n", + " [ 3.21042612e-02, -1.37577001e-02, -2.46191509e-02, ...,\n", + " -7.59620569e-04, 1.73968896e-02, 3.98548599e-03],\n", + " [-2.41934191e-02, 2.44605541e-02, 2.10886472e-03, ...,\n", + " -9.42997448e-03, 3.73646081e-03, 2.13809172e-03]],\n", + " \n", + " [[-2.56869067e-02, -2.52917130e-02, -2.69120727e-02, ...,\n", + " 2.56822929e-02, 6.19478300e-02, 2.58485996e-03],\n", + " [ 3.24825272e-02, -2.34900564e-02, -4.41881567e-02, ...,\n", + " -4.86498792e-03, 1.99105572e-02, -4.00135890e-02],\n", + " [-2.92951353e-02, -2.67245080e-02, -2.83932351e-02, ...,\n", + " -4.03682254e-02, 1.63123105e-02, -7.63610657e-03],\n", + " ...,\n", + " [-3.34154628e-02, 2.03104550e-03, 2.28485297e-02, ...,\n", + " -5.08218221e-02, -2.01304704e-02, -4.03831806e-03],\n", + " [-6.07190020e-02, -3.66450809e-02, 1.70268100e-02, ...,\n", + " -7.61246309e-03, -2.53441688e-02, 4.31227125e-02],\n", + " [ 7.33191241e-03, 1.68190449e-02, 8.46065674e-03, ...,\n", + " -3.29000056e-02, -6.22411864e-03, 2.91801710e-03]],\n", + " \n", + " [[-7.75819086e-03, -8.56200419e-03, -2.64466144e-02, ...,\n", + " 1.88126322e-02, -5.42989932e-04, 9.22569889e-04],\n", + " [-2.84281839e-02, 2.43528504e-02, 5.60956867e-03, ...,\n", + " -2.80705038e-02, 2.85478518e-03, 1.99402235e-02],\n", + " [ 1.05457604e-02, 3.30954939e-02, 1.12595335e-02, ...,\n", + " -7.82961585e-03, -2.75975242e-02, -4.71524941e-03],\n", + " ...,\n", + " [ 4.90672560e-03, 7.55071128e-03, -8.94824159e-04, ...,\n", + " 9.89169721e-03, -4.93725622e-03, -1.08299935e-02],\n", + " [-4.84592430e-02, 2.19271798e-02, -2.98727979e-03, ...,\n", + " 1.79016292e-02, -1.30447252e-02, -4.46950272e-02],\n", + " [ 1.14269052e-02, -8.27078149e-03, 1.72928590e-02, ...,\n", + " 2.07924489e-02, -1.69223472e-02, -2.15395726e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-3.39364237e-03, -4.12076600e-02, -4.12419774e-02, ...,\n", + " 9.67259705e-03, -9.85163916e-03, -3.41256615e-04],\n", + " [ 1.17117744e-02, -3.39599624e-02, -1.13945045e-02, ...,\n", + " 3.97023931e-03, 2.15489808e-02, -9.64823365e-03],\n", + " [-2.29518563e-02, 7.71512510e-03, -3.65445949e-02, ...,\n", + " 8.25929176e-03, 1.67346764e-02, -1.15262193e-03],\n", + " ...,\n", + " [-3.64172310e-02, 1.30174495e-02, 1.70525536e-02, ...,\n", + " 7.13298423e-03, -3.94735783e-02, -3.15543525e-02],\n", + " [ 5.97299933e-02, -1.43012321e-02, -8.13527789e-04, ...,\n", + " 5.50276414e-03, 2.08980683e-02, -1.07029872e-02],\n", + " [ 2.32037399e-02, 1.94064304e-02, -3.38720181e-03, ...,\n", + " -2.81267380e-03, 1.19097885e-02, -6.09530602e-03]],\n", + " \n", + " [[-3.87838073e-02, -2.20712908e-02, 5.16349508e-04, ...,\n", + " 4.82120216e-02, 6.93353638e-03, -4.42185625e-03],\n", + " [ 9.11362190e-03, 6.84799394e-04, -3.49367335e-02, ...,\n", + " 9.02116485e-03, 5.45205781e-03, -4.76168608e-03],\n", + " [-9.39901546e-03, -2.74084918e-02, -2.08995510e-02, ...,\n", + " 5.09775383e-03, 3.29988822e-02, 8.77375132e-04],\n", + " ...,\n", + " [-2.44022924e-02, -2.37628096e-03, 2.34068371e-02, ...,\n", + " -3.55339460e-02, -5.35484916e-03, -5.52246673e-03],\n", + " [ 6.19369152e-04, 2.14328207e-02, 2.03471235e-03, ...,\n", + " 2.79640895e-03, -1.94615889e-02, -3.47214863e-02],\n", + " [ 6.00966671e-03, 9.67412256e-03, -3.91207775e-03, ...,\n", + " 1.28835645e-02, -1.87539142e-02, 7.97474105e-03]],\n", + " \n", + " [[ 4.42411192e-02, 3.42848673e-02, 3.43981013e-02, ...,\n", + " -1.54774180e-02, -4.62314636e-02, -3.76372226e-02],\n", + " [ 1.50129497e-02, -2.02215696e-03, 6.96580391e-03, ...,\n", + " 2.14976873e-02, 1.61050279e-02, 2.71202046e-02],\n", + " [ 3.28233019e-02, -2.14138422e-02, -5.75384870e-03, ...,\n", + " 6.58221962e-03, -5.89284929e-04, -1.97059494e-02],\n", + " ...,\n", + " [-4.07181568e-02, -5.18422797e-02, 3.51646543e-02, ...,\n", + " -2.91977283e-02, 2.57178373e-03, -2.14522388e-02],\n", + " [ 2.91784331e-02, -3.59218195e-02, 4.06476296e-02, ...,\n", + " -1.63284652e-02, 1.95505023e-02, 1.01783033e-02],\n", + " [-3.82008143e-02, 1.03504187e-03, 2.64316946e-02, ...,\n", + " -1.18750669e-02, 1.51615236e-02, -2.46810284e-03]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-4.58724704e-03, 1.84262847e-03, 1.59414019e-02, ...,\n", + " -7.30866706e-03, 1.94087308e-02, -1.21839512e-02],\n", + " [-3.84173505e-02, 1.60658192e-02, -3.36556658e-02, ...,\n", + " 4.20719124e-02, -2.44133975e-02, -1.21817300e-02],\n", + " [-6.62909374e-02, -9.45606921e-03, 1.46906907e-02, ...,\n", + " 4.38689403e-02, -7.52058718e-03, 5.88563550e-03],\n", + " ...,\n", + " [-2.36306675e-02, -4.02815528e-02, -1.64191367e-03, ...,\n", + " -5.18989600e-02, -7.86373112e-03, 1.27146505e-02],\n", + " [ 1.62428431e-02, -1.80150811e-02, 1.40160695e-02, ...,\n", + " 2.84018684e-02, -1.56360231e-02, 2.25506574e-02],\n", + " [-1.26797641e-02, 1.82986483e-02, -2.31718887e-02, ...,\n", + " 3.32747437e-02, 1.79026648e-02, 2.18483582e-02]],\n", + " \n", + " [[-1.25358384e-02, -5.07524312e-02, -4.36438248e-03, ...,\n", + " 4.32352945e-02, -1.57412663e-02, -1.62391143e-03],\n", + " [-3.00274249e-02, -4.21895050e-02, -4.98003000e-03, ...,\n", + " 8.16286821e-03, -2.58241519e-02, 1.48327760e-02],\n", + " [ 7.31880451e-03, -3.23302224e-02, 1.24811409e-02, ...,\n", + " 1.01921009e-02, 2.39183824e-03, -1.84088740e-02],\n", + " ...,\n", + " [-1.57114454e-02, 1.13585033e-02, 2.76730936e-02, ...,\n", + " -2.91880909e-02, 1.70474313e-02, 3.28698545e-03],\n", + " [ 2.64854208e-02, 2.09573247e-02, -5.59193268e-03, ...,\n", + " 7.62763713e-03, 3.65343876e-03, -3.10379528e-02],\n", + " [-1.52877066e-02, -3.32704833e-04, 1.86000261e-02, ...,\n", + " 1.75157320e-02, 1.65743902e-02, -1.36229014e-02]],\n", + " \n", + " [[ 5.48482165e-02, -4.11887355e-02, 1.27601884e-02, ...,\n", + " 1.10465207e-03, -2.92004528e-03, -2.41212491e-02],\n", + " [ 7.49747828e-03, -1.03475144e-02, -2.32921038e-02, ...,\n", + " 1.13517065e-02, 1.80653408e-02, 3.02632581e-02],\n", + " [ 4.80407802e-03, -3.32131684e-02, -6.75722212e-03, ...,\n", + " -3.37354280e-02, 1.01653030e-02, 1.20083075e-02],\n", + " ...,\n", + " [ 2.26004310e-02, 5.17798103e-02, 3.06624789e-02, ...,\n", + " 2.80908942e-02, -2.65690926e-02, 2.48037186e-03],\n", + " [ 3.67342420e-02, -1.43716848e-02, 5.88846952e-02, ...,\n", + " 2.45979838e-02, -1.16324788e-02, -2.13733874e-03],\n", + " [-2.98646186e-02, 2.87309885e-02, 1.41075179e-02, ...,\n", + " 3.56943458e-02, 5.47803454e-02, -2.99454317e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 9.84334387e-03, -1.11082187e-02, 1.31320897e-02, ...,\n", + " 4.62813973e-02, -1.54536478e-02, -1.90398972e-02],\n", + " [ 1.15578668e-02, 5.06232791e-02, 6.15853304e-03, ...,\n", + " -4.54988284e-03, -2.10013706e-02, -2.85706930e-02],\n", + " [ 1.12148533e-02, -1.29149882e-02, -1.84205566e-02, ...,\n", + " -2.10194774e-02, -1.82451922e-02, -7.77353719e-03],\n", + " ...,\n", + " [ 2.56632958e-02, 1.33961244e-02, 1.94497164e-02, ...,\n", + " 2.26942115e-02, 5.38479630e-03, 3.45805176e-02],\n", + " [ 8.41068756e-03, -2.62886509e-02, 3.16147879e-02, ...,\n", + " 4.49473560e-02, 1.20342092e-03, -1.61337741e-02],\n", + " [ 3.53462920e-02, -1.53516391e-02, 1.62687283e-02, ...,\n", + " 2.82897986e-03, -4.38735774e-03, 1.54578174e-02]],\n", + " \n", + " [[-1.83449127e-02, 6.24263240e-03, 5.79582993e-03, ...,\n", + " 2.43802238e-02, 4.61507887e-02, -3.05204885e-03],\n", + " [ 1.28827188e-02, 1.49206603e-02, -1.59613602e-02, ...,\n", + " -2.22381465e-02, -2.54402477e-02, 2.86431354e-03],\n", + " [-5.00601307e-02, -1.18409574e-03, 1.71309114e-02, ...,\n", + " 1.20848315e-02, 1.89993810e-02, -2.35487078e-03],\n", + " ...,\n", + " [ 1.95662561e-03, -2.97833867e-02, 4.16429043e-02, ...,\n", + " 3.10091246e-02, -4.07143636e-03, 2.25523394e-02],\n", + " [-2.29367204e-02, 2.70870030e-02, 2.37994324e-02, ...,\n", + " 1.50637887e-02, -7.98132794e-04, -7.16836844e-03],\n", + " [-1.36653828e-02, -8.65338277e-03, 2.96106446e-03, ...,\n", + " 2.58903531e-03, -2.00002827e-02, 2.40406627e-03]],\n", + " \n", + " [[-3.74989421e-03, -2.69103814e-02, -1.37282973e-02, ...,\n", + " -2.74942094e-03, 5.02489321e-03, -5.55477710e-03],\n", + " [-1.59526784e-02, 9.80219990e-03, -5.22539299e-03, ...,\n", + " -9.71710775e-03, 1.13452058e-02, 8.27637315e-03],\n", + " [ 7.51705095e-03, 3.93428728e-02, 4.67957929e-03, ...,\n", + " 9.86303575e-03, -1.65611356e-02, 5.38460584e-03],\n", + " ...,\n", + " [ 2.96352841e-02, 2.02489607e-02, -2.15385910e-02, ...,\n", + " 1.58443283e-02, 6.32582530e-02, 1.81198213e-02],\n", + " [ 1.86992867e-03, -3.50251980e-02, -1.58548646e-04, ...,\n", + " -3.68681289e-02, 1.86539739e-02, -4.13889475e-02],\n", + " [ 6.54638419e-03, -1.10186059e-02, 2.72593647e-03, ...,\n", + " 1.37353577e-02, -1.05523868e-02, -1.47886993e-02]]],\n", + " \n", + " \n", + " [[[ 2.96085104e-02, -8.64459726e-04, 1.03734676e-02, ...,\n", + " -1.93101298e-02, 2.15447005e-02, -2.30967999e-02],\n", + " [ 5.59520861e-03, 1.69687867e-02, 3.46487537e-02, ...,\n", + " -2.16689073e-02, -3.65057662e-02, 2.16613840e-02],\n", + " [-2.52702776e-02, -3.11615374e-02, -9.37360711e-03, ...,\n", + " -1.83256101e-02, -6.97410712e-03, 3.54643650e-02],\n", + " ...,\n", + " [-9.59516596e-03, 4.44376729e-02, 2.67800130e-02, ...,\n", + " -1.79972444e-02, 1.55405728e-02, -3.89852305e-03],\n", + " [-6.87336689e-03, -6.83802553e-03, -5.33916382e-03, ...,\n", + " -1.07136136e-02, -5.67768095e-03, 2.36731302e-02],\n", + " [ 1.66032724e-02, -3.83345671e-02, 1.15090553e-02, ...,\n", + " -2.14124266e-02, -1.31887654e-02, 4.18281108e-02]],\n", + " \n", + " [[-1.44259548e-02, 5.07674180e-03, -4.98639653e-03, ...,\n", + " -1.68191213e-02, -4.76384275e-02, 2.09134184e-02],\n", + " [ 4.31073420e-02, -8.46340973e-03, 8.46047234e-03, ...,\n", + " -9.49777197e-03, -1.60638355e-02, -1.59489177e-02],\n", + " [ 2.35831570e-02, -1.58593301e-02, 3.62763293e-02, ...,\n", + " 2.14741263e-03, -9.14497953e-03, 2.74271201e-02],\n", + " ...,\n", + " [-1.58326887e-02, -1.21537352e-03, 3.51192616e-03, ...,\n", + " -1.68082267e-02, -6.16165297e-03, -7.83039723e-03],\n", + " [-6.79581314e-02, 4.39122356e-02, 1.17320858e-03, ...,\n", + " 2.64919293e-03, -4.65588365e-03, 2.58365385e-02],\n", + " [-7.22379703e-03, -2.25982675e-03, 3.32807750e-03, ...,\n", + " 5.11071319e-03, -2.32066233e-02, 8.58908240e-03]],\n", + " \n", + " [[ 1.14204669e-02, 1.26187224e-02, -2.02171747e-02, ...,\n", + " 1.06056072e-02, 1.35602718e-02, -1.07879229e-02],\n", + " [-1.91155132e-02, -1.15982499e-02, 3.75030488e-02, ...,\n", + " -3.01600322e-02, -8.43106117e-03, 5.20722102e-03],\n", + " [ 8.55474826e-03, -2.38150712e-02, 1.90938637e-03, ...,\n", + " 5.33795792e-05, -3.30268852e-02, 1.40983250e-03],\n", + " ...,\n", + " [-3.49477888e-03, -1.50410691e-02, 1.31678740e-02, ...,\n", + " 9.57109407e-03, -2.06619725e-02, -1.52875977e-02],\n", + " [-2.17521345e-04, 2.06902511e-02, 1.68445949e-02, ...,\n", + " 2.32611280e-02, 1.21494792e-02, -3.16542131e-03],\n", + " [ 5.08197630e-03, -4.16894481e-02, 1.32774599e-02, ...,\n", + " 1.38081471e-02, -2.10417956e-02, -1.77357569e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-9.32270661e-03, -4.86975759e-02, -1.22559275e-02, ...,\n", + " 9.91007127e-03, -5.22737205e-02, 3.44520994e-02],\n", + " [ 2.99138613e-02, -3.09917349e-02, 3.80303040e-02, ...,\n", + " -2.13460624e-02, -4.87153651e-03, -2.64047552e-02],\n", + " [ 2.43745316e-02, -2.80445255e-02, -1.49960816e-02, ...,\n", + " -2.09649410e-02, 1.33474041e-02, 7.09103839e-03],\n", + " ...,\n", + " [-2.44469643e-02, -2.85505224e-02, -1.03405584e-02, ...,\n", + " 1.12081319e-02, -2.20116417e-04, -2.62159202e-03],\n", + " [-1.48854954e-02, -1.89094320e-02, 1.69817265e-02, ...,\n", + " 1.29670845e-02, -7.10745016e-03, -2.23227069e-02],\n", + " [ 5.00965491e-03, 6.61183032e-04, 2.95177032e-03, ...,\n", + " -6.37814477e-02, 1.02637561e-04, 2.14206260e-02]],\n", + " \n", + " [[ 2.08203234e-02, -3.28206643e-02, -1.80193428e-02, ...,\n", + " 3.27008404e-02, 2.36736541e-03, 2.52143815e-02],\n", + " [-1.14564821e-02, 3.51683050e-02, -2.54075564e-02, ...,\n", + " -1.98455248e-02, -1.60334688e-02, 1.26598291e-02],\n", + " [ 2.52072290e-02, -1.41177718e-02, 2.33329218e-02, ...,\n", + " 1.99852865e-02, 1.60512365e-02, 1.72943398e-02],\n", + " ...,\n", + " [-3.89409885e-02, 2.57347133e-02, 2.67292224e-02, ...,\n", + " 4.47717197e-02, 6.62771380e-03, -1.28623890e-02],\n", + " [-8.40376969e-03, 9.19526722e-03, 1.17159765e-02, ...,\n", + " -4.06121509e-03, 1.35967098e-02, -1.33481603e-02],\n", + " [ 2.79057045e-02, -1.28383907e-02, 1.17101567e-02, ...,\n", + " -1.40030123e-02, -1.03625832e-02, 9.47575085e-03]],\n", + " \n", + " [[-1.02209095e-02, -1.28642032e-02, -2.01340988e-02, ...,\n", + " -3.31911887e-03, 2.95157637e-02, 1.13839433e-02],\n", + " [-3.32815275e-02, -1.09317852e-03, -9.99160483e-03, ...,\n", + " 8.19176715e-03, -9.69397661e-04, -1.54546918e-02],\n", + " [ 1.85882282e-02, -2.31177099e-02, 5.53894937e-02, ...,\n", + " -5.57155674e-03, 1.32098431e-02, 3.43774669e-02],\n", + " ...,\n", + " [-6.87400177e-02, -1.64256785e-02, -7.84636568e-03, ...,\n", + " 2.36314107e-02, 2.35060435e-02, -1.73750557e-02],\n", + " [ 1.24063510e-02, 6.88708015e-03, -1.09144384e-02, ...,\n", + " 2.15095934e-02, -2.39985064e-02, -2.54034027e-02],\n", + " [ 3.67148519e-02, 5.17751789e-03, 4.77852002e-02, ...,\n", + " 3.17434743e-02, -3.43416329e-03, -6.54102070e-03]]],\n", + " \n", + " \n", + " [[[-4.50115353e-02, 9.88601148e-03, 3.44616990e-03, ...,\n", + " 1.60399508e-02, -2.60483660e-02, -3.04819588e-02],\n", + " [ 3.09726167e-02, 1.07981935e-02, -1.26751475e-02, ...,\n", + " 2.30894201e-02, -5.98903513e-03, -1.99429840e-02],\n", + " [-8.74540210e-03, 4.40312587e-02, -7.14880088e-03, ...,\n", + " -6.83315611e-03, 2.13570315e-02, -4.90062498e-03],\n", + " ...,\n", + " [-3.58510427e-02, 5.99759445e-03, -8.54603387e-03, ...,\n", + " -1.97365992e-02, -3.76066156e-02, -1.65376235e-02],\n", + " [-1.46962302e-02, -1.32065946e-02, -3.22574895e-04, ...,\n", + " -2.25602239e-02, -3.99535522e-03, -1.79724246e-02],\n", + " [ 1.76086389e-02, -2.11215671e-02, -1.82414241e-02, ...,\n", + " 2.38824133e-02, -8.81620857e-04, 1.30195590e-03]],\n", + " \n", + " [[-3.22711375e-03, 1.82708371e-02, 4.94276248e-02, ...,\n", + " 3.03644557e-02, 2.07000505e-03, -8.10617674e-03],\n", + " [ 2.13704128e-02, -7.59003824e-03, 2.42410805e-02, ...,\n", + " -1.33883534e-02, 2.77286209e-02, 3.09803784e-02],\n", + " [ 1.27141853e-03, -1.63340699e-02, -4.49707452e-03, ...,\n", + " -1.56458020e-02, 5.17506152e-03, -2.21844930e-02],\n", + " ...,\n", + " [ 6.39585825e-03, 3.53232883e-02, 1.00959348e-03, ...,\n", + " 9.18802433e-03, -2.83297040e-02, -2.12442875e-02],\n", + " [ 2.15159580e-02, 7.05921371e-03, -2.35589258e-02, ...,\n", + " -2.06220932e-02, 3.91857177e-02, -3.53739783e-02],\n", + " [-4.15303605e-03, -1.16510410e-02, 1.38698295e-02, ...,\n", + " -1.40645690e-02, 2.97055617e-02, -4.08164710e-02]],\n", + " \n", + " [[-3.05376071e-02, 1.55272828e-02, 1.84483603e-02, ...,\n", + " -4.67726635e-03, -8.48224945e-03, -1.22447480e-02],\n", + " [ 4.33716998e-02, -1.16816917e-02, -6.10151212e-04, ...,\n", + " 2.18748022e-02, -4.94001396e-02, -2.81436015e-02],\n", + " [-2.35400032e-02, 1.10280747e-03, -2.27291640e-02, ...,\n", + " 5.13769165e-02, -2.17145234e-02, -6.94895396e-03],\n", + " ...,\n", + " [-3.02998759e-02, 4.54241410e-03, 1.39345182e-02, ...,\n", + " -1.45768784e-02, 2.66641490e-02, -1.92480292e-02],\n", + " [ 6.25158660e-03, -1.20887309e-02, -1.59647372e-02, ...,\n", + " 1.82863064e-02, 1.49304233e-02, -8.66006315e-03],\n", + " [-2.17061229e-02, 7.07721943e-03, 2.68889349e-02, ...,\n", + " -4.10804003e-02, 3.40832323e-02, -2.15050410e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 3.36539485e-02, 4.11366113e-02, 3.30055249e-03, ...,\n", + " 4.55072075e-02, -4.09245156e-02, -3.35720628e-02],\n", + " [ 1.23279830e-02, -3.17452066e-02, -1.41952904e-02, ...,\n", + " -2.83970672e-04, 3.97219881e-03, -7.91779440e-03],\n", + " [ 5.08769834e-03, -2.69417912e-02, 5.20863989e-03, ...,\n", + " -1.25344815e-02, -1.83151371e-03, -1.13596898e-02],\n", + " ...,\n", + " [ 9.15902574e-03, -4.73445393e-02, 1.75720900e-02, ...,\n", + " 1.07421428e-02, -6.49937848e-03, 4.14044550e-03],\n", + " [-1.40790374e-03, -1.59456525e-02, 5.83324069e-03, ...,\n", + " 2.82617062e-02, 1.76429171e-02, 4.65353066e-03],\n", + " [ 1.28202168e-02, 4.83296439e-03, 2.40389109e-02, ...,\n", + " 3.52953821e-02, -3.09269130e-02, 3.29072587e-02]],\n", + " \n", + " [[ 8.57099053e-03, -4.58021794e-04, 2.64320858e-02, ...,\n", + " 5.95255941e-02, -5.05150072e-02, -1.17604621e-02],\n", + " [-1.85479708e-02, 3.04677300e-02, -2.71590538e-02, ...,\n", + " -2.56795995e-02, 1.50768459e-02, 2.48940382e-02],\n", + " [-2.43159272e-02, -3.38991992e-02, -7.97818322e-03, ...,\n", + " 8.37430917e-03, 8.93322751e-03, 1.99266797e-04],\n", + " ...,\n", + " [ 1.59750916e-02, 2.16094661e-03, 1.78365912e-02, ...,\n", + " 1.34351449e-02, 1.85928424e-04, -2.35226825e-02],\n", + " [ 1.62189566e-02, -8.88638385e-03, -4.31055427e-02, ...,\n", + " 1.29797533e-02, 1.37055805e-02, 6.95320312e-03],\n", + " [ 3.24853463e-03, -1.37587907e-02, 1.60833336e-02, ...,\n", + " 2.20150892e-02, 9.29446146e-03, -2.00364329e-02]],\n", + " \n", + " [[-1.31356549e-02, 2.71320306e-02, 1.67173278e-02, ...,\n", + " -1.96653549e-02, 2.22931560e-02, -1.42931556e-02],\n", + " [ 4.78752190e-03, 3.20651685e-03, 1.92375090e-02, ...,\n", + " -2.60417387e-02, 3.57824825e-02, 3.00234780e-02],\n", + " [-1.53056290e-02, -3.27066779e-02, -8.34650174e-03, ...,\n", + " -1.14975059e-02, -3.99734788e-02, -2.92718410e-02],\n", + " ...,\n", + " [-2.92373765e-02, -5.84430713e-03, -4.38026935e-02, ...,\n", + " -1.44839135e-03, -7.65370950e-03, -2.85166763e-02],\n", + " [-1.45522235e-02, 6.76052994e-04, -2.38911714e-02, ...,\n", + " -1.94044542e-02, -1.98307056e-02, -4.34963824e-03],\n", + " [-2.31107175e-02, 2.11544465e-02, -9.99872107e-03, ...,\n", + " -7.38644600e-03, -1.42173497e-02, 3.14076385e-03]]]], dtype=float32)\n", + " ), names=('embed', 'layers', 'kv_heads', 'kv_head_dim'), mesh=None, rules=None)}}},\n", + " 'logits_dense': {'kernel': LogicallyPartitioned(value=Param(\n", + " value=Array([[ 0.00208639, 0.00065333, -0.01520132, ..., 0.02358625,\n", + " 0.01595038, 0.02565785],\n", + " [-0.01363917, 0.00701551, 0.02157802, ..., -0.02228342,\n", + " -0.032594 , -0.00968419],\n", + " [-0.01002284, -0.00648022, -0.0233501 , ..., 0.01772123,\n", + " -0.02631799, -0.02947669],\n", + " ...,\n", + " [ 0.00455413, 0.01614293, 0.00397578, ..., 0.00327452,\n", + " -0.0232591 , -0.00470191],\n", + " [-0.01218283, -0.01479897, 0.02100204, ..., 0.00534702,\n", + " 0.04443625, 0.0248835 ],\n", + " [-0.02129968, -0.00154596, 0.03655213, ..., -0.02471166,\n", + " -0.02714085, -0.01690293]], dtype=float32)\n", + " ), names=('embed', 'vocab'), mesh=None, rules=None)}}}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.decoder.linen_state" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "State({\n", + " 'embedding': VariableState(\n", + " type=Param,\n", + " value=Array([[ 0.0575558 , -0.31785473, 0.11529455, ..., 1.0821939 ,\n", + " 1.4235774 , -1.2933688 ],\n", + " [-2.0068665 , -0.06486757, 0.1310754 , ..., -1.5467689 ,\n", + " 0.37397835, 0.41232687],\n", + " [-0.57422966, 0.1731033 , 0.9584525 , ..., 0.07480869,\n", + " 0.15087242, 0.41225332],\n", + " ...,\n", + " [ 0.7054784 , -0.4994459 , 0.07542419, ..., -1.2780907 ,\n", + " -0.12462003, 0.4509493 ],\n", + " [ 1.3809816 , -1.2765152 , 0.77147233, ..., 1.7020334 ,\n", + " 0.6716798 , -0.24864346],\n", + " [ 1.4495107 , 0.41864708, 1.412156 , ..., -1.0488809 ,\n", + " 0.12066022, 1.5232936 ]], dtype=float32)\n", + " ),\n", + " 'rngs': {\n", + " 'default': {\n", + " 'count': VariableState(\n", + " type=RngCount,\n", + " value=Array(2, dtype=uint32),\n", + " tag='default'\n", + " ),\n", + " 'key': VariableState(\n", + " type=RngKey,\n", + " value=Array((), dtype=key) overlaying:\n", + " [0 0],\n", + " tag='default'\n", + " )\n", + " }\n", + " }\n", + "})" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nnx.state(model.shared_embedding)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num_devices: 2, shape (1, 1, 2, 1, 1, 1, 1)\n", + "Setting up checkpoint logger...\n", + "Creating checkpoint manager...\n", + "Checkpoint manager created!\n" + ] + } + ], + "source": [ + "init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = train.setup_mesh_and_model(pyconfig.config)\n", + "with jax.profiler.trace(\"linen_init\"):\n", + " params = model.init(init_rng, input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'embedding': LogicallyPartitioned(value=Array([[ 1.4774086 , 0.29882422, -0.34196898, ..., 0.50766706,\n", + " -1.3415289 , 1.7924749 ],\n", + " [ 0.47449157, -1.9720819 , 0.38987246, ..., 1.4013296 ,\n", + " -0.44650635, 0.5926745 ],\n", + " [-0.74953204, -1.6305867 , -0.9119016 , ..., 0.9319291 ,\n", + " -0.24040265, 0.36705223],\n", + " ...,\n", + " [ 1.0246398 , 0.14891908, 1.2458084 , ..., -0.3720324 ,\n", + " 1.6434435 , -0.6255694 ],\n", + " [-1.3611012 , -1.3714991 , 0.5478506 , ..., 3.409998 ,\n", + " -0.0136618 , 0.4574892 ],\n", + " [-0.4890293 , -1.4788411 , 0.851169 , ..., 0.9657814 ,\n", + " -0.2644767 , 1.1933228 ]], dtype=float32), names=('vocab', 'embed'), mesh=None, rules=None)}" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "params[\"params\"][\"token_embedder\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "model.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "state = nnx.state(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Update keys: []\n", + "Update keys: []\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([[[ 0.3628558 , -1.0749931 , 0.5897113 , ..., 0.47235385,\n", + " -1.8449324 , 1.3828944 ]]], dtype=float32)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "@functools.partial(jax.jit, static_argnums=(0,))\n", + "def fwd_fn(gdef, state, *inputs):\n", + " return nnx.merge(gdef, state)(*inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[[ 0.36080366, -1.0749974 , 0.59387594, ..., 0.46911016,\n", + " -1.8451215 , 1.3860557 ]]], dtype=float32)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fwd_fn(*nnx.split(model), input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "rng = nnx.Rngs(time.time_ns() % 2 ** 31)()\n", + "params = model.init(rng, input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{0: (2048,),\n", + " 1: (2048, 16, 7168),\n", + " 2: (2048, 16, 7168),\n", + " 3: (7168, 16, 2048),\n", + " 4: (2048, 16),\n", + " 5: (2048, 16),\n", + " 6: (2048, 16, 16, 128),\n", + " 7: (16, 16, 128, 2048),\n", + " 8: (2048, 16, 16, 128),\n", + " 9: (2048, 16, 16, 128),\n", + " 10: (2048, 32000),\n", + " 11: (32000, 2048)}\n" + ] + } + ], + "source": [ + "params_flat = jax.tree.flatten(params)[0]\n", + "pprint({i: x.shape for i, x in enumerate(params_flat)})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sharding" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "mesh = sharding.Mesh(jax.devices(\"cpu\"), (\"x\",))\n", + "shard = sharding.NamedSharding(mesh, sharding.PartitionSpec(\"x\", None))" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "arr = nnx.with_partitioning(nnx.initializers.ones, shard)(nnx.Rngs(0), (128, 16))\n", + "arr = nnx.Variable(jnp.ones((128, 20)))" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "o = nn.with_logical_partitioning(jnp.zeros, (\"x\", None), mesh)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LinenToNNX" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import jaxfi as jaxm\n", + "mod = LinenToNNX(nn.BatchNorm(use_running_average=False), rngs=nnx.Rngs(0))\n", + "y = mod(jaxm.randn((10, 100)))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "296 μs ± 8.55 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "r = jaxm.randn((1, 100))\n", + "%timeit y = mod(r)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'mean': Variable(\n", + " value=Array([-2.0821048e-03, -6.8043353e-04, 2.7525767e-03, 9.4063886e-05,\n", + " -4.5324411e-04, -4.6478841e-03, 1.6677739e-03, 5.1256125e-03,\n", + " -1.7300450e-03, -1.4472724e-03, 1.0186685e-03, 4.4281147e-03,\n", + " -5.8264798e-04, 4.9753045e-03, -4.9023994e-04, 3.1962610e-04,\n", + " 3.4446979e-04, -1.7976237e-03, 3.2958160e-03, -2.7488766e-03,\n", + " -7.9713884e-04, -4.9346022e-04, 3.1748249e-03, 2.1695246e-03,\n", + " -1.1033370e-03, 6.7661819e-04, -9.1823353e-04, 4.7784806e-03,\n", + " -2.0047871e-03, 5.3438372e-03, 9.5874531e-04, 2.3210477e-03,\n", + " 2.5819831e-03, 3.6093504e-03, 9.5296197e-04, 2.7141403e-03,\n", + " 1.1053076e-03, -1.9327267e-03, -2.2280891e-03, 2.1749218e-03,\n", + " 3.4572810e-03, -2.1481735e-03, 5.5041173e-03, 1.5461477e-03,\n", + " -4.6485366e-04, -1.8874716e-03, 1.0145564e-03, 7.3133651e-03,\n", + " 1.3833139e-03, -4.3422944e-04, 3.4693078e-04, 1.7135066e-03,\n", + " 2.3201664e-03, -5.1228818e-03, 1.7357944e-03, -3.6539419e-03,\n", + " 2.8987734e-03, 2.5271243e-04, -7.2261551e-03, -1.8371275e-03,\n", + " -1.1411947e-03, -5.7459867e-04, -7.4111065e-04, 7.9631256e-03,\n", + " -2.0830294e-03, 7.5301039e-03, -2.8514321e-04, -3.7315756e-05,\n", + " 5.6105801e-03, -4.4330298e-03, -1.0110141e-03, -9.5138873e-04,\n", + " 8.2918943e-04, 3.1246032e-04, -3.5149166e-03, -1.4403564e-03,\n", + " -2.7137392e-03, 3.4334748e-03, -1.5897165e-03, 4.4826269e-03,\n", + " -2.9919487e-03, -6.9704029e-04, -9.6759025e-04, -7.8441733e-03,\n", + " 3.1670760e-03, 2.5880667e-03, 4.1315053e-03, 3.3142844e-03,\n", + " -7.1539119e-04, 2.6897197e-03, 1.6492496e-03, 3.0477690e-03,\n", + " 5.0413902e-03, -4.2108651e-03, 3.2424401e-03, 3.1541085e-03,\n", + " -2.5885326e-03, 3.4503869e-04, -3.7288184e-03, 4.2963885e-03], dtype=float32)\n", + "), 'var': Variable(\n", + " value=Array([0.9971783 , 1.0019077 , 1.001673 , 0.9983553 , 0.9998553 ,\n", + " 0.9975779 , 1.0029755 , 0.9951358 , 1.0020292 , 0.9994908 ,\n", + " 0.99590933, 0.9941095 , 0.9936969 , 1.0015997 , 0.99937445,\n", + " 1.004996 , 1.0016071 , 0.9963046 , 0.9940032 , 0.9992138 ,\n", + " 0.9929048 , 0.99869776, 0.99739367, 0.9953224 , 0.99461335,\n", + " 1.0075474 , 0.9959614 , 0.99454284, 1.0000423 , 0.99699223,\n", + " 1.0099753 , 0.99452955, 1.0054451 , 0.99928814, 1.0029738 ,\n", + " 0.9926543 , 0.9997919 , 0.9992111 , 0.9968599 , 0.9972401 ,\n", + " 0.99680287, 0.9974731 , 0.9948098 , 1.0048696 , 1.003878 ,\n", + " 0.9978373 , 0.99334776, 0.99859864, 0.9938207 , 0.99825364,\n", + " 1.0002991 , 1.0005965 , 0.9979661 , 0.998658 , 0.9917483 ,\n", + " 1.0000683 , 0.99667555, 0.9993567 , 0.9979503 , 0.994172 ,\n", + " 1.0156822 , 0.9991715 , 0.9980215 , 0.9959246 , 0.9962822 ,\n", + " 0.9941707 , 1.0005825 , 0.99335754, 1.0082855 , 0.9988821 ,\n", + " 0.993079 , 0.99710596, 1.0010343 , 1.0007817 , 0.994291 ,\n", + " 0.99595785, 0.9973155 , 1.0029474 , 1.0004606 , 1.0059624 ,\n", + " 0.99519885, 0.9975085 , 1.0036467 , 0.99411505, 0.99441326,\n", + " 1.0108154 , 0.9955842 , 1.0035017 , 0.9971621 , 0.9960533 ,\n", + " 0.99359286, 0.99726 , 0.9954215 , 0.99923676, 0.993431 ,\n", + " 0.9959255 , 0.9988144 , 1.0002699 , 0.9999126 , 1.0007291 ], dtype=float32)\n", + ")}\n", + "{'mean': Variable(\n", + " value=Array([-5.4133725e-03, 5.1530160e-04, 4.4251932e-03, 3.9685285e-03,\n", + " -4.5122160e-03, 1.1648417e-03, 5.6889076e-03, 8.0464687e-03,\n", + " -1.2249277e-03, 1.3089568e-03, -4.3154065e-03, 1.8955907e-03,\n", + " -5.2792724e-04, -7.7994156e-04, 7.4503319e-03, 2.3446612e-03,\n", + " 5.7236715e-03, 8.8289576e-03, 1.8565931e-03, -4.0224856e-03,\n", + " -9.0854121e-03, 3.5154931e-03, 2.6152381e-03, 2.8177741e-04,\n", + " 5.1246164e-03, -8.4920449e-04, -2.9435810e-03, 6.2172529e-03,\n", + " 1.1228498e-03, 4.1098618e-03, 6.5024028e-04, 5.6879362e-03,\n", + " 3.8265260e-03, 4.6147103e-03, 4.2743082e-03, -3.9841952e-03,\n", + " -4.7089322e-03, 2.3843069e-03, -2.7297654e-03, 7.8746444e-03,\n", + " 9.0348283e-03, -6.3487482e-03, 2.9564318e-03, -6.2560937e-03,\n", + " 1.3615972e-04, -5.1170862e-03, 4.7581387e-03, 6.4143995e-03,\n", + " 4.8178812e-03, -2.2680149e-03, -7.4463087e-04, 3.7080129e-03,\n", + " -6.3487394e-05, 1.2278042e-04, 4.9312841e-03, -3.5790501e-03,\n", + " 6.8373820e-03, 7.0673223e-03, -4.7117486e-03, -2.0354297e-03,\n", + " -1.2460124e-03, -2.3877760e-03, -1.7705668e-03, 1.6839999e-03,\n", + " 3.4304420e-03, 7.1969694e-03, 7.4808377e-05, 3.6029713e-03,\n", + " 7.8885294e-03, -3.8607048e-03, -5.1787621e-03, -2.7489376e-03,\n", + " 1.7790885e-03, 5.2325344e-03, -6.4765373e-03, -4.3493495e-03,\n", + " -5.5444185e-03, 8.3459392e-03, -2.9847890e-03, 5.6494232e-03,\n", + " -7.1711997e-03, -4.5145974e-03, -6.6269690e-04, -1.1403039e-02,\n", + " 1.5156582e-03, 5.7607614e-03, 2.8842646e-03, 6.0570636e-03,\n", + " -3.9650500e-03, 9.1814566e-03, 4.9557528e-03, 7.2096945e-03,\n", + " -2.3154306e-04, -6.5131253e-03, 5.6542759e-03, 2.3517762e-03,\n", + " -7.1707079e-03, 3.6808623e-03, 4.9370120e-04, 3.4018788e-03], dtype=float32)\n", + "), 'var': Variable(\n", + " value=Array([1.0023384 , 0.9964494 , 0.99732035, 0.99333113, 1.0048463 ,\n", + " 0.99717253, 1.0118234 , 0.9916133 , 1.0008204 , 0.99870276,\n", + " 0.9894864 , 0.9942008 , 0.9934598 , 0.99853426, 0.9946769 ,\n", + " 1.000786 , 0.99866813, 0.99252796, 0.9906824 , 0.9978659 ,\n", + " 0.988714 , 0.99534506, 0.99482954, 0.9910098 , 0.9910871 ,\n", + " 1.0073289 , 0.9946512 , 0.9903945 , 0.99340665, 0.9921589 ,\n", + " 1.0136465 , 0.99685293, 1.0036296 , 1.0014584 , 0.99795246,\n", + " 0.99337196, 1.001991 , 0.99551386, 1.0015377 , 0.9967534 ,\n", + " 0.99382496, 0.997434 , 0.9905447 , 0.9987997 , 0.9977848 ,\n", + " 0.9952865 , 0.99183404, 0.99689627, 0.991924 , 0.9907353 ,\n", + " 0.9956136 , 0.9948083 , 1.0019448 , 1.0068014 , 0.995616 ,\n", + " 1.0036167 , 0.9915774 , 1.0079556 , 0.99606764, 0.98830104,\n", + " 1.0129006 , 1.0011445 , 0.9961759 , 0.99989605, 0.9944935 ,\n", + " 0.99695563, 0.99888873, 0.98671234, 1.0061815 , 0.99892527,\n", + " 0.9900453 , 0.9921596 , 0.99695516, 0.9955983 , 0.99430966,\n", + " 0.9991283 , 0.99156207, 0.9987522 , 0.99729806, 0.99824554,\n", + " 0.99144685, 0.9953179 , 0.9969685 , 0.99307317, 0.998193 ,\n", + " 1.005656 , 0.989351 , 1.0140046 , 0.99514276, 0.994439 ,\n", + " 0.9950554 , 0.9921042 , 0.98929805, 0.9975349 , 0.99167955,\n", + " 0.99616444, 1.0018144 , 0.99656606, 1.0046463 , 0.9952588 ], dtype=float32)\n", + ")}\n" + ] + } + ], + "source": [ + "old_batch_stats = mod.state[\"batch_stats\"]\n", + "print(old_batch_stats)\n", + "y = mod(jaxm.randn((10, 100)))\n", + "print(mod.state[\"batch_stats\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "mod.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "model = nnx.Linear(100, 100, rngs=nnx.Rngs(0))\n", + "opt = nnx.Optimizer(model, optax.adam(1e-5))\n", + "grads = nnx.split(model)[1]\n", + "grads = nnx.grad(lambda model: jnp.mean(model(jaxm.randn((4, 100)))))(model)\n", + "opt.update(grads)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Updating keys from env and command line: ['hardware', 'enable_single_controller']\n", + "Running Model: default\n", + "Updating keys from model: []\n", + "Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period\n", + "dataset_type set to tfds, will use keys['dataset_path']='' and keys['dataset_name']='c4/en:3.0.1'\n", + "Config param adam_b1: 0.9\n", + "Config param adam_b2: 0.95\n", + "Config param adam_eps: 1e-08\n", + "Config param adam_eps_root: 0.0\n", + "Config param adam_weight_decay: 0.1\n", + "Config param add_bos: True\n", + "Config param add_eos: True\n", + "Config param allow_split_physical_axes: False\n", + "Config param ar_cache_axis_order: 1,2,0,3\n", + "Config param async_checkpointing: True\n", + "Config param attention: autoselected\n", + "Config param attention_type: global\n", + "Config param attn_logits_soft_cap: None\n", + "Config param autoregressive_decode_assert: \n", + "Config param base_emb_dim: 2048\n", + "Config param base_mlp_dim: 7168\n", + "Config param base_num_decoder_layers: 16\n", + "Config param base_num_kv_heads: 16\n", + "Config param base_num_query_heads: 16\n", + "Config param base_output_directory: \n", + "Config param checkpoint_dir: \n", + "Config param checkpoint_is_quantized: False\n", + "Config param checkpoint_period: 10000\n", + "Config param collect_stack_trace: False\n", + "Config param compile_topology: \n", + "Config param compile_topology_num_slices: -1\n", + "Config param compiled_trainstep_file: \n", + "Config param compute_axis_order: 0,1,2,3\n", + "Config param cosine_learning_rate_final_fraction: 0.1\n", + "Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)\n", + "Config param data_shuffle_seed: 0\n", + "Config param dataset_name: c4/en:3.0.1\n", + "Config param dataset_path: \n", + "Config param dataset_type: tfds\n", + "Config param dcn_autoregressive_parallelism: 1\n", + "Config param dcn_data_parallelism: -1\n", + "Config param dcn_fsdp_parallelism: 1\n", + "Config param dcn_fsdp_transpose_parallelism: 1\n", + "Config param dcn_pipeline_parallelism: 1\n", + "Config param dcn_sequence_parallelism: 1\n", + "Config param dcn_tensor_parallelism: 1\n", + "Config param decode_sampling_nucleus_p: -1\n", + "Config param decode_sampling_strategy: greedy\n", + "Config param decode_sampling_temperature: 1.0\n", + "Config param decode_sampling_top_k: 0\n", + "Config param decoder_block: llama2\n", + "Config param dropout_rate: 0\n", + "Config param dtype: bfloat16\n", + "Config param emb_dim: 2048\n", + "Config param enable_checkpoint_cloud_logger: False\n", + "Config param enable_checkpoint_standard_logger: False\n", + "Config param enable_checkpointing: True\n", + "Config param enable_data_shuffling: True\n", + "Config param enable_dropout: True\n", + "Config param enable_emergency_checkpoint: False\n", + "Config param enable_goodput_recording: False\n", + "Config param enable_jax_profiler: False\n", + "Config param enable_model_warmup: False\n", + "Config param enable_single_controller: True\n", + "Config param enable_single_replica_ckpt_restoring: False\n", + "Config param eval_data_column: text\n", + "Config param eval_dataset_name: c4/en:3.0.1\n", + "Config param eval_interval: -1\n", + "Config param eval_per_device_batch_size: 0\n", + "Config param eval_split: validation\n", + "Config param eval_steps: -1\n", + "Config param expansion_factor_real_data: -1\n", + "Config param final_logits_soft_cap: None\n", + "Config param force_unroll: False\n", + "Config param fused_mlp: False\n", + "Config param fused_qkv: False\n", + "Config param gcs_metrics: False\n", + "Config param global_batch_size_to_load: 96\n", + "Config param global_batch_size_to_train_on: 96\n", + "Config param global_parameter_scale: 1\n", + "Config param goodput_upload_interval_seconds: 60\n", + "Config param gradient_accumulation_steps: 1\n", + "Config param gradient_clipping_threshold: 1.0\n", + "Config param grain_eval_files: \n", + "Config param grain_train_files: \n", + "Config param grain_worker_count: 1\n", + "Config param hardware: other\n", + "Config param head_dim: 128\n", + "Config param hf_access_token: \n", + "Config param hf_data_dir: \n", + "Config param hf_eval_files: \n", + "Config param hf_eval_split: \n", + "Config param hf_path: \n", + "Config param hf_train_files: \n", + "Config param ici_autoregressive_parallelism: 1\n", + "Config param ici_data_parallelism: 1\n", + "Config param ici_fsdp_parallelism: -1\n", + "Config param ici_fsdp_transpose_parallelism: 1\n", + "Config param ici_pipeline_parallelism: 1\n", + "Config param ici_sequence_parallelism: 1\n", + "Config param ici_tensor_parallelism: 1\n", + "Config param inference_metadata_file: \n", + "Config param inference_microbenchmark_log_file_path: \n", + "Config param inference_microbenchmark_loop_iters: 10\n", + "Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n", + "Config param inference_microbenchmark_stages: prefill,generate\n", + "Config param init_weights_seed: 0\n", + "Config param jax_cache_dir: ~/jax_cache\n", + "Config param jax_profiler_port: 9999\n", + "Config param kv_quant_axis: heads_and_dkv\n", + "Config param kv_quant_dtype: int8\n", + "Config param learning_rate: 3e-05\n", + "Config param learning_rate_schedule_steps: 150001\n", + "Config param load_from_prefill_dir: False\n", + "Config param load_full_state_path: \n", + "Config param load_parameters_path: \n", + "Config param local_checkpoint_directory: \n", + "Config param local_checkpoint_period: 0\n", + "Config param log_period: 100\n", + "Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('stage', 'data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_kv_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_kv_head_dim', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('activation_stage', 'stage'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('norm', 'tensor'), ('heads', ('tensor', 'autoregressive')), ('layers', 'stage'), ('kv', ()), ('kv_heads', ('tensor', 'autoregressive')), ('kv_head_dim', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()))\n", + "Config param logits_dot_in_fp32: True\n", + "Config param logits_via_embedding: False\n", + "Config param max_checkify: False\n", + "Config param max_corpus_chars: 10000000\n", + "Config param max_prefill_predict_length: 64\n", + "Config param max_target_length: 2048\n", + "Config param megablox: True\n", + "Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']\n", + "Config param metrics_file: \n", + "Config param micro_batch_size_to_train_on: 96\n", + "Config param mlp_activations: ['silu', 'linear']\n", + "Config param mlp_dim: 7168\n", + "Config param model_name: default\n", + "Config param monitor_goodput: False\n", + "Config param normalization_layer_epsilon: 1e-05\n", + "Config param normalize_embedding_logits: True\n", + "Config param num_decoder_layers: 16\n", + "Config param num_experts: 1\n", + "Config param num_experts_per_tok: 1\n", + "Config param num_kv_heads: 16\n", + "Config param num_layers_per_pipeline_stage: 1\n", + "Config param num_pipeline_microbatches: -1\n", + "Config param num_pipeline_repeats: -1\n", + "Config param num_query_heads: 16\n", + "Config param num_slices: 1\n", + "Config param opt_type: adamw\n", + "Config param param_scan_axis: 1\n", + "Config param per_device_batch_size: 12.0\n", + "Config param prefill_cache_axis_order: 1,2,0,3\n", + "Config param prefill_cache_dir: \n", + "Config param profiler: \n", + "Config param profiler_steps: 5\n", + "Config param prometheus_port: 0\n", + "Config param prompt: I love to\n", + "Config param quant_cfg_path: \n", + "Config param quantization: \n", + "Config param quantization_local_shard_count: 1\n", + "Config param quantize_kvcache: False\n", + "Config param record_internal_nn_metrics: 0\n", + "Config param remat_policy: full\n", + "Config param reshape_q: False\n", + "Config param reuse_example_batch: 0\n", + "Config param rope_max_timescale: 10000\n", + "Config param rope_min_timescale: 1\n", + "Config param run_name: None\n", + "Config param save_config_to_gcs: False\n", + "Config param save_quantized_params_path: \n", + "Config param scan_layers: True\n", + "Config param scan_pipeline_iterations: True\n", + "Config param skip_first_n_steps_for_profiler: 1\n", + "Config param sliding_window_size: 0\n", + "Config param stack_trace_interval_seconds: 600\n", + "Config param stack_trace_to_cloud: False\n", + "Config param steps: 150001\n", + "Config param target_eval_loss: 0.0\n", + "Config param tensorboard_dir: \n", + "Config param tokenize_eval_data: True\n", + "Config param tokenize_train_data: True\n", + "Config param tokenizer_path: assets/tokenizer.llama2\n", + "Config param train_data_column: text\n", + "Config param trainable_position_size: -1\n", + "Config param upload_all_profiler_results: False\n", + "Config param use_iota_embed: False\n", + "Config param use_post_attn_norm: False\n", + "Config param use_post_ffw_norm: False\n", + "Config param use_untrainable_positional_embedding: False\n", + "Config param use_vertex_tensorboard: False\n", + "Config param using_pipeline_parallelism: False\n", + "Config param vertex_tensorboard_project: \n", + "Config param vertex_tensorboard_region: \n", + "Config param vocab_size: 32000\n", + "Config param warmup_steps_fraction: 0.1\n", + "Config param weight_dtype: float32\n", + "Num_devices: 8, shape (1, 1, 8, 1, 1, 1, 1)\n", + "Setting up checkpoint logger...\n", + "Creating checkpoint manager...\n", + "Checkpoint manager created!\n" + ] + } + ], + "source": [ + "#os.environ.setdefault(\"PROCESS_ID\", \"0\")\n", + "#os.environ.setdefault(\"JAX_PROCESS_COUNT\", \"1\")\n", + "#os.environ.setdefault(\"PROCESS_IN_JOB\", \"0\")\n", + "#os.environ.setdefault(\"JAX_COORDINATOR_ADDRESS\", \"127.0.0.1\")\n", + "#os.environ.setdefault(\"JAX_COORDINATOR_IP\", \"127.0.0.1\")\n", + "#os.environ.setdefault(\"JAX_COORDINATOR_PORT\", str(65323))\n", + "#os.environ.setdefault(\"NNODES\", str(1))\n", + "#os.environ.setdefault(\"NODE_RANK\", str(0))\n", + "#\n", + "##os.environ[\"COORDINATOR_ADDRESS\"] = \"127.0.0.1:1234\"\n", + "#os.environ.setdefault(\"JOB_INDEX\", \"0\")\n", + "#os.environ.setdefault(\"JOB_COMPLETION_INDEX\", \"0\")\n", + "#os.environ[\"PROCESSES_IN_JOB\"] = \"1\"\n", + "##os.environ[\"JAX_PROCESS_COUNT\"] = \"1\"\n", + "pyconfig.initialize([\"python3\", \"MaxText/configs/base.yml\", \"hardware=other\", \"enable_single_controller=True\"])\n", + "config = dict(pyconfig.config.get_keys())\n", + "\n", + "init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = train.setup_mesh_and_model(pyconfig.config)\n", + "input_tokens = jnp.array([[0]])\n", + "input_positions = jnp.array([[0]])\n", + "params = model.init(init_rng, input_tokens, input_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'batch_stats': {'mean': Variable(\n", + " value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\n", + " ),\n", + " 'var': Variable(\n", + " value=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)\n", + " )},\n", + " 'params': {'bias': Param(\n", + " value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\n", + " ),\n", + " 'scale': Param(\n", + " value=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)\n", + " )}}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mod.state" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "mod.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nnx.Dropout" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nnx.BatchNorm" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LazyNNX(\n", + " init_args=(),\n", + " initialized=True,\n", + " linen_mod=BatchNorm(\n", + " # attributes\n", + " use_running_average = True\n", + " axis = -1\n", + " momentum = 0.99\n", + " epsilon = 1e-05\n", + " dtype = None\n", + " param_dtype = float32\n", + " use_bias = True\n", + " use_scale = True\n", + " bias_init = zeros\n", + " scale_init = ones\n", + " axis_name = None\n", + " axis_index_groups = None\n", + " use_fast_variance = True\n", + " force_float32_reductions = True\n", + " ),\n", + " state={'batch_stats': {'mean': Variable(\n", + " value=Array(shape=(100,), dtype=float32)\n", + " ), 'var': Variable(\n", + " value=Array(shape=(100,), dtype=float32)\n", + " )}, 'params': {'bias': Param(\n", + " value=Array(shape=(100,), dtype=float32)\n", + " ), 'scale': Param(\n", + " value=Array(shape=(100,), dtype=float32)\n", + " )}}\n", + ")\n" + ] + } + ], + "source": [ + "nnx.display(mod)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "@functools.partial(jax.jit, static_argnames=[\"graphdef\"])\n", + "def loss_fn(graphdef, params, *args):\n", + " return jnp.sum(nnx.merge(graphdef, params)(*args))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "State({\n", + " 'state': {\n", + " 'params': {\n", + " 'bias': VariableState(\n", + " type=Param,\n", + " value=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)\n", + " ),\n", + " 'kernel': VariableState(\n", + " type=Param,\n", + " value=Array([[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)\n", + " )\n", + " }\n", + " }\n", + "})" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "graphdef, params = nnx.split(mod)\n", + "loss_fn(graphdef, params, jnp.ones(100))\n", + "grad_fn = jax.jit(jax.grad(loss_fn, argnums=1))\n", + "grad_fn(graphdef, params, jnp.ones(100))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "@functools.partial(nnx.vmap, axis_size=5, in_axes=(None, None))\n", + "def make_model(rngs: nnx.Rngs, x: jax.Array):\n", + " mod = LazyNNX(nn.Dense, 100, rngs=rngs)\n", + " y = mod(x) # run the model in accordance with linen's convention (nn.compact)\n", + " return mod" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "vmap_mod = make_model(nnx.Rngs(0), jnp.ones((1, 100)))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array([[-0.22681893, 0.12948306, -0.5660833 , 0.54388836, -0.13207269,\n", + " -0.28127185, 0.13340358, 0.21858714, 0.0342543 , 0.40834817,\n", + " 0.13434777, -0.4555197 , 0.06381328, -0.58547539, -0.76521007,\n", + " 0.16313892, -0.26837662, -0.54795653, 0.18024976, 0.34415041,\n", + " -0.13140569, 0.56270905, -0.10050073, 0.38099573, -0.37621883,\n", + " 0.08581573, 0.47219038, -0.55236467, -0.43320429, -0.39558568,\n", + " 0.09934793, 0.15027676, -0.07010986, -0.22549472, -0.08241178,\n", + " 0.15171293, -0.33565122, -0.5415724 , -0.44550078, -0.02126699,\n", + " -0.00690109, -0.00483307, 0.01377263, 0.22923029, -0.30733769,\n", + " 0.13028407, -0.23582329, -0.06795412, 0.17091568, -0.48275513,\n", + " 0.45854019, -0.15280847, 0.69775289, -0.27580605, 0.41420138,\n", + " 0.4121478 , -0.21960127, -0.3110823 , 0.09149418, 0.60862518,\n", + " 0.15215905, -0.7924634 , 0.27823312, -0.56006468, -0.19038152,\n", + " -0.40971096, -0.07249114, -0.1114347 , -0.32514797, -0.58197642,\n", + " -0.21161082, -0.17796598, -0.31446279, -0.35074647, -0.12421286,\n", + " 0.26287658, 0.34816247, 0.16237974, -0.14378001, 0.39320464,\n", + " 0.04536457, 0.2551306 , 0.16496384, -0.16709305, 0.07018957,\n", + " 0.22217788, -0.53453768, -0.31078607, -0.399209 , -0.39957056,\n", + " -0.45992775, 0.59042818, -0.26991543, 0.34353292, 0.40906526,\n", + " 0.26577448, 0.19427761, 0.22063262, 0.15312979, 0.0455256 ]], dtype=float64),\n", + " None)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nnx.scan(lambda x, mod: (jax.nn.tanh(mod(x)), None))(\n", + " jnp.ones((1, 100)),\n", + " vmap_mod, \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "X = jaxm.randn((10 ** 3,), dtype=jaxm.float32)\n", + "y = jaxm.sin(X)\n", + "\n", + "@functools.partial(jax.jit, static_argnums=(0,))\n", + "def loss_fn(graphdef, params, x, y):\n", + " model = nnx.merge(graphdef, params)\n", + " x = x[..., None] + jnp.zeros(100)\n", + " yp = nnx.scan(lambda x, mod: (jax.nn.tanh(mod(x)), None))(\n", + " x,\n", + " model, \n", + " )[0]\n", + " return jaxm.mean((yp[..., 0] - y) ** 2)\n", + " \n", + "@functools.partial(jax.jit, static_argnums=(0,))\n", + "def grad_fn(graphdef, params, x, y):\n", + " return jax.grad(loss_fn, argnums=1)(graphdef, params, x, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss = 7.5165e-01\n", + "loss = 6.2398e-01\n", + "loss = 5.0591e-01\n", + "loss = 4.0034e-01\n", + "loss = 3.0931e-01\n", + "loss = 2.3373e-01\n", + "loss = 1.7333e-01\n", + "loss = 1.2687e-01\n", + "loss = 9.2456e-02\n", + "loss = 6.7913e-02\n", + "loss = 5.1097e-02\n", + "loss = 4.0097e-02\n", + "loss = 3.3329e-02\n", + "loss = 2.9542e-02\n", + "loss = 2.7792e-02\n", + "loss = 2.7382e-02\n", + "loss = 2.7813e-02\n", + "loss = 2.8731e-02\n", + "loss = 2.9894e-02\n", + "loss = 3.1132e-02\n", + "loss = 3.2337e-02\n", + "loss = 3.3435e-02\n", + "loss = 3.4382e-02\n", + "loss = 3.5152e-02\n", + "loss = 3.5733e-02\n", + "loss = 3.6122e-02\n", + "loss = 3.6320e-02\n", + "loss = 3.6336e-02\n", + "loss = 3.6179e-02\n", + "loss = 3.5860e-02\n", + "loss = 3.5394e-02\n", + "loss = 3.4793e-02\n", + "loss = 3.4072e-02\n", + "loss = 3.3246e-02\n", + "loss = 3.2330e-02\n", + "loss = 3.1339e-02\n", + "loss = 3.0288e-02\n", + "loss = 2.9192e-02\n", + "loss = 2.8068e-02\n", + "loss = 2.6928e-02\n", + "loss = 2.5789e-02\n", + "loss = 2.4666e-02\n", + "loss = 2.3570e-02\n", + "loss = 2.2517e-02\n", + "loss = 2.1518e-02\n", + "loss = 2.0584e-02\n", + "loss = 1.9723e-02\n", + "loss = 1.8945e-02\n", + "loss = 1.8253e-02\n", + "loss = 1.7652e-02\n", + "loss = 1.7141e-02\n", + "loss = 1.6719e-02\n", + "loss = 1.6381e-02\n", + "loss = 1.6119e-02\n", + "loss = 1.5925e-02\n", + "loss = 1.5786e-02\n", + "loss = 1.5692e-02\n", + "loss = 1.5629e-02\n", + "loss = 1.5584e-02\n", + "loss = 1.5546e-02\n", + "loss = 1.5504e-02\n", + "loss = 1.5452e-02\n", + "loss = 1.5382e-02\n", + "loss = 1.5291e-02\n", + "loss = 1.5179e-02\n", + "loss = 1.5046e-02\n", + "loss = 1.4894e-02\n", + "loss = 1.4728e-02\n", + "loss = 1.4552e-02\n", + "loss = 1.4370e-02\n", + "loss = 1.4187e-02\n", + "loss = 1.4007e-02\n", + "loss = 1.3834e-02\n", + "loss = 1.3669e-02\n", + "loss = 1.3514e-02\n", + "loss = 1.3370e-02\n", + "loss = 1.3238e-02\n", + "loss = 1.3117e-02\n", + "loss = 1.3005e-02\n", + "loss = 1.2902e-02\n", + "loss = 1.2806e-02\n", + "loss = 1.2716e-02\n", + "loss = 1.2630e-02\n", + "loss = 1.2547e-02\n", + "loss = 1.2466e-02\n", + "loss = 1.2386e-02\n", + "loss = 1.2307e-02\n", + "loss = 1.2227e-02\n", + "loss = 1.2148e-02\n", + "loss = 1.2068e-02\n", + "loss = 1.1989e-02\n", + "loss = 1.1910e-02\n", + "loss = 1.1831e-02\n", + "loss = 1.1754e-02\n", + "loss = 1.1678e-02\n", + "loss = 1.1604e-02\n", + "loss = 1.1532e-02\n", + "loss = 1.1463e-02\n", + "loss = 1.1396e-02\n", + "loss = 1.1331e-02\n", + "loss = 1.1269e-02\n", + "loss = 1.1209e-02\n", + "loss = 1.1152e-02\n", + "loss = 1.1096e-02\n", + "loss = 1.1043e-02\n", + "loss = 1.0991e-02\n", + "loss = 1.0940e-02\n", + "loss = 1.0891e-02\n", + "loss = 1.0843e-02\n", + "loss = 1.0796e-02\n", + "loss = 1.0750e-02\n", + "loss = 1.0705e-02\n", + "loss = 1.0661e-02\n", + "loss = 1.0617e-02\n", + "loss = 1.0575e-02\n", + "loss = 1.0534e-02\n", + "loss = 1.0493e-02\n", + "loss = 1.0454e-02\n", + "loss = 1.0416e-02\n", + "loss = 1.0379e-02\n", + "loss = 1.0343e-02\n", + "loss = 1.0308e-02\n", + "loss = 1.0274e-02\n", + "loss = 1.0242e-02\n", + "loss = 1.0210e-02\n", + "loss = 1.0179e-02\n", + "loss = 1.0149e-02\n", + "loss = 1.0120e-02\n", + "loss = 1.0091e-02\n", + "loss = 1.0064e-02\n", + "loss = 1.0037e-02\n", + "loss = 1.0011e-02\n", + "loss = 9.9856e-03\n", + "loss = 9.9609e-03\n", + "loss = 9.9368e-03\n", + "loss = 9.9135e-03\n", + "loss = 9.8907e-03\n", + "loss = 9.8687e-03\n", + "loss = 9.8472e-03\n", + "loss = 9.8264e-03\n", + "loss = 9.8062e-03\n", + "loss = 9.7865e-03\n", + "loss = 9.7675e-03\n", + "loss = 9.7490e-03\n", + "loss = 9.7310e-03\n", + "loss = 9.7136e-03\n", + "loss = 9.6966e-03\n", + "loss = 9.6802e-03\n", + "loss = 9.6642e-03\n", + "loss = 9.6486e-03\n", + "loss = 9.6335e-03\n", + "loss = 9.6188e-03\n", + "loss = 9.6045e-03\n", + "loss = 9.5906e-03\n", + "loss = 9.5771e-03\n", + "loss = 9.5640e-03\n", + "loss = 9.5512e-03\n", + "loss = 9.5388e-03\n", + "loss = 9.5268e-03\n", + "loss = 9.5150e-03\n", + "loss = 9.5037e-03\n", + "loss = 9.4926e-03\n", + "loss = 9.4818e-03\n", + "loss = 9.4713e-03\n", + "loss = 9.4611e-03\n", + "loss = 9.4512e-03\n", + "loss = 9.4416e-03\n", + "loss = 9.4322e-03\n", + "loss = 9.4231e-03\n", + "loss = 9.4142e-03\n", + "loss = 9.4055e-03\n", + "loss = 9.3971e-03\n", + "loss = 9.3889e-03\n", + "loss = 9.3809e-03\n", + "loss = 9.3731e-03\n", + "loss = 9.3655e-03\n", + "loss = 9.3580e-03\n", + "loss = 9.3508e-03\n", + "loss = 9.3438e-03\n", + "loss = 9.3369e-03\n", + "loss = 9.3302e-03\n", + "loss = 9.3236e-03\n", + "loss = 9.3173e-03\n", + "loss = 9.3110e-03\n", + "loss = 9.3049e-03\n", + "loss = 9.2990e-03\n", + "loss = 9.2931e-03\n", + "loss = 9.2874e-03\n", + "loss = 9.2819e-03\n", + "loss = 9.2764e-03\n", + "loss = 9.2711e-03\n", + "loss = 9.2659e-03\n", + "loss = 9.2607e-03\n", + "loss = 9.2557e-03\n", + "loss = 9.2508e-03\n", + "loss = 9.2460e-03\n", + "loss = 9.2413e-03\n", + "loss = 9.2366e-03\n", + "loss = 9.2321e-03\n", + "loss = 9.2276e-03\n", + "loss = 9.2233e-03\n", + "loss = 9.2189e-03\n", + "loss = 9.2147e-03\n", + "loss = 9.2105e-03\n", + "loss = 9.2065e-03\n", + "loss = 9.2024e-03\n", + "loss = 9.1985e-03\n", + "loss = 9.1946e-03\n", + "loss = 9.1907e-03\n", + "loss = 9.1869e-03\n", + "loss = 9.1832e-03\n", + "loss = 9.1795e-03\n", + "loss = 9.1758e-03\n", + "loss = 9.1723e-03\n", + "loss = 9.1687e-03\n", + "loss = 9.1652e-03\n", + "loss = 9.1617e-03\n", + "loss = 9.1583e-03\n", + "loss = 9.1549e-03\n", + "loss = 9.1516e-03\n", + "loss = 9.1483e-03\n", + "loss = 9.1450e-03\n", + "loss = 9.1418e-03\n", + "loss = 9.1385e-03\n", + "loss = 9.1354e-03\n", + "loss = 9.1322e-03\n", + "loss = 9.1291e-03\n", + "loss = 9.1260e-03\n", + "loss = 9.1229e-03\n", + "loss = 9.1198e-03\n", + "loss = 9.1168e-03\n", + "loss = 9.1138e-03\n", + "loss = 9.1108e-03\n", + "loss = 9.1078e-03\n", + "loss = 9.1049e-03\n", + "loss = 9.1019e-03\n", + "loss = 9.0990e-03\n", + "loss = 9.0961e-03\n", + "loss = 9.0932e-03\n", + "loss = 9.0903e-03\n", + "loss = 9.0875e-03\n", + "loss = 9.0846e-03\n", + "loss = 9.0818e-03\n", + "loss = 9.0790e-03\n", + "loss = 9.0762e-03\n", + "loss = 9.0734e-03\n", + "loss = 9.0706e-03\n", + "loss = 9.0678e-03\n", + "loss = 9.0650e-03\n", + "loss = 9.0623e-03\n", + "loss = 9.0595e-03\n", + "loss = 9.0568e-03\n", + "loss = 9.0540e-03\n", + "loss = 9.0513e-03\n", + "loss = 9.0485e-03\n", + "loss = 9.0458e-03\n", + "loss = 9.0431e-03\n", + "loss = 9.0404e-03\n", + "loss = 9.0377e-03\n", + "loss = 9.0350e-03\n", + "loss = 9.0323e-03\n", + "loss = 9.0296e-03\n", + "loss = 9.0269e-03\n", + "loss = 9.0242e-03\n", + "loss = 9.0215e-03\n", + "loss = 9.0188e-03\n", + "loss = 9.0161e-03\n", + "loss = 9.0134e-03\n", + "loss = 9.0108e-03\n", + "loss = 9.0081e-03\n", + "loss = 9.0054e-03\n", + "loss = 9.0027e-03\n", + "loss = 9.0000e-03\n", + "loss = 8.9974e-03\n", + "loss = 8.9947e-03\n", + "loss = 8.9920e-03\n", + "loss = 8.9893e-03\n", + "loss = 8.9867e-03\n", + "loss = 8.9840e-03\n", + "loss = 8.9813e-03\n", + "loss = 8.9786e-03\n", + "loss = 8.9760e-03\n", + "loss = 8.9733e-03\n", + "loss = 8.9706e-03\n", + "loss = 8.9679e-03\n", + "loss = 8.9652e-03\n", + "loss = 8.9625e-03\n", + "loss = 8.9599e-03\n", + "loss = 8.9572e-03\n", + "loss = 8.9545e-03\n", + "loss = 8.9518e-03\n", + "loss = 8.9491e-03\n", + "loss = 8.9464e-03\n", + "loss = 8.9437e-03\n", + "loss = 8.9410e-03\n", + "loss = 8.9383e-03\n", + "loss = 8.9356e-03\n", + "loss = 8.9329e-03\n", + "loss = 8.9302e-03\n", + "loss = 8.9275e-03\n", + "loss = 8.9248e-03\n", + "loss = 8.9221e-03\n", + "loss = 8.9193e-03\n", + "loss = 8.9166e-03\n", + "loss = 8.9139e-03\n", + "loss = 8.9112e-03\n", + "loss = 8.9084e-03\n", + "loss = 8.9057e-03\n", + "loss = 8.9030e-03\n", + "loss = 8.9002e-03\n", + "loss = 8.8975e-03\n", + "loss = 8.8947e-03\n", + "loss = 8.8920e-03\n", + "loss = 8.8892e-03\n", + "loss = 8.8865e-03\n", + "loss = 8.8837e-03\n", + "loss = 8.8809e-03\n", + "loss = 8.8782e-03\n", + "loss = 8.8754e-03\n", + "loss = 8.8726e-03\n", + "loss = 8.8698e-03\n", + "loss = 8.8670e-03\n", + "loss = 8.8643e-03\n", + "loss = 8.8615e-03\n", + "loss = 8.8587e-03\n", + "loss = 8.8559e-03\n", + "loss = 8.8531e-03\n", + "loss = 8.8502e-03\n", + "loss = 8.8474e-03\n", + "loss = 8.8446e-03\n", + "loss = 8.8418e-03\n", + "loss = 8.8390e-03\n", + "loss = 8.8361e-03\n", + "loss = 8.8333e-03\n", + "loss = 8.8305e-03\n", + "loss = 8.8276e-03\n", + "loss = 8.8248e-03\n", + "loss = 8.8219e-03\n", + "loss = 8.8190e-03\n", + "loss = 8.8162e-03\n", + "loss = 8.8133e-03\n", + "loss = 8.8104e-03\n", + "loss = 8.8076e-03\n", + "loss = 8.8047e-03\n", + "loss = 8.8018e-03\n", + "loss = 8.7989e-03\n", + "loss = 8.7960e-03\n", + "loss = 8.7931e-03\n", + "loss = 8.7902e-03\n", + "loss = 8.7873e-03\n", + "loss = 8.7844e-03\n", + "loss = 8.7815e-03\n", + "loss = 8.7785e-03\n", + "loss = 8.7756e-03\n", + "loss = 8.7727e-03\n", + "loss = 8.7697e-03\n", + "loss = 8.7668e-03\n", + "loss = 8.7638e-03\n", + "loss = 8.7609e-03\n", + "loss = 8.7579e-03\n", + "loss = 8.7549e-03\n", + "loss = 8.7519e-03\n", + "loss = 8.7490e-03\n", + "loss = 8.7460e-03\n", + "loss = 8.7430e-03\n", + "loss = 8.7400e-03\n", + "loss = 8.7370e-03\n", + "loss = 8.7340e-03\n", + "loss = 8.7310e-03\n", + "loss = 8.7280e-03\n", + "loss = 8.7249e-03\n", + "loss = 8.7219e-03\n", + "loss = 8.7189e-03\n", + "loss = 8.7158e-03\n", + "loss = 8.7128e-03\n", + "loss = 8.7097e-03\n", + "loss = 8.7067e-03\n", + "loss = 8.7036e-03\n", + "loss = 8.7005e-03\n", + "loss = 8.6975e-03\n", + "loss = 8.6944e-03\n", + "loss = 8.6913e-03\n", + "loss = 8.6882e-03\n", + "loss = 8.6851e-03\n", + "loss = 8.6820e-03\n", + "loss = 8.6789e-03\n", + "loss = 8.6758e-03\n", + "loss = 8.6726e-03\n", + "loss = 8.6695e-03\n", + "loss = 8.6664e-03\n", + "loss = 8.6632e-03\n", + "loss = 8.6601e-03\n", + "loss = 8.6569e-03\n", + "loss = 8.6538e-03\n", + "loss = 8.6506e-03\n", + "loss = 8.6474e-03\n", + "loss = 8.6442e-03\n", + "loss = 8.6410e-03\n", + "loss = 8.6379e-03\n", + "loss = 8.6346e-03\n", + "loss = 8.6314e-03\n", + "loss = 8.6282e-03\n", + "loss = 8.6250e-03\n", + "loss = 8.6218e-03\n", + "loss = 8.6185e-03\n", + "loss = 8.6153e-03\n", + "loss = 8.6120e-03\n", + "loss = 8.6088e-03\n", + "loss = 8.6055e-03\n", + "loss = 8.6022e-03\n", + "loss = 8.5990e-03\n", + "loss = 8.5957e-03\n", + "loss = 8.5924e-03\n", + "loss = 8.5891e-03\n", + "loss = 8.5858e-03\n", + "loss = 8.5824e-03\n", + "loss = 8.5791e-03\n", + "loss = 8.5758e-03\n", + "loss = 8.5724e-03\n", + "loss = 8.5691e-03\n", + "loss = 8.5657e-03\n", + "loss = 8.5624e-03\n", + "loss = 8.5590e-03\n", + "loss = 8.5556e-03\n", + "loss = 8.5522e-03\n", + "loss = 8.5488e-03\n", + "loss = 8.5454e-03\n", + "loss = 8.5420e-03\n", + "loss = 8.5386e-03\n", + "loss = 8.5352e-03\n", + "loss = 8.5317e-03\n", + "loss = 8.5283e-03\n", + "loss = 8.5248e-03\n", + "loss = 8.5213e-03\n", + "loss = 8.5179e-03\n", + "loss = 8.5144e-03\n", + "loss = 8.5109e-03\n", + "loss = 8.5074e-03\n", + "loss = 8.5039e-03\n", + "loss = 8.5003e-03\n", + "loss = 8.4968e-03\n", + "loss = 8.4933e-03\n", + "loss = 8.4897e-03\n", + "loss = 8.4862e-03\n", + "loss = 8.4826e-03\n", + "loss = 8.4790e-03\n", + "loss = 8.4754e-03\n", + "loss = 8.4718e-03\n", + "loss = 8.4682e-03\n", + "loss = 8.4646e-03\n", + "loss = 8.4610e-03\n", + "loss = 8.4573e-03\n", + "loss = 8.4537e-03\n", + "loss = 8.4500e-03\n", + "loss = 8.4463e-03\n", + "loss = 8.4427e-03\n", + "loss = 8.4390e-03\n", + "loss = 8.4353e-03\n", + "loss = 8.4316e-03\n", + "loss = 8.4278e-03\n", + "loss = 8.4241e-03\n", + "loss = 8.4204e-03\n", + "loss = 8.4166e-03\n", + "loss = 8.4128e-03\n", + "loss = 8.4090e-03\n", + "loss = 8.4052e-03\n", + "loss = 8.4014e-03\n", + "loss = 8.3976e-03\n", + "loss = 8.3938e-03\n", + "loss = 8.3900e-03\n", + "loss = 8.3861e-03\n", + "loss = 8.3822e-03\n", + "loss = 8.3784e-03\n", + "loss = 8.3745e-03\n", + "loss = 8.3706e-03\n", + "loss = 8.3667e-03\n", + "loss = 8.3627e-03\n", + "loss = 8.3588e-03\n", + "loss = 8.3549e-03\n", + "loss = 8.3509e-03\n", + "loss = 8.3469e-03\n", + "loss = 8.3429e-03\n", + "loss = 8.3389e-03\n", + "loss = 8.3349e-03\n", + "loss = 8.3309e-03\n", + "loss = 8.3268e-03\n", + "loss = 8.3228e-03\n", + "loss = 8.3187e-03\n", + "loss = 8.3146e-03\n", + "loss = 8.3105e-03\n", + "loss = 8.3064e-03\n", + "loss = 8.3023e-03\n", + "loss = 8.2982e-03\n", + "loss = 8.2940e-03\n", + "loss = 8.2899e-03\n", + "loss = 8.2857e-03\n", + "loss = 8.2815e-03\n", + "loss = 8.2773e-03\n", + "loss = 8.2731e-03\n", + "loss = 8.2688e-03\n", + "loss = 8.2646e-03\n", + "loss = 8.2603e-03\n", + "loss = 8.2560e-03\n", + "loss = 8.2517e-03\n", + "loss = 8.2474e-03\n", + "loss = 8.2431e-03\n", + "loss = 8.2387e-03\n", + "loss = 8.2344e-03\n", + "loss = 8.2300e-03\n", + "loss = 8.2256e-03\n", + "loss = 8.2212e-03\n", + "loss = 8.2168e-03\n", + "loss = 8.2123e-03\n", + "loss = 8.2079e-03\n", + "loss = 8.2034e-03\n", + "loss = 8.1989e-03\n", + "loss = 8.1944e-03\n", + "loss = 8.1899e-03\n", + "loss = 8.1854e-03\n", + "loss = 8.1808e-03\n", + "loss = 8.1762e-03\n", + "loss = 8.1716e-03\n", + "loss = 8.1670e-03\n", + "loss = 8.1624e-03\n", + "loss = 8.1578e-03\n", + "loss = 8.1531e-03\n", + "loss = 8.1484e-03\n", + "loss = 8.1437e-03\n", + "loss = 8.1390e-03\n", + "loss = 8.1343e-03\n", + "loss = 8.1296e-03\n", + "loss = 8.1248e-03\n", + "loss = 8.1200e-03\n", + "loss = 8.1152e-03\n", + "loss = 8.1104e-03\n", + "loss = 8.1056e-03\n", + "loss = 8.1007e-03\n", + "loss = 8.0958e-03\n", + "loss = 8.0909e-03\n", + "loss = 8.0860e-03\n", + "loss = 8.0811e-03\n", + "loss = 8.0761e-03\n", + "loss = 8.0712e-03\n", + "loss = 8.0662e-03\n", + "loss = 8.0612e-03\n", + "loss = 8.0561e-03\n", + "loss = 8.0511e-03\n", + "loss = 8.0460e-03\n", + "loss = 8.0409e-03\n", + "loss = 8.0358e-03\n", + "loss = 8.0307e-03\n", + "loss = 8.0255e-03\n", + "loss = 8.0203e-03\n", + "loss = 8.0151e-03\n", + "loss = 8.0099e-03\n", + "loss = 8.0047e-03\n", + "loss = 7.9994e-03\n", + "loss = 7.9942e-03\n", + "loss = 7.9889e-03\n", + "loss = 7.9835e-03\n", + "loss = 7.9782e-03\n", + "loss = 7.9728e-03\n", + "loss = 7.9674e-03\n", + "loss = 7.9620e-03\n", + "loss = 7.9566e-03\n", + "loss = 7.9511e-03\n", + "loss = 7.9456e-03\n", + "loss = 7.9401e-03\n", + "loss = 7.9346e-03\n", + "loss = 7.9291e-03\n", + "loss = 7.9235e-03\n", + "loss = 7.9179e-03\n", + "loss = 7.9123e-03\n", + "loss = 7.9066e-03\n", + "loss = 7.9009e-03\n", + "loss = 7.8952e-03\n", + "loss = 7.8895e-03\n", + "loss = 7.8838e-03\n", + "loss = 7.8780e-03\n", + "loss = 7.8722e-03\n", + "loss = 7.8664e-03\n", + "loss = 7.8605e-03\n", + "loss = 7.8547e-03\n", + "loss = 7.8488e-03\n", + "loss = 7.8428e-03\n", + "loss = 7.8369e-03\n", + "loss = 7.8309e-03\n", + "loss = 7.8249e-03\n", + "loss = 7.8189e-03\n", + "loss = 7.8128e-03\n", + "loss = 7.8067e-03\n", + "loss = 7.8006e-03\n", + "loss = 7.7945e-03\n", + "loss = 7.7883e-03\n", + "loss = 7.7821e-03\n", + "loss = 7.7759e-03\n", + "loss = 7.7696e-03\n", + "loss = 7.7633e-03\n", + "loss = 7.7570e-03\n", + "loss = 7.7507e-03\n", + "loss = 7.7443e-03\n", + "loss = 7.7379e-03\n", + "loss = 7.7314e-03\n", + "loss = 7.7250e-03\n", + "loss = 7.7185e-03\n", + "loss = 7.7119e-03\n", + "loss = 7.7054e-03\n", + "loss = 7.6988e-03\n", + "loss = 7.6921e-03\n", + "loss = 7.6855e-03\n", + "loss = 7.6788e-03\n", + "loss = 7.6721e-03\n", + "loss = 7.6653e-03\n", + "loss = 7.6585e-03\n", + "loss = 7.6517e-03\n", + "loss = 7.6448e-03\n", + "loss = 7.6379e-03\n", + "loss = 7.6310e-03\n", + "loss = 7.6240e-03\n", + "loss = 7.6170e-03\n", + "loss = 7.6100e-03\n", + "loss = 7.6029e-03\n", + "loss = 7.5958e-03\n", + "loss = 7.5886e-03\n", + "loss = 7.5815e-03\n", + "loss = 7.5742e-03\n", + "loss = 7.5670e-03\n", + "loss = 7.5597e-03\n", + "loss = 7.5523e-03\n", + "loss = 7.5449e-03\n", + "loss = 7.5375e-03\n", + "loss = 7.5301e-03\n", + "loss = 7.5226e-03\n", + "loss = 7.5150e-03\n", + "loss = 7.5074e-03\n", + "loss = 7.4998e-03\n", + "loss = 7.4922e-03\n", + "loss = 7.4845e-03\n", + "loss = 7.4767e-03\n", + "loss = 7.4689e-03\n", + "loss = 7.4611e-03\n", + "loss = 7.4532e-03\n", + "loss = 7.4453e-03\n", + "loss = 7.4373e-03\n", + "loss = 7.4293e-03\n", + "loss = 7.4212e-03\n", + "loss = 7.4131e-03\n", + "loss = 7.4050e-03\n", + "loss = 7.3968e-03\n", + "loss = 7.3886e-03\n", + "loss = 7.3803e-03\n", + "loss = 7.3720e-03\n", + "loss = 7.3636e-03\n", + "loss = 7.3552e-03\n", + "loss = 7.3467e-03\n", + "loss = 7.3382e-03\n", + "loss = 7.3296e-03\n", + "loss = 7.3210e-03\n", + "loss = 7.3123e-03\n", + "loss = 7.3036e-03\n", + "loss = 7.2949e-03\n", + "loss = 7.2861e-03\n", + "loss = 7.2772e-03\n", + "loss = 7.2683e-03\n", + "loss = 7.2594e-03\n", + "loss = 7.2504e-03\n", + "loss = 7.2414e-03\n", + "loss = 7.2323e-03\n", + "loss = 7.2231e-03\n", + "loss = 7.2139e-03\n", + "loss = 7.2047e-03\n", + "loss = 7.1954e-03\n", + "loss = 7.1861e-03\n", + "loss = 7.1767e-03\n", + "loss = 7.1673e-03\n", + "loss = 7.1578e-03\n", + "loss = 7.1483e-03\n", + "loss = 7.1387e-03\n", + "loss = 7.1291e-03\n", + "loss = 7.1195e-03\n", + "loss = 7.1098e-03\n", + "loss = 7.1000e-03\n", + "loss = 7.0902e-03\n", + "loss = 7.0804e-03\n", + "loss = 7.0705e-03\n", + "loss = 7.0606e-03\n", + "loss = 7.0506e-03\n", + "loss = 7.0406e-03\n", + "loss = 7.0305e-03\n", + "loss = 7.0204e-03\n", + "loss = 7.0103e-03\n", + "loss = 7.0001e-03\n", + "loss = 6.9898e-03\n", + "loss = 6.9795e-03\n", + "loss = 6.9692e-03\n", + "loss = 6.9588e-03\n", + "loss = 6.9484e-03\n", + "loss = 6.9380e-03\n", + "loss = 6.9275e-03\n", + "loss = 6.9169e-03\n", + "loss = 6.9063e-03\n", + "loss = 6.8957e-03\n", + "loss = 6.8850e-03\n", + "loss = 6.8743e-03\n", + "loss = 6.8636e-03\n", + "loss = 6.8528e-03\n", + "loss = 6.8419e-03\n", + "loss = 6.8310e-03\n", + "loss = 6.8201e-03\n", + "loss = 6.8091e-03\n", + "loss = 6.7981e-03\n", + "loss = 6.7870e-03\n", + "loss = 6.7759e-03\n", + "loss = 6.7648e-03\n", + "loss = 6.7536e-03\n", + "loss = 6.7424e-03\n", + "loss = 6.7311e-03\n", + "loss = 6.7198e-03\n", + "loss = 6.7084e-03\n", + "loss = 6.6970e-03\n", + "loss = 6.6856e-03\n", + "loss = 6.6741e-03\n", + "loss = 6.6625e-03\n", + "loss = 6.6509e-03\n", + "loss = 6.6393e-03\n", + "loss = 6.6276e-03\n", + "loss = 6.6159e-03\n", + "loss = 6.6042e-03\n", + "loss = 6.5924e-03\n", + "loss = 6.5805e-03\n", + "loss = 6.5686e-03\n", + "loss = 6.5567e-03\n", + "loss = 6.5447e-03\n", + "loss = 6.5327e-03\n", + "loss = 6.5206e-03\n", + "loss = 6.5085e-03\n", + "loss = 6.4963e-03\n", + "loss = 6.4841e-03\n", + "loss = 6.4718e-03\n", + "loss = 6.4595e-03\n", + "loss = 6.4472e-03\n", + "loss = 6.4348e-03\n", + "loss = 6.4223e-03\n", + "loss = 6.4098e-03\n", + "loss = 6.3973e-03\n", + "loss = 6.3847e-03\n", + "loss = 6.3721e-03\n", + "loss = 6.3594e-03\n", + "loss = 6.3467e-03\n", + "loss = 6.3340e-03\n", + "loss = 6.3212e-03\n", + "loss = 6.3083e-03\n", + "loss = 6.2954e-03\n", + "loss = 6.2825e-03\n", + "loss = 6.2695e-03\n", + "loss = 6.2565e-03\n", + "loss = 6.2434e-03\n", + "loss = 6.2303e-03\n", + "loss = 6.2171e-03\n", + "loss = 6.2039e-03\n", + "loss = 6.1906e-03\n", + "loss = 6.1773e-03\n", + "loss = 6.1640e-03\n", + "loss = 6.1506e-03\n", + "loss = 6.1372e-03\n", + "loss = 6.1237e-03\n", + "loss = 6.1102e-03\n", + "loss = 6.0966e-03\n", + "loss = 6.0830e-03\n", + "loss = 6.0694e-03\n", + "loss = 6.0557e-03\n", + "loss = 6.0420e-03\n", + "loss = 6.0282e-03\n", + "loss = 6.0144e-03\n", + "loss = 6.0005e-03\n", + "loss = 5.9866e-03\n", + "loss = 5.9727e-03\n", + "loss = 5.9587e-03\n", + "loss = 5.9447e-03\n", + "loss = 5.9306e-03\n", + "loss = 5.9165e-03\n", + "loss = 5.9024e-03\n", + "loss = 5.8882e-03\n", + "loss = 5.8740e-03\n", + "loss = 5.8598e-03\n", + "loss = 5.8455e-03\n", + "loss = 5.8311e-03\n", + "loss = 5.8168e-03\n", + "loss = 5.8024e-03\n", + "loss = 5.7879e-03\n", + "loss = 5.7734e-03\n", + "loss = 5.7589e-03\n", + "loss = 5.7444e-03\n", + "loss = 5.7298e-03\n", + "loss = 5.7152e-03\n", + "loss = 5.7005e-03\n", + "loss = 5.6859e-03\n", + "loss = 5.6711e-03\n", + "loss = 5.6564e-03\n", + "loss = 5.6416e-03\n", + "loss = 5.6268e-03\n", + "loss = 5.6120e-03\n", + "loss = 5.5971e-03\n", + "loss = 5.5822e-03\n", + "loss = 5.5672e-03\n", + "loss = 5.5523e-03\n", + "loss = 5.5373e-03\n", + "loss = 5.5223e-03\n", + "loss = 5.5072e-03\n", + "loss = 5.4921e-03\n", + "loss = 5.4770e-03\n", + "loss = 5.4619e-03\n", + "loss = 5.4468e-03\n", + "loss = 5.4316e-03\n", + "loss = 5.4164e-03\n", + "loss = 5.4011e-03\n", + "loss = 5.3859e-03\n", + "loss = 5.3706e-03\n", + "loss = 5.3553e-03\n", + "loss = 5.3400e-03\n", + "loss = 5.3247e-03\n", + "loss = 5.3093e-03\n", + "loss = 5.2939e-03\n", + "loss = 5.2785e-03\n", + "loss = 5.2631e-03\n", + "loss = 5.2477e-03\n", + "loss = 5.2322e-03\n", + "loss = 5.2168e-03\n", + "loss = 5.2013e-03\n", + "loss = 5.1858e-03\n", + "loss = 5.1703e-03\n", + "loss = 5.1548e-03\n", + "loss = 5.1392e-03\n", + "loss = 5.1237e-03\n", + "loss = 5.1081e-03\n", + "loss = 5.0925e-03\n", + "loss = 5.0769e-03\n", + "loss = 5.0613e-03\n", + "loss = 5.0457e-03\n", + "loss = 5.0301e-03\n", + "loss = 5.0145e-03\n", + "loss = 4.9988e-03\n", + "loss = 4.9832e-03\n", + "loss = 4.9675e-03\n", + "loss = 4.9519e-03\n", + "loss = 4.9362e-03\n", + "loss = 4.9206e-03\n", + "loss = 4.9049e-03\n", + "loss = 4.8892e-03\n", + "loss = 4.8735e-03\n", + "loss = 4.8578e-03\n", + "loss = 4.8422e-03\n", + "loss = 4.8265e-03\n", + "loss = 4.8108e-03\n", + "loss = 4.7951e-03\n", + "loss = 4.7794e-03\n", + "loss = 4.7638e-03\n", + "loss = 4.7481e-03\n", + "loss = 4.7324e-03\n", + "loss = 4.7167e-03\n", + "loss = 4.7011e-03\n", + "loss = 4.6854e-03\n", + "loss = 4.6697e-03\n", + "loss = 4.6541e-03\n", + "loss = 4.6384e-03\n", + "loss = 4.6228e-03\n", + "loss = 4.6072e-03\n", + "loss = 4.5915e-03\n", + "loss = 4.5759e-03\n", + "loss = 4.5603e-03\n", + "loss = 4.5447e-03\n", + "loss = 4.5291e-03\n", + "loss = 4.5135e-03\n", + "loss = 4.4980e-03\n", + "loss = 4.4824e-03\n", + "loss = 4.4669e-03\n", + "loss = 4.4514e-03\n", + "loss = 4.4358e-03\n", + "loss = 4.4203e-03\n", + "loss = 4.4049e-03\n", + "loss = 4.3894e-03\n", + "loss = 4.3739e-03\n", + "loss = 4.3585e-03\n", + "loss = 4.3431e-03\n", + "loss = 4.3277e-03\n", + "loss = 4.3123e-03\n", + "loss = 4.2969e-03\n", + "loss = 4.2816e-03\n", + "loss = 4.2662e-03\n", + "loss = 4.2509e-03\n", + "loss = 4.2356e-03\n", + "loss = 4.2204e-03\n", + "loss = 4.2051e-03\n", + "loss = 4.1899e-03\n", + "loss = 4.1747e-03\n", + "loss = 4.1595e-03\n", + "loss = 4.1444e-03\n", + "loss = 4.1293e-03\n", + "loss = 4.1142e-03\n", + "loss = 4.0991e-03\n", + "loss = 4.0840e-03\n", + "loss = 4.0690e-03\n", + "loss = 4.0540e-03\n", + "loss = 4.0390e-03\n", + "loss = 4.0241e-03\n", + "loss = 4.0091e-03\n", + "loss = 3.9942e-03\n", + "loss = 3.9794e-03\n", + "loss = 3.9645e-03\n", + "loss = 3.9497e-03\n", + "loss = 3.9350e-03\n", + "loss = 3.9202e-03\n", + "loss = 3.9055e-03\n", + "loss = 3.8908e-03\n", + "loss = 3.8762e-03\n", + "loss = 3.8615e-03\n", + "loss = 3.8469e-03\n", + "loss = 3.8324e-03\n", + "loss = 3.8179e-03\n", + "loss = 3.8034e-03\n", + "loss = 3.7889e-03\n", + "loss = 3.7745e-03\n", + "loss = 3.7601e-03\n", + "loss = 3.7457e-03\n", + "loss = 3.7314e-03\n", + "loss = 3.7171e-03\n", + "loss = 3.7028e-03\n", + "loss = 3.6886e-03\n", + "loss = 3.6744e-03\n", + "loss = 3.6603e-03\n", + "loss = 3.6462e-03\n", + "loss = 3.6321e-03\n", + "loss = 3.6181e-03\n", + "loss = 3.6041e-03\n", + "loss = 3.5901e-03\n", + "loss = 3.5762e-03\n", + "loss = 3.5623e-03\n", + "loss = 3.5484e-03\n", + "loss = 3.5346e-03\n", + "loss = 3.5208e-03\n", + "loss = 3.5071e-03\n", + "loss = 3.4934e-03\n", + "loss = 3.4797e-03\n", + "loss = 3.4661e-03\n", + "loss = 3.4525e-03\n", + "loss = 3.4390e-03\n", + "loss = 3.4255e-03\n", + "loss = 3.4121e-03\n", + "loss = 3.3986e-03\n", + "loss = 3.3853e-03\n", + "loss = 3.3719e-03\n", + "loss = 3.3586e-03\n", + "loss = 3.3454e-03\n", + "loss = 3.3322e-03\n", + "loss = 3.3190e-03\n", + "loss = 3.3059e-03\n", + "loss = 3.2928e-03\n", + "loss = 3.2797e-03\n", + "loss = 3.2667e-03\n", + "loss = 3.2537e-03\n", + "loss = 3.2408e-03\n", + "loss = 3.2279e-03\n", + "loss = 3.2151e-03\n", + "loss = 3.2023e-03\n", + "loss = 3.1896e-03\n", + "loss = 3.1769e-03\n", + "loss = 3.1642e-03\n", + "loss = 3.1516e-03\n", + "loss = 3.1390e-03\n", + "loss = 3.1265e-03\n", + "loss = 3.1140e-03\n", + "loss = 3.1015e-03\n", + "loss = 3.0891e-03\n", + "loss = 3.0767e-03\n", + "loss = 3.0644e-03\n", + "loss = 3.0521e-03\n", + "loss = 3.0399e-03\n", + "loss = 3.0277e-03\n", + "loss = 3.0156e-03\n", + "loss = 3.0035e-03\n", + "loss = 2.9914e-03\n", + "loss = 2.9794e-03\n", + "loss = 2.9674e-03\n", + "loss = 2.9555e-03\n", + "loss = 2.9436e-03\n", + "loss = 2.9318e-03\n", + "loss = 2.9200e-03\n", + "loss = 2.9082e-03\n", + "loss = 2.8965e-03\n", + "loss = 2.8849e-03\n", + "loss = 2.8733e-03\n", + "loss = 2.8617e-03\n", + "loss = 2.8502e-03\n", + "loss = 2.8387e-03\n", + "loss = 2.8272e-03\n", + "loss = 2.8158e-03\n", + "loss = 2.8045e-03\n", + "loss = 2.7932e-03\n", + "loss = 2.7819e-03\n", + "loss = 2.7707e-03\n" + ] + } + ], + "source": [ + "graphdef, params = nnx.split(vmap_mod)\n", + "loss_fn(graphdef, params, X, y)\n", + "optimizer = optax.adam(1e-4)\n", + "opt_state = optimizer.init(params)\n", + "optimizer_update = jax.jit(optimizer.update)\n", + "optax_apply_updates = jax.jit(optax.apply_updates)\n", + "value_and_grad = jax.jit(jax.value_and_grad(loss_fn, argnums=1))\n", + "for _ in range(1000):\n", + " l, gs = value_and_grad(graphdef, params, X, y)\n", + " updates, opt_state = optimizer.update(gs, opt_state)\n", + " params = optax_apply_updates(params, updates)\n", + " print(f\"loss = {l:.4e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(100,)\n", + "(50, 100)\n" + ] + } + ], + "source": [ + "print(nnx.split(vmap_mod)[1][\"state\"][\"params\"][\"bias\"].value.shape)\n", + "print(nnx.split(vmap_mod)[1][\"state\"][\"params\"][\"kernel\"].value.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[ 0.13516979, -0.05105809, 0.07632802, ..., 0.0749987 ,\n", + " -0.05567247, -0.05779112],\n", + " [-0.00094059, 0.04644845, -0.21303476, ..., 0.10651965,\n", + " -0.08589667, -0.04726676],\n", + " [-0.12583305, 0.17015733, 0.04022592, ..., -0.14773165,\n", + " -0.0417266 , -0.00611958],\n", + " ...,\n", + " [ 0.08767611, 0.11031242, -0.05278822, ..., -0.02778996,\n", + " -0.13942076, 0.12322947],\n", + " [-0.00845618, 0.15537211, 0.13054934, ..., 0.17835702,\n", + " 0.10111373, 0.16907702],\n", + " [ 0.05105361, -0.06231965, 0.04937202, ..., 0.07263377,\n", + " -0.12006826, -0.04576119]], dtype=float32)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@nnx.jit\n", + "def gen_layer(r_in):\n", + " return nnx.Linear(100, 100, rngs=r_in)\n", + " \n", + "rng = rng\n", + "gen_layer().kernel.value" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "rngs = nnx.Rngs(0, bias=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "stream = rngs.get(\"default\")" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "rkey = stream.key" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def get_randn(rngs):\n", + " return nn.initializers.normal(1)(rngs, 100)" + ] + }, + { + "cell_type": "code", + "execution_count": 349, + "metadata": {}, + "outputs": [], + "source": [ + "@functools.partial(jax.tree_util.register_dataclass, data_fields=[\"r\"], meta_fields=[])\n", + "@dataclasses.dataclass\n", + "class RngWrapper:\n", + " r: nnx.Rngs\n", + " def __call__(self, key : str=\"default\"):\n", + " return self.r.get(key)() if isinstance(self.r, nnx.Rngs) else self.r" + ] + }, + { + "cell_type": "code", + "execution_count": 239, + "metadata": {}, + "outputs": [], + "source": [ + "r = RngWrapper(rngs)" + ] + }, + { + "cell_type": "code", + "execution_count": 346, + "metadata": {}, + "outputs": [], + "source": [ + "@dataclasses.dataclass\n", + "class MyModel(nnx.Module):\n", + " in_features: int\n", + " out_features: int\n", + " rngs: RngWrapper = None\n", + " \n", + " def __post_init__(self):\n", + " shape = (self.in_features, self.out_features)\n", + " #if isinstance(self.rngs, nnx.Rngs):\n", + " # self.kernel = nnx.initializers.lecun_normal()(self.rngs.get(\"default\")(), shape)\n", + " # nnx.Linear\n", + " #else:\n", + " # self.kernel = nnx.initializers.lecun_normal()(self.rngs, shape)\n", + " self.kernel = nnx.initializers.lecun_normal()(self.rngs(), shape)\n", + " #self.kernel = nnx.initializers.lecun_normal()(self.rngs(\"default\"), shape)\n", + " #self.kernel = nnx.initializers.lecun_normal()(self.rngs(\"default\"), shape)\n", + " #self.kernel = nnx.initializers.lecun_normal()(self.rngs, shape)\n", + " \n", + " def __call__(self):\n", + " return self.kernel\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 348, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "expected 0 arguments, got 1", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[348], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mMyModel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mRngWrapper\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnnx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRngs\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/object.py:79\u001b[0m, in \u001b[0;36mObjectMeta.__call__\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mcls\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m---> 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_graph_node_meta_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/object.py:88\u001b[0m, in \u001b[0;36m_graph_node_meta_call\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 86\u001b[0m node \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__new__\u001b[39m(\u001b[38;5;28mcls\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28mvars\u001b[39m(node)[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_object__state\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m ObjectState()\n\u001b[0;32m---> 88\u001b[0m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_object_meta_construct\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m node\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/object.py:82\u001b[0m, in \u001b[0;36mObjectMeta._object_meta_construct\u001b[0;34m(cls, self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_object_meta_construct\u001b[39m(\u001b[38;5;28mcls\u001b[39m, \u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 82\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m:6\u001b[0m, in \u001b[0;36m__init__\u001b[0;34m(self, in_features, out_features, rngs)\u001b[0m\n", + "Cell \u001b[0;32mIn[346], line 14\u001b[0m, in \u001b[0;36mMyModel.__post_init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 8\u001b[0m shape \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_features, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mout_features)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m#if isinstance(self.rngs, nnx.Rngs):\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# self.kernel = nnx.initializers.lecun_normal()(self.rngs.get(\"default\")(), shape)\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# nnx.Linear\u001b[39;00m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m#else:\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;66;03m# self.kernel = nnx.initializers.lecun_normal()(self.rngs, shape)\u001b[39;00m\n\u001b[0;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkernel \u001b[38;5;241m=\u001b[39m \u001b[43mnnx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minitializers\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlecun_normal\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrngs\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/nn/initializers.py:335\u001b[0m, in \u001b[0;36mvariance_scaling..init\u001b[0;34m(key, shape, dtype)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m jnp\u001b[38;5;241m.\u001b[39missubdtype(dtype, jnp\u001b[38;5;241m.\u001b[39mfloating):\n\u001b[1;32m 333\u001b[0m \u001b[38;5;66;03m# constant is stddev of standard normal truncated to (-2, 2)\u001b[39;00m\n\u001b[1;32m 334\u001b[0m stddev \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39msqrt(variance) \u001b[38;5;241m/\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray(\u001b[38;5;241m.87962566103423978\u001b[39m, dtype)\n\u001b[0;32m--> 335\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrandom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtruncated_normal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m*\u001b[39m stddev\n\u001b[1;32m 336\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 337\u001b[0m \u001b[38;5;66;03m# constant is stddev of complex standard normal truncated to 2\u001b[39;00m\n\u001b[1;32m 338\u001b[0m stddev \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39msqrt(variance) \u001b[38;5;241m/\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray(\u001b[38;5;241m.95311164380491208\u001b[39m, dtype)\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/random.py:831\u001b[0m, in \u001b[0;36mtruncated_normal\u001b[0;34m(key, lower, upper, shape, dtype)\u001b[0m\n\u001b[1;32m 828\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdtype argument to `truncated_normal` must be a float \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 829\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdtype, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 830\u001b[0m dtype \u001b[38;5;241m=\u001b[39m dtypes\u001b[38;5;241m.\u001b[39mcanonicalize_dtype(dtype)\n\u001b[0;32m--> 831\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_truncated_normal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlower\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mupper\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n", + " \u001b[0;31m[... skipping hidden 20 frame]\u001b[0m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/jax/_src/core.py:3244\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 3242\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m custom_str_eqn_compact_rules[primitive](primitive, params)\n\u001b[1;32m 3243\u001b[0m primitive_name \u001b[38;5;241m=\u001b[39m primitive\u001b[38;5;241m.\u001b[39mname\n\u001b[0;32m-> 3244\u001b[0m kvs \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mv\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m params\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 3245\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _compact_eqn_should_include(k, v))\n\u001b[1;32m 3246\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mprimitive_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m[\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkvs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m]\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(kvs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m primitive_name\n", + "\u001b[0;31mTypeError\u001b[0m: expected 0 arguments, got 1" + ] + } + ], + "source": [ + "MyModel(100, 100, RngWrapper(nnx.Rngs(0)))" + ] + }, + { + "cell_type": "code", + "execution_count": 345, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def create_linear(rngs):\n", + " return nnx.Linear(100, 100, rngs=rngs)" + ] + }, + { + "cell_type": "code", + "execution_count": 316, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "t = 2.8158e+00\n" + ] + } + ], + "source": [ + "t = time.time()\n", + "mod = nnx.vmap(lambda i, o, k: nnx.Linear(i, o, rngs=k), axis_size=50000, in_axes=(None, None, None))(100, 100, rngs)\n", + "t = time.time() - t\n", + "print(f\"{t = :.4e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 337, + "metadata": {}, + "outputs": [], + "source": [ + "graphdef, params = nnx.split(mod)\n", + "fn = jax.jit(nnx.scan(lambda x, p: (nnx.merge(graphdef, p)(x), None)))\n", + "#(jnp.zeros(100), params)" + ] + }, + { + "cell_type": "code", + "execution_count": 343, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "module 'flax.nnx' has no attribute 'bridge'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[343], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mnnx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbridge\u001b[49m\n", + "\u001b[0;31mAttributeError\u001b[0m: module 'flax.nnx' has no attribute 'bridge'" + ] + } + ], + "source": [ + "nnx.bridge" + ] + }, + { + "cell_type": "code", + "execution_count": 341, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float64),\n", + " None)" + ] + }, + "execution_count": 341, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fn(jnp.zeros(100), params)" + ] + }, + { + "cell_type": "code", + "execution_count": 211, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[ 3.55973911e-01, -2.37477006e-01, -9.74529953e-02,\n", + " -3.39887647e-01, 6.99265619e-01, 4.49438212e-01,\n", + " -1.17788739e-01, 4.91785924e-01, 1.43185656e-01,\n", + " -1.35137965e-01],\n", + " [ 2.11869889e-01, -4.85307116e-01, -1.85187564e-02,\n", + " 1.35506714e-01, 3.02765362e-02, 9.85176910e-02,\n", + " -8.09381560e-02, 2.50044546e-01, 9.01392114e-03,\n", + " -2.77441520e-02],\n", + " [ 6.12978399e-01, -5.05180316e-01, -1.20650054e-01,\n", + " 3.75843736e-01, 2.17106242e-01, 1.02653012e-02,\n", + " 5.78397549e-01, -4.85449128e-01, 2.70678648e-01,\n", + " 2.44499461e-01],\n", + " [-4.31290899e-01, 8.96568564e-02, -5.51868309e-01,\n", + " -2.96766079e-01, -7.05394023e-01, 2.57234510e-01,\n", + " -2.29387100e-01, -7.16125931e-01, -6.00397803e-01,\n", + " -5.67827435e-01],\n", + " [ 2.23155164e-01, -1.99485335e-01, -2.55706195e-01,\n", + " 4.28001152e-01, -3.65027100e-01, -1.38338403e-01,\n", + " -5.10901823e-01, 4.56128621e-01, 5.56484914e-01,\n", + " -5.56046998e-01],\n", + " [ 1.18075951e-01, -1.69895637e-01, -4.09282855e-01,\n", + " -1.93506271e-01, 4.71324021e-04, 4.26806842e-01,\n", + " 5.72788747e-03, 2.13779474e-01, 2.99276780e-01,\n", + " 2.82478080e-01],\n", + " [ 1.84209731e-01, -6.11981531e-02, -3.40476559e-01,\n", + " 4.29254537e-01, 1.21000041e-01, 1.66407238e-01,\n", + " 5.64191031e-01, -6.41666040e-02, 5.26531006e-01,\n", + " 6.42945781e-01],\n", + " [ 1.18356967e-01, 6.88516232e-01, 2.42346186e-01,\n", + " 1.13217974e-03, -7.06617960e-02, 1.74556856e-01,\n", + " -7.46310141e-03, 6.56408665e-01, -4.58758064e-01,\n", + " -6.02616799e-01],\n", + " [ 1.71379340e-01, 3.02302362e-02, -1.63445600e-01,\n", + " 8.96689669e-02, -2.52574914e-01, -4.21846373e-01,\n", + " -4.77525850e-01, 4.51567715e-01, -1.79644584e-01,\n", + " -4.44087654e-01],\n", + " [ 2.92692156e-02, 2.98974156e-01, 5.55897741e-01,\n", + " -5.18448380e-01, 3.86063011e-01, -4.65159316e-01,\n", + " -1.41970684e-02, 2.96357553e-01, 1.44150289e-01,\n", + " 1.46604166e-01]], dtype=float64)" + ] + }, + "execution_count": 211, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "under_jit(r)" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([-1.39944734e+00, 1.42729388e-01, -2.04663887e+00, 6.84886398e-01,\n", + " 7.53044294e-01, 1.75674106e+00, -2.62389780e-01, -4.68049738e-01,\n", + " -4.79173378e-01, 1.64071374e-03, 1.82383563e+00, -2.44221580e+00,\n", + " -1.25232444e+00, 3.04795824e-01, 1.15632370e+00, 7.60857420e-01,\n", + " -1.33744763e+00, -1.77406268e+00, 5.06903744e-01, 1.72165095e+00,\n", + " 1.41896385e+00, 9.19465409e-01, -5.67572806e-01, 1.10987088e+00,\n", + " -4.98479686e-01, 1.89477152e+00, -1.52960825e-01, -1.53472124e-01,\n", + " 2.82643017e+00, -2.21983111e+00, 2.00269146e-01, 1.09333757e+00,\n", + " -1.66862222e+00, -8.56525357e-01, -1.02235849e+00, -7.24134529e-01,\n", + " 4.48094968e-01, -7.93246631e-01, 1.40813146e-01, -2.54582938e+00,\n", + " -1.13855538e+00, -4.76501459e-01, -3.93253085e-01, -1.18342087e+00,\n", + " 2.13279447e+00, -1.40168059e+00, 6.49652544e-03, -4.12427051e-01,\n", + " -6.57155787e-01, -1.19050183e+00, 6.96635743e-01, 2.27212476e-01,\n", + " -3.50703119e-01, 5.13223595e-01, -1.44912848e+00, 3.24963613e+00,\n", + " 5.63140810e-01, 6.77694287e-02, -6.63633567e-01, -1.18094659e+00,\n", + " -5.97535397e-01, -6.27794198e-01, -6.96481752e-01, 2.12944636e-01,\n", + " -6.66431936e-01, 2.09924889e+00, -2.12541596e-01, -1.03803163e+00,\n", + " 1.49710488e+00, 1.32619456e+00, -1.10585359e+00, 4.20273501e-01,\n", + " -4.56547942e-02, -5.87798800e-01, 9.63563868e-01, -8.74024262e-01,\n", + " 3.05793155e+00, 2.07996328e+00, 2.50519016e-01, 8.94426973e-01,\n", + " -1.42202002e+00, -7.51182080e-01, 3.96023350e-01, -3.86676841e-01,\n", + " 1.21138883e+00, -1.04336148e+00, 1.95922769e-01, 1.28663718e+00,\n", + " 1.16980698e-01, -3.38606962e-01, -2.76363003e+00, -1.49801837e+00,\n", + " 1.29559500e+00, -5.41882792e-01, 3.39621748e-01, -2.20668680e-01,\n", + " 3.88775811e+00, 1.24475108e+00, -4.00992678e-01, 2.03779212e+00], dtype=float64)" + ] + }, + "execution_count": 129, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_randn(rngs)" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "err = 0.0000e+00\n" + ] + } + ], + "source": [ + "y1 = call_linen_model(layer, x)\n", + "y2 = nnx_layer(x)\n", + "err = jnp.linalg.norm(y1 - y2) / (jnp.linalg.norm(y1) + 1e-7)\n", + "print(f\"{err = :.4e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "params = layer.init(jrandom.key(time.time_ns() % 2 ** 31), x)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "nnx_layer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "jax.jit\n", + "def grad(graphdef, state, *args):\n", + " gs = jax.grad(lambda state: jnp.sum(nnx.merge(graphdef, state)(*args)))(state)\n", + " return gs" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = optax.adam(1e-5)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "opt_state = optimizer.init(nnx.split(nnx_layer)[1])\n", + "optimizer_update = jax.jit(optimizer.update)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "params = nnx.split(nnx_layer)[1]\n", + "\n", + "@jax.jit\n", + "def learning_step(graphdef, params, *args, opt_state=None):\n", + " if opt_state is None:\n", + " opt_state = optimizer.init(params)\n", + " gs = grad(graphdef, params, x)\n", + " updates, opt_state = optimizer.update(gs, opt_state)\n", + " new_params = optax.apply_updates(params, updates)\n", + " return new_params, opt_state" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "graphdef, params = nnx.split(nnx_layer)\n", + "opt_state = None\n", + "for _ in range(10000):\n", + " params, opt_state = learning_step(graphdef, params, x, opt_state=opt_state)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "devel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}