From a213be4fbb1fd764535413c04038dadaf6f3d950 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AF=92=E5=85=89?= <2510399607@qq.com> Date: Wed, 11 Feb 2026 21:14:30 +0800 Subject: [PATCH] feat(conversation_service): implement core session management with OTS backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit introduces the Conversation Service module, which provides session state persistence capabilities using Alibaba Cloud TableStore (OTS). Key components include: SessionStore: Manages session CRUD operations and cascading deletes. OTSBackend: Encapsulates OTS SDK operations for table creation and data manipulation. Data Models: Defines core data structures for sessions, events, and states. Utilities: Provides helper functions for state serialization and timestamp generation. Documentation: Includes usage examples and design documentation. Additionally, a new dependency on tablestore has been added to pyproject.toml to support the OTS operations. Signed-off-by: 寒光 <2510399607@qq.com> --- .gitignore | 9 + Makefile | 2 +- agentrun/conversation_service/README.md | 288 ++ agentrun/conversation_service/__init__.py | 44 + .../__ots_backend_async_template.py | 1272 +++++++++ .../__session_store_async_template.py | 764 ++++++ .../conversation_service/adapters/__init__.py | 21 + .../adapters/adk_adapter.py | 674 +++++ .../adapters/langchain_adapter.py | 248 ++ .../conversation_design.md | 262 ++ agentrun/conversation_service/model.py | 138 + agentrun/conversation_service/ots_backend.py | 2314 +++++++++++++++++ .../conversation_service/session_store.py | 1485 +++++++++++ agentrun/conversation_service/utils.py | 99 + agentrun/integration/utils/tool.py | 4 +- agentrun/memory_collection/README.md | 406 +++ agentrun/toolset/api/mcp.py | 2 +- codegen/codegen.py | 1 + examples/conversation_service_adk_agent.py | 128 + pyproject.toml | 4 + .../conversation_service/__init__.py | 0 .../conversation_service/test_adk_adapter.py | 824 ++++++ .../test_langchain_adapter.py | 544 ++++ .../conversation_service/test_model.py | 224 ++ .../conversation_service/test_ots_backend.py | 1953 ++++++++++++++ .../test_session_store.py | 1404 ++++++++++ .../conversation_service/test_utils.py | 147 ++ .../test_langchain_agui_integration.py | 8 +- tests/unittests/toolset/api/test_openapi.py | 20 +- 29 files changed, 13278 insertions(+), 11 deletions(-) create mode 100644 agentrun/conversation_service/README.md create mode 100644 agentrun/conversation_service/__init__.py create mode 100644 agentrun/conversation_service/__ots_backend_async_template.py create mode 100644 agentrun/conversation_service/__session_store_async_template.py create mode 100644 agentrun/conversation_service/adapters/__init__.py create mode 100644 agentrun/conversation_service/adapters/adk_adapter.py create mode 100644 agentrun/conversation_service/adapters/langchain_adapter.py create mode 100644 agentrun/conversation_service/conversation_design.md create mode 100644 agentrun/conversation_service/model.py create mode 100644 agentrun/conversation_service/ots_backend.py create mode 100644 agentrun/conversation_service/session_store.py create mode 100644 agentrun/conversation_service/utils.py create mode 100644 agentrun/memory_collection/README.md create mode 100644 examples/conversation_service_adk_agent.py create mode 100644 tests/unittests/conversation_service/__init__.py create mode 100644 tests/unittests/conversation_service/test_adk_adapter.py create mode 100644 tests/unittests/conversation_service/test_langchain_adapter.py create mode 100644 tests/unittests/conversation_service/test_model.py create mode 100644 tests/unittests/conversation_service/test_ots_backend.py create mode 100644 tests/unittests/conversation_service/test_session_store.py create mode 100644 tests/unittests/conversation_service/test_utils.py diff --git a/.gitignore b/.gitignore index eba76a7..7a2da69 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,12 @@ dmypy.json uv.lock coverage.json coverage.json + +# examples +examples/conversation_service_adk_example.py +examples/conversation_service_adk_data.py +examples/conversation_service_langchain_example.py +examples/conversation_service_langchain_data.py +examples/conversation_service_verify.py +examples/Langchain_His_example.py +examples/agent-quickstart-langchain/ \ No newline at end of file diff --git a/Makefile b/Makefile index a00124f..b43aad4 100644 --- a/Makefile +++ b/Makefile @@ -149,5 +149,5 @@ mypy-check: ## 运行 mypy 类型检查 .PHONY: coverage coverage: ## 运行测试并显示覆盖率报告(全量代码 + 增量代码) @echo "📊 运行覆盖率测试..." - @uv run python scripts/check_coverage.py + @uv run --python ${PYTHON_VERSION} --all-extras python scripts/check_coverage.py $(COVERAGE_ARGS) diff --git a/agentrun/conversation_service/README.md b/agentrun/conversation_service/README.md new file mode 100644 index 0000000..da51c0f --- /dev/null +++ b/agentrun/conversation_service/README.md @@ -0,0 +1,288 @@ +# Conversation Service + +为不同 Agent 开发框架提供**统一的会话状态持久化**能力,底层存储选用阿里云 TableStore(OTS,宽表模型)。 + +## 架构概览 + +采用 **统一存储 + 中心 Service + 薄 Adapter** 的三层设计: + +``` +ADK Agent ──→ OTSSessionService ──┐ + │ ┌─────────────┐ ┌─────────┐ +LangChain ──→ OTSChatMessageHistory ──→│ SessionStore │───→│ OTS │ + │ │ (业务逻辑层) │───→│ Tables │ +LangGraph ──→ (LG Adapter) ─────┘ └─────────────┘ └─────────┘ + │ + OTSBackend + (存储操作层) +``` + +- **SessionStore**:核心业务层,理解 OTS 表结构,提供 Session / Event / State 的 CRUD、级联删除、三级状态合并等统一接口。 +- **OTSBackend**:存储操作层,封装 TableStore SDK 的底层调用。 +- **Adapter**:薄适配层,仅负责框架数据模型转换。 + +## 快速开始 + +### 前置条件 + +- 阿里云账号,配置 AK/SK 环境变量 +- AgentRun 平台上已创建 MemoryCollection(包含 OTS 实例配置) + +### 安装 + +```bash +pip install agentrun +``` + +### 初始化 + +**方式一(推荐):通过 MemoryCollection 自动获取 OTS 连接信息** + +```python +from agentrun.conversation_service import SessionStore + +# 环境变量:AGENTRUN_ACCESS_KEY_ID / AGENTRUN_ACCESS_KEY_SECRET +store = SessionStore.from_memory_collection("your-memory-collection-name") + +# 首次使用时创建表 +store.init_tables() +``` + +`from_memory_collection()` 内部自动完成: +1. 调用 AgentRun API 获取 MemoryCollection 配置 +2. 从中提取 OTS 的 endpoint 和 instance_name +3. 从 `Config` 读取 AK/SK 凭证 +4. 构建 OTSClient 和 OTSBackend + +**方式二:手动传入 OTSClient** + +```python +import tablestore +from agentrun.conversation_service import SessionStore, OTSBackend + +ots_client = tablestore.OTSClient( + endpoint, access_key_id, access_key_secret, instance_name, + retry_policy=tablestore.WriteRetryPolicy(), +) +backend = OTSBackend(ots_client) +store = SessionStore(backend) +store.init_tables() +``` + +### 表初始化策略 + +表和索引按用途分组创建,避免创建不必要的表: + +| 方法 | 创建的资源 | 适用场景 | +|------|-----------|---------| +| `init_core_tables()` | Conversation + Event + 二级索引 | 所有框架 | +| `init_state_tables()` | State + App_state + User_state | ADK 三级 State | +| `init_search_index()` | 多元索引(conversation_search_index) | 需要搜索/过滤 | +| `init_tables()` | 以上全部 | 快速开发 | + +> 多元索引创建耗时较长(数秒级),建议与核心表创建分离,不阻塞核心流程。 + +## 使用示例 + +### Google ADK 集成 + +```python +import asyncio +from agentrun.conversation_service import SessionStore +from agentrun.conversation_service.adapters import OTSSessionService +from google.adk.agents import Agent +from google.adk.runners import Runner + +# 初始化 +store = SessionStore.from_memory_collection("my-collection") +store.init_tables() +session_service = OTSSessionService(session_store=store) + +# 创建 Agent + Runner +agent = Agent(name="assistant", model=my_model, instruction="...") +runner = Runner(agent=agent, app_name="my_app", session_service=session_service) + +# 对话自动持久化到 OTS +async def chat(): + session = await session_service.create_session( + app_name="my_app", user_id="user_1" + ) + async for event in runner.run_async( + user_id="user_1", session_id=session.id, new_message=content + ): + ... + +asyncio.run(chat()) +``` + +### LangChain 集成 + +```python +from agentrun.conversation_service import SessionStore +from agentrun.conversation_service.adapters import OTSChatMessageHistory +from langchain_core.messages import HumanMessage, AIMessage + +# 初始化 +store = SessionStore.from_memory_collection("my-collection") +store.init_core_tables() + +# 创建消息历史(自动关联 Session) +history = OTSChatMessageHistory( + session_store=store, + agent_id="my_agent", + user_id="user_1", + session_id="session_1", +) + +# 添加消息(自动持久化到 OTS) +history.add_message(HumanMessage(content="你好")) +history.add_message(AIMessage(content="你好!有什么可以帮你的?")) + +# 读取历史消息 +for msg in history.messages: + print(f"{msg.type}: {msg.content}") +``` + +### 直接使用 SessionStore + +```python +from agentrun.conversation_service import SessionStore + +store = SessionStore.from_memory_collection("my-collection") +store.init_tables() + +# Session CRUD +session = store.create_session("agent_1", "user_1", "sess_1", summary="测试会话") +sessions = store.list_sessions("agent_1", "user_1") + +# Event CRUD +event = store.append_event("agent_1", "user_1", "sess_1", "message", {"text": "hello"}) +events = store.get_events("agent_1", "user_1", "sess_1") +recent = store.get_recent_events("agent_1", "user_1", "sess_1", n=10) + +# 三级 State 管理(ADK 概念) +store.update_app_state("agent_1", {"model": "qwen-max"}) +store.update_user_state("agent_1", "user_1", {"language": "zh-CN"}) +store.update_session_state("agent_1", "user_1", "sess_1", {"topic": "weather"}) +merged = store.get_merged_state("agent_1", "user_1", "sess_1") +# merged = app_state <- user_state <- session_state(浅合并) + +# 多元索引搜索 +results, total = store.search_sessions( + "agent_1", + summary_keyword="天气", + updated_after=1700000000000000, + limit=20, +) + +# 级联删除(Event → State → Session 行) +store.delete_session("agent_1", "user_1", "sess_1") +``` + +## API 参考 + +### SessionStore + +核心业务层,所有方法同时提供同步和异步(`_async` 后缀)版本。 + +**工厂方法** + +| 方法 | 说明 | +|------|------| +| `from_memory_collection(name, *, config, table_prefix)` | 通过 MemoryCollection 名称创建实例 | + +**初始化** + +| 方法 | 说明 | +|------|------| +| `init_tables()` | 创建所有表和索引 | +| `init_core_tables()` | 创建核心表 + 二级索引 | +| `init_state_tables()` | 创建三张 State 表 | +| `init_search_index()` | 创建多元索引 | + +**Session 管理** + +| 方法 | 说明 | +|------|------| +| `create_session(agent_id, user_id, session_id, ...)` | 创建新会话 | +| `get_session(agent_id, user_id, session_id)` | 获取单个会话 | +| `list_sessions(agent_id, user_id, limit)` | 列出用户会话(按 updated_at 倒序) | +| `list_all_sessions(agent_id, limit)` | 列出 agent 下所有会话 | +| `search_sessions(agent_id, *, user_id, summary_keyword, ...)` | 多元索引搜索会话 | +| `update_session(agent_id, user_id, session_id, *, version, ...)` | 更新会话属性(乐观锁) | +| `delete_session(agent_id, user_id, session_id)` | 级联删除会话 | + +**Event 管理** + +| 方法 | 说明 | +|------|------| +| `append_event(agent_id, user_id, session_id, event_type, content)` | 追加事件 | +| `get_events(agent_id, user_id, session_id)` | 获取全部事件(正序) | +| `get_recent_events(agent_id, user_id, session_id, n)` | 获取最近 N 条事件 | +| `delete_events(agent_id, user_id, session_id)` | 删除会话下所有事件 | + +**State 管理** + +| 方法 | 说明 | +|------|------| +| `get_session_state / update_session_state` | 会话级状态读写 | +| `get_app_state / update_app_state` | 应用级状态读写 | +| `get_user_state / update_user_state` | 用户级状态读写 | +| `get_merged_state(agent_id, user_id, session_id)` | 三级状态浅合并 | + +### 框架适配器 + +| 适配器 | 框架 | 基类 | +|--------|------|------| +| `OTSSessionService` | Google ADK | `BaseSessionService` | +| `OTSChatMessageHistory` | LangChain | `BaseChatMessageHistory` | + +### 领域模型 + +| 模型 | 说明 | +|------|------| +| `ConversationSession` | 会话对象(含 agent_id, user_id, session_id, summary, labels 等) | +| `ConversationEvent` | 事件对象(含 seq_id 自增序号、type、content、raw_event) | +| `StateData` | 状态数据对象(含 state 字典、version 乐观锁) | +| `StateScope` | 状态作用域枚举:APP / USER / SESSION | + +## OTS 表结构 + +共五张表 + 一个二级索引 + 一个多元索引: + +| 表名 | 主键 | 用途 | +|------|------|------| +| `conversation` | agent_id, user_id, session_id | 会话元信息 | +| `event` | agent_id, user_id, session_id, seq_id (自增) | 事件/消息流 | +| `state` | agent_id, user_id, session_id | 会话级状态 | +| `app_state` | agent_id | 应用级状态 | +| `user_state` | agent_id, user_id | 用户级状态 | +| `conversation_secondary_index` | agent_id, user_id, updated_at, session_id | 二级索引(list 热路径) | +| `conversation_search_index` | 多元索引 | 全文搜索 / 标签过滤 / 组合查询 | + +> 表名支持通过 `table_prefix` 参数添加前缀,实现多租户隔离。 + +## 示例代码 + +| 文件 | 说明 | +|------|------| +| [`conversation_service_adk_agent.py`](../../examples/conversation_service_adk_agent.py) | ADK Agent 完整对话示例,自动持久化到 OTS | +| [`conversation_service_adk_example.py`](../../examples/conversation_service_adk_example.py) | ADK 数据读写验证(Session / Event / State) | +| [`conversation_service_adk_data.py`](../../examples/conversation_service_adk_data.py) | ADK 模拟数据填充 + 多元索引搜索验证 | +| [`conversation_service_langchain_example.py`](../../examples/conversation_service_langchain_example.py) | LangChain 消息历史读写验证 | +| [`conversation_service_langchain_data.py`](../../examples/conversation_service_langchain_data.py) | LangChain 模拟数据填充 | +| [`conversation_service_verify.py`](../../examples/conversation_service_verify.py) | 端到端 CRUD 验证脚本 | + +## 环境变量 + +| 变量 | 说明 | 必填 | +|------|------|------| +| `AGENTRUN_ACCESS_KEY_ID` | 阿里云 Access Key ID | 是(使用 `from_memory_collection` 时) | +| `AGENTRUN_ACCESS_KEY_SECRET` | 阿里云 Access Key Secret | 是(使用 `from_memory_collection` 时) | +| `ALIBABA_CLOUD_ACCESS_KEY_ID` | 备选 AK 环境变量 | 否(AK 候选) | +| `ALIBABA_CLOUD_ACCESS_KEY_SECRET` | 备选 SK 环境变量 | 否(SK 候选) | +| `MEMORY_COLLECTION_NAME` | MemoryCollection 名称(示例脚本使用) | 否 | + +## 设计文档 + +详细的表设计、访问模式分析和分层架构说明见 [conversation_design.md](./conversation_design.md)。 diff --git a/agentrun/conversation_service/__init__.py b/agentrun/conversation_service/__init__.py new file mode 100644 index 0000000..a43e700 --- /dev/null +++ b/agentrun/conversation_service/__init__.py @@ -0,0 +1,44 @@ +"""Conversation Service 模块。 + +为不同 Agent 开发框架提供会话状态持久化能力, +持久化数据库选用阿里云 TableStore(OTS,宽表模型)。 + +使用方式:: + + # 方式一(推荐):通过 MemoryCollection 自动获取 OTS 连接信息 + from agentrun.conversation_service import SessionStore + + store = SessionStore.from_memory_collection("your-memory-collection-name") + store.init_tables() + + # 方式二:手动传入 OTSClient + import tablestore + from agentrun.conversation_service import SessionStore, OTSBackend + + ots_client = tablestore.OTSClient( + endpoint, access_key_id, access_key_secret, instance_name, + ) + backend = OTSBackend(ots_client) + store = SessionStore(backend) + store.init_tables() +""" + +from agentrun.conversation_service.model import ( + ConversationEvent, + ConversationSession, + StateData, + StateScope, +) +from agentrun.conversation_service.ots_backend import OTSBackend +from agentrun.conversation_service.session_store import SessionStore + +__all__ = [ + # 核心服务 + "SessionStore", + "OTSBackend", + # 领域模型 + "ConversationSession", + "ConversationEvent", + "StateData", + "StateScope", +] diff --git a/agentrun/conversation_service/__ots_backend_async_template.py b/agentrun/conversation_service/__ots_backend_async_template.py new file mode 100644 index 0000000..f4821bf --- /dev/null +++ b/agentrun/conversation_service/__ots_backend_async_template.py @@ -0,0 +1,1272 @@ +"""OTS 存储后端。 + +封装 TableStore SDK 的底层操作,负责五张表的建表和 CRUD。 +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Optional + +from tablestore import AsyncOTSClient # type: ignore[import-untyped] +from tablestore import BatchWriteRowRequest # type: ignore[import-untyped] +from tablestore import ( + CapacityUnit, + ComparatorType, + Condition, + DeleteRowItem, + Direction, + INF_MAX, + INF_MIN, + OTSClient, + OTSServiceError, + PK_AUTO_INCR, + ReservedThroughput, + ReturnType, + Row, + RowExistenceExpectation, + SecondaryIndexMeta, + SecondaryIndexType, + SingleColumnCondition, + TableInBatchWriteRowItem, + TableMeta, + TableOptions, +) + +from agentrun.conversation_service.model import ( + ConversationEvent, + ConversationSession, + DEFAULT_APP_STATE_TABLE, + DEFAULT_CONVERSATION_SEARCH_INDEX, + DEFAULT_CONVERSATION_SECONDARY_INDEX, + DEFAULT_CONVERSATION_TABLE, + DEFAULT_EVENT_TABLE, + DEFAULT_STATE_TABLE, + DEFAULT_USER_STATE_TABLE, + StateData, + StateScope, +) +from agentrun.conversation_service.utils import ( + deserialize_state, + from_chunks, + MAX_COLUMN_SIZE, + nanoseconds_timestamp, + serialize_state, + to_chunks, +) + +logger = logging.getLogger(__name__) + +# OTS BatchWriteRow 每批最多 200 行 +_BATCH_WRITE_LIMIT = 200 + + +class OTSBackend: + """TableStore 存储后端。 + + 封装 OTS SDK 底层操作,理解表结构,提供五张表的 CRUD。 + 同时提供异步(_async 后缀)和同步方法。 + + Args: + ots_client: 预构建的 OTS SDK 同步客户端实例(同步方法使用)。 + table_prefix: 表名前缀,用于多租户隔离。 + async_ots_client: 预构建的 OTS SDK 异步客户端实例(异步方法使用)。 + """ + + def __init__( + self, + ots_client: Optional[OTSClient] = None, + table_prefix: str = "", + *, + async_ots_client: Optional[AsyncOTSClient] = None, + ) -> None: + self._client = ots_client + self._async_client = async_ots_client + self._table_prefix = table_prefix + + # 根据前缀生成实际表名 + self._conversation_table = f"{table_prefix}{DEFAULT_CONVERSATION_TABLE}" + self._event_table = f"{table_prefix}{DEFAULT_EVENT_TABLE}" + self._state_table = f"{table_prefix}{DEFAULT_STATE_TABLE}" + self._app_state_table = f"{table_prefix}{DEFAULT_APP_STATE_TABLE}" + self._user_state_table = f"{table_prefix}{DEFAULT_USER_STATE_TABLE}" + self._conversation_secondary_index = ( + f"{table_prefix}{DEFAULT_CONVERSATION_SECONDARY_INDEX}" + ) + self._conversation_search_index = ( + f"{table_prefix}{DEFAULT_CONVERSATION_SEARCH_INDEX}" + ) + + # ----------------------------------------------------------------------- + # 建表(异步)/ Table creation (async) + # ----------------------------------------------------------------------- + + async def init_tables_async(self) -> None: + """创建五张表和 Conversation 二级索引(异步)。 + + 表已存在时跳过(catch OTSServiceError 并 log warning)。 + """ + await self._create_conversation_table_async() + await self._create_event_table_async() + await self._create_state_table_async( + self._state_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ], + ) + await self._create_state_table_async( + self._app_state_table, + [("agent_id", "STRING")], + ) + await self._create_state_table_async( + self._user_state_table, + [("agent_id", "STRING"), ("user_id", "STRING")], + ) + + async def init_core_tables_async(self) -> None: + """创建核心表(Conversation + Event)和二级索引(异步)。""" + await self._create_conversation_table_async() + await self._create_event_table_async() + + async def init_state_tables_async(self) -> None: + """创建三张 State 表(异步)。""" + await self._create_state_table_async( + self._state_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ], + ) + await self._create_state_table_async( + self._app_state_table, + [("agent_id", "STRING")], + ) + await self._create_state_table_async( + self._user_state_table, + [("agent_id", "STRING"), ("user_id", "STRING")], + ) + + async def init_search_index_async(self) -> None: + """创建 Conversation 多元索引(异步)。按需调用。""" + await self._create_conversation_search_index_async() + + async def _create_conversation_table_async(self) -> None: + """创建 Conversation 表 + 二级索引(异步)。""" + table_meta = TableMeta( + self._conversation_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ], + # 二级索引引用的非 PK 列必须声明为 defined_columns + defined_columns=[ + ("updated_at", "INTEGER"), + ("summary", "STRING"), + ("labels", "STRING"), + ("framework", "STRING"), + ("extensions", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + # 二级索引:按 updated_at 排序 + secondary_index_meta = SecondaryIndexMeta( + self._conversation_secondary_index, + [ + "agent_id", + "user_id", + "updated_at", + "session_id", + ], + [ + "summary", + "labels", + "framework", + "extensions", + ], + index_type=SecondaryIndexType.GLOBAL_INDEX, + ) + + try: + await self._async_client.create_table( + table_meta, + table_options, + reserved_throughput, + secondary_indexes=[secondary_index_meta], + ) + logger.info( + "Created table: %s with secondary index: %s", + self._conversation_table, + self._conversation_secondary_index, + ) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._conversation_table, + ) + else: + raise + + async def _create_event_table_async(self) -> None: + """创建 Event 表(seq_id 为 AUTO_INCREMENT)(异步)。""" + table_meta = TableMeta( + self._event_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ("seq_id", "INTEGER", PK_AUTO_INCR), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + await self._async_client.create_table( + table_meta, + table_options, + reserved_throughput, + ) + logger.info("Created table: %s", self._event_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._event_table, + ) + else: + raise + + async def _create_state_table_async( + self, + table_name: str, + pk_schema: list[tuple[str, str]], + ) -> None: + """创建 State 类型表(通用方法)(异步)。""" + table_meta = TableMeta(table_name, pk_schema) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + await self._async_client.create_table( + table_meta, + table_options, + reserved_throughput, + ) + logger.info("Created table: %s", table_name) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + table_name, + ) + else: + raise + + async def _create_conversation_search_index_async(self) -> None: + """创建 Conversation 表的多元索引(异步)。 + + 多元索引支持全文检索 summary、精确匹配过滤 labels/framework/is_pinned、 + 范围查询 updated_at/created_at、跨 user 查询等场景。 + 索引已存在时跳过。 + """ + from tablestore import AnalyzerType # type: ignore[import-untyped] + from tablestore import FieldType # type: ignore[import-untyped] + from tablestore import IndexSetting # type: ignore[import-untyped] + from tablestore import SortOrder # type: ignore[import-untyped] + from tablestore import FieldSchema + from tablestore import ( + FieldSort as OTSFieldSort, + ) # type: ignore[import-untyped] + from tablestore import SearchIndexMeta + from tablestore import Sort as OTSSort # type: ignore[import-untyped] + + fields = [ + FieldSchema( + "agent_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "user_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "session_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "updated_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "created_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "is_pinned", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "framework", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "summary", + FieldType.TEXT, + index=True, + analyzer=AnalyzerType.SINGLEWORD, + ), + FieldSchema( + "labels", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + ] + + index_setting = IndexSetting(routing_fields=["agent_id"]) + index_sort = OTSSort( + sorters=[OTSFieldSort("updated_at", sort_order=SortOrder.DESC)] + ) + index_meta = SearchIndexMeta( + fields, + index_setting=index_setting, + index_sort=index_sort, + ) + + try: + await self._async_client.create_search_index( + self._conversation_table, + self._conversation_search_index, + index_meta, + ) + logger.info( + "Created search index: %s on table: %s", + self._conversation_search_index, + self._conversation_table, + ) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Search index %s already exists, skipping.", + self._conversation_search_index, + ) + else: + raise + + # ----------------------------------------------------------------------- + # Session CRUD(异步)/ Session CRUD (async) + # ----------------------------------------------------------------------- + + async def put_session_async(self, session: ConversationSession) -> None: + """PutRow 写入/覆盖 Session 行(异步)。""" + primary_key = [ + ("agent_id", session.agent_id), + ("user_id", session.user_id), + ("session_id", session.session_id), + ] + + attribute_columns = [ + ("created_at", session.created_at), + ("updated_at", session.updated_at), + ("is_pinned", session.is_pinned), + ("version", session.version), + ] + + if session.summary is not None: + attribute_columns.append(("summary", session.summary)) + if session.labels is not None: + attribute_columns.append(("labels", session.labels)) + if session.framework is not None: + attribute_columns.append(("framework", session.framework)) + if session.extensions is not None: + attribute_columns.append(( + "extensions", + json.dumps(session.extensions, ensure_ascii=False), + )) + + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + await self._async_client.put_row( + self._conversation_table, row, condition + ) + + async def get_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> Optional[ConversationSession]: + """GetRow 点读 Session(异步)。""" + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + + _, row, _ = await self._async_client.get_row( + self._conversation_table, + primary_key, + max_version=1, + ) + + if row is None or row.primary_key is None: + return None + + return self._row_to_session(row) + + async def delete_session_row_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> None: + """DeleteRow 删除 Session 单行(不含级联)(异步)。""" + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + row = Row(primary_key) + condition = Condition(RowExistenceExpectation.IGNORE) + await self._async_client.delete_row( + self._conversation_table, row, condition + ) + + async def update_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + columns_to_put: dict[str, Any], + version: int, + ) -> None: + """UpdateRow + 乐观锁更新 Session 行(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + columns_to_put: 要更新的列及其值。 + version: 当前版本号(乐观锁校验)。 + """ + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + + put_cols = list(columns_to_put.items()) + update_of_attribute_columns = {"PUT": put_cols} + + row = Row(primary_key, update_of_attribute_columns) + condition = Condition( + RowExistenceExpectation.EXPECT_EXIST, + SingleColumnCondition( + "version", + version, + ComparatorType.EQUAL, + ), + ) + await self._async_client.update_row( + self._conversation_table, row, condition + ) + + async def list_sessions_async( + self, + agent_id: str, + user_id: str, + limit: Optional[int] = None, + order_desc: bool = True, + ) -> list[ConversationSession]: + """通过二级索引按 updated_at 排序扫描 Session 列表(异步)。""" + + if order_desc: + # 倒序:从最新到最旧 + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MAX), + ("session_id", INF_MAX), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MIN), + ("session_id", INF_MIN), + ] + direction = Direction.BACKWARD + else: + # 正序:从最旧到最新 + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MIN), + ("session_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MAX), + ("session_id", INF_MAX), + ] + direction = Direction.FORWARD + + sessions: list[ConversationSession] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = await self._async_client.get_range( + self._conversation_secondary_index, + direction, + next_start, + exclusive_end, + max_version=1, + limit=limit, + ) + + for row in rows: + session = self._row_to_session_from_index(row) + sessions.append(session) + if limit is not None and len(sessions) >= limit: + return sessions + + if next_token is None: + break + next_start = next_token + + return sessions + + async def list_all_sessions_async( + self, + agent_id: str, + limit: Optional[int] = None, + ) -> list[ConversationSession]: + """扫描 agent_id 下所有用户的 Session(主表 GetRange)(异步)。 + + 不走二级索引,直接扫主表。返回结果不含 events, + 适用于 ADK list_sessions(user_id=None) 场景。 + + Args: + agent_id: 智能体 ID。 + limit: 最多返回条数,None 表示全部。 + + Returns: + ConversationSession 列表。 + """ + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", INF_MIN), + ("session_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", INF_MAX), + ("session_id", INF_MAX), + ] + + sessions: list[ConversationSession] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = await self._async_client.get_range( + self._conversation_table, + Direction.FORWARD, + next_start, + exclusive_end, + max_version=1, + limit=limit, + ) + + for row in rows: + session = self._row_to_session(row) + sessions.append(session) + if limit is not None and len(sessions) >= limit: + return sessions + + if next_token is None: + break + next_start = next_token + + return sessions + + async def search_sessions_async( + self, + agent_id: str, + *, + user_id: Optional[str] = None, + summary_keyword: Optional[str] = None, + labels: Optional[str] = None, + framework: Optional[str] = None, + updated_after: Optional[int] = None, + updated_before: Optional[int] = None, + is_pinned: Optional[bool] = None, + limit: int = 20, + offset: int = 0, + ) -> tuple[list[ConversationSession], int]: + """通过多元索引搜索 Session(异步)。 + + 支持全文搜索 summary、精确过滤 labels/framework/is_pinned、 + 范围查询 updated_at 以及跨 user_id 查询。 + + Args: + agent_id: 智能体 ID(必填,作为 routing 键优化查询)。 + user_id: 用户 ID(可选,精确匹配)。 + summary_keyword: summary 关键词(全文搜索)。 + labels: 标签 JSON 字符串(精确匹配)。 + framework: 框架标识(精确匹配)。 + updated_after: 仅返回 updated_at >= 此值的记录。 + updated_before: 仅返回 updated_at < 此值的记录。 + is_pinned: 是否置顶过滤。 + limit: 最多返回条数,默认 20。 + offset: 分页偏移量,默认 0。 + + Returns: + (结果列表, 总匹配数) 二元组。 + """ + from tablestore import BoolQuery # type: ignore[import-untyped] + from tablestore import MatchQuery # type: ignore[import-untyped] + from tablestore import SortOrder # type: ignore[import-untyped] + from tablestore import TermQuery # type: ignore[import-untyped] + from tablestore import ColumnReturnType, ColumnsToGet + from tablestore import ( + FieldSort as OTSFieldSort, + ) # type: ignore[import-untyped] + from tablestore import RangeQuery, SearchQuery + from tablestore import Sort as OTSSort # type: ignore[import-untyped] + + must_queries: list[Any] = [ + TermQuery("agent_id", agent_id), + ] + + if user_id is not None: + must_queries.append(TermQuery("user_id", user_id)) + if summary_keyword is not None: + must_queries.append(MatchQuery("summary", summary_keyword)) + if labels is not None: + must_queries.append(TermQuery("labels", labels)) + if framework is not None: + must_queries.append(TermQuery("framework", framework)) + if is_pinned is not None: + must_queries.append( + TermQuery("is_pinned", "true" if is_pinned else "false") + ) + if updated_after is not None or updated_before is not None: + must_queries.append( + RangeQuery( + "updated_at", + range_from=updated_after, + include_lower=True if updated_after is not None else None, + range_to=updated_before, + include_upper=False if updated_before is not None else None, + ) + ) + + query = BoolQuery(must_queries=must_queries) + + search_query = SearchQuery( + query, + sort=OTSSort( + sorters=[OTSFieldSort("updated_at", sort_order=SortOrder.DESC)] + ), + limit=limit, + offset=offset, + get_total_count=True, + ) + + columns_to_get = ColumnsToGet( + return_type=ColumnReturnType.ALL, + ) + + search_response = await self._async_client.search( + self._conversation_table, + self._conversation_search_index, + search_query, + columns_to_get, + ) + + sessions: list[ConversationSession] = [] + for row in search_response.rows: + # search API 返回 (primary_key, attribute_columns) 元组, + # 需要包装为 Row 对象以复用 _row_to_session + if isinstance(row, tuple): + row = Row(row[0], row[1]) + sessions.append(self._row_to_session(row)) + + return sessions, search_response.total_count or 0 + + # ----------------------------------------------------------------------- + # Event CRUD(异步)/ Event CRUD (async) + # ----------------------------------------------------------------------- + + async def put_event_async( + self, + agent_id: str, + user_id: str, + session_id: str, + event_type: str, + content: dict[str, Any], + created_at: Optional[int] = None, + updated_at: Optional[int] = None, + raw_event: Optional[str] = None, + ) -> int: + """PutRow 写入事件(seq_id AUTO_INCREMENT),返回 OTS 生成的 seq_id(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + event_type: 事件类型。 + content: 事件数据。 + created_at: 创建时间(纳秒时间戳),默认当前时间。 + updated_at: 更新时间(纳秒时间戳),默认当前时间。 + raw_event: 框架原生 Event 的完整 JSON 序列化(可选)。 + 用于精确还原框架特定的 Event 对象(如 ADK Event)。 + + Returns: + OTS 生成的 seq_id。 + """ + now = nanoseconds_timestamp() + if created_at is None: + created_at = now + if updated_at is None: + updated_at = now + + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", PK_AUTO_INCR), + ] + + content_json = json.dumps(content, ensure_ascii=False) + attribute_columns = [ + ("type", event_type), + ("content", content_json), + ("created_at", created_at), + ("updated_at", updated_at), + ("version", 0), + ] + + if raw_event is not None: + attribute_columns.append(("raw_event", raw_event)) + + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + + # put_row 返回 (consumed, return_row) + # 使用 ReturnType.RT_PK 让 OTS 返回自增 PK 值 + _, return_row = await self._async_client.put_row( + self._event_table, + row, + condition, + return_type=ReturnType.RT_PK, + ) + + # 从返回的主键中提取 seq_id + seq_id: int = 0 + if return_row is not None and return_row.primary_key is not None: + for pk_col in return_row.primary_key: + if pk_col[0] == "seq_id": + seq_id = pk_col[1] # type: ignore[assignment] + break + + return seq_id + + async def get_events_async( + self, + agent_id: str, + user_id: str, + session_id: str, + direction: str = "FORWARD", + limit: Optional[int] = None, + ) -> list[ConversationEvent]: + """GetRange 扫描事件列表(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + direction: 'FORWARD'(正序)或 'BACKWARD'(倒序)。 + limit: 最多返回条数。 + """ + if direction == "BACKWARD": + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MAX), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MIN), + ] + ots_direction = Direction.BACKWARD + else: + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MAX), + ] + ots_direction = Direction.FORWARD + + events: list[ConversationEvent] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = await self._async_client.get_range( + self._event_table, + ots_direction, + next_start, + exclusive_end, + max_version=1, + limit=limit, + ) + + for row in rows: + event = self._row_to_event(row) + events.append(event) + if limit is not None and len(events) >= limit: + return events + + if next_token is None: + break + next_start = next_token + + return events + + async def delete_events_by_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> int: + """批量删除 Session 下所有 Event,返回删除条数(异步)。 + + 先 GetRange 扫出所有 PK,再分批 BatchWriteRow 删除。 + """ + # 1. 扫描所有 Event 的 PK + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MAX), + ] + + all_pks: list[list[tuple[str, Any]]] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = await self._async_client.get_range( + self._event_table, + Direction.FORWARD, + next_start, + exclusive_end, + columns_to_get=[], # 只取 PK,不读属性列 + max_version=1, + ) + + for row in rows: + all_pks.append(row.primary_key) + + if next_token is None: + break + next_start = next_token + + if not all_pks: + return 0 + + # 2. 分批 BatchWriteRow 删除 + deleted = 0 + for i in range(0, len(all_pks), _BATCH_WRITE_LIMIT): + batch = all_pks[i : i + _BATCH_WRITE_LIMIT] + delete_items = [] + for pk in batch: + row = Row(pk) + condition = Condition(RowExistenceExpectation.IGNORE) + delete_items.append(DeleteRowItem(row, condition)) + + request = BatchWriteRowRequest() + request.add( + TableInBatchWriteRowItem(self._event_table, delete_items) + ) + await self._async_client.batch_write_row(request) + deleted += len(batch) + + return deleted + + # ----------------------------------------------------------------------- + # State CRUD(JSON 字符串存储 + 列分片)(异步) + # ----------------------------------------------------------------------- + + async def put_state_async( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + state: dict[str, Any], + version: int, + ) -> None: + """序列化 + 列分片写入 State(异步)。 + + State 以 JSON 字符串(STRING 类型)存储,不压缩。 + 当 JSON 字符串超过 1.5M 字符时自动分片。 + + Args: + scope: 状态作用域(APP / USER / SESSION)。 + agent_id: 智能体 ID。 + user_id: 用户 ID(APP scope 时忽略)。 + session_id: 会话 ID(APP/USER scope 时忽略)。 + state: 状态字典。 + version: 当前版本号(乐观锁校验,首次写入传 0)。 + """ + table_name, primary_key = self._resolve_state_table_and_pk( + scope, agent_id, user_id, session_id + ) + + now = nanoseconds_timestamp() + state_json = serialize_state(state) + + put_cols: list[tuple[str, Any]] = [ + ("updated_at", now), + ("version", version + 1), + ] + + # 首次写入需要 created_at + if version == 0: + put_cols.append(("created_at", now)) + + if len(state_json) <= MAX_COLUMN_SIZE: + # 不分片 + new_chunk_count = 0 + put_cols.append(("chunk_count", 0)) + put_cols.append(("state", state_json)) + else: + # 分片 + chunks = to_chunks(state_json) + new_chunk_count = len(chunks) + put_cols.append(("chunk_count", new_chunk_count)) + for idx, chunk in enumerate(chunks): + put_cols.append((f"state_{idx}", chunk)) + + update_of_attribute_columns: dict[str, Any] = {"PUT": put_cols} + + # 如果是更新(version > 0),需要清理旧的分片列 + delete_cols: list[str] = [] + if version > 0: + old_chunk_count = await self._get_chunk_count_async( + table_name, primary_key + ) + + if new_chunk_count == 0 and old_chunk_count > 0: + # 旧的有分片,新的不分片:删除所有 state_N 列 + for i in range(old_chunk_count): + delete_cols.append(f"state_{i}") + elif new_chunk_count > 0 and old_chunk_count == 0: + # 旧的不分片,新的有分片:删除 state 列 + delete_cols.append("state") + elif new_chunk_count > 0 and old_chunk_count > new_chunk_count: + # 都分片,但旧的分片更多:删除多余分片列 + for i in range(new_chunk_count, old_chunk_count): + delete_cols.append(f"state_{i}") + + if delete_cols: + update_of_attribute_columns["DELETE_ALL"] = delete_cols + + row = Row(primary_key, update_of_attribute_columns) + + if version == 0: + # 首次写入 + condition = Condition(RowExistenceExpectation.IGNORE) + else: + condition = Condition( + RowExistenceExpectation.EXPECT_EXIST, + SingleColumnCondition( + "version", + version, + ComparatorType.EQUAL, + ), + ) + + await self._async_client.update_row(table_name, row, condition) + + async def get_state_async( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + ) -> Optional[StateData]: + """读取 + 拼接分片 + 反序列化 State(异步)。""" + table_name, primary_key = self._resolve_state_table_and_pk( + scope, agent_id, user_id, session_id + ) + + _, row, _ = await self._async_client.get_row( + table_name, + primary_key, + max_version=1, + ) + + if row is None or row.primary_key is None: + return None + + attrs = self._attrs_to_dict(row.attribute_columns) + + chunk_count = attrs.get("chunk_count", 0) + if chunk_count == 0: + raw_state = attrs.get("state") + if raw_state is None: + return None + state = deserialize_state(str(raw_state)) + else: + chunks: list[str] = [] + for i in range(chunk_count): + chunk = attrs.get(f"state_{i}") + if chunk is None: + raise ValueError(f"Missing state chunk: state_{i}") + chunks.append(str(chunk)) + merged_str = from_chunks(chunks) + state = deserialize_state(merged_str) + + return StateData( + state=state, + created_at=attrs.get("created_at", 0), + updated_at=attrs.get("updated_at", 0), + version=attrs.get("version", 0), + ) + + async def delete_state_row_async( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + ) -> None: + """删除 State 行(异步)。""" + table_name, primary_key = self._resolve_state_table_and_pk( + scope, agent_id, user_id, session_id + ) + row = Row(primary_key) + condition = Condition(RowExistenceExpectation.IGNORE) + await self._async_client.delete_row(table_name, row, condition) + + # ----------------------------------------------------------------------- + # 内部辅助方法(I/O 相关,异步) + # ----------------------------------------------------------------------- + + async def _get_chunk_count_async( + self, + table_name: str, + primary_key: list[tuple[str, str]], + ) -> int: + """读取指定行的 chunk_count 值(异步)。""" + _, row, _ = await self._async_client.get_row( + table_name, + primary_key, + columns_to_get=["chunk_count"], + max_version=1, + ) + if row is None or row.primary_key is None: + return 0 + + attrs = self._attrs_to_dict(row.attribute_columns) + return attrs.get("chunk_count", 0) + + # ----------------------------------------------------------------------- + # 内部辅助方法(纯计算,不涉及 I/O,保持同步) + # ----------------------------------------------------------------------- + + def _resolve_state_table_and_pk( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + ) -> tuple[str, list[tuple[str, str]]]: + """根据 scope 返回对应的表名和主键列表。""" + if scope == StateScope.APP: + return self._app_state_table, [ + ("agent_id", agent_id), + ] + elif scope == StateScope.USER: + return self._user_state_table, [ + ("agent_id", agent_id), + ("user_id", user_id), + ] + else: # SESSION + return self._state_table, [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + + @staticmethod + def _attrs_to_dict( + attribute_columns: list[Any], + ) -> dict[str, Any]: + """将 OTS 属性列列表转换为字典。 + + OTS 返回的属性列格式为 [(name, value, timestamp), ...] + """ + result: dict[str, Any] = {} + if attribute_columns is None: + return result + for col in attribute_columns: + # col 格式: (name, value, timestamp) + name = col[0] + value = col[1] + result[name] = value + return result + + @staticmethod + def _pk_to_dict( + primary_key: list[Any], + ) -> dict[str, Any]: + """将 OTS 主键列表转换为字典。""" + result: dict[str, Any] = {} + if primary_key is None: + return result + for col in primary_key: + name = col[0] + value = col[1] + result[name] = value + return result + + def _row_to_session(self, row: Row) -> ConversationSession: + """将 OTS Row 转换为 ConversationSession。""" + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + + extensions = None + ext_raw = attrs.get("extensions") + if ext_raw is not None and isinstance(ext_raw, str): + extensions = json.loads(ext_raw) + + return ConversationSession( + agent_id=pk["agent_id"], + user_id=pk["user_id"], + session_id=pk["session_id"], + created_at=attrs.get("created_at", 0), + updated_at=attrs.get("updated_at", 0), + is_pinned=attrs.get("is_pinned", False), + summary=attrs.get("summary"), + labels=attrs.get("labels"), + framework=attrs.get("framework"), + extensions=extensions, + version=attrs.get("version", 0), + ) + + def _row_to_session_from_index(self, row: Row) -> ConversationSession: + """将二级索引 Row 转换为 ConversationSession。 + + 二级索引的 PK 包含 updated_at,属性列只有预定义的列。 + """ + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + + extensions = None + ext_raw = attrs.get("extensions") + if ext_raw is not None and isinstance(ext_raw, str): + extensions = json.loads(ext_raw) + + return ConversationSession( + agent_id=pk["agent_id"], + user_id=pk["user_id"], + session_id=pk["session_id"], + created_at=0, # 二级索引不含 created_at + updated_at=pk.get("updated_at", 0), + summary=attrs.get("summary"), + labels=attrs.get("labels"), + framework=attrs.get("framework"), + extensions=extensions, + ) + + def _row_to_event(self, row: Row) -> ConversationEvent: + """将 OTS Row 转换为 ConversationEvent。""" + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + + content_raw = attrs.get("content", "{}") + if isinstance(content_raw, str): + content = json.loads(content_raw) + else: + content = {} + + return ConversationEvent( + agent_id=pk["agent_id"], + user_id=pk["user_id"], + session_id=pk["session_id"], + seq_id=pk.get("seq_id"), + type=attrs.get("type", ""), + content=content, + created_at=attrs.get("created_at", 0), + updated_at=attrs.get("updated_at", 0), + version=attrs.get("version", 0), + raw_event=attrs.get("raw_event"), + ) diff --git a/agentrun/conversation_service/__session_store_async_template.py b/agentrun/conversation_service/__session_store_async_template.py new file mode 100644 index 0000000..7f827a0 --- /dev/null +++ b/agentrun/conversation_service/__session_store_async_template.py @@ -0,0 +1,764 @@ +"""SessionStore 核心业务逻辑层。 + +提供框架无关的统一会话管理接口,包括 Session、Event、State 的 CRUD, +以及级联删除和三级状态合并。 +""" + +from __future__ import annotations + +import logging +from typing import Any, Optional + +from agentrun.conversation_service.model import ( + ConversationEvent, + ConversationSession, + StateScope, +) +from agentrun.conversation_service.ots_backend import OTSBackend +from agentrun.conversation_service.utils import nanoseconds_timestamp + +logger = logging.getLogger(__name__) + + +class SessionStore: + """核心业务逻辑层。 + + 封装 OTSBackend,实现级联删除、状态合并等业务逻辑, + 向上暴露框架无关的统一接口。 + 同时提供异步(_async 后缀)和同步方法。 + + Args: + ots_backend: OTS 存储后端实例。 + """ + + def __init__(self, ots_backend: OTSBackend) -> None: + self._backend = ots_backend + + async def init_tables_async(self) -> None: + """创建所有 OTS 表和索引(异步)。代理到 OTSBackend.init_tables_async()。""" + await self._backend.init_tables_async() + + async def init_core_tables_async(self) -> None: + """创建核心表(Conversation + Event)和二级索引(异步)。""" + await self._backend.init_core_tables_async() + + async def init_state_tables_async(self) -> None: + """创建三张 State 表(异步)。""" + await self._backend.init_state_tables_async() + + async def init_search_index_async(self) -> None: + """创建 Conversation 多元索引(异步)。按需调用。""" + await self._backend.init_search_index_async() + + # ------------------------------------------------------------------- + # Session 管理(异步)/ Session management (async) + # ------------------------------------------------------------------- + + async def create_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + *, + is_pinned: bool = False, + summary: Optional[str] = None, + labels: Optional[str] = None, + framework: Optional[str] = None, + extensions: Optional[dict[str, Any]] = None, + ) -> ConversationSession: + """创建新 Session(异步)。 + + 自动设置 created_at 和 updated_at 为当前纳秒时间戳。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + is_pinned: 是否置顶。 + summary: 会话摘要。 + labels: 会话标签。 + framework: 框架标识。 + extensions: 框架扩展数据。 + + Returns: + 创建完成的 ConversationSession 对象。 + """ + now = nanoseconds_timestamp() + session = ConversationSession( + agent_id=agent_id, + user_id=user_id, + session_id=session_id, + created_at=now, + updated_at=now, + is_pinned=is_pinned, + summary=summary, + labels=labels, + framework=framework, + extensions=extensions, + version=0, + ) + await self._backend.put_session_async(session) + return session + + async def get_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> Optional[ConversationSession]: + """获取单个 Session(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + ConversationSession 对象,不存在时返回 None。 + """ + return await self._backend.get_session_async( + agent_id, user_id, session_id + ) + + async def list_sessions_async( + self, + agent_id: str, + user_id: str, + limit: Optional[int] = None, + ) -> list[ConversationSession]: + """列出用户的 Session(按 updated_at 倒序)(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + limit: 最多返回条数,None 表示全部。 + + Returns: + ConversationSession 列表。 + """ + return await self._backend.list_sessions_async( + agent_id, user_id, limit=limit, order_desc=True + ) + + async def list_all_sessions_async( + self, + agent_id: str, + limit: Optional[int] = None, + ) -> list[ConversationSession]: + """列出 agent_id 下所有用户的 Session(异步)。 + + 不要求 user_id,扫描主表全量返回。 + 适用于 ADK list_sessions(user_id=None) 场景。 + + Args: + agent_id: 智能体 ID。 + limit: 最多返回条数,None 表示全部。 + + Returns: + ConversationSession 列表。 + """ + return await self._backend.list_all_sessions_async( + agent_id, limit=limit + ) + + async def search_sessions_async( + self, + agent_id: str, + *, + user_id: Optional[str] = None, + summary_keyword: Optional[str] = None, + labels: Optional[str] = None, + framework: Optional[str] = None, + updated_after: Optional[int] = None, + updated_before: Optional[int] = None, + is_pinned: Optional[bool] = None, + limit: int = 20, + offset: int = 0, + ) -> tuple[list[ConversationSession], int]: + """搜索会话(多元索引)(异步)。 + + 通过多元索引实现全文搜索 summary、标签过滤、跨 user 查询等高级查询。 + + Args: + agent_id: 智能体 ID(必填)。 + user_id: 用户 ID(可选,精确匹配)。 + summary_keyword: summary 关键词(全文搜索)。 + labels: 标签 JSON 字符串(精确匹配)。 + framework: 框架标识(精确匹配)。 + updated_after: 仅返回 updated_at >= 此值的记录。 + updated_before: 仅返回 updated_at < 此值的记录。 + is_pinned: 是否置顶过滤。 + limit: 最多返回条数,默认 20。 + offset: 分页偏移量,默认 0。 + + Returns: + (结果列表, 总匹配数) 二元组。 + """ + return await self._backend.search_sessions_async( + agent_id, + user_id=user_id, + summary_keyword=summary_keyword, + labels=labels, + framework=framework, + updated_after=updated_after, + updated_before=updated_before, + is_pinned=is_pinned, + limit=limit, + offset=offset, + ) + + async def delete_events_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> int: + """只删除 Session 下所有 Event,不删 Session 本身(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + 删除的事件条数。 + """ + deleted = await self._backend.delete_events_by_session_async( + agent_id, user_id, session_id + ) + logger.debug( + "Deleted %d events for session %s/%s/%s", + deleted, + agent_id, + user_id, + session_id, + ) + return deleted + + async def delete_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> None: + """级联删除 Session(异步)。 + + 删除顺序:Event → State → Session 行。 + 先删 Event(量最大),再删 State,最后删 Session 行。 + 如果中间失败,Session 行仍在,下次重试可继续清理(幂等安全)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + """ + # 1. 删除所有 Event + deleted_events = await self._backend.delete_events_by_session_async( + agent_id, user_id, session_id + ) + logger.debug( + "Deleted %d events for session %s/%s/%s", + deleted_events, + agent_id, + user_id, + session_id, + ) + + # 2. 删除 Session 级 State + await self._backend.delete_state_row_async( + StateScope.SESSION, + agent_id, + user_id, + session_id, + ) + + # 3. 删除 Session 行 + await self._backend.delete_session_row_async( + agent_id, user_id, session_id + ) + + logger.info( + "Cascade deleted session %s/%s/%s", + agent_id, + user_id, + session_id, + ) + + async def update_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + *, + is_pinned: Optional[bool] = None, + summary: Optional[str] = None, + labels: Optional[str] = None, + extensions: Optional[dict[str, Any]] = None, + version: int, + ) -> None: + """更新 Session 属性(乐观锁)(异步)。 + + 只更新提供的字段,未提供的字段不变。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + is_pinned: 是否置顶。 + summary: 会话摘要。 + labels: 会话标签。 + extensions: 框架扩展数据。 + version: 当前版本号(乐观锁校验)。 + """ + columns_to_put: dict[str, Any] = { + "updated_at": nanoseconds_timestamp(), + "version": version + 1, + } + + if is_pinned is not None: + columns_to_put["is_pinned"] = is_pinned + if summary is not None: + columns_to_put["summary"] = summary + if labels is not None: + columns_to_put["labels"] = labels + if extensions is not None: + import json + + columns_to_put["extensions"] = json.dumps( + extensions, ensure_ascii=False + ) + + await self._backend.update_session_async( + agent_id, + user_id, + session_id, + columns_to_put, + version, + ) + + # ------------------------------------------------------------------- + # Event 管理(异步)/ Event management (async) + # ------------------------------------------------------------------- + + async def append_event_async( + self, + agent_id: str, + user_id: str, + session_id: str, + event_type: str, + content: dict[str, Any], + raw_event: Optional[str] = None, + ) -> ConversationEvent: + """追加事件,同时更新 Session 的 updated_at(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + event_type: 事件类型。 + content: 事件数据。 + raw_event: 框架原生 Event 的完整 JSON 序列化(可选)。 + 用于精确还原框架特定的 Event 对象(如 ADK Event)。 + + Returns: + 包含 OTS 生成的 seq_id 的 ConversationEvent 对象。 + """ + now = nanoseconds_timestamp() + + # 1. 写入 Event + seq_id = await self._backend.put_event_async( + agent_id, + user_id, + session_id, + event_type, + content, + created_at=now, + updated_at=now, + raw_event=raw_event, + ) + + # 2. 更新 Session 的 updated_at(保证二级索引排序正确) + # 先读取当前 Session 获取 version + session = await self._backend.get_session_async( + agent_id, user_id, session_id + ) + if session is not None: + try: + await self._backend.update_session_async( + agent_id, + user_id, + session_id, + { + "updated_at": now, + "version": session.version + 1, + }, + session.version, + ) + except Exception: + # 更新 Session 时间戳失败不应阻断事件写入 + logger.warning( + "Failed to update session updated_at " + "for %s/%s/%s, event was still written.", + agent_id, + user_id, + session_id, + exc_info=True, + ) + + return ConversationEvent( + agent_id=agent_id, + user_id=user_id, + session_id=session_id, + seq_id=seq_id, + type=event_type, + content=content, + created_at=now, + updated_at=now, + version=0, + raw_event=raw_event, + ) + + async def get_events_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> list[ConversationEvent]: + """获取 Session 全部事件(正序)(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + 按 seq_id 正序排列的事件列表。 + """ + return await self._backend.get_events_async( + agent_id, + user_id, + session_id, + direction="FORWARD", + ) + + async def get_recent_events_async( + self, + agent_id: str, + user_id: str, + session_id: str, + n: int, + ) -> list[ConversationEvent]: + """获取最近 N 条事件(异步)。 + + 倒序取 N 条,返回时翻转为正序。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + n: 需要获取的事件数量。 + + Returns: + 按 seq_id 正序排列的最近 N 条事件。 + """ + events = await self._backend.get_events_async( + agent_id, + user_id, + session_id, + direction="BACKWARD", + limit=n, + ) + events.reverse() + return events + + # ------------------------------------------------------------------- + # State 管理(异步)/ State management (async) + # ------------------------------------------------------------------- + + async def get_session_state_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> dict[str, Any]: + """获取 session 级 state,不存在返回 {}(异步)。""" + state_data = await self._backend.get_state_async( + StateScope.SESSION, + agent_id, + user_id, + session_id, + ) + return state_data.state if state_data else {} + + async def update_session_state_async( + self, + agent_id: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 session state(异步)。 + + 浅合并语义:top-level key 覆盖,值为 None 表示删除该 key。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + delta: 增量更新字典。 + """ + await self._apply_delta_async( + StateScope.SESSION, + agent_id, + user_id, + session_id, + delta, + ) + + async def get_app_state_async(self, agent_id: str) -> dict[str, Any]: + """获取 app 级 state,不存在返回 {}(异步)。""" + state_data = await self._backend.get_state_async( + StateScope.APP, agent_id, "", "" + ) + return state_data.state if state_data else {} + + async def update_app_state_async( + self, + agent_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 app state(异步)。 + + 浅合并语义:top-level key 覆盖,值为 None 表示删除该 key。 + """ + await self._apply_delta_async(StateScope.APP, agent_id, "", "", delta) + + async def get_user_state_async( + self, agent_id: str, user_id: str + ) -> dict[str, Any]: + """获取 user 级 state,不存在返回 {}(异步)。""" + state_data = await self._backend.get_state_async( + StateScope.USER, agent_id, user_id, "" + ) + return state_data.state if state_data else {} + + async def update_user_state_async( + self, + agent_id: str, + user_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 user state(异步)。 + + 浅合并语义:top-level key 覆盖,值为 None 表示删除该 key。 + """ + await self._apply_delta_async( + StateScope.USER, + agent_id, + user_id, + "", + delta, + ) + + async def get_merged_state_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> dict[str, Any]: + """三级状态浅合并:app_state <- user_state <- session_state(异步)。 + + 后者覆盖前者,任意层不存在视为空 dict。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + 合并后的状态字典。 + """ + merged: dict[str, Any] = {} + merged.update(await self.get_app_state_async(agent_id)) + merged.update(await self.get_user_state_async(agent_id, user_id)) + merged.update( + await self.get_session_state_async(agent_id, user_id, session_id) + ) + return merged + + # ------------------------------------------------------------------- + # 内部辅助方法(异步) + # ------------------------------------------------------------------- + + async def _apply_delta_async( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 State(通用逻辑)(异步)。 + + - 首次写入:过滤 None 值后整体写入,version=0 + - 后续更新:读取现有 state → 浅合并 delta(None 删除 key)→ 写回 + + Args: + scope: 状态作用域。 + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + delta: 增量更新字典。 + """ + existing = await self._backend.get_state_async( + scope, agent_id, user_id, session_id + ) + + if existing is None: + # 首次写入,过滤 None 值 + new_state = {k: v for k, v in delta.items() if v is not None} + await self._backend.put_state_async( + scope, + agent_id, + user_id, + session_id, + state=new_state, + version=0, + ) + else: + # 增量合并 + merged = dict(existing.state) + for k, v in delta.items(): + if v is None: + merged.pop(k, None) # None 表示删除 + else: + merged[k] = v # 浅覆盖 + await self._backend.put_state_async( + scope, + agent_id, + user_id, + session_id, + state=merged, + version=existing.version, + ) + + # ------------------------------------------------------------------- + # 工厂方法(异步)/ Factory methods (async) + # ------------------------------------------------------------------- + + @classmethod + async def from_memory_collection_async( + cls, + memory_collection_name: str, + *, + config: Optional[Any] = None, + table_prefix: str = "", + ) -> "SessionStore": + """通过 MemoryCollection 名称创建 SessionStore(异步)。 + + 从 AgentRun 平台获取 MemoryCollection 配置,自动提取 OTS 实例 + 的 endpoint 和 instance_name,结合 Config 中的 AK/SK 凭证, + 构建 OTSClient 和 OTSBackend,返回即用的 SessionStore。 + + Args: + memory_collection_name: AgentRun 平台上的 MemoryCollection 名称。 + config: agentrun Config 对象(可选)。 + 未提供时自动从环境变量读取凭证。 + table_prefix: 表名前缀,用于多租户隔离,默认不添加。 + + Returns: + 配置完成的 SessionStore 实例。 + + Raises: + ImportError: 未安装 agentrun 主包时抛出。 + ValueError: MemoryCollection 缺少 OTS 配置或凭证为空时抛出。 + + Example:: + + store = await SessionStore.from_memory_collection_async( + "my-memory-collection", + ) + await store.init_tables_async() + """ + # 延迟导入,避免 conversation_service 强依赖 agentrun 主包 + try: + from agentrun.memory_collection import MemoryCollection + from agentrun.utils.config import Config + except ImportError as e: + raise ImportError( + "agentrun 主包未安装。请先安装: pip install agentrun" + ) from e + + from tablestore import AsyncOTSClient # type: ignore[import-untyped] + from tablestore import OTSClient # type: ignore[import-untyped] + from tablestore import WriteRetryPolicy + + from agentrun.conversation_service.utils import ( + convert_vpc_endpoint_to_public, + ) + + # 1. 获取 MemoryCollection 配置 + mc = await MemoryCollection.get_by_name_async( + memory_collection_name, config=config + ) + + # 2. 提取 OTS 连接信息 + if not mc.vector_store_config or not mc.vector_store_config.config: + raise ValueError( + f"MemoryCollection '{memory_collection_name}' 缺少 " + "vector_store_config 配置,无法获取 OTS 连接信息。" + ) + + vs_config = mc.vector_store_config.config + endpoint = convert_vpc_endpoint_to_public(vs_config.endpoint or "") + instance_name = vs_config.instance_name or "" + + if not endpoint: + raise ValueError( + f"MemoryCollection '{memory_collection_name}' 的 " + "vector_store_config.endpoint 为空。" + ) + if not instance_name: + raise ValueError( + f"MemoryCollection '{memory_collection_name}' 的 " + "vector_store_config.instance_name 为空。" + ) + + # 3. 获取凭证 + effective_config = config if isinstance(config, Config) else Config() + access_key_id = effective_config.get_access_key_id() + access_key_secret = effective_config.get_access_key_secret() + + if not access_key_id or not access_key_secret: + raise ValueError( + "AK/SK 凭证为空。请通过 Config 参数传入或设置环境变量 " + "AGENTRUN_ACCESS_KEY_ID / AGENTRUN_ACCESS_KEY_SECRET。" + ) + + security_token = effective_config.get_security_token() + sts_token = security_token if security_token else None + + # 4. 构建 OTSClient + AsyncOTSClient 和 OTSBackend + ots_client = OTSClient( + endpoint, + access_key_id, + access_key_secret, + instance_name, + sts_token=sts_token, + retry_policy=WriteRetryPolicy(), + ) + async_ots_client = AsyncOTSClient( + endpoint, + access_key_id, + access_key_secret, + instance_name, + sts_token=sts_token, + retry_policy=WriteRetryPolicy(), + ) + + backend = OTSBackend( + ots_client, + table_prefix=table_prefix, + async_ots_client=async_ots_client, + ) + return cls(backend) diff --git a/agentrun/conversation_service/adapters/__init__.py b/agentrun/conversation_service/adapters/__init__.py new file mode 100644 index 0000000..a67d248 --- /dev/null +++ b/agentrun/conversation_service/adapters/__init__.py @@ -0,0 +1,21 @@ +"""Conversation Service 框架适配器。 + +提供不同 Agent 开发框架的会话持久化适配器。 +""" + +from agentrun.conversation_service.adapters.langchain_adapter import ( + OTSChatMessageHistory, +) + +# ADK adapter 依赖 google-adk,仅在安装了 google-adk 时可用 +try: + from agentrun.conversation_service.adapters.adk_adapter import ( + OTSSessionService, + ) +except ImportError: + pass + +__all__ = [ + "OTSChatMessageHistory", + "OTSSessionService", +] diff --git a/agentrun/conversation_service/adapters/adk_adapter.py b/agentrun/conversation_service/adapters/adk_adapter.py new file mode 100644 index 0000000..e3e4e90 --- /dev/null +++ b/agentrun/conversation_service/adapters/adk_adapter.py @@ -0,0 +1,674 @@ +"""Google ADK BaseSessionService 适配器。 + +将 Google ADK 的会话管理持久化到 OTS,通过 SessionStore 实现。 + +使用方式:: + + from agentrun.conversation_service import SessionStore, OTSBackend + from agentrun.conversation_service.adapters import OTSSessionService + + store = SessionStore(OTSBackend(ots_client, async_ots_client=async_ots_client)) + + # 作为 ADK Runner 的 session_service + from google.adk.runners import Runner + + runner = Runner( + agent=my_agent, + app_name="my_app", + session_service=OTSSessionService(session_store=store), + ) +""" + +from __future__ import annotations + +import logging +import time +from typing import Any, Optional +import uuid + +from google.adk.events.event import Event # type: ignore[import-untyped] +from google.adk.sessions.base_session_service import ( # type: ignore[import-untyped] + BaseSessionService, + GetSessionConfig, + ListSessionsResponse, +) +from google.adk.sessions.session import Session # type: ignore[import-untyped] +from google.adk.sessions.state import State # type: ignore[import-untyped] +from typing_extensions import override + +from agentrun.conversation_service.session_store import SessionStore + +logger = logging.getLogger(__name__) + +# ADK 使用 key 前缀区分 state 作用域 +_APP_PREFIX = State.APP_PREFIX # "app:" +_USER_PREFIX = State.USER_PREFIX # "user:" +_TEMP_PREFIX = State.TEMP_PREFIX # "temp:" + +# 事件类型标识 +_EVENT_TYPE = "adk_event" + + +# ------------------------------------------------------------------- +# 工具函数 +# ------------------------------------------------------------------- + + +def _extract_state_delta( + state: dict[str, Any], +) -> dict[str, dict[str, Any]]: + """从 state 字典中按前缀拆分出 app / user / session 三级 delta。 + + 自行实现,避免依赖 google.adk.sessions._session_util(私有模块)。 + + Args: + state: 包含前缀标识的 state 字典。 + + Returns: + 包含 'app'、'user'、'session' 三个 key 的字典。 + """ + deltas: dict[str, dict[str, Any]] = { + "app": {}, + "user": {}, + "session": {}, + } + if state: + for key in state.keys(): + if key.startswith(_APP_PREFIX): + deltas["app"][key.removeprefix(_APP_PREFIX)] = state[key] + elif key.startswith(_USER_PREFIX): + deltas["user"][key.removeprefix(_USER_PREFIX)] = state[key] + elif not key.startswith(_TEMP_PREFIX): + deltas["session"][key] = state[key] + return deltas + + +def _extract_display_content( + event: Event, +) -> dict[str, Any]: + """从 ADK Event 提取用于展示的简化内容。 + + 存入 OTS Event 表的 content 列,供跨框架展示使用。 + """ + result: dict[str, Any] = {"author": event.author} + if event.content and event.content.parts: + texts: list[str] = [] + for part in event.content.parts: + if part.text: + texts.append(part.text) + elif part.function_call: + texts.append(f"[call:{part.function_call.name}]") + elif part.function_response: + texts.append(f"[response:{part.function_response.name}]") + result["text"] = "\n".join(texts) + return result + + +# ------------------------------------------------------------------- +# OTSSessionService +# ------------------------------------------------------------------- + + +class OTSSessionService(BaseSessionService): + """基于 OTS 的 Google ADK SessionService 实现。 + + async 公共方法使用原生 ``await self._store.xxx_async(...)`` 调用, + sync ``_impl`` 方法使用 ``self._store.xxx(...)`` 调用。 + + Args: + session_store: SessionStore 实例。 + """ + + def __init__(self, session_store: SessionStore) -> None: + self._store = session_store + + # --------------------------------------------------------------- + # create_session + # --------------------------------------------------------------- + + @override + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + session_id = ( + session_id.strip() + if session_id and session_id.strip() + else str(uuid.uuid4()) + ) + + # 1. 拆分初始 state 为三级 + state_deltas = _extract_state_delta(state or {}) + + # 2. 创建 OTS session + await self._store.create_session_async( + app_name, + user_id, + session_id, + framework="adk", + ) + + # 3. 持久化三级 state + if state_deltas["app"]: + await self._store.update_app_state_async( + app_name, state_deltas["app"] + ) + if state_deltas["user"]: + await self._store.update_user_state_async( + app_name, user_id, state_deltas["user"] + ) + if state_deltas["session"]: + await self._store.update_session_state_async( + app_name, + user_id, + session_id, + state_deltas["session"], + ) + + # 4. 构造 ADK Session 返回(含合并 state) + return await self._build_adk_session_async( + app_name, user_id, session_id, events=[] + ) + + def create_session_sync( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + """同步版 create_session。""" + return self._create_session_impl( + app_name=app_name, + user_id=user_id, + state=state, + session_id=session_id, + ) + + def _create_session_impl( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + session_id = ( + session_id.strip() + if session_id and session_id.strip() + else str(uuid.uuid4()) + ) + + # 1. 拆分初始 state 为三级 + state_deltas = _extract_state_delta(state or {}) + + # 2. 创建 OTS session + self._store.create_session( + app_name, + user_id, + session_id, + framework="adk", + ) + + # 3. 持久化三级 state + if state_deltas["app"]: + self._store.update_app_state(app_name, state_deltas["app"]) + if state_deltas["user"]: + self._store.update_user_state( + app_name, user_id, state_deltas["user"] + ) + if state_deltas["session"]: + self._store.update_session_state( + app_name, + user_id, + session_id, + state_deltas["session"], + ) + + # 4. 构造 ADK Session 返回(含合并 state) + return self._build_adk_session(app_name, user_id, session_id, events=[]) + + # --------------------------------------------------------------- + # get_session + # --------------------------------------------------------------- + + @override + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + # 1. 读 session 元数据 + ots_session = await self._store.get_session_async( + app_name, user_id, session_id + ) + if ots_session is None: + return None + + # 2. 读 events(考虑 config.num_recent_events) + if config and config.num_recent_events: + ots_events = await self._store.get_recent_events_async( + app_name, + user_id, + session_id, + config.num_recent_events, + ) + else: + ots_events = await self._store.get_events_async( + app_name, user_id, session_id + ) + + # 3. 从 raw_event 列反序列化为 ADK Event + adk_events: list[Event] = [] + for e in ots_events: + if e.raw_event is not None: + try: + adk_events.append(Event.model_validate_json(e.raw_event)) + except Exception: + logger.warning( + "Failed to deserialize ADK Event seq_id=%s, skipping.", + e.seq_id, + exc_info=True, + ) + + # 4. 如有 after_timestamp,过滤 + if config and config.after_timestamp: + adk_events = [ + e for e in adk_events if e.timestamp >= config.after_timestamp + ] + + # 5. 构造带 merged state 的 ADK Session + return await self._build_adk_session_async( + app_name, + user_id, + session_id, + events=adk_events, + updated_at=ots_session.updated_at, + ) + + def get_session_sync( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + """同步版 get_session。""" + return self._get_session_impl( + app_name=app_name, + user_id=user_id, + session_id=session_id, + config=config, + ) + + def _get_session_impl( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + # 1. 读 session 元数据 + ots_session = self._store.get_session(app_name, user_id, session_id) + if ots_session is None: + return None + + # 2. 读 events(考虑 config.num_recent_events) + if config and config.num_recent_events: + ots_events = self._store.get_recent_events( + app_name, + user_id, + session_id, + config.num_recent_events, + ) + else: + ots_events = self._store.get_events(app_name, user_id, session_id) + + # 3. 从 raw_event 列反序列化为 ADK Event + adk_events: list[Event] = [] + for e in ots_events: + if e.raw_event is not None: + try: + adk_events.append(Event.model_validate_json(e.raw_event)) + except Exception: + logger.warning( + "Failed to deserialize ADK Event seq_id=%s, skipping.", + e.seq_id, + exc_info=True, + ) + + # 4. 如有 after_timestamp,过滤 + if config and config.after_timestamp: + adk_events = [ + e for e in adk_events if e.timestamp >= config.after_timestamp + ] + + # 5. 构造带 merged state 的 ADK Session + return self._build_adk_session( + app_name, + user_id, + session_id, + events=adk_events, + updated_at=ots_session.updated_at, + ) + + # --------------------------------------------------------------- + # list_sessions + # --------------------------------------------------------------- + + @override + async def list_sessions( + self, + *, + app_name: str, + user_id: Optional[str] = None, + ) -> ListSessionsResponse: + if user_id is not None: + ots_sessions = await self._store.list_sessions_async( + app_name, user_id + ) + else: + ots_sessions = await self._store.list_all_sessions_async(app_name) + + sessions: list[Session] = [] + for s in ots_sessions: + sessions.append( + Session( + id=s.session_id, + app_name=app_name, + user_id=s.user_id, + state={}, + events=[], + last_update_time=s.updated_at / 1_000_000_000.0, + ) + ) + + return ListSessionsResponse(sessions=sessions) + + def list_sessions_sync( + self, + *, + app_name: str, + user_id: Optional[str] = None, + ) -> ListSessionsResponse: + """同步版 list_sessions。""" + return self._list_sessions_impl(app_name=app_name, user_id=user_id) + + def _list_sessions_impl( + self, + *, + app_name: str, + user_id: Optional[str] = None, + ) -> ListSessionsResponse: + if user_id is not None: + ots_sessions = self._store.list_sessions(app_name, user_id) + else: + ots_sessions = self._store.list_all_sessions(app_name) + + sessions: list[Session] = [] + for s in ots_sessions: + sessions.append( + Session( + id=s.session_id, + app_name=app_name, + user_id=s.user_id, + state={}, + events=[], + last_update_time=s.updated_at / 1_000_000_000.0, + ) + ) + + return ListSessionsResponse(sessions=sessions) + + # --------------------------------------------------------------- + # delete_session + # --------------------------------------------------------------- + + @override + async def delete_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + ) -> None: + await self._store.delete_session_async(app_name, user_id, session_id) + + def delete_session_sync( + self, + *, + app_name: str, + user_id: str, + session_id: str, + ) -> None: + """同步版 delete_session。""" + self._delete_session_impl( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + + def _delete_session_impl( + self, + *, + app_name: str, + user_id: str, + session_id: str, + ) -> None: + self._store.delete_session(app_name, user_id, session_id) + + # --------------------------------------------------------------- + # append_event + # --------------------------------------------------------------- + + @override + async def append_event(self, session: Session, event: Event) -> Event: + if event.partial: + return event + + # 1. 调用父类 sync 辅助方法更新内存 session + # (trim temp state delta + update session state) + event = self._trim_temp_delta_state(event) + self._update_session_state(session, event) + session.events.append(event) + session.last_update_time = event.timestamp + + # 2. 序列化 Event,写入 content(简化文本)和 + # raw_event(完整 JSON) + raw_event_str = event.model_dump_json(by_alias=False) + content_dict = _extract_display_content(event) + + await self._store.append_event_async( + session.app_name, + session.user_id, + session.id, + event_type=_EVENT_TYPE, + content=content_dict, + raw_event=raw_event_str, + ) + + # 3. 持久化 state delta 到三级 state 表 + if event.actions and event.actions.state_delta: + state_deltas = _extract_state_delta(event.actions.state_delta) + if state_deltas["app"]: + await self._store.update_app_state_async( + session.app_name, state_deltas["app"] + ) + if state_deltas["user"]: + await self._store.update_user_state_async( + session.app_name, + session.user_id, + state_deltas["user"], + ) + if state_deltas["session"]: + await self._store.update_session_state_async( + session.app_name, + session.user_id, + session.id, + state_deltas["session"], + ) + + return event + + def _append_event_impl(self, session: Session, event: Event) -> Event: + """同步版 append_event 的内部实现。""" + if event.partial: + return event + + # 1. 调用父类 sync 辅助方法更新内存 session + event = self._trim_temp_delta_state(event) + self._update_session_state(session, event) + session.events.append(event) + session.last_update_time = event.timestamp + + # 2. 序列化 Event,写入 content(简化文本)和 raw_event(完整 JSON) + raw_event_str = event.model_dump_json(by_alias=False) + content_dict = _extract_display_content(event) + + self._store.append_event( + session.app_name, + session.user_id, + session.id, + event_type=_EVENT_TYPE, + content=content_dict, + raw_event=raw_event_str, + ) + + # 3. 持久化 state delta 到三级 state 表 + if event.actions and event.actions.state_delta: + state_deltas = _extract_state_delta(event.actions.state_delta) + if state_deltas["app"]: + self._store.update_app_state( + session.app_name, state_deltas["app"] + ) + if state_deltas["user"]: + self._store.update_user_state( + session.app_name, + session.user_id, + state_deltas["user"], + ) + if state_deltas["session"]: + self._store.update_session_state( + session.app_name, + session.user_id, + session.id, + state_deltas["session"], + ) + + return event + + # --------------------------------------------------------------- + # 内部辅助方法 + # --------------------------------------------------------------- + + async def _build_adk_session_async( + self, + app_name: str, + user_id: str, + session_id: str, + events: list[Event], + updated_at: Optional[int] = None, + ) -> Session: + """构造 ADK Session 对象,合并三级 state(异步)。 + + Args: + app_name: 应用名(对应 OTS agent_id)。 + user_id: 用户 ID。 + session_id: 会话 ID。 + events: ADK Event 列表。 + updated_at: OTS 中的 updated_at(纳秒), + 用于设置 last_update_time。 + """ + merged: dict[str, Any] = {} + + # session state(无前缀) + session_state = await self._store.get_session_state_async( + app_name, user_id, session_id + ) + merged.update(session_state) + + # user state(加 user: 前缀) + user_state = await self._store.get_user_state_async(app_name, user_id) + for k, v in user_state.items(): + merged[_USER_PREFIX + k] = v + + # app state(加 app: 前缀) + app_state = await self._store.get_app_state_async(app_name) + for k, v in app_state.items(): + merged[_APP_PREFIX + k] = v + + last_update = ( + updated_at / 1_000_000_000.0 + if updated_at is not None + else time.time() + ) + + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=merged, + events=events, + last_update_time=last_update, + ) + + def _build_adk_session( + self, + app_name: str, + user_id: str, + session_id: str, + events: list[Event], + updated_at: Optional[int] = None, + ) -> Session: + """构造 ADK Session 对象,合并三级 state(同步)。 + + Args: + app_name: 应用名(对应 OTS agent_id)。 + user_id: 用户 ID。 + session_id: 会话 ID。 + events: ADK Event 列表。 + updated_at: OTS 中的 updated_at(纳秒), + 用于设置 last_update_time。 + """ + merged: dict[str, Any] = {} + + # session state(无前缀) + session_state = self._store.get_session_state( + app_name, user_id, session_id + ) + merged.update(session_state) + + # user state(加 user: 前缀) + user_state = self._store.get_user_state(app_name, user_id) + for k, v in user_state.items(): + merged[_USER_PREFIX + k] = v + + # app state(加 app: 前缀) + app_state = self._store.get_app_state(app_name) + for k, v in app_state.items(): + merged[_APP_PREFIX + k] = v + + last_update = ( + updated_at / 1_000_000_000.0 + if updated_at is not None + else time.time() + ) + + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=merged, + events=events, + last_update_time=last_update, + ) diff --git a/agentrun/conversation_service/adapters/langchain_adapter.py b/agentrun/conversation_service/adapters/langchain_adapter.py new file mode 100644 index 0000000..74d64d0 --- /dev/null +++ b/agentrun/conversation_service/adapters/langchain_adapter.py @@ -0,0 +1,248 @@ +"""LangChain BaseChatMessageHistory 适配器。 + +将 LangChain 的消息历史持久化到 OTS,通过 SessionStore 实现。 + +使用方式:: + + from agentrun.conversation_service import SessionStore, OTSBackend + from agentrun.conversation_service.adapters import OTSChatMessageHistory + + store = SessionStore(OTSBackend(ots_client)) + + history = OTSChatMessageHistory( + session_store=store, + agent_id="my_agent", + user_id="user_1", + session_id="session_1", + ) + + # 配合 RunnableWithMessageHistory + from langchain_core.runnables.history import RunnableWithMessageHistory + + chain_with_history = RunnableWithMessageHistory( + chain, + lambda session_id: OTSChatMessageHistory( + session_store=store, + agent_id="my_agent", + user_id="user_1", + session_id=session_id, + ), + ) +""" + +from __future__ import annotations + +import logging +from typing import Any, Sequence + +from langchain_core.chat_history import ( + BaseChatMessageHistory, +) # type: ignore[import-untyped] +from langchain_core.messages import AIMessage # type: ignore[import-untyped] +from langchain_core.messages import ( + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + +from agentrun.conversation_service.model import ConversationEvent +from agentrun.conversation_service.session_store import SessionStore + +logger = logging.getLogger(__name__) + +# LangChain message type → Message class 映射 +_TYPE_TO_CLASS: dict[str, type[BaseMessage]] = { + "human": HumanMessage, + "ai": AIMessage, + "system": SystemMessage, + "tool": ToolMessage, +} + +# 统一的事件类型标识 +_EVENT_TYPE = "message" + + +class OTSChatMessageHistory(BaseChatMessageHistory): + """基于 OTS 的 LangChain 消息历史实现。 + + 将 LangChain 的 BaseMessage 序列化为 ConversationEvent + 存储到 TableStore,通过 SessionStore 进行读写。 + + Attributes: + session_store: SessionStore 实例。 + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + """ + + def __init__( + self, + session_store: SessionStore, + agent_id: str, + user_id: str, + session_id: str, + *, + auto_create_session: bool = True, + ) -> None: + """初始化。 + + Args: + session_store: SessionStore 实例。 + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + auto_create_session: 若 Session 不存在是否自动创建, + 默认 True。 + """ + self.session_store = session_store + self.agent_id = agent_id + self.user_id = user_id + self.session_id = session_id + + if auto_create_session: + existing = session_store.get_session(agent_id, user_id, session_id) + if existing is None: + session_store.create_session( + agent_id, + user_id, + session_id, + framework="langchain", + ) + + # --------------------------------------------------------------- + # BaseChatMessageHistory 接口实现 + # --------------------------------------------------------------- + + @property + def messages(self) -> list[BaseMessage]: # type: ignore[override] + """从 OTS 读取全部消息,按 seq_id 正序返回。""" + events = self.session_store.get_events( + self.agent_id, self.user_id, self.session_id + ) + result: list[BaseMessage] = [] + for event in events: + try: + msg = _event_to_message(event) + result.append(msg) + except Exception: + logger.warning( + "Failed to deserialize event seq_id=%s, skipping.", + event.seq_id, + exc_info=True, + ) + return result + + def add_messages(self, messages: Sequence[BaseMessage]) -> None: + """批量写入消息到 OTS。 + + 每条 LangChain BaseMessage 转为一个 ConversationEvent。 + """ + for message in messages: + content = _message_to_dict(message) + self.session_store.append_event( + self.agent_id, + self.user_id, + self.session_id, + event_type=_EVENT_TYPE, + content=content, + ) + + def clear(self) -> None: + """清空当前 Session 的所有消息。 + + 只删除 Event,不删除 Session 本身和 State。 + """ + self.session_store.delete_events( + self.agent_id, self.user_id, self.session_id + ) + + +# ------------------------------------------------------------------- +# 序列化 / 反序列化工具函数 +# ------------------------------------------------------------------- + + +def _message_to_dict(message: BaseMessage) -> dict[str, Any]: + """将 LangChain BaseMessage 序列化为可 JSON 存储的 dict。 + + 存储字段: + - lc_type: 消息类型(human / ai / system / tool) + - content: 消息内容 + - additional_kwargs: 额外参数 + - response_metadata: 响应元数据 + - name: 消息名称(可选) + - id: 消息 ID(可选) + - tool_calls: 工具调用列表(AIMessage 特有) + - tool_call_id: 工具调用 ID(ToolMessage 特有) + """ + data: dict[str, Any] = { + "lc_type": message.type, + "content": message.content, + } + + # 只存非空字段,减少存储开销 + if message.additional_kwargs: + data["additional_kwargs"] = message.additional_kwargs + if getattr(message, "response_metadata", None): + data["response_metadata"] = message.response_metadata + if message.name is not None: + data["name"] = message.name + if message.id is not None: + data["id"] = message.id + + # AIMessage 特有字段 + if isinstance(message, AIMessage): + if message.tool_calls: + data["tool_calls"] = message.tool_calls + if message.invalid_tool_calls: + data["invalid_tool_calls"] = message.invalid_tool_calls + + # ToolMessage 特有字段 + if isinstance(message, ToolMessage): + data["tool_call_id"] = message.tool_call_id + + return data + + +def _event_to_message( + event: ConversationEvent, +) -> BaseMessage: + """将 ConversationEvent 反序列化为 LangChain BaseMessage。""" + data = dict(event.content) + lc_type = data.pop("lc_type", "human") + + cls = _TYPE_TO_CLASS.get(lc_type) + if cls is None: + logger.warning( + "Unknown message type '%s', falling back to HumanMessage.", + lc_type, + ) + cls = HumanMessage + + # 构造参数:只传非空字段 + kwargs: dict[str, Any] = { + "content": data.get("content", ""), + } + + if "additional_kwargs" in data: + kwargs["additional_kwargs"] = data["additional_kwargs"] + if "response_metadata" in data: + kwargs["response_metadata"] = data["response_metadata"] + if "name" in data: + kwargs["name"] = data["name"] + if "id" in data: + kwargs["id"] = data["id"] + + # AIMessage 特有 + if cls is AIMessage: + if "tool_calls" in data: + kwargs["tool_calls"] = data["tool_calls"] + if "invalid_tool_calls" in data: + kwargs["invalid_tool_calls"] = data["invalid_tool_calls"] + + # ToolMessage 特有 + if cls is ToolMessage: + kwargs["tool_call_id"] = data.get("tool_call_id", "") + + return cls(**kwargs) diff --git a/agentrun/conversation_service/conversation_design.md b/agentrun/conversation_service/conversation_design.md new file mode 100644 index 0000000..fecfbec --- /dev/null +++ b/agentrun/conversation_service/conversation_design.md @@ -0,0 +1,262 @@ + + +## 功能描述 +为不同 Agent 开发框架提供会话状态持久化能力,持久化数据库选用阿里云 TableStore (OTS,宽表模型)。通过一套表结构兼容多种框架。 + +方案:统一存储 + 中心 Service + 薄 Adapter + + ┌─────────────────┐ + ADK Agent ──→ ADK Adapter ──→ │ │ ┌──────────┐ + │ SessionStore │───→│ │ + LangChain ──→ LC Adapter ───→ │ (Central Svc) │───→│ OTS │ + │ │───→│ Tables │ + LangGraph ──→ LG Adapter ──→ │ │ │ │ + └─────────────────┘ └──────────┘ + + Central Service 职责: + ① 理解 OTS 表结构 + ② 实现 OTS 读写操作 + ③ 实现业务逻辑(级联删除、状态合并…) + ④ 暴露框架无关的统一接口 + + Adapter 职责: + 仅做 ④ 框架数据模型转换 + + +## 访问模式分析 + +访问模式 频率 操作类型 +───────────────────────────────────────────────────── +1. 创建 session 中 PutRow +2. 获取 session (app, user, sid) 高 GetRow(点读) +3. 列出用户所有 session 中 GetRange(二级索引扫描) +4. 删除 session + 所有消息 低 BatchWrite + GetRange +5. 追加消息/事件 高 PutRow(自增排序) +6. 获取 session 全部消息 高 GetRange +7. 获取最近 N 条消息 高 GetRange(反向 + limit) +8. 按时间过滤消息 中 GetRange + Filter / 多元索引 +9. 读写 app/user 级状态 中 GetRow / UpdateRow +10. 全文搜索 summary / 标签过滤 中 多元索引 SearchQuery +11. 跨 user_id 查询(管理后台场景) 低 多元索引 BoolQuery + +## 表设计 +### 会话表 +Conversation 表 +PK: + agent_id (String, 分区键) + user_id (String) + session_id (String) + +Defined Columns(二级索引 / 多元索引引用的非 PK 列,建表时需声明): + updated_at : Integer + summary : String + labels : String + framework : String + extensions : String + +Attributes: + created_at : Integer -- 纳秒时间戳 + updated_at : Integer -- 纳秒时间戳 + is_pinned : Boolean -- 是否置顶 + summary : String -- 会话摘要 + labels : String -- 会话标签(JSON 字符串) + framework : String -- 'adk' / 'langchain' / … + extensions : String -- JSON 框架扩展数据 + version : Integer -- 版本号, 用于乐观锁 + + +二级索引(GLOBAL_INDEX): +conversation_secondary_index: +PK: + agent_id (String, 分区键) + user_id (String) + updated_at (Integer) -- 纳秒时间戳,支持按更新时间排序 + session_id (String) + +Attributes: + summary (String) -- 会话摘要 + labels (String) -- 会话标签 + framework (String) -- 'adk' / 'langchain' / … + extensions (String) -- JSON 框架扩展数据 + +用途:list_sessions(agent_id, user_id) 热路径,低延迟(毫秒级)。 + + +多元索引(Search Index): +conversation_search_index: + +字段 OTS FieldType index enable_sort_and_agg 说明 +────────────────────────────────────────────────────────────────── +agent_id KEYWORD True True 精确匹配 + routing_field +user_id KEYWORD True True 精确匹配 +session_id KEYWORD True True 精确匹配 +updated_at LONG True True 范围查询 + 排序 +created_at LONG True True 范围查询 + 排序 +is_pinned KEYWORD True True 过滤置顶("true"/"false") +framework KEYWORD True True 按框架过滤 +summary TEXT True False 全文检索(SINGLEWORD 分词) +labels KEYWORD True True 精确匹配标签 JSON + +高级配置: + routing_fields = ["agent_id"] -- 同一 agent 数据路由到同一分区 + index_sort = updated_at DESC -- 预排序,最近更新优先 + +用途:全文搜索 summary、标签过滤、跨 user_id 查询、组合条件搜索。 +延迟稍高(10-50ms),适合搜索/过滤场景,与二级索引并行保留。 + + +### Event 表 + +Event 表 +PK: + agent_id (String, 分区键) + user_id (String) + session_id (String) + seq_id (Integer, AUTO_INCREMENT) -- 事件序号,OTS 自增列 + +Attributes: + type : String -- 事件类型 + content : String -- JSON 序列化的事件数据 + created_at : Integer -- 纳秒时间戳 + updated_at : Integer -- 纳秒时间戳 + version : Integer -- 版本号, 用于乐观锁 + raw_event : String -- 框架原生 Event 的完整 JSON 序列化(可选) + 用于精确还原框架特定的 Event 对象(如 ADK Event) + +说明:统一用 Event 抽象,Message 是 Event 的子集 + +### State 表 +State 表 +PK: + agent_id (String, 分区键) + user_id (String) + session_id (String) + +Attributes: + state : String -- JSON 序列化的状态数据(未分片时使用) + chunk_count : Integer -- 分片数量,0 表示未分片 + state_0..N : String -- 分片列(当 JSON 超过 1.5M 字符时自动分片) + created_at : Integer -- 纳秒时间戳 + updated_at : Integer -- 纳秒时间戳 + version : Integer -- 版本号, 用于乐观锁 + +说明:State 以 JSON 字符串存储。当 JSON 超过 1.5M 字符时, +自动拆分为 state_0, state_1, ... 多列存储(列分片), +读取时按 chunk_count 拼接还原。 + +### App_state 表 +App_state 表 +PK: + agent_id (String, 分区键) + +Attributes: + state : String -- JSON 序列化的状态数据(未分片时使用) + chunk_count : Integer -- 分片数量,0 表示未分片 + state_0..N : String -- 分片列(当 JSON 超过 1.5M 字符时自动分片) + created_at : Integer -- 纳秒时间戳 + updated_at : Integer -- 纳秒时间戳 + version : Integer -- 版本号, 用于乐观锁 + +说明:三级 State 是 ADK 的概念,其他框架按需使用 + +### User_state 表 +User_state 表 +PK: + agent_id (String, 分区键) + user_id (String) + +Attributes: + state : String -- JSON 序列化的状态数据(未分片时使用) + chunk_count : Integer -- 分片数量,0 表示未分片 + state_0..N : String -- 分片列(当 JSON 超过 1.5M 字符时自动分片) + created_at : Integer -- 纳秒时间戳 + updated_at : Integer -- 纳秒时间戳 + version : Integer -- 版本号, 用于乐观锁 + +说明:三级 State 是 ADK 的概念,其他框架按需使用 + +## 初始化策略 + +表和索引按用途分组创建,避免为未使用的框架创建不必要的表: + +方法 创建的资源 适用场景 +───────────────────────────────────────────────────────────────── +init_core_tables() Conversation + Event + 二级索引 所有框架 +init_state_tables() State + App_state + User_state ADK 三级 State +init_search_index() conversation_search_index (多元索引) 需要搜索/过滤 +init_tables() 以上全部(向后兼容) 快速开发 +多元索引创建耗时较长(数秒级),建议与核心表创建分离,不阻塞核心流程。 + +## 分层架构 + +┌─────────────────────────────────────────────────────────┐ +│ Layer 1: Framework Adapters(薄,只做模型转换) │ +│ │ +│ ┌─────────────┐ ┌───────────────┐ ┌──────────────┐ │ +│ │ ADK Adapter │ │ LC Adapter │ │ LG Adapter │ │ +│ │ implements │ │ implements │ │ implements │ │ +│ │ BaseSession │ │ BaseChatMsg │ │ BaseCheck │ │ +│ │ Service │ │ History │ │ pointSaver │ │ +│ └──────┬───────┘ └──────┬────────┘ └──────┬───────┘ │ +│ │ │ │ │ +│ ▼ ▼ ▼ │ +├─────────────────────────────────────────────────────────┤ +│ Layer 2: SessionStore(厚,核心业务逻辑) │ +│ │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ class SessionStore: │ │ +│ │ # 统一领域对象 │ │ +│ │ @dataclass ConversationSession │ │ +│ │ @dataclass ConversationEvent │ │ +│ │ @dataclass StateData │ │ +│ │ │ │ +│ │ # 工厂方法 │ │ +│ │ from_memory_collection(name) # 推荐入口 │ │ +│ │ → 从 MemoryCollection 获取 OTS 连接信息 │ │ +│ │ → 自动构建 OTSClient + OTSBackend │ │ +│ │ │ │ +│ │ # 初始化 │ │ +│ │ init_tables() # 全量建表(向后兼容) │ │ +│ │ init_core_tables() # 核心表 + 二级索引 │ │ +│ │ init_state_tables() # 三级 State 表 │ │ +│ │ init_search_index() # 多元索引(按需) │ │ +│ │ │ │ +│ │ # Session CRUD │ │ +│ │ create_session(...) → ConversationSession │ │ +│ │ get_session(...) → ConversationSession? │ │ +│ │ list_sessions(...) → [ConversationSession] │ │ +│ │ list_all_sessions(...)→ [ConversationSession] │ │ +│ │ search_sessions(...) → ([Session], total) │ │ +│ │ update_session(...) # 乐观锁 │ │ +│ │ delete_session(...) # 级联删除 Event→State→Row │ │ +│ │ │ │ +│ │ # Event CRUD │ │ +│ │ append_event(...) → ConversationEvent │ │ +│ │ get_events(...) → [ConversationEvent] │ │ +│ │ get_recent_events(...)→ [ConversationEvent] │ │ +│ │ delete_events(...) → int │ │ +│ │ │ │ +│ │ # 三级 State 管理 │ │ +│ │ get_session_state / update_session_state │ │ +│ │ get_app_state / update_app_state │ │ +│ │ get_user_state / update_user_state │ │ +│ │ get_merged_state(...)→ dict # 三级浅合并 │ │ +│ └──────────────┬──────────────────────────────────┘ │ +│ │ │ +├─────────────────┼───────────────────────────────────────┤ +│ Layer 3: Storage Backend │ +│ ▼ │ +│ ┌──────────────────────┐ │ +│ │ OTSBackend │ ← 当前实现 │ +│ │ (OTS SDK 调用) │ │ +│ └──────────────────────┘ │ +│ │ +│ OTS 连接信息来源: │ +│ SessionStore.from_memory_collection(name) │ +│ → MemoryCollection.get_by_name(name) │ +│ → vector_store_config.endpoint / instance_name │ +│ → Config (AK/SK) │ +│ → OTSClient → OTSBackend │ +│ │ +│ 也可手动传入 OTSClient 构建 OTSBackend(向后兼容) │ +└─────────────────────────────────────────────────────────┘ \ No newline at end of file diff --git a/agentrun/conversation_service/model.py b/agentrun/conversation_service/model.py new file mode 100644 index 0000000..96aceef --- /dev/null +++ b/agentrun/conversation_service/model.py @@ -0,0 +1,138 @@ +"""Conversation Service 领域模型。 + +定义会话、事件、状态等核心数据结构,以及表名常量。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +import json +from typing import Any, Optional + +# --------------------------------------------------------------------------- +# 表名常量(支持通过 table_prefix 自定义) +# --------------------------------------------------------------------------- + +DEFAULT_CONVERSATION_TABLE = "conversation" +DEFAULT_EVENT_TABLE = "event" +DEFAULT_STATE_TABLE = "state" +DEFAULT_APP_STATE_TABLE = "app_state" +DEFAULT_USER_STATE_TABLE = "user_state" +DEFAULT_CONVERSATION_SECONDARY_INDEX = "conversation_secondary_index" +DEFAULT_CONVERSATION_SEARCH_INDEX = "conversation_search_index" + + +# --------------------------------------------------------------------------- +# 枚举 +# --------------------------------------------------------------------------- + + +class StateScope(str, Enum): + """状态作用域。 + + 三级 State 是 ADK 的概念,其他框架按需使用。 + - APP: 应用级状态(agent_id 维度) + - USER: 用户级状态(agent_id + user_id 维度) + - SESSION: 会话级状态(agent_id + user_id + session_id 维度) + """ + + APP = "app" + USER = "user" + SESSION = "session" + + +# --------------------------------------------------------------------------- +# 领域对象 +# --------------------------------------------------------------------------- + + +@dataclass +class ConversationSession: + """会话对象。 + + Attributes: + agent_id: 智能体 ID(分区键)。 + user_id: 用户 ID。 + session_id: 会话 ID。 + created_at: 创建时间(纳秒时间戳)。 + updated_at: 最后更新时间(纳秒时间戳)。 + is_pinned: 是否置顶。 + summary: 会话摘要。 + labels: 会话标签(JSON 字符串)。 + framework: 框架标识,如 'adk' / 'langchain' / 'langgraph'。 + extensions: 框架扩展数据(JSON 序列化后存储)。 + version: 乐观锁版本号。 + """ + + agent_id: str + user_id: str + session_id: str + created_at: int + updated_at: int + is_pinned: bool = False + summary: Optional[str] = None + labels: Optional[str] = None + framework: Optional[str] = None + extensions: Optional[dict[str, Any]] = None + version: int = 0 + + +@dataclass +class ConversationEvent: + """事件对象。 + + 统一用 Event 抽象,Message 是 Event 的子集。 + + Attributes: + agent_id: 智能体 ID(分区键)。 + user_id: 用户 ID。 + session_id: 会话 ID。 + seq_id: 事件序号(OTS AUTO_INCREMENT 生成,写入前为 None)。 + type: 事件类型。 + content: 事件数据(JSON 序列化后存储)。 + created_at: 创建时间(纳秒时间戳)。 + updated_at: 最后更新时间(纳秒时间戳)。 + version: 乐观锁版本号。 + raw_event: 框架原生 Event 的完整 JSON 序列化(可选)。 + 用于精确还原框架特定的 Event 对象(如 ADK Event)。 + LangChain 等不使用此字段的框架默认为 None。 + """ + + agent_id: str + user_id: str + session_id: str + seq_id: Optional[int] + type: str + content: dict[str, Any] = field(default_factory=dict) + created_at: int = 0 + updated_at: int = 0 + version: int = 0 + raw_event: Optional[str] = None + + def content_as_json(self) -> str: + """将 content 序列化为 JSON 字符串。""" + return json.dumps(self.content, ensure_ascii=False) + + @staticmethod + def content_from_json(raw: str) -> dict[str, Any]: + """从 JSON 字符串反序列化 content。""" + result: dict[str, Any] = json.loads(raw) + return result + + +@dataclass +class StateData: + """状态数据对象。 + + Attributes: + state: 状态字典。 + created_at: 创建时间(纳秒时间戳)。 + updated_at: 最后更新时间(纳秒时间戳)。 + version: 乐观锁版本号。 + """ + + state: dict[str, Any] = field(default_factory=dict) + created_at: int = 0 + updated_at: int = 0 + version: int = 0 diff --git a/agentrun/conversation_service/ots_backend.py b/agentrun/conversation_service/ots_backend.py new file mode 100644 index 0000000..0a469be --- /dev/null +++ b/agentrun/conversation_service/ots_backend.py @@ -0,0 +1,2314 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/conversation_service/__ots_backend_async_template.py + +OTS 存储后端。 + +封装 TableStore SDK 的底层操作,负责五张表的建表和 CRUD。 +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Optional + +from tablestore import AsyncOTSClient # type: ignore[import-untyped] +from tablestore import BatchWriteRowRequest # type: ignore[import-untyped] +from tablestore import ( + CapacityUnit, + ComparatorType, + Condition, + DeleteRowItem, + Direction, + INF_MAX, + INF_MIN, + OTSClient, + OTSServiceError, + PK_AUTO_INCR, + ReservedThroughput, + ReturnType, + Row, + RowExistenceExpectation, + SecondaryIndexMeta, + SecondaryIndexType, + SingleColumnCondition, + TableInBatchWriteRowItem, + TableMeta, + TableOptions, +) + +from agentrun.conversation_service.model import ( + ConversationEvent, + ConversationSession, + DEFAULT_APP_STATE_TABLE, + DEFAULT_CONVERSATION_SEARCH_INDEX, + DEFAULT_CONVERSATION_SECONDARY_INDEX, + DEFAULT_CONVERSATION_TABLE, + DEFAULT_EVENT_TABLE, + DEFAULT_STATE_TABLE, + DEFAULT_USER_STATE_TABLE, + StateData, + StateScope, +) +from agentrun.conversation_service.utils import ( + deserialize_state, + from_chunks, + MAX_COLUMN_SIZE, + nanoseconds_timestamp, + serialize_state, + to_chunks, +) + +logger = logging.getLogger(__name__) + +# OTS BatchWriteRow 每批最多 200 行 +_BATCH_WRITE_LIMIT = 200 + + +class OTSBackend: + """TableStore 存储后端。 + + 封装 OTS SDK 底层操作,理解表结构,提供五张表的 CRUD。 + 同时提供异步(_async 后缀)和同步方法。 + + Args: + ots_client: 预构建的 OTS SDK 同步客户端实例(同步方法使用)。 + table_prefix: 表名前缀,用于多租户隔离。 + async_ots_client: 预构建的 OTS SDK 异步客户端实例(异步方法使用)。 + """ + + def __init__( + self, + ots_client: Optional[OTSClient] = None, + table_prefix: str = "", + *, + async_ots_client: Optional[AsyncOTSClient] = None, + ) -> None: + self._client = ots_client + self._async_client = async_ots_client + self._table_prefix = table_prefix + + # 根据前缀生成实际表名 + self._conversation_table = f"{table_prefix}{DEFAULT_CONVERSATION_TABLE}" + self._event_table = f"{table_prefix}{DEFAULT_EVENT_TABLE}" + self._state_table = f"{table_prefix}{DEFAULT_STATE_TABLE}" + self._app_state_table = f"{table_prefix}{DEFAULT_APP_STATE_TABLE}" + self._user_state_table = f"{table_prefix}{DEFAULT_USER_STATE_TABLE}" + self._conversation_secondary_index = ( + f"{table_prefix}{DEFAULT_CONVERSATION_SECONDARY_INDEX}" + ) + self._conversation_search_index = ( + f"{table_prefix}{DEFAULT_CONVERSATION_SEARCH_INDEX}" + ) + + # ----------------------------------------------------------------------- + # 建表(异步)/ Table creation (async) + # ----------------------------------------------------------------------- + + async def init_tables_async(self) -> None: + """创建五张表和 Conversation 二级索引(异步)。 + + 表已存在时跳过(catch OTSServiceError 并 log warning)。 + """ + await self._create_conversation_table_async() + await self._create_event_table_async() + await self._create_state_table_async( + self._state_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ], + ) + await self._create_state_table_async( + self._app_state_table, + [("agent_id", "STRING")], + ) + await self._create_state_table_async( + self._user_state_table, + [("agent_id", "STRING"), ("user_id", "STRING")], + ) + + def init_tables(self) -> None: + """创建五张表和 Conversation 二级索引(同步)。 + + 表已存在时跳过(catch OTSServiceError 并 log warning)。 + """ + self._create_conversation_table() + self._create_event_table() + self._create_state_table( + self._state_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ], + ) + self._create_state_table( + self._app_state_table, + [("agent_id", "STRING")], + ) + self._create_state_table( + self._user_state_table, + [("agent_id", "STRING"), ("user_id", "STRING")], + ) + + async def init_core_tables_async(self) -> None: + """创建核心表(Conversation + Event)和二级索引(异步)。""" + await self._create_conversation_table_async() + await self._create_event_table_async() + + def init_core_tables(self) -> None: + """创建核心表(Conversation + Event)和二级索引(同步)。""" + self._create_conversation_table() + self._create_event_table() + + async def init_state_tables_async(self) -> None: + """创建三张 State 表(异步)。""" + await self._create_state_table_async( + self._state_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ], + ) + await self._create_state_table_async( + self._app_state_table, + [("agent_id", "STRING")], + ) + await self._create_state_table_async( + self._user_state_table, + [("agent_id", "STRING"), ("user_id", "STRING")], + ) + + def init_state_tables(self) -> None: + """创建三张 State 表(同步)。""" + self._create_state_table( + self._state_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ], + ) + self._create_state_table( + self._app_state_table, + [("agent_id", "STRING")], + ) + self._create_state_table( + self._user_state_table, + [("agent_id", "STRING"), ("user_id", "STRING")], + ) + + async def init_search_index_async(self) -> None: + """创建 Conversation 多元索引(异步)。按需调用。""" + await self._create_conversation_search_index_async() + + def init_search_index(self) -> None: + """创建 Conversation 多元索引(同步)。按需调用。""" + self._create_conversation_search_index() + + async def _create_conversation_table_async(self) -> None: + """创建 Conversation 表 + 二级索引(异步)。""" + table_meta = TableMeta( + self._conversation_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ], + # 二级索引引用的非 PK 列必须声明为 defined_columns + defined_columns=[ + ("updated_at", "INTEGER"), + ("summary", "STRING"), + ("labels", "STRING"), + ("framework", "STRING"), + ("extensions", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + # 二级索引:按 updated_at 排序 + secondary_index_meta = SecondaryIndexMeta( + self._conversation_secondary_index, + [ + "agent_id", + "user_id", + "updated_at", + "session_id", + ], + [ + "summary", + "labels", + "framework", + "extensions", + ], + index_type=SecondaryIndexType.GLOBAL_INDEX, + ) + + try: + await self._async_client.create_table( + table_meta, + table_options, + reserved_throughput, + secondary_indexes=[secondary_index_meta], + ) + logger.info( + "Created table: %s with secondary index: %s", + self._conversation_table, + self._conversation_secondary_index, + ) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._conversation_table, + ) + else: + raise + + def _create_conversation_table(self) -> None: + """创建 Conversation 表 + 二级索引(同步)。""" + table_meta = TableMeta( + self._conversation_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ], + # 二级索引引用的非 PK 列必须声明为 defined_columns + defined_columns=[ + ("updated_at", "INTEGER"), + ("summary", "STRING"), + ("labels", "STRING"), + ("framework", "STRING"), + ("extensions", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + # 二级索引:按 updated_at 排序 + secondary_index_meta = SecondaryIndexMeta( + self._conversation_secondary_index, + [ + "agent_id", + "user_id", + "updated_at", + "session_id", + ], + [ + "summary", + "labels", + "framework", + "extensions", + ], + index_type=SecondaryIndexType.GLOBAL_INDEX, + ) + + try: + self._client.create_table( + table_meta, + table_options, + reserved_throughput, + secondary_indexes=[secondary_index_meta], + ) + logger.info( + "Created table: %s with secondary index: %s", + self._conversation_table, + self._conversation_secondary_index, + ) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._conversation_table, + ) + else: + raise + + async def _create_event_table_async(self) -> None: + """创建 Event 表(seq_id 为 AUTO_INCREMENT)(异步)。""" + table_meta = TableMeta( + self._event_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ("seq_id", "INTEGER", PK_AUTO_INCR), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + await self._async_client.create_table( + table_meta, + table_options, + reserved_throughput, + ) + logger.info("Created table: %s", self._event_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._event_table, + ) + else: + raise + + def _create_event_table(self) -> None: + """创建 Event 表(seq_id 为 AUTO_INCREMENT)(同步)。""" + table_meta = TableMeta( + self._event_table, + [ + ("agent_id", "STRING"), + ("user_id", "STRING"), + ("session_id", "STRING"), + ("seq_id", "INTEGER", PK_AUTO_INCR), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + self._client.create_table( + table_meta, + table_options, + reserved_throughput, + ) + logger.info("Created table: %s", self._event_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._event_table, + ) + else: + raise + + async def _create_state_table_async( + self, + table_name: str, + pk_schema: list[tuple[str, str]], + ) -> None: + """创建 State 类型表(通用方法)(异步)。""" + table_meta = TableMeta(table_name, pk_schema) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + await self._async_client.create_table( + table_meta, + table_options, + reserved_throughput, + ) + logger.info("Created table: %s", table_name) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + table_name, + ) + else: + raise + + def _create_state_table( + self, + table_name: str, + pk_schema: list[tuple[str, str]], + ) -> None: + """创建 State 类型表(通用方法)(同步)。""" + table_meta = TableMeta(table_name, pk_schema) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + self._client.create_table( + table_meta, + table_options, + reserved_throughput, + ) + logger.info("Created table: %s", table_name) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + table_name, + ) + else: + raise + + async def _create_conversation_search_index_async(self) -> None: + """创建 Conversation 表的多元索引(异步)。 + + 多元索引支持全文检索 summary、精确匹配过滤 labels/framework/is_pinned、 + 范围查询 updated_at/created_at、跨 user 查询等场景。 + 索引已存在时跳过。 + """ + from tablestore import AnalyzerType # type: ignore[import-untyped] + from tablestore import FieldType # type: ignore[import-untyped] + from tablestore import IndexSetting # type: ignore[import-untyped] + from tablestore import SortOrder # type: ignore[import-untyped] + from tablestore import FieldSchema + from tablestore import ( + FieldSort as OTSFieldSort, + ) # type: ignore[import-untyped] + from tablestore import SearchIndexMeta + from tablestore import Sort as OTSSort # type: ignore[import-untyped] + + fields = [ + FieldSchema( + "agent_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "user_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "session_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "updated_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "created_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "is_pinned", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "framework", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "summary", + FieldType.TEXT, + index=True, + analyzer=AnalyzerType.SINGLEWORD, + ), + FieldSchema( + "labels", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + ] + + index_setting = IndexSetting(routing_fields=["agent_id"]) + index_sort = OTSSort( + sorters=[OTSFieldSort("updated_at", sort_order=SortOrder.DESC)] + ) + index_meta = SearchIndexMeta( + fields, + index_setting=index_setting, + index_sort=index_sort, + ) + + try: + await self._async_client.create_search_index( + self._conversation_table, + self._conversation_search_index, + index_meta, + ) + logger.info( + "Created search index: %s on table: %s", + self._conversation_search_index, + self._conversation_table, + ) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Search index %s already exists, skipping.", + self._conversation_search_index, + ) + else: + raise + + # ----------------------------------------------------------------------- + # Session CRUD(异步)/ Session CRUD (async) + # ----------------------------------------------------------------------- + + def _create_conversation_search_index(self) -> None: + """创建 Conversation 表的多元索引(同步)。 + + 多元索引支持全文检索 summary、精确匹配过滤 labels/framework/is_pinned、 + 范围查询 updated_at/created_at、跨 user 查询等场景。 + 索引已存在时跳过。 + """ + from tablestore import AnalyzerType # type: ignore[import-untyped] + from tablestore import FieldType # type: ignore[import-untyped] + from tablestore import IndexSetting # type: ignore[import-untyped] + from tablestore import SortOrder # type: ignore[import-untyped] + from tablestore import FieldSchema + from tablestore import ( + FieldSort as OTSFieldSort, + ) # type: ignore[import-untyped] + from tablestore import SearchIndexMeta + from tablestore import Sort as OTSSort # type: ignore[import-untyped] + + fields = [ + FieldSchema( + "agent_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "user_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "session_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "updated_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "created_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "is_pinned", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "framework", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "summary", + FieldType.TEXT, + index=True, + analyzer=AnalyzerType.SINGLEWORD, + ), + FieldSchema( + "labels", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + ] + + index_setting = IndexSetting(routing_fields=["agent_id"]) + index_sort = OTSSort( + sorters=[OTSFieldSort("updated_at", sort_order=SortOrder.DESC)] + ) + index_meta = SearchIndexMeta( + fields, + index_setting=index_setting, + index_sort=index_sort, + ) + + try: + self._client.create_search_index( + self._conversation_table, + self._conversation_search_index, + index_meta, + ) + logger.info( + "Created search index: %s on table: %s", + self._conversation_search_index, + self._conversation_table, + ) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Search index %s already exists, skipping.", + self._conversation_search_index, + ) + else: + raise + + # ----------------------------------------------------------------------- + # Session CRUD(同步)/ Session CRUD (async) + # ----------------------------------------------------------------------- + + async def put_session_async(self, session: ConversationSession) -> None: + """PutRow 写入/覆盖 Session 行(异步)。""" + primary_key = [ + ("agent_id", session.agent_id), + ("user_id", session.user_id), + ("session_id", session.session_id), + ] + + attribute_columns = [ + ("created_at", session.created_at), + ("updated_at", session.updated_at), + ("is_pinned", session.is_pinned), + ("version", session.version), + ] + + if session.summary is not None: + attribute_columns.append(("summary", session.summary)) + if session.labels is not None: + attribute_columns.append(("labels", session.labels)) + if session.framework is not None: + attribute_columns.append(("framework", session.framework)) + if session.extensions is not None: + attribute_columns.append(( + "extensions", + json.dumps(session.extensions, ensure_ascii=False), + )) + + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + await self._async_client.put_row( + self._conversation_table, row, condition + ) + + def put_session(self, session: ConversationSession) -> None: + """PutRow 写入/覆盖 Session 行(同步)。""" + primary_key = [ + ("agent_id", session.agent_id), + ("user_id", session.user_id), + ("session_id", session.session_id), + ] + + attribute_columns = [ + ("created_at", session.created_at), + ("updated_at", session.updated_at), + ("is_pinned", session.is_pinned), + ("version", session.version), + ] + + if session.summary is not None: + attribute_columns.append(("summary", session.summary)) + if session.labels is not None: + attribute_columns.append(("labels", session.labels)) + if session.framework is not None: + attribute_columns.append(("framework", session.framework)) + if session.extensions is not None: + attribute_columns.append(( + "extensions", + json.dumps(session.extensions, ensure_ascii=False), + )) + + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + self._client.put_row(self._conversation_table, row, condition) + + async def get_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> Optional[ConversationSession]: + """GetRow 点读 Session(异步)。""" + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + + _, row, _ = await self._async_client.get_row( + self._conversation_table, + primary_key, + max_version=1, + ) + + if row is None or row.primary_key is None: + return None + + return self._row_to_session(row) + + def get_session( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> Optional[ConversationSession]: + """GetRow 点读 Session(同步)。""" + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + + _, row, _ = self._client.get_row( + self._conversation_table, + primary_key, + max_version=1, + ) + + if row is None or row.primary_key is None: + return None + + return self._row_to_session(row) + + async def delete_session_row_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> None: + """DeleteRow 删除 Session 单行(不含级联)(异步)。""" + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + row = Row(primary_key) + condition = Condition(RowExistenceExpectation.IGNORE) + await self._async_client.delete_row( + self._conversation_table, row, condition + ) + + def delete_session_row( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> None: + """DeleteRow 删除 Session 单行(不含级联)(同步)。""" + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + row = Row(primary_key) + condition = Condition(RowExistenceExpectation.IGNORE) + self._client.delete_row(self._conversation_table, row, condition) + + async def update_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + columns_to_put: dict[str, Any], + version: int, + ) -> None: + """UpdateRow + 乐观锁更新 Session 行(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + columns_to_put: 要更新的列及其值。 + version: 当前版本号(乐观锁校验)。 + """ + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + + put_cols = list(columns_to_put.items()) + update_of_attribute_columns = {"PUT": put_cols} + + row = Row(primary_key, update_of_attribute_columns) + condition = Condition( + RowExistenceExpectation.EXPECT_EXIST, + SingleColumnCondition( + "version", + version, + ComparatorType.EQUAL, + ), + ) + await self._async_client.update_row( + self._conversation_table, row, condition + ) + + def update_session( + self, + agent_id: str, + user_id: str, + session_id: str, + columns_to_put: dict[str, Any], + version: int, + ) -> None: + """UpdateRow + 乐观锁更新 Session 行(同步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + columns_to_put: 要更新的列及其值。 + version: 当前版本号(乐观锁校验)。 + """ + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + + put_cols = list(columns_to_put.items()) + update_of_attribute_columns = {"PUT": put_cols} + + row = Row(primary_key, update_of_attribute_columns) + condition = Condition( + RowExistenceExpectation.EXPECT_EXIST, + SingleColumnCondition( + "version", + version, + ComparatorType.EQUAL, + ), + ) + self._client.update_row(self._conversation_table, row, condition) + + async def list_sessions_async( + self, + agent_id: str, + user_id: str, + limit: Optional[int] = None, + order_desc: bool = True, + ) -> list[ConversationSession]: + """通过二级索引按 updated_at 排序扫描 Session 列表(异步)。""" + + if order_desc: + # 倒序:从最新到最旧 + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MAX), + ("session_id", INF_MAX), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MIN), + ("session_id", INF_MIN), + ] + direction = Direction.BACKWARD + else: + # 正序:从最旧到最新 + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MIN), + ("session_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MAX), + ("session_id", INF_MAX), + ] + direction = Direction.FORWARD + + sessions: list[ConversationSession] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = await self._async_client.get_range( + self._conversation_secondary_index, + direction, + next_start, + exclusive_end, + max_version=1, + limit=limit, + ) + + for row in rows: + session = self._row_to_session_from_index(row) + sessions.append(session) + if limit is not None and len(sessions) >= limit: + return sessions + + if next_token is None: + break + next_start = next_token + + return sessions + + def list_sessions( + self, + agent_id: str, + user_id: str, + limit: Optional[int] = None, + order_desc: bool = True, + ) -> list[ConversationSession]: + """通过二级索引按 updated_at 排序扫描 Session 列表(同步)。""" + + if order_desc: + # 倒序:从最新到最旧 + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MAX), + ("session_id", INF_MAX), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MIN), + ("session_id", INF_MIN), + ] + direction = Direction.BACKWARD + else: + # 正序:从最旧到最新 + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MIN), + ("session_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("updated_at", INF_MAX), + ("session_id", INF_MAX), + ] + direction = Direction.FORWARD + + sessions: list[ConversationSession] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = self._client.get_range( + self._conversation_secondary_index, + direction, + next_start, + exclusive_end, + max_version=1, + limit=limit, + ) + + for row in rows: + session = self._row_to_session_from_index(row) + sessions.append(session) + if limit is not None and len(sessions) >= limit: + return sessions + + if next_token is None: + break + next_start = next_token + + return sessions + + async def list_all_sessions_async( + self, + agent_id: str, + limit: Optional[int] = None, + ) -> list[ConversationSession]: + """扫描 agent_id 下所有用户的 Session(主表 GetRange)(异步)。 + + 不走二级索引,直接扫主表。返回结果不含 events, + 适用于 ADK list_sessions(user_id=None) 场景。 + + Args: + agent_id: 智能体 ID。 + limit: 最多返回条数,None 表示全部。 + + Returns: + ConversationSession 列表。 + """ + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", INF_MIN), + ("session_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", INF_MAX), + ("session_id", INF_MAX), + ] + + sessions: list[ConversationSession] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = await self._async_client.get_range( + self._conversation_table, + Direction.FORWARD, + next_start, + exclusive_end, + max_version=1, + limit=limit, + ) + + for row in rows: + session = self._row_to_session(row) + sessions.append(session) + if limit is not None and len(sessions) >= limit: + return sessions + + if next_token is None: + break + next_start = next_token + + return sessions + + def list_all_sessions( + self, + agent_id: str, + limit: Optional[int] = None, + ) -> list[ConversationSession]: + """扫描 agent_id 下所有用户的 Session(主表 GetRange)(同步)。 + + 不走二级索引,直接扫主表。返回结果不含 events, + 适用于 ADK list_sessions(user_id=None) 场景。 + + Args: + agent_id: 智能体 ID。 + limit: 最多返回条数,None 表示全部。 + + Returns: + ConversationSession 列表。 + """ + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", INF_MIN), + ("session_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", INF_MAX), + ("session_id", INF_MAX), + ] + + sessions: list[ConversationSession] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = self._client.get_range( + self._conversation_table, + Direction.FORWARD, + next_start, + exclusive_end, + max_version=1, + limit=limit, + ) + + for row in rows: + session = self._row_to_session(row) + sessions.append(session) + if limit is not None and len(sessions) >= limit: + return sessions + + if next_token is None: + break + next_start = next_token + + return sessions + + async def search_sessions_async( + self, + agent_id: str, + *, + user_id: Optional[str] = None, + summary_keyword: Optional[str] = None, + labels: Optional[str] = None, + framework: Optional[str] = None, + updated_after: Optional[int] = None, + updated_before: Optional[int] = None, + is_pinned: Optional[bool] = None, + limit: int = 20, + offset: int = 0, + ) -> tuple[list[ConversationSession], int]: + """通过多元索引搜索 Session(异步)。 + + 支持全文搜索 summary、精确过滤 labels/framework/is_pinned、 + 范围查询 updated_at 以及跨 user_id 查询。 + + Args: + agent_id: 智能体 ID(必填,作为 routing 键优化查询)。 + user_id: 用户 ID(可选,精确匹配)。 + summary_keyword: summary 关键词(全文搜索)。 + labels: 标签 JSON 字符串(精确匹配)。 + framework: 框架标识(精确匹配)。 + updated_after: 仅返回 updated_at >= 此值的记录。 + updated_before: 仅返回 updated_at < 此值的记录。 + is_pinned: 是否置顶过滤。 + limit: 最多返回条数,默认 20。 + offset: 分页偏移量,默认 0。 + + Returns: + (结果列表, 总匹配数) 二元组。 + """ + from tablestore import BoolQuery # type: ignore[import-untyped] + from tablestore import MatchQuery # type: ignore[import-untyped] + from tablestore import SortOrder # type: ignore[import-untyped] + from tablestore import TermQuery # type: ignore[import-untyped] + from tablestore import ColumnReturnType, ColumnsToGet + from tablestore import ( + FieldSort as OTSFieldSort, + ) # type: ignore[import-untyped] + from tablestore import RangeQuery, SearchQuery + from tablestore import Sort as OTSSort # type: ignore[import-untyped] + + must_queries: list[Any] = [ + TermQuery("agent_id", agent_id), + ] + + if user_id is not None: + must_queries.append(TermQuery("user_id", user_id)) + if summary_keyword is not None: + must_queries.append(MatchQuery("summary", summary_keyword)) + if labels is not None: + must_queries.append(TermQuery("labels", labels)) + if framework is not None: + must_queries.append(TermQuery("framework", framework)) + if is_pinned is not None: + must_queries.append( + TermQuery("is_pinned", "true" if is_pinned else "false") + ) + if updated_after is not None or updated_before is not None: + must_queries.append( + RangeQuery( + "updated_at", + range_from=updated_after, + include_lower=True if updated_after is not None else None, + range_to=updated_before, + include_upper=False if updated_before is not None else None, + ) + ) + + query = BoolQuery(must_queries=must_queries) + + search_query = SearchQuery( + query, + sort=OTSSort( + sorters=[OTSFieldSort("updated_at", sort_order=SortOrder.DESC)] + ), + limit=limit, + offset=offset, + get_total_count=True, + ) + + columns_to_get = ColumnsToGet( + return_type=ColumnReturnType.ALL, + ) + + search_response = await self._async_client.search( + self._conversation_table, + self._conversation_search_index, + search_query, + columns_to_get, + ) + + sessions: list[ConversationSession] = [] + for row in search_response.rows: + # search API 返回 (primary_key, attribute_columns) 元组, + # 需要包装为 Row 对象以复用 _row_to_session + if isinstance(row, tuple): + row = Row(row[0], row[1]) + sessions.append(self._row_to_session(row)) + + return sessions, search_response.total_count or 0 + + # ----------------------------------------------------------------------- + # Event CRUD(异步)/ Event CRUD (async) + # ----------------------------------------------------------------------- + + def search_sessions( + self, + agent_id: str, + *, + user_id: Optional[str] = None, + summary_keyword: Optional[str] = None, + labels: Optional[str] = None, + framework: Optional[str] = None, + updated_after: Optional[int] = None, + updated_before: Optional[int] = None, + is_pinned: Optional[bool] = None, + limit: int = 20, + offset: int = 0, + ) -> tuple[list[ConversationSession], int]: + """通过多元索引搜索 Session(同步)。 + + 支持全文搜索 summary、精确过滤 labels/framework/is_pinned、 + 范围查询 updated_at 以及跨 user_id 查询。 + + Args: + agent_id: 智能体 ID(必填,作为 routing 键优化查询)。 + user_id: 用户 ID(可选,精确匹配)。 + summary_keyword: summary 关键词(全文搜索)。 + labels: 标签 JSON 字符串(精确匹配)。 + framework: 框架标识(精确匹配)。 + updated_after: 仅返回 updated_at >= 此值的记录。 + updated_before: 仅返回 updated_at < 此值的记录。 + is_pinned: 是否置顶过滤。 + limit: 最多返回条数,默认 20。 + offset: 分页偏移量,默认 0。 + + Returns: + (结果列表, 总匹配数) 二元组。 + """ + from tablestore import BoolQuery # type: ignore[import-untyped] + from tablestore import MatchQuery # type: ignore[import-untyped] + from tablestore import SortOrder # type: ignore[import-untyped] + from tablestore import TermQuery # type: ignore[import-untyped] + from tablestore import ColumnReturnType, ColumnsToGet + from tablestore import ( + FieldSort as OTSFieldSort, + ) # type: ignore[import-untyped] + from tablestore import RangeQuery, SearchQuery + from tablestore import Sort as OTSSort # type: ignore[import-untyped] + + must_queries: list[Any] = [ + TermQuery("agent_id", agent_id), + ] + + if user_id is not None: + must_queries.append(TermQuery("user_id", user_id)) + if summary_keyword is not None: + must_queries.append(MatchQuery("summary", summary_keyword)) + if labels is not None: + must_queries.append(TermQuery("labels", labels)) + if framework is not None: + must_queries.append(TermQuery("framework", framework)) + if is_pinned is not None: + must_queries.append( + TermQuery("is_pinned", "true" if is_pinned else "false") + ) + if updated_after is not None or updated_before is not None: + must_queries.append( + RangeQuery( + "updated_at", + range_from=updated_after, + include_lower=True if updated_after is not None else None, + range_to=updated_before, + include_upper=False if updated_before is not None else None, + ) + ) + + query = BoolQuery(must_queries=must_queries) + + search_query = SearchQuery( + query, + sort=OTSSort( + sorters=[OTSFieldSort("updated_at", sort_order=SortOrder.DESC)] + ), + limit=limit, + offset=offset, + get_total_count=True, + ) + + columns_to_get = ColumnsToGet( + return_type=ColumnReturnType.ALL, + ) + + search_response = self._client.search( + self._conversation_table, + self._conversation_search_index, + search_query, + columns_to_get, + ) + + sessions: list[ConversationSession] = [] + for row in search_response.rows: + # search API 返回 (primary_key, attribute_columns) 元组, + # 需要包装为 Row 对象以复用 _row_to_session + if isinstance(row, tuple): + row = Row(row[0], row[1]) + sessions.append(self._row_to_session(row)) + + return sessions, search_response.total_count or 0 + + # ----------------------------------------------------------------------- + # Event CRUD(同步)/ Event CRUD (async) + # ----------------------------------------------------------------------- + + async def put_event_async( + self, + agent_id: str, + user_id: str, + session_id: str, + event_type: str, + content: dict[str, Any], + created_at: Optional[int] = None, + updated_at: Optional[int] = None, + raw_event: Optional[str] = None, + ) -> int: + """PutRow 写入事件(seq_id AUTO_INCREMENT),返回 OTS 生成的 seq_id(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + event_type: 事件类型。 + content: 事件数据。 + created_at: 创建时间(纳秒时间戳),默认当前时间。 + updated_at: 更新时间(纳秒时间戳),默认当前时间。 + raw_event: 框架原生 Event 的完整 JSON 序列化(可选)。 + 用于精确还原框架特定的 Event 对象(如 ADK Event)。 + + Returns: + OTS 生成的 seq_id。 + """ + now = nanoseconds_timestamp() + if created_at is None: + created_at = now + if updated_at is None: + updated_at = now + + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", PK_AUTO_INCR), + ] + + content_json = json.dumps(content, ensure_ascii=False) + attribute_columns = [ + ("type", event_type), + ("content", content_json), + ("created_at", created_at), + ("updated_at", updated_at), + ("version", 0), + ] + + if raw_event is not None: + attribute_columns.append(("raw_event", raw_event)) + + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + + # put_row 返回 (consumed, return_row) + # 使用 ReturnType.RT_PK 让 OTS 返回自增 PK 值 + _, return_row = await self._async_client.put_row( + self._event_table, + row, + condition, + return_type=ReturnType.RT_PK, + ) + + # 从返回的主键中提取 seq_id + seq_id: int = 0 + if return_row is not None and return_row.primary_key is not None: + for pk_col in return_row.primary_key: + if pk_col[0] == "seq_id": + seq_id = pk_col[1] # type: ignore[assignment] + break + + return seq_id + + def put_event( + self, + agent_id: str, + user_id: str, + session_id: str, + event_type: str, + content: dict[str, Any], + created_at: Optional[int] = None, + updated_at: Optional[int] = None, + raw_event: Optional[str] = None, + ) -> int: + """PutRow 写入事件(seq_id AUTO_INCREMENT),返回 OTS 生成的 seq_id(同步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + event_type: 事件类型。 + content: 事件数据。 + created_at: 创建时间(纳秒时间戳),默认当前时间。 + updated_at: 更新时间(纳秒时间戳),默认当前时间。 + raw_event: 框架原生 Event 的完整 JSON 序列化(可选)。 + 用于精确还原框架特定的 Event 对象(如 ADK Event)。 + + Returns: + OTS 生成的 seq_id。 + """ + now = nanoseconds_timestamp() + if created_at is None: + created_at = now + if updated_at is None: + updated_at = now + + primary_key = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", PK_AUTO_INCR), + ] + + content_json = json.dumps(content, ensure_ascii=False) + attribute_columns = [ + ("type", event_type), + ("content", content_json), + ("created_at", created_at), + ("updated_at", updated_at), + ("version", 0), + ] + + if raw_event is not None: + attribute_columns.append(("raw_event", raw_event)) + + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + + # put_row 返回 (consumed, return_row) + # 使用 ReturnType.RT_PK 让 OTS 返回自增 PK 值 + _, return_row = self._client.put_row( + self._event_table, + row, + condition, + return_type=ReturnType.RT_PK, + ) + + # 从返回的主键中提取 seq_id + seq_id: int = 0 + if return_row is not None and return_row.primary_key is not None: + for pk_col in return_row.primary_key: + if pk_col[0] == "seq_id": + seq_id = pk_col[1] # type: ignore[assignment] + break + + return seq_id + + async def get_events_async( + self, + agent_id: str, + user_id: str, + session_id: str, + direction: str = "FORWARD", + limit: Optional[int] = None, + ) -> list[ConversationEvent]: + """GetRange 扫描事件列表(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + direction: 'FORWARD'(正序)或 'BACKWARD'(倒序)。 + limit: 最多返回条数。 + """ + if direction == "BACKWARD": + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MAX), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MIN), + ] + ots_direction = Direction.BACKWARD + else: + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MAX), + ] + ots_direction = Direction.FORWARD + + events: list[ConversationEvent] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = await self._async_client.get_range( + self._event_table, + ots_direction, + next_start, + exclusive_end, + max_version=1, + limit=limit, + ) + + for row in rows: + event = self._row_to_event(row) + events.append(event) + if limit is not None and len(events) >= limit: + return events + + if next_token is None: + break + next_start = next_token + + return events + + def get_events( + self, + agent_id: str, + user_id: str, + session_id: str, + direction: str = "FORWARD", + limit: Optional[int] = None, + ) -> list[ConversationEvent]: + """GetRange 扫描事件列表(同步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + direction: 'FORWARD'(正序)或 'BACKWARD'(倒序)。 + limit: 最多返回条数。 + """ + if direction == "BACKWARD": + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MAX), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MIN), + ] + ots_direction = Direction.BACKWARD + else: + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MAX), + ] + ots_direction = Direction.FORWARD + + events: list[ConversationEvent] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = self._client.get_range( + self._event_table, + ots_direction, + next_start, + exclusive_end, + max_version=1, + limit=limit, + ) + + for row in rows: + event = self._row_to_event(row) + events.append(event) + if limit is not None and len(events) >= limit: + return events + + if next_token is None: + break + next_start = next_token + + return events + + async def delete_events_by_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> int: + """批量删除 Session 下所有 Event,返回删除条数(异步)。 + + 先 GetRange 扫出所有 PK,再分批 BatchWriteRow 删除。 + """ + # 1. 扫描所有 Event 的 PK + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MAX), + ] + + all_pks: list[list[tuple[str, Any]]] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = await self._async_client.get_range( + self._event_table, + Direction.FORWARD, + next_start, + exclusive_end, + columns_to_get=[], # 只取 PK,不读属性列 + max_version=1, + ) + + for row in rows: + all_pks.append(row.primary_key) + + if next_token is None: + break + next_start = next_token + + if not all_pks: + return 0 + + # 2. 分批 BatchWriteRow 删除 + deleted = 0 + for i in range(0, len(all_pks), _BATCH_WRITE_LIMIT): + batch = all_pks[i : i + _BATCH_WRITE_LIMIT] + delete_items = [] + for pk in batch: + row = Row(pk) + condition = Condition(RowExistenceExpectation.IGNORE) + delete_items.append(DeleteRowItem(row, condition)) + + request = BatchWriteRowRequest() + request.add( + TableInBatchWriteRowItem(self._event_table, delete_items) + ) + await self._async_client.batch_write_row(request) + deleted += len(batch) + + return deleted + + # ----------------------------------------------------------------------- + # State CRUD(JSON 字符串存储 + 列分片)(异步) + # ----------------------------------------------------------------------- + + def delete_events_by_session( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> int: + """批量删除 Session 下所有 Event,返回删除条数(同步)。 + + 先 GetRange 扫出所有 PK,再分批 BatchWriteRow 删除。 + """ + # 1. 扫描所有 Event 的 PK + inclusive_start = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MIN), + ] + exclusive_end = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", INF_MAX), + ] + + all_pks: list[list[tuple[str, Any]]] = [] + next_start = inclusive_start + + while True: + ( + _, + next_token, + rows, + _, + ) = self._client.get_range( + self._event_table, + Direction.FORWARD, + next_start, + exclusive_end, + columns_to_get=[], # 只取 PK,不读属性列 + max_version=1, + ) + + for row in rows: + all_pks.append(row.primary_key) + + if next_token is None: + break + next_start = next_token + + if not all_pks: + return 0 + + # 2. 分批 BatchWriteRow 删除 + deleted = 0 + for i in range(0, len(all_pks), _BATCH_WRITE_LIMIT): + batch = all_pks[i : i + _BATCH_WRITE_LIMIT] + delete_items = [] + for pk in batch: + row = Row(pk) + condition = Condition(RowExistenceExpectation.IGNORE) + delete_items.append(DeleteRowItem(row, condition)) + + request = BatchWriteRowRequest() + request.add( + TableInBatchWriteRowItem(self._event_table, delete_items) + ) + self._client.batch_write_row(request) + deleted += len(batch) + + return deleted + + # ----------------------------------------------------------------------- + # State CRUD(JSON 字符串存储 + 列分片)(同步) + # ----------------------------------------------------------------------- + + async def put_state_async( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + state: dict[str, Any], + version: int, + ) -> None: + """序列化 + 列分片写入 State(异步)。 + + State 以 JSON 字符串(STRING 类型)存储,不压缩。 + 当 JSON 字符串超过 1.5M 字符时自动分片。 + + Args: + scope: 状态作用域(APP / USER / SESSION)。 + agent_id: 智能体 ID。 + user_id: 用户 ID(APP scope 时忽略)。 + session_id: 会话 ID(APP/USER scope 时忽略)。 + state: 状态字典。 + version: 当前版本号(乐观锁校验,首次写入传 0)。 + """ + table_name, primary_key = self._resolve_state_table_and_pk( + scope, agent_id, user_id, session_id + ) + + now = nanoseconds_timestamp() + state_json = serialize_state(state) + + put_cols: list[tuple[str, Any]] = [ + ("updated_at", now), + ("version", version + 1), + ] + + # 首次写入需要 created_at + if version == 0: + put_cols.append(("created_at", now)) + + if len(state_json) <= MAX_COLUMN_SIZE: + # 不分片 + new_chunk_count = 0 + put_cols.append(("chunk_count", 0)) + put_cols.append(("state", state_json)) + else: + # 分片 + chunks = to_chunks(state_json) + new_chunk_count = len(chunks) + put_cols.append(("chunk_count", new_chunk_count)) + for idx, chunk in enumerate(chunks): + put_cols.append((f"state_{idx}", chunk)) + + update_of_attribute_columns: dict[str, Any] = {"PUT": put_cols} + + # 如果是更新(version > 0),需要清理旧的分片列 + delete_cols: list[str] = [] + if version > 0: + old_chunk_count = await self._get_chunk_count_async( + table_name, primary_key + ) + + if new_chunk_count == 0 and old_chunk_count > 0: + # 旧的有分片,新的不分片:删除所有 state_N 列 + for i in range(old_chunk_count): + delete_cols.append(f"state_{i}") + elif new_chunk_count > 0 and old_chunk_count == 0: + # 旧的不分片,新的有分片:删除 state 列 + delete_cols.append("state") + elif new_chunk_count > 0 and old_chunk_count > new_chunk_count: + # 都分片,但旧的分片更多:删除多余分片列 + for i in range(new_chunk_count, old_chunk_count): + delete_cols.append(f"state_{i}") + + if delete_cols: + update_of_attribute_columns["DELETE_ALL"] = delete_cols + + row = Row(primary_key, update_of_attribute_columns) + + if version == 0: + # 首次写入 + condition = Condition(RowExistenceExpectation.IGNORE) + else: + condition = Condition( + RowExistenceExpectation.EXPECT_EXIST, + SingleColumnCondition( + "version", + version, + ComparatorType.EQUAL, + ), + ) + + await self._async_client.update_row(table_name, row, condition) + + def put_state( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + state: dict[str, Any], + version: int, + ) -> None: + """序列化 + 列分片写入 State(同步)。 + + State 以 JSON 字符串(STRING 类型)存储,不压缩。 + 当 JSON 字符串超过 1.5M 字符时自动分片。 + + Args: + scope: 状态作用域(APP / USER / SESSION)。 + agent_id: 智能体 ID。 + user_id: 用户 ID(APP scope 时忽略)。 + session_id: 会话 ID(APP/USER scope 时忽略)。 + state: 状态字典。 + version: 当前版本号(乐观锁校验,首次写入传 0)。 + """ + table_name, primary_key = self._resolve_state_table_and_pk( + scope, agent_id, user_id, session_id + ) + + now = nanoseconds_timestamp() + state_json = serialize_state(state) + + put_cols: list[tuple[str, Any]] = [ + ("updated_at", now), + ("version", version + 1), + ] + + # 首次写入需要 created_at + if version == 0: + put_cols.append(("created_at", now)) + + if len(state_json) <= MAX_COLUMN_SIZE: + # 不分片 + new_chunk_count = 0 + put_cols.append(("chunk_count", 0)) + put_cols.append(("state", state_json)) + else: + # 分片 + chunks = to_chunks(state_json) + new_chunk_count = len(chunks) + put_cols.append(("chunk_count", new_chunk_count)) + for idx, chunk in enumerate(chunks): + put_cols.append((f"state_{idx}", chunk)) + + update_of_attribute_columns: dict[str, Any] = {"PUT": put_cols} + + # 如果是更新(version > 0),需要清理旧的分片列 + delete_cols: list[str] = [] + if version > 0: + old_chunk_count = self._get_chunk_count(table_name, primary_key) + + if new_chunk_count == 0 and old_chunk_count > 0: + # 旧的有分片,新的不分片:删除所有 state_N 列 + for i in range(old_chunk_count): + delete_cols.append(f"state_{i}") + elif new_chunk_count > 0 and old_chunk_count == 0: + # 旧的不分片,新的有分片:删除 state 列 + delete_cols.append("state") + elif new_chunk_count > 0 and old_chunk_count > new_chunk_count: + # 都分片,但旧的分片更多:删除多余分片列 + for i in range(new_chunk_count, old_chunk_count): + delete_cols.append(f"state_{i}") + + if delete_cols: + update_of_attribute_columns["DELETE_ALL"] = delete_cols + + row = Row(primary_key, update_of_attribute_columns) + + if version == 0: + # 首次写入 + condition = Condition(RowExistenceExpectation.IGNORE) + else: + condition = Condition( + RowExistenceExpectation.EXPECT_EXIST, + SingleColumnCondition( + "version", + version, + ComparatorType.EQUAL, + ), + ) + + self._client.update_row(table_name, row, condition) + + async def get_state_async( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + ) -> Optional[StateData]: + """读取 + 拼接分片 + 反序列化 State(异步)。""" + table_name, primary_key = self._resolve_state_table_and_pk( + scope, agent_id, user_id, session_id + ) + + _, row, _ = await self._async_client.get_row( + table_name, + primary_key, + max_version=1, + ) + + if row is None or row.primary_key is None: + return None + + attrs = self._attrs_to_dict(row.attribute_columns) + + chunk_count = attrs.get("chunk_count", 0) + if chunk_count == 0: + raw_state = attrs.get("state") + if raw_state is None: + return None + state = deserialize_state(str(raw_state)) + else: + chunks: list[str] = [] + for i in range(chunk_count): + chunk = attrs.get(f"state_{i}") + if chunk is None: + raise ValueError(f"Missing state chunk: state_{i}") + chunks.append(str(chunk)) + merged_str = from_chunks(chunks) + state = deserialize_state(merged_str) + + return StateData( + state=state, + created_at=attrs.get("created_at", 0), + updated_at=attrs.get("updated_at", 0), + version=attrs.get("version", 0), + ) + + def get_state( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + ) -> Optional[StateData]: + """读取 + 拼接分片 + 反序列化 State(同步)。""" + table_name, primary_key = self._resolve_state_table_and_pk( + scope, agent_id, user_id, session_id + ) + + _, row, _ = self._client.get_row( + table_name, + primary_key, + max_version=1, + ) + + if row is None or row.primary_key is None: + return None + + attrs = self._attrs_to_dict(row.attribute_columns) + + chunk_count = attrs.get("chunk_count", 0) + if chunk_count == 0: + raw_state = attrs.get("state") + if raw_state is None: + return None + state = deserialize_state(str(raw_state)) + else: + chunks: list[str] = [] + for i in range(chunk_count): + chunk = attrs.get(f"state_{i}") + if chunk is None: + raise ValueError(f"Missing state chunk: state_{i}") + chunks.append(str(chunk)) + merged_str = from_chunks(chunks) + state = deserialize_state(merged_str) + + return StateData( + state=state, + created_at=attrs.get("created_at", 0), + updated_at=attrs.get("updated_at", 0), + version=attrs.get("version", 0), + ) + + async def delete_state_row_async( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + ) -> None: + """删除 State 行(异步)。""" + table_name, primary_key = self._resolve_state_table_and_pk( + scope, agent_id, user_id, session_id + ) + row = Row(primary_key) + condition = Condition(RowExistenceExpectation.IGNORE) + await self._async_client.delete_row(table_name, row, condition) + + # ----------------------------------------------------------------------- + # 内部辅助方法(I/O 相关,异步) + # ----------------------------------------------------------------------- + + def delete_state_row( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + ) -> None: + """删除 State 行(同步)。""" + table_name, primary_key = self._resolve_state_table_and_pk( + scope, agent_id, user_id, session_id + ) + row = Row(primary_key) + condition = Condition(RowExistenceExpectation.IGNORE) + self._client.delete_row(table_name, row, condition) + + # ----------------------------------------------------------------------- + # 内部辅助方法(I/O 相关,同步) + # ----------------------------------------------------------------------- + + async def _get_chunk_count_async( + self, + table_name: str, + primary_key: list[tuple[str, str]], + ) -> int: + """读取指定行的 chunk_count 值(异步)。""" + _, row, _ = await self._async_client.get_row( + table_name, + primary_key, + columns_to_get=["chunk_count"], + max_version=1, + ) + if row is None or row.primary_key is None: + return 0 + + attrs = self._attrs_to_dict(row.attribute_columns) + return attrs.get("chunk_count", 0) + + # ----------------------------------------------------------------------- + # 内部辅助方法(纯计算,不涉及 I/O,保持同步) + # ----------------------------------------------------------------------- + + def _get_chunk_count( + self, + table_name: str, + primary_key: list[tuple[str, str]], + ) -> int: + """读取指定行的 chunk_count 值(同步)。""" + _, row, _ = self._client.get_row( + table_name, + primary_key, + columns_to_get=["chunk_count"], + max_version=1, + ) + if row is None or row.primary_key is None: + return 0 + + attrs = self._attrs_to_dict(row.attribute_columns) + return attrs.get("chunk_count", 0) + + # ----------------------------------------------------------------------- + # 内部辅助方法(纯计算,不涉及 I/O,保持同步) + # ----------------------------------------------------------------------- + + def _resolve_state_table_and_pk( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + ) -> tuple[str, list[tuple[str, str]]]: + """根据 scope 返回对应的表名和主键列表。""" + if scope == StateScope.APP: + return self._app_state_table, [ + ("agent_id", agent_id), + ] + elif scope == StateScope.USER: + return self._user_state_table, [ + ("agent_id", agent_id), + ("user_id", user_id), + ] + else: # SESSION + return self._state_table, [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + + @staticmethod + def _attrs_to_dict( + attribute_columns: list[Any], + ) -> dict[str, Any]: + """将 OTS 属性列列表转换为字典。 + + OTS 返回的属性列格式为 [(name, value, timestamp), ...] + """ + result: dict[str, Any] = {} + if attribute_columns is None: + return result + for col in attribute_columns: + # col 格式: (name, value, timestamp) + name = col[0] + value = col[1] + result[name] = value + return result + + @staticmethod + def _pk_to_dict( + primary_key: list[Any], + ) -> dict[str, Any]: + """将 OTS 主键列表转换为字典。""" + result: dict[str, Any] = {} + if primary_key is None: + return result + for col in primary_key: + name = col[0] + value = col[1] + result[name] = value + return result + + def _row_to_session(self, row: Row) -> ConversationSession: + """将 OTS Row 转换为 ConversationSession。""" + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + + extensions = None + ext_raw = attrs.get("extensions") + if ext_raw is not None and isinstance(ext_raw, str): + extensions = json.loads(ext_raw) + + return ConversationSession( + agent_id=pk["agent_id"], + user_id=pk["user_id"], + session_id=pk["session_id"], + created_at=attrs.get("created_at", 0), + updated_at=attrs.get("updated_at", 0), + is_pinned=attrs.get("is_pinned", False), + summary=attrs.get("summary"), + labels=attrs.get("labels"), + framework=attrs.get("framework"), + extensions=extensions, + version=attrs.get("version", 0), + ) + + def _row_to_session_from_index(self, row: Row) -> ConversationSession: + """将二级索引 Row 转换为 ConversationSession。 + + 二级索引的 PK 包含 updated_at,属性列只有预定义的列。 + """ + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + + extensions = None + ext_raw = attrs.get("extensions") + if ext_raw is not None and isinstance(ext_raw, str): + extensions = json.loads(ext_raw) + + return ConversationSession( + agent_id=pk["agent_id"], + user_id=pk["user_id"], + session_id=pk["session_id"], + created_at=0, # 二级索引不含 created_at + updated_at=pk.get("updated_at", 0), + summary=attrs.get("summary"), + labels=attrs.get("labels"), + framework=attrs.get("framework"), + extensions=extensions, + ) + + def _row_to_event(self, row: Row) -> ConversationEvent: + """将 OTS Row 转换为 ConversationEvent。""" + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + + content_raw = attrs.get("content", "{}") + if isinstance(content_raw, str): + content = json.loads(content_raw) + else: + content = {} + + return ConversationEvent( + agent_id=pk["agent_id"], + user_id=pk["user_id"], + session_id=pk["session_id"], + seq_id=pk.get("seq_id"), + type=attrs.get("type", ""), + content=content, + created_at=attrs.get("created_at", 0), + updated_at=attrs.get("updated_at", 0), + version=attrs.get("version", 0), + raw_event=attrs.get("raw_event"), + ) diff --git a/agentrun/conversation_service/session_store.py b/agentrun/conversation_service/session_store.py new file mode 100644 index 0000000..49062a3 --- /dev/null +++ b/agentrun/conversation_service/session_store.py @@ -0,0 +1,1485 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/conversation_service/__session_store_async_template.py + +SessionStore 核心业务逻辑层。 + +提供框架无关的统一会话管理接口,包括 Session、Event、State 的 CRUD, +以及级联删除和三级状态合并。 +""" + +from __future__ import annotations + +import logging +from typing import Any, Optional + +from agentrun.conversation_service.model import ( + ConversationEvent, + ConversationSession, + StateScope, +) +from agentrun.conversation_service.ots_backend import OTSBackend +from agentrun.conversation_service.utils import nanoseconds_timestamp + +logger = logging.getLogger(__name__) + + +class SessionStore: + """核心业务逻辑层。 + + 封装 OTSBackend,实现级联删除、状态合并等业务逻辑, + 向上暴露框架无关的统一接口。 + 同时提供异步(_async 后缀)和同步方法。 + + Args: + ots_backend: OTS 存储后端实例。 + """ + + def __init__(self, ots_backend: OTSBackend) -> None: + self._backend = ots_backend + + async def init_tables_async(self) -> None: + """创建所有 OTS 表和索引(异步)。代理到 OTSBackend.init_tables_async()。""" + await self._backend.init_tables_async() + + def init_tables(self) -> None: + """创建所有 OTS 表和索引(同步)。代理到 OTSBackend.init_tables()。""" + self._backend.init_tables() + + async def init_core_tables_async(self) -> None: + """创建核心表(Conversation + Event)和二级索引(异步)。""" + await self._backend.init_core_tables_async() + + def init_core_tables(self) -> None: + """创建核心表(Conversation + Event)和二级索引(同步)。""" + self._backend.init_core_tables() + + async def init_state_tables_async(self) -> None: + """创建三张 State 表(异步)。""" + await self._backend.init_state_tables_async() + + def init_state_tables(self) -> None: + """创建三张 State 表(同步)。""" + self._backend.init_state_tables() + + async def init_search_index_async(self) -> None: + """创建 Conversation 多元索引(异步)。按需调用。""" + await self._backend.init_search_index_async() + + # ------------------------------------------------------------------- + # Session 管理(异步)/ Session management (async) + # ------------------------------------------------------------------- + + def init_search_index(self) -> None: + """创建 Conversation 多元索引(同步)。按需调用。""" + self._backend.init_search_index() + + # ------------------------------------------------------------------- + # Session 管理(同步)/ Session management (async) + # ------------------------------------------------------------------- + + async def create_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + *, + is_pinned: bool = False, + summary: Optional[str] = None, + labels: Optional[str] = None, + framework: Optional[str] = None, + extensions: Optional[dict[str, Any]] = None, + ) -> ConversationSession: + """创建新 Session(异步)。 + + 自动设置 created_at 和 updated_at 为当前纳秒时间戳。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + is_pinned: 是否置顶。 + summary: 会话摘要。 + labels: 会话标签。 + framework: 框架标识。 + extensions: 框架扩展数据。 + + Returns: + 创建完成的 ConversationSession 对象。 + """ + now = nanoseconds_timestamp() + session = ConversationSession( + agent_id=agent_id, + user_id=user_id, + session_id=session_id, + created_at=now, + updated_at=now, + is_pinned=is_pinned, + summary=summary, + labels=labels, + framework=framework, + extensions=extensions, + version=0, + ) + await self._backend.put_session_async(session) + return session + + def create_session( + self, + agent_id: str, + user_id: str, + session_id: str, + *, + is_pinned: bool = False, + summary: Optional[str] = None, + labels: Optional[str] = None, + framework: Optional[str] = None, + extensions: Optional[dict[str, Any]] = None, + ) -> ConversationSession: + """创建新 Session(同步)。 + + 自动设置 created_at 和 updated_at 为当前纳秒时间戳。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + is_pinned: 是否置顶。 + summary: 会话摘要。 + labels: 会话标签。 + framework: 框架标识。 + extensions: 框架扩展数据。 + + Returns: + 创建完成的 ConversationSession 对象。 + """ + now = nanoseconds_timestamp() + session = ConversationSession( + agent_id=agent_id, + user_id=user_id, + session_id=session_id, + created_at=now, + updated_at=now, + is_pinned=is_pinned, + summary=summary, + labels=labels, + framework=framework, + extensions=extensions, + version=0, + ) + self._backend.put_session(session) + return session + + async def get_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> Optional[ConversationSession]: + """获取单个 Session(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + ConversationSession 对象,不存在时返回 None。 + """ + return await self._backend.get_session_async( + agent_id, user_id, session_id + ) + + def get_session( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> Optional[ConversationSession]: + """获取单个 Session(同步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + ConversationSession 对象,不存在时返回 None。 + """ + return self._backend.get_session(agent_id, user_id, session_id) + + async def list_sessions_async( + self, + agent_id: str, + user_id: str, + limit: Optional[int] = None, + ) -> list[ConversationSession]: + """列出用户的 Session(按 updated_at 倒序)(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + limit: 最多返回条数,None 表示全部。 + + Returns: + ConversationSession 列表。 + """ + return await self._backend.list_sessions_async( + agent_id, user_id, limit=limit, order_desc=True + ) + + def list_sessions( + self, + agent_id: str, + user_id: str, + limit: Optional[int] = None, + ) -> list[ConversationSession]: + """列出用户的 Session(按 updated_at 倒序)(同步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + limit: 最多返回条数,None 表示全部。 + + Returns: + ConversationSession 列表。 + """ + return self._backend.list_sessions( + agent_id, user_id, limit=limit, order_desc=True + ) + + async def list_all_sessions_async( + self, + agent_id: str, + limit: Optional[int] = None, + ) -> list[ConversationSession]: + """列出 agent_id 下所有用户的 Session(异步)。 + + 不要求 user_id,扫描主表全量返回。 + 适用于 ADK list_sessions(user_id=None) 场景。 + + Args: + agent_id: 智能体 ID。 + limit: 最多返回条数,None 表示全部。 + + Returns: + ConversationSession 列表。 + """ + return await self._backend.list_all_sessions_async( + agent_id, limit=limit + ) + + def list_all_sessions( + self, + agent_id: str, + limit: Optional[int] = None, + ) -> list[ConversationSession]: + """列出 agent_id 下所有用户的 Session(同步)。 + + 不要求 user_id,扫描主表全量返回。 + 适用于 ADK list_sessions(user_id=None) 场景。 + + Args: + agent_id: 智能体 ID。 + limit: 最多返回条数,None 表示全部。 + + Returns: + ConversationSession 列表。 + """ + return self._backend.list_all_sessions(agent_id, limit=limit) + + async def search_sessions_async( + self, + agent_id: str, + *, + user_id: Optional[str] = None, + summary_keyword: Optional[str] = None, + labels: Optional[str] = None, + framework: Optional[str] = None, + updated_after: Optional[int] = None, + updated_before: Optional[int] = None, + is_pinned: Optional[bool] = None, + limit: int = 20, + offset: int = 0, + ) -> tuple[list[ConversationSession], int]: + """搜索会话(多元索引)(异步)。 + + 通过多元索引实现全文搜索 summary、标签过滤、跨 user 查询等高级查询。 + + Args: + agent_id: 智能体 ID(必填)。 + user_id: 用户 ID(可选,精确匹配)。 + summary_keyword: summary 关键词(全文搜索)。 + labels: 标签 JSON 字符串(精确匹配)。 + framework: 框架标识(精确匹配)。 + updated_after: 仅返回 updated_at >= 此值的记录。 + updated_before: 仅返回 updated_at < 此值的记录。 + is_pinned: 是否置顶过滤。 + limit: 最多返回条数,默认 20。 + offset: 分页偏移量,默认 0。 + + Returns: + (结果列表, 总匹配数) 二元组。 + """ + return await self._backend.search_sessions_async( + agent_id, + user_id=user_id, + summary_keyword=summary_keyword, + labels=labels, + framework=framework, + updated_after=updated_after, + updated_before=updated_before, + is_pinned=is_pinned, + limit=limit, + offset=offset, + ) + + def search_sessions( + self, + agent_id: str, + *, + user_id: Optional[str] = None, + summary_keyword: Optional[str] = None, + labels: Optional[str] = None, + framework: Optional[str] = None, + updated_after: Optional[int] = None, + updated_before: Optional[int] = None, + is_pinned: Optional[bool] = None, + limit: int = 20, + offset: int = 0, + ) -> tuple[list[ConversationSession], int]: + """搜索会话(多元索引)(同步)。 + + 通过多元索引实现全文搜索 summary、标签过滤、跨 user 查询等高级查询。 + + Args: + agent_id: 智能体 ID(必填)。 + user_id: 用户 ID(可选,精确匹配)。 + summary_keyword: summary 关键词(全文搜索)。 + labels: 标签 JSON 字符串(精确匹配)。 + framework: 框架标识(精确匹配)。 + updated_after: 仅返回 updated_at >= 此值的记录。 + updated_before: 仅返回 updated_at < 此值的记录。 + is_pinned: 是否置顶过滤。 + limit: 最多返回条数,默认 20。 + offset: 分页偏移量,默认 0。 + + Returns: + (结果列表, 总匹配数) 二元组。 + """ + return self._backend.search_sessions( + agent_id, + user_id=user_id, + summary_keyword=summary_keyword, + labels=labels, + framework=framework, + updated_after=updated_after, + updated_before=updated_before, + is_pinned=is_pinned, + limit=limit, + offset=offset, + ) + + async def delete_events_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> int: + """只删除 Session 下所有 Event,不删 Session 本身(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + 删除的事件条数。 + """ + deleted = await self._backend.delete_events_by_session_async( + agent_id, user_id, session_id + ) + logger.debug( + "Deleted %d events for session %s/%s/%s", + deleted, + agent_id, + user_id, + session_id, + ) + return deleted + + def delete_events( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> int: + """只删除 Session 下所有 Event,不删 Session 本身(同步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + 删除的事件条数。 + """ + deleted = self._backend.delete_events_by_session( + agent_id, user_id, session_id + ) + logger.debug( + "Deleted %d events for session %s/%s/%s", + deleted, + agent_id, + user_id, + session_id, + ) + return deleted + + async def delete_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> None: + """级联删除 Session(异步)。 + + 删除顺序:Event → State → Session 行。 + 先删 Event(量最大),再删 State,最后删 Session 行。 + 如果中间失败,Session 行仍在,下次重试可继续清理(幂等安全)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + """ + # 1. 删除所有 Event + deleted_events = await self._backend.delete_events_by_session_async( + agent_id, user_id, session_id + ) + logger.debug( + "Deleted %d events for session %s/%s/%s", + deleted_events, + agent_id, + user_id, + session_id, + ) + + # 2. 删除 Session 级 State + await self._backend.delete_state_row_async( + StateScope.SESSION, + agent_id, + user_id, + session_id, + ) + + # 3. 删除 Session 行 + await self._backend.delete_session_row_async( + agent_id, user_id, session_id + ) + + logger.info( + "Cascade deleted session %s/%s/%s", + agent_id, + user_id, + session_id, + ) + + def delete_session( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> None: + """级联删除 Session(同步)。 + + 删除顺序:Event → State → Session 行。 + 先删 Event(量最大),再删 State,最后删 Session 行。 + 如果中间失败,Session 行仍在,下次重试可继续清理(幂等安全)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + """ + # 1. 删除所有 Event + deleted_events = self._backend.delete_events_by_session( + agent_id, user_id, session_id + ) + logger.debug( + "Deleted %d events for session %s/%s/%s", + deleted_events, + agent_id, + user_id, + session_id, + ) + + # 2. 删除 Session 级 State + self._backend.delete_state_row( + StateScope.SESSION, + agent_id, + user_id, + session_id, + ) + + # 3. 删除 Session 行 + self._backend.delete_session_row(agent_id, user_id, session_id) + + logger.info( + "Cascade deleted session %s/%s/%s", + agent_id, + user_id, + session_id, + ) + + async def update_session_async( + self, + agent_id: str, + user_id: str, + session_id: str, + *, + is_pinned: Optional[bool] = None, + summary: Optional[str] = None, + labels: Optional[str] = None, + extensions: Optional[dict[str, Any]] = None, + version: int, + ) -> None: + """更新 Session 属性(乐观锁)(异步)。 + + 只更新提供的字段,未提供的字段不变。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + is_pinned: 是否置顶。 + summary: 会话摘要。 + labels: 会话标签。 + extensions: 框架扩展数据。 + version: 当前版本号(乐观锁校验)。 + """ + columns_to_put: dict[str, Any] = { + "updated_at": nanoseconds_timestamp(), + "version": version + 1, + } + + if is_pinned is not None: + columns_to_put["is_pinned"] = is_pinned + if summary is not None: + columns_to_put["summary"] = summary + if labels is not None: + columns_to_put["labels"] = labels + if extensions is not None: + import json + + columns_to_put["extensions"] = json.dumps( + extensions, ensure_ascii=False + ) + + await self._backend.update_session_async( + agent_id, + user_id, + session_id, + columns_to_put, + version, + ) + + # ------------------------------------------------------------------- + # Event 管理(异步)/ Event management (async) + # ------------------------------------------------------------------- + + def update_session( + self, + agent_id: str, + user_id: str, + session_id: str, + *, + is_pinned: Optional[bool] = None, + summary: Optional[str] = None, + labels: Optional[str] = None, + extensions: Optional[dict[str, Any]] = None, + version: int, + ) -> None: + """更新 Session 属性(乐观锁)(同步)。 + + 只更新提供的字段,未提供的字段不变。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + is_pinned: 是否置顶。 + summary: 会话摘要。 + labels: 会话标签。 + extensions: 框架扩展数据。 + version: 当前版本号(乐观锁校验)。 + """ + columns_to_put: dict[str, Any] = { + "updated_at": nanoseconds_timestamp(), + "version": version + 1, + } + + if is_pinned is not None: + columns_to_put["is_pinned"] = is_pinned + if summary is not None: + columns_to_put["summary"] = summary + if labels is not None: + columns_to_put["labels"] = labels + if extensions is not None: + import json + + columns_to_put["extensions"] = json.dumps( + extensions, ensure_ascii=False + ) + + self._backend.update_session( + agent_id, + user_id, + session_id, + columns_to_put, + version, + ) + + # ------------------------------------------------------------------- + # Event 管理(同步)/ Event management (async) + # ------------------------------------------------------------------- + + async def append_event_async( + self, + agent_id: str, + user_id: str, + session_id: str, + event_type: str, + content: dict[str, Any], + raw_event: Optional[str] = None, + ) -> ConversationEvent: + """追加事件,同时更新 Session 的 updated_at(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + event_type: 事件类型。 + content: 事件数据。 + raw_event: 框架原生 Event 的完整 JSON 序列化(可选)。 + 用于精确还原框架特定的 Event 对象(如 ADK Event)。 + + Returns: + 包含 OTS 生成的 seq_id 的 ConversationEvent 对象。 + """ + now = nanoseconds_timestamp() + + # 1. 写入 Event + seq_id = await self._backend.put_event_async( + agent_id, + user_id, + session_id, + event_type, + content, + created_at=now, + updated_at=now, + raw_event=raw_event, + ) + + # 2. 更新 Session 的 updated_at(保证二级索引排序正确) + # 先读取当前 Session 获取 version + session = await self._backend.get_session_async( + agent_id, user_id, session_id + ) + if session is not None: + try: + await self._backend.update_session_async( + agent_id, + user_id, + session_id, + { + "updated_at": now, + "version": session.version + 1, + }, + session.version, + ) + except Exception: + # 更新 Session 时间戳失败不应阻断事件写入 + logger.warning( + "Failed to update session updated_at " + "for %s/%s/%s, event was still written.", + agent_id, + user_id, + session_id, + exc_info=True, + ) + + return ConversationEvent( + agent_id=agent_id, + user_id=user_id, + session_id=session_id, + seq_id=seq_id, + type=event_type, + content=content, + created_at=now, + updated_at=now, + version=0, + raw_event=raw_event, + ) + + def append_event( + self, + agent_id: str, + user_id: str, + session_id: str, + event_type: str, + content: dict[str, Any], + raw_event: Optional[str] = None, + ) -> ConversationEvent: + """追加事件,同时更新 Session 的 updated_at(同步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + event_type: 事件类型。 + content: 事件数据。 + raw_event: 框架原生 Event 的完整 JSON 序列化(可选)。 + 用于精确还原框架特定的 Event 对象(如 ADK Event)。 + + Returns: + 包含 OTS 生成的 seq_id 的 ConversationEvent 对象。 + """ + now = nanoseconds_timestamp() + + # 1. 写入 Event + seq_id = self._backend.put_event( + agent_id, + user_id, + session_id, + event_type, + content, + created_at=now, + updated_at=now, + raw_event=raw_event, + ) + + # 2. 更新 Session 的 updated_at(保证二级索引排序正确) + # 先读取当前 Session 获取 version + session = self._backend.get_session(agent_id, user_id, session_id) + if session is not None: + try: + self._backend.update_session( + agent_id, + user_id, + session_id, + { + "updated_at": now, + "version": session.version + 1, + }, + session.version, + ) + except Exception: + # 更新 Session 时间戳失败不应阻断事件写入 + logger.warning( + "Failed to update session updated_at " + "for %s/%s/%s, event was still written.", + agent_id, + user_id, + session_id, + exc_info=True, + ) + + return ConversationEvent( + agent_id=agent_id, + user_id=user_id, + session_id=session_id, + seq_id=seq_id, + type=event_type, + content=content, + created_at=now, + updated_at=now, + version=0, + raw_event=raw_event, + ) + + async def get_events_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> list[ConversationEvent]: + """获取 Session 全部事件(正序)(异步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + 按 seq_id 正序排列的事件列表。 + """ + return await self._backend.get_events_async( + agent_id, + user_id, + session_id, + direction="FORWARD", + ) + + def get_events( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> list[ConversationEvent]: + """获取 Session 全部事件(正序)(同步)。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + 按 seq_id 正序排列的事件列表。 + """ + return self._backend.get_events( + agent_id, + user_id, + session_id, + direction="FORWARD", + ) + + async def get_recent_events_async( + self, + agent_id: str, + user_id: str, + session_id: str, + n: int, + ) -> list[ConversationEvent]: + """获取最近 N 条事件(异步)。 + + 倒序取 N 条,返回时翻转为正序。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + n: 需要获取的事件数量。 + + Returns: + 按 seq_id 正序排列的最近 N 条事件。 + """ + events = await self._backend.get_events_async( + agent_id, + user_id, + session_id, + direction="BACKWARD", + limit=n, + ) + events.reverse() + return events + + # ------------------------------------------------------------------- + # State 管理(异步)/ State management (async) + # ------------------------------------------------------------------- + + def get_recent_events( + self, + agent_id: str, + user_id: str, + session_id: str, + n: int, + ) -> list[ConversationEvent]: + """获取最近 N 条事件(同步)。 + + 倒序取 N 条,返回时翻转为正序。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + n: 需要获取的事件数量。 + + Returns: + 按 seq_id 正序排列的最近 N 条事件。 + """ + events = self._backend.get_events( + agent_id, + user_id, + session_id, + direction="BACKWARD", + limit=n, + ) + events.reverse() + return events + + # ------------------------------------------------------------------- + # State 管理(同步)/ State management (async) + # ------------------------------------------------------------------- + + async def get_session_state_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> dict[str, Any]: + """获取 session 级 state,不存在返回 {}(异步)。""" + state_data = await self._backend.get_state_async( + StateScope.SESSION, + agent_id, + user_id, + session_id, + ) + return state_data.state if state_data else {} + + def get_session_state( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> dict[str, Any]: + """获取 session 级 state,不存在返回 {}(同步)。""" + state_data = self._backend.get_state( + StateScope.SESSION, + agent_id, + user_id, + session_id, + ) + return state_data.state if state_data else {} + + async def update_session_state_async( + self, + agent_id: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 session state(异步)。 + + 浅合并语义:top-level key 覆盖,值为 None 表示删除该 key。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + delta: 增量更新字典。 + """ + await self._apply_delta_async( + StateScope.SESSION, + agent_id, + user_id, + session_id, + delta, + ) + + def update_session_state( + self, + agent_id: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 session state(同步)。 + + 浅合并语义:top-level key 覆盖,值为 None 表示删除该 key。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + delta: 增量更新字典。 + """ + self._apply_delta( + StateScope.SESSION, + agent_id, + user_id, + session_id, + delta, + ) + + async def get_app_state_async(self, agent_id: str) -> dict[str, Any]: + """获取 app 级 state,不存在返回 {}(异步)。""" + state_data = await self._backend.get_state_async( + StateScope.APP, agent_id, "", "" + ) + return state_data.state if state_data else {} + + def get_app_state(self, agent_id: str) -> dict[str, Any]: + """获取 app 级 state,不存在返回 {}(同步)。""" + state_data = self._backend.get_state(StateScope.APP, agent_id, "", "") + return state_data.state if state_data else {} + + async def update_app_state_async( + self, + agent_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 app state(异步)。 + + 浅合并语义:top-level key 覆盖,值为 None 表示删除该 key。 + """ + await self._apply_delta_async(StateScope.APP, agent_id, "", "", delta) + + def update_app_state( + self, + agent_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 app state(同步)。 + + 浅合并语义:top-level key 覆盖,值为 None 表示删除该 key。 + """ + self._apply_delta(StateScope.APP, agent_id, "", "", delta) + + async def get_user_state_async( + self, agent_id: str, user_id: str + ) -> dict[str, Any]: + """获取 user 级 state,不存在返回 {}(异步)。""" + state_data = await self._backend.get_state_async( + StateScope.USER, agent_id, user_id, "" + ) + return state_data.state if state_data else {} + + def get_user_state(self, agent_id: str, user_id: str) -> dict[str, Any]: + """获取 user 级 state,不存在返回 {}(同步)。""" + state_data = self._backend.get_state( + StateScope.USER, agent_id, user_id, "" + ) + return state_data.state if state_data else {} + + async def update_user_state_async( + self, + agent_id: str, + user_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 user state(异步)。 + + 浅合并语义:top-level key 覆盖,值为 None 表示删除该 key。 + """ + await self._apply_delta_async( + StateScope.USER, + agent_id, + user_id, + "", + delta, + ) + + def update_user_state( + self, + agent_id: str, + user_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 user state(同步)。 + + 浅合并语义:top-level key 覆盖,值为 None 表示删除该 key。 + """ + self._apply_delta( + StateScope.USER, + agent_id, + user_id, + "", + delta, + ) + + async def get_merged_state_async( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> dict[str, Any]: + """三级状态浅合并:app_state <- user_state <- session_state(异步)。 + + 后者覆盖前者,任意层不存在视为空 dict。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + 合并后的状态字典。 + """ + merged: dict[str, Any] = {} + merged.update(await self.get_app_state_async(agent_id)) + merged.update(await self.get_user_state_async(agent_id, user_id)) + merged.update( + await self.get_session_state_async(agent_id, user_id, session_id) + ) + return merged + + # ------------------------------------------------------------------- + # 内部辅助方法(异步) + # ------------------------------------------------------------------- + + def get_merged_state( + self, + agent_id: str, + user_id: str, + session_id: str, + ) -> dict[str, Any]: + """三级状态浅合并:app_state <- user_state <- session_state(同步)。 + + 后者覆盖前者,任意层不存在视为空 dict。 + + Args: + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + + Returns: + 合并后的状态字典。 + """ + merged: dict[str, Any] = {} + merged.update(self.get_app_state(agent_id)) + merged.update(self.get_user_state(agent_id, user_id)) + merged.update(self.get_session_state(agent_id, user_id, session_id)) + return merged + + # ------------------------------------------------------------------- + # 内部辅助方法(同步) + # ------------------------------------------------------------------- + + async def _apply_delta_async( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 State(通用逻辑)(异步)。 + + - 首次写入:过滤 None 值后整体写入,version=0 + - 后续更新:读取现有 state → 浅合并 delta(None 删除 key)→ 写回 + + Args: + scope: 状态作用域。 + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + delta: 增量更新字典。 + """ + existing = await self._backend.get_state_async( + scope, agent_id, user_id, session_id + ) + + if existing is None: + # 首次写入,过滤 None 值 + new_state = {k: v for k, v in delta.items() if v is not None} + await self._backend.put_state_async( + scope, + agent_id, + user_id, + session_id, + state=new_state, + version=0, + ) + else: + # 增量合并 + merged = dict(existing.state) + for k, v in delta.items(): + if v is None: + merged.pop(k, None) # None 表示删除 + else: + merged[k] = v # 浅覆盖 + await self._backend.put_state_async( + scope, + agent_id, + user_id, + session_id, + state=merged, + version=existing.version, + ) + + # ------------------------------------------------------------------- + # 工厂方法(异步)/ Factory methods (async) + # ------------------------------------------------------------------- + + def _apply_delta( + self, + scope: StateScope, + agent_id: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """增量更新 State(通用逻辑)(同步)。 + + - 首次写入:过滤 None 值后整体写入,version=0 + - 后续更新:读取现有 state → 浅合并 delta(None 删除 key)→ 写回 + + Args: + scope: 状态作用域。 + agent_id: 智能体 ID。 + user_id: 用户 ID。 + session_id: 会话 ID。 + delta: 增量更新字典。 + """ + existing = self._backend.get_state(scope, agent_id, user_id, session_id) + + if existing is None: + # 首次写入,过滤 None 值 + new_state = {k: v for k, v in delta.items() if v is not None} + self._backend.put_state( + scope, + agent_id, + user_id, + session_id, + state=new_state, + version=0, + ) + else: + # 增量合并 + merged = dict(existing.state) + for k, v in delta.items(): + if v is None: + merged.pop(k, None) # None 表示删除 + else: + merged[k] = v # 浅覆盖 + self._backend.put_state( + scope, + agent_id, + user_id, + session_id, + state=merged, + version=existing.version, + ) + + # ------------------------------------------------------------------- + # 工厂方法(同步)/ Factory methods (async) + # ------------------------------------------------------------------- + + @classmethod + async def from_memory_collection_async( + cls, + memory_collection_name: str, + *, + config: Optional[Any] = None, + table_prefix: str = "", + ) -> "SessionStore": + """通过 MemoryCollection 名称创建 SessionStore(异步)。 + + 从 AgentRun 平台获取 MemoryCollection 配置,自动提取 OTS 实例 + 的 endpoint 和 instance_name,结合 Config 中的 AK/SK 凭证, + 构建 OTSClient 和 OTSBackend,返回即用的 SessionStore。 + + Args: + memory_collection_name: AgentRun 平台上的 MemoryCollection 名称。 + config: agentrun Config 对象(可选)。 + 未提供时自动从环境变量读取凭证。 + table_prefix: 表名前缀,用于多租户隔离,默认不添加。 + + Returns: + 配置完成的 SessionStore 实例。 + + Raises: + ImportError: 未安装 agentrun 主包时抛出。 + ValueError: MemoryCollection 缺少 OTS 配置或凭证为空时抛出。 + + Example:: + + store = await SessionStore.from_memory_collection_async( + "my-memory-collection", + ) + await store.init_tables_async() + """ + # 延迟导入,避免 conversation_service 强依赖 agentrun 主包 + try: + from agentrun.memory_collection import MemoryCollection + from agentrun.utils.config import Config + except ImportError as e: + raise ImportError( + "agentrun 主包未安装。请先安装: pip install agentrun" + ) from e + + from tablestore import AsyncOTSClient # type: ignore[import-untyped] + from tablestore import OTSClient # type: ignore[import-untyped] + from tablestore import WriteRetryPolicy + + from agentrun.conversation_service.utils import ( + convert_vpc_endpoint_to_public, + ) + + # 1. 获取 MemoryCollection 配置 + mc = await MemoryCollection.get_by_name_async( + memory_collection_name, config=config + ) + + # 2. 提取 OTS 连接信息 + if not mc.vector_store_config or not mc.vector_store_config.config: + raise ValueError( + f"MemoryCollection '{memory_collection_name}' 缺少 " + "vector_store_config 配置,无法获取 OTS 连接信息。" + ) + + vs_config = mc.vector_store_config.config + endpoint = convert_vpc_endpoint_to_public(vs_config.endpoint or "") + instance_name = vs_config.instance_name or "" + + if not endpoint: + raise ValueError( + f"MemoryCollection '{memory_collection_name}' 的 " + "vector_store_config.endpoint 为空。" + ) + if not instance_name: + raise ValueError( + f"MemoryCollection '{memory_collection_name}' 的 " + "vector_store_config.instance_name 为空。" + ) + + # 3. 获取凭证 + effective_config = config if isinstance(config, Config) else Config() + access_key_id = effective_config.get_access_key_id() + access_key_secret = effective_config.get_access_key_secret() + + if not access_key_id or not access_key_secret: + raise ValueError( + "AK/SK 凭证为空。请通过 Config 参数传入或设置环境变量 " + "AGENTRUN_ACCESS_KEY_ID / AGENTRUN_ACCESS_KEY_SECRET。" + ) + + security_token = effective_config.get_security_token() + sts_token = security_token if security_token else None + + # 4. 构建 OTSClient + AsyncOTSClient 和 OTSBackend + ots_client = OTSClient( + endpoint, + access_key_id, + access_key_secret, + instance_name, + sts_token=sts_token, + retry_policy=WriteRetryPolicy(), + ) + async_ots_client = AsyncOTSClient( + endpoint, + access_key_id, + access_key_secret, + instance_name, + sts_token=sts_token, + retry_policy=WriteRetryPolicy(), + ) + + backend = OTSBackend( + ots_client, + table_prefix=table_prefix, + async_ots_client=async_ots_client, + ) + return cls(backend) + + @classmethod + def from_memory_collection( + cls, + memory_collection_name: str, + *, + config: Optional[Any] = None, + table_prefix: str = "", + ) -> "SessionStore": + """通过 MemoryCollection 名称创建 SessionStore(同步)。 + + 从 AgentRun 平台获取 MemoryCollection 配置,自动提取 OTS 实例 + 的 endpoint 和 instance_name,结合 Config 中的 AK/SK 凭证, + 构建 OTSClient 和 OTSBackend,返回即用的 SessionStore。 + + Args: + memory_collection_name: AgentRun 平台上的 MemoryCollection 名称。 + config: agentrun Config 对象(可选)。 + 未提供时自动从环境变量读取凭证。 + table_prefix: 表名前缀,用于多租户隔离,默认不添加。 + + Returns: + 配置完成的 SessionStore 实例。 + + Raises: + ImportError: 未安装 agentrun 主包时抛出。 + ValueError: MemoryCollection 缺少 OTS 配置或凭证为空时抛出。 + + Example:: + + store = SessionStore.from_memory_collection( + "my-memory-collection", + ) + store.init_tables() + """ + # 延迟导入,避免 conversation_service 强依赖 agentrun 主包 + try: + from agentrun.memory_collection import MemoryCollection + from agentrun.utils.config import Config + except ImportError as e: + raise ImportError( + "agentrun 主包未安装。请先安装: pip install agentrun" + ) from e + + from tablestore import AsyncOTSClient # type: ignore[import-untyped] + from tablestore import OTSClient # type: ignore[import-untyped] + from tablestore import WriteRetryPolicy + + from agentrun.conversation_service.utils import ( + convert_vpc_endpoint_to_public, + ) + + # 1. 获取 MemoryCollection 配置 + mc = MemoryCollection.get_by_name(memory_collection_name, config=config) + + # 2. 提取 OTS 连接信息 + if not mc.vector_store_config or not mc.vector_store_config.config: + raise ValueError( + f"MemoryCollection '{memory_collection_name}' 缺少 " + "vector_store_config 配置,无法获取 OTS 连接信息。" + ) + + vs_config = mc.vector_store_config.config + endpoint = convert_vpc_endpoint_to_public(vs_config.endpoint or "") + instance_name = vs_config.instance_name or "" + + if not endpoint: + raise ValueError( + f"MemoryCollection '{memory_collection_name}' 的 " + "vector_store_config.endpoint 为空。" + ) + if not instance_name: + raise ValueError( + f"MemoryCollection '{memory_collection_name}' 的 " + "vector_store_config.instance_name 为空。" + ) + + # 3. 获取凭证 + effective_config = config if isinstance(config, Config) else Config() + access_key_id = effective_config.get_access_key_id() + access_key_secret = effective_config.get_access_key_secret() + + if not access_key_id or not access_key_secret: + raise ValueError( + "AK/SK 凭证为空。请通过 Config 参数传入或设置环境变量 " + "AGENTRUN_ACCESS_KEY_ID / AGENTRUN_ACCESS_KEY_SECRET。" + ) + + security_token = effective_config.get_security_token() + sts_token = security_token if security_token else None + + # 4. 构建 OTSClient + AsyncOTSClient 和 OTSBackend + ots_client = OTSClient( + endpoint, + access_key_id, + access_key_secret, + instance_name, + sts_token=sts_token, + retry_policy=WriteRetryPolicy(), + ) + async_ots_client = AsyncOTSClient( + endpoint, + access_key_id, + access_key_secret, + instance_name, + sts_token=sts_token, + retry_policy=WriteRetryPolicy(), + ) + + backend = OTSBackend( + ots_client, + table_prefix=table_prefix, + async_ots_client=async_ots_client, + ) + return cls(backend) diff --git a/agentrun/conversation_service/utils.py b/agentrun/conversation_service/utils.py new file mode 100644 index 0000000..70ac3b8 --- /dev/null +++ b/agentrun/conversation_service/utils.py @@ -0,0 +1,99 @@ +"""Conversation Service 工具函数。 + +提供状态序列化/反序列化、字符串分片/拼接、时间戳生成等工具。 +""" + +from __future__ import annotations + +import json +import time +from typing import Any + +# OTS 单个属性列值上限为 2MB,留 0.5MB 余量(按字符数计) +MAX_COLUMN_SIZE: int = 1_500_000 # 1.5M 字符 + + +def convert_vpc_endpoint_to_public(endpoint: str) -> str: + """将 OTS VPC 内网地址转换为公网地址。 + + Args: + endpoint: 原始 endpoint,可能是 VPC 内网地址。 + + Returns: + 公网地址。若非 VPC 地址则原样返回。 + + Example:: + + >>> convert_vpc_endpoint_to_public( + ... "https://inst.cn-hangzhou.vpc.tablestore.aliyuncs.com" + ... ) + 'https://inst.cn-hangzhou.ots.aliyuncs.com' + """ + if ".vpc.tablestore.aliyuncs.com" in endpoint: + return endpoint.replace( + ".vpc.tablestore.aliyuncs.com", ".ots.aliyuncs.com" + ) + return endpoint + + +def nanoseconds_timestamp() -> int: + """返回当前时间的纳秒时间戳。""" + return int(time.time() * 1_000_000_000) + + +def serialize_state(state: dict[str, Any]) -> str: + """将状态字典序列化为 JSON 字符串。 + + Args: + state: 状态字典。 + + Returns: + JSON 字符串。 + """ + return json.dumps(state, ensure_ascii=False) + + +def deserialize_state(data: str) -> dict[str, Any]: + """将 JSON 字符串反序列化为状态字典。 + + Args: + data: JSON 字符串。 + + Returns: + 状态字典。 + """ + result: dict[str, Any] = json.loads(data) + return result + + +def to_chunks(data: str, max_size: int = MAX_COLUMN_SIZE) -> list[str]: + """将字符串按指定长度切分为多个分片。 + + Args: + data: 待切分的字符串。 + max_size: 每个分片的最大字符数,默认 1.5M。 + + Returns: + 分片列表。若数据小于 max_size,返回包含单个元素的列表。 + """ + if max_size <= 0: + raise ValueError("max_size must be positive") + + chunks: list[str] = [] + offset = 0 + while offset < len(data): + chunks.append(data[offset : offset + max_size]) + offset += max_size + return chunks + + +def from_chunks(chunks: list[str]) -> str: + """将多个分片拼接为完整字符串。 + + Args: + chunks: 分片列表。 + + Returns: + 拼接后的完整字符串。 + """ + return "".join(chunks) diff --git a/agentrun/integration/utils/tool.py b/agentrun/integration/utils/tool.py index bde72a5..5b0e874 100644 --- a/agentrun/integration/utils/tool.py +++ b/agentrun/integration/utils/tool.py @@ -1562,8 +1562,8 @@ def _build_openapi_schema( if isinstance(schema, dict): properties[name] = { **schema, - "description": param.get("description") or schema.get( - "description", "" + "description": ( + param.get("description") or schema.get("description", "") ), } if param.get("required"): diff --git a/agentrun/memory_collection/README.md b/agentrun/memory_collection/README.md new file mode 100644 index 0000000..abca75e --- /dev/null +++ b/agentrun/memory_collection/README.md @@ -0,0 +1,406 @@ +# MemoryCollection 模块开发参考 + +## 目录 + +- [模块概述](#模块概述) +- [目录结构](#目录结构) +- [架构分层](#架构分层) +- [数据模型](#数据模型) +- [API 使用指南](#api-使用指南) +- [mem0 集成](#mem0-集成) +- [代码生成机制](#代码生成机制) +- [开发注意事项](#开发注意事项) + +--- + +## 模块概述 + +`memory_collection` 模块提供 **记忆集合 (MemoryCollection)** 资源的完整生命周期管理,包括创建、获取、更新、删除和列表查询。同时支持将 MemoryCollection 转换为 `agentrun-mem0ai` 客户端,以便直接操作记忆数据。 + +所有 API 均提供 **同步** 和 **异步** 两种调用方式。 + +--- + +## 目录结构 + +``` +memory_collection/ +├── __init__.py # 模块入口,导出公开 API +├── model.py # 数据模型定义(手动维护) +├── client.py # 客户端封装(自动生成,勿手动修改) +├── memory_collection.py # 高层资源 API(自动生成,勿手动修改) +├── __client_async_template.py # client.py 的异步模板(手动维护) +├── __memory_collection_async_template.py # memory_collection.py 的异步模板(手动维护) +└── api/ + ├── __init__.py + └── control.py # 底层 SDK 交互层(自动生成,勿手动修改) +``` + +### 文件职责速查 + +| 文件 | 可编辑 | 职责 | +|------|--------|------| +| `model.py` | 是 | 定义所有数据模型(输入/输出/属性) | +| `__client_async_template.py` | 是 | Client 的异步模板,用于生成 `client.py` | +| `__memory_collection_async_template.py` | 是 | MemoryCollection 的异步模板,用于生成 `memory_collection.py` | +| `client.py` | **否** | 自动生成的客户端代码(包含同步+异步方法) | +| `memory_collection.py` | **否** | 自动生成的高层 API(包含同步+异步方法) | +| `api/control.py` | **否** | 自动生成的底层 SDK 调用封装 | + +--- + +## 架构分层 + +模块采用三层架构设计: + +``` +┌────────────────────────────────────────────────────────────────┐ +│ 用户代码 │ +└────────────────────┬───────────────────────────────────────────┘ + │ + ┌────────────────▼────────────────────┐ + │ MemoryCollection (高层 API) │ memory_collection.py + │ - 类方法: create / get_by_name │ 继承自 ResourceBase + │ - 实例方法: update / delete │ + MutableProps + │ - 列表: list_all │ + ImmutableProps + │ - 转换: to_mem0_memory │ + SystemProps + └────────────────┬────────────────────┘ + │ + ┌────────────────▼────────────────────┐ + │ MemoryCollectionClient (客户端) │ client.py + │ - create / delete │ 封装输入输出转换 + │ - update / get / list │ 错误处理转换 + └────────────────┬────────────────────┘ + │ + ┌────────────────▼────────────────────┐ + │ MemoryCollectionControlAPI (底层) │ api/control.py + │ - 直接调用阿里云底层 SDK │ 继承自 ControlAPI + │ - 处理 HTTP 异常映射 │ + └────────────────────────────────────┘ + │ + ┌────────────────▼────────────────────┐ + │ alibabacloud_agentrun20250910 │ 底层 SDK + └────────────────────────────────────┘ +``` + +### 各层职责 + +1. **api/control.py(底层 API 层)** + - 直接调用 `alibabacloud_agentrun20250910` 底层 SDK + - 将 `ClientException` / `ServerException` 转换为 `ClientError` / `ServerError` + - 提供日志输出(debug 级别) + +2. **client.py(客户端层)** + - 将 AgentRun SDK 的 Model(如 `MemoryCollectionCreateInput`)转换为底层 SDK 的 Input + - 将底层 SDK 返回的对象转换为 `MemoryCollection` 高层对象 + - 将 `HTTPError` 转换为语义化的资源错误(如 `ResourceAlreadyExistError`、`ResourceNotExistError`) + +3. **memory_collection.py(高层 API 层)** + - 提供类方法(静态操作):`create`、`get_by_name`、`delete_by_name`、`update_by_name`、`list_all` + - 提供实例方法(基于当前对象):`update`、`delete`、`get`、`refresh` + - 提供 mem0 集成能力:`to_mem0_memory`、`to_mem0_memory_async` + +--- + +## 数据模型 + +### 模型继承关系 + +``` +BaseModel +├── MemoryCollectionMutableProps # 可变属性(可通过 update 修改) +│ ├── description +│ ├── embedder_config +│ ├── execution_role_arn +│ ├── llm_config +│ ├── network_configuration +│ └── vector_store_config +│ +├── MemoryCollectionImmutableProps # 不可变属性(创建时指定,不可修改) +│ ├── memory_collection_name +│ └── type +│ +├── MemoryCollectionSystemProps # 系统属性(只读,由服务端生成) +│ ├── memory_collection_id +│ ├── created_at +│ └── last_updated_at +│ +├── MemoryCollectionCreateInput # 创建输入 = Immutable + Mutable +├── MemoryCollectionUpdateInput # 更新输入 = Mutable +├── MemoryCollectionListInput # 列表查询输入(含分页) +└── MemoryCollectionListOutput # 列表查询输出(摘要信息) +``` + +### 配置子模型 + +| 模型 | 说明 | 关键字段 | +|------|------|----------| +| `EmbedderConfig` | 嵌入模型配置 | `config.model`, `model_service_name` | +| `LLMConfig` | LLM 配置 | `config.model`, `model_service_name` | +| `VectorStoreConfig` | 向量存储配置 | `provider`, `config` (TableStore), `mysql_config` (MySQL) | +| `VectorStoreConfigConfig` | TableStore 向量存储内部配置 | `endpoint`, `instance_name`, `collection_name`, `vector_dimension` | +| `VectorStoreConfigMysqlConfig` | MySQL 向量存储配置 | `host`, `port`, `db_name`, `user`, `credential_name`, `collection_name`, `vector_dimension` | +| `NetworkConfiguration` | 网络配置 | `vpc_id`, `vswitch_ids`, `security_group_id`, `network_mode` | + +### 完整的 MemoryCollection 属性 + +`MemoryCollection` 类同时继承了三组属性和 `ResourceBase`: + +```python +class MemoryCollection( + MemoryCollectionMutableProps, # 可变属性 + MemoryCollectionImmutableProps, # 不可变属性 + MemoryCollectionSystemProps, # 系统属性 + ResourceBase, # 资源基类(提供 from_inner_object 等) +): +``` + +--- + +## API 使用指南 + +### 1. 创建记忆集合 + +```python +from agentrun.memory_collection import ( + MemoryCollection, + MemoryCollectionCreateInput, + EmbedderConfig, + EmbedderConfigConfig, + LLMConfig, + LLMConfigConfig, + VectorStoreConfig, + VectorStoreConfigConfig, +) + +# 方式一:通过高层 API(推荐) +mc = MemoryCollection.create( + MemoryCollectionCreateInput( + memory_collection_name="my-collection", + type="mem0", + description="示例记忆集合", + embedder_config=EmbedderConfig( + model_service_name="my-embedder-service", + config=EmbedderConfigConfig(model="text-embedding-v3"), + ), + llm_config=LLMConfig( + model_service_name="my-llm-service", + config=LLMConfigConfig(model="qwen-plus"), + ), + vector_store_config=VectorStoreConfig( + provider="aliyun_tablestore", + config=VectorStoreConfigConfig( + endpoint="https://xxx.cn-hangzhou.ots.aliyuncs.com", + instance_name="my-instance", + collection_name="my-collection", + vector_dimension=1024, + ), + ), + ) +) + +# 方式二:通过 Client +from agentrun.memory_collection import MemoryCollectionClient + +client = MemoryCollectionClient() +mc = client.create(input=MemoryCollectionCreateInput(...)) +``` + +### 2. 获取记忆集合 + +```python +# 类方法(通过名称) +mc = MemoryCollection.get_by_name("my-collection") + +# 实例方法(刷新当前对象) +mc.refresh() # 等价于 mc.get() +``` + +### 3. 更新记忆集合 + +```python +from agentrun.memory_collection import MemoryCollectionUpdateInput + +# 类方法 +mc = MemoryCollection.update_by_name( + "my-collection", + MemoryCollectionUpdateInput(description="更新后的描述"), +) + +# 实例方法(就地更新) +mc.update(MemoryCollectionUpdateInput(description="更新后的描述")) +# mc 对象的属性会被自动更新 +``` + +### 4. 删除记忆集合 + +```python +# 类方法 +MemoryCollection.delete_by_name("my-collection") + +# 实例方法 +mc.delete() +``` + +### 5. 列出记忆集合 + +```python +# 列出所有(自动分页) +collections = MemoryCollection.list_all() + +# 带过滤条件 +collections = MemoryCollection.list_all( + memory_collection_name="my-collection", + status="READY", + type="mem0", +) + +# 列表项转完整对象 +for item in collections: + full_mc = item.to_memory_collection() +``` + +### 6. 异步调用 + +所有方法都有对应的 `_async` 版本: + +```python +import asyncio + +async def main(): + mc = await MemoryCollection.create_async(input=...) + mc = await MemoryCollection.get_by_name_async("my-collection") + await mc.update_async(MemoryCollectionUpdateInput(description="新描述")) + await mc.delete_async() + collections = await MemoryCollection.list_all_async() + +asyncio.run(main()) +``` + +--- + +## mem0 集成 + +MemoryCollection 提供了与 `agentrun-mem0ai` 包的集成能力,可以将平台上的 MemoryCollection 配置直接转换为可操作的 mem0 Memory 客户端。 + +### 前置条件 + +```bash +pip install agentrun-mem0ai +``` + +### 使用方式 + +```python +# 同步 +memory = MemoryCollection.to_mem0_memory("my-collection") +memory.add("用户喜欢吃苹果", user_id="user123") +results = memory.search("用户喜欢什么水果", user_id="user123") + +# 异步 +memory = await MemoryCollection.to_mem0_memory_async("my-collection") +await memory.add("用户喜欢吃苹果", user_id="user123") +``` + +### 内部工作流程 + +`to_mem0_memory` 方法内部会: + +1. 通过 `get_by_name` 获取 MemoryCollection 的完整配置 +2. 调用 `_build_mem0_config` 构建 mem0 兼容的配置字典,包括: + - **vector_store 配置**:支持 `aliyun_tablestore` 和 `alibabacloud_mysql` 两种 provider + - **llm 配置**:通过 `ModelService` 解析 base_url 和 api_key + - **embedder 配置**:通过 `ModelService` 解析 base_url 和 api_key,同步 vector_dimension +3. 使用配置字典创建 `Memory` / `AsyncMemory` 实例 + +### 向量存储 Provider 支持 + +| Provider | 配置来源 | 地址转换 | 认证方式 | +|----------|---------|---------|---------| +| `aliyun_tablestore` | `VectorStoreConfigConfig` | VPC 内网自动转公网 | AK/SK + 可选 STS Token | +| `alibabacloud_mysql` | `VectorStoreConfigMysqlConfig` | 环境变量 `AGENTRUN_MYSQL_PUBLIC_HOST` | Credential 获取密码 | + +### 相关环境变量 + +| 环境变量 | 说明 | +|----------|------| +| `AGENTRUN_MYSQL_PUBLIC_HOST` | MySQL 公网地址覆盖(当内网地址不可达时使用) | + +### 跨模块依赖 + +`to_mem0_memory` 会依赖以下其他模块: + +- `agentrun.model.ModelService` - 解析 LLM/Embedder 的 base_url +- `agentrun.credential.Credential` - 获取 API Key 或 MySQL 密码 + +--- + +## 代码生成机制 + +本模块使用 **模板 + 代码生成** 的模式来同时维护同步和异步代码。 + +### 工作流 + +``` +模板文件(手动编写异步代码) + │ + │ make codegen + ▼ +生成文件(自动生成同步+异步代码) +``` + +### 模板与生成文件的对应关系 + +| 模板文件 (手动维护) | 生成文件 (自动生成) | +|-------|---------| +| `__client_async_template.py` | `client.py` | +| `__memory_collection_async_template.py` | `memory_collection.py` | +| `codegen/configs/memory_collection_control_api.yaml` | `api/control.py` | + +### 开发流程 + +1. 修改对应的 `_async_template.py` 模板文件(只需要编写 async 方法) +2. 运行 `make codegen` 自动生成包含同步和异步方法的完整文件 +3. **切勿直接修改** `client.py`、`memory_collection.py`、`api/control.py` + +--- + +## 开发注意事项 + +### 新增/修改字段 + +1. 先检查底层 SDK (`alibabacloud_agentrun20250910`) 的输入输出参数定义 +2. 在 `model.py` 中添加/修改对应的数据模型字段 +3. 如需修改 Client 或 MemoryCollection 的行为,编辑对应的 `_async_template.py` 模板 +4. 运行 `make codegen` 重新生成代码 +5. 运行单元测试验证 + +### 错误处理链路 + +``` +底层 SDK 异常 + → ClientException / ServerException + → ClientError / ServerError (api/control.py) + → HTTPError (client.py) + → ResourceAlreadyExistError / ResourceNotExistError (client.py) +``` + +### 类型系统 + +- 所有 Model 继承自 `agentrun.utils.model.BaseModel` +- `BaseModel` 提供 `model_dump()`(序列化)和 `from_inner_object()`(从底层 SDK 对象反序列化) +- `ResourceBase` 提供 `update_self()` 方法,用于就地更新实例属性 +- `PageableInput` 提供分页参数(`page_number`、`page_size` 等) + +### 常用命令 + +```bash +# 代码生成 +make codegen + +# 类型检查 +uv run mypy --config-file mypy.ini . + +# 运行测试 +uv run pytest tests/ +``` diff --git a/agentrun/toolset/api/mcp.py b/agentrun/toolset/api/mcp.py index ebefddf..d2c382a 100644 --- a/agentrun/toolset/api/mcp.py +++ b/agentrun/toolset/api/mcp.py @@ -26,7 +26,7 @@ async def __aenter__(self): headers=self.config.get_headers(), timeout=timeout if timeout else 60, ) - (read, write) = await self.client.__aenter__() + read, write = await self.client.__aenter__() self.client_session = ClientSession(read, write) session = await self.client_session.__aenter__() diff --git a/codegen/codegen.py b/codegen/codegen.py index c4ff00b..21ea67a 100644 --- a/codegen/codegen.py +++ b/codegen/codegen.py @@ -205,6 +205,7 @@ def _generate_sync_code_for_file(async_file): if in_async_def_indent >= 0: content = ( line.replace("AsyncClient", "Client") + .replace("AsyncOTSClient", "OTSClient") .replace("AsyncOpenAI", "OpenAI") .replace("AsyncMemory", "Memory") .replace("async_playwright", "sync_playwright") diff --git a/examples/conversation_service_adk_agent.py b/examples/conversation_service_adk_agent.py new file mode 100644 index 0000000..2e2791a --- /dev/null +++ b/examples/conversation_service_adk_agent.py @@ -0,0 +1,128 @@ +"""Google ADK Agent —— 使用 OTSSessionService 持久化会话。 + +集成步骤: + Step 1: 初始化 SessionStore(OTS 后端) + Step 2: 创建 OTSSessionService + Step 3: 创建 ADK Agent + Runner,传入 session_service + Step 4: 通过 runner.run_async() 对话(自动持久化) + +使用方式: + export MEMORY_COLLECTION_NAME="your-collection-name" + export DASHSCOPE_API_KEY="your-dashscope-api-key" + uv run python examples/conversation_service_adk_agent.py +""" + +from __future__ import annotations + +import asyncio +import os +import sys +from typing import Any + +from dotenv import load_dotenv +from google.adk.agents import Agent # type: ignore[import-untyped] +from google.adk.models.lite_llm import LiteLlm +from google.adk.runners import Runner # type: ignore[import-untyped] +from google.genai import types # type: ignore[import-untyped] + +from agentrun.conversation_service import SessionStore +from agentrun.conversation_service.adapters import OTSSessionService + +load_dotenv() + +APP_NAME = "adk_chat_demo" +USER_ID = "demo_user" +# ADK 通过 LiteLLM 调用 DashScope OpenAI 兼容接口 +# 需设置: DASHSCOPE_API_KEY + + +# ── 工具定义 ────────────────────────────────────────────────── + + +def get_weather(city: str) -> dict[str, Any]: + """查询指定城市的天气信息。""" + data = { + "北京": {"weather": "晴", "temperature": "5~15°C"}, + "上海": {"weather": "多云", "temperature": "12~20°C"}, + } + return data.get(city, {"error": "暂无该城市数据"}) + + +# ── Step 1: 初始化 SessionStore ────────────────────────────── + +memory_collection_name = os.environ.get("MEMORY_COLLECTION_NAME", "") +if not memory_collection_name: + print("ERROR: 请设置环境变量 MEMORY_COLLECTION_NAME") + sys.exit(1) + +store = SessionStore.from_memory_collection(memory_collection_name) +store.init_tables() + +# ── Step 2: 创建 OTSSessionService ────────────────────────── + +session_service = OTSSessionService(session_store=store) + +# ── Step 3: 创建 Agent + Runner ───────────────────────────── + +custom_model = LiteLlm( + model="openai/qwen3-max", + api_key=os.environ.get("DASHSCOPE_API_KEY"), + api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", +) +agent = Agent( + name="smart_assistant", + model=custom_model, + instruction="你是一个友好的中文智能助手,用户问天气时调用 get_weather。", + tools=[get_weather], +) + +runner = Runner( + agent=agent, + app_name=APP_NAME, + session_service=session_service, +) + + +# ── Step 4: 对话(自动持久化到 OTS) ──────────────────────── + + +async def chat(session_id: str, text: str) -> str: + """发送消息并返回 Agent 回复。""" + content = types.Content( + role="user", + parts=[types.Part(text=text)], + ) + reply_parts: list[str] = [] + async for event in runner.run_async( + user_id=USER_ID, + session_id=session_id, + new_message=content, + ): + if event.is_final_response() and event.content and event.content.parts: + for part in event.content.parts: + if part.text: + reply_parts.append(part.text) + return "\n".join(reply_parts) + + +async def main() -> None: + session = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + state={"app:model_name": custom_model.model, "user:language": "zh-CN"}, + ) + print(f"会话已创建: {session.id}\n输入 /quit 退出\n") + + while True: + try: + user_input = input("你: ").strip() + except (EOFError, KeyboardInterrupt): + break + if not user_input or user_input == "/quit": + break + reply = await chat(session.id, user_input) + print(f"Agent: {reply}\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 78aef5e..ad3e0d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,10 @@ mcp = [ "mcp>=1.21.2; python_version >= '3.10'", ] +tablestore = [ + "tablestore>=6.1.0", +] + [dependency-groups] dev = [ "coverage>=7.10.7", diff --git a/tests/unittests/conversation_service/__init__.py b/tests/unittests/conversation_service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/conversation_service/test_adk_adapter.py b/tests/unittests/conversation_service/test_adk_adapter.py new file mode 100644 index 0000000..c9643fe --- /dev/null +++ b/tests/unittests/conversation_service/test_adk_adapter.py @@ -0,0 +1,824 @@ +"""ADK OTSSessionService 适配器单元测试。 + +通过 Mock SessionStore 测试 OTSSessionService 的核心逻辑: +- Event 序列化 round-trip(raw_event 列) +- 三级 state 映射(app / user / session) +- CRUD 操作 +- list_sessions(含 user_id=None 场景) +""" + +from __future__ import annotations + +import json +from typing import Any, Optional +from unittest.mock import AsyncMock, MagicMock, patch +import uuid + +from google.adk.events.event import Event # type: ignore[import-untyped] +from google.adk.events.event_actions import ( + EventActions, +) # type: ignore[import-untyped] +from google.adk.sessions.base_session_service import ( + GetSessionConfig, +) # type: ignore[import-untyped] +from google.adk.sessions.session import Session # type: ignore[import-untyped] +from google.genai import types # type: ignore[import-untyped] +import pytest + +from agentrun.conversation_service.adapters.adk_adapter import ( + _extract_display_content, + _extract_state_delta, + OTSSessionService, +) +from agentrun.conversation_service.model import ( + ConversationEvent, + ConversationSession, +) +from agentrun.conversation_service.session_store import SessionStore + +# ------------------------------------------------------------------- +# 工具函数测试 +# ------------------------------------------------------------------- + + +class TestExtractStateDelta: + """_extract_state_delta 单元测试。""" + + def test_empty_state(self) -> None: + result = _extract_state_delta({}) + assert result == { + "app": {}, + "user": {}, + "session": {}, + } + + def test_session_only(self) -> None: + result = _extract_state_delta({"key1": "val1", "key2": 42}) + assert result["session"] == { + "key1": "val1", + "key2": 42, + } + assert result["app"] == {} + assert result["user"] == {} + + def test_app_prefix(self) -> None: + result = _extract_state_delta({"app:config": "value"}) + assert result["app"] == {"config": "value"} + assert result["session"] == {} + + def test_user_prefix(self) -> None: + result = _extract_state_delta({"user:name": "Alice"}) + assert result["user"] == {"name": "Alice"} + assert result["session"] == {} + + def test_temp_prefix_excluded(self) -> None: + result = _extract_state_delta({"temp:cache": "data", "real_key": "val"}) + assert result["session"] == {"real_key": "val"} + assert result["app"] == {} + assert result["user"] == {} + + def test_mixed_prefixes(self) -> None: + result = _extract_state_delta({ + "app:setting": True, + "user:pref": "dark", + "session_var": 123, + "temp:scratch": "ignored", + }) + assert result["app"] == {"setting": True} + assert result["user"] == {"pref": "dark"} + assert result["session"] == {"session_var": 123} + + +class TestExtractDisplayContent: + """_extract_display_content 单元测试。""" + + def test_text_event(self) -> None: + event = Event( + author="user", + content=types.Content( + role="user", + parts=[types.Part(text="Hello world")], + ), + ) + result = _extract_display_content(event) + assert result["author"] == "user" + assert result["text"] == "Hello world" + + def test_function_call_event(self) -> None: + event = Event( + author="agent", + content=types.Content( + role="model", + parts=[ + types.Part( + function_call=types.FunctionCall( + name="get_weather", + args={"city": "Shanghai"}, + ) + ) + ], + ), + ) + result = _extract_display_content(event) + assert result["author"] == "agent" + assert "[call:get_weather]" in result["text"] + + def test_function_response_event(self) -> None: + event = Event( + author="agent", + content=types.Content( + role="model", + parts=[ + types.Part( + function_response=types.FunctionResponse( + name="get_weather", + response={"result": "sunny"}, + ) + ) + ], + ), + ) + result = _extract_display_content(event) + assert "[response:get_weather]" in result["text"] + + def test_empty_content(self) -> None: + event = Event(author="user") + result = _extract_display_content(event) + assert result["author"] == "user" + assert "text" not in result + + +# ------------------------------------------------------------------- +# Event 序列化 round-trip 测试 +# ------------------------------------------------------------------- + + +class TestEventSerialization: + """ADK Event 的 model_dump_json / model_validate_json round-trip。""" + + def test_text_event_roundtrip(self) -> None: + original = Event( + invocation_id="inv-1", + author="user", + content=types.Content( + role="user", + parts=[types.Part(text="Hello, how are you?")], + ), + ) + json_str = original.model_dump_json(by_alias=False) + restored = Event.model_validate_json(json_str) + + assert restored.author == "user" + assert restored.invocation_id == "inv-1" + assert restored.content is not None + assert restored.content.parts is not None + assert len(restored.content.parts) == 1 + assert restored.content.parts[0].text == "Hello, how are you?" + + def test_function_call_roundtrip(self) -> None: + original = Event( + invocation_id="inv-2", + author="agent", + content=types.Content( + role="model", + parts=[ + types.Part( + function_call=types.FunctionCall( + name="search", + args={ + "query": "weather", + "count": 5, + }, + ) + ) + ], + ), + ) + json_str = original.model_dump_json(by_alias=False) + restored = Event.model_validate_json(json_str) + + assert restored.author == "agent" + fc = restored.get_function_calls() + assert len(fc) == 1 + assert fc[0].name == "search" + assert fc[0].args == { + "query": "weather", + "count": 5, + } + + def test_function_response_roundtrip(self) -> None: + original = Event( + invocation_id="inv-3", + author="agent", + content=types.Content( + role="model", + parts=[ + types.Part( + function_response=types.FunctionResponse( + name="search", + response={ + "results": [ + "sunny", + "warm", + ] + }, + ) + ) + ], + ), + ) + json_str = original.model_dump_json(by_alias=False) + restored = Event.model_validate_json(json_str) + + fr = restored.get_function_responses() + assert len(fr) == 1 + assert fr[0].name == "search" + + def test_event_with_state_delta_roundtrip(self) -> None: + original = Event( + invocation_id="inv-4", + author="agent", + actions=EventActions( + state_delta={ + "counter": 42, + "app:global_count": 100, + "user:preference": "dark", + } + ), + content=types.Content( + role="model", + parts=[types.Part(text="Updated state")], + ), + ) + json_str = original.model_dump_json(by_alias=False) + restored = Event.model_validate_json(json_str) + + assert restored.actions.state_delta == { + "counter": 42, + "app:global_count": 100, + "user:preference": "dark", + } + + def test_multipart_event_roundtrip(self) -> None: + """多 Part 事件的 round-trip。""" + original = Event( + invocation_id="inv-5", + author="model", + content=types.Content( + role="model", + parts=[ + types.Part(text="Let me search..."), + types.Part( + function_call=types.FunctionCall( + name="web_search", + args={"q": "test"}, + ) + ), + ], + ), + ) + json_str = original.model_dump_json(by_alias=False) + restored = Event.model_validate_json(json_str) + + assert restored.content is not None + assert restored.content.parts is not None + assert len(restored.content.parts) == 2 + assert restored.content.parts[0].text == "Let me search..." + assert restored.content.parts[1].function_call.name == "web_search" + + +# ------------------------------------------------------------------- +# OTSSessionService Mock 测试 +# ------------------------------------------------------------------- + + +def _make_mock_store() -> MagicMock: + """创建 Mock SessionStore。 + + 同时设置同步方法(MagicMock)和异步方法(AsyncMock)的返回值。 + """ + store = MagicMock(spec=SessionStore) + + # 同步方法默认返回值 + store.get_app_state.return_value = {} + store.get_user_state.return_value = {} + store.get_session_state.return_value = {} + + # 异步方法使用 AsyncMock + store.create_session_async = AsyncMock() + store.get_session_async = AsyncMock(return_value=None) + store.list_sessions_async = AsyncMock(return_value=[]) + store.list_all_sessions_async = AsyncMock(return_value=[]) + store.delete_session_async = AsyncMock() + store.delete_events_async = AsyncMock() + store.update_session_async = AsyncMock() + store.append_event_async = AsyncMock() + store.get_events_async = AsyncMock(return_value=[]) + store.get_recent_events_async = AsyncMock(return_value=[]) + store.get_app_state_async = AsyncMock(return_value={}) + store.get_user_state_async = AsyncMock(return_value={}) + store.get_session_state_async = AsyncMock(return_value={}) + store.update_app_state_async = AsyncMock() + store.update_user_state_async = AsyncMock() + store.update_session_state_async = AsyncMock() + store.init_tables_async = AsyncMock() + + return store + + +class TestOTSSessionServiceCreateSession: + """create_session 测试。""" + + @pytest.mark.asyncio + async def test_create_basic(self) -> None: + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + session = await service.create_session( + app_name="test_app", + user_id="user_1", + ) + + assert session.app_name == "test_app" + assert session.user_id == "user_1" + assert session.id # 自动生成 UUID + assert session.events == [] + store.create_session_async.assert_called_once() + + @pytest.mark.asyncio + async def test_create_with_session_id(self) -> None: + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + session = await service.create_session( + app_name="test_app", + user_id="user_1", + session_id="my-session-id", + ) + + assert session.id == "my-session-id" + + @pytest.mark.asyncio + async def test_create_with_state(self) -> None: + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + await service.create_session( + app_name="test_app", + user_id="user_1", + state={ + "app:config": "val", + "user:pref": "dark", + "local": 123, + }, + ) + + store.update_app_state_async.assert_called_once_with( + "test_app", {"config": "val"} + ) + store.update_user_state_async.assert_called_once_with( + "test_app", "user_1", {"pref": "dark"} + ) + store.update_session_state_async.assert_called_once() + + def test_create_sync(self) -> None: + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + session = service.create_session_sync( + app_name="test_app", + user_id="user_1", + ) + + assert session.app_name == "test_app" + store.create_session.assert_called_once() + + +class TestOTSSessionServiceGetSession: + """get_session 测试。""" + + @pytest.mark.asyncio + async def test_get_nonexistent(self) -> None: + store = _make_mock_store() + store.get_session_async.return_value = None + service = OTSSessionService(session_store=store) + + result = await service.get_session( + app_name="test_app", + user_id="user_1", + session_id="nonexistent", + ) + + assert result is None + + @pytest.mark.asyncio + async def test_get_with_events(self) -> None: + store = _make_mock_store() + + ots_session = ConversationSession( + agent_id="test_app", + user_id="user_1", + session_id="s1", + created_at=1000000000, + updated_at=2000000000, + ) + store.get_session_async.return_value = ots_session + + # 构造一个 ADK Event 并序列化 + adk_event = Event( + invocation_id="inv-1", + author="user", + content=types.Content( + role="user", + parts=[types.Part(text="Hello")], + ), + ) + raw_json = adk_event.model_dump_json(by_alias=False) + + ots_event = ConversationEvent( + agent_id="test_app", + user_id="user_1", + session_id="s1", + seq_id=1, + type="adk_event", + content={"author": "user", "text": "Hello"}, + raw_event=raw_json, + ) + store.get_events_async.return_value = [ots_event] + + service = OTSSessionService(session_store=store) + result = await service.get_session( + app_name="test_app", + user_id="user_1", + session_id="s1", + ) + + assert result is not None + assert len(result.events) == 1 + assert result.events[0].author == "user" + parts = result.events[0].content.parts + assert parts is not None + assert parts[0].text == "Hello" + + @pytest.mark.asyncio + async def test_get_skips_events_without_raw_event( + self, + ) -> None: + """LangChain 事件(无 raw_event)应被跳过。""" + store = _make_mock_store() + + ots_session = ConversationSession( + agent_id="test_app", + user_id="user_1", + session_id="s1", + created_at=1000000000, + updated_at=2000000000, + ) + store.get_session_async.return_value = ots_session + + lc_event = ConversationEvent( + agent_id="test_app", + user_id="user_1", + session_id="s1", + seq_id=1, + type="message", + content={"lc_type": "human", "content": "Hi"}, + raw_event=None, + ) + store.get_events_async.return_value = [lc_event] + + service = OTSSessionService(session_store=store) + result = await service.get_session( + app_name="test_app", + user_id="user_1", + session_id="s1", + ) + + assert result is not None + assert len(result.events) == 0 + + @pytest.mark.asyncio + async def test_get_with_num_recent_events(self) -> None: + store = _make_mock_store() + + ots_session = ConversationSession( + agent_id="test_app", + user_id="user_1", + session_id="s1", + created_at=1000000000, + updated_at=2000000000, + ) + store.get_session_async.return_value = ots_session + store.get_recent_events_async.return_value = [] + + service = OTSSessionService(session_store=store) + config = GetSessionConfig(num_recent_events=5) + await service.get_session( + app_name="test_app", + user_id="user_1", + session_id="s1", + config=config, + ) + + store.get_recent_events_async.assert_called_once_with( + "test_app", "user_1", "s1", 5 + ) + + @pytest.mark.asyncio + async def test_get_with_merged_state(self) -> None: + store = _make_mock_store() + + ots_session = ConversationSession( + agent_id="test_app", + user_id="user_1", + session_id="s1", + created_at=1000000000, + updated_at=2000000000, + ) + store.get_session_async.return_value = ots_session + store.get_events_async.return_value = [] + + store.get_app_state_async.return_value = {"setting": "A"} + store.get_user_state_async.return_value = {"pref": "dark"} + store.get_session_state_async.return_value = {"counter": 42} + + service = OTSSessionService(session_store=store) + result = await service.get_session( + app_name="test_app", + user_id="user_1", + session_id="s1", + ) + + assert result is not None + # session state: 无前缀 + assert result.state["counter"] == 42 + # user state: user: 前缀 + assert result.state["user:pref"] == "dark" + # app state: app: 前缀 + assert result.state["app:setting"] == "A" + + +class TestOTSSessionServiceListSessions: + """list_sessions 测试。""" + + @pytest.mark.asyncio + async def test_list_with_user_id(self) -> None: + store = _make_mock_store() + store.list_sessions_async.return_value = [ + ConversationSession( + agent_id="app", + user_id="u1", + session_id="s1", + created_at=0, + updated_at=1000000000, + ) + ] + service = OTSSessionService(session_store=store) + + response = await service.list_sessions(app_name="app", user_id="u1") + + assert len(response.sessions) == 1 + assert response.sessions[0].id == "s1" + store.list_sessions_async.assert_called_once_with("app", "u1") + + @pytest.mark.asyncio + async def test_list_all_users(self) -> None: + store = _make_mock_store() + store.list_all_sessions_async.return_value = [ + ConversationSession( + agent_id="app", + user_id="u1", + session_id="s1", + created_at=0, + updated_at=1000000000, + ), + ConversationSession( + agent_id="app", + user_id="u2", + session_id="s2", + created_at=0, + updated_at=2000000000, + ), + ] + service = OTSSessionService(session_store=store) + + response = await service.list_sessions(app_name="app", user_id=None) + + assert len(response.sessions) == 2 + store.list_all_sessions_async.assert_called_once_with("app") + + +class TestOTSSessionServiceDeleteSession: + """delete_session 测试。""" + + @pytest.mark.asyncio + async def test_delete(self) -> None: + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + await service.delete_session( + app_name="app", + user_id="u1", + session_id="s1", + ) + + store.delete_session_async.assert_called_once_with("app", "u1", "s1") + + def test_delete_sync(self) -> None: + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + service.delete_session_sync( + app_name="app", + user_id="u1", + session_id="s1", + ) + + store.delete_session.assert_called_once_with("app", "u1", "s1") + + +class TestOTSSessionServiceAppendEvent: + """append_event 测试。""" + + @pytest.mark.asyncio + async def test_append_text_event(self) -> None: + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + session = Session( + id="s1", + app_name="app", + user_id="u1", + state={}, + events=[], + ) + + event = Event( + invocation_id="inv-1", + author="user", + content=types.Content( + role="user", + parts=[types.Part(text="Hello")], + ), + ) + + result = await service.append_event(session, event) + + # 事件被追加到 session.events + assert len(session.events) == 1 + assert session.events[0] is result + + # store.append_event_async 被调用,且传递了 raw_event + store.append_event_async.assert_called_once() + call_kwargs = store.append_event_async.call_args + # raw_event 参数不为 None + assert call_kwargs.kwargs.get("raw_event") is not None or ( + len(call_kwargs.args) > 5 and call_kwargs.args[5] is not None + ) + + @pytest.mark.asyncio + async def test_append_skips_partial(self) -> None: + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + session = Session( + id="s1", + app_name="app", + user_id="u1", + state={}, + events=[], + ) + + event = Event( + invocation_id="inv-1", + author="user", + partial=True, + content=types.Content( + role="user", + parts=[types.Part(text="partial")], + ), + ) + + result = await service.append_event(session, event) + + assert result.partial is True + store.append_event_async.assert_not_called() + + @pytest.mark.asyncio + async def test_append_persists_state_delta(self) -> None: + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + session = Session( + id="s1", + app_name="app", + user_id="u1", + state={}, + events=[], + ) + + event = Event( + invocation_id="inv-1", + author="agent", + actions=EventActions( + state_delta={ + "counter": 1, + "app:global": "yes", + "user:theme": "dark", + } + ), + content=types.Content( + role="model", + parts=[types.Part(text="Done")], + ), + ) + + await service.append_event(session, event) + + # 三级 state 分别被更新 + store.update_app_state_async.assert_called_once_with( + "app", {"global": "yes"} + ) + store.update_user_state_async.assert_called_once_with( + "app", "u1", {"theme": "dark"} + ) + store.update_session_state_async.assert_called_once_with( + "app", "u1", "s1", {"counter": 1} + ) + + @pytest.mark.asyncio + async def test_append_updates_session_state_in_memory( + self, + ) -> None: + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + session = Session( + id="s1", + app_name="app", + user_id="u1", + state={"existing": "value"}, + events=[], + ) + + event = Event( + invocation_id="inv-1", + author="agent", + actions=EventActions(state_delta={"new_key": "new_value"}), + content=types.Content( + role="model", + parts=[types.Part(text="ok")], + ), + ) + + await service.append_event(session, event) + + # 内存中的 session.state 应该已更新 + assert session.state["new_key"] == "new_value" + assert session.state["existing"] == "value" + + @pytest.mark.asyncio + async def test_append_raw_event_roundtrip(self) -> None: + """验证 append_event 写入的 raw_event 可以被 get_session 还原。""" + store = _make_mock_store() + service = OTSSessionService(session_store=store) + + session = Session( + id="s1", + app_name="app", + user_id="u1", + state={}, + events=[], + ) + + original_event = Event( + invocation_id="inv-roundtrip", + author="agent", + content=types.Content( + role="model", + parts=[ + types.Part(text="Answer"), + types.Part( + function_call=types.FunctionCall( + name="tool1", + args={"x": 1}, + ) + ), + ], + ), + ) + + await service.append_event(session, original_event) + + # 获取 store.append_event_async 被调用时的 raw_event 参数 + call_args = store.append_event_async.call_args + raw_event_str: str = call_args.kwargs["raw_event"] + + # 还原 + restored = Event.model_validate_json(raw_event_str) + assert restored.invocation_id == "inv-roundtrip" + assert restored.author == "agent" + parts = restored.content.parts + assert parts is not None + assert parts[0].text == "Answer" + assert parts[1].function_call.name == "tool1" diff --git a/tests/unittests/conversation_service/test_langchain_adapter.py b/tests/unittests/conversation_service/test_langchain_adapter.py new file mode 100644 index 0000000..5904a82 --- /dev/null +++ b/tests/unittests/conversation_service/test_langchain_adapter.py @@ -0,0 +1,544 @@ +"""conversation_service.adapters.langchain_adapter 单元测试。 + +通过 Mock SessionStore 测试 OTSChatMessageHistory 的核心逻辑: +- messages 属性(读取事件并反序列化为 LangChain BaseMessage) +- add_messages(写入消息) +- clear(清空事件) +- auto_create_session +- _message_to_dict / _event_to_message 序列化/反序列化 +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + +from agentrun.conversation_service.adapters.langchain_adapter import ( + _event_to_message, + _message_to_dict, + OTSChatMessageHistory, +) +from agentrun.conversation_service.model import ( + ConversationEvent, + ConversationSession, +) +from agentrun.conversation_service.session_store import SessionStore + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_store() -> MagicMock: + """创建 Mock SessionStore。""" + store = MagicMock(spec=SessionStore) + store.get_session.return_value = ConversationSession( + agent_id="a", + user_id="u", + session_id="s", + created_at=100, + updated_at=200, + ) + store.get_events.return_value = [] + store.delete_events.return_value = 0 + return store + + +# --------------------------------------------------------------------------- +# _message_to_dict 序列化 +# --------------------------------------------------------------------------- + + +class TestMessageToDict: + """_message_to_dict 测试。""" + + def test_human_message(self) -> None: + msg = HumanMessage(content="hello") + result = _message_to_dict(msg) + assert result["lc_type"] == "human" + assert result["content"] == "hello" + + def test_ai_message(self) -> None: + msg = AIMessage(content="response") + result = _message_to_dict(msg) + assert result["lc_type"] == "ai" + assert result["content"] == "response" + + def test_system_message(self) -> None: + msg = SystemMessage(content="you are a helper") + result = _message_to_dict(msg) + assert result["lc_type"] == "system" + + def test_tool_message(self) -> None: + msg = ToolMessage(content="result", tool_call_id="tc-1") + result = _message_to_dict(msg) + assert result["lc_type"] == "tool" + assert result["tool_call_id"] == "tc-1" + + def test_ai_message_with_tool_calls(self) -> None: + msg = AIMessage( + content="", + tool_calls=[ + { + "name": "search", + "args": {"q": "test"}, + "id": "tc-1", + "type": "tool_call", + }, + ], + ) + result = _message_to_dict(msg) + assert "tool_calls" in result + assert len(result["tool_calls"]) == 1 + + def test_with_additional_kwargs(self) -> None: + msg = HumanMessage( + content="hi", + additional_kwargs={"extra": "data"}, + ) + result = _message_to_dict(msg) + assert result["additional_kwargs"] == {"extra": "data"} + + def test_with_name_and_id(self) -> None: + msg = HumanMessage(content="hi", name="user", id="msg-1") + result = _message_to_dict(msg) + assert result["name"] == "user" + assert result["id"] == "msg-1" + + def test_minimal_fields(self) -> None: + """空 additional_kwargs 等不应出现在结果中。""" + msg = HumanMessage(content="hi") + result = _message_to_dict(msg) + assert "additional_kwargs" not in result + assert "name" not in result + + def test_ai_message_with_response_metadata(self) -> None: + msg = AIMessage( + content="ok", + response_metadata={"model": "gpt-4"}, + ) + result = _message_to_dict(msg) + assert result["response_metadata"] == {"model": "gpt-4"} + + +# --------------------------------------------------------------------------- +# _event_to_message 反序列化 +# --------------------------------------------------------------------------- + + +class TestEventToMessage: + """_event_to_message 测试。""" + + def test_human_message(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={"lc_type": "human", "content": "hello"}, + ) + msg = _event_to_message(event) + assert isinstance(msg, HumanMessage) + assert msg.content == "hello" + + def test_ai_message(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={"lc_type": "ai", "content": "response"}, + ) + msg = _event_to_message(event) + assert isinstance(msg, AIMessage) + assert msg.content == "response" + + def test_system_message(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={"lc_type": "system", "content": "be helpful"}, + ) + msg = _event_to_message(event) + assert isinstance(msg, SystemMessage) + + def test_tool_message(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={ + "lc_type": "tool", + "content": "result", + "tool_call_id": "tc-1", + }, + ) + msg = _event_to_message(event) + assert isinstance(msg, ToolMessage) + assert msg.tool_call_id == "tc-1" + + def test_tool_message_without_tool_call_id(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={"lc_type": "tool", "content": "result"}, + ) + msg = _event_to_message(event) + assert isinstance(msg, ToolMessage) + assert msg.tool_call_id == "" + + def test_unknown_type_fallback(self) -> None: + """未知类型回退到 HumanMessage。""" + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={"lc_type": "unknown_type", "content": "hi"}, + ) + msg = _event_to_message(event) + assert isinstance(msg, HumanMessage) + + def test_missing_lc_type(self) -> None: + """无 lc_type 默认 human。""" + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={"content": "hi"}, + ) + msg = _event_to_message(event) + assert isinstance(msg, HumanMessage) + + def test_ai_with_tool_calls(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={ + "lc_type": "ai", + "content": "", + "tool_calls": [ + { + "name": "fn", + "args": {}, + "id": "tc-1", + "type": "tool_call", + }, + ], + }, + ) + msg = _event_to_message(event) + assert isinstance(msg, AIMessage) + assert len(msg.tool_calls) == 1 + + def test_with_additional_kwargs(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={ + "lc_type": "human", + "content": "hi", + "additional_kwargs": {"extra": True}, + }, + ) + msg = _event_to_message(event) + assert msg.additional_kwargs == {"extra": True} + + def test_with_name_and_id(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={ + "lc_type": "ai", + "content": "ok", + "name": "assistant", + "id": "msg-1", + }, + ) + msg = _event_to_message(event) + assert msg.name == "assistant" + assert msg.id == "msg-1" + + def test_with_response_metadata(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + content={ + "lc_type": "ai", + "content": "ok", + "response_metadata": {"model": "gpt-4"}, + }, + ) + msg = _event_to_message(event) + assert msg.response_metadata == {"model": "gpt-4"} + + +# --------------------------------------------------------------------------- +# _message_to_dict + _event_to_message round-trip +# --------------------------------------------------------------------------- + + +class TestMessageRoundTrip: + """消息序列化/反序列化 round-trip。""" + + def test_human_roundtrip(self) -> None: + original = HumanMessage(content="hello world") + data = _message_to_dict(original) + event = ConversationEvent("a", "u", "s", 1, "message", content=data) + restored = _event_to_message(event) + assert isinstance(restored, HumanMessage) + assert restored.content == "hello world" + + def test_ai_with_tool_calls_roundtrip(self) -> None: + original = AIMessage( + content="let me search", + tool_calls=[ + { + "name": "search", + "args": {"q": "test"}, + "id": "tc-1", + "type": "tool_call", + }, + ], + ) + data = _message_to_dict(original) + event = ConversationEvent("a", "u", "s", 1, "message", content=data) + restored = _event_to_message(event) + assert isinstance(restored, AIMessage) + assert len(restored.tool_calls) == 1 + + def test_tool_message_roundtrip(self) -> None: + original = ToolMessage(content="result data", tool_call_id="tc-1") + data = _message_to_dict(original) + event = ConversationEvent("a", "u", "s", 1, "message", content=data) + restored = _event_to_message(event) + assert isinstance(restored, ToolMessage) + assert restored.tool_call_id == "tc-1" + + def test_system_roundtrip(self) -> None: + original = SystemMessage(content="be helpful") + data = _message_to_dict(original) + event = ConversationEvent("a", "u", "s", 1, "message", content=data) + restored = _event_to_message(event) + assert isinstance(restored, SystemMessage) + + +# --------------------------------------------------------------------------- +# OTSChatMessageHistory +# --------------------------------------------------------------------------- + + +class TestOTSChatMessageHistoryInit: + """OTSChatMessageHistory 初始化测试。""" + + def test_auto_create_session_new(self) -> None: + """Session 不存在时自动创建。""" + store = _make_mock_store() + store.get_session.return_value = None # Session 不存在 + + history = OTSChatMessageHistory( + session_store=store, + agent_id="a", + user_id="u", + session_id="s", + ) + + store.create_session.assert_called_once_with( + "a", "u", "s", framework="langchain" + ) + + def test_auto_create_session_exists(self) -> None: + """Session 已存在时不创建。""" + store = _make_mock_store() + history = OTSChatMessageHistory( + session_store=store, + agent_id="a", + user_id="u", + session_id="s", + ) + store.create_session.assert_not_called() + + def test_auto_create_disabled(self) -> None: + """禁用自动创建。""" + store = _make_mock_store() + store.get_session.return_value = None + + history = OTSChatMessageHistory( + session_store=store, + agent_id="a", + user_id="u", + session_id="s", + auto_create_session=False, + ) + + store.get_session.assert_not_called() + store.create_session.assert_not_called() + + +class TestOTSChatMessageHistoryMessages: + """messages 属性测试。""" + + def test_empty(self) -> None: + store = _make_mock_store() + history = OTSChatMessageHistory( + session_store=store, + agent_id="a", + user_id="u", + session_id="s", + ) + assert history.messages == [] + + def test_with_events(self) -> None: + store = _make_mock_store() + events = [ + ConversationEvent( + "a", + "u", + "s", + 1, + "message", + content={"lc_type": "human", "content": "hi"}, + ), + ConversationEvent( + "a", + "u", + "s", + 2, + "message", + content={"lc_type": "ai", "content": "hello"}, + ), + ] + store.get_events.return_value = events + + history = OTSChatMessageHistory( + session_store=store, + agent_id="a", + user_id="u", + session_id="s", + ) + messages = history.messages + + assert len(messages) == 2 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + + def test_skips_bad_events(self) -> None: + """反序列化失败的事件应被跳过。""" + store = _make_mock_store() + events = [ + ConversationEvent( + "a", + "u", + "s", + 1, + "message", + content={"lc_type": "human", "content": "hi"}, + ), + ConversationEvent( + "a", + "u", + "s", + 2, + "bad_type", + content={"invalid": True}, # 缺少 content 字段不会报错 + ), + ] + store.get_events.return_value = events + + history = OTSChatMessageHistory( + session_store=store, + agent_id="a", + user_id="u", + session_id="s", + ) + messages = history.messages + # 两个都应成功(第二个会 fallback 到 HumanMessage) + assert len(messages) == 2 + + +class TestOTSChatMessageHistoryAddMessages: + """add_messages 测试。""" + + def test_add_single(self) -> None: + store = _make_mock_store() + history = OTSChatMessageHistory( + session_store=store, + agent_id="a", + user_id="u", + session_id="s", + ) + + history.add_messages([HumanMessage(content="hello")]) + + store.append_event.assert_called_once() + call_args = store.append_event.call_args + assert call_args[0][0] == "a" # agent_id + assert call_args[0][1] == "u" # user_id + assert call_args[0][2] == "s" # session_id + + def test_add_multiple(self) -> None: + store = _make_mock_store() + history = OTSChatMessageHistory( + session_store=store, + agent_id="a", + user_id="u", + session_id="s", + ) + + history.add_messages([ + HumanMessage(content="hello"), + AIMessage(content="hi"), + SystemMessage(content="be kind"), + ]) + + assert store.append_event.call_count == 3 + + +class TestOTSChatMessageHistoryClear: + """clear 测试。""" + + def test_clear(self) -> None: + store = _make_mock_store() + history = OTSChatMessageHistory( + session_store=store, + agent_id="a", + user_id="u", + session_id="s", + ) + + history.clear() + + store.delete_events.assert_called_once_with("a", "u", "s") diff --git a/tests/unittests/conversation_service/test_model.py b/tests/unittests/conversation_service/test_model.py new file mode 100644 index 0000000..027caeb --- /dev/null +++ b/tests/unittests/conversation_service/test_model.py @@ -0,0 +1,224 @@ +"""conversation_service.model 单元测试。 + +覆盖 ConversationSession、ConversationEvent、StateData 数据类 +以及 StateScope 枚举。 +""" + +from __future__ import annotations + +import json + +from agentrun.conversation_service.model import ( + ConversationEvent, + ConversationSession, + DEFAULT_APP_STATE_TABLE, + DEFAULT_CONVERSATION_SEARCH_INDEX, + DEFAULT_CONVERSATION_SECONDARY_INDEX, + DEFAULT_CONVERSATION_TABLE, + DEFAULT_EVENT_TABLE, + DEFAULT_STATE_TABLE, + DEFAULT_USER_STATE_TABLE, + StateData, + StateScope, +) + +# --------------------------------------------------------------------------- +# 表名常量 +# --------------------------------------------------------------------------- + + +class TestTableConstants: + """表名常量校验。""" + + def test_default_table_names(self) -> None: + assert DEFAULT_CONVERSATION_TABLE == "conversation" + assert DEFAULT_EVENT_TABLE == "event" + assert DEFAULT_STATE_TABLE == "state" + assert DEFAULT_APP_STATE_TABLE == "app_state" + assert DEFAULT_USER_STATE_TABLE == "user_state" + assert ( + DEFAULT_CONVERSATION_SECONDARY_INDEX + == "conversation_secondary_index" + ) + assert DEFAULT_CONVERSATION_SEARCH_INDEX == "conversation_search_index" + + +# --------------------------------------------------------------------------- +# StateScope 枚举 +# --------------------------------------------------------------------------- + + +class TestStateScope: + """StateScope 枚举测试。""" + + def test_values(self) -> None: + assert StateScope.APP.value == "app" + assert StateScope.USER.value == "user" + assert StateScope.SESSION.value == "session" + + def test_is_str_enum(self) -> None: + assert isinstance(StateScope.APP, str) + assert StateScope.APP == "app" + + +# --------------------------------------------------------------------------- +# ConversationSession +# --------------------------------------------------------------------------- + + +class TestConversationSession: + """ConversationSession 数据类测试。""" + + def test_required_fields(self) -> None: + session = ConversationSession( + agent_id="agent1", + user_id="user1", + session_id="sess1", + created_at=1000, + updated_at=2000, + ) + assert session.agent_id == "agent1" + assert session.user_id == "user1" + assert session.session_id == "sess1" + assert session.created_at == 1000 + assert session.updated_at == 2000 + + def test_default_values(self) -> None: + session = ConversationSession( + agent_id="a", + user_id="u", + session_id="s", + created_at=0, + updated_at=0, + ) + assert session.is_pinned is False + assert session.summary is None + assert session.labels is None + assert session.framework is None + assert session.extensions is None + assert session.version == 0 + + def test_all_fields(self) -> None: + session = ConversationSession( + agent_id="a", + user_id="u", + session_id="s", + created_at=100, + updated_at=200, + is_pinned=True, + summary="hello", + labels='["tag1"]', + framework="adk", + extensions={"key": "val"}, + version=3, + ) + assert session.is_pinned is True + assert session.summary == "hello" + assert session.labels == '["tag1"]' + assert session.framework == "adk" + assert session.extensions == {"key": "val"} + assert session.version == 3 + + +# --------------------------------------------------------------------------- +# ConversationEvent +# --------------------------------------------------------------------------- + + +class TestConversationEvent: + """ConversationEvent 数据类测试。""" + + def test_required_fields(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="message", + ) + assert event.agent_id == "a" + assert event.seq_id == 1 + assert event.type == "message" + + def test_default_values(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=None, + type="test", + ) + assert event.content == {} + assert event.created_at == 0 + assert event.updated_at == 0 + assert event.version == 0 + assert event.raw_event is None + + def test_content_as_json(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="msg", + content={"key": "值", "num": 42}, + ) + result = event.content_as_json() + parsed = json.loads(result) + assert parsed == {"key": "值", "num": 42} + # ensure_ascii=False 应保留中文 + assert "值" in result + + def test_content_from_json(self) -> None: + raw = '{"key": "value", "nested": {"a": 1}}' + result = ConversationEvent.content_from_json(raw) + assert result == {"key": "value", "nested": {"a": 1}} + + def test_content_as_json_empty(self) -> None: + event = ConversationEvent( + agent_id="a", + user_id="u", + session_id="s", + seq_id=1, + type="msg", + ) + assert event.content_as_json() == "{}" + + def test_content_from_json_empty(self) -> None: + result = ConversationEvent.content_from_json("{}") + assert result == {} + + +# --------------------------------------------------------------------------- +# StateData +# --------------------------------------------------------------------------- + + +class TestStateData: + """StateData 数据类测试。""" + + def test_default_values(self) -> None: + sd = StateData() + assert sd.state == {} + assert sd.created_at == 0 + assert sd.updated_at == 0 + assert sd.version == 0 + + def test_with_values(self) -> None: + sd = StateData( + state={"counter": 42}, + created_at=100, + updated_at=200, + version=3, + ) + assert sd.state == {"counter": 42} + assert sd.created_at == 100 + assert sd.updated_at == 200 + assert sd.version == 3 + + def test_state_default_factory_isolation(self) -> None: + """确保不同实例的 state 字典是独立的。""" + sd1 = StateData() + sd2 = StateData() + sd1.state["key"] = "val" + assert "key" not in sd2.state diff --git a/tests/unittests/conversation_service/test_ots_backend.py b/tests/unittests/conversation_service/test_ots_backend.py new file mode 100644 index 0000000..ac21dba --- /dev/null +++ b/tests/unittests/conversation_service/test_ots_backend.py @@ -0,0 +1,1953 @@ +"""conversation_service.ots_backend 单元测试。 + +通过 Mock OTSClient 测试 OTSBackend 的同步和异步方法: +- 建表(含表已存在跳过) +- Session CRUD +- Event CRUD(含 batch 删除) +- State CRUD(含分片/拼接逻辑) +- 内部辅助方法 +""" + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import AsyncMock, call, MagicMock, patch + +import pytest +from tablestore import OTSServiceError, Row # type: ignore[import-untyped] + +from agentrun.conversation_service.model import ( + ConversationEvent, + ConversationSession, + StateData, + StateScope, +) +from agentrun.conversation_service.ots_backend import ( + _BATCH_WRITE_LIMIT, + OTSBackend, +) +from agentrun.conversation_service.utils import MAX_COLUMN_SIZE + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_client() -> MagicMock: + """创建 mock OTSClient。""" + return MagicMock() + + +def _make_backend( + client: MagicMock | None = None, + table_prefix: str = "", +) -> OTSBackend: + """创建带 mock client 的 OTSBackend。""" + if client is None: + client = _make_mock_client() + return OTSBackend(client, table_prefix=table_prefix) + + +def _make_session_row( + agent_id: str = "agent1", + user_id: str = "user1", + session_id: str = "sess1", + created_at: int = 1000, + updated_at: int = 2000, + is_pinned: bool = False, + summary: str | None = None, + labels: str | None = None, + framework: str | None = None, + extensions: dict | None = None, + version: int = 0, +) -> Row: + """构造 OTS 返回的 Session Row。""" + pk = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ] + attrs = [ + ("created_at", created_at, 0), + ("updated_at", updated_at, 0), + ("is_pinned", is_pinned, 0), + ("version", version, 0), + ] + if summary is not None: + attrs.append(("summary", summary, 0)) + if labels is not None: + attrs.append(("labels", labels, 0)) + if framework is not None: + attrs.append(("framework", framework, 0)) + if extensions is not None: + attrs.append(("extensions", json.dumps(extensions), 0)) + return Row(pk, attrs) + + +def _make_event_row( + agent_id: str = "agent1", + user_id: str = "user1", + session_id: str = "sess1", + seq_id: int = 1, + event_type: str = "message", + content: dict | None = None, + created_at: int = 1000, + updated_at: int = 2000, + version: int = 0, + raw_event: str | None = None, +) -> Row: + """构造 OTS 返回的 Event Row。""" + pk = [ + ("agent_id", agent_id), + ("user_id", user_id), + ("session_id", session_id), + ("seq_id", seq_id), + ] + content_json = json.dumps(content or {}) + attrs = [ + ("type", event_type, 0), + ("content", content_json, 0), + ("created_at", created_at, 0), + ("updated_at", updated_at, 0), + ("version", version, 0), + ] + if raw_event is not None: + attrs.append(("raw_event", raw_event, 0)) + return Row(pk, attrs) + + +def _make_state_row( + pk: list[tuple[str, Any]], + state: dict | None = None, + chunk_count: int = 0, + chunks: list[str] | None = None, + created_at: int = 1000, + updated_at: int = 2000, + version: int = 1, +) -> Row: + """构造 OTS 返回的 State Row。""" + attrs = [ + ("chunk_count", chunk_count, 0), + ("created_at", created_at, 0), + ("updated_at", updated_at, 0), + ("version", version, 0), + ] + if chunk_count == 0 and state is not None: + attrs.append(("state", json.dumps(state), 0)) + if chunks is not None: + for idx, chunk in enumerate(chunks): + attrs.append((f"state_{idx}", chunk, 0)) + return Row(pk, attrs) + + +def _make_null_row() -> Row: + """构造空 Row(模拟行不存在)。""" + row = MagicMock(spec=Row) + row.primary_key = None + return row + + +# --------------------------------------------------------------------------- +# 建表测试 +# --------------------------------------------------------------------------- + + +class TestInitTables: + """建表方法测试。""" + + def test_init_tables_success(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + backend.init_tables() + + # 应创建 5 张表 + assert client.create_table.call_count == 5 + + def test_init_tables_already_exist(self) -> None: + client = _make_mock_client() + err = OTSServiceError( + 409, "OTSObjectAlreadyExist", "table already exist" + ) + client.create_table.side_effect = err + + backend = _make_backend(client) + # 不应抛异常 + backend.init_tables() + + def test_init_tables_other_error(self) -> None: + client = _make_mock_client() + err = OTSServiceError(500, "InternalError", "internal error") + client.create_table.side_effect = err + + backend = _make_backend(client) + with pytest.raises(OTSServiceError): + backend.init_tables() + + def test_create_event_table_other_error(self) -> None: + """Event 表创建非已存在错误应抛出。""" + client = _make_mock_client() + # conversation table 正常,event table 抛异常 + err = OTSServiceError(500, "InternalError", "internal error") + client.create_table.side_effect = [None, err] + + backend = _make_backend(client) + with pytest.raises(OTSServiceError): + backend.init_core_tables() + + def test_create_state_table_other_error(self) -> None: + """State 表创建非已存在错误应抛出。""" + client = _make_mock_client() + err = OTSServiceError(500, "InternalError", "internal error") + client.create_table.side_effect = err + + backend = _make_backend(client) + with pytest.raises(OTSServiceError): + backend.init_state_tables() + + def test_init_core_tables(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + backend.init_core_tables() + # Conversation + Event = 2 次 + assert client.create_table.call_count == 2 + + def test_init_state_tables(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + backend.init_state_tables() + # state + app_state + user_state = 3 次 + assert client.create_table.call_count == 3 + + def test_init_search_index_success(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + backend.init_search_index() + client.create_search_index.assert_called_once() + + def test_init_search_index_already_exist(self) -> None: + client = _make_mock_client() + err = OTSServiceError( + 409, "OTSObjectAlreadyExist", "index already exist" + ) + client.create_search_index.side_effect = err + + backend = _make_backend(client) + backend.init_search_index() # 不抛异常 + + def test_init_search_index_other_error(self) -> None: + client = _make_mock_client() + err = OTSServiceError(500, "InternalError", "internal error") + client.create_search_index.side_effect = err + + backend = _make_backend(client) + with pytest.raises(OTSServiceError): + backend.init_search_index() + + def test_table_prefix(self) -> None: + client = _make_mock_client() + backend = _make_backend(client, table_prefix="myprefix_") + assert backend._conversation_table == "myprefix_conversation" + assert backend._event_table == "myprefix_event" + assert backend._state_table == "myprefix_state" + assert backend._app_state_table == "myprefix_app_state" + assert backend._user_state_table == "myprefix_user_state" + + +# --------------------------------------------------------------------------- +# Session CRUD +# --------------------------------------------------------------------------- + + +class TestPutSession: + """put_session 测试。""" + + def test_basic(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + + session = ConversationSession( + agent_id="a", + user_id="u", + session_id="s", + created_at=100, + updated_at=200, + ) + backend.put_session(session) + client.put_row.assert_called_once() + + def test_with_optional_fields(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + + session = ConversationSession( + agent_id="a", + user_id="u", + session_id="s", + created_at=100, + updated_at=200, + summary="hello", + labels='["tag"]', + framework="adk", + extensions={"key": "val"}, + ) + backend.put_session(session) + client.put_row.assert_called_once() + + +class TestGetSession: + """get_session 测试。""" + + def test_found(self) -> None: + client = _make_mock_client() + row = _make_session_row( + summary="test", + framework="adk", + extensions={"k": "v"}, + ) + client.get_row.return_value = (None, row, None) + + backend = _make_backend(client) + result = backend.get_session("agent1", "user1", "sess1") + + assert result is not None + assert result.agent_id == "agent1" + assert result.summary == "test" + assert result.framework == "adk" + assert result.extensions == {"k": "v"} + + def test_not_found_none(self) -> None: + client = _make_mock_client() + client.get_row.return_value = (None, None, None) + + backend = _make_backend(client) + result = backend.get_session("a", "u", "s") + assert result is None + + def test_not_found_null_pk(self) -> None: + client = _make_mock_client() + null_row = _make_null_row() + client.get_row.return_value = (None, null_row, None) + + backend = _make_backend(client) + result = backend.get_session("a", "u", "s") + assert result is None + + +class TestDeleteSessionRow: + """delete_session_row 测试。""" + + def test_delete(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + backend.delete_session_row("a", "u", "s") + client.delete_row.assert_called_once() + + +class TestUpdateSession: + """update_session 乐观锁更新测试。""" + + def test_update(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + backend.update_session( + "a", + "u", + "s", + {"updated_at": 999, "version": 2}, + version=1, + ) + client.update_row.assert_called_once() + + +class TestListSessions: + """list_sessions 通过二级索引测试。""" + + def test_list_desc(self) -> None: + client = _make_mock_client() + row = _make_session_row() + # 二级索引 PK 包含 updated_at + idx_row = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("updated_at", 2000), + ("session_id", "s1"), + ], + [ + ("summary", "test", 0), + ], + ) + client.get_range.return_value = (None, None, [idx_row], None) + + backend = _make_backend(client) + result = backend.list_sessions("a", "u", order_desc=True) + + assert len(result) == 1 + assert result[0].session_id == "s1" + + def test_list_asc(self) -> None: + client = _make_mock_client() + client.get_range.return_value = (None, None, [], None) + + backend = _make_backend(client) + result = backend.list_sessions("a", "u", order_desc=False) + assert result == [] + + def test_list_with_limit(self) -> None: + client = _make_mock_client() + idx_rows = [ + Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("updated_at", i), + ("session_id", f"s{i}"), + ], + [], + ) + for i in range(5) + ] + client.get_range.return_value = (None, None, idx_rows, None) + + backend = _make_backend(client) + result = backend.list_sessions("a", "u", limit=3) + assert len(result) == 3 + + def test_list_with_pagination(self) -> None: + """模拟分页:第一次返回 next_token,第二次返回完。""" + client = _make_mock_client() + row1 = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("updated_at", 200), + ("session_id", "s1"), + ], + [], + ) + row2 = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("updated_at", 100), + ("session_id", "s2"), + ], + [], + ) + client.get_range.side_effect = [ + (None, "token", [row1], None), + (None, None, [row2], None), + ] + + backend = _make_backend(client) + result = backend.list_sessions("a", "u") + assert len(result) == 2 + + +class TestListAllSessions: + """list_all_sessions 主表扫描测试。""" + + def test_list_all(self) -> None: + client = _make_mock_client() + row = _make_session_row() + client.get_range.return_value = (None, None, [row], None) + + backend = _make_backend(client) + result = backend.list_all_sessions("agent1") + assert len(result) == 1 + + def test_list_all_with_limit(self) -> None: + client = _make_mock_client() + rows = [_make_session_row(session_id=f"s{i}") for i in range(5)] + client.get_range.return_value = (None, None, rows, None) + + backend = _make_backend(client) + result = backend.list_all_sessions("agent1", limit=2) + assert len(result) == 2 + + def test_list_all_with_pagination(self) -> None: + client = _make_mock_client() + r1 = _make_session_row(session_id="s1") + r2 = _make_session_row(session_id="s2") + client.get_range.side_effect = [ + (None, "token", [r1], None), + (None, None, [r2], None), + ] + + backend = _make_backend(client) + result = backend.list_all_sessions("agent1") + assert len(result) == 2 + + +class TestSearchSessions: + """search_sessions 多元索引搜索测试。""" + + def test_basic_search(self) -> None: + client = _make_mock_client() + # search 返回格式 + response = MagicMock() + response.rows = [( + [("agent_id", "a"), ("user_id", "u"), ("session_id", "s1")], + [ + ("created_at", 100, 0), + ("updated_at", 200, 0), + ("is_pinned", False, 0), + ("version", 0, 0), + ], + )] + response.total_count = 1 + client.search.return_value = response + + backend = _make_backend(client) + sessions, total = backend.search_sessions("a") + + assert len(sessions) == 1 + assert total == 1 + assert sessions[0].agent_id == "a" + + def test_search_with_all_filters(self) -> None: + client = _make_mock_client() + response = MagicMock() + response.rows = [] + response.total_count = 0 + client.search.return_value = response + + backend = _make_backend(client) + sessions, total = backend.search_sessions( + "a", + user_id="u", + summary_keyword="hello", + labels="tag", + framework="adk", + updated_after=100, + updated_before=200, + is_pinned=True, + ) + + assert sessions == [] + assert total == 0 + + def test_search_is_pinned_false(self) -> None: + client = _make_mock_client() + response = MagicMock() + response.rows = [] + response.total_count = 0 + client.search.return_value = response + + backend = _make_backend(client) + backend.search_sessions("a", is_pinned=False) + client.search.assert_called_once() + + def test_search_with_row_objects(self) -> None: + """测试 search 返回 Row 对象而非 tuple 的情况。""" + client = _make_mock_client() + response = MagicMock() + row = _make_session_row() + response.rows = [row] + response.total_count = 1 + client.search.return_value = response + + backend = _make_backend(client) + sessions, total = backend.search_sessions("agent1") + assert len(sessions) == 1 + + def test_search_total_count_none(self) -> None: + client = _make_mock_client() + response = MagicMock() + response.rows = [] + response.total_count = None + client.search.return_value = response + + backend = _make_backend(client) + _, total = backend.search_sessions("a") + assert total == 0 + + +# --------------------------------------------------------------------------- +# Event CRUD +# --------------------------------------------------------------------------- + + +class TestPutEvent: + """put_event 测试。""" + + def test_basic(self) -> None: + client = _make_mock_client() + return_row = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 42), + ], + [], + ) + client.put_row.return_value = (None, return_row) + + backend = _make_backend(client) + seq_id = backend.put_event("a", "u", "s", "msg", {"key": "val"}) + + assert seq_id == 42 + client.put_row.assert_called_once() + + def test_with_raw_event(self) -> None: + client = _make_mock_client() + return_row = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 1), + ], + [], + ) + client.put_row.return_value = (None, return_row) + + backend = _make_backend(client) + seq_id = backend.put_event( + "a", + "u", + "s", + "msg", + {}, + raw_event='{"raw": "data"}', + ) + assert seq_id == 1 + + def test_with_timestamps(self) -> None: + client = _make_mock_client() + return_row = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 5), + ], + [], + ) + client.put_row.return_value = (None, return_row) + + backend = _make_backend(client) + seq_id = backend.put_event( + "a", + "u", + "s", + "msg", + {}, + created_at=100, + updated_at=200, + ) + assert seq_id == 5 + + def test_return_row_none(self) -> None: + client = _make_mock_client() + client.put_row.return_value = (None, None) + + backend = _make_backend(client) + seq_id = backend.put_event("a", "u", "s", "msg", {}) + assert seq_id == 0 + + def test_return_row_no_pk(self) -> None: + client = _make_mock_client() + return_row = MagicMock(spec=Row) + return_row.primary_key = None + client.put_row.return_value = (None, return_row) + + backend = _make_backend(client) + seq_id = backend.put_event("a", "u", "s", "msg", {}) + assert seq_id == 0 + + +class TestGetEvents: + """get_events 测试。""" + + def test_forward(self) -> None: + client = _make_mock_client() + row = _make_event_row(seq_id=1, content={"msg": "hi"}) + client.get_range.return_value = (None, None, [row], None) + + backend = _make_backend(client) + events = backend.get_events("a", "u", "s", direction="FORWARD") + + assert len(events) == 1 + assert events[0].seq_id == 1 + assert events[0].content == {"msg": "hi"} + + def test_backward(self) -> None: + client = _make_mock_client() + client.get_range.return_value = (None, None, [], None) + + backend = _make_backend(client) + events = backend.get_events("a", "u", "s", direction="BACKWARD") + assert events == [] + + def test_with_limit(self) -> None: + client = _make_mock_client() + rows = [_make_event_row(seq_id=i) for i in range(5)] + client.get_range.return_value = (None, None, rows, None) + + backend = _make_backend(client) + events = backend.get_events("a", "u", "s", limit=2) + assert len(events) == 2 + + def test_with_raw_event(self) -> None: + client = _make_mock_client() + row = _make_event_row(raw_event='{"raw": "data"}') + client.get_range.return_value = (None, None, [row], None) + + backend = _make_backend(client) + events = backend.get_events("a", "u", "s") + assert events[0].raw_event == '{"raw": "data"}' + + def test_pagination(self) -> None: + client = _make_mock_client() + r1 = _make_event_row(seq_id=1) + r2 = _make_event_row(seq_id=2) + client.get_range.side_effect = [ + (None, "token", [r1], None), + (None, None, [r2], None), + ] + + backend = _make_backend(client) + events = backend.get_events("a", "u", "s") + assert len(events) == 2 + + def test_content_non_string(self) -> None: + """content 列为非 string 的情况。""" + client = _make_mock_client() + pk = [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 1), + ] + attrs = [ + ("type", "msg", 0), + ("content", 12345, 0), # 非 string + ("created_at", 100, 0), + ("updated_at", 200, 0), + ("version", 0, 0), + ] + row = Row(pk, attrs) + client.get_range.return_value = (None, None, [row], None) + + backend = _make_backend(client) + events = backend.get_events("a", "u", "s") + assert events[0].content == {} + + +class TestDeleteEventsBySession: + """delete_events_by_session 测试。""" + + def test_no_events(self) -> None: + client = _make_mock_client() + client.get_range.return_value = (None, None, [], None) + + backend = _make_backend(client) + deleted = backend.delete_events_by_session("a", "u", "s") + assert deleted == 0 + client.batch_write_row.assert_not_called() + + def test_batch_delete(self) -> None: + client = _make_mock_client() + # 3 个 event + rows = [] + for i in range(3): + row = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", i), + ], + [], + ) + rows.append(row) + client.get_range.return_value = (None, None, rows, None) + + backend = _make_backend(client) + deleted = backend.delete_events_by_session("a", "u", "s") + assert deleted == 3 + client.batch_write_row.assert_called_once() + + def test_batch_delete_pagination(self) -> None: + """模拟 event 扫描分页。""" + client = _make_mock_client() + r1 = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 1), + ], + [], + ) + r2 = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 2), + ], + [], + ) + client.get_range.side_effect = [ + (None, "token", [r1], None), + (None, None, [r2], None), + ] + + backend = _make_backend(client) + deleted = backend.delete_events_by_session("a", "u", "s") + assert deleted == 2 + + +# --------------------------------------------------------------------------- +# State CRUD +# --------------------------------------------------------------------------- + + +class TestPutState: + """put_state 测试。""" + + def test_first_write_no_chunk(self) -> None: + """首次写入,不分片。""" + client = _make_mock_client() + backend = _make_backend(client) + + backend.put_state( + StateScope.SESSION, + "a", + "u", + "s", + state={"key": "val"}, + version=0, + ) + client.update_row.assert_called_once() + + def test_first_write_with_chunk(self) -> None: + """首次写入,需要分片。""" + client = _make_mock_client() + backend = _make_backend(client) + + big_state = {"data": "x" * (MAX_COLUMN_SIZE + 100)} + backend.put_state( + StateScope.SESSION, + "a", + "u", + "s", + state=big_state, + version=0, + ) + client.update_row.assert_called_once() + + def test_update_no_chunk_to_no_chunk(self) -> None: + """更新:旧无分片 → 新无分片。""" + client = _make_mock_client() + # _get_chunk_count 需要 get_row + chunk_row = Row( + [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")], + [("chunk_count", 0, 0)], + ) + client.get_row.return_value = (None, chunk_row, None) + + backend = _make_backend(client) + backend.put_state( + StateScope.SESSION, + "a", + "u", + "s", + state={"key": "new"}, + version=1, + ) + client.update_row.assert_called_once() + + def test_update_chunk_to_no_chunk(self) -> None: + """更新:旧有分片 → 新无分片,应删除旧 state_N 列。""" + client = _make_mock_client() + chunk_row = Row( + [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")], + [("chunk_count", 2, 0)], + ) + client.get_row.return_value = (None, chunk_row, None) + + backend = _make_backend(client) + backend.put_state( + StateScope.SESSION, + "a", + "u", + "s", + state={"key": "small"}, + version=1, + ) + # 检查 update_row 被调用,且包含 DELETE_ALL + call_args = client.update_row.call_args + row_arg = call_args[0][1] # Row 参数 + # row.attribute_columns 是 update_of_attribute_columns dict + assert "DELETE_ALL" in row_arg.attribute_columns + + def test_update_no_chunk_to_chunk(self) -> None: + """更新:旧无分片 → 新有分片,应删除 state 列。""" + client = _make_mock_client() + chunk_row = Row( + [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")], + [("chunk_count", 0, 0)], + ) + client.get_row.return_value = (None, chunk_row, None) + + backend = _make_backend(client) + big_state = {"data": "x" * (MAX_COLUMN_SIZE + 100)} + backend.put_state( + StateScope.SESSION, + "a", + "u", + "s", + state=big_state, + version=1, + ) + call_args = client.update_row.call_args + row_arg = call_args[0][1] + assert "DELETE_ALL" in row_arg.attribute_columns + assert "state" in row_arg.attribute_columns["DELETE_ALL"] + + def test_update_more_chunks_to_fewer(self) -> None: + """更新:旧 4 个分片 → 新 2 个分片,应删除多余分片。""" + client = _make_mock_client() + chunk_row = Row( + [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")], + [("chunk_count", 4, 0)], + ) + client.get_row.return_value = (None, chunk_row, None) + + backend = _make_backend(client) + # 构造刚好 2 个分片的数据 + data_size = MAX_COLUMN_SIZE + 10 # 刚好超过 1 个分片 + big_state = {"d": "a" * data_size} + backend.put_state( + StateScope.SESSION, + "a", + "u", + "s", + state=big_state, + version=1, + ) + call_args = client.update_row.call_args + row_arg = call_args[0][1] + if "DELETE_ALL" in row_arg.attribute_columns: + # 应删除 state_2 和 state_3 + deleted = row_arg.attribute_columns["DELETE_ALL"] + assert "state_2" in deleted + assert "state_3" in deleted + + def test_scope_app(self) -> None: + """APP scope 使用 app_state 表。""" + client = _make_mock_client() + backend = _make_backend(client) + backend.put_state(StateScope.APP, "a", "", "", state={}, version=0) + call_args = client.update_row.call_args + assert call_args[0][0] == "app_state" + + def test_scope_user(self) -> None: + """USER scope 使用 user_state 表。""" + client = _make_mock_client() + backend = _make_backend(client) + backend.put_state(StateScope.USER, "a", "u", "", state={}, version=0) + call_args = client.update_row.call_args + assert call_args[0][0] == "user_state" + + +class TestGetState: + """get_state 测试。""" + + def test_not_found(self) -> None: + client = _make_mock_client() + client.get_row.return_value = (None, None, None) + + backend = _make_backend(client) + result = backend.get_state(StateScope.SESSION, "a", "u", "s") + assert result is None + + def test_not_found_null_pk(self) -> None: + client = _make_mock_client() + null_row = _make_null_row() + client.get_row.return_value = (None, null_row, None) + + backend = _make_backend(client) + result = backend.get_state(StateScope.SESSION, "a", "u", "s") + assert result is None + + def test_no_chunk(self) -> None: + """无分片正常读取。""" + client = _make_mock_client() + pk = [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")] + row = _make_state_row(pk, state={"key": "val"}, chunk_count=0) + client.get_row.return_value = (None, row, None) + + backend = _make_backend(client) + result = backend.get_state(StateScope.SESSION, "a", "u", "s") + + assert result is not None + assert result.state == {"key": "val"} + assert result.version == 1 + + def test_with_chunks(self) -> None: + """有分片,拼接读取。""" + client = _make_mock_client() + pk = [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")] + state_json = json.dumps({"data": "hello"}) + chunk1 = state_json[:5] + chunk2 = state_json[5:] + row = _make_state_row(pk, chunk_count=2, chunks=[chunk1, chunk2]) + client.get_row.return_value = (None, row, None) + + backend = _make_backend(client) + result = backend.get_state(StateScope.SESSION, "a", "u", "s") + + assert result is not None + assert result.state == {"data": "hello"} + + def test_missing_chunk_raises(self) -> None: + """分片缺失应抛异常。""" + client = _make_mock_client() + pk = [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")] + # chunk_count=2 但只有 state_0 + row = _make_state_row(pk, chunk_count=2, chunks=["partial"]) + client.get_row.return_value = (None, row, None) + + backend = _make_backend(client) + with pytest.raises(ValueError, match="Missing state chunk"): + backend.get_state(StateScope.SESSION, "a", "u", "s") + + def test_no_state_column(self) -> None: + """chunk_count=0 但无 state 列,返回 None。""" + client = _make_mock_client() + pk = [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")] + row = Row(pk, [("chunk_count", 0, 0), ("version", 1, 0)]) + client.get_row.return_value = (None, row, None) + + backend = _make_backend(client) + result = backend.get_state(StateScope.SESSION, "a", "u", "s") + assert result is None + + def test_scope_app(self) -> None: + client = _make_mock_client() + pk = [("agent_id", "a")] + row = _make_state_row(pk, state={"app": True}) + client.get_row.return_value = (None, row, None) + + backend = _make_backend(client) + result = backend.get_state(StateScope.APP, "a", "", "") + assert result is not None + assert result.state == {"app": True} + + +class TestDeleteStateRow: + """delete_state_row 测试。""" + + def test_delete_session_state(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + backend.delete_state_row(StateScope.SESSION, "a", "u", "s") + client.delete_row.assert_called_once() + + def test_delete_app_state(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + backend.delete_state_row(StateScope.APP, "a", "", "") + call_args = client.delete_row.call_args + assert call_args[0][0] == "app_state" + + def test_delete_user_state(self) -> None: + client = _make_mock_client() + backend = _make_backend(client) + backend.delete_state_row(StateScope.USER, "a", "u", "") + call_args = client.delete_row.call_args + assert call_args[0][0] == "user_state" + + +# --------------------------------------------------------------------------- +# 内部辅助方法 +# --------------------------------------------------------------------------- + + +class TestHelperMethods: + """内部辅助方法测试。""" + + def test_attrs_to_dict(self) -> None: + attrs = [("name", "val1", 0), ("count", 42, 0)] + result = OTSBackend._attrs_to_dict(attrs) + assert result == {"name": "val1", "count": 42} + + def test_attrs_to_dict_none(self) -> None: + result = OTSBackend._attrs_to_dict(None) # type: ignore[arg-type] + assert result == {} + + def test_pk_to_dict(self) -> None: + pk = [("agent_id", "a"), ("user_id", "u")] + result = OTSBackend._pk_to_dict(pk) + assert result == {"agent_id": "a", "user_id": "u"} + + def test_pk_to_dict_none(self) -> None: + result = OTSBackend._pk_to_dict(None) # type: ignore[arg-type] + assert result == {} + + def test_resolve_state_table_app(self) -> None: + backend = _make_backend() + table, pk = backend._resolve_state_table_and_pk( + StateScope.APP, "a", "u", "s" + ) + assert table == "app_state" + assert pk == [("agent_id", "a")] + + def test_resolve_state_table_user(self) -> None: + backend = _make_backend() + table, pk = backend._resolve_state_table_and_pk( + StateScope.USER, "a", "u", "s" + ) + assert table == "user_state" + assert pk == [("agent_id", "a"), ("user_id", "u")] + + def test_resolve_state_table_session(self) -> None: + backend = _make_backend() + table, pk = backend._resolve_state_table_and_pk( + StateScope.SESSION, "a", "u", "s" + ) + assert table == "state" + assert pk == [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")] + + def test_row_to_session_with_extensions(self) -> None: + backend = _make_backend() + row = _make_session_row(extensions={"k": "v"}) + session = backend._row_to_session(row) + assert session.extensions == {"k": "v"} + + def test_row_to_session_without_extensions(self) -> None: + backend = _make_backend() + row = _make_session_row() + session = backend._row_to_session(row) + assert session.extensions is None + + def test_row_to_event(self) -> None: + backend = _make_backend() + row = _make_event_row( + content={"msg": "hello"}, + raw_event='{"raw": true}', + ) + event = backend._row_to_event(row) + assert event.content == {"msg": "hello"} + assert event.raw_event == '{"raw": true}' + + def test_row_to_session_from_index(self) -> None: + backend = _make_backend() + idx_row = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("updated_at", 2000), + ("session_id", "s"), + ], + [("summary", "test", 0), ("extensions", '{"k": "v"}', 0)], + ) + session = backend._row_to_session_from_index(idx_row) + assert session.session_id == "s" + assert session.updated_at == 2000 + assert session.created_at == 0 # 二级索引不含 created_at + assert session.extensions == {"k": "v"} + + def test_get_chunk_count(self) -> None: + client = _make_mock_client() + pk = [("agent_id", "a")] + row = Row(pk, [("chunk_count", 3, 0)]) + client.get_row.return_value = (None, row, None) + + backend = _make_backend(client) + count = backend._get_chunk_count("app_state", pk) + assert count == 3 + + def test_get_chunk_count_no_row(self) -> None: + client = _make_mock_client() + client.get_row.return_value = (None, None, None) + + backend = _make_backend(client) + count = backend._get_chunk_count("app_state", [("agent_id", "a")]) + assert count == 0 + + def test_get_chunk_count_null_pk(self) -> None: + client = _make_mock_client() + null_row = _make_null_row() + client.get_row.return_value = (None, null_row, None) + + backend = _make_backend(client) + count = backend._get_chunk_count("app_state", [("agent_id", "a")]) + assert count == 0 + + +# =========================================================================== +# 异步方法测试 +# =========================================================================== + + +def _make_async_backend( + async_client: MagicMock | None = None, + table_prefix: str = "", +) -> OTSBackend: + """创建带 async mock client 的 OTSBackend。""" + if async_client is None: + async_client = MagicMock() + # 让所有方法返回 AsyncMock + async_client.create_table = AsyncMock() + async_client.create_search_index = AsyncMock() + async_client.put_row = AsyncMock(return_value=(None, None)) + async_client.get_row = AsyncMock(return_value=(None, None, None)) + async_client.get_range = AsyncMock(return_value=(None, None, [], None)) + async_client.update_row = AsyncMock() + async_client.delete_row = AsyncMock() + async_client.batch_write_row = AsyncMock() + async_client.search = AsyncMock() + return OTSBackend( + ots_client=None, + table_prefix=table_prefix, + async_ots_client=async_client, + ) + + +class TestInitTablesAsync: + """异步建表测试。""" + + @pytest.mark.asyncio + async def test_init_tables(self) -> None: + backend = _make_async_backend() + await backend.init_tables_async() + assert backend._async_client.create_table.call_count == 5 + + @pytest.mark.asyncio + async def test_init_tables_already_exist(self) -> None: + async_client = MagicMock() + err = OTSServiceError(409, "OTSObjectAlreadyExist", "already exist") + async_client.create_table = AsyncMock(side_effect=err) + backend = _make_async_backend(async_client) + await backend.init_tables_async() + + @pytest.mark.asyncio + async def test_init_tables_other_error(self) -> None: + async_client = MagicMock() + err = OTSServiceError(500, "InternalError", "error") + async_client.create_table = AsyncMock(side_effect=err) + backend = _make_async_backend(async_client) + with pytest.raises(OTSServiceError): + await backend.init_tables_async() + + @pytest.mark.asyncio + async def test_init_core_tables(self) -> None: + backend = _make_async_backend() + await backend.init_core_tables_async() + assert backend._async_client.create_table.call_count == 2 + + @pytest.mark.asyncio + async def test_init_state_tables(self) -> None: + backend = _make_async_backend() + await backend.init_state_tables_async() + assert backend._async_client.create_table.call_count == 3 + + @pytest.mark.asyncio + async def test_init_search_index(self) -> None: + backend = _make_async_backend() + await backend.init_search_index_async() + backend._async_client.create_search_index.assert_called_once() + + @pytest.mark.asyncio + async def test_init_search_index_already_exist(self) -> None: + async_client = MagicMock() + async_client.create_table = AsyncMock() + err = OTSServiceError(409, "OTSObjectAlreadyExist", "already exist") + async_client.create_search_index = AsyncMock(side_effect=err) + backend = _make_async_backend(async_client) + await backend.init_search_index_async() + + @pytest.mark.asyncio + async def test_init_search_index_other_error(self) -> None: + async_client = MagicMock() + async_client.create_table = AsyncMock() + err = OTSServiceError(500, "InternalError", "error") + async_client.create_search_index = AsyncMock(side_effect=err) + backend = _make_async_backend(async_client) + with pytest.raises(OTSServiceError): + await backend.init_search_index_async() + + +class TestSessionCrudAsync: + """异步 Session CRUD 测试。""" + + @pytest.mark.asyncio + async def test_put_session(self) -> None: + backend = _make_async_backend() + session = ConversationSession("a", "u", "s", 100, 200) + await backend.put_session_async(session) + backend._async_client.put_row.assert_called_once() + + @pytest.mark.asyncio + async def test_put_session_with_optional(self) -> None: + backend = _make_async_backend() + session = ConversationSession( + "a", + "u", + "s", + 100, + 200, + summary="hi", + labels='["t"]', + framework="adk", + extensions={"k": "v"}, + ) + await backend.put_session_async(session) + backend._async_client.put_row.assert_called_once() + + @pytest.mark.asyncio + async def test_get_session_found(self) -> None: + backend = _make_async_backend() + row = _make_session_row() + backend._async_client.get_row = AsyncMock( + return_value=(None, row, None) + ) + + result = await backend.get_session_async("agent1", "user1", "sess1") + assert result is not None + assert result.agent_id == "agent1" + + @pytest.mark.asyncio + async def test_get_session_not_found(self) -> None: + backend = _make_async_backend() + result = await backend.get_session_async("a", "u", "s") + assert result is None + + @pytest.mark.asyncio + async def test_get_session_null_pk(self) -> None: + backend = _make_async_backend() + null_row = _make_null_row() + backend._async_client.get_row = AsyncMock( + return_value=(None, null_row, None) + ) + result = await backend.get_session_async("a", "u", "s") + assert result is None + + @pytest.mark.asyncio + async def test_delete_session_row(self) -> None: + backend = _make_async_backend() + await backend.delete_session_row_async("a", "u", "s") + backend._async_client.delete_row.assert_called_once() + + @pytest.mark.asyncio + async def test_update_session(self) -> None: + backend = _make_async_backend() + await backend.update_session_async("a", "u", "s", {"version": 2}, 1) + backend._async_client.update_row.assert_called_once() + + @pytest.mark.asyncio + async def test_list_sessions_desc(self) -> None: + backend = _make_async_backend() + idx_row = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("updated_at", 2000), + ("session_id", "s1"), + ], + [], + ) + backend._async_client.get_range = AsyncMock( + return_value=(None, None, [idx_row], None) + ) + result = await backend.list_sessions_async("a", "u", order_desc=True) + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_list_sessions_asc(self) -> None: + backend = _make_async_backend() + result = await backend.list_sessions_async("a", "u", order_desc=False) + assert result == [] + + @pytest.mark.asyncio + async def test_list_sessions_with_limit(self) -> None: + backend = _make_async_backend() + rows = [ + Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("updated_at", i), + ("session_id", f"s{i}"), + ], + [], + ) + for i in range(5) + ] + backend._async_client.get_range = AsyncMock( + return_value=(None, None, rows, None) + ) + result = await backend.list_sessions_async("a", "u", limit=3) + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_list_sessions_pagination(self) -> None: + backend = _make_async_backend() + r1 = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("updated_at", 200), + ("session_id", "s1"), + ], + [], + ) + r2 = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("updated_at", 100), + ("session_id", "s2"), + ], + [], + ) + backend._async_client.get_range = AsyncMock( + side_effect=[ + (None, "token", [r1], None), + (None, None, [r2], None), + ] + ) + result = await backend.list_sessions_async("a", "u") + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_list_all_sessions(self) -> None: + backend = _make_async_backend() + row = _make_session_row() + backend._async_client.get_range = AsyncMock( + return_value=(None, None, [row], None) + ) + result = await backend.list_all_sessions_async("agent1") + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_list_all_sessions_with_limit(self) -> None: + backend = _make_async_backend() + rows = [_make_session_row(session_id=f"s{i}") for i in range(5)] + backend._async_client.get_range = AsyncMock( + return_value=(None, None, rows, None) + ) + result = await backend.list_all_sessions_async("agent1", limit=2) + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_list_all_sessions_pagination(self) -> None: + backend = _make_async_backend() + r1 = _make_session_row(session_id="s1") + r2 = _make_session_row(session_id="s2") + backend._async_client.get_range = AsyncMock( + side_effect=[ + (None, "token", [r1], None), + (None, None, [r2], None), + ] + ) + result = await backend.list_all_sessions_async("agent1") + assert len(result) == 2 + + +class TestSearchSessionsAsync: + """异步 search_sessions 测试。""" + + @pytest.mark.asyncio + async def test_basic(self) -> None: + backend = _make_async_backend() + response = MagicMock() + response.rows = [( + [("agent_id", "a"), ("user_id", "u"), ("session_id", "s1")], + [ + ("created_at", 100, 0), + ("updated_at", 200, 0), + ("is_pinned", False, 0), + ("version", 0, 0), + ], + )] + response.total_count = 1 + backend._async_client.search = AsyncMock(return_value=response) + + sessions, total = await backend.search_sessions_async("a") + assert len(sessions) == 1 + assert total == 1 + + @pytest.mark.asyncio + async def test_with_all_filters(self) -> None: + backend = _make_async_backend() + response = MagicMock() + response.rows = [] + response.total_count = 0 + backend._async_client.search = AsyncMock(return_value=response) + + sessions, total = await backend.search_sessions_async( + "a", + user_id="u", + summary_keyword="hi", + labels="t", + framework="adk", + updated_after=100, + updated_before=200, + is_pinned=True, + ) + assert total == 0 + + @pytest.mark.asyncio + async def test_total_count_none(self) -> None: + backend = _make_async_backend() + response = MagicMock() + response.rows = [] + response.total_count = None + backend._async_client.search = AsyncMock(return_value=response) + + _, total = await backend.search_sessions_async("a") + assert total == 0 + + +class TestEventCrudAsync: + """异步 Event CRUD 测试。""" + + @pytest.mark.asyncio + async def test_put_event(self) -> None: + backend = _make_async_backend() + return_row = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 42), + ], + [], + ) + backend._async_client.put_row = AsyncMock( + return_value=(None, return_row) + ) + + seq_id = await backend.put_event_async( + "a", "u", "s", "msg", {"key": "val"} + ) + assert seq_id == 42 + + @pytest.mark.asyncio + async def test_put_event_with_raw_event(self) -> None: + backend = _make_async_backend() + return_row = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 1), + ], + [], + ) + backend._async_client.put_row = AsyncMock( + return_value=(None, return_row) + ) + + seq_id = await backend.put_event_async( + "a", + "u", + "s", + "msg", + {}, + raw_event='{"raw": true}', + ) + assert seq_id == 1 + + @pytest.mark.asyncio + async def test_put_event_with_timestamps(self) -> None: + backend = _make_async_backend() + return_row = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 5), + ], + [], + ) + backend._async_client.put_row = AsyncMock( + return_value=(None, return_row) + ) + + seq_id = await backend.put_event_async( + "a", + "u", + "s", + "msg", + {}, + created_at=100, + updated_at=200, + ) + assert seq_id == 5 + + @pytest.mark.asyncio + async def test_put_event_return_none(self) -> None: + backend = _make_async_backend() + seq_id = await backend.put_event_async("a", "u", "s", "msg", {}) + assert seq_id == 0 + + @pytest.mark.asyncio + async def test_put_event_return_no_pk(self) -> None: + backend = _make_async_backend() + return_row = MagicMock(spec=Row) + return_row.primary_key = None + backend._async_client.put_row = AsyncMock( + return_value=(None, return_row) + ) + + seq_id = await backend.put_event_async("a", "u", "s", "msg", {}) + assert seq_id == 0 + + @pytest.mark.asyncio + async def test_get_events_forward(self) -> None: + backend = _make_async_backend() + row = _make_event_row(seq_id=1) + backend._async_client.get_range = AsyncMock( + return_value=(None, None, [row], None) + ) + + events = await backend.get_events_async( + "a", "u", "s", direction="FORWARD" + ) + assert len(events) == 1 + assert events[0].seq_id == 1 + + @pytest.mark.asyncio + async def test_get_events_backward(self) -> None: + backend = _make_async_backend() + events = await backend.get_events_async( + "a", "u", "s", direction="BACKWARD" + ) + assert events == [] + + @pytest.mark.asyncio + async def test_get_events_with_limit(self) -> None: + backend = _make_async_backend() + rows = [_make_event_row(seq_id=i) for i in range(5)] + backend._async_client.get_range = AsyncMock( + return_value=(None, None, rows, None) + ) + + events = await backend.get_events_async("a", "u", "s", limit=2) + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_get_events_pagination(self) -> None: + backend = _make_async_backend() + r1 = _make_event_row(seq_id=1) + r2 = _make_event_row(seq_id=2) + backend._async_client.get_range = AsyncMock( + side_effect=[ + (None, "token", [r1], None), + (None, None, [r2], None), + ] + ) + events = await backend.get_events_async("a", "u", "s") + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_delete_events_no_events(self) -> None: + backend = _make_async_backend() + deleted = await backend.delete_events_by_session_async("a", "u", "s") + assert deleted == 0 + + @pytest.mark.asyncio + async def test_delete_events_batch(self) -> None: + backend = _make_async_backend() + rows = [ + Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", i), + ], + [], + ) + for i in range(3) + ] + backend._async_client.get_range = AsyncMock( + return_value=(None, None, rows, None) + ) + + deleted = await backend.delete_events_by_session_async("a", "u", "s") + assert deleted == 3 + backend._async_client.batch_write_row.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_events_pagination(self) -> None: + backend = _make_async_backend() + r1 = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 1), + ], + [], + ) + r2 = Row( + [ + ("agent_id", "a"), + ("user_id", "u"), + ("session_id", "s"), + ("seq_id", 2), + ], + [], + ) + backend._async_client.get_range = AsyncMock( + side_effect=[ + (None, "token", [r1], None), + (None, None, [r2], None), + ] + ) + deleted = await backend.delete_events_by_session_async("a", "u", "s") + assert deleted == 2 + + +class TestStateCrudAsync: + """异步 State CRUD 测试。""" + + @pytest.mark.asyncio + async def test_put_state_first_write(self) -> None: + backend = _make_async_backend() + await backend.put_state_async( + StateScope.SESSION, "a", "u", "s", {"k": "v"}, 0 + ) + backend._async_client.update_row.assert_called_once() + + @pytest.mark.asyncio + async def test_put_state_with_chunks(self) -> None: + backend = _make_async_backend() + big_state = {"d": "x" * (MAX_COLUMN_SIZE + 100)} + await backend.put_state_async( + StateScope.SESSION, "a", "u", "s", big_state, 0 + ) + backend._async_client.update_row.assert_called_once() + + @pytest.mark.asyncio + async def test_put_state_update_clean_old_chunks(self) -> None: + backend = _make_async_backend() + chunk_row = Row( + [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")], + [("chunk_count", 2, 0)], + ) + backend._async_client.get_row = AsyncMock( + return_value=(None, chunk_row, None) + ) + + await backend.put_state_async( + StateScope.SESSION, "a", "u", "s", {"k": "v"}, 1 + ) + call_args = backend._async_client.update_row.call_args + row_arg = call_args[0][1] + assert "DELETE_ALL" in row_arg.attribute_columns + + @pytest.mark.asyncio + async def test_put_state_update_no_chunk_to_chunk(self) -> None: + backend = _make_async_backend() + chunk_row = Row( + [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")], + [("chunk_count", 0, 0)], + ) + backend._async_client.get_row = AsyncMock( + return_value=(None, chunk_row, None) + ) + + big_state = {"d": "x" * (MAX_COLUMN_SIZE + 100)} + await backend.put_state_async( + StateScope.SESSION, "a", "u", "s", big_state, 1 + ) + call_args = backend._async_client.update_row.call_args + row_arg = call_args[0][1] + assert "DELETE_ALL" in row_arg.attribute_columns + assert "state" in row_arg.attribute_columns["DELETE_ALL"] + + @pytest.mark.asyncio + async def test_get_state_not_found(self) -> None: + backend = _make_async_backend() + result = await backend.get_state_async( + StateScope.SESSION, "a", "u", "s" + ) + assert result is None + + @pytest.mark.asyncio + async def test_get_state_null_pk(self) -> None: + backend = _make_async_backend() + null_row = _make_null_row() + backend._async_client.get_row = AsyncMock( + return_value=(None, null_row, None) + ) + result = await backend.get_state_async( + StateScope.SESSION, "a", "u", "s" + ) + assert result is None + + @pytest.mark.asyncio + async def test_get_state_no_chunk(self) -> None: + backend = _make_async_backend() + pk = [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")] + row = _make_state_row(pk, state={"k": "v"}) + backend._async_client.get_row = AsyncMock( + return_value=(None, row, None) + ) + + result = await backend.get_state_async( + StateScope.SESSION, "a", "u", "s" + ) + assert result is not None + assert result.state == {"k": "v"} + + @pytest.mark.asyncio + async def test_get_state_with_chunks(self) -> None: + backend = _make_async_backend() + pk = [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")] + state_json = json.dumps({"data": "hello"}) + c1 = state_json[:5] + c2 = state_json[5:] + row = _make_state_row(pk, chunk_count=2, chunks=[c1, c2]) + backend._async_client.get_row = AsyncMock( + return_value=(None, row, None) + ) + + result = await backend.get_state_async( + StateScope.SESSION, "a", "u", "s" + ) + assert result is not None + assert result.state == {"data": "hello"} + + @pytest.mark.asyncio + async def test_get_state_missing_chunk(self) -> None: + backend = _make_async_backend() + pk = [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")] + row = _make_state_row(pk, chunk_count=2, chunks=["partial"]) + backend._async_client.get_row = AsyncMock( + return_value=(None, row, None) + ) + + with pytest.raises(ValueError, match="Missing state chunk"): + await backend.get_state_async(StateScope.SESSION, "a", "u", "s") + + @pytest.mark.asyncio + async def test_get_state_no_state_column(self) -> None: + backend = _make_async_backend() + pk = [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")] + row = Row(pk, [("chunk_count", 0, 0), ("version", 1, 0)]) + backend._async_client.get_row = AsyncMock( + return_value=(None, row, None) + ) + + result = await backend.get_state_async( + StateScope.SESSION, "a", "u", "s" + ) + assert result is None + + @pytest.mark.asyncio + async def test_delete_state_row(self) -> None: + backend = _make_async_backend() + await backend.delete_state_row_async(StateScope.SESSION, "a", "u", "s") + backend._async_client.delete_row.assert_called_once() + + @pytest.mark.asyncio + async def test_get_chunk_count_async(self) -> None: + backend = _make_async_backend() + pk = [("agent_id", "a")] + row = Row(pk, [("chunk_count", 3, 0)]) + backend._async_client.get_row = AsyncMock( + return_value=(None, row, None) + ) + + count = await backend._get_chunk_count_async("app_state", pk) + assert count == 3 + + @pytest.mark.asyncio + async def test_get_chunk_count_async_no_row(self) -> None: + backend = _make_async_backend() + count = await backend._get_chunk_count_async( + "app_state", [("agent_id", "a")] + ) + assert count == 0 + + @pytest.mark.asyncio + async def test_get_chunk_count_async_null_pk(self) -> None: + backend = _make_async_backend() + null_row = _make_null_row() + backend._async_client.get_row = AsyncMock( + return_value=(None, null_row, None) + ) + count = await backend._get_chunk_count_async( + "app_state", [("agent_id", "a")] + ) + assert count == 0 + + @pytest.mark.asyncio + async def test_put_state_async_more_chunks_to_fewer(self) -> None: + """异步:旧 4 个分片 → 新 2 个分片,应删除多余分片。""" + backend = _make_async_backend() + chunk_row = Row( + [("agent_id", "a"), ("user_id", "u"), ("session_id", "s")], + [("chunk_count", 4, 0)], + ) + backend._async_client.get_row = AsyncMock( + return_value=(None, chunk_row, None) + ) + + data_size = MAX_COLUMN_SIZE + 10 + big_state = {"d": "a" * data_size} + await backend.put_state_async( + StateScope.SESSION, "a", "u", "s", big_state, 1 + ) + call_args = backend._async_client.update_row.call_args + row_arg = call_args[0][1] + if "DELETE_ALL" in row_arg.attribute_columns: + deleted = row_arg.attribute_columns["DELETE_ALL"] + assert "state_2" in deleted + assert "state_3" in deleted + + @pytest.mark.asyncio + async def test_create_event_table_other_error(self) -> None: + """异步:Event 表创建非已存在错误应抛出。""" + async_client = MagicMock() + err = OTSServiceError(500, "InternalError", "internal error") + async_client.create_table = AsyncMock(side_effect=[None, err]) + backend = _make_async_backend(async_client) + with pytest.raises(OTSServiceError): + await backend.init_core_tables_async() + + @pytest.mark.asyncio + async def test_create_state_table_other_error(self) -> None: + """异步:State 表创建非已存在错误应抛出。""" + async_client = MagicMock() + err = OTSServiceError(500, "InternalError", "internal error") + async_client.create_table = AsyncMock(side_effect=err) + backend = _make_async_backend(async_client) + with pytest.raises(OTSServiceError): + await backend.init_state_tables_async() + + @pytest.mark.asyncio + async def test_search_sessions_is_pinned_false(self) -> None: + """异步搜索 is_pinned=False。""" + backend = _make_async_backend() + response = MagicMock() + response.rows = [] + response.total_count = 0 + backend._async_client.search = AsyncMock(return_value=response) + await backend.search_sessions_async("a", is_pinned=False) + backend._async_client.search.assert_called_once() + + @pytest.mark.asyncio + async def test_search_sessions_with_row_objects(self) -> None: + """异步搜索返回 Row 对象而非 tuple。""" + backend = _make_async_backend() + response = MagicMock() + row = _make_session_row() + response.rows = [row] + response.total_count = 1 + backend._async_client.search = AsyncMock(return_value=response) + + sessions, total = await backend.search_sessions_async("agent1") + assert len(sessions) == 1 diff --git a/tests/unittests/conversation_service/test_session_store.py b/tests/unittests/conversation_service/test_session_store.py new file mode 100644 index 0000000..47862b9 --- /dev/null +++ b/tests/unittests/conversation_service/test_session_store.py @@ -0,0 +1,1404 @@ +"""conversation_service.session_store 单元测试。 + +通过 Mock OTSBackend 测试 SessionStore 的业务逻辑: +- Session CRUD(含级联删除) +- Event 追加/获取 +- State 三级管理(app / user / session) +- 三级状态合并 +- _apply_delta 增量更新逻辑 +- from_memory_collection 工厂方法 +""" + +from __future__ import annotations + +import json +import os +from typing import Any, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.conversation_service.model import ( + ConversationEvent, + ConversationSession, + StateData, + StateScope, +) +from agentrun.conversation_service.ots_backend import OTSBackend +from agentrun.conversation_service.session_store import SessionStore + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_backend() -> MagicMock: + """创建 Mock OTSBackend。""" + backend = MagicMock(spec=OTSBackend) + + # 同步方法默认返回值 + backend.get_session.return_value = None + backend.list_sessions.return_value = [] + backend.list_all_sessions.return_value = [] + backend.get_events.return_value = [] + backend.get_state.return_value = None + backend.put_event.return_value = 1 + backend.delete_events_by_session.return_value = 0 + backend.search_sessions.return_value = ([], 0) + + return backend + + +def _make_store(backend: MagicMock | None = None) -> SessionStore: + """创建带 Mock backend 的 SessionStore。""" + if backend is None: + backend = _make_mock_backend() + return SessionStore(backend) + + +# --------------------------------------------------------------------------- +# init 方法 +# --------------------------------------------------------------------------- + + +class TestInitMethods: + """init_tables / init_core_tables 等代理方法测试。""" + + def test_init_tables(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + store.init_tables() + backend.init_tables.assert_called_once() + + def test_init_core_tables(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + store.init_core_tables() + backend.init_core_tables.assert_called_once() + + def test_init_state_tables(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + store.init_state_tables() + backend.init_state_tables.assert_called_once() + + def test_init_search_index(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + store.init_search_index() + backend.init_search_index.assert_called_once() + + +# --------------------------------------------------------------------------- +# Session 管理 +# --------------------------------------------------------------------------- + + +class TestCreateSession: + """create_session 测试。""" + + def test_basic(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + + session = store.create_session("agent", "user", "sess") + + assert session.agent_id == "agent" + assert session.user_id == "user" + assert session.session_id == "sess" + assert session.created_at > 0 + assert session.updated_at == session.created_at + assert session.version == 0 + backend.put_session.assert_called_once() + + def test_with_optional_fields(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + + session = store.create_session( + "a", + "u", + "s", + is_pinned=True, + summary="hello", + labels='["tag"]', + framework="adk", + extensions={"key": "val"}, + ) + + assert session.is_pinned is True + assert session.summary == "hello" + assert session.labels == '["tag"]' + assert session.framework == "adk" + assert session.extensions == {"key": "val"} + + +class TestGetSession: + """get_session 测试。""" + + def test_found(self) -> None: + backend = _make_mock_backend() + expected = ConversationSession( + agent_id="a", + user_id="u", + session_id="s", + created_at=100, + updated_at=200, + ) + backend.get_session.return_value = expected + + store = _make_store(backend) + result = store.get_session("a", "u", "s") + + assert result is expected + backend.get_session.assert_called_once_with("a", "u", "s") + + def test_not_found(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + result = store.get_session("a", "u", "s") + assert result is None + + +class TestListSessions: + """list_sessions 测试。""" + + def test_basic(self) -> None: + backend = _make_mock_backend() + sessions = [ + ConversationSession("a", "u", "s1", 100, 200), + ConversationSession("a", "u", "s2", 100, 300), + ] + backend.list_sessions.return_value = sessions + + store = _make_store(backend) + result = store.list_sessions("a", "u") + + assert len(result) == 2 + backend.list_sessions.assert_called_once_with( + "a", "u", limit=None, order_desc=True + ) + + def test_with_limit(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + store.list_sessions("a", "u", limit=5) + backend.list_sessions.assert_called_once_with( + "a", "u", limit=5, order_desc=True + ) + + +class TestListAllSessions: + """list_all_sessions 测试。""" + + def test_basic(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + store.list_all_sessions("a") + backend.list_all_sessions.assert_called_once_with("a", limit=None) + + def test_with_limit(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + store.list_all_sessions("a", limit=10) + backend.list_all_sessions.assert_called_once_with("a", limit=10) + + +class TestSearchSessions: + """search_sessions 测试。""" + + def test_basic(self) -> None: + backend = _make_mock_backend() + backend.search_sessions.return_value = ([], 0) + + store = _make_store(backend) + sessions, total = store.search_sessions("a") + + assert sessions == [] + assert total == 0 + + def test_with_all_filters(self) -> None: + backend = _make_mock_backend() + backend.search_sessions.return_value = ([], 0) + + store = _make_store(backend) + store.search_sessions( + "a", + user_id="u", + summary_keyword="hello", + labels="tag", + framework="adk", + updated_after=100, + updated_before=200, + is_pinned=True, + limit=10, + offset=5, + ) + + backend.search_sessions.assert_called_once_with( + "a", + user_id="u", + summary_keyword="hello", + labels="tag", + framework="adk", + updated_after=100, + updated_before=200, + is_pinned=True, + limit=10, + offset=5, + ) + + +class TestUpdateSession: + """update_session 乐观锁更新测试。""" + + def test_update_all_fields(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + + store.update_session( + "a", + "u", + "s", + is_pinned=True, + summary="new summary", + labels='["new"]', + extensions={"new": "ext"}, + version=1, + ) + + backend.update_session.assert_called_once() + call_args = backend.update_session.call_args + cols = call_args[0][3] # columns_to_put + assert cols["version"] == 2 + assert cols["is_pinned"] is True + assert cols["summary"] == "new summary" + assert cols["labels"] == '["new"]' + assert json.loads(cols["extensions"]) == {"new": "ext"} + + def test_update_partial_fields(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + + store.update_session("a", "u", "s", is_pinned=True, version=0) + + call_args = backend.update_session.call_args + cols = call_args[0][3] + assert "is_pinned" in cols + assert "summary" not in cols + assert "labels" not in cols + assert "extensions" not in cols + + +class TestDeleteSession: + """delete_session 级联删除测试。""" + + def test_cascade_delete(self) -> None: + backend = _make_mock_backend() + backend.delete_events_by_session.return_value = 3 + + store = _make_store(backend) + store.delete_session("a", "u", "s") + + # 1. 删除 Event + backend.delete_events_by_session.assert_called_once_with("a", "u", "s") + # 2. 删除 Session State + backend.delete_state_row.assert_called_once_with( + StateScope.SESSION, "a", "u", "s" + ) + # 3. 删除 Session 行 + backend.delete_session_row.assert_called_once_with("a", "u", "s") + + +class TestDeleteEvents: + """delete_events 测试。""" + + def test_basic(self) -> None: + backend = _make_mock_backend() + backend.delete_events_by_session.return_value = 5 + + store = _make_store(backend) + deleted = store.delete_events("a", "u", "s") + + assert deleted == 5 + backend.delete_events_by_session.assert_called_once_with("a", "u", "s") + + +# --------------------------------------------------------------------------- +# Event 管理 +# --------------------------------------------------------------------------- + + +class TestAppendEvent: + """append_event 测试。""" + + def test_basic(self) -> None: + backend = _make_mock_backend() + backend.put_event.return_value = 42 + backend.get_session.return_value = ConversationSession( + "a", + "u", + "s", + 100, + 200, + version=1, + ) + + store = _make_store(backend) + event = store.append_event( + "a", + "u", + "s", + event_type="message", + content={"msg": "hi"}, + ) + + assert event.seq_id == 42 + assert event.type == "message" + assert event.content == {"msg": "hi"} + backend.put_event.assert_called_once() + # 应更新 Session 的 updated_at + backend.update_session.assert_called_once() + + def test_with_raw_event(self) -> None: + backend = _make_mock_backend() + backend.put_event.return_value = 1 + backend.get_session.return_value = None + + store = _make_store(backend) + event = store.append_event( + "a", + "u", + "s", + event_type="adk_event", + content={}, + raw_event='{"raw": true}', + ) + + assert event.raw_event == '{"raw": true}' + + def test_session_not_found_skips_update(self) -> None: + """Session 不存在时不更新 updated_at。""" + backend = _make_mock_backend() + backend.put_event.return_value = 1 + backend.get_session.return_value = None + + store = _make_store(backend) + store.append_event("a", "u", "s", "msg", {}) + + backend.update_session.assert_not_called() + + def test_update_session_failure_ignored(self) -> None: + """更新 Session 失败不应阻断事件写入。""" + backend = _make_mock_backend() + backend.put_event.return_value = 1 + backend.get_session.return_value = ConversationSession( + "a", + "u", + "s", + 100, + 200, + version=0, + ) + backend.update_session.side_effect = Exception("OTS error") + + store = _make_store(backend) + event = store.append_event("a", "u", "s", "msg", {}) + + # 事件仍然返回 + assert event.seq_id == 1 + + +class TestGetEvents: + """get_events / get_recent_events 测试。""" + + def test_get_events(self) -> None: + backend = _make_mock_backend() + events = [ + ConversationEvent("a", "u", "s", 1, "msg", {"text": "1"}), + ConversationEvent("a", "u", "s", 2, "msg", {"text": "2"}), + ] + backend.get_events.return_value = events + + store = _make_store(backend) + result = store.get_events("a", "u", "s") + + assert len(result) == 2 + backend.get_events.assert_called_once_with( + "a", "u", "s", direction="FORWARD" + ) + + def test_get_recent_events(self) -> None: + backend = _make_mock_backend() + # 倒序返回 + events = [ + ConversationEvent("a", "u", "s", 3, "msg"), + ConversationEvent("a", "u", "s", 2, "msg"), + ] + backend.get_events.return_value = events + + store = _make_store(backend) + result = store.get_recent_events("a", "u", "s", n=2) + + # 应翻转为正序 + assert result[0].seq_id == 2 + assert result[1].seq_id == 3 + backend.get_events.assert_called_once_with( + "a", "u", "s", direction="BACKWARD", limit=2 + ) + + +# --------------------------------------------------------------------------- +# State 管理 +# --------------------------------------------------------------------------- + + +class TestGetSessionState: + """get_session_state 测试。""" + + def test_exists(self) -> None: + backend = _make_mock_backend() + backend.get_state.return_value = StateData(state={"counter": 42}) + + store = _make_store(backend) + result = store.get_session_state("a", "u", "s") + + assert result == {"counter": 42} + backend.get_state.assert_called_once_with( + StateScope.SESSION, "a", "u", "s" + ) + + def test_not_exists(self) -> None: + backend = _make_mock_backend() + backend.get_state.return_value = None + + store = _make_store(backend) + result = store.get_session_state("a", "u", "s") + assert result == {} + + +class TestGetAppState: + """get_app_state 测试。""" + + def test_exists(self) -> None: + backend = _make_mock_backend() + backend.get_state.return_value = StateData(state={"config": "val"}) + + store = _make_store(backend) + result = store.get_app_state("a") + + assert result == {"config": "val"} + backend.get_state.assert_called_once_with(StateScope.APP, "a", "", "") + + def test_not_exists(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + assert store.get_app_state("a") == {} + + +class TestGetUserState: + """get_user_state 测试。""" + + def test_exists(self) -> None: + backend = _make_mock_backend() + backend.get_state.return_value = StateData(state={"pref": "dark"}) + + store = _make_store(backend) + result = store.get_user_state("a", "u") + + assert result == {"pref": "dark"} + backend.get_state.assert_called_once_with(StateScope.USER, "a", "u", "") + + def test_not_exists(self) -> None: + backend = _make_mock_backend() + store = _make_store(backend) + assert store.get_user_state("a", "u") == {} + + +class TestUpdateSessionState: + """update_session_state 增量更新测试。""" + + def test_first_write(self) -> None: + """首次写入,过滤 None 值。""" + backend = _make_mock_backend() + backend.get_state.return_value = None + + store = _make_store(backend) + store.update_session_state( + "a", "u", "s", {"key": "val", "null_key": None} + ) + + backend.put_state.assert_called_once() + call_args = backend.put_state.call_args + # state 不应包含 null_key + assert call_args.kwargs["state"] == {"key": "val"} + assert call_args.kwargs["version"] == 0 + + def test_merge_update(self) -> None: + """增量合并已有 state。""" + backend = _make_mock_backend() + backend.get_state.return_value = StateData( + state={"existing": "val", "to_delete": "old"}, + version=2, + ) + + store = _make_store(backend) + store.update_session_state( + "a", + "u", + "s", + {"new_key": "new", "to_delete": None}, + ) + + backend.put_state.assert_called_once() + call_args = backend.put_state.call_args + merged = call_args.kwargs["state"] + assert merged == {"existing": "val", "new_key": "new"} + assert "to_delete" not in merged + assert call_args.kwargs["version"] == 2 + + +class TestUpdateAppState: + """update_app_state 测试。""" + + def test_first_write(self) -> None: + backend = _make_mock_backend() + backend.get_state.return_value = None + + store = _make_store(backend) + store.update_app_state("a", {"config": "val"}) + + backend.put_state.assert_called_once() + call_args = backend.put_state.call_args + assert call_args[0][0] == StateScope.APP + assert call_args[0][1] == "a" + + +class TestUpdateUserState: + """update_user_state 测试。""" + + def test_first_write(self) -> None: + backend = _make_mock_backend() + backend.get_state.return_value = None + + store = _make_store(backend) + store.update_user_state("a", "u", {"pref": "dark"}) + + backend.put_state.assert_called_once() + call_args = backend.put_state.call_args + assert call_args[0][0] == StateScope.USER + assert call_args[0][1] == "a" + assert call_args[0][2] == "u" + + +class TestGetMergedState: + """get_merged_state 三级状态合并测试。""" + + def test_all_levels(self) -> None: + backend = _make_mock_backend() + # 模拟三级返回 + backend.get_state.side_effect = [ + StateData(state={"app_key": "app_val"}), # APP + StateData(state={"user_key": "user_val"}), # USER + StateData(state={"sess_key": "sess_val"}), # SESSION + ] + + store = _make_store(backend) + result = store.get_merged_state("a", "u", "s") + + assert result == { + "app_key": "app_val", + "user_key": "user_val", + "sess_key": "sess_val", + } + + def test_override_order(self) -> None: + """后者覆盖前者。""" + backend = _make_mock_backend() + backend.get_state.side_effect = [ + StateData(state={"key": "app"}), # APP + StateData(state={"key": "user"}), # USER + StateData(state={"key": "session"}), # SESSION + ] + + store = _make_store(backend) + result = store.get_merged_state("a", "u", "s") + assert result["key"] == "session" + + def test_missing_levels(self) -> None: + """某级不存在视为空 dict。""" + backend = _make_mock_backend() + backend.get_state.side_effect = [ + None, # APP + StateData(state={"user_key": "val"}), # USER + None, # SESSION + ] + + store = _make_store(backend) + result = store.get_merged_state("a", "u", "s") + assert result == {"user_key": "val"} + + +# --------------------------------------------------------------------------- +# from_memory_collection 工厂方法 +# --------------------------------------------------------------------------- + + +class TestFromMemoryCollection: + """from_memory_collection 工厂方法测试。""" + + def test_import_error(self) -> None: + """agentrun 主包未安装时抛 ImportError。""" + with patch.dict( + "sys.modules", + {"agentrun.memory_collection": None, "agentrun.utils.config": None}, + ): + with pytest.raises(ImportError, match="agentrun 主包未安装"): + SessionStore.from_memory_collection("test-mc") + + def _make_mock_mc( + self, + endpoint: str = "https://inst.cn-hangzhou.ots.aliyuncs.com", + instance_name: str = "test_instance", + has_vs_config: bool = True, + ) -> MagicMock: + """构造 Mock MemoryCollection。""" + mc = MagicMock() + if not has_vs_config: + mc.vector_store_config = None + else: + mc.vector_store_config = MagicMock() + mc.vector_store_config.config = MagicMock() + mc.vector_store_config.config.endpoint = endpoint + mc.vector_store_config.config.instance_name = instance_name + return mc + + @patch("tablestore.WriteRetryPolicy") + @patch("tablestore.AsyncOTSClient") + @patch("tablestore.OTSClient") + def test_success( + self, + mock_ots_cls: MagicMock, + mock_async_ots_cls: MagicMock, + mock_wrp: MagicMock, + ) -> None: + """正常创建。""" + mock_mc = self._make_mock_mc() + + with ( + patch( + "agentrun.memory_collection.MemoryCollection.get_by_name", + return_value=mock_mc, + ), + patch.dict( + "os.environ", + { + "AGENTRUN_ACCESS_KEY_ID": "ak_id", + "AGENTRUN_ACCESS_KEY_SECRET": "ak_secret", + }, + ), + ): + store = SessionStore.from_memory_collection( + "test-mc", + table_prefix="p_", + ) + + assert isinstance(store, SessionStore) + mock_ots_cls.assert_called_once() + mock_async_ots_cls.assert_called_once() + + def test_missing_vector_store_config(self) -> None: + mock_mc = self._make_mock_mc(has_vs_config=False) + + with patch( + "agentrun.memory_collection.MemoryCollection.get_by_name", + return_value=mock_mc, + ): + with pytest.raises(ValueError, match="缺少"): + SessionStore.from_memory_collection("test-mc") + + def test_empty_endpoint(self) -> None: + mock_mc = self._make_mock_mc(endpoint="") + + with patch( + "agentrun.memory_collection.MemoryCollection.get_by_name", + return_value=mock_mc, + ): + with pytest.raises(ValueError, match="endpoint 为空"): + SessionStore.from_memory_collection("test-mc") + + def test_empty_instance_name(self) -> None: + mock_mc = self._make_mock_mc(instance_name="") + + with patch( + "agentrun.memory_collection.MemoryCollection.get_by_name", + return_value=mock_mc, + ): + with pytest.raises(ValueError, match="instance_name 为空"): + SessionStore.from_memory_collection("test-mc") + + def test_empty_credentials(self) -> None: + mock_mc = self._make_mock_mc() + + with ( + patch( + "agentrun.memory_collection.MemoryCollection.get_by_name", + return_value=mock_mc, + ), + patch.dict( + "os.environ", + { + "AGENTRUN_ACCESS_KEY_ID": "", + "AGENTRUN_ACCESS_KEY_SECRET": "", + }, + clear=False, + ), + ): + with pytest.raises(ValueError, match="AK/SK 凭证为空"): + SessionStore.from_memory_collection("test-mc") + + @patch("tablestore.WriteRetryPolicy") + @patch("tablestore.AsyncOTSClient") + @patch("tablestore.OTSClient") + def test_with_sts_token( + self, + mock_ots_cls: MagicMock, + mock_async_ots_cls: MagicMock, + mock_wrp: MagicMock, + ) -> None: + """带 STS token。""" + mock_mc = self._make_mock_mc() + + with ( + patch( + "agentrun.memory_collection.MemoryCollection.get_by_name", + return_value=mock_mc, + ), + patch.dict( + "os.environ", + { + "AGENTRUN_ACCESS_KEY_ID": "ak_id", + "AGENTRUN_ACCESS_KEY_SECRET": "ak_secret", + "AGENTRUN_SECURITY_TOKEN": "sts_token", + }, + ), + ): + store = SessionStore.from_memory_collection("test-mc") + + assert isinstance(store, SessionStore) + ots_kwargs = mock_ots_cls.call_args.kwargs + assert ots_kwargs.get("sts_token") == "sts_token" + + @patch("tablestore.WriteRetryPolicy") + @patch("tablestore.AsyncOTSClient") + @patch("tablestore.OTSClient") + def test_vpc_endpoint_conversion( + self, + mock_ots_cls: MagicMock, + mock_async_ots_cls: MagicMock, + mock_wrp: MagicMock, + ) -> None: + """VPC 地址转公网。""" + mock_mc = self._make_mock_mc( + endpoint="https://inst.cn-hangzhou.vpc.tablestore.aliyuncs.com", + ) + + with ( + patch( + "agentrun.memory_collection.MemoryCollection.get_by_name", + return_value=mock_mc, + ), + patch.dict( + "os.environ", + { + "AGENTRUN_ACCESS_KEY_ID": "ak_id", + "AGENTRUN_ACCESS_KEY_SECRET": "ak_secret", + }, + ), + ): + SessionStore.from_memory_collection("test-mc") + + ots_call_args = mock_ots_cls.call_args[0] + assert ots_call_args[0] == "https://inst.cn-hangzhou.ots.aliyuncs.com" + + +class TestFromMemoryCollectionAsync: + """from_memory_collection_async 异步工厂方法测试。""" + + @pytest.mark.asyncio + async def test_import_error(self) -> None: + with patch.dict( + "sys.modules", + {"agentrun.memory_collection": None, "agentrun.utils.config": None}, + ): + with pytest.raises(ImportError, match="agentrun 主包未安装"): + await SessionStore.from_memory_collection_async("test-mc") + + def _make_mock_mc( + self, + endpoint: str = "https://inst.cn-hangzhou.ots.aliyuncs.com", + instance_name: str = "inst", + has_vs_config: bool = True, + ) -> MagicMock: + mc = MagicMock() + if not has_vs_config: + mc.vector_store_config = None + else: + mc.vector_store_config = MagicMock() + mc.vector_store_config.config = MagicMock() + mc.vector_store_config.config.endpoint = endpoint + mc.vector_store_config.config.instance_name = instance_name + return mc + + @pytest.mark.asyncio + @patch("tablestore.WriteRetryPolicy") + @patch("tablestore.AsyncOTSClient") + @patch("tablestore.OTSClient") + async def test_success( + self, + mock_ots_cls: MagicMock, + mock_async_ots_cls: MagicMock, + mock_wrp: MagicMock, + ) -> None: + mock_mc = self._make_mock_mc() + + with ( + patch( + "agentrun.memory_collection.MemoryCollection.get_by_name_async", + new=AsyncMock(return_value=mock_mc), + ), + patch.dict( + "os.environ", + { + "AGENTRUN_ACCESS_KEY_ID": "ak_id", + "AGENTRUN_ACCESS_KEY_SECRET": "ak_secret", + }, + ), + ): + store = await SessionStore.from_memory_collection_async("test-mc") + + assert isinstance(store, SessionStore) + + @pytest.mark.asyncio + async def test_missing_config(self) -> None: + mock_mc = self._make_mock_mc(has_vs_config=False) + + with patch( + "agentrun.memory_collection.MemoryCollection.get_by_name_async", + new=AsyncMock(return_value=mock_mc), + ): + with pytest.raises(ValueError, match="缺少"): + await SessionStore.from_memory_collection_async("test-mc") + + @pytest.mark.asyncio + async def test_empty_endpoint(self) -> None: + mock_mc = self._make_mock_mc(endpoint="") + + with patch( + "agentrun.memory_collection.MemoryCollection.get_by_name_async", + new=AsyncMock(return_value=mock_mc), + ): + with pytest.raises(ValueError, match="endpoint 为空"): + await SessionStore.from_memory_collection_async("test-mc") + + @pytest.mark.asyncio + async def test_empty_instance_name(self) -> None: + mock_mc = self._make_mock_mc(instance_name="") + + with patch( + "agentrun.memory_collection.MemoryCollection.get_by_name_async", + new=AsyncMock(return_value=mock_mc), + ): + with pytest.raises(ValueError, match="instance_name 为空"): + await SessionStore.from_memory_collection_async("test-mc") + + @pytest.mark.asyncio + async def test_empty_credentials(self) -> None: + mock_mc = self._make_mock_mc() + + with ( + patch( + "agentrun.memory_collection.MemoryCollection.get_by_name_async", + new=AsyncMock(return_value=mock_mc), + ), + patch.dict( + "os.environ", + { + "AGENTRUN_ACCESS_KEY_ID": "", + "AGENTRUN_ACCESS_KEY_SECRET": "", + }, + clear=False, + ), + ): + with pytest.raises(ValueError, match="AK/SK 凭证为空"): + await SessionStore.from_memory_collection_async("test-mc") + + +# --------------------------------------------------------------------------- +# 异步方法测试 +# --------------------------------------------------------------------------- + + +def _make_async_mock_backend() -> MagicMock: + """创建带异步方法的 Mock OTSBackend。""" + backend = MagicMock(spec=OTSBackend) + + # 异步方法 + backend.init_tables_async = AsyncMock() + backend.init_core_tables_async = AsyncMock() + backend.init_state_tables_async = AsyncMock() + backend.init_search_index_async = AsyncMock() + backend.put_session_async = AsyncMock() + backend.get_session_async = AsyncMock(return_value=None) + backend.list_sessions_async = AsyncMock(return_value=[]) + backend.list_all_sessions_async = AsyncMock(return_value=[]) + backend.search_sessions_async = AsyncMock(return_value=([], 0)) + backend.delete_session_row_async = AsyncMock() + backend.update_session_async = AsyncMock() + backend.put_event_async = AsyncMock(return_value=1) + backend.get_events_async = AsyncMock(return_value=[]) + backend.delete_events_by_session_async = AsyncMock(return_value=0) + backend.get_state_async = AsyncMock(return_value=None) + backend.put_state_async = AsyncMock() + backend.delete_state_row_async = AsyncMock() + + return backend + + +class TestInitMethodsAsync: + """异步 init 方法测试。""" + + @pytest.mark.asyncio + async def test_init_tables_async(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.init_tables_async() + backend.init_tables_async.assert_called_once() + + @pytest.mark.asyncio + async def test_init_core_tables_async(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.init_core_tables_async() + backend.init_core_tables_async.assert_called_once() + + @pytest.mark.asyncio + async def test_init_state_tables_async(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.init_state_tables_async() + backend.init_state_tables_async.assert_called_once() + + @pytest.mark.asyncio + async def test_init_search_index_async(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.init_search_index_async() + backend.init_search_index_async.assert_called_once() + + +class TestCreateSessionAsync: + """create_session_async 测试。""" + + @pytest.mark.asyncio + async def test_basic(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + + session = await store.create_session_async("a", "u", "s") + + assert session.agent_id == "a" + assert session.created_at > 0 + backend.put_session_async.assert_called_once() + + @pytest.mark.asyncio + async def test_with_optional_fields(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + + session = await store.create_session_async( + "a", + "u", + "s", + is_pinned=True, + summary="test", + labels='["tag"]', + framework="adk", + extensions={"k": "v"}, + ) + + assert session.is_pinned is True + assert session.extensions == {"k": "v"} + + +class TestGetSessionAsync: + """get_session_async 测试。""" + + @pytest.mark.asyncio + async def test_found(self) -> None: + backend = _make_async_mock_backend() + expected = ConversationSession("a", "u", "s", 100, 200) + backend.get_session_async.return_value = expected + + store = SessionStore(backend) + result = await store.get_session_async("a", "u", "s") + assert result is expected + + @pytest.mark.asyncio + async def test_not_found(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + result = await store.get_session_async("a", "u", "s") + assert result is None + + +class TestListSessionsAsync: + """list_sessions_async / list_all_sessions_async 测试。""" + + @pytest.mark.asyncio + async def test_list(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.list_sessions_async("a", "u", limit=5) + backend.list_sessions_async.assert_called_once_with( + "a", "u", limit=5, order_desc=True + ) + + @pytest.mark.asyncio + async def test_list_all(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.list_all_sessions_async("a", limit=10) + backend.list_all_sessions_async.assert_called_once_with("a", limit=10) + + +class TestSearchSessionsAsync: + """search_sessions_async 测试。""" + + @pytest.mark.asyncio + async def test_basic(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + sessions, total = await store.search_sessions_async("a") + assert sessions == [] + assert total == 0 + + +class TestUpdateSessionAsync: + """update_session_async 测试。""" + + @pytest.mark.asyncio + async def test_update_all_fields(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.update_session_async( + "a", + "u", + "s", + is_pinned=True, + summary="new", + labels='["t"]', + extensions={"e": 1}, + version=1, + ) + backend.update_session_async.assert_called_once() + call_args = backend.update_session_async.call_args + cols = call_args[0][3] + assert cols["version"] == 2 + assert cols["is_pinned"] is True + assert cols["summary"] == "new" + assert cols["labels"] == '["t"]' + assert json.loads(cols["extensions"]) == {"e": 1} + + @pytest.mark.asyncio + async def test_update_partial_fields(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.update_session_async( + "a", "u", "s", is_pinned=True, version=0 + ) + call_args = backend.update_session_async.call_args + cols = call_args[0][3] + assert "is_pinned" in cols + assert "summary" not in cols + assert "labels" not in cols + assert "extensions" not in cols + + @pytest.mark.asyncio + async def test_update_no_optional_fields(self) -> None: + """不传任何可选字段。""" + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.update_session_async("a", "u", "s", version=0) + call_args = backend.update_session_async.call_args + cols = call_args[0][3] + assert "is_pinned" not in cols + assert "summary" not in cols + + +class TestDeleteSessionAsync: + """delete_session_async 级联删除测试。""" + + @pytest.mark.asyncio + async def test_cascade(self) -> None: + backend = _make_async_mock_backend() + backend.delete_events_by_session_async.return_value = 3 + + store = SessionStore(backend) + await store.delete_session_async("a", "u", "s") + + backend.delete_events_by_session_async.assert_called_once() + backend.delete_state_row_async.assert_called_once_with( + StateScope.SESSION, "a", "u", "s" + ) + backend.delete_session_row_async.assert_called_once_with("a", "u", "s") + + +class TestDeleteEventsAsync: + """delete_events_async 测试。""" + + @pytest.mark.asyncio + async def test_basic(self) -> None: + backend = _make_async_mock_backend() + backend.delete_events_by_session_async.return_value = 5 + + store = SessionStore(backend) + deleted = await store.delete_events_async("a", "u", "s") + assert deleted == 5 + + +class TestAppendEventAsync: + """append_event_async 测试。""" + + @pytest.mark.asyncio + async def test_basic(self) -> None: + backend = _make_async_mock_backend() + backend.put_event_async.return_value = 42 + backend.get_session_async.return_value = ConversationSession( + "a", + "u", + "s", + 100, + 200, + version=1, + ) + + store = SessionStore(backend) + event = await store.append_event_async( + "a", + "u", + "s", + "msg", + {"key": "val"}, + ) + + assert event.seq_id == 42 + backend.update_session_async.assert_called_once() + + @pytest.mark.asyncio + async def test_session_not_found(self) -> None: + backend = _make_async_mock_backend() + backend.put_event_async.return_value = 1 + backend.get_session_async.return_value = None + + store = SessionStore(backend) + event = await store.append_event_async("a", "u", "s", "msg", {}) + assert event.seq_id == 1 + backend.update_session_async.assert_not_called() + + @pytest.mark.asyncio + async def test_update_failure_ignored(self) -> None: + backend = _make_async_mock_backend() + backend.put_event_async.return_value = 1 + backend.get_session_async.return_value = ConversationSession( + "a", + "u", + "s", + 100, + 200, + version=0, + ) + backend.update_session_async.side_effect = Exception("fail") + + store = SessionStore(backend) + event = await store.append_event_async("a", "u", "s", "msg", {}) + assert event.seq_id == 1 + + @pytest.mark.asyncio + async def test_with_raw_event(self) -> None: + backend = _make_async_mock_backend() + backend.put_event_async.return_value = 1 + backend.get_session_async.return_value = None + + store = SessionStore(backend) + event = await store.append_event_async( + "a", + "u", + "s", + "msg", + {}, + raw_event='{"raw": true}', + ) + assert event.raw_event == '{"raw": true}' + + +class TestGetEventsAsync: + """get_events_async / get_recent_events_async 测试。""" + + @pytest.mark.asyncio + async def test_get_events(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.get_events_async("a", "u", "s") + backend.get_events_async.assert_called_once_with( + "a", "u", "s", direction="FORWARD" + ) + + @pytest.mark.asyncio + async def test_get_recent_events(self) -> None: + backend = _make_async_mock_backend() + events = [ + ConversationEvent("a", "u", "s", 3, "msg"), + ConversationEvent("a", "u", "s", 2, "msg"), + ] + backend.get_events_async.return_value = events + + store = SessionStore(backend) + result = await store.get_recent_events_async("a", "u", "s", 2) + + assert result[0].seq_id == 2 + assert result[1].seq_id == 3 + + +class TestStateAsync: + """异步 State 管理测试。""" + + @pytest.mark.asyncio + async def test_get_session_state(self) -> None: + backend = _make_async_mock_backend() + backend.get_state_async.return_value = StateData(state={"k": "v"}) + + store = SessionStore(backend) + result = await store.get_session_state_async("a", "u", "s") + assert result == {"k": "v"} + + @pytest.mark.asyncio + async def test_get_session_state_empty(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + result = await store.get_session_state_async("a", "u", "s") + assert result == {} + + @pytest.mark.asyncio + async def test_get_app_state(self) -> None: + backend = _make_async_mock_backend() + backend.get_state_async.return_value = StateData(state={"app": True}) + + store = SessionStore(backend) + result = await store.get_app_state_async("a") + assert result == {"app": True} + + @pytest.mark.asyncio + async def test_get_app_state_empty(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + result = await store.get_app_state_async("a") + assert result == {} + + @pytest.mark.asyncio + async def test_get_user_state(self) -> None: + backend = _make_async_mock_backend() + backend.get_state_async.return_value = StateData(state={"user": True}) + + store = SessionStore(backend) + result = await store.get_user_state_async("a", "u") + assert result == {"user": True} + + @pytest.mark.asyncio + async def test_get_user_state_empty(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + result = await store.get_user_state_async("a", "u") + assert result == {} + + @pytest.mark.asyncio + async def test_update_session_state(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.update_session_state_async("a", "u", "s", {"k": "v"}) + backend.put_state_async.assert_called_once() + + @pytest.mark.asyncio + async def test_update_app_state(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.update_app_state_async("a", {"k": "v"}) + backend.put_state_async.assert_called_once() + + @pytest.mark.asyncio + async def test_update_user_state(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + await store.update_user_state_async("a", "u", {"k": "v"}) + backend.put_state_async.assert_called_once() + + +class TestGetMergedStateAsync: + """get_merged_state_async 测试。""" + + @pytest.mark.asyncio + async def test_all_levels(self) -> None: + backend = _make_async_mock_backend() + backend.get_state_async.side_effect = [ + StateData(state={"app": 1}), + StateData(state={"user": 2}), + StateData(state={"sess": 3}), + ] + + store = SessionStore(backend) + result = await store.get_merged_state_async("a", "u", "s") + assert result == {"app": 1, "user": 2, "sess": 3} + + +class TestApplyDeltaAsync: + """_apply_delta_async 增量更新逻辑测试。""" + + @pytest.mark.asyncio + async def test_first_write(self) -> None: + backend = _make_async_mock_backend() + store = SessionStore(backend) + + await store._apply_delta_async( + StateScope.SESSION, + "a", + "u", + "s", + {"key": "val", "null": None}, + ) + + backend.put_state_async.assert_called_once() + call_args = backend.put_state_async.call_args + assert call_args.kwargs["state"] == {"key": "val"} + assert call_args.kwargs["version"] == 0 + + @pytest.mark.asyncio + async def test_merge_update(self) -> None: + backend = _make_async_mock_backend() + backend.get_state_async.return_value = StateData( + state={"existing": "val", "to_remove": "old"}, + version=2, + ) + + store = SessionStore(backend) + await store._apply_delta_async( + StateScope.SESSION, + "a", + "u", + "s", + {"new": "val", "to_remove": None}, + ) + + call_args = backend.put_state_async.call_args + merged = call_args.kwargs["state"] + assert merged == {"existing": "val", "new": "val"} + assert call_args.kwargs["version"] == 2 diff --git a/tests/unittests/conversation_service/test_utils.py b/tests/unittests/conversation_service/test_utils.py new file mode 100644 index 0000000..ade7d2f --- /dev/null +++ b/tests/unittests/conversation_service/test_utils.py @@ -0,0 +1,147 @@ +"""conversation_service.utils 单元测试。 + +覆盖 convert_vpc_endpoint_to_public、nanoseconds_timestamp、 +serialize_state、deserialize_state、to_chunks、from_chunks。 +""" + +from __future__ import annotations + +import time + +import pytest + +from agentrun.conversation_service.utils import ( + convert_vpc_endpoint_to_public, + deserialize_state, + from_chunks, + MAX_COLUMN_SIZE, + nanoseconds_timestamp, + serialize_state, + to_chunks, +) + +# --------------------------------------------------------------------------- +# convert_vpc_endpoint_to_public +# --------------------------------------------------------------------------- + + +class TestConvertVpcEndpoint: + """VPC 地址转公网地址。""" + + def test_vpc_endpoint(self) -> None: + result = convert_vpc_endpoint_to_public( + "https://inst.cn-hangzhou.vpc.tablestore.aliyuncs.com" + ) + assert result == "https://inst.cn-hangzhou.ots.aliyuncs.com" + + def test_non_vpc_endpoint(self) -> None: + ep = "https://inst.cn-hangzhou.ots.aliyuncs.com" + assert convert_vpc_endpoint_to_public(ep) == ep + + def test_empty_string(self) -> None: + assert convert_vpc_endpoint_to_public("") == "" + + def test_other_domain(self) -> None: + ep = "https://example.com" + assert convert_vpc_endpoint_to_public(ep) == ep + + +# --------------------------------------------------------------------------- +# nanoseconds_timestamp +# --------------------------------------------------------------------------- + + +class TestNanosecondsTimestamp: + """纳秒时间戳。""" + + def test_returns_int(self) -> None: + ts = nanoseconds_timestamp() + assert isinstance(ts, int) + + def test_roughly_correct(self) -> None: + before = int(time.time() * 1_000_000_000) + ts = nanoseconds_timestamp() + after = int(time.time() * 1_000_000_000) + assert before <= ts <= after + + +# --------------------------------------------------------------------------- +# serialize_state / deserialize_state +# --------------------------------------------------------------------------- + + +class TestStateSerialization: + """状态序列化/反序列化。""" + + def test_roundtrip(self) -> None: + state = {"key": "value", "num": 42, "nested": {"a": [1, 2]}} + serialized = serialize_state(state) + deserialized = deserialize_state(serialized) + assert deserialized == state + + def test_unicode(self) -> None: + state = {"中文": "值"} + serialized = serialize_state(state) + assert "中文" in serialized + assert deserialize_state(serialized) == state + + def test_empty(self) -> None: + serialized = serialize_state({}) + assert deserialize_state(serialized) == {} + + +# --------------------------------------------------------------------------- +# to_chunks / from_chunks +# --------------------------------------------------------------------------- + + +class TestChunking: + """字符串分片/拼接。""" + + def test_small_data_single_chunk(self) -> None: + data = "hello" + chunks = to_chunks(data, max_size=100) + assert chunks == ["hello"] + assert from_chunks(chunks) == data + + def test_exact_size(self) -> None: + data = "abcdef" + chunks = to_chunks(data, max_size=6) + assert chunks == ["abcdef"] + assert from_chunks(chunks) == data + + def test_split_into_multiple_chunks(self) -> None: + data = "abcdefghij" # 10 chars + chunks = to_chunks(data, max_size=3) + assert chunks == ["abc", "def", "ghi", "j"] + assert from_chunks(chunks) == data + + def test_empty_string(self) -> None: + assert to_chunks("", max_size=10) == [] + assert from_chunks([]) == "" + + def test_max_size_one(self) -> None: + data = "abc" + chunks = to_chunks(data, max_size=1) + assert chunks == ["a", "b", "c"] + assert from_chunks(chunks) == data + + def test_invalid_max_size(self) -> None: + with pytest.raises(ValueError, match="max_size must be positive"): + to_chunks("data", max_size=0) + with pytest.raises(ValueError, match="max_size must be positive"): + to_chunks("data", max_size=-1) + + def test_default_max_size(self) -> None: + """默认使用 MAX_COLUMN_SIZE。""" + data = "x" * 10 + chunks = to_chunks(data) + assert len(chunks) == 1 + assert MAX_COLUMN_SIZE == 1_500_000 + + def test_large_data(self) -> None: + """模拟大数据分片场景。""" + data = "a" * 100 + chunks = to_chunks(data, max_size=30) + assert len(chunks) == 4 # 30+30+30+10 + assert from_chunks(chunks) == data diff --git a/tests/unittests/integration/test_langchain_agui_integration.py b/tests/unittests/integration/test_langchain_agui_integration.py index 6cfb32b..b5dba63 100644 --- a/tests/unittests/integration/test_langchain_agui_integration.py +++ b/tests/unittests/integration/test_langchain_agui_integration.py @@ -664,7 +664,9 @@ async def invoke_agent(request: AgentRequest): json={ "messages": [{ "role": "user", - "content": "查询当前的时间,并获取天气信息,同时输出我的密钥信息", + "content": ( + "查询当前的时间,并获取天气信息,同时输出我的密钥信息" + ), }], "stream": True, }, @@ -727,7 +729,9 @@ async def invoke_agent(request: AgentRequest): json={ "messages": [{ "role": "user", - "content": "查询当前的时间,并获取天气信息,同时输出我的密钥信息", + "content": ( + "查询当前的时间,并获取天气信息,同时输出我的密钥信息" + ), }], "stream": True, }, diff --git a/tests/unittests/toolset/api/test_openapi.py b/tests/unittests/toolset/api/test_openapi.py index 3a60866..ab9acc6 100644 --- a/tests/unittests/toolset/api/test_openapi.py +++ b/tests/unittests/toolset/api/test_openapi.py @@ -545,7 +545,9 @@ def test_post_with_ref_schema(self): "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateOrderRequest" + "$ref": ( + "#/components/schemas/CreateOrderRequest" + ) } } }, @@ -756,7 +758,9 @@ def test_invalid_ref_gracefully_handled(self): "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/NonExistent" + "$ref": ( + "#/components/schemas/NonExistent" + ) } } } @@ -789,7 +793,9 @@ def test_external_ref_not_resolved(self): "content": { "application/json": { "schema": { - "$ref": "https://example.com/schemas/external.json" + "$ref": ( + "https://example.com/schemas/external.json" + ) } } } @@ -909,7 +915,9 @@ def _get_coffee_shop_schema(): "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateOrderRequest" + "$ref": ( + "#/components/schemas/CreateOrderRequest" + ) } } }, @@ -945,7 +953,9 @@ def _get_coffee_shop_schema(): "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UpdateOrderStatusRequest" + "$ref": ( + "#/components/schemas/UpdateOrderStatusRequest" + ) } } },