Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
✨ Description
Refactor MTP models to make them closer to non-MTP models, I.e. so that the common subset of config parameters, module name and parameter names matches exactly. This avoids lots of situations where we would otherwise have to take different code paths depending on whether MTP is enabled or not.
The MTP config is now just a standard LM config, with
prediction_headsenabling it. This is essentially identical to what it used to be (before #370). The MTP block is configured from the decoder, using the last layer config, which removes a bit of generality but makes things way simpler.As for the modules and weights, next-token-prediction head is standardized to
base_model.head, whilebase_model.multi_token_prediction` optionally contains the MTP stuff. This makes it easier to compare weights between MTP and non-MTP models, and to use logit distillation (which should now fully support MTP).