import jax.numpy as np
import numpy as onp
a = onp.arange(12).reshape((6, 2))
b = onp.arange(6).reshape((2, 3))
a_ = np.asarray(a)
b_ = np.asarray(b)
a[b] # success
a_[b_] # success
a_[b] # success
a[b_] # error: index 3 is out of bounds for axis 1 with size 2
Generally speaking, JAX supports NumPy arrays, but NumPy does not support JAX arrays.
import numpy as onp
import torch
a = onp.random.rand(3, 4, 5)
b = onp.random.rand(4, 5, 6)
onp.dot(a, b) # success
a_ = torch.from_numpy(a)
b_ = torch.from_numpy(b)
torch.dot(a_, b_) # error: 1D tensors expected, but got 3D and 3D tensors
import torch
x = torch.tensor([[-1., 1.]])
print(x.std(-1).numpy()) # [1.4142135]
print(x.numpy().std(-1)) # [1.]
This is because in np.std
the denominator is n, while in torch.std
it is n-1. See pytorch/pytorch#1854 for details.
JAX uses bfloat16 for matrix multiplication on TPU by default, even if the data type is float32.
import jax.numpy as np
print(4176 * 5996) # 25039296
a = np.array(0.4176, dtype=np.float32)
b = np.array(0.5996, dtype=np.float32)
print((a * b).item()) # 0.25039297342300415
To do matrix multiplication in float32, you need to add this line at the top of the script:
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
Other precision values can be found in jax.lax.Precision. See google/jax#9973 for details.
Weight matrix of linear layer is transposed in PyTorch, but not in Flax. Therefore, if you want to convert model parameters between PyTorch and Flax, you needed to transpose the weight matrices.
In Flax:
import flax.linen as nn
import jax.numpy as np
import jax.random as rand
linear = nn.Dense(5)
key = rand.PRNGKey(42)
params = linear.init(key, np.zeros((3,)))
print(params['params']['kernel'].shape) # (3, 5)
In PyTorch:
import torch.nn as nn
linear = nn.Linear(3, 5)
print(linear.weight.shape) # (5, 3), not (3, 5)