Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ dependencies = [
plotting = ["matplotlib>=3.10.1", "seaborn>=0.13.2"]

backend = [
"peft>=0.14.0",
"peft>=0.18.0",
"hf-xet>=1.1.0",
"bitsandbytes>=0.45.2",
"unsloth==2025.12.9",
"unsloth-zoo==2025.12.7",
"unsloth==2026.2.1",
"unsloth-zoo==2026.2.1",
"torch>=2.8.0",
"torchao==0.14.1",
"accelerate==1.7.0",
"awscli>=1.38.1",
"setuptools>=78.1.0",
"wandb==0.24.0",
"transformers>=4.55.2,<=4.57.3",
"transformers==5.1.0",
"duckdb>=1.0.0",
"pyarrow>=15.0.0",
"trl==0.20.0",
Expand Down Expand Up @@ -65,7 +65,7 @@ tinker = [
"pydantic>=2.12.5",
"tinker>=0.8.1",
"torch>=2.8.0",
"transformers>=4.55.2,<=4.57.3",
"transformers==5.1.0",
"uvicorn>=0.35.0",
"datrie>=0.8.3",
]
Expand Down Expand Up @@ -122,7 +122,13 @@ required-version = ">=0.6.15"
# Override numpy to <2.0 for compatibility with megatron-core in the training
# environment. vLLM 0.15.1 pulls opencv-python-headless>=4.13 which wants
# numpy>=2 on Python 3.9+, but megatron-core requires numpy<2.
override-dependencies = ["transformer-engine>=2.11.0", "numpy<2"]
override-dependencies = [
"transformer-engine>=2.11.0",
"numpy<2",
# Override unsloth's overly strict constraint on transformers — v5.x
# is confirmed working per unsloth February-2026 release notes
"transformers==5.1.0",
]
exclude-dependencies = ["pynvml"]
no-build-isolation-package = ["apex", "transformer-engine-torch", "nv-grouped-gemm"]

Expand Down
6 changes: 5 additions & 1 deletion src/art/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ def __init__(self, **kwargs):
import transformers

try:
from .transformers.patches import patch_preprocess_mask_arguments
from .transformers.patches import (
patch_apply_chat_template,
patch_preprocess_mask_arguments,
)

patch_preprocess_mask_arguments()
patch_apply_chat_template()
except Exception:
pass
except ImportError:
Expand Down
13 changes: 0 additions & 13 deletions src/art/dev/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ class PeftArgs(TypedDict, total=False):

class TrainerArgs(TypedDict, total=False):
output_dir: str | None
overwrite_output_dir: bool
do_train: bool
do_eval: bool
do_predict: bool
Expand Down Expand Up @@ -219,7 +218,6 @@ class TrainerArgs(TypedDict, total=False):
log_level: str
log_level_replica: str
log_on_each_node: bool
logging_dir: str | None
logging_strategy: "IntervalStrategy | str"
logging_first_step: bool
logging_steps: float
Expand All @@ -236,25 +234,21 @@ class TrainerArgs(TypedDict, total=False):
use_mps_device: bool
seed: int
data_seed: int | None
jit_mode_eval: bool
use_ipex: bool
bf16: bool
fp16: bool
fp16_opt_level: str
half_precision_backend: str
bf16_full_eval: bool
fp16_full_eval: bool
tf32: bool | None
local_rank: int
ddp_backend: str | None
tpu_num_cores: int | None
tpu_metrics_debug: bool
debug: str | list[DebugOption]
dataloader_drop_last: bool
eval_steps: float | None
dataloader_num_workers: int
dataloader_prefetch_factor: int | None
past_index: int
run_name: str | None
disable_tqdm: bool | None
remove_unused_columns: bool | None
Expand Down Expand Up @@ -295,15 +289,8 @@ class TrainerArgs(TypedDict, total=False):
include_inputs_for_metrics: bool
include_for_metrics: list[str]
eval_do_concat_batches: bool
fp16_backend: str
push_to_hub_model_id: str | None
push_to_hub_organization: str | None
push_to_hub_token: str | None
mp_parameters: str
auto_find_batch_size: bool
full_determinism: bool
torchdynamo: str | None
ray_scope: str | None
ddp_timeout: int
torch_compile: bool
torch_compile_backend: str | None
Expand Down
22 changes: 21 additions & 1 deletion src/art/transformers/patches.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
from typing import TYPE_CHECKING, Optional, Union

import torch
from transformers import masking_utils
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

if TYPE_CHECKING:
from torch.nn.attention.flex_attention import BlockMask
Expand All @@ -19,7 +21,9 @@ def _patched_preprocess_mask_arguments(
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
layer_idx: Optional[int],
) -> tuple[bool, Optional[Union[torch.Tensor, "BlockMask"]], int, int]:
) -> tuple[
bool, Optional[Union[torch.Tensor, "BlockMask"]], Optional[torch.Tensor], int, int
]:
if position_ids is not None and len(position_ids.shape) == 3:
position_ids = position_ids[0]
return _preprocess_mask_arguments(
Expand All @@ -35,3 +39,19 @@ def _patched_preprocess_mask_arguments(

def patch_preprocess_mask_arguments() -> None:
masking_utils._preprocess_mask_arguments = _patched_preprocess_mask_arguments # ty:ignore[invalid-assignment]


def patch_apply_chat_template() -> None:
"""Default return_dict=False in apply_chat_template for transformers v5.

Transformers v5 changed the default from list[int] to BatchEncoding.
This restores the v4 behavior so all call sites get list[int] back.
"""
original = PreTrainedTokenizerBase.apply_chat_template

@functools.wraps(original)
def _patched(self, *args, **kwargs): # type: ignore
kwargs.setdefault("return_dict", False)
return original(self, *args, **kwargs)

PreTrainedTokenizerBase.apply_chat_template = _patched # type: ignore
10 changes: 9 additions & 1 deletion src/art/unsloth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from datasets import Dataset
import peft
import torch
from transformers import GenerationMixin, PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils.dummy_pt_objects import GenerationMixin, PreTrainedModel
from trl import GRPOConfig, GRPOTrainer
from vllm import AsyncEngineArgs
from vllm.lora.request import LoRARequest
Expand All @@ -25,6 +25,7 @@
packed_tensors_from_dir,
)
from ..preprocessing.tokenize import SFTBatch
from ..utils.convert_moe_lora import convert_checkpoint_if_needed
from ..utils.get_model_step import get_step_from_dir
from ..utils.output_dirs import get_step_checkpoint_dir
from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers
Expand Down Expand Up @@ -149,6 +150,7 @@ def save_checkpoint(
checkpoint_dir = get_step_checkpoint_dir(output_dir, next_step)
os.makedirs(checkpoint_dir, exist_ok=True)
trainer.save_model(checkpoint_dir)
convert_checkpoint_if_needed(checkpoint_dir)
return checkpoint_dir


Expand Down Expand Up @@ -280,6 +282,7 @@ async def start_openai_server(
lora_path = get_step_checkpoint_dir(self.output_dir, 0)
os.makedirs(os.path.dirname(lora_path), exist_ok=True)
self._state.trainer.save_model(lora_path)
convert_checkpoint_if_needed(lora_path)
self._latest_step = 0
else:
# Extract step from checkpoint path
Expand Down Expand Up @@ -667,6 +670,11 @@ def _state(self) -> UnslothState:
),
)

# Unsloth's model patching can leave the PEFT model without
# `warnings_issued`, which GRPOTrainer expects during init.
if not hasattr(peft_model, "warnings_issued"):
peft_model.warnings_issued = {} # type: ignore[attr-defined]

# Initialize trainer with dummy dataset
data = {"prompt": ""}
trainer = GRPOTrainer(
Expand Down
177 changes: 177 additions & 0 deletions src/art/utils/convert_moe_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""Convert fused MoE LoRA adapters to per-expert format for vLLM compatibility.

Unsloth with transformers v5 saves MoE expert LoRA as fused 2D tensors:
mlp.experts.base_layer.lora_A [num_experts*rank, intermediate*2] (gate_up_proj)
mlp.experts.base_layer.lora_B [hidden, num_experts*rank] (gate_up_proj)
mlp.experts.lora_A [num_experts*rank, hidden] (down_proj)
mlp.experts.lora_B [intermediate, num_experts*rank] (down_proj)

vLLM expects per-expert keys:
mlp.experts.0.gate_proj.lora_A [rank, hidden]
mlp.experts.0.gate_proj.lora_B [intermediate, rank]
...
"""

import json
import os
import re

import safetensors.torch
import torch


def _has_fused_moe_lora(tensors: dict[str, torch.Tensor]) -> bool:
"""Check if the adapter contains fused MoE LoRA tensors."""
return any(
re.search(r"mlp\.experts\.(base_layer\.)?lora_[AB]\.weight$", key)
for key in tensors
)


def _infer_moe_params(
tensors: dict[str, torch.Tensor],
adapter_config: dict,
) -> tuple[int, int, int, int]:
"""Infer num_experts, rank, intermediate_size, hidden_size from tensor shapes."""
rank = adapter_config.get("r", adapter_config.get("lora_rank", 8))

for key, tensor in tensors.items():
# gate_up_proj lora_A: [num_experts*rank, intermediate*2]
if re.search(r"mlp\.experts\.base_layer\.lora_A\.weight$", key):
num_experts_times_rank = tensor.shape[0]
intermediate_times_2 = tensor.shape[1]
num_experts = num_experts_times_rank // rank
intermediate_size = intermediate_times_2 // 2
break
# down_proj lora_B: [intermediate, num_experts*rank]
if re.search(r"mlp\.experts\.lora_B\.weight$", key):
intermediate_size = tensor.shape[0]
num_experts = tensor.shape[1] // rank
break
else:
raise ValueError("Could not find fused MoE tensors to infer parameters")

# Get hidden_size from gate_up_proj lora_B: [hidden, num_experts*rank]
for key, tensor in tensors.items():
if re.search(r"mlp\.experts\.base_layer\.lora_B\.weight$", key):
hidden_size = tensor.shape[0]
break
else:
raise ValueError("Could not find gate_up_proj lora_B to infer hidden_size")

return num_experts, rank, intermediate_size, hidden_size


def convert_fused_moe_lora(
tensors: dict[str, torch.Tensor],
num_experts: int,
rank: int,
intermediate_size: int,
hidden_size: int,
) -> dict[str, torch.Tensor]:
"""Convert fused MoE LoRA tensors to per-expert format.

Non-expert tensors (e.g. self_attn) are passed through unchanged.
"""
new_tensors: dict[str, torch.Tensor] = {}

for key, tensor in tensors.items():
# Non-expert tensors: keep as-is
m = re.match(
r"(.*\.mlp\.experts)\.(base_layer\.lora_(A|B)|lora_(A|B))\.weight$",
key,
)
if not m:
new_tensors[key] = tensor
continue

prefix = m.group(1)
is_base_layer = "base_layer" in key
is_A = "lora_A" in key

if is_base_layer:
# gate_up_proj (fused gate + up)
if is_A:
# [num_experts*rank, intermediate*2] → per expert
per_expert = tensor.reshape(num_experts, rank, intermediate_size * 2)
for e in range(num_experts):
expert_a = per_expert[e] # [rank, intermediate*2]
gate_a = expert_a[:, :intermediate_size]
up_a = expert_a[:, intermediate_size:]
new_tensors[f"{prefix}.{e}.gate_proj.lora_B.weight"] = (
gate_a.T.contiguous()
)
new_tensors[f"{prefix}.{e}.up_proj.lora_B.weight"] = (
up_a.T.contiguous()
)
else:
# [hidden, num_experts*rank] → per expert
per_expert = tensor.reshape(hidden_size, num_experts, rank)
for e in range(num_experts):
expert_b = per_expert[:, e, :] # [hidden, rank]
new_tensors[f"{prefix}.{e}.gate_proj.lora_A.weight"] = (
expert_b.T.contiguous()
)
new_tensors[f"{prefix}.{e}.up_proj.lora_A.weight"] = (
expert_b.T.contiguous()
)
else:
# down_proj
if is_A:
# [num_experts*rank, hidden] → per expert
per_expert = tensor.reshape(num_experts, rank, hidden_size)
for e in range(num_experts):
expert_a = per_expert[e] # [rank, hidden]
new_tensors[f"{prefix}.{e}.down_proj.lora_B.weight"] = (
expert_a.T.contiguous()
)
else:
# [intermediate, num_experts*rank] → per expert
per_expert = tensor.reshape(intermediate_size, num_experts, rank)
for e in range(num_experts):
expert_b = per_expert[:, e, :] # [intermediate, rank]
new_tensors[f"{prefix}.{e}.down_proj.lora_A.weight"] = (
expert_b.T.contiguous()
)

return new_tensors


def convert_checkpoint_if_needed(checkpoint_dir: str) -> None:
"""Convert a checkpoint's MoE LoRA adapter to per-expert format if needed.

This is a no-op for non-MoE adapters.
"""
adapter_path = os.path.join(checkpoint_dir, "adapter_model.safetensors")
config_path = os.path.join(checkpoint_dir, "adapter_config.json")

if not os.path.exists(adapter_path) or not os.path.exists(config_path):
return

tensors = safetensors.torch.load_file(adapter_path)
if not _has_fused_moe_lora(tensors):
return

with open(config_path) as f:
adapter_config = json.load(f)

num_experts, rank, intermediate_size, hidden_size = _infer_moe_params(
tensors, adapter_config
)

new_tensors = convert_fused_moe_lora(
tensors, num_experts, rank, intermediate_size, hidden_size
)

# Overwrite the adapter with the converted tensors
safetensors.torch.save_file(new_tensors, adapter_path)

# Update adapter_config.json target_modules
adapter_config["target_modules"] = [
m for m in adapter_config.get("target_modules", []) if "experts" not in m
] + ["gate_proj", "up_proj", "down_proj"]
# Remove target_parameters if present (not needed for per-expert format)
adapter_config.pop("target_parameters", None)

with open(config_path, "w") as f:
json.dump(adapter_config, f, indent=2)
Loading