Skip to content

Error when exporting LiquidAI/LFM2.5-1.2B-Instruct #17439

@msluszniak

Description

@msluszniak

🐛 Describe the bug

I tried to export LiquidAI/LFM2.5-1.2B-Instruct. When input_ids was set to a fixed size, e.g. 128 export works as well as model inference. However, when added dynamic shape into this script (regular export, just applied monkey patch to remove data dependent flow and made function return a tensor instead of a complex type):

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from transformers.models.lfm2.modeling_lfm2 import Lfm2ShortConv
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower
from torch.export import Dim

def apply_mask_to_padding_states(hidden_states, attention_mask):
    if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
        dtype = hidden_states.dtype
        hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
    return hidden_states

def patched_slow_forward(
    self,
    x: torch.Tensor,
    past_key_values = None,
    cache_position: torch.LongTensor | None = None,
    attention_mask: torch.Tensor | None = None,
):
    seqlen = x.shape[1]
    x = apply_mask_to_padding_states(x, attention_mask)
    BCx = self.in_proj(x).transpose(-1, -2)
    B, C, x = BCx.chunk(3, dim=-2)
    Bx = B * x

    if past_key_values is not None:
        conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
        past_key_values.conv_cache[self.layer_idx].copy_(conv_state)

    conv_out = self.conv(Bx)[..., :seqlen]
    y = C * conv_out
    y = y.transpose(-1, -2).contiguous()
    y = self.out_proj(y)
    return y

class ExportWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids):
        outputs = self.model(input_ids=input_ids, use_cache=False)
        return (outputs.logits,)

print("Applying monkey patch...")
original_slow_forward = Lfm2ShortConv.slow_forward
Lfm2ShortConv.slow_forward = patched_slow_forward

try:
    print("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        "LiquidAI/LFM2.5-1.2B-Instruct", 
        trust_remote_code=True,
        torch_dtype=torch.float32 
    )
    model.eval()
    wrapper = ExportWrapper(model)

    print("Preparing inputs...")
    sample_inputs = (
        torch.randint(1, model.config.vocab_size, (1, 128), dtype=torch.long), 
    )

    dynamic_shapes = {
        "input_ids": {
            1: Dim(
                "tokens", min=1, max=2048
            ),
        }
    }

    print("Starting export (this captures the graph)...")
    exported_program = torch.export.export(wrapper, sample_inputs, dynamic_shapes=dynamic_shapes)
    
    print("Lowering to ExecuTorch (XNNPACK)...")
    et_program = to_edge_transform_and_lower(
        exported_program,
        partitioner=[XnnpackPartitioner()]
    ).to_executorch()

    output_filename = "LFM2.5-1.2B-Instruct.pte"
    with open(output_filename, "wb") as f:
        f.write(et_program.buffer)

    print(f"✅ Success! Exported to {output_filename}")

finally:
    Lfm2ShortConv.slow_forward = original_slow_forward
    print("Restored original model methods.")

I got the following error:

---------------------------------------------------------------------------
ConstraintViolationError                  Traceback (most recent call last)
File /opt/miniconda3/lib/python3.13/site-packages/torch/export/_trace.py:1933, in _export_to_aten_ir_make_fx(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform)
   1932 try:
-> 1933     produce_guards_callback(gm)
   1934 except (ConstraintViolationError, ValueRangeError) as e:

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/_trace.py:2080, in _non_strict_export.<locals>._produce_guards_callback(gm)
   2079 def _produce_guards_callback(gm):
-> 2080     return produce_guards_and_solve_constraints(
   2081         fake_mode=fake_mode,
   2082         gm=gm,
   2083         dynamic_shapes=dynamic_shapes,
   2084         equalities_inputs=equalities_inputs,
   2085         original_signature=original_signature,
   2086     )

File /opt/miniconda3/lib/python3.13/site-packages/torch/_export/non_strict_utils.py:603, in produce_guards_and_solve_constraints(fake_mode, gm, dynamic_shapes, equalities_inputs, original_signature)
    602 if constraint_violation_error:
--> 603     raise constraint_violation_error

File /opt/miniconda3/lib/python3.13/site-packages/torch/_export/non_strict_utils.py:565, in produce_guards_and_solve_constraints(fake_mode, gm, dynamic_shapes, equalities_inputs, original_signature)
    564 try:
--> 565     shape_env.produce_guards(
    566         placeholders,
    567         sources,
    568         input_contexts=input_contexts,
    569         equalities_inputs=equalities_inputs,
    570         ignore_static=False,
    571     )
    572 except ConstraintViolationError as e:

File /opt/miniconda3/lib/python3.13/site-packages/torch/fx/experimental/symbolic_shapes.py:5315, in ShapeEnv.produce_guards(self, *args, **kwargs)
   5311 """
   5312 Like produce_guards_verbose, but only returns the non-verbose python guard expressions
   5313 (no verbose guards produced.)
   5314 """
-> 5315 return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs

File /opt/miniconda3/lib/python3.13/site-packages/torch/fx/experimental/symbolic_shapes.py:6056, in ShapeEnv.produce_guards_verbose(self, placeholders, sources, source_ref, guards, input_contexts, equalities_inputs, _simplified, ignore_static, langs)
   6055     err = "\n".join(error_msgs)
-> 6056     raise ConstraintViolationError(
   6057         f"Constraints violated ({debug_names_str})! "
   6058         'For more information, run with TORCH_LOGS="+dynamic".\n'
   6059         f"{err}"
   6060     )
   6061 elif len(warn_msgs) > 0:

ConstraintViolationError: Constraints violated (tokens)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of tokens = L['input_ids'].size()[1] in the specified range tokens <= 2048 satisfy the generated guard min(64*L['input_ids'].size()[1], 256*L['input_ids'].size()[1]) == (64*L['input_ids'].size()[1]).

During handling of the above exception, another exception occurred:

UserError                                 Traceback (most recent call last)
Cell In[5], line 89
     87 print("Starting export (this captures the graph)...")
     88 # Export the WRAPPER, not the raw model
---> 89 exported_program = torch.export.export(wrapper, sample_inputs, dynamic_shapes=dynamic_shapes)
     91 print("Lowering to ExecuTorch (XNNPACK)...")
     92 et_program = to_edge_transform_and_lower(
     93     exported_program,
     94     partitioner=[XnnpackPartitioner()]
     95 ).to_executorch()

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/__init__.py:311, in export(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, prefer_deferred_runtime_asserts_over_guards)
    309     new_msg = str(e) + "\n\n" + draft_export_msg
    310     e.args = (new_msg,)
--> 311 raise e

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/__init__.py:277, in export(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, prefer_deferred_runtime_asserts_over_guards)
    270     raise ValueError(
    271         "Exporting a ScriptModule is not supported. "
    272         "Maybe try converting your ScriptModule to an ExportedProgram "
    273         "using `TS2EPConverter(mod, args, kwargs).convert()` instead."
    274     )
    276 try:
--> 277     return _export(
    278         mod,
    279         args,
    280         kwargs,
    281         dynamic_shapes,
    282         strict=strict,
    283         preserve_module_call_signature=preserve_module_call_signature,
    284         pre_dispatch=True,
    285         prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
    286     )
    287 except Exception as e:
    288     draft_export_msg = (
    289         "The error above occurred when calling torch.export.export. If you would "
    290         "like to view some more information about this error, and get a list "
    291         "of all other errors that may occur in your export call, you can "
    292         "replace your `export()` call with `draft_export()`."
    293     )

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/_trace.py:1271, in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
   1265     if hasattr(e, "partial_fx_graph"):
   1266         print(
   1267             e.partial_fx_graph,
   1268             file=sys.stderr,
   1269         )
-> 1271     raise e
   1272 finally:
   1273     _EXPORT_FLAGS = None

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/_trace.py:1237, in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
   1235 try:
   1236     start = time.time()
-> 1237     ep = fn(*args, **kwargs)
   1238     end = time.time()
   1239     log_export_usage(
   1240         event="export.time",
   1241         metrics=end - start,
   1242         flags=_EXPORT_FLAGS,
   1243         **get_ep_stats(ep),
   1244     )

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/exported_program.py:124, in _disable_prexisiting_fake_mode.<locals>.wrapper(*args, **kwargs)
    121 @functools.wraps(fn)
    122 def wrapper(*args, **kwargs):
    123     with unset_fake_temporarily():
--> 124         return fn(*args, **kwargs)

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/_trace.py:2377, in _export(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, pre_dispatch, prefer_deferred_runtime_asserts_over_guards)
   2369 # NOTE Export training IR rollout
   2370 # Old export calls export._trace(pre_dispatch=True)
   2371 # and there are still lot of internal/OSS callsites that
   (...)   2374 # export_training_ir_rollout_check returns True in OSS
   2375 # while internally it returns False UNLESS otherwise specified.
   2376 if pre_dispatch and export_training_ir_rollout_check():
-> 2377     ep = _export_for_training(
   2378         mod,
   2379         args,
   2380         kwargs,
   2381         dynamic_shapes,
   2382         strict=strict,
   2383         preserve_module_call_signature=preserve_module_call_signature,
   2384         prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
   2385     )
   2386     dtrace_structured("exported_program", payload_fn=lambda: str(ep))
   2387     return ep

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/_trace.py:1271, in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
   1265     if hasattr(e, "partial_fx_graph"):
   1266         print(
   1267             e.partial_fx_graph,
   1268             file=sys.stderr,
   1269         )
-> 1271     raise e
   1272 finally:
   1273     _EXPORT_FLAGS = None

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/_trace.py:1237, in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
   1235 try:
   1236     start = time.time()
-> 1237     ep = fn(*args, **kwargs)
   1238     end = time.time()
   1239     log_export_usage(
   1240         event="export.time",
   1241         metrics=end - start,
   1242         flags=_EXPORT_FLAGS,
   1243         **get_ep_stats(ep),
   1244     )

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/exported_program.py:124, in _disable_prexisiting_fake_mode.<locals>.wrapper(*args, **kwargs)
    121 @functools.wraps(fn)
    122 def wrapper(*args, **kwargs):
    123     with unset_fake_temporarily():
--> 124         return fn(*args, **kwargs)

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/_trace.py:2185, in _export_for_training(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, prefer_deferred_runtime_asserts_over_guards)
   2181     from torch._subclasses.fake_tensor import fake_tensor_tls
   2183     fake_tensor_tls.non_strict_export_fake_tensor_tracker.clear()
-> 2185 export_artifact = export_func(
   2186     mod=mod,
   2187     args=args,
   2188     kwargs=kwargs,
   2189     dynamic_shapes=dynamic_shapes,
   2190     preserve_module_call_signature=preserve_module_call_signature,
   2191     orig_in_spec=orig_in_spec,
   2192     prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
   2193     _to_aten_func=_export_to_aten_ir_make_fx,
   2194 )
   2196 # If we are tracing with fake inputs, it is expected to
   2197 # see fake tensor constants.
   2198 if not strict and not has_ambient_mode:

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/_trace.py:2116, in _non_strict_export(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, orig_in_spec, prefer_deferred_runtime_asserts_over_guards, _to_aten_func)
   2097 with (
   2098     fake_mode,
   2099     _NonStrictTorchFunctionHandler(),
   2100     tracing(tx),
   2101     torch._dynamo.config.patch(dynamo_config),
   2102 ):
   2103     with (
   2104         _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
   2105             patched_mod,
   (...)   2114         # _to_aten_func is _export_to_aten_ir when using the default non-strict export
   2115         # We need to pass positional args correctly
-> 2116         aten_export_artifact = _to_aten_func(
   2117             patched_mod,
   2118             new_fake_args,
   2119             new_fake_kwargs,
   2120             fake_params_buffers,
   2121             new_fake_constant_attrs,
   2122             produce_guards_callback=_produce_guards_callback,
   2123             transform=_tuplify_outputs,
   2124         )
   2125         # aten_export_artifact.constants contains only fake script objects, we need to map them back
   2126         aten_export_artifact.constants = {
   2127             fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj
   2128             for fqn, obj in aten_export_artifact.constants.items()
   2129         }

File /opt/miniconda3/lib/python3.13/site-packages/torch/export/_trace.py:1935, in _export_to_aten_ir_make_fx(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform)
   1933         produce_guards_callback(gm)
   1934     except (ConstraintViolationError, ValueRangeError) as e:
-> 1935         raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
   1937 return _produce_aten_artifact(
   1938     gm=gm,
   1939     mod=mod,
   (...)   1945     fake_params_buffers=fake_params_buffers,
   1946 )

UserError: Constraints violated (tokens)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of tokens = L['input_ids'].size()[1] in the specified range tokens <= 2048 satisfy the generated guard min(64*L['input_ids'].size()[1], 256*L['input_ids'].size()[1]) == (64*L['input_ids'].size()[1]).

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

Look that this error message makes no sense:

Not all values of tokens = L['input_ids'].size()[1] in the specified range tokens <= 2048 satisfy the generated guard min(64*L['input_ids'].size()[1], 256*L['input_ids'].size()[1]) == (64*L['input_ids'].size()[1]).

Assuming that ['input_ids'].size() values are non-negative this condition is always true.

Versions

Collecting environment information...
PyTorch version: 2.10.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.2 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: version 4.1.0
Libc version: N/A

Python version: 3.13.5 | packaged by Anaconda, Inc. | (main, Jun 12 2025, 11:23:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-26.2-arm64-arm-64bit-Mach-O
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] executorch==1.1.0
[pip3] numpy==2.4.2
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0
[conda] executorch 1.1.0 pypi_0 pypi
[conda] numpy 2.3.5 pypi_0 pypi
[conda] pytorch-tokenizers 1.1.0 pypi_0 pypi
[conda] torch 2.10.0 pypi_0 pypi
[conda] torchao 0.15.0 pypi_0 pypi
[conda] torchvision 0.25.0 pypi_0 pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions