From a0f0f8cef4bcef19cc1a37b486139aeb29da2961 Mon Sep 17 00:00:00 2001 From: mukesh reddy Date: Wed, 4 Feb 2026 09:37:49 -0500 Subject: [PATCH 1/8] feat: Add SGLang backend integration --- docs/sglang-integration.md | 301 +++++++++++++ scripts/setup_sglang.sh | 122 ++++++ scripts/test_sglang_e2e.py | 209 ++++++++++ src/art/sglang_backend/__init__.py | 53 +++ src/art/sglang_backend/backend.py | 293 +++++++++++++ src/art/sglang_backend/config.py | 203 +++++++++ src/art/sglang_backend/service.py | 650 +++++++++++++++++++++++++++++ src/art/unsloth/training_utils.py | 128 ++++++ 8 files changed, 1959 insertions(+) create mode 100644 docs/sglang-integration.md create mode 100644 scripts/setup_sglang.sh create mode 100644 scripts/test_sglang_e2e.py create mode 100644 src/art/sglang_backend/__init__.py create mode 100644 src/art/sglang_backend/backend.py create mode 100644 src/art/sglang_backend/config.py create mode 100644 src/art/sglang_backend/service.py create mode 100644 src/art/unsloth/training_utils.py diff --git a/docs/sglang-integration.md b/docs/sglang-integration.md new file mode 100644 index 000000000..45c7efe67 --- /dev/null +++ b/docs/sglang-integration.md @@ -0,0 +1,301 @@ +# SGLang Backend Integration + +ART supports SGLang as an alternative inference engine to vLLM. SGLang offers +potentially faster inference for agent trajectories due to its RadixAttention +prefix caching mechanism. + +## Architecture + +### Multi-GPU Split Mode (Recommended) + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Multi-GPU Split Architecture │ +├─────────────────────────────────────────────────────────────────┤ +│ │ +│ GPU 0: SGLang Inference Server │ +│ ┌────────────────────────────────────────────────────────────┐ │ +│ │ • RadixAttention cache (PERSISTENT across training) │ │ +│ │ • OpenAI-compatible API on localhost:8000 │ │ +│ │ • LoRA hot-reload via /update_weights_from_lora │ │ +│ │ • No restart needed = cache stays warm │ │ +│ └────────────────────────────────────────────────────────────┘ │ +│ │ +│ GPU 1+: Training (Unsloth/GRPO) │ +│ ┌────────────────────────────────────────────────────────────┐ │ +│ │ • PEFT/LoRA model │ │ +│ │ • Optimizer states │ │ +│ │ • Gradient computation │ │ +│ │ • Checkpoint saving │ │ +│ └────────────────────────────────────────────────────────────┘ │ +│ │ +│ Weight Sync: Hot-reload via HTTP API (~5-10s) │ +│ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### Single-GPU Fallback Mode + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Single-GPU Shared Mode │ +├─────────────────────────────────────────────────────────────────┤ +│ │ +│ GPU 0: Time-multiplexed │ +│ │ +│ [Inference Phase] │ +│ ┌────────────────────────────────────────────────────────────┐ │ +│ │ SGLang Server running │ │ +│ │ Training model offloaded to CPU │ │ +│ └────────────────────────────────────────────────────────────┘ │ +│ ↓ Stop server │ +│ [Training Phase] │ +│ ┌────────────────────────────────────────────────────────────┐ │ +│ │ Training model on GPU │ │ +│ │ SGLang server stopped │ │ +│ └────────────────────────────────────────────────────────────┘ │ +│ ↓ Restart server │ +│ [Inference Phase] │ +│ ┌────────────────────────────────────────────────────────────┐ │ +│ │ SGLang Server running (cache cleared) │ │ +│ └────────────────────────────────────────────────────────────┘ │ +│ │ +│ Weight Sync: Server restart (~30-60s, cache lost) │ +│ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Why SGLang? + +| Feature | vLLM | SGLang | Benefit for RL | +|---------|------|--------|----------------| +| Prefix Caching | PagedAttention | RadixAttention (automatic LRU) | Better multi-turn perf | +| Cache Persistence | Manual | Automatic | Less memory management | +| Scheduling | Continuous batching | Zero-overhead | Lower latency | +| Structured Outputs | Native | Optimized | Faster tool calls | +| Weight Updates | LoRA add | Hot-reload API | No restart needed | + +**Key benefit**: SGLang's RadixAttention automatically caches common prefixes across +requests. For RL training where many rollouts share the same system prompt and context, +this provides significant speedups. + +## Installation + +**CRITICAL**: SGLang and vLLM have conflicting PyTorch dependencies. You MUST use +separate virtual environments. + +### vLLM Environment (Default) + +```bash +python -m venv .venv-vllm +source .venv-vllm/bin/activate +pip install openpipe-art[backend] +``` + +### SGLang Environment + +```bash +python -m venv .venv-sglang +source .venv-sglang/bin/activate +pip install openpipe-art[sglang] +``` + +## Usage + +### Basic Usage (Auto-detect GPUs) + +```python +from art.sglang_backend import SGLangBackend +import art + +model = art.TrainableModel( + name="my-model", + base_model="Qwen/Qwen2.5-3B-Instruct", + project="my-project", +) + +# Auto-detects GPU count: +# - 2+ GPUs: split mode (recommended) +# - 1 GPU: shared mode (fallback) +backend = SGLangBackend() +await backend.register(model) + +# Everything else works like LocalBackend +result = await backend.train(model, trajectory_groups) +``` + +### Explicit Device Configuration + +```python +from art.sglang_backend import SGLangBackend, DeviceConfig, SGLangConfig + +# 2-GPU setup +backend = SGLangBackend( + inference_device=0, # SGLang on GPU 0 + training_devices=[1], # Training on GPU 1 +) + +# 4-GPU setup with multi-GPU training +backend = SGLangBackend( + inference_device=0, + training_devices=[1, 2, 3], +) + +# Custom SGLang configuration +backend = SGLangBackend( + sglang_config=SGLangConfig( + mem_fraction_static=0.85, + weight_sync_method="lora", # or "disk", "restart" + flush_cache_on_sync=False, # Keep cache warm + tensor_parallel_size=1, + ) +) +``` + +### With vLLM (Default Behavior) + +```python +import art + +# Default LocalBackend uses vLLM +backend = art.LocalBackend() +await backend.register(model) +``` + +## Configuration Reference + +### DeviceConfig + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `inference_device` | int | 0 | GPU index for SGLang server | +| `training_devices` | list[int] | [1] | GPU indices for training | +| `auto_detect` | bool | True | Auto-detect available GPUs | + +### SGLangConfig + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `mem_fraction_static` | float | 0.9 | GPU memory for SGLang (0.0-1.0) | +| `disable_radix_cache` | bool | False | Disable RadixAttention (NOT recommended) | +| `max_loras_per_batch` | int | 4 | Max LoRA adapters to batch | +| `context_length` | int | None | Max context (None = model default) | +| `weight_sync_method` | str | "lora" | "lora", "disk", or "restart" | +| `flush_cache_on_sync` | bool | False | Clear KV cache on weight sync | +| `server_timeout` | float | 120.0 | Server startup timeout (seconds) | +| `tensor_parallel_size` | int | 1 | TP size for large models | + +## Weight Synchronization Methods + +| Method | Speed | Cache | Best For | +|--------|-------|-------|----------| +| `lora` | ~5-10s | Preserved | Multi-GPU, frequent training | +| `disk` | ~10-20s | Preserved | Large checkpoints | +| `restart` | ~30-60s | Lost | Single-GPU fallback | + +## Known Issues and Workarounds + +### 1. DeviceMesh Memory Imbalance Error + +**Symptom**: SGLang fails to start with memory imbalance error. + +**Solution**: Set environment variable (done automatically by SGLangBackend): +```bash +export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=True +``` + +### 2. update_weights_from_tensor Fails with TP > 1 + +**Reference**: [SGLang #3726](https://github.com/sgl-project/sglang/issues/3726) + +**Solution**: Use `weight_sync_method="lora"` or `"disk"` instead of tensor sync. + +### 3. OOM on Weight Update + +**Reference**: [SGLang #8076](https://github.com/sgl-project/sglang/issues/8076) + +**Solution**: Use disk-based sync or reduce `mem_fraction_static`. + +### 4. dp_size Must Be 1 for Weight Updates + +**Reference**: [SGLang #4283](https://github.com/sgl-project/sglang/issues/4283) + +**Solution**: Don't use data parallelism for inference (use TP instead). + +### 5. Garbled Output with Small Tensor Buckets + +**Reference**: [SGLang #14178](https://github.com/sgl-project/sglang/issues/14178) + +**Solution**: Use LoRA-based sync instead of tensor sync. + +## Performance Comparison + +Based on external benchmarks (H100, Llama 3.1 8B): + +| Metric | vLLM | SGLang | Improvement | +|--------|------|--------|-------------| +| Throughput (tok/s) | ~12,500 | ~16,200 | ~29% | +| TTFT (ms) | ~45 | ~35 | ~22% | +| P99 Latency (ms) | ~120 | ~95 | ~21% | + +*Source: [aimultiple.com benchmark](https://aimultiple.com/llm-inference-benchmark)* + +The performance advantage comes from: +- RadixAttention's automatic prefix caching +- Zero-overhead scheduler design +- Optimized FlashInfer kernels + +## Benchmarking Your Setup + +```bash +# In vLLM environment +source .venv-vllm/bin/activate +python scripts/benchmark_inference.py --engine vllm --model Qwen/Qwen2.5-3B-Instruct + +# In SGLang environment +source .venv-sglang/bin/activate +python scripts/benchmark_inference.py --engine sglang --model Qwen/Qwen2.5-3B-Instruct +``` + +## Troubleshooting + +### "SGLang is not installed" + +```bash +source .venv-sglang/bin/activate +pip install openpipe-art[sglang] +``` + +### Server timeout errors + +```python +backend = SGLangBackend( + sglang_config=SGLangConfig(server_timeout=180.0) +) +``` + +Or via environment: +```bash +export ART_SERVER_TIMEOUT=180 +``` + +### CUDA out of memory + +```python +backend = SGLangBackend( + sglang_config=SGLangConfig(mem_fraction_static=0.8) +) +``` + +### Check server logs + +```bash +cat .art///logs/sglang.log +``` + +## References + +- [verl SGLang integration](https://verl.readthedocs.io/en/latest/workers/sglang_worker.html) +- [SGLang weight sync optimization (slime)](https://hebiao064.github.io/rl-weight-sync) +- [SGLang GitHub](https://github.com/sgl-project/sglang) +- [Anatomy of RL Frameworks](https://www.hanifleo.com/anatomy-of-rl-frameworks/) diff --git a/scripts/setup_sglang.sh b/scripts/setup_sglang.sh new file mode 100644 index 000000000..690ed5370 --- /dev/null +++ b/scripts/setup_sglang.sh @@ -0,0 +1,122 @@ +#!/bin/bash +# Setup script for SGLang + Unsloth two-environment architecture +# +# Creates TWO COMPLETELY ISOLATED virtual environments: +# - .venv: Main training env (ART + unsloth + openai>=2.14) +# - .venv-sglang-server: SGLang server ONLY (sglang + openai==2.6.1) +# +# They communicate via HTTP (localhost:8000), NOT Python imports. +# This avoids ALL dependency conflicts (torchao, openai, etc.) +# +# Usage: +# chmod +x scripts/setup_sglang.sh +# ./scripts/setup_sglang.sh +# +# Then activate the main env to run training: +# source .venv/bin/activate +# python your_training_script.py + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" + +cd "$PROJECT_DIR" + +echo "==========================================" +echo "SGLang + Unsloth Two-Environment Setup" +echo "==========================================" +echo "" +echo "This will create TWO ISOLATED environments:" +echo " 1. .venv - Main: ART + Unsloth (openai>=2.14, torchao>=0.13)" +echo " 2. .venv-sglang-server - Server: SGLang ONLY (openai==2.6.1, torchao==0.9)" +echo "" +echo "They communicate via HTTP only. No shared dependencies." +echo "" + +# Check for python3.11 +PYTHON_CMD="" +if command -v python3.11 &> /dev/null; then + PYTHON_CMD="python3.11" +elif command -v python3 &> /dev/null; then + PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') + MAJOR=$(echo $PYTHON_VERSION | cut -d. -f1) + MINOR=$(echo $PYTHON_VERSION | cut -d. -f2) + if [ "$MAJOR" -ge 3 ] && [ "$MINOR" -ge 11 ]; then + PYTHON_CMD="python3" + fi +fi + +if [ -z "$PYTHON_CMD" ]; then + echo "ERROR: Python 3.11+ required." + echo "" + echo "Install with:" + echo " apt update && apt install -y software-properties-common" + echo " add-apt-repository -y ppa:deadsnakes/ppa" + echo " apt update && apt install -y python3.11 python3.11-venv python3.11-dev" + exit 1 +fi + +echo "Using: $PYTHON_CMD ($($PYTHON_CMD --version))" + +echo "" +echo "Step 1/4: Creating main training environment (.venv)..." +echo "--------------------------------------------------------" +if [ -d ".venv" ]; then + echo " .venv already exists, removing..." + rm -rf .venv +fi +$PYTHON_CMD -m venv .venv +echo " Created .venv" + +echo "" +echo "Step 2/4: Installing ART + training dependencies..." +echo "----------------------------------------------------" +source .venv/bin/activate +pip install --upgrade pip wheel +pip install -e ".[sglang]" +deactivate +echo " Main environment ready (ART + Unsloth)" + +echo "" +echo "Step 3/4: Creating SGLang server environment (.venv-sglang-server)..." +echo "----------------------------------------------------------------------" +if [ -d ".venv-sglang-server" ]; then + echo " .venv-sglang-server already exists, removing..." + rm -rf .venv-sglang-server +fi +$PYTHON_CMD -m venv .venv-sglang-server +echo " Created .venv-sglang-server" + +echo "" +echo "Step 4/4: Installing SGLang server (ISOLATED - no ART)..." +echo "----------------------------------------------------------" +source .venv-sglang-server/bin/activate +pip install --upgrade pip wheel +# Install ONLY sglang - nothing else! No ART, no shared deps. +pip install "sglang[srt]>=0.5.5" +deactivate +echo " SGLang server environment ready (sglang ONLY)" + +echo "" +echo "==========================================" +echo "Setup Complete!" +echo "==========================================" +echo "" +echo "Architecture:" +echo " .venv (main) <--HTTP--> .venv-sglang-server" +echo " - ART + Unsloth - sglang[srt] ONLY" +echo " - openai>=2.14 - openai==2.6.1" +echo " - torchao>=0.13 - torchao==0.9" +echo "" +echo "Usage:" +echo "" +echo " # Activate main training environment" +echo " source .venv/bin/activate" +echo "" +echo " # Run your script (SGLang server auto-detected)" +echo " python your_script.py" +echo "" +echo "The SGLang backend automatically finds .venv-sglang-server/bin/python" +echo "and uses it to spawn the inference server subprocess." +echo "" diff --git a/scripts/test_sglang_e2e.py b/scripts/test_sglang_e2e.py new file mode 100644 index 000000000..6efbed600 --- /dev/null +++ b/scripts/test_sglang_e2e.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +"""End-to-end test for SGLang backend with training loop. + +Tests the full RL cycle: +1. Server startup +2. Inference (rollouts) +3. Training (GRPO) +4. Weight sync (hot-reload or restart) +5. Second inference (verify weights updated) + +Usage: + source .venv/bin/activate + python scripts/test_sglang_e2e.py +""" + +# Suppress multiprocessing resource_tracker warnings +import warnings +warnings.filterwarnings("ignore", message="resource_tracker:") + +# CRITICAL: Set CUDA_VISIBLE_DEVICES for training BEFORE any imports +# This must be the VERY FIRST thing to happen before PyTorch initializes CUDA +import os + +# For split-mode training, we need GPUs 1,2,3 for training +# But we keep all GPUs visible so SGLang server (subprocess) can use GPU 0 +# The subprocess will set its own CUDA_VISIBLE_DEVICES +os.environ["IMPORT_UNSLOTH"] = "1" # Tell art package to import unsloth early + +# IMPORTANT: Import unsloth BEFORE any other ML libraries to prevent early CUDA initialization. +# This must happen before importing transformers, torch, vllm, or the art package. +# See: https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing +try: + import unsloth # noqa: F401 +except ImportError: + pass # unsloth not installed, continue without it + +import asyncio +import sys + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + + +async def test_e2e(): + """Run end-to-end test.""" + print("=" * 60) + print("SGLang Backend End-to-End Test") + print("=" * 60) + + # Step 1: Import and config check + print("\n[1/7] Importing modules...") + try: + import art + from art.sglang_backend import SGLangBackend, SGLangConfig + from art.trajectories import Trajectory, TrajectoryGroup + from openai import AsyncOpenAI + print(" ✓ Imports OK") + except ImportError as e: + print(f" ✗ Import failed: {e}") + return False + + # Step 2: Check server Python + print("\n[2/7] Checking SGLang server environment...") + config = SGLangConfig() + server_python = config.get_server_python() + print(f" Server Python: {server_python}") + if ".venv-sglang-server" in server_python: + print(" ✓ Using separate SGLang server environment") + else: + print(" ⚠ Using same Python (may have dependency issues)") + + # Step 3: Initialize backend + print("\n[3/7] Initializing SGLangBackend...") + try: + backend = SGLangBackend() + print(f" Mode: {'split' if backend.device_config.is_split_mode else 'shared'}-GPU") + print(f" Inference: cuda:{backend.device_config.inference_device}") + print(f" Training: cuda:{backend.device_config.training_devices}") + print(" ✓ Backend initialized") + except Exception as e: + print(f" ✗ Backend init failed: {e}") + return False + + # Step 4: Register model + print("\n[4/7] Registering model...") + try: + model = art.TrainableModel( + name="sglang-e2e-test", + base_model="Qwen/Qwen2.5-0.5B-Instruct", + project="sglang-test", + ) + await backend.register(model) + print(f" Model: {model.name}") + print(f" Base: {model.base_model}") + print(" ✓ Model registered") + except Exception as e: + print(f" ✗ Registration failed: {e}") + await backend.close() + return False + + # Step 5: Start server and test inference + print("\n[5/7] Starting server and testing inference...") + try: + base_url, api_key = await backend._prepare_backend_for_training(model, None) + print(f" Server URL: {base_url}") + + client = AsyncOpenAI(base_url=base_url, api_key=api_key) + model_name = backend._model_inference_name(model) + print(f" Model name for inference: {model_name}") + + response = await client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "Say 'test passed' in exactly two words."}], + max_tokens=10, + ) + response_text = response.choices[0].message.content + print(f" Response: {response_text}") + print(" ✓ Inference works") + except Exception as e: + print(f" ✗ Inference failed: {e}") + import traceback + traceback.print_exc() + await backend.close() + return False + + # Step 6: Create trajectories using real inference and train + print("\n[6/7] Running training step...") + try: + # Create trajectories by doing actual inference (to get real Choice objects) + trajectories = [] + + for i, (question, expected_reward) in enumerate([ + ("What is 2+2? Answer with just the number.", 1.0), + ("What is 2+2? Answer with a wrong number.", 0.0), + ]): + response = await client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": question}], + max_tokens=10, + logprobs=True, # Request logprobs for training + ) + choice = response.choices[0] + + traj = Trajectory( + messages_and_choices=[ + {"role": "user", "content": question}, + choice, # Real Choice object from API + ], + reward=expected_reward, + ) + trajectories.append(traj) + print(f" Trajectory {i+1}: '{choice.message.content}' -> reward={expected_reward}") + + trajectory_group = TrajectoryGroup(trajectories=trajectories) + + print(" Training on 2 trajectories...") + result = await backend.train( + model, + [trajectory_group], + learning_rate=1e-5, + verbose=True, + ) + print(f" Step: {result.step}") + print(f" Metrics: {result.metrics}") + print(" ✓ Training complete") + except Exception as e: + print(f" ✗ Training failed: {e}") + import traceback + traceback.print_exc() + await backend.close() + return False + + # Step 7: Test inference after training (weights should be updated) + print("\n[7/7] Testing inference after training...") + try: + # Get updated model name + model_name = backend._model_inference_name(model) + print(f" Model name: {model_name}") + + response = await client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "What is 2+2?"}], + max_tokens=10, + ) + response_text = response.choices[0].message.content + print(f" Response: {response_text}") + print(" ✓ Post-training inference works") + except Exception as e: + print(f" ✗ Post-training inference failed: {e}") + import traceback + traceback.print_exc() + await backend.close() + return False + + # Skip cleanup - just kill processes on exit + print("\n" + "=" * 60) + print("ALL TESTS PASSED!") + print("=" * 60) + + # Force kill SGLang server (faster than graceful shutdown) + import subprocess + subprocess.run(["pkill", "-9", "-f", "sglang"], capture_output=True) + + return True + + +if __name__ == "__main__": + success = asyncio.run(test_e2e()) + sys.exit(0 if success else 1) diff --git a/src/art/sglang_backend/__init__.py b/src/art/sglang_backend/__init__.py new file mode 100644 index 000000000..037297296 --- /dev/null +++ b/src/art/sglang_backend/__init__.py @@ -0,0 +1,53 @@ +"""SGLang-based backend for ART with Multi-GPU Split architecture. + +This module provides an alternative backend that uses SGLang for inference +instead of vLLM. The key advantage is RadixAttention prefix caching which +significantly improves performance for multi-turn agent trajectories. + +Architecture (Multi-GPU Split): + GPU 0: SGLang inference server (persistent, preserves RadixAttention cache) + GPU 1+: Training with Unsloth/GRPO + + This separation means: + - No memory release/reclaim overhead between train/inference + - RadixAttention cache stays warm across training steps + - Weight sync via hot-reload API (no server restart) + +IMPORTANT: SGLang and vLLM have conflicting dependencies (different PyTorch +versions). Use SEPARATE virtual environments: + + # For vLLM (default) + pip install openpipe-art[backend] + + # For SGLang (separate environment) + pip install openpipe-art[sglang] + +Usage: + from art.sglang_backend import SGLangBackend + + # Multi-GPU (recommended, requires 2+ GPUs) + backend = SGLangBackend( + inference_device=0, # SGLang on GPU 0 + training_devices=[1], # Training on GPU 1 + ) + + # Single-GPU fallback (uses restart mode, slower) + backend = SGLangBackend() # Auto-detects single GPU + + await backend.register(model) + result = await backend.train(model, trajectory_groups) + +References: + - verl SGLang integration: https://verl.readthedocs.io/en/latest/workers/sglang_worker.html + - SGLang weight sync: https://hebiao064.github.io/rl-weight-sync + - slime framework: https://github.com/Tsinghua-MARS-Lab/Slime +""" + +from .backend import SGLangBackend +from .config import SGLangConfig, DeviceConfig + +__all__ = [ + "SGLangBackend", + "SGLangConfig", + "DeviceConfig", +] diff --git a/src/art/sglang_backend/backend.py b/src/art/sglang_backend/backend.py new file mode 100644 index 000000000..c99833b4c --- /dev/null +++ b/src/art/sglang_backend/backend.py @@ -0,0 +1,293 @@ +"""SGLang-based backend for ART. + +This module provides SGLangBackend, an alternative to LocalBackend that uses +SGLang for inference instead of vLLM. Training remains the same (Unsloth/GRPO). + +Architecture: + Multi-GPU (recommended): + GPU 0: SGLang server (persistent, RadixAttention cache preserved) + GPU 1+: Training (Unsloth/GRPO) + Weight sync: Hot-reload via API (no restart) + + Single-GPU (fallback): + GPU 0: Shared between SGLang and training + Weight sync: Server restart (cache lost) + +Benefits over vLLM: + - RadixAttention: Better prefix caching for multi-turn agent trajectories + - Zero-overhead scheduler: Lower latency for RL rollouts + - Faster structured outputs: Better tool call parsing + +Limitations: + - No Tinker support yet + - Requires separate environment from vLLM (dependency conflicts) + - Multi-GPU recommended for best performance +""" + +import asyncio +import os +import subprocess + +from ..local.backend import LocalBackend +from ..local.service import ModelService +from ..model import TrainableModel +from ..utils.output_dirs import get_model_dir + +from .config import DeviceConfig, SGLangConfig +from .service import SGLangService + + +class SGLangBackend(LocalBackend): + """Backend using SGLang for inference instead of vLLM. + + This is a drop-in replacement for LocalBackend with SGLang-specific + optimizations for RL training workloads. + + Args: + inference_device: GPU index for SGLang server (default: 0) + training_devices: GPU indices for training (default: auto-detect) + in_process: Run service in-process (default: False) + path: Path for checkpoints/logs (default: ".art") + sglang_config: SGLang-specific configuration + + Example: + # Multi-GPU setup (recommended) + backend = SGLangBackend( + inference_device=0, + training_devices=[1, 2], + ) + + # Single-GPU (auto-fallback) + backend = SGLangBackend() + + # With custom config + backend = SGLangBackend( + sglang_config=SGLangConfig( + mem_fraction_static=0.85, + weight_sync_method="lora", + ) + ) + + await backend.register(model) + result = await backend.train(model, trajectory_groups) + """ + + def __init__( + self, + *, + inference_device: int | None = None, + training_devices: list[int] | None = None, + in_process: bool = False, + path: str | None = None, + sglang_config: SGLangConfig | None = None, + ) -> None: + """Initialize SGLangBackend. + + Args: + inference_device: GPU for SGLang (None = auto-detect) + training_devices: GPUs for training (None = auto-detect) + in_process: Run in-process (mainly for debugging) + path: Checkpoint/log directory + sglang_config: SGLang server configuration + """ + # Validate SGLang is available + self._validate_sglang_installation() + + # Initialize device configuration + if inference_device is not None or training_devices is not None: + self._device_config = DeviceConfig( + inference_device=inference_device or 0, + training_devices=training_devices or [1], + auto_detect=False, + ) + else: + self._device_config = DeviceConfig(auto_detect=True) + + # SGLang configuration + self._sglang_config = sglang_config or SGLangConfig() + + # In single-GPU mode, always use restart for weight sync + if not self._device_config.is_split_mode: + if self._sglang_config.weight_sync_method != "restart": + print( + f"Note: Single-GPU mode detected. Using 'restart' weight sync " + f"instead of '{self._sglang_config.weight_sync_method}'. " + f"For better performance, use 2+ GPUs." + ) + self._sglang_config.weight_sync_method = "restart" + + # Initialize parent + super().__init__(in_process=in_process, path=path) + + # Log configuration + self._log_config() + + def _validate_sglang_installation(self) -> None: + """Check that SGLang server environment is available. + + SGLang can run in a separate venv to avoid torchao conflicts with unsloth. + This checks if the configured server Python has sglang installed. + """ + pass # Validation happens when server starts (in the server's Python) + + def _log_config(self) -> None: + """Log configuration for debugging.""" + mode = "split" if self._device_config.is_split_mode else "shared" + print(f"SGLangBackend initialized:") + print(f" Mode: {mode}-GPU") + print(f" Inference device: cuda:{self._device_config.inference_device}") + print(f" Training devices: cuda:{self._device_config.training_devices}") + print(f" Weight sync: {self._sglang_config.weight_sync_method}") + if self._device_config.is_split_mode: + print(f" RadixAttention cache: preserved across training") + else: + print(f" RadixAttention cache: cleared on each training step") + + async def _get_service(self, model: TrainableModel) -> ModelService: + """Get or create the SGLang-based model service. + + Overrides LocalBackend._get_service to use SGLangService. + """ + from ..dev.get_model_config import get_model_config + + if model.name not in self._services: + config = get_model_config( + base_model=model.base_model, + output_dir=get_model_dir(model=model, art_path=self._path), + config=model._internal_config, + ) + + # Check for tinker config + if config.get("tinker_args") is not None: + raise NotImplementedError( + "SGLangBackend does not support tinker models yet. " + "Use LocalBackend for tinker models." + ) + + # Create SGLang service + service = SGLangService( + model_name=model.name, + base_model=model.base_model, + config=config, + output_dir=get_model_dir(model=model, art_path=self._path), + device_config=self._device_config, + sglang_config=self._sglang_config, + ) + + self._services[model.name] = service + + if not self._in_process: + # Kill any existing SGLang processes + subprocess.run( + ["pkill", "-9", "-f", "sglang.launch_server"], + capture_output=True, + ) + + return self._services[model.name] + + async def _monitor_openai_server( + self, model_name: str, base_url: str, api_key: str + ) -> None: + """Monitor the SGLang OpenAI-compatible server. + + SGLang uses different metrics, so we use simpler health checks. + """ + import aiohttp + from openai import AsyncOpenAI + + openai_client = AsyncOpenAI( + base_url=base_url, + api_key=api_key, + ) + consecutive_failures = 0 + max_consecutive_failures = 3 + + try: + async with aiohttp.ClientSession() as session: + while not getattr(self, '_monitor_should_stop', False): + # Sleep in small increments to allow fast shutdown + for _ in range(int(self._sglang_config.health_check_interval)): + if getattr(self, '_monitor_should_stop', False): + return + await asyncio.sleep(1) + + # Check stop flag after sleep + if getattr(self, '_monitor_should_stop', False): + return + + try: + # Check if service is sleeping (single-GPU mode during training) + service = self._services.get(model_name) + if service and await service.vllm_engine_is_sleeping(): + consecutive_failures = 0 + continue + + # Health check via models endpoint + async with session.get( + f"{base_url.replace('/v1', '')}/v1/models", + timeout=aiohttp.ClientTimeout(total=10), + ) as response: + if response.status == 200: + consecutive_failures = 0 + continue + + # Fallback: try completion + await openai_client.completions.create( + model=model_name, + prompt="Hi", + max_tokens=1, + timeout=5.0, + ) + consecutive_failures = 0 + + except Exception: + # Check stop flag - don't error during shutdown + if getattr(self, '_monitor_should_stop', False): + return + + # Check sleep status during exception + try: + service = self._services.get(model_name) + if service and await service.vllm_engine_is_sleeping(): + consecutive_failures = 0 + continue + except Exception: + pass + + consecutive_failures += 1 + if consecutive_failures >= max_consecutive_failures: + raise + except asyncio.CancelledError: + # Graceful shutdown + return + except aiohttp.ClientError: + # Connection errors during shutdown are expected + if getattr(self, '_monitor_should_stop', False): + return + raise + + async def close(self) -> None: + """Clean up resources and shutdown SGLang servers.""" + # Signal monitor to stop + self._monitor_should_stop = True + + # Brief pause for monitor to notice stop flag + await asyncio.sleep(0.1) + + # Shutdown all SGLang services + for name, service in list(self._services.items()): + if isinstance(service, SGLangService): + await service.shutdown() + + # Call parent close + await super().close() + + @property + def device_config(self) -> DeviceConfig: + """Get device configuration.""" + return self._device_config + + @property + def sglang_config(self) -> SGLangConfig: + """Get SGLang configuration.""" + return self._sglang_config diff --git a/src/art/sglang_backend/config.py b/src/art/sglang_backend/config.py new file mode 100644 index 000000000..0e290fc35 --- /dev/null +++ b/src/art/sglang_backend/config.py @@ -0,0 +1,203 @@ +"""Configuration classes for SGLang backend. + +These configurations control device placement, memory allocation, +and weight synchronization behavior. +""" + +from dataclasses import dataclass, field +from typing import Literal + + +@dataclass +class DeviceConfig: + """GPU device assignment configuration. + + For optimal performance, SGLang inference and training should run on + separate GPUs. This eliminates memory release/reclaim overhead and + keeps the RadixAttention cache warm. + + Attributes: + inference_device: GPU index for SGLang server (default: 0) + training_devices: GPU indices for training (default: [1] or [0] if single GPU) + auto_detect: If True, automatically detect available GPUs + + Example: + # 2-GPU setup + config = DeviceConfig(inference_device=0, training_devices=[1]) + + # 4-GPU setup with multi-GPU training + config = DeviceConfig(inference_device=0, training_devices=[1, 2, 3]) + + # Single GPU (fallback mode with server restart) + config = DeviceConfig(inference_device=0, training_devices=[0]) + """ + inference_device: int = 0 + training_devices: list[int] = field(default_factory=lambda: [1]) + auto_detect: bool = True + + def __post_init__(self): + if self.auto_detect: + self._auto_configure() + + def _auto_configure(self): + """Auto-detect GPU count and configure devices.""" + try: + import torch + gpu_count = torch.cuda.device_count() + except Exception: + gpu_count = 1 + + if gpu_count == 0: + raise RuntimeError("No CUDA GPUs available. SGLang requires GPU.") + elif gpu_count == 1: + # Single GPU: shared mode (will use restart) + self.inference_device = 0 + self.training_devices = [0] + else: + # Multi-GPU: split mode + self.inference_device = 0 + if not self.training_devices or self.training_devices == [1]: + self.training_devices = list(range(1, gpu_count)) + + @property + def is_split_mode(self) -> bool: + """True if inference and training use separate GPUs.""" + return self.inference_device not in self.training_devices + + @property + def inference_cuda_devices(self) -> str: + """CUDA_VISIBLE_DEVICES string for inference subprocess.""" + return str(self.inference_device) + + @property + def training_cuda_devices(self) -> str: + """CUDA_VISIBLE_DEVICES string for training.""" + return ",".join(str(d) for d in self.training_devices) + + +@dataclass +class SGLangConfig: + """SGLang server and weight sync configuration. + + Attributes: + sglang_python_path: Path to Python executable in SGLang server venv. + SGLang requires torchao==0.9.0 which conflicts with unsloth's torchao>=0.13.0. + Solution: Run SGLang server in a separate venv with its own dependencies. + Set this to the path of that venv's Python (e.g., ".venv-sglang-server/bin/python"). + If None, uses sys.executable (same Python, may have dependency conflicts). + + mem_fraction_static: GPU memory fraction for SGLang (0.0-1.0) + disable_radix_cache: If True, disable RadixAttention (NOT recommended) + max_loras_per_batch: Maximum LoRA adapters to batch + context_length: Maximum context length (None = model default) + + weight_sync_method: How to sync weights after training + - "lora": Use update_weights_from_lora (recommended) + - "disk": Use update_weights_from_disk + - "restart": Restart server (fallback, slow) + + flush_cache_on_sync: Clear KV cache when syncing weights + server_timeout: Seconds to wait for server startup + health_check_interval: Seconds between health checks + + References: + - verl config: https://verl.readthedocs.io/en/latest/examples/config.html + - SGLang issues on weight sync: #3726, #4283, #8076 + + Two-Environment Setup: + # 1. Create main training env (with unsloth) + python3 -m venv .venv + source .venv/bin/activate + pip install -e ".[sglang]" + + # 2. Create SGLang server env (separate, with sglang[srt]) + python3 -m venv .venv-sglang-server + .venv-sglang-server/bin/pip install -e ".[sglang-server]" + + # 3. Configure to use server env + config = SGLangConfig(sglang_python_path=".venv-sglang-server/bin/python") + """ + # Two-environment architecture: path to SGLang server's Python + # This allows sglang (torchao==0.9.0) and unsloth (torchao>=0.13.0) to coexist + sglang_python_path: str | None = None + + # Memory configuration + # NOTE: Set to 0.5 to leave enough GPU memory for training when CUDA_VISIBLE_DEVICES + # can't be set early enough (before PyTorch initialization) + mem_fraction_static: float = 0.5 + disable_radix_cache: bool = False # Keep False for RL training! + max_loras_per_batch: int = 4 + context_length: int | None = None + + # Weight synchronization + weight_sync_method: Literal["lora", "disk", "restart"] = "lora" + flush_cache_on_sync: bool = False # Keep cache warm + + # Server configuration + server_timeout: float = 120.0 + health_check_interval: float = 30.0 + + # Environment variables (from verl docs) + disable_tp_memory_check: bool = True # SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK + + # Tensor parallelism (for large models) + tensor_parallel_size: int = 1 + + # Logging + log_level: str = "warning" + + def get_server_python(self) -> str: + """Get Python executable path for SGLang server subprocess. + + Auto-detection order: + 1. Explicit sglang_python_path if set + 2. .venv-sglang-server/bin/python if exists + 3. sys.executable (same Python, may have conflicts) + """ + import os + import sys + + if self.sglang_python_path: + # Resolve relative paths from current working directory + path = os.path.abspath(self.sglang_python_path) + if not os.path.exists(path): + raise FileNotFoundError( + f"SGLang server Python not found at {path}. " + f"Create the server venv: python3 -m venv .venv-sglang-server && " + f".venv-sglang-server/bin/pip install -e '.[sglang-server]'" + ) + return path + + # Auto-detect: check for .venv-sglang-server in common locations + search_paths = [ + ".venv-sglang-server/bin/python", # Same directory + "../.venv-sglang-server/bin/python", # Parent directory + ] + + for rel_path in search_paths: + abs_path = os.path.abspath(rel_path) + if os.path.exists(abs_path): + print(f"Auto-detected SGLang server venv: {abs_path}") + return abs_path + + # Fallback to same Python (may have dependency conflicts) + return sys.executable + + def to_server_args(self) -> dict: + """Convert to SGLang server launch arguments.""" + args = { + "mem_fraction_static": self.mem_fraction_static, + "disable_radix_cache": self.disable_radix_cache, + "tp_size": self.tensor_parallel_size, + "log_level": self.log_level, + } + if self.context_length: + args["context_length"] = self.context_length + return args + + def to_env_vars(self) -> dict[str, str]: + """Environment variables to set for SGLang subprocess.""" + env = {} + if self.disable_tp_memory_check: + env["SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"] = "True" + return env diff --git a/src/art/sglang_backend/service.py b/src/art/sglang_backend/service.py new file mode 100644 index 000000000..c89c20f4f --- /dev/null +++ b/src/art/sglang_backend/service.py @@ -0,0 +1,650 @@ +"""SGLang service for inference with Unsloth training. + +This service manages the SGLang inference server and training lifecycle. +In multi-GPU mode, the server stays running and weights are hot-reloaded. +In single-GPU mode, the server is restarted for each training step. + +Key features: +- Persistent SGLang server preserves RadixAttention cache +- Hot-reload LoRA weights via SGLang API (no restart needed) +- Automatic fallback to restart mode on single GPU +- Health monitoring and graceful shutdown +""" + +import asyncio +import os +import signal +import subprocess +import sys +from dataclasses import dataclass, field +from functools import cached_property +from typing import TYPE_CHECKING, Any, AsyncIterator, cast + +import aiohttp +import torch +from datasets import Dataset +import peft +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from trl import GRPOConfig, GRPOTrainer + +from .. import dev, types +from ..local.checkpoints import get_last_checkpoint_dir +from ..preprocessing.inputs import TrainInputs +from ..preprocessing.pack import ( + DiskPackedTensors, + PackedTensors, + packed_tensors_from_dir, +) +from ..utils.get_model_step import get_step_from_dir +from ..utils.output_dirs import get_step_checkpoint_dir +from ..unsloth.train import gc_and_empty_cuda_cache, train + +from .config import DeviceConfig, SGLangConfig + +if TYPE_CHECKING: + from peft.peft_model import PeftModelForCausalLM + + +# Type alias for Unsloth model +CausalLM = Any + + +@dataclass +class TrainingState: + """Container for training model state.""" + + model: CausalLM + tokenizer: PreTrainedTokenizerBase + peft_model: "PeftModelForCausalLM" + trainer: "GRPOTrainer" + inputs_queue: asyncio.Queue[TrainInputs] + results_queue: asyncio.Queue[dict[str, float]] + _pinned_buffers: dict[str, torch.Tensor] = field(default_factory=dict) + _is_offloaded: bool = False + + def offload_to_cpu(self) -> None: + """Offload training model to CPU to free GPU memory.""" + if self._is_offloaded: + return + + for name, param in self.peft_model.named_parameters(): + if param.device.type == "cuda": + if ( + name not in self._pinned_buffers + or self._pinned_buffers[name].shape != param.shape + ): + self._pinned_buffers[name] = torch.empty( + param.shape, dtype=param.dtype, device="cpu", pin_memory=True + ) + self._pinned_buffers[name].copy_(param.data, non_blocking=True) + param.data = self._pinned_buffers[name] + + optimizer = getattr(self.trainer, "optimizer", None) + if optimizer is not None and hasattr(optimizer, "state"): + for param_id, state in optimizer.state.items(): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and v.device.type == "cuda": + key = f"opt_{id(param_id)}_{k}" + if ( + key not in self._pinned_buffers + or self._pinned_buffers[key].shape != v.shape + ): + self._pinned_buffers[key] = torch.empty( + v.shape, dtype=v.dtype, device="cpu", pin_memory=True + ) + self._pinned_buffers[key].copy_(v, non_blocking=True) + state[k] = self._pinned_buffers[key] + + torch.cuda.synchronize() + self._is_offloaded = True + gc_and_empty_cuda_cache() + + def reload_to_gpu(self, device: str = "cuda:0") -> None: + """Reload training model and optimizer back to GPU.""" + if not self._is_offloaded: + return + + for name, param in self.peft_model.named_parameters(): + if param.device.type == "cpu": + gpu_tensor = torch.empty(param.shape, dtype=param.dtype, device=device) + gpu_tensor.copy_(param.data, non_blocking=True) + param.data = gpu_tensor + + optimizer = getattr(self.trainer, "optimizer", None) + if optimizer is not None and hasattr(optimizer, "state"): + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and v.device.type == "cpu": + gpu_tensor = torch.empty(v.shape, dtype=v.dtype, device=device) + gpu_tensor.copy_(v, non_blocking=True) + state[k] = gpu_tensor + + torch.cuda.synchronize() + self._is_offloaded = False + + +@dataclass +class SGLangService: + """Service using SGLang for inference and Unsloth for training. + + This implements the ModelService protocol while using SGLang + instead of vLLM for the inference server. + + Multi-GPU Mode (recommended): + - SGLang server runs persistently on inference_device + - Training runs on training_devices + - Weights hot-reloaded via API after each training step + - RadixAttention cache preserved across training + + Single-GPU Mode (fallback): + - SGLang server killed before training + - Server restarted after training with new LoRA + - Cache lost on each restart + """ + + model_name: str + base_model: str + config: dev.InternalModelConfig + output_dir: str + device_config: DeviceConfig + sglang_config: SGLangConfig + + _is_sleeping: bool = False + _latest_step: int = 0 + _server_process: subprocess.Popen | None = None + _server_port: int = 8000 + _server_host: str = "127.0.0.1" + _train_task: asyncio.Task | None = None + _lora_counter: int = 1 + + def _next_lora_id(self) -> int: + """Generate unique LoRA ID.""" + self._lora_counter += 1 + return self._lora_counter + + async def start_openai_server( + self, config: dev.OpenAIServerConfig | None + ) -> tuple[str, int]: + """Start SGLang OpenAI-compatible server. + + In multi-GPU mode, training model stays on training GPUs. + In single-GPU mode, training model is offloaded to CPU first. + """ + # Get or create initial LoRA checkpoint + lora_path = get_last_checkpoint_dir(self.output_dir) + if lora_path is None: + lora_path = get_step_checkpoint_dir(self.output_dir, 0) + os.makedirs(os.path.dirname(lora_path), exist_ok=True) + self._training_state.trainer.save_model(lora_path) + self._latest_step = 0 + else: + self._latest_step = get_step_from_dir(self.output_dir) + + # In single-GPU mode, offload training model before starting SGLang + if not self.device_config.is_split_mode: + self._training_state.offload_to_cpu() + gc_and_empty_cuda_cache() # Ensure GPU memory is freed for SGLang + + # Get server configuration + server_config = config or {} + server_args = server_config.get("server_args", {}) + + self._server_host = server_args.get("host", "127.0.0.1") + self._server_port = server_args.get("port", 8000) + + # Create logs directory + log_dir = f"{self.output_dir}/logs" + os.makedirs(log_dir, exist_ok=True) + + # Start SGLang server subprocess + await self._start_server_process(lora_path) + + return self._server_host, self._server_port + + async def _start_server_process(self, lora_path: str | None = None) -> None: + """Start SGLang server as subprocess with proper device isolation. + + Uses a separate Python environment if sglang_python_path is configured. + This allows SGLang (torchao==0.9.0) and unsloth (torchao>=0.13.0) to coexist. + """ + # Build environment with device isolation + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = self.device_config.inference_cuda_devices + env.update(self.sglang_config.to_env_vars()) + + # Get Python executable for SGLang server (may be different venv) + server_python = self.sglang_config.get_server_python() + + # Build server command + cmd = [ + server_python, "-m", "sglang.launch_server", + "--model-path", self.base_model, + "--host", self._server_host, + "--port", str(self._server_port), + "--mem-fraction-static", str(self.sglang_config.mem_fraction_static), + "--log-level", self.sglang_config.log_level, + "--enable-lora", # Enable LoRA hot-reload endpoint + ] + + # Add tensor parallelism if configured + if self.sglang_config.tensor_parallel_size > 1: + cmd.extend(["--tp-size", str(self.sglang_config.tensor_parallel_size)]) + + # Add context length if specified + if self.sglang_config.context_length: + cmd.extend(["--context-length", str(self.sglang_config.context_length)]) + + # Add LoRA configuration + if lora_path and os.path.exists(lora_path): + cmd.extend(["--lora-paths", lora_path]) + cmd.extend(["--max-loras-per-batch", str(self.sglang_config.max_loras_per_batch)]) + + # Disable radix cache only if explicitly requested (not recommended) + if self.sglang_config.disable_radix_cache: + cmd.append("--disable-radix-cache") + + # Start server + log_file = open(f"{self.output_dir}/logs/sglang.log", "a") + self._server_process = subprocess.Popen( + cmd, + env=env, + stdout=log_file, + stderr=subprocess.STDOUT, + preexec_fn=os.setsid, # Create new process group for clean shutdown + ) + + # Wait for server to be ready + await self._wait_for_server() + + async def _wait_for_server(self) -> None: + """Wait for SGLang server to be ready.""" + timeout = self.sglang_config.server_timeout + start_time = asyncio.get_event_loop().time() + + while asyncio.get_event_loop().time() - start_time < timeout: + # Check if process died + if self._server_process and self._server_process.poll() is not None: + raise RuntimeError( + f"SGLang server process died with code {self._server_process.returncode}. " + f"Check logs at {self.output_dir}/logs/sglang.log" + ) + + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://{self._server_host}:{self._server_port}/v1/models", + timeout=aiohttp.ClientTimeout(total=5) + ) as resp: + if resp.status == 200: + return + except Exception: + pass + await asyncio.sleep(0.5) + + raise TimeoutError( + f"SGLang server did not start within {timeout} seconds. " + f"Check logs at {self.output_dir}/logs/sglang.log" + ) + + async def _stop_server_process(self) -> None: + """Stop SGLang server subprocess gracefully.""" + if self._server_process is None: + return + + try: + # Force kill immediately for fast cleanup + try: + os.killpg(os.getpgid(self._server_process.pid), signal.SIGKILL) + except (ProcessLookupError, OSError): + self._server_process.kill() + + # Non-blocking wait with short timeout + for _ in range(10): # Max 1 second + if self._server_process.poll() is not None: + break + await asyncio.sleep(0.1) + except Exception: + pass # Best effort cleanup + finally: + self._server_process = None + + self._server_process = None + gc_and_empty_cuda_cache() + + async def _hot_reload_lora(self, checkpoint_dir: str, step: int) -> None: + """Hot-reload LoRA weights without restarting server. + + Uses SGLang's update_weights_from_lora API. + This preserves the RadixAttention cache. + """ + lora_name = f"{self.model_name}@{step}" + + # Call SGLang's LoRA update endpoint + async with aiohttp.ClientSession() as session: + payload = { + "lora_path": checkpoint_dir, + "lora_name": lora_name, + } + + if self.sglang_config.flush_cache_on_sync: + payload["flush_cache"] = True + + try: + async with session.post( + f"http://{self._server_host}:{self._server_port}/load_lora_adapter", + json=payload, + timeout=aiohttp.ClientTimeout(total=60) + ) as resp: + if resp.status != 200: + error_text = await resp.text() + raise RuntimeError(f"Failed to hot-reload LoRA: {error_text}") + except aiohttp.ClientError as e: + # Fallback: try add_lora endpoint (older SGLang versions) + try: + async with session.post( + f"http://{self._server_host}:{self._server_port}/add_lora", + json={ + "lora_path": checkpoint_dir, + "lora_name": lora_name, + "lora_int_id": self._next_lora_id(), + }, + timeout=aiohttp.ClientTimeout(total=60) + ) as resp: + if resp.status != 200: + raise RuntimeError(f"Failed to add LoRA: {await resp.text()}") + except Exception: + raise RuntimeError(f"Failed to hot-reload LoRA: {e}") from e + + async def vllm_engine_is_sleeping(self) -> bool: + """Check if engine is sleeping (for LocalBackend compatibility). + + In multi-GPU mode, server never sleeps. + In single-GPU mode, returns True during training. + """ + return self._is_sleeping + + async def train( + self, + disk_packed_tensors: DiskPackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + """Run training step. + + Multi-GPU mode: + 1. Training runs on training_devices (server keeps running) + 2. Save LoRA checkpoint + 3. Hot-reload weights via API + + Single-GPU mode: + 1. Stop SGLang server + 2. Reload training model to GPU + 3. Train + 4. Save checkpoint + 5. Restart server with new LoRA + """ + if self.device_config.is_split_mode: + # Multi-GPU: server stays running + async for metrics in self._train_split_mode( + disk_packed_tensors, config, _config, verbose + ): + yield metrics + else: + # Single-GPU: need to swap + async for metrics in self._train_shared_mode( + disk_packed_tensors, config, _config, verbose + ): + yield metrics + + async def _train_split_mode( + self, + disk_packed_tensors: DiskPackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + """Training in multi-GPU split mode. + + Server keeps running. Weights hot-reloaded after training. + """ + # Training device is cuda:0 after CUDA_VISIBLE_DEVICES is set in _training_state + # (e.g., if training GPUs are [1,2,3], GPU 1 becomes cuda:0 after setting CUDA_VISIBLE_DEVICES="1,2,3") + training_device = "cuda:0" + + # Ensure training model is on GPU + self._training_state.reload_to_gpu(training_device) + + # Load packed tensors + packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) + + # Wait for any pending batches + await self._training_state.results_queue.join() + + # Start training task if needed + if self._train_task is None: + self._train_task = asyncio.create_task( + train( + trainer=self._training_state.trainer, + results_queue=self._training_state.results_queue, + ) + ) + warmup = True + else: + warmup = False + + # Process training batch + from ..unsloth.training_utils import process_train_batch + + async for result in process_train_batch( + packed_tensors=packed_tensors, + config=config, + _config=_config, + inputs_queue=self._training_state.inputs_queue, + results_queue=self._training_state.results_queue, + train_task=self._train_task, + trainer=self._training_state.trainer, + peft_model=self._training_state.peft_model, + warmup=warmup, + verbose=verbose, + ): + yield result + + # Save checkpoint + from ..unsloth.training_utils import save_checkpoint + + checkpoint_dir = save_checkpoint( + trainer=self._training_state.trainer, + output_dir=self.output_dir, + verbose=verbose, + ) + + # Determine new step + new_step = int(os.path.basename(checkpoint_dir)) + + # Hot-reload LoRA weights (no server restart!) + if self.sglang_config.weight_sync_method == "lora": + await self._hot_reload_lora(checkpoint_dir, new_step) + elif self.sglang_config.weight_sync_method == "disk": + await self._reload_from_disk(checkpoint_dir) + else: + # Fallback: restart server + await self._stop_server_process() + await self._start_server_process(checkpoint_dir) + + self._latest_step = new_step + + if verbose: + print(f"SGLangService.train complete (split mode, step {new_step})") + + async def _train_shared_mode( + self, + disk_packed_tensors: DiskPackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + """Training in single-GPU shared mode. + + Server is stopped during training, restarted after. + """ + # Stop SGLang server to free GPU memory + await self._stop_server_process() + self._is_sleeping = True + gc_and_empty_cuda_cache() + + # Reload training model to GPU + self._training_state.reload_to_gpu("cuda:0") + + # Load packed tensors + packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) + + # Wait for pending batches + await self._training_state.results_queue.join() + + # Start training task if needed + if self._train_task is None: + self._train_task = asyncio.create_task( + train( + trainer=self._training_state.trainer, + results_queue=self._training_state.results_queue, + ) + ) + warmup = True + else: + warmup = False + + # Process training batch + from ..unsloth.training_utils import process_train_batch + + async for result in process_train_batch( + packed_tensors=packed_tensors, + config=config, + _config=_config, + inputs_queue=self._training_state.inputs_queue, + results_queue=self._training_state.results_queue, + train_task=self._train_task, + trainer=self._training_state.trainer, + peft_model=self._training_state.peft_model, + warmup=warmup, + verbose=verbose, + ): + yield result + + # Save checkpoint + from ..unsloth.training_utils import save_checkpoint + + checkpoint_dir = save_checkpoint( + trainer=self._training_state.trainer, + output_dir=self.output_dir, + verbose=verbose, + ) + + # Offload training model + self._training_state.offload_to_cpu() + gc_and_empty_cuda_cache() + + # Restart SGLang server with new LoRA + new_step = int(os.path.basename(checkpoint_dir)) + await self._start_server_process(checkpoint_dir) + + self._latest_step = new_step + self._is_sleeping = False + + if verbose: + print(f"SGLangService.train complete (shared mode, step {new_step})") + + async def _reload_from_disk(self, checkpoint_dir: str) -> None: + """Reload weights from disk (alternative to LoRA hot-reload).""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self._server_host}:{self._server_port}/update_weights_from_disk", + json={ + "model_path": checkpoint_dir, + "load_format": "auto", + }, + timeout=aiohttp.ClientTimeout(total=120) + ) as resp: + if resp.status != 200: + raise RuntimeError(f"Failed to reload weights: {await resp.text()}") + + async def shutdown(self) -> None: + """Clean shutdown of service.""" + await self._stop_server_process() + + if self._train_task: + self._train_task.cancel() + try: + await self._train_task + except asyncio.CancelledError: + pass + self._train_task = None + + @cached_property + def _training_state(self) -> TrainingState: + """Initialize Unsloth model and trainer on training device.""" + import unsloth + + # Set training device with proper GPU isolation + if self.device_config.is_split_mode: + # CRITICAL: Set CUDA_VISIBLE_DEVICES to training GPUs only + # This ensures training doesn't accidentally use the inference GPU + os.environ["CUDA_VISIBLE_DEVICES"] = self.device_config.training_cuda_devices + device = "cuda:0" # After CUDA_VISIBLE_DEVICES, GPU 0 is the first training GPU + torch.cuda.set_device(0) + else: + device = "cuda:0" + + init_args = self.config.get("init_args", {}) + checkpoint_dir = get_last_checkpoint_dir(self.output_dir) + if checkpoint_dir: + init_args["model_name"] = checkpoint_dir + else: + init_args["model_name"] = self.base_model + + model, tokenizer = cast( + tuple[CausalLM, PreTrainedTokenizerBase], + unsloth.FastLanguageModel.from_pretrained(**init_args), + ) + + if ( + hasattr(model, "peft_config") + and getattr(model, "peft_config", None) is not None + ): + peft_model = cast(peft.peft_model.PeftModelForCausalLM, model) + else: + peft_model = cast( + peft.peft_model.PeftModelForCausalLM, + unsloth.FastLanguageModel.get_peft_model( + model, **self.config.get("peft_args", {}) + ), + ) + + data = {"prompt": ""} + trainer = GRPOTrainer( + model=peft_model, + reward_funcs=[], + args=GRPOConfig(**self.config.get("trainer_args", {})), + train_dataset=Dataset.from_list([data for _ in range(10_000_000)]), + processing_class=tokenizer, + ) + + inputs_queue: asyncio.Queue[TrainInputs] = asyncio.Queue() + results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue() + + def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]: + async def get_inputs() -> TrainInputs: + return await inputs_queue.get() + inputs = asyncio.run(get_inputs()) + return cast(dict[str, torch.Tensor], inputs) + + trainer._prepare_inputs = _async_prepare_inputs + + return TrainingState( + model=model, + tokenizer=tokenizer, + peft_model=peft_model, + trainer=trainer, + inputs_queue=inputs_queue, + results_queue=results_queue, + ) diff --git a/src/art/unsloth/training_utils.py b/src/art/unsloth/training_utils.py new file mode 100644 index 000000000..e4c4214c0 --- /dev/null +++ b/src/art/unsloth/training_utils.py @@ -0,0 +1,128 @@ +"""Training utilities that don't depend on vLLM. + +These functions are extracted from unsloth/service.py to allow use +by backends that don't use vLLM (e.g., SGLang backend). +""" + +import asyncio +import os +from typing import TYPE_CHECKING, AsyncIterator + +import torch + +from .. import dev, types +from ..preprocessing.inputs import TrainInputs, create_train_inputs +from ..preprocessing.pack import PackedTensors +from ..utils.get_model_step import get_step_from_dir +from ..utils.output_dirs import get_step_checkpoint_dir +from .train import gc_and_empty_cuda_cache + +if TYPE_CHECKING: + from peft.peft_model import PeftModelForCausalLM + from trl import GRPOTrainer + + +def precalculate_new_logprobs( + trainer: "GRPOTrainer", + peft_model: "PeftModelForCausalLM", + packed_tensors: PackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, +) -> torch.Tensor: + """Precalculate logprobs for all offsets and return as a tensor.""" + return torch.cat( + [ + trainer.compute_loss( + peft_model, + TrainInputs( # ty:ignore[missing-typed-dict-key] + **{ + k: v[_offset : _offset + 1] + for k, v in packed_tensors.items() + if isinstance(v, torch.Tensor) + }, + pixel_values=packed_tensors["pixel_values"][_offset : _offset + 1], + image_grid_thw=packed_tensors["image_grid_thw"][ + _offset : _offset + 1 + ], + config=config, + _config=_config, + return_new_logprobs=True, + ), + ) + for _offset in range(0, packed_tensors["tokens"].shape[0]) + ] + ).to("cpu") + + +async def process_train_batch( + packed_tensors: PackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, + inputs_queue: asyncio.Queue[TrainInputs], + results_queue: asyncio.Queue[dict[str, float]], + train_task: asyncio.Task[None], + trainer: "GRPOTrainer", + peft_model: "PeftModelForCausalLM", + warmup: bool, + verbose: bool = False, +) -> AsyncIterator[dict[str, float]]: + """ + Process training batches and yield results. + + Yields tuples of (result, warmup_done) where warmup_done indicates if warmup just finished. + """ + precalculate_logprobs = _config.get("precalculate_logprobs", False) + + for offset in range(0, packed_tensors["tokens"].shape[0]): + for _ in range(2 if warmup else 1): + if precalculate_logprobs and not warmup: + # Preserve original logprobs before overwriting + packed_tensors["original_logprobs"] = packed_tensors["logprobs"] # type: ignore + packed_tensors["logprobs"] = precalculate_new_logprobs( + trainer, peft_model, packed_tensors, config, _config + ) + precalculate_logprobs = False + + inputs_queue.put_nowait( + create_train_inputs(packed_tensors, offset, config, _config, warmup) + ) + + # Wait for a result from the queue or for the training task to, + # presumably, raise an exception + done, _ = await asyncio.wait( + [ + asyncio.create_task(results_queue.get()), + train_task, + ], + return_when=asyncio.FIRST_COMPLETED, + ) + if verbose: + print( + "Done waiting for a result from the queue or for the training task to, presumably, raise an exception" + ) + for task in done: + result = task.result() + # If `result` is `None`, the training task finished somehow. + assert result is not None, "The training task should never finish." + results_queue.task_done() + if warmup: + gc_and_empty_cuda_cache() + await asyncio.sleep(0.1) + warmup = False + else: + yield result + + +def save_checkpoint( + trainer: "GRPOTrainer", + output_dir: str, + verbose: bool = False, +) -> str: + """Save a checkpoint and return the checkpoint directory path.""" + if verbose: + print("Saving new LoRA adapter...") + next_step = get_step_from_dir(output_dir) + 1 + checkpoint_dir = get_step_checkpoint_dir(output_dir, next_step) + os.makedirs(checkpoint_dir, exist_ok=True) + trainer.save_model(checkpoint_dir) + return checkpoint_dir From 486365cc1b62d3aee209cebd522421b16709dea9 Mon Sep 17 00:00:00 2001 From: mukesh reddy Date: Wed, 4 Feb 2026 09:57:58 -0500 Subject: [PATCH 2/8] Add sglang optional dependencies to pyproject.toml --- pyproject.toml | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e2934df55..ef0b13c5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,34 @@ backend = [ "vllm==0.15.1 ; sys_platform == 'linux'", ] +# SGLang training environment (main env - NO sglang here, just training deps) +# SGLang server runs in COMPLETELY SEPARATE venv (just: pip install sglang[srt]) +# Communication between envs is via HTTP (localhost), not Python imports +sglang = [ + "peft>=0.14.0", + "hf-xet>=1.1.0", + "bitsandbytes>=0.45.2", + "unsloth==2025.12.9", + "unsloth-zoo==2025.12.7", + "torch>=2.8.0", + "torchao==0.14.1", + "accelerate==1.7.0", + "awscli>=1.38.1", + "setuptools>=78.1.0", + "wandb==0.23.1", + "transformers>=4.55.2,<=4.57.3", + "duckdb>=1.0.0", + "pyarrow>=15.0.0", + "trl==0.20.0", + "nbclient>=0.10.1", + "pytest>=8.4.1", + "nbmake>=1.5.5", + "gql<4", + "aiohttp>=3.9.0", +] +# NOTE: SGLang server venv is created separately with JUST: pip install "sglang[srt]" +# Do NOT install ART in the server venv - they communicate via HTTP only + langgraph = [ "langchain-core>=0.3.51", "langgraph>=0.6.2", @@ -145,6 +173,10 @@ allowed-unresolved-imports = [ "uvicorn.**", "vllm.**", "wandb.**", + # sglang deps + "sglang.**", + "flashinfer.**", + "flashinfer_python.**", # langgraph deps "langchain_core.**", "langchain_openai.**", @@ -152,8 +184,6 @@ allowed-unresolved-imports = [ # plotting deps "matplotlib.**", "seaborn.**", - # megatron deps - "megatron.**", ] [dependency-groups] From a872571c517f717d0d070b7a5f66c8421404114e Mon Sep 17 00:00:00 2001 From: mukesh reddy Date: Wed, 4 Feb 2026 09:58:25 -0500 Subject: [PATCH 3/8] Add missing modified files for SGLang integration --- src/art/unsloth/service.py | 238 +++++++++++++++++++++++++++++-------- src/art/unsloth/train.py | 12 +- 2 files changed, 198 insertions(+), 52 deletions(-) diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index d42941357..2915c855a 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -1,7 +1,7 @@ """Unsloth training service with decoupled vLLM inference.""" import asyncio -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import cached_property import os from typing import TYPE_CHECKING, Any, AsyncIterator, Protocol, cast @@ -29,6 +29,76 @@ from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers from .train import gc_and_empty_cuda_cache, train + +# ============================================================================ +# Device Configuration for Multi-GPU Support +# ============================================================================ + + +@dataclass +class DeviceConfig: + """GPU device assignment for Unsloth training and vLLM inference. + + For optimal performance, training and inference should run on separate GPUs. + This eliminates memory contention and the need for CPU offloading. + + Attributes: + inference_device: GPU index for vLLM inference (default: 0) + training_device: GPU index for Unsloth training (default: 1, or 0 if single GPU) + auto_detect: If True, automatically detect available GPUs + + Example: + # 2-GPU setup (recommended) + config = DeviceConfig(inference_device=0, training_device=1) + + # Single GPU (fallback with CPU offloading) + config = DeviceConfig(inference_device=0, training_device=0) + """ + inference_device: int = 0 + training_device: int = 1 + auto_detect: bool = True + + def __post_init__(self): + if self.auto_detect: + self._auto_configure() + + def _auto_configure(self): + """Auto-detect GPU count and configure devices.""" + try: + gpu_count = torch.cuda.device_count() + except Exception: + gpu_count = 1 + + if gpu_count == 0: + raise RuntimeError("No CUDA GPUs available.") + elif gpu_count == 1: + # Single GPU: shared mode (will use CPU offloading) + self.inference_device = 0 + self.training_device = 0 + print(f"[DeviceConfig] Single GPU detected. Using shared mode with CPU offloading.") + else: + # Multi-GPU: split mode (no offloading needed!) + self.inference_device = 0 + self.training_device = 1 + print(f"[DeviceConfig] {gpu_count} GPUs detected. Using split mode:") + print(f" - GPU {self.inference_device}: vLLM inference") + print(f" - GPU {self.training_device}: Unsloth training") + + @property + def is_split_mode(self) -> bool: + """True if inference and training use separate GPUs.""" + return self.inference_device != self.training_device + + @property + def inference_cuda_devices(self) -> str: + """CUDA_VISIBLE_DEVICES string for vLLM inference subprocess.""" + return str(self.inference_device) + + @property + def training_cuda_device(self) -> str: + """CUDA device string for training (e.g., 'cuda:1').""" + return f"cuda:{self.training_device}" + if TYPE_CHECKING: from peft.peft_model import PeftModelForCausalLM from trl import GRPOTrainer @@ -174,79 +244,54 @@ class UnslothState: _pinned_buffers: dict[str, torch.Tensor] | None = None def offload_to_cpu(self) -> None: - """Offload training model and optimizer to CPU using pinned memory for faster transfers.""" + """Offload entire training model (base + adapters) and optimizer to CPU.""" if self._is_offloaded: return - # Initialize pinned buffer storage - if self._pinned_buffers is None: - self._pinned_buffers = {} - - # Offload model parameters to pinned memory for faster reload - for name, param in self.peft_model.named_parameters(): - if param.device.type == "cuda": - # Create pinned buffer if not exists or wrong size - if ( - name not in self._pinned_buffers - or self._pinned_buffers[name].shape != param.shape - ): - self._pinned_buffers[name] = torch.empty( - param.shape, dtype=param.dtype, device="cpu", pin_memory=True - ) - # Async copy to pinned memory - self._pinned_buffers[name].copy_(param.data, non_blocking=True) - param.data = self._pinned_buffers[name] - - # Offload optimizer state to pinned memory + print("[UnslothService] Offloading entire model to CPU...") + + # Move the entire PEFT model to CPU (this includes base model + adapters) + self.peft_model.to("cpu") + + # Offload optimizer state to CPU optimizer = getattr(self.trainer, "optimizer", None) if optimizer is not None and hasattr(optimizer, "state"): for param_id, state in optimizer.state.items(): for k, v in state.items(): if isinstance(v, torch.Tensor) and v.device.type == "cuda": - key = f"opt_{id(param_id)}_{k}" - if ( - key not in self._pinned_buffers - or self._pinned_buffers[key].shape != v.shape - ): - self._pinned_buffers[key] = torch.empty( - v.shape, dtype=v.dtype, device="cpu", pin_memory=True - ) - self._pinned_buffers[key].copy_(v, non_blocking=True) - state[k] = self._pinned_buffers[key] - - # Sync to ensure all copies are complete before freeing GPU memory - torch.cuda.synchronize() + state[k] = v.cpu() + # Sync and clear GPU memory + torch.cuda.synchronize() self._is_offloaded = True gc_and_empty_cuda_cache() + + # Report free memory + free_mem = torch.cuda.mem_get_info()[0] / 1e9 + print(f"[UnslothService] Model offloaded. GPU memory free: {free_mem:.2f} GB") def reload_to_gpu(self, device: str = "cuda:0") -> None: - """Reload training model and optimizer back to GPU using async transfers.""" + """Reload entire training model and optimizer back to GPU.""" if not self._is_offloaded: return - # Reload model parameters from pinned memory (fast async transfer) - for name, param in self.peft_model.named_parameters(): - if param.device.type == "cpu": - # Allocate on GPU and async copy from pinned memory - gpu_tensor = torch.empty(param.shape, dtype=param.dtype, device=device) - gpu_tensor.copy_(param.data, non_blocking=True) - param.data = gpu_tensor + print(f"[UnslothService] Reloading model to {device}...") + + # Move the entire PEFT model back to GPU + self.peft_model.to(device) - # Reload optimizer state + # Reload optimizer state to GPU optimizer = getattr(self.trainer, "optimizer", None) if optimizer is not None and hasattr(optimizer, "state"): for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor) and v.device.type == "cpu": - gpu_tensor = torch.empty(v.shape, dtype=v.dtype, device=device) - gpu_tensor.copy_(v, non_blocking=True) - state[k] = gpu_tensor + state[k] = v.to(device) # Sync to ensure all copies are complete before training torch.cuda.synchronize() - self._is_offloaded = False + print(f"[UnslothService] Model reloaded to {device}") # ============================================================================ @@ -260,6 +305,7 @@ class UnslothService: base_model: str config: dev.InternalModelConfig output_dir: str + device_config: DeviceConfig = field(default_factory=DeviceConfig) _is_sleeping: bool = False _latest_step: int = 0 _lora_id_counter: int = 1 # Start from 1 since 0 is reserved @@ -283,8 +329,13 @@ async def start_openai_server( # Extract step from checkpoint path self._latest_step = get_step_from_dir(self.output_dir) - # Offload training model to CPU before vLLM starts to free GPU memory + # Offload training model to CPU so vLLM can use the GPU self._state.offload_to_cpu() + # Force garbage collection and clear CUDA cache + import gc + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() server_config = dev.get_openai_server_config( model_name=self.model_name, @@ -334,7 +385,7 @@ async def train( ) -> AsyncIterator[dict[str, float]]: llm = await self.llm - # Pause generation to prevent new requests during training + # Time-sharing mode: pause vLLM, free GPU memory, then train await llm.pause_generation() # Determine sleep level based on outstanding requests: @@ -364,10 +415,14 @@ async def train( # If we haven't already, start the training task if not hasattr(self, "_train_task") or self._train_task is None: + # Use remapped device index: in split mode with CUDA_VISIBLE_DEVICES=0,1, + # training is cuda:1 (second visible device) + # Training device is cuda:0 self._train_task = asyncio.create_task( train( trainer=self._state.trainer, results_queue=self._state.results_queue, + training_device=0, ) ) warmup = True @@ -396,7 +451,7 @@ async def train( verbose=verbose, ) - # Offload training model to CPU before waking vLLM + # Offload training model before waking vLLM self._state.offload_to_cpu() # Free memory before waking up vLLM @@ -438,6 +493,12 @@ async def train( def _state(self) -> UnslothState: import unsloth + # Use cuda:0 for training - Unsloth's compiled code expects this + # Time-sharing with vLLM via sleep/wake handles memory management + cuda_device_index = 0 + torch.cuda.set_device(cuda_device_index) + print(f"[UnslothService] Loading training model on cuda:{cuda_device_index}") + # Initialize Unsloth model init_args = self.config.get("init_args", {}) checkpoint_dir = get_last_checkpoint_dir(self.output_dir) @@ -445,11 +506,19 @@ def _state(self) -> UnslothState: init_args["model_name"] = checkpoint_dir else: init_args["model_name"] = self.base_model + + # Set device_map to cuda:0 - Unsloth expects training on cuda:0 + if "device_map" not in init_args: + init_args["device_map"] = {"": 0} model, tokenizer = cast( tuple[CausalLM, PreTrainedTokenizerBase], unsloth.FastLanguageModel.from_pretrained(**init_args), ) + + # Verify the model is on the correct device + model_device = next(model.parameters()).device + print(f"[UnslothService] Model loaded on device: {model_device}, current_device={torch.cuda.current_device()}") # Initialize PEFT model - skip if already a PeftModel (e.g. loaded from checkpoint) if ( @@ -466,6 +535,56 @@ def _state(self) -> UnslothState: ), ) + # Reset AcceleratorState singleton and patch device check before creating trainer + # This is necessary because AcceleratorState caches the device from first initialization, + # which might have been device 0 (from vLLM or imports). We need it to use device 1. + try: + from accelerate.state import AcceleratorState + from accelerate import Accelerator + AcceleratorState._reset_state() + + # Monkey-patch Accelerator to skip device check for 4-bit models + # The check fails when model is on GPU 1 but Accelerator was initialized earlier + # We need to bypass the check BEFORE original_prepare_model runs + original_prepare_model = Accelerator.prepare_model + def patched_prepare_model(self, model, device_placement=None, evaluation_mode=False): + # For quantized models, temporarily remove the quantization flags to bypass the check + # Then restore them after prepare_model completes + was_8bit = getattr(model, "is_loaded_in_8bit", False) + was_4bit = getattr(model, "is_loaded_in_4bit", False) + was_device_map = getattr(model, "hf_device_map", None) + + if was_8bit or was_4bit: + print(f"[UnslothService] Temporarily hiding quantization flags to bypass device check") + # Temporarily hide the quantization flags + model.is_loaded_in_8bit = False + model.is_loaded_in_4bit = False + # Try to delete hf_device_map - it may be on inner model (accessible via __getattr__) + # but not directly deletable from the PEFT wrapper + try: + delattr(model, "hf_device_map") + except AttributeError: + pass # Attribute is on inner model, not directly on PEFT wrapper + + try: + result = original_prepare_model(self, model, device_placement, evaluation_mode) + finally: + # Restore the flags + if was_8bit: + model.is_loaded_in_8bit = True + if was_4bit: + model.is_loaded_in_4bit = True + if was_device_map is not None: + model.hf_device_map = was_device_map + return result + else: + return original_prepare_model(self, model, device_placement, evaluation_mode) + Accelerator.prepare_model = patched_prepare_model + + print(f"[UnslothService] Reset AcceleratorState and patched prepare_model, current_device={torch.cuda.current_device()}") + except Exception as e: + print(f"[UnslothService] Could not reset AcceleratorState: {e}") + # Initialize trainer with dummy dataset data = {"prompt": ""} trainer = GRPOTrainer( @@ -504,12 +623,29 @@ async def get_inputs() -> TrainInputs: @cached_property def llm(self) -> asyncio.Task[AsyncLLM]: + # Use single GPU (cuda:0) for both vLLM and Unsloth with time-sharing + # Unsloth's compiled training loop expects cuda:0, so split-GPU mode is not supported + inference_gpu = self.device_config.inference_device + os.environ["CUDA_VISIBLE_DEVICES"] = str(inference_gpu) + print(f"[UnslothService] Starting vLLM on GPU {inference_gpu} (time-sharing mode with Unsloth)") + # Filter engine args to remove incompatible boolean flags engine_args = { **self.config.get("engine_args", {}), "enable_lora": True, "max_loras": self.config.get("engine_args", {}).get("max_loras", 2), } + + # In split mode, vLLM has the full GPU to itself, so use high utilization + # In shared mode, use lower utilization to leave room for training model + if self.device_config.is_split_mode: + if "gpu_memory_utilization" not in engine_args: + engine_args["gpu_memory_utilization"] = 0.90 + else: + # Shared mode: lower utilization to coexist with training + if "gpu_memory_utilization" not in engine_args: + engine_args["gpu_memory_utilization"] = 0.80 + # Remove boolean flags that vLLM's argparse doesn't accept as =False for key in ["enable_log_requests", "disable_log_requests"]: engine_args.pop(key, None) diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index e5d229537..7af5de282 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -23,7 +23,14 @@ async def train( trainer: "GRPOTrainer", results_queue: asyncio.Queue[dict[str, float]], + training_device: int | None = None, ) -> None: + # Set the CUDA device before training - required for 4-bit/8-bit quantized models + # because accelerate checks torch.cuda.current_device() matches the model's device + if training_device is not None: + torch.cuda.set_device(training_device) + print(f"[train] Set CUDA device to {training_device}, current_device={torch.cuda.current_device()}") + _compute_loss = trainer.compute_loss _log = trainer.log trainer.compute_loss = get_compute_loss_fn(trainer) @@ -37,7 +44,10 @@ async def train( if not is_train_dict: trainer._metrics = {"train": defaultdict(list)} try: - trainer.train() + # Use context manager to ensure device is set during training + with torch.cuda.device(training_device) if training_device is not None else nullcontext(): + print(f"[train] About to call trainer.train(), current_device={torch.cuda.current_device()}") + trainer.train() finally: trainer.compute_loss = _compute_loss trainer.log = _log # ty:ignore[invalid-assignment] From 13377c037cd6ad1ebae4cb6bac6f01674663c3c5 Mon Sep 17 00:00:00 2001 From: mukesh reddy p <88029886+pmukeshreddy@users.noreply.github.com> Date: Wed, 4 Feb 2026 20:44:04 +0530 Subject: [PATCH 4/8] Update sglang-integration.md --- docs/sglang-integration.md | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/docs/sglang-integration.md b/docs/sglang-integration.md index 45c7efe67..bf47ceee6 100644 --- a/docs/sglang-integration.md +++ b/docs/sglang-integration.md @@ -81,25 +81,35 @@ this provides significant speedups. ## Installation -**CRITICAL**: SGLang and vLLM have conflicting PyTorch dependencies. You MUST use -separate virtual environments. - -### vLLM Environment (Default) +**CRITICAL**: SGLang requires a TWO-environment architecture due to torchao version conflicts. +### Quick Setup (Recommended) ```bash -python -m venv .venv-vllm -source .venv-vllm/bin/activate -pip install openpipe-art[backend] +# Run the setup script (creates both environments) +chmod +x scripts/setup_sglang.sh +./scripts/setup_sglang.sh ``` -### SGLang Environment - +### Manual Setup ```bash -python -m venv .venv-sglang -source .venv-sglang/bin/activate -pip install openpipe-art[sglang] +# 1. Main training environment (ART + Unsloth) +python3.11 -m venv .venv +source .venv/bin/activate +pip install -e ".[sglang]" +deactivate + +# 2. SGLang server environment (ISOLATED - no ART) +python3.11 -m venv .venv-sglang-server +source .venv-sglang-server/bin/activate +pip install "sglang[srt]>=0.5.5" +deactivate + +# 3. Activate main env to run training +source .venv/bin/activate ``` +The SGLang backend automatically detects `.venv-sglang-server` and uses it for the inference server subprocess. + ## Usage ### Basic Usage (Auto-detect GPUs) From 19bd069d59560a4c6a57ae705535ca28e575781b Mon Sep 17 00:00:00 2001 From: mukesh reddy p <88029886+pmukeshreddy@users.noreply.github.com> Date: Wed, 4 Feb 2026 20:47:00 +0530 Subject: [PATCH 5/8] Update sglang-integration.md --- docs/sglang-integration.md | 47 -------------------------------------- 1 file changed, 47 deletions(-) diff --git a/docs/sglang-integration.md b/docs/sglang-integration.md index bf47ceee6..4fc40a235 100644 --- a/docs/sglang-integration.md +++ b/docs/sglang-integration.md @@ -203,57 +203,10 @@ await backend.register(model) | `disk` | ~10-20s | Preserved | Large checkpoints | | `restart` | ~30-60s | Lost | Single-GPU fallback | -## Known Issues and Workarounds -### 1. DeviceMesh Memory Imbalance Error -**Symptom**: SGLang fails to start with memory imbalance error. -**Solution**: Set environment variable (done automatically by SGLangBackend): -```bash -export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=True -``` - -### 2. update_weights_from_tensor Fails with TP > 1 - -**Reference**: [SGLang #3726](https://github.com/sgl-project/sglang/issues/3726) - -**Solution**: Use `weight_sync_method="lora"` or `"disk"` instead of tensor sync. - -### 3. OOM on Weight Update - -**Reference**: [SGLang #8076](https://github.com/sgl-project/sglang/issues/8076) - -**Solution**: Use disk-based sync or reduce `mem_fraction_static`. - -### 4. dp_size Must Be 1 for Weight Updates - -**Reference**: [SGLang #4283](https://github.com/sgl-project/sglang/issues/4283) - -**Solution**: Don't use data parallelism for inference (use TP instead). - -### 5. Garbled Output with Small Tensor Buckets - -**Reference**: [SGLang #14178](https://github.com/sgl-project/sglang/issues/14178) - -**Solution**: Use LoRA-based sync instead of tensor sync. - -## Performance Comparison - -Based on external benchmarks (H100, Llama 3.1 8B): - -| Metric | vLLM | SGLang | Improvement | -|--------|------|--------|-------------| -| Throughput (tok/s) | ~12,500 | ~16,200 | ~29% | -| TTFT (ms) | ~45 | ~35 | ~22% | -| P99 Latency (ms) | ~120 | ~95 | ~21% | - -*Source: [aimultiple.com benchmark](https://aimultiple.com/llm-inference-benchmark)* -The performance advantage comes from: -- RadixAttention's automatic prefix caching -- Zero-overhead scheduler design -- Optimized FlashInfer kernels ## Benchmarking Your Setup From 062d6809e7db7ee8cdd95a6e9c8b5f73e0767ff8 Mon Sep 17 00:00:00 2001 From: mukesh reddy p <88029886+pmukeshreddy@users.noreply.github.com> Date: Mon, 16 Feb 2026 21:43:16 -0500 Subject: [PATCH 6/8] feat: complete SGLang backend with multi-GPU split, benchmarks, and core fixes - SGLang backend with dedicated GPU split (inference GPU 0, training GPU 1+) - LoRA hot-reload via SGLang API preserves RadixAttention cache - Two-environment architecture for torchao version isolation - Benchmarks: SGLang vs vLLM comparison suite - Training utils extracted for backend-agnostic use - DeviceConfig with auto-detection - Ruler fix for empty trajectory groups and exception preservation - vLLM compatibility patches --- CLAUDE.md | 33 +- CONTRIBUTING.md | 18 - README.md | 2 +- benchmarks/__init__.py | 0 benchmarks/sglang_vs_vllm/README.md | 172 ++ benchmarks/sglang_vs_vllm/__init__.py | 1 + benchmarks/sglang_vs_vllm/config.py | 231 +++ .../sglang_vs_vllm/metrics_collector.py | 433 +++++ benchmarks/sglang_vs_vllm/run_benchmark.py | 668 ++++++++ .../sglang_vs_vllm/setup_environments.sh | 313 ++++ benchmarks/sglang_vs_vllm/sglang_server.py | 617 +++++++ benchmarks/sglang_vs_vllm/train_ddp.py | 346 ++++ .../sglang_vs_vllm/unsloth_sglang_service.py | 1137 +++++++++++++ dev/math-vista/math-vista.ipynb | 2 +- dev/math-vista/math-vista.py | 2 +- dev/new_models/benchmark_inference.py | 8 +- dev/new_models/gemma3.py | 2 +- dev/new_models/qwen3_try.ipynb | 2 +- dev/new_models/qwen3_try.py | 2 +- dev/yes-no-maybe-vision/train.ipynb | 2 +- dev/yes-no-maybe.ipynb | 2 +- dev/yes-no-maybe.py | 2 +- docs/fundamentals/art-client.mdx | 4 +- docs/integrations/langgraph-integration.mdx | 6 +- docs/sglang-integration.md | 81 +- docs/tutorials/open-deep-research.mdx | 15 +- examples/2048/rollout.py | 2 +- .../just-the-facts/just_the_facts/checks.py | 13 +- .../just-the-facts/just_the_facts/rollout.py | 2 +- examples/mcp-rl/mcp_rl/rollout.py | 4 +- examples/prisoners-dilemma.ipynb | 14 +- .../temporal-clue-7b-async.ipynb | 2 +- examples/temporal_clue/temporal-clue-7b.ipynb | 2 +- examples/temporal_clue/temporal-clue.py | 2 +- pyproject.toml | 2 +- scripts/benchmark_2048_rollout.py | 509 ++++++ scripts/benchmark_inference.py | 638 ++++++++ scripts/benchmark_rl_cost.py | 723 +++++++++ scripts/benchmark_rollout_cost.py | 463 ++++++ scripts/benchmark_sglang_vs_vllm.py | 588 +++++++ scripts/setup.sh | 36 +- skypilot-config.yaml | 1 - src/art/__init__.py | 26 +- src/art/dev/openai_server.py | 14 +- src/art/local/backend.py | 48 +- .../binary_prefix_tool_pipeline.py | 2 +- src/art/preprocessing/tokenize.py | 64 +- src/art/rewards/ruler.py | 9 +- src/art/serverless/backend.py | 238 +-- src/art/tinker/prefix_cache.py | 7 +- src/art/tinker/service.py | 5 + src/art/tinker_native/backend.py | 188 +-- src/art/trajectories.py | 2 +- src/art/unsloth/service.py | 3 - src/art/unsloth/train.py | 633 ++++---- src/art/utils/__init__.py | 2 - .../log_constant_metrics_wandb.py | 3 +- src/art/utils/deployment/common.py | 1 - src/art/utils/deployment/wandb.py | 10 +- src/art/vllm/patches.py | 8 +- src/art/vllm/server.py | 14 +- .../integration/test_tinker_native_backend.py | 102 -- tests/test_backend_train_api.py | 52 - tests/unit/test_multi_checkpoint_inference.py | 76 +- tests/unit/test_trajectory_parquet.py | 3 +- uv.lock | 1429 +++++++++-------- 66 files changed, 8243 insertions(+), 1798 deletions(-) mode change 120000 => 100644 CLAUDE.md create mode 100644 benchmarks/__init__.py create mode 100644 benchmarks/sglang_vs_vllm/README.md create mode 100644 benchmarks/sglang_vs_vllm/__init__.py create mode 100644 benchmarks/sglang_vs_vllm/config.py create mode 100644 benchmarks/sglang_vs_vllm/metrics_collector.py create mode 100755 benchmarks/sglang_vs_vllm/run_benchmark.py create mode 100755 benchmarks/sglang_vs_vllm/setup_environments.sh create mode 100644 benchmarks/sglang_vs_vllm/sglang_server.py create mode 100644 benchmarks/sglang_vs_vllm/train_ddp.py create mode 100644 benchmarks/sglang_vs_vllm/unsloth_sglang_service.py create mode 100644 scripts/benchmark_2048_rollout.py create mode 100644 scripts/benchmark_inference.py create mode 100644 scripts/benchmark_rl_cost.py create mode 100644 scripts/benchmark_rollout_cost.py create mode 100644 scripts/benchmark_sglang_vs_vllm.py diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 120000 index ac534a310..000000000 --- a/CLAUDE.md +++ /dev/null @@ -1 +0,0 @@ -AGENT.md \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..c98e47341 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,32 @@ +## uv package manager by default + +This project uses the `uv` package manager. + +- To add a dependency, run `uv add `. +- To run a script, run `uv run