Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/art/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ def migrate(
model_dir,
delete_originals=not keep_jsonl,
dry_run=dry_run,
progress_callback=lambda f: typer.echo(f" {f}")
if verbose
else None,
progress_callback=lambda f: (
typer.echo(f" {f}") if verbose else None
),
)
result = result + model_result
else:
Expand Down
43 changes: 24 additions & 19 deletions src/art/tinker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import socket
import time
from typing import Annotated
from typing import Annotated, cast
import uuid

from fastapi import FastAPI, HTTPException, Request
Expand Down Expand Up @@ -62,10 +62,13 @@ async def prompt_tokens(
messages: list[ChatCompletionMessageParam],
tools: list[ChatCompletionToolUnionParam] | None,
) -> list[int]:
return self._get_renderer(base_model).tokenizer.apply_chat_template(
messages, # type: ignore
tools=tools, # type: ignore
add_generation_prompt=True,
return cast(
list[int],
self._get_renderer(base_model).tokenizer.apply_chat_template(
messages, # type: ignore
tools=tools, # type: ignore
add_generation_prompt=True,
),
)

async def chat_completion_and_token_discrepancies(
Expand All @@ -80,9 +83,9 @@ async def chat_completion_and_token_discrepancies(
token_discrepancies: list[tuple[list[int], list[int]]] = []
for i, sequence in enumerate(sample_response.sequences):
assert sequence.logprobs is not None, "Logprobs are required"
assert len(sequence.tokens) == len(
sequence.logprobs
), "Tokens and logprobs must have the same length"
assert len(sequence.tokens) == len(sequence.logprobs), (
"Tokens and logprobs must have the same length"
)
rendered_response_tokens = renderer.tokenizer.encode(
renderer.tokenizer.decode(sequence.tokens)
)
Expand Down Expand Up @@ -222,10 +225,11 @@ async def chat_completions(
detail="Missing or invalid Authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
sampling_client, base_model = (
await self._get_sampling_client_and_base_model(
api_key, self.models.get(body["model"], body["model"])
)
(
sampling_client,
base_model,
) = await self._get_sampling_client_and_base_model(
api_key, self.models.get(body["model"], body["model"])
)
rendered_prompt_tokens = await worker.prompt_tokens(
base_model=base_model,
Expand Down Expand Up @@ -265,13 +269,14 @@ async def chat_completions(
else:
detail = str(e)
raise HTTPException(status_code=e.status_code, detail=detail) from e
chat_completion, token_discrepancies = (
await worker.chat_completion_and_token_discrepancies(
base_model=base_model,
sample_response=sample_response,
model_name=body["model"],
prompt_tokens=len(prompt_tokens),
)
(
chat_completion,
token_discrepancies,
) = await worker.chat_completion_and_token_discrepancies(
base_model=base_model,
sample_response=sample_response,
model_name=body["model"],
prompt_tokens=len(prompt_tokens),
)
for rendered_response_tokens, raw_response_tokens in token_discrepancies:
self._prefix_cache.insert(
Expand Down