diff --git a/tests/unittests/integration/langchain/test_agent_invoke_methods.py b/tests/unittests/integration/langchain/test_agent_invoke_methods.py index d24404a..7aba9fe 100644 --- a/tests/unittests/integration/langchain/test_agent_invoke_methods.py +++ b/tests/unittests/integration/langchain/test_agent_invoke_methods.py @@ -81,6 +81,39 @@ def _sse(data: Dict[str, Any]) -> str: return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" +def _start_server(app: FastAPI) -> tuple: + """启动 FastAPI 服务器并返回 base_url + + Returns: + tuple: (base_url, server, thread) + """ + + # 添加健康检查端点 + @app.get("/health") + async def health(): + return {"status": "ok"} + + port = _find_free_port() + config = uvicorn.Config( + app, host="127.0.0.1", port=port, log_level="warning" + ) + server = uvicorn.Server(config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + base_url = f"http://127.0.0.1:{port}" + # 等待服务器启动 + for _ in range(50): + try: + httpx.get(f"{base_url}/health", timeout=0.2) + break + except Exception: + time.sleep(0.1) + + return base_url, server, thread + + def _build_mock_openai_app() -> FastAPI: """构建本地 OpenAI 协议兼容的简单服务""" app = FastAPI() @@ -297,20 +330,36 @@ def parse_sse_events(content: str) -> List[Dict[str, Any]]: async def request_agui_events( - server_app, + server_url_or_app: Union[str, FastAPI], messages: List[Dict[str, str]], stream: bool = True, ) -> List[Dict[str, Any]]: - """发送 AG-UI 请求并返回事件列表""" - async with httpx.AsyncClient( - transport=httpx.ASGITransport(app=server_app), - base_url="http://test", - ) as client: - response = await client.post( - "/ag-ui/agent", - json={"messages": messages, "stream": stream}, - timeout=60.0, - ) + """发送 AG-UI 请求并返回事件列表 + + Args: + server_url_or_app: 服务器 URL (如 "http://127.0.0.1:8000") 或 FastAPI app 对象 + messages: 消息列表 + stream: 是否流式响应 + """ + if isinstance(server_url_or_app, str): + # 使用真实的 HTTP 连接 + async with httpx.AsyncClient(base_url=server_url_or_app) as client: + response = await client.post( + "/ag-ui/agent", + json={"messages": messages, "stream": stream}, + timeout=60.0, + ) + else: + # 使用 ASGITransport (用于非流式测试) + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=server_url_or_app), + base_url="http://test", + ) as client: + response = await client.post( + "/ag-ui/agent", + json={"messages": messages, "stream": stream}, + timeout=60.0, + ) assert response.status_code == 200 return parse_sse_events(response.text) @@ -670,26 +719,42 @@ def assert_openai_tool_call_response( async def request_openai_events( - server_app, + server_url_or_app: Union[str, FastAPI], messages: List[Dict[str, str]], stream: bool = True, ) -> Union[List[Dict[str, Any]], Dict[str, Any]]: - """发送 OpenAI 协议请求并返回流式事件列表或响应""" + """发送 OpenAI 协议请求并返回流式事件列表或响应 + + Args: + server_url_or_app: 服务器 URL (如 "http://127.0.0.1:8000") 或 FastAPI app 对象 + messages: 消息列表 + stream: 是否流式响应 + """ payload: Dict[str, Any] = { "model": "mock-model", "messages": messages, "stream": stream, } - async with httpx.AsyncClient( - transport=httpx.ASGITransport(app=server_app), - base_url="http://test", - ) as client: - response = await client.post( - "/openai/v1/chat/completions", - json=payload, - timeout=60.0, - ) + if isinstance(server_url_or_app, str): + # 使用真实的 HTTP 连接 + async with httpx.AsyncClient(base_url=server_url_or_app) as client: + response = await client.post( + "/openai/v1/chat/completions", + json=payload, + timeout=60.0, + ) + else: + # 使用 ASGITransport (用于非流式测试) + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=server_url_or_app), + base_url="http://test", + ) as client: + response = await client.post( + "/openai/v1/chat/completions", + json=payload, + timeout=60.0, + ) assert response.status_code == 200 @@ -779,7 +844,16 @@ async def generator(): return generator() server = AgentRunServer(invoke_agent=invoke_agent) - return server.app + app = server.app + + # 启动真实的 HTTP 服务器 + base_url, uvicorn_server, thread = _start_server(app) + + yield base_url + + # 清理服务器 + uvicorn_server.should_exit = True + thread.join(timeout=5) # ============================================================================= @@ -1643,7 +1717,16 @@ async def generator(): return await agent.ainvoke(cast(Any, input_data)) server = AgentRunServer(invoke_agent=invoke_agent) - return server.app + app = server.app + + # 启动真实的 HTTP 服务器 + base_url, uvicorn_server, thread = _start_server(app) + + yield base_url + + # 清理服务器 + uvicorn_server.should_exit = True + thread.join(timeout=5) @pytest.mark.parametrize( "case_key,prompt", @@ -1704,7 +1787,16 @@ def generator(): return agent.invoke(cast(Any, input_data)) server = AgentRunServer(invoke_agent=invoke_agent) - return server.app + app = server.app + + # 启动真实的 HTTP 服务器 + base_url, uvicorn_server, thread = _start_server(app) + + yield base_url + + # 清理服务器 + uvicorn_server.should_exit = True + thread.join(timeout=5) @pytest.mark.parametrize( "case_key,prompt", @@ -1762,7 +1854,16 @@ async def generator(): return generator() server = AgentRunServer(invoke_agent=invoke_agent) - return server.app + app = server.app + + # 启动真实的 HTTP 服务器 + base_url, uvicorn_server, thread = _start_server(app) + + yield base_url + + # 清理服务器 + uvicorn_server.should_exit = True + thread.join(timeout=5) @pytest.fixture def server_app_async(self, agent_model): @@ -1796,7 +1897,16 @@ async def generator(): return generator() server = AgentRunServer(invoke_agent=invoke_agent) - return server.app + app = server.app + + # 启动真实的 HTTP 服务器 + base_url, uvicorn_server, thread = _start_server(app) + + yield base_url + + # 清理服务器 + uvicorn_server.should_exit = True + thread.join(timeout=5) @pytest.mark.parametrize( "case_key,prompt",