Skip to content

RuntimeError: MixtralForCausalLM float32 model errors at grouped_mm op during torch dynamo tracing #43541

@vakumar1

Description

@vakumar1

System Info

  • transformers version: 5.0.0
  • Platform: Linux-6.8.0-1040-aws-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 1.3.4
  • Safetensors version: 0.7.0
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.9.1+cu128 (NA)
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker @Cyrilvallez
(text models)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Code to reproduce

(note that here torch.compile is using the default inductor backend)

import torch
import transformers
from transformers import MixtralForCausalLM, MixtralConfig, AutoTokenizer
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

config = MixtralConfig(
    vocab_size=256,
    hidden_size=256,
    intermediate_size=256,
    num_hidden_layers=4,
    num_attention_heads=4,
    num_key_value_heads=4,
    hidden_act="silu",
    max_position_embeddings=256,
    initializer_range=0.02,
    rms_norm_eps=1e-5,
    use_cache=True,
    rope_theta=1000000.0,
    attention_dropout=0.0,
    num_local_experts=8,
    num_experts_per_tok=2,
    output_router_logits=False,
    router_aux_loss_coef=0.001,
)

model = MixtralForCausalLM(config)
compiled_model = torch.compile(model, backend="inductor", mode="reduce-overhead")
batch_size = 1
seq_length = 128
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length))
attention_mask = torch.ones((batch_size, seq_length))
with torch.no_grad():
    outputs = compiled_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        use_cache=False,
    )
    logits = outputs.logits

Error traceback

  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2096, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1511, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2755, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/_ops.py", line 841, in __call__
    return self._op(*args, **kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 309, in _fn
    result = fn(*args, **kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/_meta_registrations.py", line 7626, in meta_grouped_mm
    return _meta_grouped_mm_common(
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/_meta_registrations.py", line 7430, in _meta_grouped_mm_common
    torch._check(
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/__init__.py", line 1695, in _check
    _check_with(RuntimeError, cond, message)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/__init__.py", line 1677, in _check_with
    raise error_type(message_evaluated)
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in method _grouped_mm of type object at 0x7115a37dba00>(*(FakeTensor(..., size=(256, 256)), FakeTensor(..., size=(8, 256, 512), requires_grad=True)), **{'offs': FakeTensor(..., size=(8,), dtype=torch.int32)}): got RuntimeError('Expected inputs of BF16 type but got mat_a.dtype=torch.float32 and mat_b.dtype=torch.float32.')

from user code:
   File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 835, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 655, in forward
    outputs: MoeModelOutputWithPast = self.model(
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 1002, in wrapper
    outputs = func(self, *args, **kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 492, in forward
    hidden_states = decoder_layer(
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/transformers/modeling_layers.py", line 93, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 394, in forward
    hidden_states = self.mlp(hidden_states)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 134, in forward
    hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/transformers/integrations/moe.py", line 349, in forward
    return experts_forward(self, *args, **kwargs)
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/transformers/integrations/moe.py", line 263, in grouped_mm_experts_forward
    gate_up_out = _grouped_linear(
  File "/home/ubuntu/workplace/tc_moduscope/src/TorchNeuronEager/.venv/lib/python3.10/site-packages/transformers/integrations/moe.py", line 191, in _grouped_linear
    out = torch._grouped_mm(input, weight.transpose(-2, -1), offs=offs)

Expected behavior

When torch dynamo traces the model (specifically in this case, a model with float32 weights) it should not hit any errors. However, (i) the new MoE util code uses the grouped_mm op here, and (ii) the Mixtral model code does not cast the inputs to bfloat16 which is a constraint that the grouped_mm operands need to satisfy, which causes the dtype mismatch error

'Expected inputs of BF16 type but got mat_a.dtype=torch.float32 and mat_b.dtype=torch.float32.'

Is this error expected behavior? i.e., are these models only expected/guaranteed to work when the weight dtype is bfloat16? We have also seen the same error when using torch.compile with a few other language models using MoE (DeepseekV3, GptOss, Qwen3, Olmoe)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions