-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is it possible to include instructions on how to run it on GPUs #4
Comments
The jax code should also run on GPUs. We have tested this on a virtual machine on google cloud so it should work without any special instructions. |
thanks for the reply! But it would throw an OOM error on a single Titan X GPU, it'd be nice if there's a flag like accumulate-gradients/update-freq to be able to reproduce the results on a single GPU. (sorry if this is a dumb question, but I'm not very familiar with tensorflow/jax) |
Thanks for the feedback! @ppham27 ran this on the cloud vm, so I'm looping him in and wondering if he has any thoughts on this. |
A single Titan X doesn't have enough HBM. For our GPU setup, we had 8 V100s for a total of 128GB of HBM. For a single Titan X, I think you could max out at batch size of 3, which is probably, too small. Adding an outer loop and doing gradient accumulation is probably the right way to address this. If there's a lot of interest in being able to train on a single GPU, we can look into this. |
Having a way to turn on accumulate-gradients/update-freq would be amazing for reproducibility on GPUs. What is the best approach for doing this in JAX? |
@MostafaDehghani has an example for this. Do you mind sharing it? |
Hi, thanks for the question. Yes. I also think using gradient accumulation is the way to go. Here is an example of implementing it in JAX, which we used in another project, but I'm sure it's easily portable to LRA. Adding gradient accumulation to LRA is in our TODO list, but currently there a few higher priority fixes/features requests that we should take care of. In the meantime, a PR that adds it to our training loops is extremely welcome :) |
Hi, Mostafa! Thank you for the quick response. I was able to adapt your code for text classification and it seems like the gradient accumulation is working fine. Since def train_step(optimizer, batch, learning_rate_fn, accum_steps, dropout_rng=None):
train_keys = ['inputs', 'targets']
(inputs, targets) = [batch.get(k, None) for k in train_keys]
dropout_rng, new_dropout_rng = random.split(dropout_rng)
def loss_fn(model, x, y):
"""Loss function used for training."""
with nn.stochastic(dropout_rng):
logits = model(x, train=True)
loss, weight_sum = train_utils.compute_weighted_cross_entropy(
logits, y, num_classes=CLASS_MAP[FLAGS.task_name], weights=None)
mean_loss = loss / weight_sum
return mean_loss, logits
step = optimizer.state.step
lr = learning_rate_fn(step)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
# compute gradients and get logits
_, grad = accumulate_gradient(grad_fn, optimizer.target, inputs, targets, accum_steps)
grad = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), grad)
logits = optimizer.target(inputs, train=False)
# to save memory:
# logits = optimizer.target(inputs[0][jnp.newaxis, ...], train=False)
# for i in range(1, inputs.shape[0]):
# y_hat = optimizer.target(inputs[i][jnp.newaxis, ...], train=False)
# logits = jnp.concatenate((logits, y_hat), axis=0)
new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
metrics = compute_metrics(logits, targets, None)
metrics['learning_rate'] = lr
return new_optimizer, metrics, new_dropout_rng
def accumulate_gradient(loss_and_grad_fn, params, inputs, labels, accum_steps):
"""Accumulate gradient over multiple steps to save on memory."""
if accum_steps and accum_steps > 1:
assert inputs.shape[0] % accum_steps == 0, (
f'Bad accum_steps {accum_steps} for batch size {inputs.shape[0]}')
step_size = inputs.shape[0] // accum_steps
(l, _), g = loss_and_grad_fn(params, inputs[:step_size], labels[:step_size])
def acc_grad_and_loss(i, l_and_g):
inps = jax.lax.dynamic_slice(inputs, (i * step_size, 0),
(step_size,) + inputs.shape[1:])
lbls = jax.lax.dynamic_slice(labels[..., jnp.newaxis], (i * step_size, 1),
(step_size, 1)).squeeze(axis=-1)
(li, _), gi = loss_and_grad_fn(params, inps, lbls)
l, g = l_and_g
return l + li, jax.tree_multimap(lambda x, y: x + y, g, gi)
l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
l, g = jax.tree_map(lambda x: x / accum_steps, (l, g))
return l, g
else:
return loss_and_grad_fn(params, inputs, labels) |
Hi! I got the following results on the test set by using a single GPU (24GB) and setting
@vanzytay @MostafaDehghani Any idea on why? Best, |
I'm also running into memory issues. I've given up on the vanilla Transformer (this is a benchmark for efficient Transformers, after all), but even for the Performer, I need 2× Tesla V100 (32GB each). Do you think it's possible to reproduce your results with, say, a batch size of 16 or 8 (and without changing the code)? |
In Table 2 you given some insights on the 'peak memory usage' per device with a batch size of 32. Can I expect to have a similar memory consumption on a single GPU with a batch size of 32 or 2? |
hi, i've also met the OOM problem with a V100 32GB card, really need the gradient accumulation def loss_fn(model, inputs, targets):
with nn.stateful(state) as new_state:
with nn.stochastic(dropout_rng):
logits = model(inputs, train=True)
...
return mean_loss, (new_state, logits) and the returned new_state is used for the next train_step by the train_loop method for step, batch in zip(range(start_step, num_train_steps), train_iter):
batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access
optimizer, state, metrics, dropout_rngs = p_train_step(
optimizer, state, batch, dropout_rng=dropout_rngs) Would simply deleting this variable as in your implementation cause some problem in the training ? |
I'd also like to duplicate @La-SilverLand question. Currently I'm trying to fit the Pathfinder model code into a V100 GPU, and you have provided all tools for that except the answer about |
Sorry for the delay in my reply to this issue. If you needed a ResNet baseline that has BatchNorm, I recommend using the version with GroupNorm to avoid complication of handling batch statistic when using gradient accumulation. |
Thanks for response. It seems that in this case Transformer implementations in the repo should be fine (at least most of them) -- LayerNorms won't use batch-wise statistics. |
This code seems to be using 4x4 TPUs, but since I don't have access to TPUs, I wonder if you could release instructions on how to replicater the results on GPUs, which would make this code more accessible for people without abundant computation resources.
The text was updated successfully, but these errors were encountered: