Skip to content

Conversation

@jlamypoirier
Copy link
Collaborator

✨ Description

Refactor MTP models to make them closer to non-MTP models, I.e. so that the common subset of config parameters, module name and parameter names matches exactly. This avoids lots of situations where we would otherwise have to take different code paths depending on whether MTP is enabled or not.

The MTP config is now just a standard LM config, with prediction_heads enabling it. This is essentially identical to what it used to be (before #370). The MTP block is configured from the decoder, using the last layer config, which removes a bit of generality but makes things way simpler.

As for the modules and weights, next-token-prediction head is standardized to base_model.head, while base_model.multi_token_prediction` optionally contains the MTP stuff. This makes it easier to compare weights between MTP and non-MTP models, and to use logit distillation (which should now fully support MTP).

@jlamypoirier jlamypoirier marked this pull request as ready for review February 11, 2026 21:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant