Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shape Error when running on GPU #187

Open
LasseRogers opened this issue Mar 21, 2024 · 2 comments
Open

Shape Error when running on GPU #187

LasseRogers opened this issue Mar 21, 2024 · 2 comments

Comments

@LasseRogers
Copy link

Hello @sanchit-gandhi

I am trying to run Whisper-JAX on a more granular level as illustrated by you in this example.

However I am getting ScopeParamShapeError that I can't seem to fix.

This is the example that I am trying to run:

import jax.numpy as jnp
from datasets import load_dataset
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax import device_get, pmap
from transformers import WhisperProcessor

from whisper_jax import FlaxWhisperForConditionalGeneration

# load the processor and model
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-large-v2", dtype=jnp.bfloat16, _do_init=False,
)

def generate_fn(input_features):
    pred_ids = model.generate(
        input_features, task="transcribe", return_timestamps=False, max_length=model.config.max_length, params=params,
    )
    return pred_ids.sequences

# pmap the generate function for data parallelism
p_generate = pmap(generate_fn, "input_features")
# replicate the parameters across devices
params = replicate(params)

# load a dummy sample from the LibriSpeech dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]

# pre-process: convert the audio array to log-mel input features
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="np").input_features

num_devices = 2
input_features = input_features.repeat(num_devices, axis=0)

# replicate the input features across devices for DP
input_features = shard(input_features)

# run the forward pass (JIT compiled the first time it is called)
pred_ids = p_generate(input_features)
output_ids = device_get(pred_ids.reshape(-1, model.config.max_length))

# post-process: convert tokens ids to text string
transcription = processor.batch_decode(pred_ids, skip_special_tokens=True)

and this is the error that I keep getting:

ScopeParamShapeError: Initializer expected to generate shape (2, 3, 80, 768) but got shape (3, 80, 768) instead for parameter "kernel" in "/model/encoder/conv1".

I am running the script in a Kaggle Notebook with 2x GPU T4's

@LasseRogers
Copy link
Author

This is the whole error I'm getting

---------------------------------------------------------------------------
ScopeParamShapeError                      Traceback (most recent call last)
Cell In[16], line 2
      1 # run the forward pass (JIT compiled the first time it is called)
----> 2 pred_ids = p_generate(input_features)
      3 output_ids = device_get(pred_ids.reshape(-1, model.config.max_length))

    [... skipping hidden 12 frame]

Cell In[9], line 8, in generate_fn(input_features)
      7 def generate_fn(input_features):
----> 8     pred_ids = model.generate(
      9         input_features, task="transcribe", return_timestamps=False, max_length=model.config.max_length, params=params,
     10     )
     11     return pred_ids.sequences

File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:1588, in FlaxWhisperForConditionalGeneration.generate(self, input_features, generation_config, logits_processor, return_timestamps, task, language, is_multilingual, **kwargs)
   1585 if len(forced_decoder_ids) > 0:
   1586     generation_config.forced_decoder_ids = forced_decoder_ids
-> 1588 return super().generate(
   1589     input_features,
   1590     generation_config,
   1591     logits_processor=logits_processor,
   1592     **kwargs,
   1593 )

File /opt/conda/lib/python3.10/site-packages/transformers/generation/flax_utils.py:372, in FlaxGenerationMixin.generate(self, input_ids, generation_config, prng_key, trace, params, logits_processor, **kwargs)
    369 if self.config.is_encoder_decoder:
    370     # add encoder_outputs to model_kwargs
    371     if model_kwargs.get("encoder_outputs") is None:
--> 372         model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
    373     # prepare decoder_input_ids for generation
    374     input_ids = self._prepare_decoder_input_ids_for_generation(
    375         batch_size,
    376         decoder_start_token_id=generation_config.decoder_start_token_id,
    377         bos_token_id=generation_config.bos_token_id,
    378         model_kwargs=model_kwargs,
    379     )

File /opt/conda/lib/python3.10/site-packages/transformers/generation/flax_utils.py:167, in FlaxGenerationMixin._prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs)
    161 def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
    162     encoder_kwargs = {
    163         argument: value
    164         for argument, value in model_kwargs.items()
    165         if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
    166     }
--> 167     model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
    168     return model_kwargs

File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:1132, in FlaxWhisperPreTrainedModel.encode(self, input_features, attention_mask, output_attentions, output_hidden_states, return_dict, train, params, dropout_rng, **kwargs)
   1129     encode_module = module._get_encoder_module()
   1130     return encode_module(input_features, **kwargs)
-> 1132 return self.module.apply(
   1133     {"params": params or self.params},
   1134     input_features=jnp.array(input_features, dtype="f4"),
   1135     output_attentions=output_attentions,
   1136     output_hidden_states=output_hidden_states,
   1137     return_dict=return_dict,
   1138     deterministic=not train,
   1139     rngs=rngs,
   1140     method=_encoder_forward,
   1141 )

    [... skipping hidden 4 frame]

File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:1130, in FlaxWhisperPreTrainedModel.encode.<locals>._encoder_forward(module, input_features, **kwargs)
   1128 def _encoder_forward(module, input_features, **kwargs):
   1129     encode_module = module._get_encoder_module()
-> 1130     return encode_module(input_features, **kwargs)

    [... skipping hidden 2 frame]

File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:824, in FlaxWhisperEncoder.__call__(self, input_features, output_attentions, output_hidden_states, return_dict, deterministic)
    817     raise ValueError(
    818         "input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
    819         f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be"
    820         f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))"
    821     )
    823 input_features = input_features.transpose(0, 2, 1)
--> 824 hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
    825 hidden_states = with_sharding_constraint(hidden_states, ("batch", "embed", "num_mel"))
    826 hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)

    [... skipping hidden 2 frame]

File /opt/conda/lib/python3.10/site-packages/whisper_jax/layers.py:1205, in _Conv.__call__(self, inputs)
   1200 if self.mask is not None and self.mask.shape != kernel_shape:
   1201     raise ValueError(
   1202         "Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}"
   1203     )
-> 1205 kernel = param_with_axes(
   1206     "kernel",
   1207     self.kernel_init,
   1208     kernel_shape,
   1209     self.params_dtype,
   1210     axes=self.kernel_axes,
   1211 )
   1213 if self.mask is not None:
   1214     kernel *= self.mask

File /opt/conda/lib/python3.10/site-packages/flax/linen/partitioning.py:159, in param_with_axes(name, init_fn, axes, module, *init_args, **init_kwargs)
    157   assert module is not None
    158 # define/fetch parameter on that module
--> 159 module_param = module.param(name, init_fn, *init_args, **init_kwargs)
    160 if axes is not None:
    161   # apply logical axis constraint immediately
    162   module_param = with_sharding_constraint(
    163       module_param, jax.sharding.PartitionSpec(*axes)
    164   )

    [... skipping hidden 1 frame]

File /opt/conda/lib/python3.10/site-packages/flax/core/scope.py:982, in Scope.param(self, name, init_fn, unbox, *init_args, **init_kwargs)
    977   for val, abs_val in zip(value_flat, abs_value_flat):
    978     # NOTE: We could check dtype consistency here as well but it's
    979     # usefuleness is less obvious. We might intentionally change the dtype
    980     # for inference to a half float type for example.
    981     if jnp.shape(val) != jnp.shape(abs_val):
--> 982       raise errors.ScopeParamShapeError(
    983         name, self.path_text, jnp.shape(abs_val), jnp.shape(val)
    984       )
    985 else:
    986   if not self.is_mutable_collection('params'):

ScopeParamShapeError: Initializer expected to generate shape (2, 3, 80, 768) but got shape (3, 80, 768) instead for parameter "kernel" in "/model/encoder/conv1"

@nairajay2k
Copy link

I am also getting this error? I tried downgrading whisper-jax to an older commit and other things start breaking.
@sanchit-gandhi Pl let us know how to solve this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants