From bff6a175cc4d49182d7ea6fc5a4adc180c76222d Mon Sep 17 00:00:00 2001 From: XyLearningProgramming Date: Mon, 23 Feb 2026 14:59:25 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=F0=9F=90=9B=20fixed=20tool=20call=20by=20a?= =?UTF-8?q?dding=20true=20state=20machine?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 + Makefile | 6 +- pyproject.toml | 5 +- scripts/smoke.sh | 102 ++- slm_server/app.py | 66 +- slm_server/config.py | 15 + slm_server/model.py | 340 +++++++++- slm_server/utils/__init__.py | 1 + slm_server/utils/constants.py | 2 + slm_server/utils/ids.py | 20 + slm_server/utils/postprocess.py | 357 ++++++++++ slm_server/utils/spans.py | 41 +- swagger/openapi.yaml | 1080 +++++++++++++++++++++++++++++++ tests/test_app.py | 2 +- tests/test_postprocess.py | 503 ++++++++++++++ uv.lock | 7 +- 16 files changed, 2495 insertions(+), 55 deletions(-) create mode 100644 slm_server/utils/ids.py create mode 100644 slm_server/utils/postprocess.py create mode 100644 swagger/openapi.yaml create mode 100644 tests/test_postprocess.py diff --git a/.gitignore b/.gitignore index 377327f..f3726a3 100644 --- a/.gitignore +++ b/.gitignore @@ -214,3 +214,6 @@ models/* # IDE related .vscode + +# Custom logs +logs/ diff --git a/Makefile b/Makefile index 1f29c4f..c658360 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: dev run download install lint format check test smoke clean help +.PHONY: dev run download install lint format check test smoke swagger clean help help: ## Show this help @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}' @@ -30,6 +30,10 @@ smoke: ## Smoke-test the running server APIs with curl test: ## Run tests with coverage uv run pytest tests/ -v --cov=slm_server --cov-report=term-missing +swagger: ## Refresh OpenAPI spec from running server + curl -sf http://localhost:8000/openapi.json | uv run python -c "import sys,json,yaml;yaml.dump(json.load(sys.stdin),sys.stdout,default_flow_style=False,sort_keys=False,allow_unicode=True)" > swagger/openapi.yaml + @echo "swagger/openapi.yaml updated" + clean: ## Remove caches and build artifacts rm -rf __pycache__ .pytest_cache .ruff_cache .coverage htmlcov build dist *.egg-info find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true diff --git a/pyproject.toml b/pyproject.toml index 0aae74b..3c33a42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,10 @@ readme = "README.md" requires-python = ">=3.13" dependencies = [ "fastapi>=0.116.1", - "llama-cpp-python>=0.3.13", + # Cherry-picked PR #1884 (streaming tool use) onto latest upstream. + # Upstream llama-cpp-python silently ignores tool_choice when stream=True; + # this fork adds streaming support. + "llama-cpp-python @ git+https://github.com/XyLearningProgramming/llama-cpp-python.git@main", "opentelemetry-instrumentation-logging>=0.50b0", "opentelemetry-instrumentation-fastapi>=0.50b0", "pydantic-settings>=2.10.1", diff --git a/scripts/smoke.sh b/scripts/smoke.sh index 23094fc..3e43ee3 100755 --- a/scripts/smoke.sh +++ b/scripts/smoke.sh @@ -17,7 +17,7 @@ curl -sf "$BASE_URL/api/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "messages": [{"role": "user", "content": "Say hello in one sentence."}], - "max_tokens": 64 + "max_tokens": 512 }' | python3 -m json.tool echo @@ -26,11 +26,109 @@ curl -sf "$BASE_URL/api/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "messages": [{"role": "user", "content": "What is 2+2?"}], - "max_tokens": 32, + "max_tokens": 512, "stream": true }' echo +echo "=== Tool call (no tool_choice, defaults to auto) ===" +TOOL_RESP=$(curl -sf "$BASE_URL/api/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [{"role": "user", "content": "What is the weather in San Francisco? /no_think"}], + "max_tokens": 256, + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name" + } + }, + "required": ["location"] + } + } + } + ] + }') +echo "$TOOL_RESP" | python3 -m json.tool + +# Verify response has structured tool_calls (not raw in content) +echo "$TOOL_RESP" | python3 -c " +import sys, json +resp = json.load(sys.stdin) +choice = resp['choices'][0] +msg = choice['message'] +has_tool = 'tool_calls' in msg and msg['tool_calls'] +has_content = 'content' in msg and msg['content'] +if not has_tool and not has_content: + print('FAIL: no tool_calls and no content'); sys.exit(1) +if has_tool: + tc = msg['tool_calls'][0] + assert tc['type'] == 'function', f'bad type: {tc[\"type\"]}' + assert 'name' in tc['function'], 'missing function name' + assert 'arguments' in tc['function'], 'missing arguments' + assert '' not in (msg.get('content') or ''), 'raw leaked into content' + assert choice['finish_reason'] == 'tool_calls', f'bad finish_reason: {choice[\"finish_reason\"]}' + print(f'tool_calls: {tc[\"function\"][\"name\"]}({tc[\"function\"][\"arguments\"]})') +else: + print(f'content_only: {msg[\"content\"][:80]}...') +" +echo + +echo "=== Tool call streaming (no tool_choice, defaults to auto) ===" +STREAM_RESP=$(curl -sf "$BASE_URL/api/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [{"role": "user", "content": "What is the weather in San Francisco? /no_think"}], + "max_tokens": 256, + "stream": true, + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name" + } + }, + "required": ["location"] + } + } + } + ] + }') +echo "$STREAM_RESP" + +# Verify no raw tags leaked through as content +echo "$STREAM_RESP" | python3 -c " +import sys +raw = sys.stdin.read() +assert '' not in raw, 'raw tag leaked into stream' +assert '' not in raw, 'raw tag leaked into stream' +# Check for structured tool_calls in at least one chunk +has_tool_calls = '\"tool_calls\"' in raw +has_content = '\"content\"' in raw +if has_tool_calls: + print('streaming tool_calls: structured delta found') +elif has_content: + print('streaming tool_calls: content_only (model chose not to call tool)') +else: + print('FAIL: no tool_calls and no content in stream'); sys.exit(1) +" +echo + echo "=== Embeddings (single) ===" curl -sf "$BASE_URL/api/v1/embeddings" \ -H "Content-Type: application/json" \ diff --git a/slm_server/app.py b/slm_server/app.py index c3ebb29..a5a10ea 100644 --- a/slm_server/app.py +++ b/slm_server/app.py @@ -9,12 +9,14 @@ from fastapi.responses import StreamingResponse from llama_cpp import CreateChatCompletionStreamResponse, Llama -from slm_server.config import Settings, get_settings +from slm_server.config import Settings, get_model_id, get_settings from slm_server.embedding import OnnxEmbeddingModel from slm_server.logging import setup_logging from slm_server.metrics import setup_metrics from slm_server.model import ( + ChatCompletionChunkResponse, ChatCompletionRequest, + ChatCompletionResponse, EmbeddingData, EmbeddingRequest, EmbeddingResponse, @@ -29,14 +31,16 @@ slm_embedding_span, slm_span, ) +from slm_server.utils.postprocess import StreamPostProcessor, postprocess_completion # MAX_CONCURRENCY decides how many threads are calling model. # Default to 1 since llama cpp is designed to be at most efficiency # for single thread. Meanwhile, value larger than 1 allows # threads to compete for same resources. MAX_CONCURRENCY = 1 -# Keeps function calling and also compatible with ReAct agents. -CHAT_FORMAT = "chatml-function-calling" +# Use the model's built-in Jinja chat template from the GGUF metadata, +# which handles tool formatting natively (e.g. Qwen3, Llama 3, etc.). +CHAT_FORMAT = None # Default timeout message in detail field. DETAIL_SEM_TIMEOUT = "Server is busy, please try again later." # Status code for semaphore timeout. @@ -130,7 +134,7 @@ def raise_as_http_exception() -> Generator[Literal[True], None, None]: async def run_llm_streaming( - llm: Llama, req: ChatCompletionRequest + llm: Llama, req: ChatCompletionRequest, *, model_id: str ) -> AsyncGenerator[str, None]: """Generator that runs the LLM and yields SSE chunks under lock.""" with slm_span(req, is_streaming=True) as span: @@ -139,58 +143,79 @@ async def run_llm_streaming( **req.model_dump(), ) - # Use traced iterator that automatically handles chunk spans - # and parent span updates + processor = StreamPostProcessor(model_id=model_id) chunk: CreateChatCompletionStreamResponse for chunk in completion_stream: - set_atrribute_response_stream(span, chunk) - yield f"data: {json.dumps(chunk)}\n\n" - # NOTE: This is a workaround to yield control back to the event loop - # to allow checking for socket after yield and pop in CancelledError. - # Ref: https://github.com/encode/starlette/discussions/1776#discussioncomment-3207518 - await asyncio.sleep(0) + for out_chunk in processor.process_chunk(chunk): + set_atrribute_response_stream(span, out_chunk) + yield f"data: {json.dumps(out_chunk)}\n\n" + # NOTE: yield control back to the event loop so starlette + # can detect client disconnects between chunks. + # Ref: https://github.com/encode/starlette/discussions/1776#discussioncomment-3207518 + await asyncio.sleep(0) + + for out_chunk in processor.flush(): + set_atrribute_response_stream(span, out_chunk) + yield f"data: {json.dumps(out_chunk)}\n\n" yield "data: [DONE]\n\n" -async def run_llm_non_streaming(llm: Llama, req: ChatCompletionRequest): +async def run_llm_non_streaming( + llm: Llama, req: ChatCompletionRequest, *, model_id: str +): """Runs the LLM for a non-streaming request under lock.""" with slm_span(req, is_streaming=False) as span: completion_result = await asyncio.to_thread( llm.create_chat_completion, **req.model_dump(), ) + postprocess_completion(completion_result, model_id=model_id) set_atrribute_response(span, completion_result) return completion_result -@app.post("/api/v1/chat/completions") +@app.post( + "/api/v1/chat/completions", + response_model=ChatCompletionResponse, + responses={ + 200: { + "content": { + STREAM_RESPONSE_MEDIA_TYPE: { + "schema": ChatCompletionChunkResponse.model_json_schema(), + } + }, + }, + }, +) async def create_chat_completion( req: ChatCompletionRequest, llm: Annotated[Llama, Depends(get_llm)], + model_id: Annotated[str, Depends(get_model_id)], _: Annotated[None, Depends(lock_llm_semaphor)], __: Annotated[None, Depends(raise_as_http_exception)], -): +) -> ChatCompletionResponse: """ Generates a chat completion, handling both streaming and non-streaming cases. Concurrency is managed by the `locked_llm_session` context manager. """ if req.stream: return StreamingResponse( - run_llm_streaming(llm, req), media_type=STREAM_RESPONSE_MEDIA_TYPE + run_llm_streaming(llm, req, model_id=model_id), + media_type=STREAM_RESPONSE_MEDIA_TYPE, ) else: - return await run_llm_non_streaming(llm, req) + return await run_llm_non_streaming(llm, req, model_id=model_id) -@app.post("/api/v1/embeddings") +@app.post("/api/v1/embeddings", response_model=EmbeddingResponse) async def create_embeddings( req: EmbeddingRequest, emb_model: Annotated[OnnxEmbeddingModel, Depends(get_embedding_model)], _: Annotated[None, Depends(lock_llm_semaphor)], __: Annotated[None, Depends(raise_as_http_exception)], -): +) -> EmbeddingResponse: """Create embeddings using the dedicated ONNX embedding model.""" with slm_embedding_span(req) as span: inputs = req.input if isinstance(req.input, list) else [req.input] @@ -211,7 +236,6 @@ async def list_models( settings: Annotated[Settings, Depends(get_settings)], ) -> ModelListResponse: """List available models (OpenAI-compatible).""" - chat_model_id = Path(settings.model_path).stem try: chat_created = int(Path(settings.model_path).stat().st_mtime) except (OSError, ValueError): @@ -225,7 +249,7 @@ async def list_models( return ModelListResponse( data=[ ModelInfo( - id=chat_model_id, + id=settings.chat_model_id, created=chat_created, owned_by=settings.model_owner, ), diff --git a/slm_server/config.py b/slm_server/config.py index f0ad876..9942ccc 100644 --- a/slm_server/config.py +++ b/slm_server/config.py @@ -2,6 +2,8 @@ from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict +from typing import Annotated +from fastapi import Depends ENV_PREFIX = "SLM_" @@ -85,6 +87,11 @@ class Settings(BaseSettings): ) model_path: str = Field(MODEL_PATH_DEFAULT, description="Model path for llama_cpp.") + model_id: str = Field( + "", + description="Short model name in API responses (e.g. 'Qwen3-0.6B'). " + "Defaults to the GGUF filename stem when empty.", + ) model_owner: str = Field( MODEL_OWNER_DEFAULT, description="Owner label for /models list. Set SLM_MODEL_OWNER to override.", @@ -103,6 +110,11 @@ class Settings(BaseSettings): 1, description="Seconds to wait if undergoing another inference." ) + @property + def chat_model_id(self) -> str: + """Resolved model identifier: explicit ``model_id`` or GGUF stem.""" + return self.model_id or Path(self.model_path).stem + embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings) logging: LoggingSettings = Field(default_factory=LoggingSettings) metrics: MetricsSettings = Field(default_factory=MetricsSettings) @@ -113,3 +125,6 @@ def get_settings() -> Settings: if not hasattr(get_settings, "_instance"): get_settings._instance = Settings() return get_settings._instance + +def get_model_id(settings: Annotated[Settings, Depends(get_settings)]) -> str: + return settings.chat_model_id diff --git a/slm_server/model.py b/slm_server/model.py index 9168c37..8ffee23 100644 --- a/slm_server/model.py +++ b/slm_server/model.py @@ -1,3 +1,5 @@ +from typing import Any, Literal, Self + from llama_cpp.llama_types import ( ChatCompletionFunction, ChatCompletionRequestFunctionCall, @@ -6,7 +8,12 @@ ChatCompletionTool, ChatCompletionToolChoiceOption, ) -from pydantic import BaseModel, Field, conlist +from pydantic import BaseModel, ConfigDict, Field, conlist, model_validator + + +# --------------------------------------------------------------------------- +# Chat completion request +# --------------------------------------------------------------------------- class ChatCompletionRequest(BaseModel): @@ -81,9 +88,274 @@ class ChatCompletionRequest(BaseModel): default=None, description="Number of top log probabilities to return" ) + @model_validator(mode="after") + def _default_tool_choice_auto(self) -> Self: + """Match OpenAI: default to "auto" when tools are present.""" + if self.tools and self.tool_choice is None: + self.tool_choice = "auto" + return self + + +# --------------------------------------------------------------------------- +# Shared field types — aligned with llama_cpp.llama_types +# --------------------------------------------------------------------------- + +# ChatCompletionResponseMessage.role (non-streaming response) +MessageRole = Literal["assistant", "function"] +# ChatCompletionStreamResponseDelta.role (streaming delta) +DeltaRole = Literal["system", "user", "assistant", "tool"] +# ChatCompletionStreamResponseChoice.finish_reason (both paths; our +# postprocessor can also set "tool_calls") +FinishReason = Literal["stop", "length", "tool_calls", "function_call"] + + +# --------------------------------------------------------------------------- +# Chat completion response (non-streaming) +# --------------------------------------------------------------------------- + + +class ToolCallFunction(BaseModel): + name: str + arguments: str + + +class ToolCall(BaseModel): + id: str + type: Literal["function"] = "function" + function: ToolCallFunction + + +class ChatMessage(BaseModel): + role: MessageRole + content: str | None = None + reasoning_content: str | None = None + tool_calls: list[ToolCall] | None = None + + +class CompletionUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionChoice(BaseModel): + index: int + message: ChatMessage + logprobs: Any | None = None + finish_reason: FinishReason | None = None + + +class ChatCompletionResponse(BaseModel): + model_config = ConfigDict( + json_schema_extra={ + "examples": [ + { + "id": "chatcmpl_MInMfQVLOPebdbjF", + "object": "chat.completion", + "created": 1771828648, + "model": "Qwen3-0.6B-Q4_K_M", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hi there!", + "reasoning_content": "The user wants a greeting.", + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 14, + "completion_tokens": 149, + "total_tokens": 163, + }, + }, + { + "id": "chatcmpl_6HEgD7zMer6W3ob8", + "object": "chat.completion", + "created": 1771828660, + "model": "Qwen3-0.6B-Q4_K_M", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_2HWeVAmARMukopsB", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + }, + } + ], + }, + "logprobs": None, + "finish_reason": "tool_calls", + } + ], + "usage": { + "prompt_tokens": 166, + "completion_tokens": 24, + "total_tokens": 190, + }, + }, + ] + } + ) + + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int + model: str + choices: list[ChatCompletionChoice] + usage: CompletionUsage + + +# --------------------------------------------------------------------------- +# Chat completion response (streaming chunks) +# --------------------------------------------------------------------------- + + +class DeltaToolCall(BaseModel): + """Streaming tool call emitted by ``StreamPostProcessor``. + + Our postprocessor accumulates the full ```` block before + emitting, so every field is always present (never partial/incremental). + ``index`` is the ordinal position of this tool call in the response. + """ + + index: int + id: str + type: Literal["function"] = "function" + function: ToolCallFunction + + +class ChatDelta(BaseModel): + role: DeltaRole | None = None + content: str | None = None + reasoning_content: str | None = None + tool_calls: list[DeltaToolCall] | None = None + + +class ChatCompletionChunkChoice(BaseModel): + index: int + delta: ChatDelta + logprobs: Any | None = None + finish_reason: FinishReason | None = None + + +class ChatCompletionChunkResponse(BaseModel): + """Schema for each SSE ``data:`` payload in a streaming chat completion. + + Not used for runtime validation (chunks are pre-serialised dicts), + but exposed here so OpenAPI / Swagger documents the streaming format. + """ + + model_config = ConfigDict( + json_schema_extra={ + "examples": [ + { + "id": "chatcmpl_vGxUUAi7KYLIEaNGb", + "object": "chat.completion.chunk", + "created": 1771828652, + "model": "Qwen3-0.6B-Q4_K_M", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant"}, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl_vGxUUAi7KYLIEaNGb", + "object": "chat.completion.chunk", + "created": 1771828652, + "model": "Qwen3-0.6B-Q4_K_M", + "choices": [ + { + "index": 0, + "delta": {"reasoning_content": "Let me think."}, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl_vGxUUAi7KYLIEaNGb", + "object": "chat.completion.chunk", + "created": 1771828652, + "model": "Qwen3-0.6B-Q4_K_M", + "choices": [ + { + "index": 0, + "delta": {"content": "2 + 2 equals **4**."}, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl_uXmbuefsXElrLqSkb", + "object": "chat.completion.chunk", + "created": 1771828662, + "model": "Qwen3-0.6B-Q4_K_M", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "call_1MSB4PoE3eerSEJu", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + }, + } + ] + }, + "logprobs": None, + "finish_reason": None, + } + ], + }, + { + "id": "chatcmpl_vGxUUAi7KYLIEaNGb", + "object": "chat.completion.chunk", + "created": 1771828652, + "model": "Qwen3-0.6B-Q4_K_M", + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop", + } + ], + }, + ] + } + ) + + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int + model: str + choices: list[ChatCompletionChunkChoice] + + +# --------------------------------------------------------------------------- +# Embeddings API +# --------------------------------------------------------------------------- -# Embeddings API Models -# OpenAI allows up to 2048 inputs per request. MAX_EMBEDDING_INPUTS = 2048 @@ -93,23 +365,51 @@ class EmbeddingRequest(BaseModel): class EmbeddingData(BaseModel): - object: str = Field(default="embedding") + object: Literal["embedding"] = "embedding" embedding: list[float] index: int class EmbeddingResponse(BaseModel): - object: str = Field(default="list") + model_config = ConfigDict( + json_schema_extra={ + "examples": [ + { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [ + -0.046052, + 0.028006, + 0.014284, + 0.025734, + -0.034211, + ], + "index": 0, + } + ], + "model": "all-MiniLM-L6-v2", + } + ] + } + ) + + object: Literal["list"] = "list" data: list[EmbeddingData] model: str -# OpenAI-compatible list models API +# --------------------------------------------------------------------------- +# List models API +# --------------------------------------------------------------------------- + + class ModelInfo(BaseModel): """Single model entry for GET /api/v1/models.""" id: str = Field(description="Model identifier for use in API endpoints") - object: str = Field(default="model", description="Object type") + object: Literal["model"] = "model" created: int = Field(description="Unix timestamp when the model was created") owned_by: str = Field(description="Organization that owns the model") @@ -117,5 +417,29 @@ class ModelInfo(BaseModel): class ModelListResponse(BaseModel): """Response for GET /api/v1/models.""" - object: str = Field(default="list", description="Object type") + model_config = ConfigDict( + json_schema_extra={ + "examples": [ + { + "object": "list", + "data": [ + { + "id": "Qwen3-0.6B-Q4_K_M", + "object": "model", + "created": 1771811606, + "owned_by": "second-state", + }, + { + "id": "all-MiniLM-L6-v2", + "object": "model", + "created": 1771811614, + "owned_by": "sentence-transformers", + }, + ], + } + ] + } + ) + + object: Literal["list"] = "list" data: list[ModelInfo] = Field(description="List of available models") diff --git a/slm_server/utils/__init__.py b/slm_server/utils/__init__.py index 015f0b2..0f5a660 100644 --- a/slm_server/utils/__init__.py +++ b/slm_server/utils/__init__.py @@ -4,3 +4,4 @@ from .processors import * # noqa: F403, F401 from .sampler import * # noqa: F403, F401 from .spans import * # noqa: F403, F401 +from .postprocess import * # noqa: F403, F401 diff --git a/slm_server/utils/constants.py b/slm_server/utils/constants.py index 14247c1..744a21c 100644 --- a/slm_server/utils/constants.py +++ b/slm_server/utils/constants.py @@ -13,6 +13,7 @@ # Event attribute names EVENT_ATTR_CHUNK_SIZE = f"{SPAN_PREFIX}.chunk_size" EVENT_ATTR_CHUNK_CONTENT_SIZE = f"{SPAN_PREFIX}.chunk_content_size" +EVENT_ATTR_FINISH_REASON = f"{SPAN_PREFIX}.finish_reason" # Attribute names ATTR_MODEL = f"{SPAN_PREFIX}.model" @@ -27,6 +28,7 @@ ATTR_PROMPT_TOKENS = f"{SPAN_PREFIX}.usage.prompt_tokens" ATTR_COMPLETION_TOKENS = f"{SPAN_PREFIX}.usage.completion_tokens" ATTR_TOTAL_TOKENS = f"{SPAN_PREFIX}.usage.total_tokens" +ATTR_FINISH_REASON = f"{SPAN_PREFIX}.finish_reason" ATTR_FORCE_SAMPLE = f"{SPAN_PREFIX}.force_sample" # Embedding attributes diff --git a/slm_server/utils/ids.py b/slm_server/utils/ids.py new file mode 100644 index 0000000..d75c14d --- /dev/null +++ b/slm_server/utils/ids.py @@ -0,0 +1,20 @@ +"""Compact ID generation for OpenAI-compatible responses.""" + +from __future__ import annotations + +import os +import string + +_LENGTH=12 +_ALPHABET = string.ascii_letters + string.digits +_BASE = len(_ALPHABET) + + +def gen_id(prefix: str) -> str: + """Return ``prefix_``, e.g. ``chatcmpl_BxK9mQ7r2pNw``.""" + raw = int.from_bytes(os.urandom(_LENGTH)) + chars: list[str] = [] + while raw: + raw, idx = divmod(raw, _BASE) + chars.append(_ALPHABET[idx]) + return f"{prefix}_{''.join(chars)}" diff --git a/slm_server/utils/postprocess.py b/slm_server/utils/postprocess.py new file mode 100644 index 0000000..6b6fe2e --- /dev/null +++ b/slm_server/utils/postprocess.py @@ -0,0 +1,357 @@ +"""Post-process raw model output into OpenAI-compatible response structures. + +Handles two concerns for both non-streaming and streaming responses: +1. **Reasoning** – extracts ```` blocks into ``reasoning_content`` + (de facto standard from DeepSeek, supported by LangChain / LiteLLM / OpenRouter). +2. **Tool calls** – extracts ```` blocks into structured ``tool_calls``. +""" + +from __future__ import annotations + +import copy +import json +import re +from enum import Enum, auto +from typing import Any + +from slm_server.utils.ids import gen_id + +_THINK_RE = re.compile(r"(.*?)", re.DOTALL) +_TOOL_CALL_RE = re.compile( + r"\s*(\{.*?\})\s*", re.DOTALL +) + + +def extract_reasoning(content: str) -> tuple[str | None, str]: + """Split ```` blocks from visible content. + + Returns (reasoning_content, cleaned_content): + * *reasoning_content* – concatenated think-block text, or ``None``. + * *cleaned_content* – original text with think blocks removed. + """ + matches = _THINK_RE.findall(content) + if not matches: + return None, content + + reasoning = "\n".join(m.strip() for m in matches) + cleaned = _THINK_RE.sub("", content).strip() + return reasoning or None, cleaned + + +def parse_tool_calls( + content: str, +) -> tuple[list[dict[str, Any]], str | None]: + """Extract ```` blocks from raw model output. + + Returns (tool_calls, remaining_content): + * *tool_calls* – list of OpenAI-style tool-call dicts, empty when none found. + * *remaining_content* – text left after stripping tool_call blocks, + or ``None`` when nothing meaningful remains. + """ + matches = _TOOL_CALL_RE.findall(content) + if not matches: + return [], content + + tool_calls: list[dict[str, Any]] = [] + for raw_json in matches: + parsed = json.loads(raw_json) + arguments = parsed.get("arguments", {}) + tool_calls.append( + { + "id": gen_id("call"), + "type": "function", + "function": { + "name": parsed["name"], + "arguments": json.dumps(arguments) + if not isinstance(arguments, str) + else arguments, + }, + } + ) + + remaining = _TOOL_CALL_RE.sub("", content).strip() + return tool_calls, remaining or None + + +def postprocess_completion( + response: dict[str, Any], + *, + model_id: str | None = None, +) -> dict[str, Any]: + """Rewrite a chat-completion response in place: + + 1. Normalise ``id`` and ``model`` fields. + 2. Extract ```` blocks → ``message.reasoning_content`` + 3. Extract ```` blocks → ``message.tool_calls`` + """ + response["id"] = gen_id("chatcmpl") + if model_id: + response["model"] = model_id + + for choice in response.get("choices", []): + message = choice.get("message") + if message is None or not message.get("content"): + continue + + content = message["content"] + + reasoning, content = extract_reasoning(content) + if reasoning is not None: + message["reasoning_content"] = reasoning + + tool_calls, content = parse_tool_calls(content) + if tool_calls: + message["tool_calls"] = tool_calls + choice["finish_reason"] = "tool_calls" + + message["content"] = content + + return response + + +# --------------------------------------------------------------------------- +# Streaming post-processing +# --------------------------------------------------------------------------- + +_KNOWN_TAGS = frozenset({"", "", "", ""}) +_MAX_TAG_LEN = max(len(t) for t in _KNOWN_TAGS) + + +class _State(Enum): + CONTENT = auto() + THINKING = auto() + TOOL_CALL = auto() + + +class StreamPostProcessor: + """State machine that rewrites streaming chunks in-flight. + + Raw ``delta.content`` tokens containing ````/```` tags + are intercepted and re-emitted as ``delta.reasoning_content`` or + ``delta.tool_calls`` respectively. + """ + + _TAG_TRANSITIONS: dict[str, _State] = { + "": _State.THINKING, + "": _State.CONTENT, + "": _State.TOOL_CALL, + "": _State.CONTENT, + } + + def __init__(self, *, model_id: str | None = None) -> None: + self._state = _State.CONTENT + self._tag_buf = "" + self._tool_call_buf = "" + self._had_tool_calls = False + self._tool_call_index = 0 + self._last_chunk_template: dict[str, Any] | None = None + self._strip_leading = False + self._model_id = model_id + self._stream_id = gen_id("chatcmpl") + + # -- public API ---------------------------------------------------------- + + def process_chunk( + self, chunk: dict[str, Any] + ) -> list[dict[str, Any]]: + """Process one streaming chunk, returning zero or more output chunks.""" + self._last_chunk_template = chunk + + delta = _get_delta(chunk) + if delta is None: + return [self._stamp(chunk)] + + text = delta.get("content") + if text is None: + finish = _get_finish_reason(chunk) + if finish is not None and self._had_tool_calls: + return [self._rewrite_finish(chunk)] + return [self._stamp(chunk)] + + result = self._consume_text(text, chunk) + + # Propagate finish_reason from the original chunk onto the last output + orig_finish = _get_finish_reason(chunk) + if orig_finish is not None: + if result: + result[-1]["choices"][0]["finish_reason"] = ( + "tool_calls" if self._had_tool_calls else orig_finish + ) + else: + result.append(self._rewrite_finish(chunk) if self._had_tool_calls + else self._stamp(copy.deepcopy(chunk))) + + return result + + def flush(self) -> list[dict[str, Any]]: + """Flush any buffered content at end-of-stream.""" + if not self._tag_buf or self._last_chunk_template is None: + return [] + text = self._tag_buf + self._tag_buf = "" + return self._emit_batch(text) + + # -- internals ----------------------------------------------------------- + + def _consume_text( + self, text: str, chunk: dict[str, Any] + ) -> list[dict[str, Any]]: + """Feed *text* through the tag-detection state machine. + + Batches consecutive characters for the same state into single chunks + to avoid per-character emission. + """ + output: list[dict[str, Any]] = [] + batch = "" + pos = 0 + + while pos < len(text): + ch = text[pos] + + if ch == "<" and not self._tag_buf: + if batch: + output.extend(self._emit_batch(batch)) + batch = "" + self._tag_buf = ch + pos += 1 + continue + + if self._tag_buf: + self._tag_buf += ch + pos += 1 + + matched_tag = self._try_match_tag() + if matched_tag is not None: + new_state = self._TAG_TRANSITIONS[matched_tag] + if ( + matched_tag == "" + and self._state == _State.TOOL_CALL + ): + output.extend(self._finish_tool_call(chunk)) + self._state = new_state + self._tag_buf = "" + self._strip_leading = True + elif not self._could_be_tag(): + batch += self._tag_buf + self._tag_buf = "" + continue + + batch += ch + pos += 1 + + if batch: + output.extend(self._emit_batch(batch)) + + return output + + def _try_match_tag(self) -> str | None: + """Return the tag string if ``_tag_buf`` exactly matches one.""" + if self._tag_buf in _KNOWN_TAGS: + return self._tag_buf + return None + + def _could_be_tag(self) -> bool: + """Check if ``_tag_buf`` is a valid prefix of any known tag.""" + return any(t.startswith(self._tag_buf) for t in _KNOWN_TAGS) + + def _emit_batch(self, text: str) -> list[dict[str, Any]]: + """Emit *text* as a single chunk for the current state.""" + if not text or self._last_chunk_template is None: + return [] + + if self._strip_leading: + text = text.lstrip() + if not text: + return [] + self._strip_leading = False + + if self._state == _State.CONTENT: + return [self._make_chunk(content=text)] + elif self._state == _State.THINKING: + return [self._make_chunk(reasoning_content=text)] + else: + self._tool_call_buf += text + return [] + + def _finish_tool_call( + self, chunk: dict[str, Any] + ) -> list[dict[str, Any]]: + """Parse the accumulated JSON and emit a ``tool_calls`` delta.""" + raw = self._tool_call_buf.strip() + self._tool_call_buf = "" + + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return [self._make_chunk(content=raw)] + + arguments = parsed.get("arguments", {}) + tool_id = gen_id("call") + idx = self._tool_call_index + self._tool_call_index += 1 + self._had_tool_calls = True + + return [ + self._make_chunk( + tool_calls=[ + { + "index": idx, + "id": tool_id, + "type": "function", + "function": { + "name": parsed["name"], + "arguments": json.dumps(arguments) + if not isinstance(arguments, str) + else arguments, + }, + } + ] + ) + ] + + def _stamp(self, chunk: dict[str, Any]) -> dict[str, Any]: + """Apply consistent id/model to a passthrough chunk (mutates).""" + chunk["id"] = self._stream_id + if self._model_id: + chunk["model"] = self._model_id + return chunk + + def _rewrite_finish(self, chunk: dict[str, Any]) -> dict[str, Any]: + out = copy.deepcopy(chunk) + out["id"] = self._stream_id + if self._model_id: + out["model"] = self._model_id + choice = out["choices"][0] + choice["finish_reason"] = "tool_calls" + return out + + def _make_chunk(self, **delta_fields: Any) -> dict[str, Any]: + """Build an output chunk from the last seen template.""" + out = copy.deepcopy(self._last_chunk_template) + out["id"] = self._stream_id + if self._model_id: + out["model"] = self._model_id + choice = out["choices"][0] + # Preserve role from original delta if present + orig_delta = choice.get("delta", {}) + new_delta: dict[str, Any] = {} + if "role" in orig_delta and not delta_fields.get("tool_calls"): + new_delta["role"] = orig_delta["role"] + new_delta.update(delta_fields) + choice["delta"] = new_delta + choice["finish_reason"] = None + return out + + +def _get_delta(chunk: dict[str, Any]) -> dict[str, Any] | None: + choices = chunk.get("choices") + if not choices: + return None + return choices[0].get("delta") + + +def _get_finish_reason(chunk: dict[str, Any]) -> str | None: + choices = chunk.get("choices") + if not choices: + return None + return choices[0].get("finish_reason") diff --git a/slm_server/utils/spans.py b/slm_server/utils/spans.py index 9cbf34e..88d8baf 100644 --- a/slm_server/utils/spans.py +++ b/slm_server/utils/spans.py @@ -20,6 +20,7 @@ from .constants import ( ATTR_CHUNK_COUNT, ATTR_COMPLETION_TOKENS, + ATTR_FINISH_REASON, ATTR_FORCE_SAMPLE, ATTR_INPUT_CONTENT_LENGTH, ATTR_INPUT_COUNT, @@ -35,6 +36,7 @@ EMBEDDING_MODEL_NAME, EVENT_ATTR_CHUNK_CONTENT_SIZE, EVENT_ATTR_CHUNK_SIZE, + EVENT_ATTR_FINISH_REASON, EVENT_CHUNK_GENERATED, MODEL_NAME, SPAN_CHAT_COMPLETION, @@ -48,9 +50,7 @@ def set_atrribute_response(span: Span, response: ChatCompletionResponse | dict): """Set response attributes automatically.""" - # Non-streaming response - handle both dict and object responses if isinstance(response, dict): - # Handle dict response usage = response.get("usage") if usage: span.set_attribute(ATTR_PROMPT_TOKENS, usage.get("prompt_tokens", 0)) @@ -60,19 +60,26 @@ def set_atrribute_response(span: Span, response: ChatCompletionResponse | dict): span.set_attribute(ATTR_TOTAL_TOKENS, usage.get("total_tokens", 0)) choices = response.get("choices", []) - if choices and choices[0].get("message"): - content = choices[0]["message"].get("content") or "" - span.set_attribute(ATTR_OUTPUT_CONTENT_LENGTH, len(content)) + if choices: + if choices[0].get("message"): + content = choices[0]["message"].get("content") or "" + span.set_attribute(ATTR_OUTPUT_CONTENT_LENGTH, len(content)) + finish = choices[0].get("finish_reason") + if finish: + span.set_attribute(ATTR_FINISH_REASON, finish) else: - # Handle object response (original code) if response.usage: span.set_attribute(ATTR_PROMPT_TOKENS, response.usage.prompt_tokens) span.set_attribute(ATTR_COMPLETION_TOKENS, response.usage.completion_tokens) span.set_attribute(ATTR_TOTAL_TOKENS, response.usage.total_tokens) - if response.choices and response.choices[0].message: - content = response.choices[0].message.content or "" - span.set_attribute(ATTR_OUTPUT_CONTENT_LENGTH, len(content)) + if response.choices: + if response.choices[0].message: + content = response.choices[0].message.content or "" + span.set_attribute(ATTR_OUTPUT_CONTENT_LENGTH, len(content)) + if response.choices[0].finish_reason: + reason = response.choices[0].finish_reason + span.set_attribute(ATTR_FINISH_REASON, reason) def set_atrribute_response_stream( @@ -80,31 +87,31 @@ def set_atrribute_response_stream( ): """Record streaming chunk as an event and accumulate tokens.""" chunk_content = "" + finish_reason = "" if isinstance(response, dict): - # Handle dict response choices = response.get("choices", []) if choices and choices[0].get("delta") and choices[0]["delta"].get("content"): chunk_content = choices[0]["delta"]["content"] - chunk_json = str(response) # Simple string representation for dict + if choices: + finish_reason = choices[0].get("finish_reason") or "" + chunk_json = str(response) else: - # Handle object response (original code) if ( response.choices and response.choices[0].delta and response.choices[0].delta.content ): chunk_content = response.choices[0].delta.content + if response.choices: + finish_reason = response.choices[0].finish_reason or "" chunk_json = response.model_dump_json() - # Record chunk as an event chunk_event = { EVENT_ATTR_CHUNK_SIZE: len(chunk_json), EVENT_ATTR_CHUNK_CONTENT_SIZE: len(chunk_content), - # EVENT_ATTR_CHUNK_CONTENT: chunk_content, - # EVENT_ATTR_FINISH_REASON: response.choices[0].finish_reason or 0 - # if response.choices - # else None, } + if finish_reason: + chunk_event[EVENT_ATTR_FINISH_REASON] = finish_reason span.add_event(EVENT_CHUNK_GENERATED, chunk_event) # Only count chunks with actual content diff --git a/swagger/openapi.yaml b/swagger/openapi.yaml new file mode 100644 index 0000000..7ea8348 --- /dev/null +++ b/swagger/openapi.yaml @@ -0,0 +1,1080 @@ +openapi: 3.1.0 +info: + title: OpenAI-compatible SLM Server + description: A simple API server for serving a Small Language Model, compatible + with the OpenAI Chat Completions format. + version: 0.1.0 +paths: + /metrics: + get: + summary: Metrics + description: Endpoint that serves Prometheus metrics. + operationId: metrics_metrics_get + responses: + '200': + description: Successful Response + content: + application/json: + schema: {} + /api/v1/chat/completions: + post: + summary: Create Chat Completion + description: 'Generates a chat completion, handling both streaming and non-streaming + cases. + + Concurrency is managed by the `locked_llm_session` context manager.' + operationId: create_chat_completion_api_v1_chat_completions_post + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ChatCompletionRequest' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ChatCompletionResponse' + text/event-stream: + schema: + $defs: + ChatCompletionChunkChoice: + properties: + index: + type: integer + title: Index + delta: + $ref: '#/$defs/ChatDelta' + logprobs: + anyOf: + - {} + - type: 'null' + title: Logprobs + finish_reason: + anyOf: + - type: string + enum: + - stop + - length + - tool_calls + - function_call + - type: 'null' + title: Finish Reason + type: object + required: + - index + - delta + title: ChatCompletionChunkChoice + ChatDelta: + properties: + role: + anyOf: + - type: string + enum: + - system + - user + - assistant + - tool + - type: 'null' + title: Role + content: + anyOf: + - type: string + - type: 'null' + title: Content + reasoning_content: + anyOf: + - type: string + - type: 'null' + title: Reasoning Content + tool_calls: + anyOf: + - items: + $ref: '#/$defs/DeltaToolCall' + type: array + - type: 'null' + title: Tool Calls + type: object + title: ChatDelta + DeltaToolCall: + properties: + index: + type: integer + title: Index + id: + type: string + title: Id + type: + type: string + const: function + title: Type + default: function + function: + $ref: '#/$defs/ToolCallFunction' + type: object + required: + - index + - id + - function + title: DeltaToolCall + description: 'Streaming tool call emitted by ``StreamPostProcessor``. + + + Our postprocessor accumulates the full ```` block + before + + emitting, so every field is always present (never partial/incremental). + + ``index`` is the ordinal position of this tool call in the response.' + ToolCallFunction: + properties: + name: + type: string + title: Name + arguments: + type: string + title: Arguments + type: object + required: + - name + - arguments + title: ToolCallFunction + properties: + id: + type: string + title: Id + object: + type: string + const: chat.completion.chunk + title: Object + default: chat.completion.chunk + created: + type: integer + title: Created + model: + type: string + title: Model + choices: + items: + $ref: '#/$defs/ChatCompletionChunkChoice' + type: array + title: Choices + type: object + required: + - id + - created + - model + - choices + title: ChatCompletionChunkResponse + description: 'Schema for each SSE ``data:`` payload in a streaming + chat completion. + + + Not used for runtime validation (chunks are pre-serialised dicts), + + but exposed here so OpenAPI / Swagger documents the streaming format.' + examples: + - choices: + - delta: + role: assistant + index: 0 + created: 1771828652 + id: chatcmpl_vGxUUAi7KYLIEaNGb + model: Qwen3-0.6B-Q4_K_M + object: chat.completion.chunk + - choices: + - delta: + reasoning_content: Let me think. + index: 0 + created: 1771828652 + id: chatcmpl_vGxUUAi7KYLIEaNGb + model: Qwen3-0.6B-Q4_K_M + object: chat.completion.chunk + - choices: + - delta: + content: 2 + 2 equals **4**. + index: 0 + created: 1771828652 + id: chatcmpl_vGxUUAi7KYLIEaNGb + model: Qwen3-0.6B-Q4_K_M + object: chat.completion.chunk + - choices: + - delta: + tool_calls: + - function: + arguments: '{"location": "San Francisco"}' + name: get_weather + id: call_1MSB4PoE3eerSEJu + index: 0 + type: function + index: 0 + created: 1771828662 + id: chatcmpl_uXmbuefsXElrLqSkb + model: Qwen3-0.6B-Q4_K_M + object: chat.completion.chunk + - choices: + - delta: {} + finish_reason: stop + index: 0 + created: 1771828652 + id: chatcmpl_vGxUUAi7KYLIEaNGb + model: Qwen3-0.6B-Q4_K_M + object: chat.completion.chunk + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /api/v1/embeddings: + post: + summary: Create Embeddings + description: Create embeddings using the dedicated ONNX embedding model. + operationId: create_embeddings_api_v1_embeddings_post + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/EmbeddingRequest' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/EmbeddingResponse' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /api/v1/models: + get: + summary: List Models + description: List available models (OpenAI-compatible). + operationId: list_models_api_v1_models_get + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ModelListResponse' + /health: + get: + summary: Health + operationId: health_health_get + responses: + '200': + description: Successful Response + content: + application/json: + schema: {} +components: + schemas: + ChatCompletionChoice: + properties: + index: + type: integer + title: Index + message: + $ref: '#/components/schemas/ChatMessage' + logprobs: + anyOf: + - {} + - type: 'null' + title: Logprobs + finish_reason: + anyOf: + - type: string + enum: + - stop + - length + - tool_calls + - function_call + - type: 'null' + title: Finish Reason + type: object + required: + - index + - message + title: ChatCompletionChoice + ChatCompletionFunction: + properties: + name: + type: string + title: Name + description: + type: string + title: Description + parameters: + additionalProperties: + anyOf: + - type: integer + - type: string + - type: boolean + - items: {} + type: array + - additionalProperties: true + type: object + - type: 'null' + type: object + title: Parameters + type: object + required: + - name + - parameters + title: ChatCompletionFunction + ChatCompletionMessageToolCall: + properties: + id: + type: string + title: Id + type: + type: string + const: function + title: Type + function: + $ref: '#/components/schemas/ChatCompletionMessageToolCallFunction' + type: object + required: + - id + - type + - function + title: ChatCompletionMessageToolCall + ChatCompletionMessageToolCallFunction: + properties: + name: + type: string + title: Name + arguments: + type: string + title: Arguments + type: object + required: + - name + - arguments + title: ChatCompletionMessageToolCallFunction + ChatCompletionNamedToolChoice: + properties: + type: + type: string + const: function + title: Type + function: + $ref: '#/components/schemas/ChatCompletionNamedToolChoiceFunction' + type: object + required: + - type + - function + title: ChatCompletionNamedToolChoice + ChatCompletionNamedToolChoiceFunction: + properties: + name: + type: string + title: Name + type: object + required: + - name + title: ChatCompletionNamedToolChoiceFunction + ChatCompletionRequest: + properties: + messages: + items: + anyOf: + - $ref: '#/components/schemas/ChatCompletionRequestSystemMessage' + - $ref: '#/components/schemas/ChatCompletionRequestUserMessage' + - $ref: '#/components/schemas/ChatCompletionRequestAssistantMessage' + - $ref: '#/components/schemas/ChatCompletionRequestToolMessage' + - $ref: '#/components/schemas/ChatCompletionRequestFunctionMessage' + type: array + title: Messages + description: List of chat completion messages in the conversation + functions: + anyOf: + - items: + $ref: '#/components/schemas/ChatCompletionFunction' + type: array + - type: 'null' + title: Functions + description: List of functions available for the model to call + function_call: + anyOf: + - type: string + enum: + - none + - auto + - $ref: '#/components/schemas/ChatCompletionRequestFunctionCallOption' + - type: 'null' + title: Function Call + description: Controls which function the model should call + tools: + anyOf: + - items: + $ref: '#/components/schemas/ChatCompletionTool' + type: array + - type: 'null' + title: Tools + description: List of tools available for the model to use + tool_choice: + anyOf: + - type: string + enum: + - none + - auto + - required + - $ref: '#/components/schemas/ChatCompletionNamedToolChoice' + - type: 'null' + title: Tool Choice + description: Controls which tool the model should use + temperature: + type: number + title: Temperature + description: Sampling temperature (0.0 to 2.0) + default: 0.2 + top_p: + type: number + title: Top P + description: Nucleus sampling parameter + default: 0.95 + top_k: + type: integer + title: Top K + description: Top-k sampling parameter + default: 40 + min_p: + type: number + title: Min P + description: Minimum probability threshold for sampling + default: 0.05 + typical_p: + type: number + title: Typical P + description: Typical sampling parameter + default: 1.0 + stream: + type: boolean + title: Stream + description: Whether to stream the response + default: false + stop: + anyOf: + - type: string + - items: + type: string + type: array + - type: 'null' + title: Stop + description: Stop sequences to end generation + seed: + anyOf: + - type: integer + - type: 'null' + title: Seed + description: Random seed for reproducible generation + response_format: + anyOf: + - $ref: '#/components/schemas/ChatCompletionRequestResponseFormat' + - type: 'null' + description: Response format specification + max_tokens: + anyOf: + - type: integer + - type: 'null' + title: Max Tokens + description: Maximum number of tokens to generate + presence_penalty: + type: number + title: Presence Penalty + description: Presence penalty (-2.0 to 2.0) + default: 0.0 + frequency_penalty: + type: number + title: Frequency Penalty + description: Frequency penalty (-2.0 to 2.0) + default: 0.0 + repeat_penalty: + type: number + title: Repeat Penalty + description: Repetition penalty (1.0 = no penalty) + default: 1.0 + tfs_z: + type: number + title: Tfs Z + description: Tail free sampling parameter + default: 1.0 + mirostat_mode: + type: integer + title: Mirostat Mode + description: Mirostat sampling mode (0=disabled, 1=v1, 2=v2) + default: 0 + mirostat_tau: + type: number + title: Mirostat Tau + description: Mirostat target entropy + default: 5.0 + mirostat_eta: + type: number + title: Mirostat Eta + description: Mirostat learning rate + default: 0.1 + model: + anyOf: + - type: string + - type: 'null' + title: Model + description: Model identifier + logit_bias: + anyOf: + - additionalProperties: + type: number + type: object + - type: 'null' + title: Logit Bias + description: Logit bias adjustments for specific tokens + logprobs: + anyOf: + - type: boolean + - type: 'null' + title: Logprobs + description: Whether to return log probabilities + top_logprobs: + anyOf: + - type: integer + - type: 'null' + title: Top Logprobs + description: Number of top log probabilities to return + type: object + required: + - messages + title: ChatCompletionRequest + ChatCompletionRequestAssistantMessage: + properties: + role: + type: string + const: assistant + title: Role + content: + type: string + title: Content + tool_calls: + items: + $ref: '#/components/schemas/ChatCompletionMessageToolCall' + type: array + title: Tool Calls + function_call: + $ref: '#/components/schemas/ChatCompletionRequestAssistantMessageFunctionCall' + type: object + required: + - role + title: ChatCompletionRequestAssistantMessage + ChatCompletionRequestAssistantMessageFunctionCall: + properties: + name: + type: string + title: Name + arguments: + type: string + title: Arguments + type: object + required: + - name + - arguments + title: ChatCompletionRequestAssistantMessageFunctionCall + ChatCompletionRequestFunctionCallOption: + properties: + name: + type: string + title: Name + type: object + required: + - name + title: ChatCompletionRequestFunctionCallOption + ChatCompletionRequestFunctionMessage: + properties: + role: + type: string + const: function + title: Role + content: + anyOf: + - type: string + - type: 'null' + title: Content + name: + type: string + title: Name + type: object + required: + - role + - content + - name + title: ChatCompletionRequestFunctionMessage + ChatCompletionRequestMessageContentPartImage: + properties: + type: + type: string + const: image_url + title: Type + image_url: + anyOf: + - type: string + - $ref: '#/components/schemas/ChatCompletionRequestMessageContentPartImageImageUrl' + title: Image Url + type: object + required: + - type + - image_url + title: ChatCompletionRequestMessageContentPartImage + ChatCompletionRequestMessageContentPartImageImageUrl: + properties: + url: + type: string + title: Url + detail: + type: string + enum: + - auto + - low + - high + title: Detail + type: object + required: + - url + title: ChatCompletionRequestMessageContentPartImageImageUrl + ChatCompletionRequestMessageContentPartText: + properties: + type: + type: string + const: text + title: Type + text: + type: string + title: Text + type: object + required: + - type + - text + title: ChatCompletionRequestMessageContentPartText + ChatCompletionRequestResponseFormat: + properties: + type: + type: string + enum: + - text + - json_object + title: Type + schema: + anyOf: + - type: integer + - type: string + - type: boolean + - items: {} + type: array + - additionalProperties: true + type: object + - type: 'null' + title: Schema + type: object + required: + - type + title: ChatCompletionRequestResponseFormat + ChatCompletionRequestSystemMessage: + properties: + role: + type: string + const: system + title: Role + content: + anyOf: + - type: string + - type: 'null' + title: Content + type: object + required: + - role + - content + title: ChatCompletionRequestSystemMessage + ChatCompletionRequestToolMessage: + properties: + role: + type: string + const: tool + title: Role + content: + anyOf: + - type: string + - type: 'null' + title: Content + tool_call_id: + type: string + title: Tool Call Id + type: object + required: + - role + - content + - tool_call_id + title: ChatCompletionRequestToolMessage + ChatCompletionRequestUserMessage: + properties: + role: + type: string + const: user + title: Role + content: + anyOf: + - type: string + - items: + anyOf: + - $ref: '#/components/schemas/ChatCompletionRequestMessageContentPartText' + - $ref: '#/components/schemas/ChatCompletionRequestMessageContentPartImage' + type: array + - type: 'null' + title: Content + type: object + required: + - role + - content + title: ChatCompletionRequestUserMessage + ChatCompletionResponse: + properties: + id: + type: string + title: Id + object: + type: string + const: chat.completion + title: Object + default: chat.completion + created: + type: integer + title: Created + model: + type: string + title: Model + choices: + items: + $ref: '#/components/schemas/ChatCompletionChoice' + type: array + title: Choices + usage: + $ref: '#/components/schemas/CompletionUsage' + type: object + required: + - id + - created + - model + - choices + - usage + title: ChatCompletionResponse + examples: + - choices: + - finish_reason: stop + index: 0 + message: + content: Hi there! + reasoning_content: The user wants a greeting. + role: assistant + created: 1771828648 + id: chatcmpl_MInMfQVLOPebdbjF + model: Qwen3-0.6B-Q4_K_M + object: chat.completion + usage: + completion_tokens: 149 + prompt_tokens: 14 + total_tokens: 163 + - choices: + - finish_reason: tool_calls + index: 0 + message: + role: assistant + tool_calls: + - function: + arguments: '{"location": "San Francisco"}' + name: get_weather + id: call_2HWeVAmARMukopsB + type: function + created: 1771828660 + id: chatcmpl_6HEgD7zMer6W3ob8 + model: Qwen3-0.6B-Q4_K_M + object: chat.completion + usage: + completion_tokens: 24 + prompt_tokens: 166 + total_tokens: 190 + ChatCompletionTool: + properties: + type: + type: string + const: function + title: Type + function: + $ref: '#/components/schemas/ChatCompletionToolFunction' + type: object + required: + - type + - function + title: ChatCompletionTool + ChatCompletionToolFunction: + properties: + name: + type: string + title: Name + description: + type: string + title: Description + parameters: + additionalProperties: + anyOf: + - type: integer + - type: string + - type: boolean + - items: {} + type: array + - additionalProperties: true + type: object + - type: 'null' + type: object + title: Parameters + type: object + required: + - name + - parameters + title: ChatCompletionToolFunction + ChatMessage: + properties: + role: + type: string + enum: + - assistant + - function + title: Role + content: + anyOf: + - type: string + - type: 'null' + title: Content + reasoning_content: + anyOf: + - type: string + - type: 'null' + title: Reasoning Content + tool_calls: + anyOf: + - items: + $ref: '#/components/schemas/ToolCall' + type: array + - type: 'null' + title: Tool Calls + type: object + required: + - role + title: ChatMessage + CompletionUsage: + properties: + prompt_tokens: + type: integer + title: Prompt Tokens + completion_tokens: + type: integer + title: Completion Tokens + total_tokens: + type: integer + title: Total Tokens + type: object + required: + - prompt_tokens + - completion_tokens + - total_tokens + title: CompletionUsage + EmbeddingData: + properties: + object: + type: string + const: embedding + title: Object + default: embedding + embedding: + items: + type: number + type: array + title: Embedding + index: + type: integer + title: Index + type: object + required: + - embedding + - index + title: EmbeddingData + EmbeddingRequest: + properties: + input: + anyOf: + - type: string + - items: + type: string + type: array + maxItems: 2048 + title: Input + model: + anyOf: + - type: string + - type: 'null' + title: Model + description: Model identifier + type: object + required: + - input + title: EmbeddingRequest + EmbeddingResponse: + properties: + object: + type: string + const: list + title: Object + default: list + data: + items: + $ref: '#/components/schemas/EmbeddingData' + type: array + title: Data + model: + type: string + title: Model + type: object + required: + - data + - model + title: EmbeddingResponse + examples: + - data: + - embedding: + - -0.046052 + - 0.028006 + - 0.014284 + - 0.025734 + - -0.034211 + index: 0 + object: embedding + model: all-MiniLM-L6-v2 + object: list + HTTPValidationError: + properties: + detail: + items: + $ref: '#/components/schemas/ValidationError' + type: array + title: Detail + type: object + title: HTTPValidationError + ModelInfo: + properties: + id: + type: string + title: Id + description: Model identifier for use in API endpoints + object: + type: string + const: model + title: Object + default: model + created: + type: integer + title: Created + description: Unix timestamp when the model was created + owned_by: + type: string + title: Owned By + description: Organization that owns the model + type: object + required: + - id + - created + - owned_by + title: ModelInfo + description: Single model entry for GET /api/v1/models. + ModelListResponse: + properties: + object: + type: string + const: list + title: Object + default: list + data: + items: + $ref: '#/components/schemas/ModelInfo' + type: array + title: Data + description: List of available models + type: object + required: + - data + title: ModelListResponse + description: Response for GET /api/v1/models. + examples: + - data: + - created: 1771811606 + id: Qwen3-0.6B-Q4_K_M + object: model + owned_by: second-state + - created: 1771811614 + id: all-MiniLM-L6-v2 + object: model + owned_by: sentence-transformers + object: list + ToolCall: + properties: + id: + type: string + title: Id + type: + type: string + const: function + title: Type + default: function + function: + $ref: '#/components/schemas/ToolCallFunction' + type: object + required: + - id + - function + title: ToolCall + ToolCallFunction: + properties: + name: + type: string + title: Name + arguments: + type: string + title: Arguments + type: object + required: + - name + - arguments + title: ToolCallFunction + ValidationError: + properties: + loc: + items: + anyOf: + - type: string + - type: integer + type: array + title: Location + msg: + type: string + title: Message + type: + type: string + title: Error Type + type: object + required: + - loc + - msg + - type + title: ValidationError diff --git a/tests/test_app.py b/tests/test_app.py index 91ab31a..ca98152 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -204,7 +204,7 @@ def test_streaming_stops_on_client_disconnect(): cancellation_triggered = False - async def mock_run_llm_streaming_with_cancellation(llm, req): + async def mock_run_llm_streaming_with_cancellation(llm, req, *, model_id="test"): """Mock that yields some chunks then gets cancelled by client disconnect.""" nonlocal cancellation_triggered from slm_server.utils.spans import slm_span, set_atrribute_response_stream diff --git a/tests/test_postprocess.py b/tests/test_postprocess.py new file mode 100644 index 0000000..3635ae8 --- /dev/null +++ b/tests/test_postprocess.py @@ -0,0 +1,503 @@ +"""Tests for response post-processing: reasoning extraction and tool call parsing.""" + +import json + +from slm_server.utils.ids import gen_id +from slm_server.utils.postprocess import ( + StreamPostProcessor, + extract_reasoning, + parse_tool_calls, + postprocess_completion, +) + + +class TestExtractReasoning: + + def test_extracts_reasoning(self): + content = "\nStep 1: analyze.\nStep 2: respond.\n\n\nHello!" + reasoning, cleaned = extract_reasoning(content) + + assert reasoning == "Step 1: analyze.\nStep 2: respond." + assert cleaned == "Hello!" + + def test_empty_think_block(self): + content = "\n\n\n\nHello!" + reasoning, cleaned = extract_reasoning(content) + + assert reasoning is None + assert cleaned == "Hello!" + + def test_no_think_block(self): + content = "Just a regular response." + reasoning, cleaned = extract_reasoning(content) + + assert reasoning is None + assert cleaned == content + + def test_multiple_think_blocks(self): + content = ( + "First thought.\n" + "Middle text.\n" + "Second thought.\n" + "Final answer." + ) + reasoning, cleaned = extract_reasoning(content) + + assert reasoning == "First thought.\nSecond thought." + assert "Middle text." in cleaned + assert "Final answer." in cleaned + + def test_think_only_no_content(self): + content = "Just thinking..." + reasoning, cleaned = extract_reasoning(content) + + assert reasoning == "Just thinking..." + assert cleaned == "" + + +class TestParseToolCalls: + + def test_single_tool_call(self): + content = ( + '\n' + '{"name": "get_weather", "arguments": {"location": "San Francisco"}}\n' + '' + ) + tool_calls, remaining = parse_tool_calls(content) + + assert len(tool_calls) == 1 + tc = tool_calls[0] + assert tc["type"] == "function" + assert tc["function"]["name"] == "get_weather" + assert json.loads(tc["function"]["arguments"]) == { + "location": "San Francisco" + } + assert tc["id"].startswith("call_") + assert remaining is None + + def test_multiple_tool_calls(self): + content = ( + '\n' + '{"name": "get_weather", "arguments": {"location": "SF"}}\n' + '\n' + '\n' + '{"name": "get_time", "arguments": {"timezone": "PST"}}\n' + '' + ) + tool_calls, remaining = parse_tool_calls(content) + + assert len(tool_calls) == 2 + assert tool_calls[0]["function"]["name"] == "get_weather" + assert tool_calls[1]["function"]["name"] == "get_time" + assert tool_calls[0]["id"] != tool_calls[1]["id"] + + def test_no_tool_calls(self): + content = "The weather in San Francisco is sunny today." + tool_calls, remaining = parse_tool_calls(content) + + assert tool_calls == [] + assert remaining == content + + def test_tool_call_with_remaining_content(self): + content = ( + "Here is the result:\n" + '\n' + '{"name": "search", "arguments": {"query": "test"}}\n' + '\n' + "Let me know if you need more." + ) + tool_calls, remaining = parse_tool_calls(content) + + assert len(tool_calls) == 1 + assert remaining == "Here is the result:\n\nLet me know if you need more." + + def test_arguments_as_string(self): + content = ( + '\n' + '{"name": "run", "arguments": "{\\"key\\": \\"val\\"}"}\n' + '' + ) + tool_calls, remaining = parse_tool_calls(content) + + assert len(tool_calls) == 1 + assert tool_calls[0]["function"]["arguments"] == '{"key": "val"}' + + def test_empty_arguments(self): + content = ( + '\n' + '{"name": "ping", "arguments": {}}\n' + '' + ) + tool_calls, remaining = parse_tool_calls(content) + + assert len(tool_calls) == 1 + assert json.loads(tool_calls[0]["function"]["arguments"]) == {} + + +class TestPostprocessCompletion: + + def _make_response(self, content, finish_reason="stop"): + return { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + }, + } + + def test_extracts_reasoning_from_plain_response(self): + resp = self._make_response( + "\nAnalyzing the question.\n\n\nThe answer is 42." + ) + result = postprocess_completion(resp) + msg = result["choices"][0]["message"] + + assert result["id"].startswith("chatcmpl_") + assert msg["content"] == "The answer is 42." + assert msg["reasoning_content"] == "Analyzing the question." + assert "tool_calls" not in msg + assert result["choices"][0]["finish_reason"] == "stop" + + def test_extracts_reasoning_and_tool_calls(self): + resp = self._make_response( + '\nI should call the weather API.\n\n\n' + '\n' + '{"name": "get_weather", "arguments": {"location": "SF"}}\n' + '' + ) + result = postprocess_completion(resp) + msg = result["choices"][0]["message"] + + assert result["id"].startswith("chatcmpl_") + assert msg["content"] is None + assert msg["reasoning_content"] == "I should call the weather API." + assert len(msg["tool_calls"]) == 1 + assert msg["tool_calls"][0]["function"]["name"] == "get_weather" + assert result["choices"][0]["finish_reason"] == "tool_calls" + + def test_empty_think_block_no_reasoning_content(self): + resp = self._make_response( + "\n\n\n\nHello!" + ) + result = postprocess_completion(resp) + msg = result["choices"][0]["message"] + + assert msg["content"] == "Hello!" + assert "reasoning_content" not in msg + + def test_leaves_plain_text_unchanged(self): + resp = self._make_response("Hello, how can I help?") + result = postprocess_completion(resp) + msg = result["choices"][0]["message"] + + assert msg["content"] == "Hello, how can I help?" + assert "reasoning_content" not in msg + assert "tool_calls" not in msg + assert result["choices"][0]["finish_reason"] == "stop" + + def test_leaves_empty_content_unchanged(self): + resp = self._make_response("") + result = postprocess_completion(resp) + msg = result["choices"][0]["message"] + + assert msg["content"] == "" + assert "reasoning_content" not in msg + + def test_rewrites_id_and_preserves_usage(self): + resp = self._make_response( + '{"name": "f", "arguments": {}}' + ) + result = postprocess_completion(resp) + + assert result["id"].startswith("chatcmpl_") + assert result["id"] != "chatcmpl-123" + assert result["usage"]["total_tokens"] == 150 + + def test_rewrites_model_when_provided(self): + resp = self._make_response("Hello!") + result = postprocess_completion(resp, model_id="Qwen3-0.6B") + + assert result["model"] == "Qwen3-0.6B" + + def test_preserves_model_when_not_provided(self): + resp = self._make_response("Hello!") + result = postprocess_completion(resp) + + assert result["model"] == "test-model" + + def test_no_raw_tags_leak_into_content(self): + resp = self._make_response( + '\nThinking...\n\n\n' + '\n' + '{"name": "f", "arguments": {}}\n' + '' + ) + result = postprocess_completion(resp) + msg = result["choices"][0]["message"] + content = msg["content"] or "" + + assert "" not in content + assert "" not in content + assert "" not in content + assert "" not in content + + +# --------------------------------------------------------------------------- +# Streaming post-processing +# --------------------------------------------------------------------------- + + +def _make_chunk(content=None, finish_reason=None, **extra_delta): + """Helper to build a minimal streaming chunk dict.""" + delta = {} + if content is not None: + delta["content"] = content + delta.update(extra_delta) + return { + "id": "chatcmpl-stream", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + } + ], + } + + +def _feed(processor, tokens): + """Feed a list of token strings through the processor, return all output chunks.""" + out = [] + for tok in tokens: + chunk = _make_chunk(content=tok) + out.extend(processor.process_chunk(chunk)) + # Process a final chunk with finish_reason + final = _make_chunk(finish_reason="stop") + out.extend(processor.process_chunk(final)) + out.extend(processor.flush()) + return out + + +def _collect_field(chunks, field): + """Concatenate a delta field across all output chunks.""" + parts = [] + for c in chunks: + delta = c["choices"][0]["delta"] + val = delta.get(field) + if val: + parts.append(val) + return "".join(parts) + + +class TestStreamPostProcessor: + + def test_plain_content_passthrough(self): + proc = StreamPostProcessor() + chunks = _feed(proc, ["Hello", " world", "!"]) + + assert _collect_field(chunks, "content") == "Hello world!" + assert _collect_field(chunks, "reasoning_content") == "" + + def test_think_block_to_reasoning_content(self): + proc = StreamPostProcessor() + tokens = ["", "I need to think", "", "Answer here"] + chunks = _feed(proc, tokens) + + assert _collect_field(chunks, "reasoning_content") == "I need to think" + assert _collect_field(chunks, "content") == "Answer here" + + def test_think_tags_split_across_tokens(self): + proc = StreamPostProcessor() + tokens = ["<", "think", ">", "reasoning", "<", "/think", ">", "visible"] + chunks = _feed(proc, tokens) + + assert _collect_field(chunks, "reasoning_content") == "reasoning" + assert _collect_field(chunks, "content") == "visible" + + def test_tool_call_emits_tool_calls_delta(self): + proc = StreamPostProcessor() + tokens = [ + "", + '{"name": "get_weather",', + ' "arguments": {"location": "SF"}}', + "", + ] + chunks = _feed(proc, tokens) + + tc_chunks = [ + c for c in chunks + if c["choices"][0]["delta"].get("tool_calls") + ] + assert len(tc_chunks) == 1 + tc = tc_chunks[0]["choices"][0]["delta"]["tool_calls"][0] + assert tc["type"] == "function" + assert tc["function"]["name"] == "get_weather" + assert json.loads(tc["function"]["arguments"]) == {"location": "SF"} + assert tc["id"].startswith("call_") + + def test_tool_call_sets_finish_reason(self): + proc = StreamPostProcessor() + tokens = [ + "", + '{"name": "f", "arguments": {}}', + "", + ] + chunks = _feed(proc, tokens) + + final_chunks = [ + c for c in chunks if c["choices"][0].get("finish_reason") is not None + ] + assert any( + c["choices"][0]["finish_reason"] == "tool_calls" for c in final_chunks + ) + + def test_think_then_tool_call(self): + proc = StreamPostProcessor() + tokens = [ + "", + "Let me check", + "", + "", + '{"name": "search", "arguments": {"q": "test"}}', + "", + ] + chunks = _feed(proc, tokens) + + assert _collect_field(chunks, "reasoning_content") == "Let me check" + + tc_chunks = [ + c for c in chunks + if c["choices"][0]["delta"].get("tool_calls") + ] + assert len(tc_chunks) == 1 + assert tc_chunks[0]["choices"][0]["delta"]["tool_calls"][0][ + "function" + ]["name"] == "search" + + def test_multiple_tool_calls(self): + proc = StreamPostProcessor() + tokens = [ + "", + '{"name": "a", "arguments": {}}', + "", + "", + '{"name": "b", "arguments": {}}', + "", + ] + chunks = _feed(proc, tokens) + + tc_chunks = [ + c for c in chunks + if c["choices"][0]["delta"].get("tool_calls") + ] + assert len(tc_chunks) == 2 + assert tc_chunks[0]["choices"][0]["delta"]["tool_calls"][0]["index"] == 0 + assert tc_chunks[1]["choices"][0]["delta"]["tool_calls"][0]["index"] == 1 + + def test_partial_tag_not_matching_flushed_as_content(self): + proc = StreamPostProcessor() + tokens = ["Hello world"] + chunks = _feed(proc, tokens) + + assert _collect_field(chunks, "content") == "Hello world" + + def test_no_content_chunk_passthrough(self): + proc = StreamPostProcessor() + role_chunk = { + "id": "chatcmpl-stream", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": None, + } + ], + } + result = proc.process_chunk(role_chunk) + assert len(result) == 1 + assert result[0]["choices"][0]["delta"]["role"] == "assistant" + + def test_empty_think_block_suppressed(self): + proc = StreamPostProcessor() + tokens = ["", "\n\n", "", "\n\nHello!"] + chunks = _feed(proc, tokens) + + content = _collect_field(chunks, "content") + assert "Hello!" in content + assert "" not in content + assert "" not in content + + def test_leading_whitespace_stripped_after_tags(self): + """Align with non-streaming .strip(): no leading \\n after tag transitions.""" + proc = StreamPostProcessor() + tokens = ["", "\n\nI think.", "\n", "\n\nAnswer here"] + chunks = _feed(proc, tokens) + + reasoning = _collect_field(chunks, "reasoning_content") + content = _collect_field(chunks, "content") + assert reasoning == "I think.\n" + assert content == "Answer here" + + def test_whitespace_only_after_close_think_suppressed(self): + proc = StreamPostProcessor() + tokens = ["", "ok", "", "\n\n", "Hello"] + chunks = _feed(proc, tokens) + + content = _collect_field(chunks, "content") + assert content == "Hello" + + def test_stream_id_consistent_across_chunks(self): + proc = StreamPostProcessor() + chunks = _feed(proc, ["Hello", " world"]) + ids = {c["id"] for c in chunks} + assert len(ids) == 1 + assert ids.pop().startswith("chatcmpl_") + + def test_stream_model_rewritten_when_provided(self): + proc = StreamPostProcessor(model_id="Qwen3-0.6B") + chunks = _feed(proc, ["Hello"]) + for c in chunks: + assert c["model"] == "Qwen3-0.6B" + + def test_stream_model_preserved_when_not_provided(self): + proc = StreamPostProcessor() + chunks = _feed(proc, ["Hello"]) + for c in chunks: + if "model" in c: + assert c["model"] == "test-model" + + +class TestGenId: + + def test_chatcmpl_prefix(self): + cid = gen_id("chatcmpl") + assert cid.startswith("chatcmpl_") + + def test_call_prefix_uses_underscore(self): + cid = gen_id("call") + assert cid.startswith("call_") + + def test_unique(self): + ids = {gen_id("chatcmpl") for _ in range(100)} + assert len(ids) == 100 + + def test_compact_length(self): + cid = gen_id("chatcmpl") + assert len(cid) < 30 diff --git a/uv.lock b/uv.lock index 11e917b..ebe46ed 100644 --- a/uv.lock +++ b/uv.lock @@ -527,15 +527,14 @@ wheels = [ [[package]] name = "llama-cpp-python" -version = "0.3.13" -source = { registry = "https://pypi.org/simple" } +version = "0.3.16" +source = { git = "https://github.com/XyLearningProgramming/llama-cpp-python.git?rev=main#f62ded4478a228cb7e3b0a069bf2d32e30fe089f" } dependencies = [ { name = "diskcache" }, { name = "jinja2" }, { name = "numpy" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e2/3d/a426f9777301569a17f3c3bf4ecc3755120531c008e4601450eec13c09ac/llama_cpp_python-0.3.13.tar.gz", hash = "sha256:307ce2abf62c7cf574234b8c633978cf92eb1c4b3cfe6babef889d812c298d84", size = 50059668, upload_time = "2025-07-15T11:43:59.734Z" } [[package]] name = "markdown-it-py" @@ -1224,7 +1223,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "fastapi", specifier = ">=0.116.1" }, - { name = "llama-cpp-python", specifier = ">=0.3.13" }, + { name = "llama-cpp-python", git = "https://github.com/XyLearningProgramming/llama-cpp-python.git?rev=main" }, { name = "onnxruntime", specifier = ">=1.17.0" }, { name = "opentelemetry-api", specifier = ">=1.35.0" }, { name = "opentelemetry-exporter-otlp", specifier = ">=1.35.0" }, From 0efbc862eeea276a032e5a0ae2317c1e095b8b66 Mon Sep 17 00:00:00 2001 From: XyLearningProgramming Date: Mon, 23 Feb 2026 15:08:49 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=F0=9F=90=9B=20fixed=20swagger=20ref=20erro?= =?UTF-8?q?r?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- slm_server/app.py | 14 +- slm_server/model.py | 42 ++++- swagger/openapi.yaml | 367 ++++++++++++++++++++++--------------------- 3 files changed, 235 insertions(+), 188 deletions(-) diff --git a/slm_server/app.py b/slm_server/app.py index a5a10ea..7c8a481 100644 --- a/slm_server/app.py +++ b/slm_server/app.py @@ -14,7 +14,6 @@ from slm_server.logging import setup_logging from slm_server.metrics import setup_metrics from slm_server.model import ( - ChatCompletionChunkResponse, ChatCompletionRequest, ChatCompletionResponse, EmbeddingData, @@ -22,6 +21,7 @@ EmbeddingResponse, ModelInfo, ModelListResponse, + register_streaming_schema, ) from slm_server.trace import setup_tracing from slm_server.utils import ( @@ -50,6 +50,12 @@ STATUS_CODE_EXCEPTION = HTTPStatus.INTERNAL_SERVER_ERROR # Media type for streaming responses. STREAM_RESPONSE_MEDIA_TYPE = "text/event-stream" +# Schema for streaming responses. +STREAM_RESPONSE_SCHEMA = { + "schema": { + "$ref": "#/components/schemas/ChatCompletionChunkResponse" + } +} def get_llm_semaphor() -> asyncio.Semaphore: @@ -100,6 +106,8 @@ def get_app() -> FastAPI: # Setup trace and OTel metrics (this will also instrument FastAPI) setup_tracing(app, settings.tracing) + register_streaming_schema(app) + return app @@ -182,9 +190,7 @@ async def run_llm_non_streaming( responses={ 200: { "content": { - STREAM_RESPONSE_MEDIA_TYPE: { - "schema": ChatCompletionChunkResponse.model_json_schema(), - } + STREAM_RESPONSE_MEDIA_TYPE: STREAM_RESPONSE_SCHEMA, }, }, }, diff --git a/slm_server/model.py b/slm_server/model.py index 8ffee23..b83dd07 100644 --- a/slm_server/model.py +++ b/slm_server/model.py @@ -1,4 +1,9 @@ -from typing import Any, Literal, Self +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Self + +if TYPE_CHECKING: + from fastapi import FastAPI from llama_cpp.llama_types import ( ChatCompletionFunction, @@ -443,3 +448,38 @@ class ModelListResponse(BaseModel): object: Literal["list"] = "list" data: list[ModelInfo] = Field(description="List of available models") + + +# --------------------------------------------------------------------------- +# OpenAPI helpers +# --------------------------------------------------------------------------- + + +def register_streaming_schema(app: FastAPI) -> None: + """Register ``ChatCompletionChunkResponse`` in OpenAPI ``components.schemas``. + + Pydantic's ``model_json_schema()`` puts nested models under a local + ``$defs`` key, but Swagger UI resolves ``$ref`` from the document root. + This patches ``app.openapi`` to hoist the chunk schema and its + dependencies into ``components.schemas`` with correct ``$ref`` paths. + """ + _original = app.openapi + + def patched_openapi(): # type: ignore[no-untyped-def] + if app.openapi_schema: + return app.openapi_schema + + schema = _original() + chunk_schema = ChatCompletionChunkResponse.model_json_schema( + ref_template="#/components/schemas/{model}", + ) + defs = chunk_schema.pop("$defs", {}) + schemas = schema.setdefault("components", {}).setdefault("schemas", {}) + schemas["ChatCompletionChunkResponse"] = chunk_schema + for name, defn in defs.items(): + schemas.setdefault(name, defn) + + app.openapi_schema = schema + return schema + + app.openapi = patched_openapi # type: ignore[method-assign] diff --git a/swagger/openapi.yaml b/swagger/openapi.yaml index 7ea8348..33230b1 100644 --- a/swagger/openapi.yaml +++ b/swagger/openapi.yaml @@ -39,189 +39,7 @@ paths: $ref: '#/components/schemas/ChatCompletionResponse' text/event-stream: schema: - $defs: - ChatCompletionChunkChoice: - properties: - index: - type: integer - title: Index - delta: - $ref: '#/$defs/ChatDelta' - logprobs: - anyOf: - - {} - - type: 'null' - title: Logprobs - finish_reason: - anyOf: - - type: string - enum: - - stop - - length - - tool_calls - - function_call - - type: 'null' - title: Finish Reason - type: object - required: - - index - - delta - title: ChatCompletionChunkChoice - ChatDelta: - properties: - role: - anyOf: - - type: string - enum: - - system - - user - - assistant - - tool - - type: 'null' - title: Role - content: - anyOf: - - type: string - - type: 'null' - title: Content - reasoning_content: - anyOf: - - type: string - - type: 'null' - title: Reasoning Content - tool_calls: - anyOf: - - items: - $ref: '#/$defs/DeltaToolCall' - type: array - - type: 'null' - title: Tool Calls - type: object - title: ChatDelta - DeltaToolCall: - properties: - index: - type: integer - title: Index - id: - type: string - title: Id - type: - type: string - const: function - title: Type - default: function - function: - $ref: '#/$defs/ToolCallFunction' - type: object - required: - - index - - id - - function - title: DeltaToolCall - description: 'Streaming tool call emitted by ``StreamPostProcessor``. - - - Our postprocessor accumulates the full ```` block - before - - emitting, so every field is always present (never partial/incremental). - - ``index`` is the ordinal position of this tool call in the response.' - ToolCallFunction: - properties: - name: - type: string - title: Name - arguments: - type: string - title: Arguments - type: object - required: - - name - - arguments - title: ToolCallFunction - properties: - id: - type: string - title: Id - object: - type: string - const: chat.completion.chunk - title: Object - default: chat.completion.chunk - created: - type: integer - title: Created - model: - type: string - title: Model - choices: - items: - $ref: '#/$defs/ChatCompletionChunkChoice' - type: array - title: Choices - type: object - required: - - id - - created - - model - - choices - title: ChatCompletionChunkResponse - description: 'Schema for each SSE ``data:`` payload in a streaming - chat completion. - - - Not used for runtime validation (chunks are pre-serialised dicts), - - but exposed here so OpenAPI / Swagger documents the streaming format.' - examples: - - choices: - - delta: - role: assistant - index: 0 - created: 1771828652 - id: chatcmpl_vGxUUAi7KYLIEaNGb - model: Qwen3-0.6B-Q4_K_M - object: chat.completion.chunk - - choices: - - delta: - reasoning_content: Let me think. - index: 0 - created: 1771828652 - id: chatcmpl_vGxUUAi7KYLIEaNGb - model: Qwen3-0.6B-Q4_K_M - object: chat.completion.chunk - - choices: - - delta: - content: 2 + 2 equals **4**. - index: 0 - created: 1771828652 - id: chatcmpl_vGxUUAi7KYLIEaNGb - model: Qwen3-0.6B-Q4_K_M - object: chat.completion.chunk - - choices: - - delta: - tool_calls: - - function: - arguments: '{"location": "San Francisco"}' - name: get_weather - id: call_1MSB4PoE3eerSEJu - index: 0 - type: function - index: 0 - created: 1771828662 - id: chatcmpl_uXmbuefsXElrLqSkb - model: Qwen3-0.6B-Q4_K_M - object: chat.completion.chunk - - choices: - - delta: {} - finish_reason: stop - index: 0 - created: 1771828652 - id: chatcmpl_vGxUUAi7KYLIEaNGb - model: Qwen3-0.6B-Q4_K_M - object: chat.completion.chunk + $ref: '#/components/schemas/ChatCompletionChunkResponse' '422': description: Validation Error content: @@ -1078,3 +896,186 @@ components: - msg - type title: ValidationError + ChatCompletionChunkResponse: + description: 'Schema for each SSE ``data:`` payload in a streaming chat completion. + + + Not used for runtime validation (chunks are pre-serialised dicts), + + but exposed here so OpenAPI / Swagger documents the streaming format.' + examples: + - choices: + - delta: + role: assistant + finish_reason: null + index: 0 + logprobs: null + created: 1771828652 + id: chatcmpl_vGxUUAi7KYLIEaNGb + model: Qwen3-0.6B-Q4_K_M + object: chat.completion.chunk + - choices: + - delta: + reasoning_content: Let me think. + finish_reason: null + index: 0 + logprobs: null + created: 1771828652 + id: chatcmpl_vGxUUAi7KYLIEaNGb + model: Qwen3-0.6B-Q4_K_M + object: chat.completion.chunk + - choices: + - delta: + content: 2 + 2 equals **4**. + finish_reason: null + index: 0 + logprobs: null + created: 1771828652 + id: chatcmpl_vGxUUAi7KYLIEaNGb + model: Qwen3-0.6B-Q4_K_M + object: chat.completion.chunk + - choices: + - delta: + tool_calls: + - function: + arguments: '{"location": "San Francisco"}' + name: get_weather + id: call_1MSB4PoE3eerSEJu + index: 0 + type: function + finish_reason: null + index: 0 + logprobs: null + created: 1771828662 + id: chatcmpl_uXmbuefsXElrLqSkb + model: Qwen3-0.6B-Q4_K_M + object: chat.completion.chunk + - choices: + - delta: {} + finish_reason: stop + index: 0 + logprobs: null + created: 1771828652 + id: chatcmpl_vGxUUAi7KYLIEaNGb + model: Qwen3-0.6B-Q4_K_M + object: chat.completion.chunk + properties: + id: + title: Id + type: string + object: + const: chat.completion.chunk + default: chat.completion.chunk + title: Object + type: string + created: + title: Created + type: integer + model: + title: Model + type: string + choices: + items: + $ref: '#/components/schemas/ChatCompletionChunkChoice' + title: Choices + type: array + required: + - id + - created + - model + - choices + title: ChatCompletionChunkResponse + type: object + ChatCompletionChunkChoice: + properties: + index: + title: Index + type: integer + delta: + $ref: '#/components/schemas/ChatDelta' + logprobs: + anyOf: + - {} + - type: 'null' + default: null + title: Logprobs + finish_reason: + anyOf: + - enum: + - stop + - length + - tool_calls + - function_call + type: string + - type: 'null' + default: null + title: Finish Reason + required: + - index + - delta + title: ChatCompletionChunkChoice + type: object + ChatDelta: + properties: + role: + anyOf: + - enum: + - system + - user + - assistant + - tool + type: string + - type: 'null' + default: null + title: Role + content: + anyOf: + - type: string + - type: 'null' + default: null + title: Content + reasoning_content: + anyOf: + - type: string + - type: 'null' + default: null + title: Reasoning Content + tool_calls: + anyOf: + - items: + $ref: '#/components/schemas/DeltaToolCall' + type: array + - type: 'null' + default: null + title: Tool Calls + title: ChatDelta + type: object + DeltaToolCall: + description: 'Streaming tool call emitted by ``StreamPostProcessor``. + + + Our postprocessor accumulates the full ```` block before + + emitting, so every field is always present (never partial/incremental). + + ``index`` is the ordinal position of this tool call in the response.' + properties: + index: + title: Index + type: integer + id: + title: Id + type: string + type: + const: function + default: function + title: Type + type: string + function: + $ref: '#/components/schemas/ToolCallFunction' + required: + - index + - id + - function + title: DeltaToolCall + type: object From e30bec3a82cf1f5cebd1495cd0a4b1c3c171e7f9 Mon Sep 17 00:00:00 2001 From: XyLearningProgramming Date: Mon, 23 Feb 2026 15:27:49 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=F0=9F=90=9B=20fixed=20lint=20issues?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- slm_server/model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/slm_server/model.py b/slm_server/model.py index b83dd07..6686e82 100644 --- a/slm_server/model.py +++ b/slm_server/model.py @@ -194,7 +194,9 @@ class ChatCompletionResponse(BaseModel): "type": "function", "function": { "name": "get_weather", - "arguments": '{"location": "San Francisco"}', + "arguments": ( + '{"location": "San Francisco"}' + ), }, } ], @@ -322,7 +324,9 @@ class ChatCompletionChunkResponse(BaseModel): "type": "function", "function": { "name": "get_weather", - "arguments": '{"location": "San Francisco"}', + "arguments": ( + '{"location": "San Francisco"}' + ), }, } ] From 79283971d088dc7417766a596cbbce11df0b49a9 Mon Sep 17 00:00:00 2001 From: XyLearningProgramming Date: Mon, 23 Feb 2026 15:30:43 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=F0=9F=90=9B=20fixed=20lint=20issues=20agai?= =?UTF-8?q?n?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- slm_server/app.py | 4 +--- slm_server/config.py | 1 + slm_server/utils/ids.py | 2 +- slm_server/utils/postprocess.py | 23 +++++++++-------------- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/slm_server/app.py b/slm_server/app.py index 7c8a481..3325170 100644 --- a/slm_server/app.py +++ b/slm_server/app.py @@ -52,9 +52,7 @@ STREAM_RESPONSE_MEDIA_TYPE = "text/event-stream" # Schema for streaming responses. STREAM_RESPONSE_SCHEMA = { - "schema": { - "$ref": "#/components/schemas/ChatCompletionChunkResponse" - } + "schema": {"$ref": "#/components/schemas/ChatCompletionChunkResponse"} } diff --git a/slm_server/config.py b/slm_server/config.py index 9942ccc..b8547f5 100644 --- a/slm_server/config.py +++ b/slm_server/config.py @@ -126,5 +126,6 @@ def get_settings() -> Settings: get_settings._instance = Settings() return get_settings._instance + def get_model_id(settings: Annotated[Settings, Depends(get_settings)]) -> str: return settings.chat_model_id diff --git a/slm_server/utils/ids.py b/slm_server/utils/ids.py index d75c14d..c05eb67 100644 --- a/slm_server/utils/ids.py +++ b/slm_server/utils/ids.py @@ -5,7 +5,7 @@ import os import string -_LENGTH=12 +_LENGTH = 12 _ALPHABET = string.ascii_letters + string.digits _BASE = len(_ALPHABET) diff --git a/slm_server/utils/postprocess.py b/slm_server/utils/postprocess.py index 6b6fe2e..b17fb58 100644 --- a/slm_server/utils/postprocess.py +++ b/slm_server/utils/postprocess.py @@ -17,9 +17,7 @@ from slm_server.utils.ids import gen_id _THINK_RE = re.compile(r"(.*?)", re.DOTALL) -_TOOL_CALL_RE = re.compile( - r"\s*(\{.*?\})\s*", re.DOTALL -) +_TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) def extract_reasoning(content: str) -> tuple[str | None, str]: @@ -151,9 +149,7 @@ def __init__(self, *, model_id: str | None = None) -> None: # -- public API ---------------------------------------------------------- - def process_chunk( - self, chunk: dict[str, Any] - ) -> list[dict[str, Any]]: + def process_chunk(self, chunk: dict[str, Any]) -> list[dict[str, Any]]: """Process one streaming chunk, returning zero or more output chunks.""" self._last_chunk_template = chunk @@ -178,8 +174,11 @@ def process_chunk( "tool_calls" if self._had_tool_calls else orig_finish ) else: - result.append(self._rewrite_finish(chunk) if self._had_tool_calls - else self._stamp(copy.deepcopy(chunk))) + result.append( + self._rewrite_finish(chunk) + if self._had_tool_calls + else self._stamp(copy.deepcopy(chunk)) + ) return result @@ -193,9 +192,7 @@ def flush(self) -> list[dict[str, Any]]: # -- internals ----------------------------------------------------------- - def _consume_text( - self, text: str, chunk: dict[str, Any] - ) -> list[dict[str, Any]]: + def _consume_text(self, text: str, chunk: dict[str, Any]) -> list[dict[str, Any]]: """Feed *text* through the tag-detection state machine. Batches consecutive characters for the same state into single chunks @@ -273,9 +270,7 @@ def _emit_batch(self, text: str) -> list[dict[str, Any]]: self._tool_call_buf += text return [] - def _finish_tool_call( - self, chunk: dict[str, Any] - ) -> list[dict[str, Any]]: + def _finish_tool_call(self, chunk: dict[str, Any]) -> list[dict[str, Any]]: """Parse the accumulated JSON and emit a ``tool_calls`` delta.""" raw = self._tool_call_buf.strip() self._tool_call_buf = "" From 2952b870e959454576fb18dae27dcf2020f16326 Mon Sep 17 00:00:00 2001 From: XyLearningProgramming Date: Mon, 23 Feb 2026 15:34:15 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=F0=9F=A9=B9=20changed=20makefile=20lint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Makefile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index c658360..0427c4d 100644 --- a/Makefile +++ b/Makefile @@ -18,10 +18,11 @@ run: ## Start server via start.sh lint: ## Run ruff linter uv run ruff check slm_server/ -format: ## Run ruff formatter +format: ## Run ruff linter (--fix) and formatter + uv run ruff check slm_server/ --fix uv run ruff format slm_server/ -check: lint ## Run linter + formatter check +check: lint ## Run linter + formatter check (CI) uv run ruff format --check slm_server/ smoke: ## Smoke-test the running server APIs with curl