diff --git a/src/art/dev/__init__.py b/src/art/dev/__init__.py index 8e029139..11153d88 100644 --- a/src/art/dev/__init__.py +++ b/src/art/dev/__init__.py @@ -10,6 +10,7 @@ ) from .openai_server import OpenAIServerConfig, ServerArgs, get_openai_server_config from .train import TrainConfig, TrainSFTConfig +from .validate import is_dedicated_mode, validate_dedicated_config __all__ = [ "EngineArgs", @@ -21,8 +22,10 @@ "TinkerTrainingClientArgs", "TrainerArgs", "get_openai_server_config", + "is_dedicated_mode", "OpenAIServerConfig", "ServerArgs", "TrainSFTConfig", "TrainConfig", + "validate_dedicated_config", ] diff --git a/src/art/dev/get_model_config.py b/src/art/dev/get_model_config.py index 71229464..ed75112d 100644 --- a/src/art/dev/get_model_config.py +++ b/src/art/dev/get_model_config.py @@ -1,5 +1,6 @@ from .engine import EngineArgs from .model import InitArgs, InternalModelConfig, PeftArgs, TrainerArgs +from .validate import is_dedicated_mode def get_model_config( @@ -12,13 +13,22 @@ def get_model_config( if config is None: config = InternalModelConfig() - enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True) + dedicated = is_dedicated_mode(config) + + if dedicated: + enable_sleep_mode = False + else: + enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True) + init_args = InitArgs( - fast_inference=False, load_in_4bit=True, max_seq_length=32768, model_name=base_model, ) + # fast_inference triggers in-process vLLM via Unsloth; dedicated mode runs vLLM as a subprocess + if not dedicated: + init_args["fast_inference"] = False + engine_args = EngineArgs( allowed_local_media_path="/tmp", enable_sleep_mode=enable_sleep_mode, @@ -63,10 +73,15 @@ def get_model_config( weight_decay=0.1, ) trainer_args.update(config.get("trainer_args", {})) - return InternalModelConfig( + result = InternalModelConfig( init_args=init_args, engine_args=engine_args, peft_args=peft_args, tinker_args=config.get("tinker_args"), trainer_args=trainer_args, ) + if "trainer_gpu_ids" in config: + result["trainer_gpu_ids"] = config["trainer_gpu_ids"] + if "inference_gpu_ids" in config: + result["inference_gpu_ids"] = config["inference_gpu_ids"] + return result diff --git a/src/art/dev/model.py b/src/art/dev/model.py index 8bd342b8..84a13f1d 100644 --- a/src/art/dev/model.py +++ b/src/art/dev/model.py @@ -115,6 +115,11 @@ class InternalModelConfig(TypedDict, total=False): peft: Arguments for creating an Unsloth PEFT model wrapper. tinker: Arguments for the Tinker training client. trainer: Arguments for the GRPO trainer. + trainer_gpu_ids: GPU IDs for training (e.g., [0]). When set with + inference_gpu_ids, enables dedicated mode where training and + inference run on separate GPUs. + inference_gpu_ids: GPU IDs for vLLM inference (e.g., [1]). When set + with trainer_gpu_ids, enables dedicated mode. """ init_args: "InitArgs" @@ -123,6 +128,8 @@ class InternalModelConfig(TypedDict, total=False): tinker_args: "TinkerArgs | None" tinker_native_args: "TinkerNativeArgs | None" trainer_args: "TrainerArgs" + trainer_gpu_ids: list[int] + inference_gpu_ids: list[int] class TinkerArgs(TypedDict, total=False): diff --git a/src/art/dev/validate.py b/src/art/dev/validate.py new file mode 100644 index 00000000..031464e0 --- /dev/null +++ b/src/art/dev/validate.py @@ -0,0 +1,67 @@ +"""Validation functions for model configuration.""" + +from .model import InternalModelConfig + + +def is_dedicated_mode(config: InternalModelConfig) -> bool: + """Return True if the config specifies dedicated mode (separate training and inference GPUs).""" + return "trainer_gpu_ids" in config and "inference_gpu_ids" in config + + +def validate_dedicated_config(config: InternalModelConfig) -> None: + """Validate dedicated mode GPU configuration. + + Raises ValueError if the configuration is invalid. + Does nothing if neither trainer_gpu_ids nor inference_gpu_ids is set (shared mode). + """ + has_trainer = "trainer_gpu_ids" in config + has_inference = "inference_gpu_ids" in config + + if has_trainer != has_inference: + raise ValueError( + "trainer_gpu_ids and inference_gpu_ids must both be set or both unset" + ) + + if not has_trainer: + return + + trainer_gpu_ids = config["trainer_gpu_ids"] + inference_gpu_ids = config["inference_gpu_ids"] + + if not trainer_gpu_ids: + raise ValueError("trainer_gpu_ids must be non-empty") + + if not inference_gpu_ids: + raise ValueError("inference_gpu_ids must be non-empty") + + if set(trainer_gpu_ids) & set(inference_gpu_ids): + raise ValueError("trainer_gpu_ids and inference_gpu_ids must not overlap") + + if len(inference_gpu_ids) > 1: + raise ValueError( + "Multi-GPU inference not yet supported; inference_gpu_ids must have exactly one GPU" + ) + + if trainer_gpu_ids[0] != 0: + raise ValueError( + "trainer_gpu_ids must start at GPU 0 (training runs in-process)" + ) + + expected = list(range(len(trainer_gpu_ids))) + if trainer_gpu_ids != expected: + raise ValueError( + "trainer_gpu_ids must be contiguous starting from 0 (e.g., [0], [0,1])" + ) + + # Reject settings that are incompatible with dedicated mode + if config.get("init_args", {}).get("fast_inference"): + raise ValueError( + "fast_inference is incompatible with dedicated mode " + "(dedicated mode runs vLLM as a subprocess, not in-process)" + ) + + if config.get("engine_args", {}).get("enable_sleep_mode"): + raise ValueError( + "enable_sleep_mode is incompatible with dedicated mode " + "(dedicated mode runs vLLM on a separate GPU, sleep/wake is not needed)" + ) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 35615124..b35727e9 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -1,5 +1,6 @@ import asyncio import json +import logging import math import os import shutil @@ -9,6 +10,8 @@ from typing import AsyncIterator, Iterable, Literal, cast import warnings +logger = logging.getLogger(__name__) + import aiohttp import numpy as np from openai import AsyncOpenAI @@ -97,6 +100,9 @@ async def close(self) -> None: def _close(self) -> None: for _, service in self._services.items(): + close = getattr(service, "close", None) + if close is not None: + close() close_proxy(service) async def register( @@ -140,11 +146,29 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str: # For LocalBackend, vLLM always serves LoRA adapters with @step suffix # Default to step 0 when not specified (the initial checkpoint created at registration) - actual_step = step if step is not None else self.__get_step(model) - return f"{model.name}@{actual_step}" + if step is not None: + actual_step = step + elif model.name in self._services: + # In dedicated mode the service tracks which adapter vLLM has + # actually loaded. Reading the filesystem would race: the + # checkpoint directory appears before the HTTP reload completes. + svc = self._services[model.name] + loaded_step = getattr(svc, "_latest_step", None) + actual_step = ( + loaded_step if loaded_step is not None else self.__get_step(model) + ) + else: + actual_step = self.__get_step(model) + name = f"{model.name}@{actual_step}" + logger.debug( + f"[BACKEND] _model_inference_name: step_arg={step} " + f"actual_step={actual_step} -> {name}" + ) + return name async def _get_service(self, model: TrainableModel) -> ModelService: from ..dev.get_model_config import get_model_config + from ..dev.validate import is_dedicated_mode, validate_dedicated_config if model.name not in self._services: config = get_model_config( @@ -152,6 +176,9 @@ async def _get_service(self, model: TrainableModel) -> ModelService: output_dir=get_model_dir(model=model, art_path=self._path), config=model._internal_config, ) + validate_dedicated_config(config) + dedicated = is_dedicated_mode(config) + is_tinker = config.get("tinker_args") is not None if is_tinker: from ..tinker.service import TinkerService @@ -164,13 +191,19 @@ async def _get_service(self, model: TrainableModel) -> ModelService: # When moving the service to a child process, import unsloth # early to maximize optimizations os.environ["IMPORT_UNSLOTH"] = "1" + + if dedicated: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + str(g) for g in config["trainer_gpu_ids"] + ) + self._services[model.name] = service_class( model_name=model.name, base_model=model.base_model, config=config, output_dir=get_model_dir(model=model, art_path=self._path), ) - if not self._in_process: + if not dedicated and not self._in_process: # Kill all "model-service" processes to free up GPU memory subprocess.run(["pkill", "-9", "model-service"]) self._services[model.name] = move_to_child_process( @@ -585,6 +618,10 @@ async def _train_model( # Still advance the step by renaming the checkpoint directory current_step = self.__get_step(model) next_step = current_step + 1 + logger.info( + f"[BACKEND] _train_model SKIP: current_step={current_step} " + f"next_step={next_step} (all rewards equal)" + ) current_checkpoint_dir = get_step_checkpoint_dir( get_model_dir(model=model, art_path=self._path), current_step ) @@ -599,8 +636,9 @@ async def _train_model( next_checkpoint_dir, dirs_exist_ok=True, ) - print( - f"Advanced step from {current_step} to {next_step} (no training occurred)" + logger.info( + f"[BACKEND] _train_model SKIP: copied checkpoint " + f"{current_step} -> {next_step}, calling register_lora_for_step..." ) try: @@ -610,6 +648,10 @@ async def _train_model( await service.register_lora_for_step( # type: ignore[attr-defined] next_step, next_checkpoint_dir ) + logger.info( + f"[BACKEND] _train_model SKIP: register_lora_for_step " + f"completed for step {next_step}" + ) except ModuleNotFoundError: pass # Unsloth is not installed diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index 2417cff9..cfb95b3c 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -1,9 +1,13 @@ """Unsloth training service with decoupled vLLM inference.""" import asyncio -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import cached_property +import json +import logging import os +import subprocess +import sys from typing import TYPE_CHECKING, Any, AsyncIterator, Literal, Protocol, cast from datasets import Dataset @@ -17,6 +21,7 @@ from vllm.v1.engine.async_llm import AsyncLLM from .. import dev, types +from ..dev.validate import is_dedicated_mode from ..local.checkpoints import get_last_checkpoint_dir from ..preprocessing.inputs import TrainInputs, create_train_inputs from ..preprocessing.pack import ( @@ -30,6 +35,8 @@ from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers from .train import gc_and_empty_cuda_cache, train +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from peft.peft_model import PeftModelForCausalLM from trl import GRPOTrainer @@ -265,27 +272,179 @@ class UnslothService: _last_training_mode: Literal["sft", "rl"] | None = None _latest_step: int = 0 _lora_id_counter: int = 1 # Start from 1 since 0 is reserved + # Dedicated mode subprocess state + _vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg] + _vllm_log_file: Any = field(default=None, repr=False) + _vllm_host: str = "127.0.0.1" + _vllm_port: int = 0 + + @property + def is_dedicated(self) -> bool: + return is_dedicated_mode(self.config) def _next_lora_id(self) -> int: """Return a new unique LoRA ID to avoid collisions in vLLM.""" self._lora_id_counter += 1 return self._lora_id_counter + # ========================================================================= + # Dedicated mode: vLLM subprocess lifecycle + # ========================================================================= + + async def _start_vllm_subprocess( + self, + lora_path: str, + port: int, + config: dev.OpenAIServerConfig | None = None, + ) -> tuple[str, int]: + """Launch vLLM as a subprocess on inference GPUs. Returns (host, port).""" + import atexit + + inference_gpu_ids = self.config["inference_gpu_ids"] + cuda_devices = ",".join(str(g) for g in inference_gpu_ids) + + # Build server_args: ART defaults, then user overrides, strip CLI-handled keys + server_args: dict[str, object] = { + "return_tokens_as_token_ids": True, + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + } + if config and "server_args" in config: + server_args.update(dict(config["server_args"])) + for key in ("port", "host", "lora_modules", "api_key"): + server_args.pop(key, None) + + # Build engine_args: model-level config, then user server overrides, + # add dedicated-mode defaults, strip CLI-handled keys + engine_args = dict(self.config.get("engine_args", {})) + if config and "engine_args" in config: + engine_args.update(dict(config["engine_args"])) + engine_args.setdefault("generation_config", "vllm") + engine_args["enable_lora"] = True + engine_args.setdefault("max_loras", 2) + for key in ("model", "served_model_name", "enable_sleep_mode"): + engine_args.pop(key, None) + + cmd = [ + sys.executable, + "-m", + "art.vllm.dedicated_server", + f"--model={self.base_model}", + f"--port={port}", + f"--host={self._vllm_host}", + f"--cuda-visible-devices={cuda_devices}", + f"--lora-path={lora_path}", + f"--served-model-name={self.model_name}@{self._latest_step}", + f"--engine-args-json={json.dumps(engine_args)}", + f"--server-args-json={json.dumps(server_args)}", + ] + + log_dir = os.path.join(self.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + self._vllm_log_file = open( + os.path.join(log_dir, "vllm-dedicated.log"), "w", buffering=1 + ) + + self._vllm_process = subprocess.Popen( + cmd, stdout=self._vllm_log_file, stderr=subprocess.STDOUT, bufsize=1 + ) + self._vllm_port = port + + import httpx + + timeout = float(os.environ.get("ART_DEDICATED_VLLM_TIMEOUT", 600)) + poll_interval = 1.0 + elapsed = 0.0 + async with httpx.AsyncClient() as client: + while elapsed < timeout: + if self._vllm_process.poll() is not None: + raise RuntimeError( + f"vLLM subprocess exited with code {self._vllm_process.returncode}. " + f"Check logs at {log_dir}/vllm-dedicated.log" + ) + try: + resp = await client.get( + f"http://{self._vllm_host}:{self._vllm_port}/v1/models", + timeout=5.0, + ) + if resp.status_code == 200: + break + except (httpx.ConnectError, httpx.ReadTimeout): + pass + await asyncio.sleep(poll_interval) + elapsed += poll_interval + else: + self.close() + raise TimeoutError( + f"vLLM subprocess did not become ready within {timeout}s. " + f"Check logs at {log_dir}/vllm-dedicated.log" + ) + + atexit.register(self.close) + logger.info("vLLM subprocess ready on port %d (GPUs: %s)", port, cuda_devices) + return self._vllm_host, self._vllm_port + + async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: + """Reload LoRA adapter in vLLM subprocess via HTTP.""" + import httpx + + lora_name = f"{self.model_name}@{step}" + logger.info( + f"[DEDICATED] _reload_adapter START: lora_name={lora_name} " + f"path={checkpoint_path}" + ) + async with httpx.AsyncClient() as client: + response = await client.post( + f"http://{self._vllm_host}:{self._vllm_port}/v1/load_lora_adapter", + json={ + "lora_name": lora_name, + "lora_path": checkpoint_path, + "load_inplace": True, + }, + timeout=60.0, + ) + response.raise_for_status() + logger.info( + f"[DEDICATED] _reload_adapter DONE: lora_name={lora_name} " + f"status={response.status_code}" + ) + + def close(self) -> None: + """Terminate vLLM subprocess if running.""" + if self._vllm_process is None: + return + self._vllm_process.terminate() + try: + self._vllm_process.wait(timeout=5) + except subprocess.TimeoutExpired: + self._vllm_process.kill() + self._vllm_process.wait() + self._vllm_process = None + if self._vllm_log_file is not None: + self._vllm_log_file.close() + self._vllm_log_file = None + + # ========================================================================= + # start_openai_server + # ========================================================================= + async def start_openai_server( self, config: dev.OpenAIServerConfig | None ) -> tuple[str, int]: lora_path = get_last_checkpoint_dir(self.output_dir) if lora_path is None: - # Create initial LoRA checkpoint if none exists lora_path = get_step_checkpoint_dir(self.output_dir, 0) os.makedirs(os.path.dirname(lora_path), exist_ok=True) self._state.trainer.save_model(lora_path) self._latest_step = 0 else: - # Extract step from checkpoint path self._latest_step = get_step_from_dir(self.output_dir) - # Offload training model to CPU before vLLM starts to free GPU memory + if self.is_dedicated: + port = (config or {}).get("server_args", {}).get("port", 8000) + return await self._start_vllm_subprocess(lora_path, port, config=config) + + # Shared mode: in-process vLLM self._state.offload_to_cpu() server_config = dev.get_openai_server_config( @@ -304,12 +463,23 @@ async def start_openai_server( ) or "0.0.0.0", server_config.get("server_args", {}).get("port", 8000) async def vllm_engine_is_sleeping(self) -> bool: + if self.is_dedicated: + return False return self._is_sleeping async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: """Register a LoRA adapter for a specific checkpoint step. This is called when training is skipped but the checkpoint is renamed. """ + logger.info( + f"[DEDICATED] register_lora_for_step called: step={step} " + f"checkpoint_dir={checkpoint_dir} is_dedicated={self.is_dedicated}" + ) + if self.is_dedicated: + await self._reload_adapter(checkpoint_dir, step) + self._latest_step = step + return + llm = await self.llm await llm.pause_generation() added = await llm.add_lora( @@ -353,6 +523,86 @@ async def train( _config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: + if self.is_dedicated: + async for result in self._train_dedicated( + disk_packed_tensors, config, _config, verbose + ): + yield result + return + + async for result in self._train_shared( + disk_packed_tensors, config, _config, verbose + ): + yield result + + async def _train_dedicated( + self, + disk_packed_tensors: DiskPackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + """Train in dedicated mode — no sleep/wake, vLLM keeps running on separate GPU.""" + self._reset_optimizer_if_mode_changed("rl") + + rl_weight_decay = 0.1 + for param_group in self._state.trainer.optimizer.param_groups: + param_group["weight_decay"] = rl_weight_decay + + packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) + + await self._state.results_queue.join() + + if not hasattr(self, "_train_task") or self._train_task is None: + self._train_task = asyncio.create_task( + train( + trainer=self._state.trainer, + results_queue=self._state.results_queue, + ) + ) + warmup = True + else: + warmup = False + + async for result in process_train_batch( + packed_tensors=packed_tensors, + config=config, + _config=_config, + inputs_queue=self._state.inputs_queue, + results_queue=self._state.results_queue, + train_task=self._train_task, + trainer=self._state.trainer, + peft_model=self._state.peft_model, + warmup=warmup, + verbose=verbose, + ): + yield result + + checkpoint_dir = save_checkpoint( + trainer=self._state.trainer, + output_dir=self.output_dir, + verbose=verbose, + ) + + new_step = int(os.path.basename(checkpoint_dir)) + logger.info( + f"[DEDICATED] _train_dedicated: saved checkpoint step={new_step}, " + f"reloading adapter..." + ) + await self._reload_adapter(checkpoint_dir, new_step) + self._latest_step = new_step + logger.info( + f"[DEDICATED] _train_dedicated: adapter reloaded for step {new_step}" + ) + + async def _train_shared( + self, + disk_packed_tensors: DiskPackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + """Train in shared mode — sleep/wake cycle with in-process vLLM.""" llm = await self.llm # Pause generation to prevent new requests during training @@ -481,6 +731,10 @@ async def train_sft( Yields: Dictionary containing training metrics for each batch. """ + if self.is_dedicated: + raise NotImplementedError( + "train_sft is not yet supported in dedicated mode" + ) import time llm = await self.llm diff --git a/src/art/vllm/dedicated_server.py b/src/art/vllm/dedicated_server.py new file mode 100644 index 00000000..72e60cae --- /dev/null +++ b/src/art/vllm/dedicated_server.py @@ -0,0 +1,98 @@ +"""Dedicated vLLM subprocess entry point. + +Launched by UnslothService in dedicated mode as: + python -m art.vllm.dedicated_server --model --port ... + +Sets CUDA_VISIBLE_DEVICES and applies ART patches before starting vLLM. +Must be imported/run as a standalone process — not imported into the main training process. +""" + +import argparse +import asyncio +import json +import os + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="ART dedicated vLLM server") + parser.add_argument("--model", required=True, help="Base model name or path") + parser.add_argument("--port", type=int, required=True) + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--cuda-visible-devices", required=True) + parser.add_argument("--lora-path", required=True, help="Initial LoRA adapter path") + parser.add_argument("--served-model-name", required=True) + parser.add_argument( + "--engine-args-json", default="{}", help="Additional engine args as JSON" + ) + parser.add_argument( + "--server-args-json", + default="{}", + help="Additional server args as JSON (tool_call_parser, etc.)", + ) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> None: + args = parse_args(argv) + + # Must set CUDA_VISIBLE_DEVICES before any torch/CUDA import + os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices + os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "1" + + # Patches must be applied before vLLM's api_server is imported + from .patches import ( + patch_listen_for_disconnect, + patch_tool_parser_manager, + subclass_chat_completion_request, + ) + + subclass_chat_completion_request() + patch_listen_for_disconnect() + patch_tool_parser_manager() + + from vllm.entrypoints.openai import api_server + from vllm.entrypoints.openai.cli_args import ( + make_arg_parser, + validate_parsed_serve_args, + ) + from vllm.utils.argparse_utils import FlexibleArgumentParser + + engine_args = json.loads(args.engine_args_json) + server_args = json.loads(args.server_args_json) + + vllm_args = [ + f"--model={args.model}", + f"--port={args.port}", + f"--host={args.host}", + f"--served-model-name={args.served_model_name}", + "--enable-lora", + f"--lora-modules={args.served_model_name}={args.lora_path}", + ] + for extra_args in (engine_args, server_args): + for key, value in extra_args.items(): + if value is None: + continue + cli_key = f"--{key.replace('_', '-')}" + if isinstance(value, bool): + if value: + vllm_args.append(cli_key) + elif isinstance(value, list): + for item in value: + vllm_args.append(f"{cli_key}={item}") + else: + vllm_args.append(f"{cli_key}={value}") + + vllm_parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server." + ) + vllm_parser = make_arg_parser(vllm_parser) + namespace = vllm_parser.parse_args(vllm_args) + validate_parsed_serve_args(namespace) + + # stdout/stderr are captured to a log file by the parent process, + # so no separate uvicorn file handler is needed here. + asyncio.run(api_server.run_server(namespace)) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_dedicated_config.py b/tests/unit/test_dedicated_config.py new file mode 100644 index 00000000..3de780ef --- /dev/null +++ b/tests/unit/test_dedicated_config.py @@ -0,0 +1,175 @@ +"""Unit tests for dedicated mode config validation and get_model_config integration.""" + +import tempfile + +import pytest + +from art.dev.model import InternalModelConfig +from art.dev.validate import is_dedicated_mode, validate_dedicated_config + + +def test_shared_mode_empty_config(): + config = InternalModelConfig() + assert is_dedicated_mode(config) is False + + +def test_shared_mode_with_other_keys(): + config = InternalModelConfig(init_args={"model_name": "test"}) # type: ignore[typeddict-item] + assert is_dedicated_mode(config) is False + + +def test_dedicated_mode_detected(): + config = InternalModelConfig(trainer_gpu_ids=[0], inference_gpu_ids=[1]) + assert is_dedicated_mode(config) is True + + +def test_valid_shared_mode(): + validate_dedicated_config(InternalModelConfig()) + + +def test_valid_dedicated_two_gpus(): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[0], inference_gpu_ids=[1]) + ) + + +def test_valid_dedicated_three_gpus(): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[0, 1], inference_gpu_ids=[2]) + ) + + +def test_valid_dedicated_four_gpus(): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[0, 1, 2], inference_gpu_ids=[3]) + ) + + +def test_only_trainer_gpu_ids(): + with pytest.raises(ValueError, match="must both be set or both unset"): + validate_dedicated_config(InternalModelConfig(trainer_gpu_ids=[0])) + + +def test_only_inference_gpu_ids(): + with pytest.raises(ValueError, match="must both be set or both unset"): + validate_dedicated_config(InternalModelConfig(inference_gpu_ids=[1])) + + +def test_empty_trainer_gpu_ids(): + with pytest.raises(ValueError, match="trainer_gpu_ids must be non-empty"): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[], inference_gpu_ids=[1]) + ) + + +def test_empty_inference_gpu_ids(): + with pytest.raises(ValueError, match="inference_gpu_ids must be non-empty"): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[0], inference_gpu_ids=[]) + ) + + +def test_overlapping_gpu_ids(): + with pytest.raises(ValueError, match="must not overlap"): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[0, 1], inference_gpu_ids=[1]) + ) + + +def test_multi_gpu_inference(): + with pytest.raises(ValueError, match="Multi-GPU inference not yet supported"): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[0], inference_gpu_ids=[1, 2]) + ) + + +def test_trainer_not_starting_at_zero(): + with pytest.raises(ValueError, match="must start at GPU 0"): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[1], inference_gpu_ids=[0]) + ) + + +def test_trainer_not_contiguous(): + with pytest.raises(ValueError, match="must be contiguous starting from 0"): + validate_dedicated_config( + InternalModelConfig(trainer_gpu_ids=[0, 2], inference_gpu_ids=[1]) + ) + + +def test_dedicated_rejects_fast_inference(): + with pytest.raises( + ValueError, match="fast_inference is incompatible with dedicated" + ): + validate_dedicated_config( + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + init_args={"fast_inference": True}, # type: ignore[typeddict-item] + ) + ) + + +def test_dedicated_rejects_enable_sleep_mode(): + with pytest.raises( + ValueError, match="enable_sleep_mode is incompatible with dedicated" + ): + validate_dedicated_config( + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + engine_args={"enable_sleep_mode": True}, # type: ignore[typeddict-item] + ) + ) + + +def test_dedicated_allows_fast_inference_false(): + """fast_inference=False is fine in dedicated mode (it's the intended state).""" + validate_dedicated_config( + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + init_args={"fast_inference": False}, # type: ignore[typeddict-item] + ) + ) + + +def test_get_model_config_shared_mode(): + from art.dev.get_model_config import get_model_config + + with tempfile.TemporaryDirectory() as tmpdir: + result = get_model_config("test-model", tmpdir, None) + assert "trainer_gpu_ids" not in result + assert "inference_gpu_ids" not in result + assert result["engine_args"]["enable_sleep_mode"] is True + assert result["init_args"].get("fast_inference") is False + + +def test_get_model_config_dedicated_mode(): + from art.dev.get_model_config import get_model_config + + with tempfile.TemporaryDirectory() as tmpdir: + config = InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + ) + result = get_model_config("test-model", tmpdir, config) + assert result["trainer_gpu_ids"] == [0] + assert result["inference_gpu_ids"] == [1] + assert result["engine_args"]["enable_sleep_mode"] is False + assert "fast_inference" not in result["init_args"] + + +def test_get_model_config_dedicated_preserves_user_engine_args(): + from art.dev.get_model_config import get_model_config + + with tempfile.TemporaryDirectory() as tmpdir: + config = InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + engine_args={"max_model_len": 4096}, # type: ignore[typeddict-item] + ) + result = get_model_config("test-model", tmpdir, config) + assert result["engine_args"]["max_model_len"] == 4096 + # Sleep mode should still be disabled even if user didn't set it + assert result["engine_args"]["enable_sleep_mode"] is False diff --git a/tests/unit/test_dedicated_server.py b/tests/unit/test_dedicated_server.py new file mode 100644 index 00000000..0acf7baa --- /dev/null +++ b/tests/unit/test_dedicated_server.py @@ -0,0 +1,97 @@ +"""Unit tests for dedicated vLLM server entry point.""" + +import pytest + +pytest.importorskip("cloudpickle") +pytest.importorskip("vllm") + +from art.vllm.dedicated_server import parse_args + + +def test_parse_args_required(): + args = parse_args( + [ + "--model", + "Qwen/Qwen3-14B", + "--port", + "8000", + "--cuda-visible-devices", + "1", + "--lora-path", + "/tmp/checkpoints/0000", + "--served-model-name", + "my-model@0", + ] + ) + assert args.model == "Qwen/Qwen3-14B" + assert args.port == 8000 + assert args.cuda_visible_devices == "1" + assert args.lora_path == "/tmp/checkpoints/0000" + assert args.served_model_name == "my-model@0" + assert args.host == "127.0.0.1" + assert args.engine_args_json == "{}" + assert args.server_args_json == "{}" + + +def test_parse_args_with_engine_args(): + args = parse_args( + [ + "--model", + "test-model", + "--port", + "9000", + "--cuda-visible-devices", + "2", + "--lora-path", + "/tmp/lora", + "--served-model-name", + "test@1", + "--engine-args-json", + '{"max_model_len": 4096}', + ] + ) + assert args.engine_args_json == '{"max_model_len": 4096}' + + +def test_parse_args_custom_host(): + args = parse_args( + [ + "--model", + "test-model", + "--port", + "8000", + "--cuda-visible-devices", + "0", + "--lora-path", + "/tmp/lora", + "--served-model-name", + "test@0", + "--host", + "0.0.0.0", + ] + ) + assert args.host == "0.0.0.0" + + +def test_parse_args_with_server_args(): + args = parse_args( + [ + "--model", + "test-model", + "--port", + "8000", + "--cuda-visible-devices", + "1", + "--lora-path", + "/tmp/lora", + "--served-model-name", + "test@0", + "--server-args-json", + '{"enable_auto_tool_choice": true, "tool_call_parser": "hermes"}', + ] + ) + import json + + server_args = json.loads(args.server_args_json) + assert server_args["enable_auto_tool_choice"] is True + assert server_args["tool_call_parser"] == "hermes"