Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to include instructions on how to run it on GPUs #4

Open
da03 opened this issue Nov 16, 2020 · 15 comments
Open

Is it possible to include instructions on how to run it on GPUs #4

da03 opened this issue Nov 16, 2020 · 15 comments

Comments

@da03
Copy link

da03 commented Nov 16, 2020

This code seems to be using 4x4 TPUs, but since I don't have access to TPUs, I wonder if you could release instructions on how to replicater the results on GPUs, which would make this code more accessible for people without abundant computation resources.

@vanzytay
Copy link
Collaborator

The jax code should also run on GPUs. We have tested this on a virtual machine on google cloud so it should work without any special instructions.

@da03
Copy link
Author

da03 commented Dec 2, 2020

thanks for the reply! But it would throw an OOM error on a single Titan X GPU, it'd be nice if there's a flag like accumulate-gradients/update-freq to be able to reproduce the results on a single GPU. (sorry if this is a dumb question, but I'm not very familiar with tensorflow/jax)

@vanzytay
Copy link
Collaborator

vanzytay commented Dec 3, 2020

Thanks for the feedback!

@ppham27 ran this on the cloud vm, so I'm looping him in and wondering if he has any thoughts on this.

@ppham27
Copy link
Collaborator

ppham27 commented Dec 3, 2020

A single Titan X doesn't have enough HBM. For our GPU setup, we had 8 V100s for a total of 128GB of HBM. For a single Titan X, I think you could max out at batch size of 3, which is probably, too small. Adding an outer loop and doing gradient accumulation is probably the right way to address this. If there's a lot of interest in being able to train on a single GPU, we can look into this.

@mtreviso
Copy link

Having a way to turn on accumulate-gradients/update-freq would be amazing for reproducibility on GPUs. What is the best approach for doing this in JAX?

@vanzytay
Copy link
Collaborator

@MostafaDehghani has an example for this. Do you mind sharing it?

@MostafaDehghani
Copy link
Collaborator

Hi, thanks for the question.

Yes. I also think using gradient accumulation is the way to go. Here is an example of implementing it in JAX, which we used in another project, but I'm sure it's easily portable to LRA.
https://github.com/google-research/vision_transformer/blob/master/vit_jax/train.py#L63

Adding gradient accumulation to LRA is in our TODO list, but currently there a few higher priority fixes/features requests that we should take care of. In the meantime, a PR that adds it to our training loops is extremely welcome :)

@mtreviso
Copy link

mtreviso commented Dec 13, 2020

Hi, Mostafa!

Thank you for the quick response. I was able to adapt your code for text classification and it seems like the gradient accumulation is working fine. Since jax.fori_loop requires that the input and output remain with the same type and shape, I couldn't stack logits during accumulation. I've circumvented this by getting the logits later. Here is the code:

def train_step(optimizer, batch, learning_rate_fn, accum_steps, dropout_rng=None):
  train_keys = ['inputs', 'targets']
  (inputs, targets) = [batch.get(k, None) for k in train_keys]
  dropout_rng, new_dropout_rng = random.split(dropout_rng)

  def loss_fn(model, x, y):
    """Loss function used for training."""
    with nn.stochastic(dropout_rng):
      logits = model(x, train=True)
    loss, weight_sum = train_utils.compute_weighted_cross_entropy(
        logits, y, num_classes=CLASS_MAP[FLAGS.task_name], weights=None)
    mean_loss = loss / weight_sum
    return mean_loss, logits

  step = optimizer.state.step
  lr = learning_rate_fn(step)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

  # compute gradients and get logits
  _, grad = accumulate_gradient(grad_fn, optimizer.target, inputs, targets, accum_steps)
  grad = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), grad)
  logits = optimizer.target(inputs, train=False)
  # to save memory:
  # logits = optimizer.target(inputs[0][jnp.newaxis, ...], train=False)
  # for i in range(1, inputs.shape[0]):
  #   y_hat = optimizer.target(inputs[i][jnp.newaxis, ...], train=False)
  #   logits = jnp.concatenate((logits, y_hat), axis=0)

  new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
  metrics = compute_metrics(logits, targets, None)
  metrics['learning_rate'] = lr

  return new_optimizer, metrics, new_dropout_rng

def accumulate_gradient(loss_and_grad_fn, params, inputs, labels, accum_steps):
  """Accumulate gradient over multiple steps to save on memory."""
  if accum_steps and accum_steps > 1:
    assert inputs.shape[0] % accum_steps == 0, (
        f'Bad accum_steps {accum_steps} for batch size {inputs.shape[0]}')
    step_size = inputs.shape[0] // accum_steps
    (l, _), g = loss_and_grad_fn(params, inputs[:step_size], labels[:step_size])

    def acc_grad_and_loss(i, l_and_g):
      inps = jax.lax.dynamic_slice(inputs, (i * step_size, 0),
                                   (step_size,) + inputs.shape[1:])
      lbls = jax.lax.dynamic_slice(labels[..., jnp.newaxis], (i * step_size, 1),
                                   (step_size, 1)).squeeze(axis=-1)
      (li, _), gi = loss_and_grad_fn(params, inps, lbls)
      l, g = l_and_g
      return l + li, jax.tree_multimap(lambda x, y: x + y, g, gi)

    l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
    l, g = jax.tree_map(lambda x: x / accum_steps, (l, g))
    return l, g
  
  else:
    return loss_and_grad_fn(params, inputs, labels)

@mtreviso
Copy link

Hi! I got the following results on the test set by using a single GPU (24GB) and setting accum_steps=batch_size. All hyperparameters were kept intact, and the only thing that changed in the training procedure was the gradient accumulation part.

              1 GPU        TPU (paper)     Chance (baseline)
ListOps       0.1830       0.3637          0.10
Doc Class.    0.6323       0.6427          0.50
Retrieval     0.4752       0.5746          0.50

@vanzytay @MostafaDehghani Any idea on why?

Best,

@cifkao
Copy link
Contributor

cifkao commented Jan 26, 2021

I'm also running into memory issues. I've given up on the vanilla Transformer (this is a benchmark for efficient Transformers, after all), but even for the Performer, I need 2× Tesla V100 (32GB each).

Do you think it's possible to reproduce your results with, say, a batch size of 16 or 8 (and without changing the code)?

@GregorKobsik
Copy link

In Table 2 you given some insights on the 'peak memory usage' per device with a batch size of 32.
Do you refer to an effective batch size of 32 or to a batch size of 32 per device?

Can I expect to have a similar memory consumption on a single GPU with a batch size of 32 or 2?

@La-SilverLand
Copy link

La-SilverLand commented Aug 12, 2021

Hi, Mostafa!

Thank you for the quick response. I was able to adapt your code for text classification and it seems like the gradient accumulation is working fine. Since jax.fori_loop requires that the input and output remain with the same type and shape, I couldn't stack logits during accumulation. I've circumvented this by getting the logits later. Here is the code:

def train_step(optimizer, batch, learning_rate_fn, accum_steps, dropout_rng=None):
  train_keys = ['inputs', 'targets']
  (inputs, targets) = [batch.get(k, None) for k in train_keys]
  dropout_rng, new_dropout_rng = random.split(dropout_rng)

  def loss_fn(model, x, y):
    """Loss function used for training."""
    with nn.stochastic(dropout_rng):
      logits = model(x, train=True)
    loss, weight_sum = train_utils.compute_weighted_cross_entropy(
        logits, y, num_classes=CLASS_MAP[FLAGS.task_name], weights=None)
    mean_loss = loss / weight_sum
    return mean_loss, logits

  step = optimizer.state.step
  lr = learning_rate_fn(step)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

  # compute gradients and get logits
  _, grad = accumulate_gradient(grad_fn, optimizer.target, inputs, targets, accum_steps)
  grad = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), grad)
  logits = optimizer.target(inputs, train=False)
  # to save memory:
  # logits = optimizer.target(inputs[0][jnp.newaxis, ...], train=False)
  # for i in range(1, inputs.shape[0]):
  #   y_hat = optimizer.target(inputs[i][jnp.newaxis, ...], train=False)
  #   logits = jnp.concatenate((logits, y_hat), axis=0)

  new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
  metrics = compute_metrics(logits, targets, None)
  metrics['learning_rate'] = lr

  return new_optimizer, metrics, new_dropout_rng

def accumulate_gradient(loss_and_grad_fn, params, inputs, labels, accum_steps):
  """Accumulate gradient over multiple steps to save on memory."""
  if accum_steps and accum_steps > 1:
    assert inputs.shape[0] % accum_steps == 0, (
        f'Bad accum_steps {accum_steps} for batch size {inputs.shape[0]}')
    step_size = inputs.shape[0] // accum_steps
    (l, _), g = loss_and_grad_fn(params, inputs[:step_size], labels[:step_size])

    def acc_grad_and_loss(i, l_and_g):
      inps = jax.lax.dynamic_slice(inputs, (i * step_size, 0),
                                   (step_size,) + inputs.shape[1:])
      lbls = jax.lax.dynamic_slice(labels[..., jnp.newaxis], (i * step_size, 1),
                                   (step_size, 1)).squeeze(axis=-1)
      (li, _), gi = loss_and_grad_fn(params, inps, lbls)
      l, g = l_and_g
      return l + li, jax.tree_multimap(lambda x, y: x + y, g, gi)

    l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
    l, g = jax.tree_map(lambda x: x / accum_steps, (l, g))
    return l, g
  
  else:
    return loss_and_grad_fn(params, inputs, labels)

hi, i've also met the OOM problem with a V100 32GB card, really need the gradient accumulation
in your implementation, the state variable is missing
below is the original loss_fn code

def loss_fn(model, inputs, targets):
    with nn.stateful(state) as new_state:
      with nn.stochastic(dropout_rng):
        logits = model(inputs, train=True)
...
return mean_loss, (new_state, logits)

and the returned new_state is used for the next train_step by the train_loop method

  for step, batch in zip(range(start_step, num_train_steps), train_iter):
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
    optimizer, state, metrics, dropout_rngs = p_train_step(
        optimizer, state, batch, dropout_rng=dropout_rngs)

Would simply deleting this variable as in your implementation cause some problem in the training ?

@vladyorsh
Copy link

I'd also like to duplicate @La-SilverLand question. Currently I'm trying to fit the Pathfinder model code into a V100 GPU, and you have provided all tools for that except the answer about nn.stateful. I'm very new to JAX, so I can't tell, will it cripple the training process if I remove it.

@MostafaDehghani
Copy link
Collaborator

Sorry for the delay in my reply to this issue.
@EternalSorrrow, as long as you don't have anything that requires keeping some global statistics, (like BatchNorm) in your model, you can just delete the usage of state and nn.stateful and it should be all good.

If you needed a ResNet baseline that has BatchNorm, I recommend using the version with GroupNorm to avoid complication of handling batch statistic when using gradient accumulation.

@vladyorsh
Copy link

vladyorsh commented Oct 7, 2021

Thanks for response. It seems that in this case Transformer implementations in the repo should be fine (at least most of them) -- LayerNorms won't use batch-wise statistics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants