diff --git a/src/art/cli.py b/src/art/cli.py index 62218dbd..1afbba65 100644 --- a/src/art/cli.py +++ b/src/art/cli.py @@ -245,9 +245,9 @@ def migrate( model_dir, delete_originals=not keep_jsonl, dry_run=dry_run, - progress_callback=lambda f: typer.echo(f" {f}") - if verbose - else None, + progress_callback=lambda f: ( + typer.echo(f" {f}") if verbose else None + ), ) result = result + model_result else: diff --git a/src/art/tinker/server.py b/src/art/tinker/server.py index 22ee9bb9..8a553409 100644 --- a/src/art/tinker/server.py +++ b/src/art/tinker/server.py @@ -4,7 +4,7 @@ import os import socket import time -from typing import Annotated +from typing import Annotated, cast import uuid from fastapi import FastAPI, HTTPException, Request @@ -62,10 +62,13 @@ async def prompt_tokens( messages: list[ChatCompletionMessageParam], tools: list[ChatCompletionToolUnionParam] | None, ) -> list[int]: - return self._get_renderer(base_model).tokenizer.apply_chat_template( - messages, # type: ignore - tools=tools, # type: ignore - add_generation_prompt=True, + return cast( + list[int], + self._get_renderer(base_model).tokenizer.apply_chat_template( + messages, # type: ignore + tools=tools, # type: ignore + add_generation_prompt=True, + ), ) async def chat_completion_and_token_discrepancies( @@ -80,9 +83,9 @@ async def chat_completion_and_token_discrepancies( token_discrepancies: list[tuple[list[int], list[int]]] = [] for i, sequence in enumerate(sample_response.sequences): assert sequence.logprobs is not None, "Logprobs are required" - assert len(sequence.tokens) == len( - sequence.logprobs - ), "Tokens and logprobs must have the same length" + assert len(sequence.tokens) == len(sequence.logprobs), ( + "Tokens and logprobs must have the same length" + ) rendered_response_tokens = renderer.tokenizer.encode( renderer.tokenizer.decode(sequence.tokens) ) @@ -222,10 +225,11 @@ async def chat_completions( detail="Missing or invalid Authorization header", headers={"WWW-Authenticate": "Bearer"}, ) - sampling_client, base_model = ( - await self._get_sampling_client_and_base_model( - api_key, self.models.get(body["model"], body["model"]) - ) + ( + sampling_client, + base_model, + ) = await self._get_sampling_client_and_base_model( + api_key, self.models.get(body["model"], body["model"]) ) rendered_prompt_tokens = await worker.prompt_tokens( base_model=base_model, @@ -265,13 +269,14 @@ async def chat_completions( else: detail = str(e) raise HTTPException(status_code=e.status_code, detail=detail) from e - chat_completion, token_discrepancies = ( - await worker.chat_completion_and_token_discrepancies( - base_model=base_model, - sample_response=sample_response, - model_name=body["model"], - prompt_tokens=len(prompt_tokens), - ) + ( + chat_completion, + token_discrepancies, + ) = await worker.chat_completion_and_token_discrepancies( + base_model=base_model, + sample_response=sample_response, + model_name=body["model"], + prompt_tokens=len(prompt_tokens), ) for rendered_response_tokens, raw_response_tokens in token_discrepancies: self._prefix_cache.insert(