From 7bef549dc8642acae40864e517e9dc6f4def4754 Mon Sep 17 00:00:00 2001 From: Ashish-dwi99 Date: Wed, 11 Feb 2026 21:37:18 +0530 Subject: [PATCH 1/2] new archt --- .cursor/rules/engram-continuity.mdc | 12 + .gitignore | 2 - README.md | 40 +- docs/pdf/manifest.json | 840 +++++++++++++++ engram/api/app.py | 109 +- engram/api/schemas.py | 32 +- engram/benchmarks/__init__.py | 2 + engram/benchmarks/longmemeval.py | 418 ++++++++ engram/configs/base.py | 5 +- engram/core/handoff_backend.py | 445 ++++++++ engram/core/handoff_bus.py | 194 +++- engram/core/kernel.py | 260 +++-- engram/core/profile.py | 41 +- engram/core/scene.py | 62 +- engram/db/sqlite.py | 442 ++++++-- engram/embeddings/nvidia.py | 2 +- engram/llms/nvidia.py | 2 +- engram/mcp_server.py | 1373 ++++++++++++++----------- engram/memory/client.py | 46 + engram/memory/main.py | 711 +++++++------ engram/retrieval/dual_search.py | 33 +- engram/retrieval/reranker.py | 93 +- engram/utils/math.py | 29 + launch-article.md | 222 ++++ pitch-deck.md | 427 ++++++++ pyproject.toml | 5 +- scripts/__init__.py | 1 + scripts/build_doc_book.py | 180 ++++ scripts/docgen/__init__.py | 18 + scripts/docgen/analyze.py | 1094 ++++++++++++++++++++ scripts/docgen/render_pdf.py | 230 +++++ scripts/generate_deep_docs.py | 330 ++++++ tests/test_agent_policies.py | 72 ++ tests/test_api_v2.py | 105 ++ tests/test_backward_compat.py | 162 +++ tests/test_cosine_similarity.py | 54 + tests/test_docgen_analyze.py | 147 +++ tests/test_docgen_pipeline.py | 192 ++++ tests/test_dual_retrieval.py | 101 ++ tests/test_handoff.py | 496 +++++++++ tests/test_handoff_api_compat.py | 131 +++ tests/test_handoff_hosted_backend.py | 50 + tests/test_mcp_handoff_lifecycle.py | 123 +++ tests/test_mcp_tool_dispatch.py | 39 + tests/test_memory_client_v2.py | 71 ++ tests/test_migration.py | 120 +++ tests/test_namespace_system.py | 76 ++ tests/test_policy_masking.py | 72 ++ tests/test_profile.py | 205 ++++ tests/test_refaware_decay.py | 45 + tests/test_scene.py | 217 ++++ tests/test_security_handoff_strict.py | 33 + tests/test_security_sessions.py | 242 +++++ tests/test_sleep_cycle.py | 47 + tests/test_sqlite_connection_pool.py | 174 ++++ tests/test_staging.py | 299 ++++++ tests/test_trust_scores.py | 131 +++ tests/testlongmemeval_runner.py | 175 ++++ 58 files changed, 10084 insertions(+), 1195 deletions(-) create mode 100644 .cursor/rules/engram-continuity.mdc create mode 100644 docs/pdf/manifest.json create mode 100644 engram/benchmarks/__init__.py create mode 100644 engram/benchmarks/longmemeval.py create mode 100644 engram/core/handoff_backend.py create mode 100644 engram/utils/math.py create mode 100644 launch-article.md create mode 100644 pitch-deck.md create mode 100644 scripts/__init__.py create mode 100644 scripts/build_doc_book.py create mode 100644 scripts/docgen/__init__.py create mode 100644 scripts/docgen/analyze.py create mode 100644 scripts/docgen/render_pdf.py create mode 100644 scripts/generate_deep_docs.py create mode 100644 tests/test_agent_policies.py create mode 100644 tests/test_api_v2.py create mode 100644 tests/test_backward_compat.py create mode 100644 tests/test_cosine_similarity.py create mode 100644 tests/test_docgen_analyze.py create mode 100644 tests/test_docgen_pipeline.py create mode 100644 tests/test_dual_retrieval.py create mode 100644 tests/test_handoff.py create mode 100644 tests/test_handoff_api_compat.py create mode 100644 tests/test_handoff_hosted_backend.py create mode 100644 tests/test_mcp_handoff_lifecycle.py create mode 100644 tests/test_mcp_tool_dispatch.py create mode 100644 tests/test_memory_client_v2.py create mode 100644 tests/test_migration.py create mode 100644 tests/test_namespace_system.py create mode 100644 tests/test_policy_masking.py create mode 100644 tests/test_profile.py create mode 100644 tests/test_refaware_decay.py create mode 100644 tests/test_scene.py create mode 100644 tests/test_security_handoff_strict.py create mode 100644 tests/test_security_sessions.py create mode 100644 tests/test_sleep_cycle.py create mode 100644 tests/test_sqlite_connection_pool.py create mode 100644 tests/test_staging.py create mode 100644 tests/test_trust_scores.py create mode 100644 tests/testlongmemeval_runner.py diff --git a/.cursor/rules/engram-continuity.mdc b/.cursor/rules/engram-continuity.mdc new file mode 100644 index 0000000..22f46e1 --- /dev/null +++ b/.cursor/rules/engram-continuity.mdc @@ -0,0 +1,12 @@ +--- +description: "Use Engram handoff tools automatically for continuity" +alwaysApply: true +--- + +When an Engram MCP server is available: + +1) At the start of a new task/thread, call `get_last_session` first. + Do not pass `agent_id` unless the user explicitly asks for a specific source agent. +2) Use the returned handoff context to continue work naturally. +3) Before pausing or ending, call `save_session_digest`. +4) Do not use shell/SQLite probing for continuity when MCP handoff tools exist. diff --git a/.gitignore b/.gitignore index e44b78d..b2d1195 100644 --- a/.gitignore +++ b/.gitignore @@ -97,8 +97,6 @@ $RECYCLE.BIN/ htmlcov/ .tox/ .hypothesis/ -test_*.py -*_test.py # Claude Code .claude/ diff --git a/README.md b/README.md index 220d455..49dcc9f 100644 --- a/README.md +++ b/README.md @@ -251,7 +251,7 @@ Scene: "Engram v2 architecture session" #### Handoff Bus — Cross-Agent Continuity -When an agent pauses work — rate limit, crash, you switch tools — it saves a session digest: task summary, decisions made, files touched, remaining TODOs, blockers. The next agent calls `get_last_session` and gets the full context. No re-explanation needed. +Engram now defaults to a zero-intervention continuity model: MCP adapters automatically request resume context before tool execution and auto-write checkpoints on lifecycle events (`tool_complete`, `agent_pause`, `agent_end`). The legacy tools (`save_session_digest`, `get_last_session`, `list_sessions`) remain available for compatibility. ``` Claude Code (rate limited) @@ -264,6 +264,9 @@ Codex (picks up) → Continues where Claude Code stopped ``` +**Default runtime model:** hosted session bus when `ENGRAM_API_URL` is configured. +**Default security posture:** strict handoff auth (`read_handoff` / `write_handoff`) with no implicit trusted-agent bootstrap. + --- ### Key Flows @@ -329,6 +332,14 @@ project directory) so agents call handoff tools automatically: - `CURSOR.md` - `.cursor/rules/engram-continuity.mdc` +For hosted continuity, set: + +```bash +export ENGRAM_API_URL="http://127.0.0.1:8100" +export ENGRAM_ADMIN_KEY="" # for session minting +export ENGRAM_API_KEY="" # if your hosted API expects bearer auth +``` + Set `ENGRAM_INSTALL_SKIP_WORKSPACE_RULES=1` to disable this behavior. **MCP tools** give Claude reactive memory — it stores and retrieves when you ask. @@ -416,6 +427,9 @@ Once configured, your agent has access to these tools: | `get_last_session` | Load session context from the last active agent | | `list_sessions` | Browse handoff history across agents | +Auto-lifecycle behavior is server-driven: when `auto_session_bus` is enabled, +Engram writes handoff checkpoints without explicit user prompts. + --- ## API & SDK @@ -483,6 +497,30 @@ curl -X POST http://localhost:8100/v1/handoff/checkpoint \ curl "http://localhost:8100/v1/handoff/lanes?user_id=u123" ``` +### Handoff Compatibility Routes + +Hosted integrations can call compatibility routes that mirror MCP handoff tools: + +```bash +# Save digest (compat) +curl -X POST http://localhost:8100/v1/handoff/sessions/digest \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{"user_id":"u123","agent_id":"codex","task_summary":"Continue task","repo":"/repo"}' + +# Get last session (compat) +curl "http://localhost:8100/v1/handoff/sessions/last?user_id=u123&repo=/repo" + +# List sessions (compat) +curl "http://localhost:8100/v1/handoff/sessions?user_id=u123&repo=/repo&limit=20" +``` + +### Continuity Troubleshooting + +- `hosted_backend_unavailable`: verify `ENGRAM_API_URL` and network reachability. +- `missing_or_expired_token` / `missing_capability`: ensure the caller has a valid session token with `read_handoff` or `write_handoff`. +- `Storage folder ... qdrant is already accessed`: local Qdrant file mode is single-process; use hosted API mode or a shared Qdrant server for concurrent agents. + ### Python SDK ```python diff --git a/docs/pdf/manifest.json b/docs/pdf/manifest.json new file mode 100644 index 0000000..3b7abb6 --- /dev/null +++ b/docs/pdf/manifest.json @@ -0,0 +1,840 @@ +{ + "generated_at": "2026-02-11T10:08:17+00:00", + "repo_root": "/Users/ashish.dwivedi/Desktop/Engram", + "commit_hash": "888ba8306db4fabb8bff702f768cd69a060d6bf3", + "doc_depth": "deep", + "method": "deterministic_static", + "file_count": 83, + "items": [ + { + "source_path": "Dockerfile", + "source_sha256": "28ba9d941f7844ed1a882db8760d3f352e26d16e7c65dcf1e7ec276c60e2f21f", + "output_pdf": "files/Dockerfile.pdf", + "line_count": 9, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "docker-compose.yml", + "source_sha256": "d7eb28b7a409db9db4b107ef577761f7e5690cf61a10ebc425cf21c5c6320c04", + "output_pdf": "files/docker-compose.yml.pdf", + "line_count": 15, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/__init__.py", + "source_sha256": "4f58de4a769463c9be5a8f13ba2c590bbce0a803a60780134eeb0edc733c7158", + "output_pdf": "files/engram____init__.py.pdf", + "line_count": 47, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/api/__init__.py", + "source_sha256": "7bb0bb6ab09fad3067d803c79393dc7d2c58e1bcaad8d3c44258fc9938ebed49", + "output_pdf": "files/engram__api____init__.py.pdf", + "line_count": 6, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/api/app.py", + "source_sha256": "b3c5e3a4d2bd8fa1b02d68f5296b632ec75d116f41cad2094478c35ccb7b4783", + "output_pdf": "files/engram__api__app.py.pdf", + "line_count": 938, + "page_count": 8, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/api/auth.py", + "source_sha256": "b2838225fab23bc891dca54fea97174d0654c0eb62c26617d79ee52620195a11", + "output_pdf": "files/engram__api__auth.py.pdf", + "line_count": 76, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/api/schemas.py", + "source_sha256": "f557d58a140a00404fb72ff5d91bd8bd1f0694e442c4daff53b74a9ee429be83", + "output_pdf": "files/engram__api__schemas.py.pdf", + "line_count": 168, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/api/server.py", + "source_sha256": "dd87162b31ebd022bea0c585ca398c70fdaba53c97abe6e6f925f9cf09991de1", + "output_pdf": "files/engram__api__server.py.pdf", + "line_count": 34, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/api/static/dashboard.html", + "source_sha256": "be7058fb2173f9641c10bc5b8a908fb2af0dcfae939fa05256b2da4b69454f7a", + "output_pdf": "files/engram__api__static__dashboard.html.pdf", + "line_count": 736, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/cli.py", + "source_sha256": "cde4e065fc7f284383a6882fdd17a25aecbc116624a298d39dc1c7ffe327ede4", + "output_pdf": "files/engram__cli.py.pdf", + "line_count": 547, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/configs/__init__.py", + "source_sha256": "222b7d6b8e432a10a61748b7757b3db32d82a72808820a7520f40c947b879822", + "output_pdf": "files/engram__configs____init__.py.pdf", + "line_count": 19, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/configs/base.py", + "source_sha256": "670e205d2dc35b7ab6847851e1d34637e18365369b7be78435241140e5615bbd", + "output_pdf": "files/engram__configs__base.py.pdf", + "line_count": 192, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/__init__.py", + "source_sha256": "df5557e5f818c74c20e359d59d6130a3b19b2a2d659970607d346ef7634e862e", + "output_pdf": "files/engram__core____init__.py.pdf", + "line_count": 22, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/acceptance.py", + "source_sha256": "eb6b6da3b969b8d274cf3f187ee3e93c3eb4050e837c9643d80cc26259babc50", + "output_pdf": "files/engram__core__acceptance.py.pdf", + "line_count": 142, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/category.py", + "source_sha256": "70489214481e688e068627ec7df25f9dfe1446bcd52c3922283a1cd554055165", + "output_pdf": "files/engram__core__category.py.pdf", + "line_count": 741, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/conflict.py", + "source_sha256": "fb6719966ecb65090c0e4b7c95133cf5e520338666508c11da4c60c7c27a6d4b", + "output_pdf": "files/engram__core__conflict.py.pdf", + "line_count": 41, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/decay.py", + "source_sha256": "8cb7f02729660f06b81443820062030cb89317015d9050e21e6906881076f888", + "output_pdf": "files/engram__core__decay.py.pdf", + "line_count": 33, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/echo.py", + "source_sha256": "3c604cffe4dd2d87f670d25390fd91dcc9d1b2ff97ca6917c14bc0d7347d5510", + "output_pdf": "files/engram__core__echo.py.pdf", + "line_count": 470, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/fusion.py", + "source_sha256": "57ad524df251af330356ead9140781632224482ec792cc973629714d90bb7b1b", + "output_pdf": "files/engram__core__fusion.py.pdf", + "line_count": 43, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/graph.py", + "source_sha256": "8e93374af53975e0153adc95498a944aaf41b9cc0c5883bfed9ce524149e864a", + "output_pdf": "files/engram__core__graph.py.pdf", + "line_count": 470, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/handoff.py", + "source_sha256": "736a8d8828eb55c710f2adcac077cc38c44f096f1535965acc7e0d42c76f7b99", + "output_pdf": "files/engram__core__handoff.py.pdf", + "line_count": 193, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/handoff_bus.py", + "source_sha256": "9bc95ed3fcbd71aa1bfe642afa0e791bad2f1e07b83e95a6f82577d5cff22f8c", + "output_pdf": "files/engram__core__handoff_bus.py.pdf", + "line_count": 975, + "page_count": 6, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/invariants.py", + "source_sha256": "7620458858b939507613e833eacb298e0ada79c9b34aa62522424710d3c837f1", + "output_pdf": "files/engram__core__invariants.py.pdf", + "line_count": 110, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/kernel.py", + "source_sha256": "8f276aa2d4fc76e5c8a26bfda6497df125f799ab78ec5ee29f28135800ac3bfa", + "output_pdf": "files/engram__core__kernel.py.pdf", + "line_count": 1638, + "page_count": 7, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/policy.py", + "source_sha256": "e0daaba15a8678c14dba5e471f727db192bd27efaa9573372708e44ed322e8e0", + "output_pdf": "files/engram__core__policy.py.pdf", + "line_count": 149, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/profile.py", + "source_sha256": "3453b55bd34540440bbfdc3ccd475b659d898bb9983920abce44d4ee6fffdc1a", + "output_pdf": "files/engram__core__profile.py.pdf", + "line_count": 433, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/provenance.py", + "source_sha256": "a86017ed7e13c3a6e47ac1abd22fb3b9be8945523bc6f59429c9c82979293c9f", + "output_pdf": "files/engram__core__provenance.py.pdf", + "line_count": 40, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/retrieval.py", + "source_sha256": "65291ba3e9c5e3524dd16f6de398a6e239162dca6647994dec0dc496a3ff17fb", + "output_pdf": "files/engram__core__retrieval.py.pdf", + "line_count": 190, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/core/scene.py", + "source_sha256": "5bee40316bfdd14b734fd4d706196ac5d115c5f8b7c343adc9539b513a76e788", + "output_pdf": "files/engram__core__scene.py.pdf", + "line_count": 351, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/db/__init__.py", + "source_sha256": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + "output_pdf": "files/engram__db____init__.py.pdf", + "line_count": 0, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/db/async_sqlite.py", + "source_sha256": "e7c9c9cf2df834e2ff4fb36fd907975a05a07232d1e68032a0889e7f6c6b00e5", + "output_pdf": "files/engram__db__async_sqlite.py.pdf", + "line_count": 439, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/db/sqlite.py", + "source_sha256": "7f520e1e773710e5f0527094322f8b74a4485dc90d634ce5b2a38ec161155097", + "output_pdf": "files/engram__db__sqlite.py.pdf", + "line_count": 3258, + "page_count": 8, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/decay/__init__.py", + "source_sha256": "dee6fffa964ec6bdfb0f64d363a084a8f9812a8f3e6285c720123716bb49a551", + "output_pdf": "files/engram__decay____init__.py.pdf", + "line_count": 5, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/decay/refcounts.py", + "source_sha256": "f9607b2f3af3dbf12a7fdfc5ce8da796cbd24c2ef852cee187a7752497f078d9", + "output_pdf": "files/engram__decay__refcounts.py.pdf", + "line_count": 45, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/embeddings/__init__.py", + "source_sha256": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + "output_pdf": "files/engram__embeddings____init__.py.pdf", + "line_count": 0, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/embeddings/async_base.py", + "source_sha256": "d58164e9c04987f8d109088145249d91fa77703a4faf99878fb3ca11c108e308", + "output_pdf": "files/engram__embeddings__async_base.py.pdf", + "line_count": 125, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/embeddings/base.py", + "source_sha256": "9037a576fd14c3d839711942fff65696e855a54ba08895acaa329582cb8ce2b8", + "output_pdf": "files/engram__embeddings__base.py.pdf", + "line_count": 11, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/embeddings/gemini.py", + "source_sha256": "b48ded8e7c5d855f12791c98f16f8bb9b750fe5bbf55fc6d11670b6a60b43ca2", + "output_pdf": "files/engram__embeddings__gemini.py.pdf", + "line_count": 68, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/embeddings/nvidia.py", + "source_sha256": "135f7aedc043e38cbcde0c9b1105247c72d2f732f540a0016964d5e978583fdb", + "output_pdf": "files/engram__embeddings__nvidia.py.pdf", + "line_count": 40, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/embeddings/ollama.py", + "source_sha256": "8fd1de512439fa01ef73f51ccf7a2a77da6754c91823c7ce9e7998edbb3a02cd", + "output_pdf": "files/engram__embeddings__ollama.py.pdf", + "line_count": 66, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/embeddings/openai.py", + "source_sha256": "2415d676ae5c636a29cd337cde42f09426ce980a4acb9fcf2c4af8e347da54ed", + "output_pdf": "files/engram__embeddings__openai.py.pdf", + "line_count": 18, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/embeddings/simple.py", + "source_sha256": "bfcf35374cd8435a5eecf7e1896811791b8f228ff6af6931e5b1261e23f80c6e", + "output_pdf": "files/engram__embeddings__simple.py.pdf", + "line_count": 27, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/exceptions.py", + "source_sha256": "112d9045bad56b258ea3481356a8d679ef1ddf50ee83365389ca636fda83a59c", + "output_pdf": "files/engram__exceptions.py.pdf", + "line_count": 19, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/integrations/__init__.py", + "source_sha256": "3b7f891a7a16ba409fd063a9da162b618f8de1533ce8a20808feec40e90861df", + "output_pdf": "files/engram__integrations____init__.py.pdf", + "line_count": 1, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/integrations/claude_code_plugin.py", + "source_sha256": "79d193779b3232fa947ca08625ba7081f2c33e87bc6d1ca1733da9e97888d053", + "output_pdf": "files/engram__integrations__claude_code_plugin.py.pdf", + "line_count": 532, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/integrations/openclaw.py", + "source_sha256": "088d3bd8ca835e13cf157d9ae009908a14697511bde3b916d7d5a51c0f63d6f9", + "output_pdf": "files/engram__integrations__openclaw.py.pdf", + "line_count": 177, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/llms/__init__.py", + "source_sha256": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + "output_pdf": "files/engram__llms____init__.py.pdf", + "line_count": 0, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/llms/async_base.py", + "source_sha256": "96b7d1f85a569f7bde484c544958502083df681276414d4b4aafe9d23a97cd86", + "output_pdf": "files/engram__llms__async_base.py.pdf", + "line_count": 127, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/llms/base.py", + "source_sha256": "37147c9b8091112ea0b1d04e4ab45c24bd77c7199d4a14ad059d428e8baef34f", + "output_pdf": "files/engram__llms__base.py.pdf", + "line_count": 11, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/llms/gemini.py", + "source_sha256": "9be76ab523695706bbf17a62b4fe2d8a691284f59e97f9ddd41a7c7c80869f26", + "output_pdf": "files/engram__llms__gemini.py.pdf", + "line_count": 81, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/llms/mock.py", + "source_sha256": "757cee8c788e2b7e5103f55c7753d6149fa38971562913bdb11ccb46786f698a", + "output_pdf": "files/engram__llms__mock.py.pdf", + "line_count": 35, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/llms/nvidia.py", + "source_sha256": "714abc3223fc2d077075bae434ac9209cd9bbb5546b044dbecde78b36806abcb", + "output_pdf": "files/engram__llms__nvidia.py.pdf", + "line_count": 47, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/llms/ollama.py", + "source_sha256": "2467c42a605aa5d8c55d56687c26c967152cf82fb63029ed5c302ec9273ff707", + "output_pdf": "files/engram__llms__ollama.py.pdf", + "line_count": 58, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/llms/openai.py", + "source_sha256": "90e6fcd483343d3b1b74e76815c5c4929e79d61dc825a0bd1891bf6758c006f7", + "output_pdf": "files/engram__llms__openai.py.pdf", + "line_count": 25, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/main_cli.py", + "source_sha256": "a95649e7f3ea0283debb92d6db3a4875a41a06803cb37cf8af4e7312721b0aa2", + "output_pdf": "files/engram__main_cli.py.pdf", + "line_count": 455, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/mcp_server.py", + "source_sha256": "5bf7fcd270133d9a6bec7f2beaa61a3395767a2e037394f087b75362e9637757", + "output_pdf": "files/engram__mcp_server.py.pdf", + "line_count": 1859, + "page_count": 5, + "generated_at": "2026-02-11T10:08:17+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/memory/__init__.py", + "source_sha256": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + "output_pdf": "files/engram__memory____init__.py.pdf", + "line_count": 0, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/memory/async_memory.py", + "source_sha256": "2264132ea37d854cbaf5adf984e42a3922348320e02cd80bd93a3e5066431e4f", + "output_pdf": "files/engram__memory__async_memory.py.pdf", + "line_count": 479, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/memory/base.py", + "source_sha256": "57fd9b33a94b7dd44a71d9e06485257d8220b677cc0a46a00870d4c0f6a0f5ce", + "output_pdf": "files/engram__memory__base.py.pdf", + "line_count": 23, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/memory/client.py", + "source_sha256": "b1cd82636b8e39093f0cb78c705881840103ebddd4751f94f7df76485ee4d7e5", + "output_pdf": "files/engram__memory__client.py.pdf", + "line_count": 418, + "page_count": 6, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/memory/episodic_store.py", + "source_sha256": "ca86ad8bb6c22c82f464562ef81f56c337e39ed80beaf1753a808b39da685658", + "output_pdf": "files/engram__memory__episodic_store.py.pdf", + "line_count": 290, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/memory/main.py", + "source_sha256": "eec46c92775f4b98aea9f1cb534078a11a72a40456b690cf03bb5c8ba282520e", + "output_pdf": "files/engram__memory__main.py.pdf", + "line_count": 2626, + "page_count": 9, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/memory/staging_store.py", + "source_sha256": "4acebc3d6a0ea160f8c658293f43b1030220cc43f9255bac1d89f9d144bc39b0", + "output_pdf": "files/engram__memory__staging_store.py.pdf", + "line_count": 97, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/memory/utils.py", + "source_sha256": "559d94114ebd6bbbbb655c6777b6e8a72945eca40d0b0774c297b4610e69d1ba", + "output_pdf": "files/engram__memory__utils.py.pdf", + "line_count": 173, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/observability.py", + "source_sha256": "857d9f05e72e0de1af661602b8b228664fd9687327dd83b452aae0ad6c30d8fe", + "output_pdf": "files/engram__observability.py.pdf", + "line_count": 392, + "page_count": 6, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/retrieval/__init__.py", + "source_sha256": "7251bab64be6f447f42ce901411962e82ccdd86d9a5daf49fe2f9b77c65091c8", + "output_pdf": "files/engram__retrieval____init__.py.pdf", + "line_count": 5, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/retrieval/context_packer.py", + "source_sha256": "7a63bcd006084cf86eef129523f4cdbbb145a787a04a566bd04ffb468d841dda", + "output_pdf": "files/engram__retrieval__context_packer.py.pdf", + "line_count": 62, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/retrieval/dual_search.py", + "source_sha256": "491a428eaa24fcd006da972be73831ce293f5219f8371949779b37cfa3a8a0bd", + "output_pdf": "files/engram__retrieval__dual_search.py.pdf", + "line_count": 136, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/retrieval/reranker.py", + "source_sha256": "e03d5946cb6eb65266eb87f3da6c6a9df3925ad0d228136a2e8913061897ab72", + "output_pdf": "files/engram__retrieval__reranker.py.pdf", + "line_count": 37, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/simple.py", + "source_sha256": "d37dc95de7d1406fa2eb5e622b5703a73ae2e92e61f4055e678afcc6b70a7555", + "output_pdf": "files/engram__simple.py.pdf", + "line_count": 354, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/utils/__init__.py", + "source_sha256": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + "output_pdf": "files/engram__utils____init__.py.pdf", + "line_count": 0, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/utils/factory.py", + "source_sha256": "f4b53bd52c9f796b8a009e583a2dfa49ec977f962c35f008f9d5f246b1855b63", + "output_pdf": "files/engram__utils__factory.py.pdf", + "line_count": 66, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/utils/prompts.py", + "source_sha256": "2e18b0155f3fe3278099fd4bc166b4f9aee2b0c1caedf1df86104a5037a0b9b0", + "output_pdf": "files/engram__utils__prompts.py.pdf", + "line_count": 183, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/utils/repo_identity.py", + "source_sha256": "aa426f61c673a02272bff2ec20f4967616e762691f9c851d877a3fdaa06949a9", + "output_pdf": "files/engram__utils__repo_identity.py.pdf", + "line_count": 72, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/vector_stores/__init__.py", + "source_sha256": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + "output_pdf": "files/engram__vector_stores____init__.py.pdf", + "line_count": 0, + "page_count": 3, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/vector_stores/async_qdrant.py", + "source_sha256": "4b4607a557e8829aeb011563302a4279fb32d18902ea3b0abb33626d3cf128df", + "output_pdf": "files/engram__vector_stores__async_qdrant.py.pdf", + "line_count": 210, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/vector_stores/base.py", + "source_sha256": "c9e547779d919349429709b13f63380c84977f40ce1d82fec95446cc16461757", + "output_pdf": "files/engram__vector_stores__base.py.pdf", + "line_count": 48, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/vector_stores/memory.py", + "source_sha256": "b16e4ae73d2ce8d0a95f6f78c9321887b3ed2bfee9038d32cf3df005fc34b90d", + "output_pdf": "files/engram__vector_stores__memory.py.pdf", + "line_count": 101, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "engram/vector_stores/qdrant.py", + "source_sha256": "77d119d541c4fd8a36e74530ccacd6f10157f837421b78976086f83ddf35b5a5", + "output_pdf": "files/engram__vector_stores__qdrant.py.pdf", + "line_count": 227, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "plugins/engram-memory/.claude-plugin/plugin.json", + "source_sha256": "b559f37962bb39e415bdd53d7dbbaf429e1e9c7e3cc3ca08cb8335d3d0c1a192", + "output_pdf": "files/plugins__engram-memory__.claude-plugin__plugin.json.pdf", + "line_count": 7, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "plugins/engram-memory/hooks/hooks.json", + "source_sha256": "d296be7c0515ece5a1991245606f88846db2215927ff8e65fda97c81619de8fb", + "output_pdf": "files/plugins__engram-memory__hooks__hooks.json.pdf", + "line_count": 12, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "plugins/engram-memory/hooks/prompt_context.py", + "source_sha256": "5f285e3e6254c6a1f221151f7dd471623bde889d2c51404e776320ef5ed5bbc9", + "output_pdf": "files/plugins__engram-memory__hooks__prompt_context.py.pdf", + "line_count": 201, + "page_count": 5, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + }, + { + "source_path": "pyproject.toml", + "source_sha256": "dc5bd74055cd94b8551f5ece65dc5f8cfbe1453e3fb90b4e524888021d958685", + "output_pdf": "files/pyproject.toml.pdf", + "line_count": 84, + "page_count": 4, + "generated_at": "2026-02-11T10:07:54+00:00", + "doc_depth": "deep", + "method": "deterministic_static" + } + ] +} \ No newline at end of file diff --git a/engram/api/app.py b/engram/api/app.py index 86df251..3f258d2 100644 --- a/engram/api/app.py +++ b/engram/api/app.py @@ -28,8 +28,10 @@ CommitResolutionRequest, ConflictResolutionRequest, DailyDigestResponse, + HandoffStatus, HandoffCheckpointRequest, HandoffResumeRequest, + HandoffSessionDigestRequest, NamespaceDeclareRequest, NamespacePermissionRequest, SceneSearchRequest, @@ -50,6 +52,7 @@ class SearchResultResponse(BaseModel): results: List[Dict[str, Any]] count: int context_packet: Optional[Dict[str, Any]] = None + retrieval_trace: Optional[Dict[str, Any]] = None class StatsResponse(BaseModel): @@ -248,8 +251,8 @@ async def list_handoff_lanes( http_request: Request, user_id: str = Query(default="default"), repo_path: Optional[str] = Query(default=None), - status: Optional[str] = Query(default=None), - statuses: Optional[List[str]] = Query(default=None), + status: Optional[HandoffStatus] = Query(default=None), + statuses: Optional[List[HandoffStatus]] = Query(default=None), limit: int = Query(default=20, ge=1, le=200), requester_agent_id: Optional[str] = Query(default=None), ): @@ -268,6 +271,107 @@ async def list_handoff_lanes( return {"lanes": lanes, "count": len(lanes)} except PermissionError as exc: raise require_session_error(exc) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@app.post("/v1/handoff/sessions/digest") +@app.post("/v1/handoff/sessions/digest/") +async def save_handoff_session_digest(request: HandoffSessionDigestRequest, http_request: Request): + token = get_token_from_request(http_request) + kernel = get_kernel() + digest = { + "task_summary": request.task_summary, + "repo": request.repo, + "branch": request.branch, + "lane_id": request.lane_id, + "lane_type": request.lane_type, + "agent_role": request.agent_role, + "namespace": request.namespace, + "confidentiality_scope": request.confidentiality_scope, + "status": request.status, + "decisions_made": request.decisions_made, + "files_touched": request.files_touched, + "todos_remaining": request.todos_remaining, + "blockers": request.blockers, + "key_commands": request.key_commands, + "test_results": request.test_results, + "context_snapshot": request.context_snapshot, + "started_at": request.started_at, + "ended_at": request.ended_at, + } + try: + return kernel.save_session_digest( + user_id=request.user_id, + agent_id=request.agent_id, + digest=digest, + token=token, + requester_agent_id=request.requester_agent_id or request.agent_id, + ) + except PermissionError as exc: + raise require_session_error(exc) + + +@app.get("/v1/handoff/sessions/last") +@app.get("/v1/handoff/sessions/last/") +async def get_handoff_last_session( + http_request: Request, + user_id: str = Query(default="default"), + agent_id: Optional[str] = Query(default=None), + requester_agent_id: Optional[str] = Query(default=None), + repo: Optional[str] = Query(default=None), + statuses: Optional[List[HandoffStatus]] = Query(default=None), +): + token = get_token_from_request(http_request) + kernel = get_kernel() + try: + session = kernel.get_last_session( + user_id=user_id, + agent_id=agent_id, + repo=repo, + statuses=statuses, + token=token, + requester_agent_id=requester_agent_id or agent_id, + ) + if session: + return session + return {"error": "No sessions found"} + except PermissionError as exc: + raise require_session_error(exc) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@app.get("/v1/handoff/sessions") +@app.get("/v1/handoff/sessions/") +async def list_handoff_sessions( + http_request: Request, + user_id: str = Query(default="default"), + agent_id: Optional[str] = Query(default=None), + requester_agent_id: Optional[str] = Query(default=None), + repo: Optional[str] = Query(default=None), + status: Optional[HandoffStatus] = Query(default=None), + statuses: Optional[List[HandoffStatus]] = Query(default=None), + limit: int = Query(default=20, ge=1, le=200), +): + token = get_token_from_request(http_request) + kernel = get_kernel() + try: + sessions = kernel.list_sessions( + user_id=user_id, + agent_id=agent_id, + repo=repo, + status=status, + statuses=statuses, + limit=limit, + token=token, + requester_agent_id=requester_agent_id or agent_id, + ) + return {"sessions": sessions, "count": len(sessions)} + except PermissionError as exc: + raise require_session_error(exc) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) @app.post("/v1/search", response_model=SearchResultResponse) @@ -293,6 +397,7 @@ async def search_memories(request: SearchRequestV2, http_request: Request): results=results, count=len(results), context_packet=payload.get("context_packet"), + retrieval_trace=payload.get("retrieval_trace"), ) except PermissionError as exc: raise require_session_error(exc) diff --git a/engram/api/schemas.py b/engram/api/schemas.py index a9944ce..e06ebb0 100644 --- a/engram/api/schemas.py +++ b/engram/api/schemas.py @@ -2,12 +2,14 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field from engram.core.policy import ALL_CONFIDENTIALITY_SCOPES, DEFAULT_CAPABILITIES +HandoffStatus = Literal["active", "paused", "completed", "abandoned"] + class SessionCreateRequest(BaseModel): user_id: str = Field(default="default") @@ -37,7 +39,7 @@ class HandoffResumeRequest(BaseModel): objective: Optional[str] = Field(default=None) agent_role: Optional[str] = Field(default=None) namespace: str = Field(default="default") - statuses: Optional[List[str]] = Field(default=None) + statuses: Optional[List[HandoffStatus]] = Field(default=None) auto_create: bool = Field(default=True) @@ -54,7 +56,7 @@ class HandoffCheckpointRequest(BaseModel): namespace: str = Field(default="default") confidentiality_scope: str = Field(default="work") event_type: str = Field(default="tool_complete") - status: str = Field(default="active") + status: HandoffStatus = Field(default="active") task_summary: Optional[str] = Field(default=None) decisions_made: List[str] = Field(default_factory=list) files_touched: List[str] = Field(default_factory=list) @@ -66,6 +68,30 @@ class HandoffCheckpointRequest(BaseModel): expected_version: Optional[int] = Field(default=None) +class HandoffSessionDigestRequest(BaseModel): + user_id: str = Field(default="default") + agent_id: str = Field(default="claude-code") + requester_agent_id: Optional[str] = Field(default=None) + task_summary: str + repo: Optional[str] = Field(default=None) + branch: Optional[str] = Field(default=None) + lane_id: Optional[str] = Field(default=None) + lane_type: Optional[str] = Field(default=None) + agent_role: Optional[str] = Field(default=None) + namespace: str = Field(default="default") + confidentiality_scope: str = Field(default="work") + status: HandoffStatus = Field(default="paused") + decisions_made: List[str] = Field(default_factory=list) + files_touched: List[str] = Field(default_factory=list) + todos_remaining: List[str] = Field(default_factory=list) + blockers: List[str] = Field(default_factory=list) + key_commands: List[str] = Field(default_factory=list) + test_results: List[str] = Field(default_factory=list) + context_snapshot: Optional[str] = Field(default=None) + started_at: Optional[str] = Field(default=None) + ended_at: Optional[str] = Field(default=None) + + class SearchRequestV2(BaseModel): query: str user_id: str = Field(default="default") diff --git a/engram/benchmarks/__init__.py b/engram/benchmarks/__init__.py new file mode 100644 index 0000000..05b61e5 --- /dev/null +++ b/engram/benchmarks/__init__.py @@ -0,0 +1,2 @@ +"""Benchmark runners for Engram.""" + diff --git a/engram/benchmarks/longmemeval.py b/engram/benchmarks/longmemeval.py new file mode 100644 index 0000000..70ef2b9 --- /dev/null +++ b/engram/benchmarks/longmemeval.py @@ -0,0 +1,418 @@ +"""LongMemEval runner for Engram (Colab-friendly). + +Usage: + python -m engram.benchmarks.longmemeval --dataset-path ... --output-jsonl ... +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +from dataclasses import dataclass +from pathlib import Path +from statistics import mean +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from engram import Memory +from engram.configs.base import ( + CategoryMemConfig, + EchoMemConfig, + EmbedderConfig, + KnowledgeGraphConfig, + LLMConfig, + MemoryConfig, + ProfileConfig, + SceneConfig, + VectorStoreConfig, +) + + +SESSION_ID_PATTERN = re.compile(r"^Session ID:\s*(?P\S+)\s*$", re.MULTILINE) +HISTORY_HEADER = "User Transcript:" + + +def extract_user_only_text(session_turns: Sequence[Dict[str, Any]]) -> str: + """Convert one LongMemEval session into newline-separated user text.""" + lines = [str(turn.get("content", "")).strip() for turn in session_turns if turn.get("role") == "user"] + return "\n".join([line for line in lines if line]) + + +def format_session_memory(session_id: str, session_date: str, session_turns: Sequence[Dict[str, Any]]) -> str: + """Create a memory payload that preserves session metadata in plain text.""" + user_text = extract_user_only_text(session_turns) + return ( + f"Session ID: {session_id}\n" + f"Session Date: {session_date}\n" + f"{HISTORY_HEADER}\n" + f"{user_text}" + ) + + +def parse_session_id_from_result(result: Dict[str, Any]) -> Optional[str]: + """Extract session_id from memory metadata or fallback text header.""" + metadata = result.get("metadata") or {} + sid = metadata.get("session_id") + if sid: + return str(sid) + memory_text = str(result.get("memory", "") or "") + match = SESSION_ID_PATTERN.search(memory_text) + if match: + return match.group("session_id") + return None + + +def dedupe_preserve_order(items: Iterable[str]) -> List[str]: + seen = set() + ordered: List[str] = [] + for item in items: + if item in seen: + continue + seen.add(item) + ordered.append(item) + return ordered + + +def compute_session_metrics(retrieved_session_ids: Sequence[str], answer_session_ids: Sequence[str]) -> Dict[str, float]: + """Compute simple retrieval metrics over session IDs.""" + retrieved = dedupe_preserve_order([str(x) for x in retrieved_session_ids if str(x).strip()]) + gold = {str(x) for x in answer_session_ids if str(x).strip()} + + metrics: Dict[str, float] = {} + for k in (1, 3, 5, 10): + top_k = set(retrieved[:k]) + metrics[f"recall_any@{k}"] = 1.0 if gold and bool(top_k & gold) else 0.0 + metrics[f"recall_all@{k}"] = 1.0 if gold and gold.issubset(top_k) else 0.0 + return metrics + + +def build_answer_prompt(question: str, retrieved_context: str) -> str: + return ( + "You are answering a LongMemEval memory question.\n" + "Use only the retrieved history context. If the answer is missing, say: " + "\"I don't have enough information in memory.\"\n\n" + "Retrieved history:\n" + f"{retrieved_context}\n\n" + f"Question: {question}\n" + "Answer concisely:" + ) + + +@dataclass +class HFResponder: + model_name: str + max_new_tokens: int = 128 + + def __post_init__(self) -> None: + # Lazy heavy import so module-level import stays lightweight. + try: + from transformers import pipeline + except ImportError as exc: + raise ImportError( + "HF backend requires transformers (and torch). " + "Install with: pip install transformers accelerate" + ) from exc + + self._pipeline = pipeline( + "text-generation", + model=self.model_name, + tokenizer=self.model_name, + device_map="auto", + model_kwargs={"torch_dtype": "auto"}, + ) + + def generate(self, prompt: str) -> str: + outputs = self._pipeline( + prompt, + max_new_tokens=self.max_new_tokens, + do_sample=False, + return_full_text=False, + ) + if not outputs: + return "" + text = outputs[0].get("generated_text", "") + return str(text).strip() + + +def build_memory( + *, + llm_provider: str, + embedder_provider: str, + vector_store_provider: str, + embedding_dims: int, + history_db_path: str, + qdrant_path: Optional[str] = None, + llm_model: Optional[str] = None, + embedder_model: Optional[str] = None, + full_potential: bool = True, +) -> Memory: + """Build Engram Memory for LongMemEval. By default uses full potential (echo, categories, graph, scenes, profiles).""" + vector_cfg: Dict[str, Any] = { + "collection_name": "engram_longmemeval", + "embedding_model_dims": embedding_dims, + } + if vector_store_provider == "qdrant": + vector_cfg["path"] = qdrant_path or os.path.join(os.path.expanduser("~"), ".engram", "qdrant-longmemeval") + + llm_cfg: Dict[str, Any] = {} + if llm_model: + llm_cfg["model"] = llm_model + embedder_cfg: Dict[str, Any] = {"embedding_dims": embedding_dims} + if embedder_model: + embedder_cfg["model"] = embedder_model + + config = MemoryConfig( + vector_store=VectorStoreConfig(provider=vector_store_provider, config=vector_cfg), + llm=LLMConfig(provider=llm_provider, config=llm_cfg), + embedder=EmbedderConfig(provider=embedder_provider, config=embedder_cfg), + history_db_path=history_db_path, + embedding_model_dims=embedding_dims, + echo=EchoMemConfig(enable_echo=full_potential), + category=CategoryMemConfig(use_llm_categorization=full_potential, enable_categories=full_potential), + graph=KnowledgeGraphConfig(enable_graph=full_potential), + scene=SceneConfig(use_llm_summarization=full_potential, enable_scenes=full_potential), + profile=ProfileConfig(use_llm_extraction=full_potential, enable_profiles=full_potential), + ) + return Memory(config) + + +def build_context_text(results: Sequence[Dict[str, Any]], max_chars: int) -> str: + chunks: List[str] = [] + total = 0 + for result in results: + if result.get("masked"): + continue + text = str(result.get("memory") or result.get("details") or "").strip() + if not text: + continue + if total + len(text) > max_chars and chunks: + break + chunks.append(text) + total += len(text) + if not chunks: + return "No relevant retrieved history." + return "\n\n".join(chunks) + + +def build_output_row( + *, + question_id: str, + hypothesis: str, + retrieved_session_ids: Sequence[str], + retrieval_metrics: Dict[str, float], + include_debug_fields: bool, +) -> Dict[str, Any]: + """Build evaluator-compatible output row with optional debug fields.""" + row: Dict[str, Any] = { + "question_id": question_id, + "hypothesis": hypothesis, + } + if include_debug_fields: + row["retrieved_session_ids"] = list(retrieved_session_ids) + row["retrieval_metrics"] = dict(retrieval_metrics) + return row + + +def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: + with open(args.dataset_path, "r", encoding="utf-8") as f: + dataset = json.load(f) + if not isinstance(dataset, list): + raise ValueError("Dataset file must be a JSON list of instances.") + + selected = dataset[args.start_index : args.end_index if args.end_index > 0 else None] + if args.max_questions > 0: + selected = selected[: args.max_questions] + if args.skip_abstention: + selected = [entry for entry in selected if "_abs" not in str(entry.get("question_id", ""))] + + memory = build_memory( + llm_provider=args.llm_provider, + embedder_provider=args.embedder_provider, + vector_store_provider=args.vector_store_provider, + embedding_dims=args.embedding_dims, + history_db_path=args.history_db_path, + qdrant_path=args.qdrant_path, + llm_model=args.llm_model, + embedder_model=args.embedder_model, + full_potential=args.full_potential, + ) + + hf_responder: Optional[HFResponder] = None + if args.answer_backend == "hf": + hf_responder = HFResponder(model_name=args.hf_model, max_new_tokens=args.hf_max_new_tokens) + + output_path = Path(args.output_jsonl) + output_path.parent.mkdir(parents=True, exist_ok=True) + retrieval_path = Path(args.retrieval_jsonl) if args.retrieval_jsonl else None + if retrieval_path: + retrieval_path.parent.mkdir(parents=True, exist_ok=True) + + per_question_metrics: List[Dict[str, float]] = [] + processed = 0 + + with output_path.open("w", encoding="utf-8") as out_f: + retrieval_f = retrieval_path.open("w", encoding="utf-8") if retrieval_path else None + try: + for entry in selected: + question_id = str(entry.get("question_id", "")) + if not question_id: + continue + + # Keep each question isolated. + memory.delete_all(user_id=args.user_id) + + session_ids = entry.get("haystack_session_ids") or [] + session_dates = entry.get("haystack_dates") or [] + sessions = entry.get("haystack_sessions") or [] + for sess_id, sess_date, sess_turns in zip(session_ids, session_dates, sessions): + payload = format_session_memory(str(sess_id), str(sess_date), sess_turns or []) + memory.add( + messages=payload, + user_id=args.user_id, + metadata={ + "session_id": str(sess_id), + "session_date": str(sess_date), + "question_id": question_id, + }, + categories=["longmemeval", "session"], + infer=False, + ) + + query = str(entry.get("question", "")).strip() + search_payload = memory.search_with_context( + query=query, + user_id=args.user_id, + limit=args.top_k, + ) + results = search_payload.get("results", []) + + retrieved_session_ids = dedupe_preserve_order( + [ + sid + for sid in [parse_session_id_from_result(result) for result in results] + if sid is not None + ] + ) + metrics = compute_session_metrics( + retrieved_session_ids=retrieved_session_ids, + answer_session_ids=entry.get("answer_session_ids", []), + ) + per_question_metrics.append(metrics) + + context = build_context_text(results, max_chars=args.max_context_chars) + prompt = build_answer_prompt(question=query, retrieved_context=context) + + if args.answer_backend == "hf": + assert hf_responder is not None + hypothesis = hf_responder.generate(prompt) + else: + hypothesis = str(memory.llm.generate(prompt)).strip() + + output_row = build_output_row( + question_id=question_id, + hypothesis=hypothesis, + retrieved_session_ids=retrieved_session_ids[: args.top_k], + retrieval_metrics=metrics, + include_debug_fields=args.include_debug_fields, + ) + out_f.write(json.dumps(output_row, ensure_ascii=False) + "\n") + + if retrieval_f is not None: + retrieval_row = { + "question_id": question_id, + "answer_session_ids": entry.get("answer_session_ids", []), + "retrieved_session_ids": retrieved_session_ids[: args.top_k], + "metrics": metrics, + } + retrieval_f.write(json.dumps(retrieval_row, ensure_ascii=False) + "\n") + + processed += 1 + if args.print_every > 0 and processed % args.print_every == 0: + print(f"[LongMemEval] processed={processed} question_id={question_id}") + finally: + if retrieval_f is not None: + retrieval_f.close() + + aggregate: Dict[str, float] = {} + if per_question_metrics: + for key in sorted(per_question_metrics[0].keys()): + aggregate[key] = round(mean(metric[key] for metric in per_question_metrics), 4) + + summary = { + "processed": processed, + "output_jsonl": str(output_path), + "retrieval_jsonl": str(retrieval_path) if retrieval_path else None, + "aggregate_retrieval_metrics": aggregate, + "answer_backend": args.answer_backend, + "hf_model": args.hf_model if args.answer_backend == "hf" else None, + } + print(json.dumps(summary, indent=2)) + return summary + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Engram on LongMemEval in Colab or local environments.") + parser.add_argument("--dataset-path", required=True, help="Path to LongMemEval json file.") + parser.add_argument("--output-jsonl", required=True, help="Path to write question_id/hypothesis jsonl.") + parser.add_argument("--retrieval-jsonl", default=None, help="Optional path for retrieval-only log jsonl.") + parser.add_argument( + "--include-debug-fields", + action="store_true", + help="Include retrieval debug fields in output jsonl rows (official evaluator only needs question_id/hypothesis).", + ) + + parser.add_argument( + "--minimal", + action="store_true", + help="Disable echo, categories, graph, scenes, profiles (faster but lower retrieval quality). Default is full potential.", + ) + parser.add_argument("--user-id", default="longmemeval", help="User scope used for temporary ingestion.") + parser.add_argument("--start-index", type=int, default=0, help="Start offset for dataset slicing.") + parser.add_argument("--end-index", type=int, default=-1, help="End offset for dataset slicing (exclusive).") + parser.add_argument("--max-questions", type=int, default=-1, help="Cap number of evaluated questions.") + parser.add_argument("--skip-abstention", action="store_true", help="Skip *_abs questions.") + + parser.add_argument("--top-k", type=int, default=8, help="Number of retrieved memories for context.") + parser.add_argument("--max-context-chars", type=int, default=12000, help="Maximum context size passed to reader.") + parser.add_argument("--print-every", type=int, default=25, help="Progress print interval.") + + parser.add_argument( + "--answer-backend", + choices=["hf", "engram-llm"], + default="hf", + help="Reader backend for hypothesis generation.", + ) + parser.add_argument("--hf-model", default="Qwen/Qwen2.5-1.5B-Instruct", help="HF model when --answer-backend hf.") + parser.add_argument("--hf-max-new-tokens", type=int, default=128, help="Generation cap for HF backend.") + + parser.add_argument( + "--llm-provider", + choices=["mock", "gemini", "openai", "ollama", "nvidia"], + default="mock", + help="Engram LLM provider (used for --answer-backend engram-llm).", + ) + parser.add_argument("--llm-model", default=None, help="Optional LLM model override.") + parser.add_argument( + "--embedder-provider", + choices=["simple", "gemini", "openai", "ollama", "nvidia"], + default="simple", + help="Engram embedder provider for retrieval.", + ) + parser.add_argument("--embedder-model", default=None, help="Optional embedder model override.") + parser.add_argument("--embedding-dims", type=int, default=1536, help="Embedding dimensions for simple/memory configs.") + parser.add_argument("--vector-store-provider", choices=["memory", "qdrant"], default="memory") + parser.add_argument("--qdrant-path", default="/content/qdrant-longmemeval", help="Qdrant path when using qdrant.") + parser.add_argument("--history-db-path", default="/content/engram-longmemeval.db", help="SQLite db path.") + args = parser.parse_args() + args.full_potential = not args.minimal + return args + + +def main() -> None: + args = parse_args() + run_longmemeval(args) + + +if __name__ == "__main__": + main() diff --git a/engram/configs/base.py b/engram/configs/base.py index 24d08a1..ce7bba8 100644 --- a/engram/configs/base.py +++ b/engram/configs/base.py @@ -119,6 +119,9 @@ class HandoffConfig(BaseModel): enable_handoff: bool = True auto_enrich: bool = True # LLM-enrich digests with linked memories max_sessions_per_user: int = 100 # retain last N sessions + handoff_backend: str = "hosted" # hosted|local + strict_handoff_auth: bool = True + allow_auto_trusted_bootstrap: bool = False auto_session_bus: bool = True auto_checkpoint_events: List[str] = Field( default_factory=lambda: ["tool_complete", "agent_pause", "agent_end"] @@ -183,7 +186,7 @@ class MemoryConfig(BaseModel): echo: EchoMemConfig = Field(default_factory=EchoMemConfig) category: CategoryMemConfig = Field(default_factory=CategoryMemConfig) scope: ScopeConfig = Field(default_factory=ScopeConfig) - graph: KnowledgeGraphConfig = Field(default_factory=lambda: KnowledgeGraphConfig()) + graph: KnowledgeGraphConfig = Field(default_factory=KnowledgeGraphConfig) scene: SceneConfig = Field(default_factory=SceneConfig) profile: ProfileConfig = Field(default_factory=ProfileConfig) handoff: HandoffConfig = Field(default_factory=HandoffConfig) diff --git a/engram/core/handoff_backend.py b/engram/core/handoff_backend.py new file mode 100644 index 0000000..ac18fbd --- /dev/null +++ b/engram/core/handoff_backend.py @@ -0,0 +1,445 @@ +"""Handoff backend adapters for local and hosted continuity paths.""" + +from __future__ import annotations + +import os +from typing import Any, Dict, List, Optional + +from engram.memory.client import MemoryClient + + +class HandoffBackendError(RuntimeError): + """Structured handoff backend error.""" + + def __init__(self, code: str, message: str): + self.code = str(code) + self.message = str(message) + super().__init__(self.message) + + def to_dict(self) -> Dict[str, str]: + return {"code": self.code, "message": self.message} + + +def classify_handoff_error(exc: Exception) -> HandoffBackendError: + message = str(exc).strip() or exc.__class__.__name__ + lowered = message.lower() + + if "missing required capability" in lowered or "not allowed by policy" in lowered: + return HandoffBackendError("missing_capability", message) + if "capability token required" in lowered or "invalid capability token" in lowered or "session expired" in lowered: + return HandoffBackendError("missing_or_expired_token", message) + if " 401" in lowered or " 403" in lowered or "unauthorized" in lowered or "forbidden" in lowered: + return HandoffBackendError("missing_or_expired_token", message) + if "connection" in lowered or "timed out" in lowered or "max retries exceeded" in lowered: + return HandoffBackendError("hosted_backend_unavailable", message) + if "no matching lane found" in lowered or "unable to resolve or create handoff lane" in lowered: + return HandoffBackendError("lane_resolution_failed", message) + return HandoffBackendError("handoff_error", message) + + +class LocalHandoffBackend: + """Handoff adapter using in-process Memory APIs.""" + + def __init__(self, memory): + self.memory = memory + + def _session_token( + self, + *, + user_id: str, + requester_agent_id: Optional[str], + capabilities: List[str], + namespace: str, + ) -> str: + try: + session = self.memory.create_session( + user_id=user_id, + agent_id=requester_agent_id, + allowed_confidentiality_scopes=["work", "personal", "finance", "health", "private"], + capabilities=capabilities, + namespaces=[namespace], + ttl_minutes=24 * 60, + ) + except Exception as exc: + raise classify_handoff_error(exc) from exc + token = session.get("token") + if not token: + raise HandoffBackendError("missing_or_expired_token", "Session token was not issued") + return token + + def save_session_digest( + self, + *, + user_id: str, + agent_id: str, + requester_agent_id: str, + namespace: str, + digest: Dict[str, Any], + ) -> Dict[str, Any]: + token = self._session_token( + user_id=user_id, + requester_agent_id=requester_agent_id, + capabilities=["write_handoff"], + namespace=namespace, + ) + try: + return self.memory.save_session_digest( + user_id, + agent_id, + digest, + token=token, + requester_agent_id=requester_agent_id, + ) + except Exception as exc: + raise classify_handoff_error(exc) from exc + + def get_last_session( + self, + *, + user_id: str, + agent_id: Optional[str], + requester_agent_id: str, + namespace: str, + repo: Optional[str], + statuses: Optional[List[str]] = None, + ) -> Optional[Dict[str, Any]]: + token = self._session_token( + user_id=user_id, + requester_agent_id=requester_agent_id, + capabilities=["read_handoff"], + namespace=namespace, + ) + try: + return self.memory.get_last_session( + user_id, + agent_id=agent_id, + repo=repo, + statuses=statuses, + token=token, + requester_agent_id=requester_agent_id, + ) + except Exception as exc: + raise classify_handoff_error(exc) from exc + + def list_sessions( + self, + *, + user_id: str, + agent_id: Optional[str], + requester_agent_id: str, + namespace: str, + repo: Optional[str], + status: Optional[str], + statuses: Optional[List[str]], + limit: int, + ) -> List[Dict[str, Any]]: + token = self._session_token( + user_id=user_id, + requester_agent_id=requester_agent_id, + capabilities=["read_handoff"], + namespace=namespace, + ) + try: + return self.memory.list_sessions( + user_id=user_id, + agent_id=agent_id, + repo=repo, + status=status, + statuses=statuses, + limit=limit, + token=token, + requester_agent_id=requester_agent_id, + ) + except Exception as exc: + raise classify_handoff_error(exc) from exc + + def auto_resume_context( + self, + *, + user_id: str, + agent_id: str, + namespace: str, + repo_path: str, + branch: Optional[str], + lane_type: str, + objective: str, + agent_role: Optional[str], + statuses: Optional[List[str]] = None, + ) -> Dict[str, Any]: + token = self._session_token( + user_id=user_id, + requester_agent_id=agent_id, + capabilities=["read_handoff"], + namespace=namespace, + ) + try: + return self.memory.auto_resume_context( + user_id=user_id, + agent_id=agent_id, + repo_path=repo_path, + branch=branch, + lane_type=lane_type, + objective=objective, + agent_role=agent_role, + namespace=namespace, + statuses=statuses, + token=token, + requester_agent_id=agent_id, + auto_create=True, + ) + except Exception as exc: + raise classify_handoff_error(exc) from exc + + def auto_checkpoint( + self, + *, + user_id: str, + agent_id: str, + namespace: str, + repo_path: str, + branch: Optional[str], + lane_id: Optional[str], + lane_type: str, + objective: str, + agent_role: Optional[str], + confidentiality_scope: str, + payload: Dict[str, Any], + event_type: str, + ) -> Dict[str, Any]: + token = self._session_token( + user_id=user_id, + requester_agent_id=agent_id, + capabilities=["write_handoff"], + namespace=namespace, + ) + try: + return self.memory.auto_checkpoint( + user_id=user_id, + agent_id=agent_id, + payload=payload, + event_type=event_type, + repo_path=repo_path, + branch=branch, + lane_id=lane_id, + lane_type=lane_type, + objective=objective, + agent_role=agent_role, + namespace=namespace, + confidentiality_scope=confidentiality_scope, + token=token, + requester_agent_id=agent_id, + ) + except Exception as exc: + raise classify_handoff_error(exc) from exc + + +class HostedHandoffBackend: + """Handoff adapter using hosted Engram REST APIs.""" + + def __init__(self, api_url: str): + host = str(api_url).strip() + if not host: + raise HandoffBackendError("hosted_backend_unavailable", "ENGRAM_API_URL is not configured") + self.client = MemoryClient( + host=host, + api_key=os.environ.get("ENGRAM_API_KEY"), + org_id=os.environ.get("ENGRAM_ORG_ID"), + project_id=os.environ.get("ENGRAM_PROJECT_ID"), + admin_key=os.environ.get("ENGRAM_ADMIN_KEY"), + ) + + def _session_token( + self, + *, + user_id: str, + requester_agent_id: Optional[str], + capabilities: List[str], + namespace: str, + ) -> str: + try: + session = self.client.create_session( + user_id=user_id, + agent_id=requester_agent_id, + allowed_confidentiality_scopes=["work", "personal", "finance", "health", "private"], + capabilities=capabilities, + namespaces=[namespace], + ttl_minutes=24 * 60, + ) + except Exception as exc: + raise classify_handoff_error(exc) from exc + token = session.get("token") + if not token: + raise HandoffBackendError("missing_or_expired_token", "Session token was not issued") + return token + + def save_session_digest( + self, + *, + user_id: str, + agent_id: str, + requester_agent_id: str, + namespace: str, + digest: Dict[str, Any], + ) -> Dict[str, Any]: + self._session_token( + user_id=user_id, + requester_agent_id=requester_agent_id, + capabilities=["write_handoff"], + namespace=namespace, + ) + payload = dict(digest) + payload["user_id"] = user_id + payload["agent_id"] = agent_id + payload["requester_agent_id"] = requester_agent_id + try: + return self.client.save_session_digest(**payload) + except Exception as exc: + raise classify_handoff_error(exc) from exc + + def get_last_session( + self, + *, + user_id: str, + agent_id: Optional[str], + requester_agent_id: str, + namespace: str, + repo: Optional[str], + statuses: Optional[List[str]] = None, + ) -> Optional[Dict[str, Any]]: + self._session_token( + user_id=user_id, + requester_agent_id=requester_agent_id, + capabilities=["read_handoff"], + namespace=namespace, + ) + try: + return self.client.get_last_session( + user_id=user_id, + agent_id=agent_id, + requester_agent_id=requester_agent_id, + repo=repo, + statuses=statuses, + ) + except Exception as exc: + raise classify_handoff_error(exc) from exc + + def list_sessions( + self, + *, + user_id: str, + agent_id: Optional[str], + requester_agent_id: str, + namespace: str, + repo: Optional[str], + status: Optional[str], + statuses: Optional[List[str]], + limit: int, + ) -> List[Dict[str, Any]]: + self._session_token( + user_id=user_id, + requester_agent_id=requester_agent_id, + capabilities=["read_handoff"], + namespace=namespace, + ) + try: + payload = self.client.list_sessions( + user_id=user_id, + agent_id=agent_id, + requester_agent_id=requester_agent_id, + repo=repo, + status=status, + statuses=statuses, + limit=limit, + ) + except Exception as exc: + raise classify_handoff_error(exc) from exc + return list(payload.get("sessions", [])) + + def auto_resume_context( + self, + *, + user_id: str, + agent_id: str, + namespace: str, + repo_path: str, + branch: Optional[str], + lane_type: str, + objective: str, + agent_role: Optional[str], + statuses: Optional[List[str]] = None, + ) -> Dict[str, Any]: + self._session_token( + user_id=user_id, + requester_agent_id=agent_id, + capabilities=["read_handoff"], + namespace=namespace, + ) + try: + return self.client.handoff_resume( + user_id=user_id, + agent_id=agent_id, + requester_agent_id=agent_id, + repo_path=repo_path, + branch=branch, + lane_type=lane_type, + objective=objective, + agent_role=agent_role, + namespace=namespace, + statuses=statuses, + auto_create=True, + ) + except Exception as exc: + raise classify_handoff_error(exc) from exc + + def auto_checkpoint( + self, + *, + user_id: str, + agent_id: str, + namespace: str, + repo_path: str, + branch: Optional[str], + lane_id: Optional[str], + lane_type: str, + objective: str, + agent_role: Optional[str], + confidentiality_scope: str, + payload: Dict[str, Any], + event_type: str, + ) -> Dict[str, Any]: + self._session_token( + user_id=user_id, + requester_agent_id=agent_id, + capabilities=["write_handoff"], + namespace=namespace, + ) + checkpoint_payload = dict(payload) + checkpoint_payload.update( + { + "user_id": user_id, + "agent_id": agent_id, + "requester_agent_id": agent_id, + "repo_path": repo_path, + "branch": branch, + "lane_id": lane_id, + "lane_type": lane_type, + "objective": objective, + "agent_role": agent_role, + "namespace": namespace, + "confidentiality_scope": confidentiality_scope, + "event_type": event_type, + } + ) + try: + return self.client.handoff_checkpoint(**checkpoint_payload) + except Exception as exc: + raise classify_handoff_error(exc) from exc + + +def create_handoff_backend(memory): + """Create the configured handoff backend for MCP continuity paths.""" + api_url = os.environ.get("ENGRAM_API_URL") + + if api_url: + return HostedHandoffBackend(api_url=api_url) + return LocalHandoffBackend(memory) diff --git a/engram/core/handoff_bus.py b/engram/core/handoff_bus.py index 7e2da78..46f8759 100644 --- a/engram/core/handoff_bus.py +++ b/engram/core/handoff_bus.py @@ -3,17 +3,21 @@ from __future__ import annotations import logging -from datetime import datetime +import json +from datetime import datetime, timezone + +_UTC = timezone.utc from typing import Any, Dict, List, Optional, Tuple from engram.core.policy import ALL_CONFIDENTIALITY_SCOPES, DEFAULT_CAPABILITIES, HANDOFF_CAPABILITIES from engram.utils.repo_identity import canonicalize_repo_identity logger = logging.getLogger(__name__) +HANDOFF_SESSION_STATUSES = {"active", "paused", "completed", "abandoned"} def _utc_now_iso() -> str: - return datetime.utcnow().isoformat() + return datetime.now(tz=_UTC).isoformat() def _safe_dt(value: Optional[str]) -> Optional[datetime]: @@ -40,6 +44,13 @@ def _merge_list_values(existing: Any, incoming: Any) -> List[str]: return merged +def _stable_json(value: Any) -> str: + try: + return json.dumps(value, sort_keys=True, default=str) + except Exception: + return str(value) + + class HandoffSessionBus: """Server-side session bus with lane routing and automatic checkpointing.""" @@ -57,9 +68,15 @@ def __init__( cfg = config or {} self.auto_enrich = bool(cfg.get("auto_enrich", True)) self.max_sessions_per_user = int(cfg.get("max_sessions", 100)) + self.handoff_backend = str(cfg.get("handoff_backend", "hosted")) + self.strict_handoff_auth = bool(cfg.get("strict_handoff_auth", True)) + self.allow_auto_trusted_bootstrap = bool(cfg.get("allow_auto_trusted_bootstrap", False)) self.max_lanes_per_user = int(cfg.get("max_lanes_per_user", 50)) self.max_checkpoints_per_lane = int(cfg.get("max_checkpoints_per_lane", 200)) - self.resume_statuses = [str(v).strip() for v in cfg.get("resume_statuses", ["active", "paused"]) if str(v).strip()] + self.resume_statuses = self._normalize_status_list( + cfg.get("resume_statuses", ["active", "paused"]), + fallback=["active", "paused"], + ) self.lane_inactivity_minutes = int(cfg.get("lane_inactivity_minutes", 240)) self.auto_trusted_agents = { str(agent).strip().lower() @@ -90,7 +107,7 @@ def auto_resume_context( ) -> Dict[str, Any]: self._bootstrap_auto_trusted_policy(user_id=user_id, agent_id=agent_id, namespace=namespace) repo_identity = canonicalize_repo_identity(repo_path, branch=branch) - allowed_statuses = statuses or list(self.resume_statuses) + allowed_statuses = self._normalize_status_list(statuses, fallback=list(self.resume_statuses)) lane, created = self._select_or_create_lane( user_id=user_id, @@ -195,6 +212,7 @@ def auto_checkpoint( ) target_version = int(lane.get("version", 0)) + 1 + persisted_version = target_version lane_status = str(normalized_payload.get("status") or lane.get("status") or "active") lane_updates = { "status": lane_status, @@ -219,7 +237,7 @@ def auto_checkpoint( fresh_lane = self.db.get_handoff_lane(lane["id"]) or lane fresh_state = dict(fresh_lane.get("current_state") or {}) resolved_state, merge_conflicts = self._merge_state(fresh_state, normalized_payload) - all_conflicts = list(conflicts) + list(merge_conflicts) + all_conflicts = self._dedupe_conflicts(list(conflicts) + list(merge_conflicts)) self.db.update_handoff_lane( lane["id"], { @@ -232,6 +250,13 @@ def auto_checkpoint( ) conflicts = all_conflicts merged_state = resolved_state + persisted = self.db.get_handoff_lane(lane["id"]) + if persisted: + persisted_version = int(persisted.get("version", persisted_version)) + else: + persisted = self.db.get_handoff_lane(lane["id"]) + if persisted: + persisted_version = int(persisted.get("version", persisted_version)) if conflicts: self.db.add_handoff_lane_conflict( @@ -254,7 +279,7 @@ def auto_checkpoint( "lane_id": lane["id"], "checkpoint_id": checkpoint_id, "status": lane_status, - "version": target_version, + "version": persisted_version, "conflicts": conflicts, "enrichment": enrichment, } @@ -272,6 +297,7 @@ def finalize_lane( agent_role: Optional[str] = None, namespace: str = "default", ) -> Dict[str, Any]: + normalized_status = self._normalize_status(status, default="paused") result = self.auto_checkpoint( user_id=user_id, agent_id=agent_id, @@ -285,8 +311,8 @@ def finalize_lane( ) lane = self.db.get_handoff_lane(lane_id) if lane: - self.db.update_handoff_lane(lane_id, {"status": status}) - result["lane_status"] = status + self.db.update_handoff_lane(lane_id, {"status": normalized_status}) + result["lane_status"] = normalized_status return result def list_lanes( @@ -299,11 +325,17 @@ def list_lanes( limit: int = 20, ) -> List[Dict[str, Any]]: repo_identity = canonicalize_repo_identity(repo_path, branch=None) if repo_path else {"repo_id": None} + normalized_status = self._normalize_optional_status(status) + normalized_statuses = ( + self._normalize_status_list(statuses, fallback=[], allow_empty=True) + if statuses is not None + else None + ) return self.db.list_handoff_lanes( user_id=user_id, repo_id=repo_identity.get("repo_id"), - status=status, - statuses=statuses, + status=normalized_status, + statuses=normalized_statuses, limit=limit, ) @@ -314,7 +346,7 @@ def list_lanes( def save_session_digest(self, user_id: str, agent_id: str, digest: Dict[str, Any]) -> Dict[str, Any]: repo_path = digest.get("repo") repo_identity = canonicalize_repo_identity(repo_path, branch=digest.get("branch")) - status = str(digest.get("status") or "paused") + status = self._normalize_status(digest.get("status"), default="paused") checkpoint_payload = { "status": status, "task_summary": digest.get("task_summary"), @@ -429,7 +461,7 @@ def get_last_session( statuses: Optional[List[str]] = None, ) -> Optional[Dict[str, Any]]: repo_identity = canonicalize_repo_identity(repo, branch=None) if repo else {"repo_id": None} - preferred_statuses = statuses or list(self.resume_statuses) + preferred_statuses = self._normalize_status_list(statuses, fallback=list(self.resume_statuses)) repo_candidates: List[Optional[str]] = [repo_identity.get("repo_id")] if repo_candidates[0] is not None: repo_candidates.append(None) @@ -445,28 +477,34 @@ def get_last_session( if session: return self.get_handoff_context(session["id"]) + # Compatibility fallback: if preferred-status legacy sessions are absent, + # derive context from active lane/checkpoint state before broadening status. for repo_id in repo_candidates: - session = self.db.get_last_handoff_session( + lane_packet = self._latest_lane_resume_packet( user_id=user_id, agent_id=agent_id, - repo=repo if repo_id is not None else None, repo_id=repo_id, - statuses=None, + statuses=preferred_statuses, ) - if session: - return self.get_handoff_context(session["id"]) + if lane_packet: + return lane_packet + + # Historical fallback is only used for default behavior. If callers pass + # explicit statuses, respect that filter strictly. + if statuses is not None: + return None - # Compatibility fallback: if legacy sessions are absent, derive context - # from the latest lane/checkpoint state so resume still works. for repo_id in repo_candidates: - lane_packet = self._latest_lane_resume_packet( + session = self.db.get_last_handoff_session( user_id=user_id, agent_id=agent_id, + repo=repo if repo_id is not None else None, repo_id=repo_id, - statuses=preferred_statuses, + statuses=None, ) - if lane_packet: - return lane_packet + if session: + return self.get_handoff_context(session["id"]) + for repo_id in repo_candidates: lane_packet = self._latest_lane_resume_packet( user_id=user_id, @@ -489,13 +527,19 @@ def list_sessions( limit: int = 20, ) -> List[Dict[str, Any]]: repo_identity = canonicalize_repo_identity(repo, branch=None) if repo else {"repo_id": None} + normalized_status = self._normalize_optional_status(status) + normalized_statuses = ( + self._normalize_status_list(statuses, fallback=[], allow_empty=True) + if statuses is not None + else None + ) sessions = self.db.list_handoff_sessions( user_id=user_id, agent_id=agent_id, repo=repo, repo_id=repo_identity.get("repo_id"), - status=status, - statuses=statuses, + status=normalized_status, + statuses=normalized_statuses, limit=limit, ) if sessions: @@ -505,8 +549,8 @@ def list_sessions( user_id=user_id, agent_id=agent_id, repo_id=repo_identity.get("repo_id"), - status=status, - statuses=statuses, + status=normalized_status, + statuses=normalized_statuses, limit=limit, ) if lane_sessions or repo_identity.get("repo_id") is None: @@ -515,8 +559,8 @@ def list_sessions( user_id=user_id, agent_id=agent_id, repo_id=None, - status=status, - statuses=statuses, + status=normalized_status, + statuses=normalized_statuses, limit=limit, ) @@ -524,18 +568,27 @@ def list_sessions( # Internal helpers # ------------------------------------------------------------------ + # Cache bootstrapped policies to avoid a DB query on every checkpoint/resume. + _bootstrapped_policies: set = set() + def _bootstrap_auto_trusted_policy(self, *, user_id: str, agent_id: Optional[str], namespace: str) -> None: + if not self.allow_auto_trusted_bootstrap: + return if not user_id or not agent_id: return normalized_agent = str(agent_id).strip().lower() if normalized_agent not in self.auto_trusted_agents: return + cache_key = f"{user_id}::{normalized_agent}" + if cache_key in self._bootstrapped_policies: + return existing = self.db.get_agent_policy( user_id=user_id, agent_id=agent_id, include_wildcard=False, ) if existing: + self._bootstrapped_policies.add(cache_key) return capabilities = sorted(set(list(DEFAULT_CAPABILITIES) + list(HANDOFF_CAPABILITIES))) namespaces = ["default"] @@ -549,6 +602,81 @@ def _bootstrap_auto_trusted_policy(self, *, user_id: str, agent_id: Optional[str allowed_capabilities=capabilities, allowed_namespaces=namespaces, ) + self._bootstrapped_policies.add(cache_key) + + @staticmethod + def _normalize_status(value: Optional[str], *, default: str) -> str: + normalized = str(value or "").strip().lower() + if normalized in HANDOFF_SESSION_STATUSES: + return normalized + return default + + @staticmethod + def _normalize_optional_status(value: Optional[str]) -> Optional[str]: + if value is None: + return None + normalized = str(value).strip().lower() + if normalized in HANDOFF_SESSION_STATUSES: + return normalized + raise ValueError( + "Invalid handoff status: " + f"{value!r}. Allowed: {', '.join(sorted(HANDOFF_SESSION_STATUSES))}" + ) + + @staticmethod + def _normalize_status_list( + values: Optional[List[str]], + *, + fallback: List[str], + allow_empty: bool = False, + ) -> List[str]: + if values is None: + return list(fallback) + + raw_values: List[str] + if isinstance(values, str): + raw_values = [v for v in values.split(",")] + else: + raw_values = [str(v) for v in values] + + normalized: List[str] = [] + invalid: List[str] = [] + for value in raw_values: + item = str(value).strip().lower() + if not item: + continue + if item not in HANDOFF_SESSION_STATUSES: + invalid.append(item) + continue + if item not in normalized: + normalized.append(item) + + if invalid: + bad = ", ".join(sorted(set(invalid))) + allowed = ", ".join(sorted(HANDOFF_SESSION_STATUSES)) + raise ValueError(f"Invalid handoff statuses: {bad}. Allowed: {allowed}") + + if normalized: + return normalized + if allow_empty: + return [] + return list(fallback) + + @staticmethod + def _dedupe_conflicts(conflicts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + deduped: List[Dict[str, Any]] = [] + seen = set() + for conflict in conflicts: + key = ( + conflict.get("field"), + _stable_json(conflict.get("previous")), + _stable_json(conflict.get("incoming")), + ) + if key in seen: + continue + seen.add(key) + deduped.append(conflict) + return deduped def _latest_lane_resume_packet( self, @@ -722,7 +850,11 @@ def _score_lane( last_checkpoint = _safe_dt(lane.get("last_checkpoint_at") or lane.get("updated_at") or lane.get("created_at")) if last_checkpoint: - age_minutes = max(0.0, (datetime.utcnow() - last_checkpoint).total_seconds() / 60.0) + now = datetime.now(tz=_UTC) + # Ensure last_checkpoint is offset-aware for comparison. + if last_checkpoint.tzinfo is None: + last_checkpoint = last_checkpoint.replace(tzinfo=_UTC) + age_minutes = max(0.0, (now - last_checkpoint).total_seconds() / 60.0) score += max(0.0, 0.1 - min(age_minutes, 24 * 60) / (24 * 60 * 10)) if age_minutes > self.lane_inactivity_minutes and lane.get("status") == "active": score -= 0.2 @@ -731,7 +863,7 @@ def _score_lane( def _normalize_checkpoint_payload(self, payload: Dict[str, Any]) -> Dict[str, Any]: payload = dict(payload or {}) normalized = { - "status": str(payload.get("status") or "active"), + "status": self._normalize_status(payload.get("status"), default="active"), "task_summary": str(payload.get("task_summary") or "").strip(), "decisions_made": _merge_list_values([], payload.get("decisions_made", [])), "files_touched": _merge_list_values([], payload.get("files_touched", [])), diff --git a/engram/core/kernel.py b/engram/core/kernel.py index 700d5c5..b2991b8 100644 --- a/engram/core/kernel.py +++ b/engram/core/kernel.py @@ -6,7 +6,7 @@ import os import secrets import time -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional, Set from engram.core.invariants import InvariantEngine @@ -136,7 +136,7 @@ def create_session( token = secrets.token_urlsafe(32) token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() - expires_at = (datetime.utcnow() + timedelta(minutes=max(1, ttl_minutes))).isoformat() + expires_at = (datetime.now(timezone.utc) + timedelta(minutes=max(1, ttl_minutes))).isoformat() session_id = self.db.create_session( { @@ -184,7 +184,7 @@ def authenticate_session( expires_at = session.get("expires_at") if expires_at: exp_dt = datetime.fromisoformat(expires_at) - if datetime.utcnow() > exp_dt: + if datetime.now(timezone.utc) > exp_dt: raise PermissionError("Session expired") if user_id and session.get("user_id") not in {None, user_id}: @@ -234,6 +234,8 @@ def _bootstrap_handoff_policy_if_trusted( namespaces: Optional[List[str]], ) -> Optional[Dict[str, Any]]: handoff_cfg = getattr(self.memory, "handoff_config", None) + if not bool(getattr(handoff_cfg, "allow_auto_trusted_bootstrap", False)): + return None trusted_agents = { str(value).strip().lower() for value in getattr(handoff_cfg, "auto_trusted_agents", []) @@ -289,6 +291,28 @@ def _normalize_policy_scopes(scopes: Optional[List[str]]) -> List[str]: ) return values + @staticmethod + def _clamp_with_policy( + *, + requested: List[str], + allowed: List[str], + label: str, + user_id: str, + agent_id: Optional[str], + ) -> List[str]: + """Generic clamping: intersect requested with allowed, respecting wildcards.""" + allowed_set = set(allowed) + if "*" in allowed_set: + return sorted(set(requested)) + if not allowed_set: + raise PermissionError(f"Agent policy denies {label} for user={user_id} agent={agent_id}") + clamped = [item for item in requested if item in allowed_set] + if not clamped: + raise PermissionError( + f"Requested {label} are not allowed by policy for user={user_id} agent={agent_id}" + ) + return sorted(set(clamped)) + def _clamp_scopes_with_policy( self, *, @@ -297,19 +321,16 @@ def _clamp_scopes_with_policy( user_id: str, agent_id: Optional[str], ) -> List[str]: - allowed = self._normalize_policy_scopes(policy_scopes) - if "*" in set(str(scope).strip() for scope in (policy_scopes or [])): + # Check raw policy for wildcard before normalization strips it. + if "*" in {str(s).strip() for s in (policy_scopes or [])}: return requested_scopes - if not allowed: - raise PermissionError( - f"Agent policy denies confidentiality scopes for user={user_id} agent={agent_id}" - ) - clamped = [scope for scope in requested_scopes if scope in set(allowed)] - if not clamped: - raise PermissionError( - f"Requested confidentiality scopes are not allowed by policy for user={user_id} agent={agent_id}" - ) - return sorted(set(clamped)) + return self._clamp_with_policy( + requested=requested_scopes, + allowed=self._normalize_policy_scopes(policy_scopes), + label="confidentiality scopes", + user_id=user_id, + agent_id=agent_id, + ) def _clamp_capabilities_with_policy( self, @@ -319,19 +340,13 @@ def _clamp_capabilities_with_policy( user_id: str, agent_id: Optional[str], ) -> List[str]: - allowed = self._normalize_policy_capabilities(policy_capabilities) - if "*" in set(allowed): - return sorted(set(requested_capabilities)) - if not allowed: - raise PermissionError( - f"Agent policy denies capabilities for user={user_id} agent={agent_id}" - ) - clamped = [capability for capability in requested_capabilities if capability in set(allowed)] - if not clamped: - raise PermissionError( - f"Requested capabilities are not allowed by policy for user={user_id} agent={agent_id}" - ) - return sorted(set(clamped)) + return self._clamp_with_policy( + requested=requested_capabilities, + allowed=self._normalize_policy_capabilities(policy_capabilities), + label="capabilities", + user_id=user_id, + agent_id=agent_id, + ) def _clamp_namespaces_with_policy( self, @@ -341,17 +356,13 @@ def _clamp_namespaces_with_policy( user_id: str, agent_id: Optional[str], ) -> List[str]: - allowed = self._normalize_policy_namespaces(policy_namespaces) - if "*" in set(allowed): - return sorted(set(requested_namespaces)) - if not allowed: - raise PermissionError(f"Agent policy denies namespaces for user={user_id} agent={agent_id}") - clamped = [namespace for namespace in requested_namespaces if namespace in set(allowed)] - if not clamped: - raise PermissionError( - f"Requested namespaces are not allowed by policy for user={user_id} agent={agent_id}" - ) - return sorted(set(clamped)) + return self._clamp_with_policy( + requested=requested_namespaces, + allowed=self._normalize_policy_namespaces(policy_namespaces), + label="namespaces", + user_id=user_id, + agent_id=agent_id, + ) def _resolve_allowed_namespaces( self, @@ -389,19 +400,17 @@ def _enforce_namespaces_on_results( allowed_namespaces: List[str], ) -> List[Dict[str, Any]]: if "*" in allowed_namespaces: - visible = [] + # Fast path: no masking needed, avoid copying dicts. for item in items: - value = dict(item) - value["masked"] = False - visible.append(value) - return visible + item["masked"] = False + return items + allowed_set = set(allowed_namespaces) filtered: List[Dict[str, Any]] = [] for item in items: namespace = self._normalize_namespace(item.get("namespace")) - if self._is_namespace_allowed(namespace, allowed_namespaces): - value = dict(item) - value["masked"] = bool(value.get("masked", False)) - filtered.append(value) + if namespace in allowed_set or "*" in allowed_set: + item["masked"] = bool(item.get("masked", False)) + filtered.append(item) else: filtered.append(self._mask_for_namespace(item)) return filtered @@ -437,6 +446,68 @@ def _passes_auto_merge_guardrails(self, trust_row: Dict[str, Any]) -> bool: rejection_rate = (rejected / total) if total > 0 else 1.0 return rejection_rate <= max(0.0, max_reject_rate) + def _enforce_write_quotas( + self, + *, + user_id: str, + agent_id: Optional[str], + ) -> None: + if not feature_enabled("ENGRAM_V2_POLICY_GATEWAY", default=True): + return + + now = datetime.now(timezone.utc) + windows: List[Dict[str, Any]] = [ + { + "env": "ENGRAM_V2_POLICY_WRITE_QUOTA_PER_USER_PER_HOUR", + "label": "per-user hourly", + "user_id": user_id, + "agent_id": None, + "since": (now - timedelta(hours=1)).isoformat(), + }, + { + "env": "ENGRAM_V2_POLICY_WRITE_QUOTA_PER_USER_PER_DAY", + "label": "per-user daily", + "user_id": user_id, + "agent_id": None, + "since": (now - timedelta(days=1)).isoformat(), + }, + ] + if agent_id: + windows.extend( + [ + { + "env": "ENGRAM_V2_POLICY_WRITE_QUOTA_PER_AGENT_PER_HOUR", + "label": "per-agent hourly", + "user_id": user_id, + "agent_id": agent_id, + "since": (now - timedelta(hours=1)).isoformat(), + }, + { + "env": "ENGRAM_V2_POLICY_WRITE_QUOTA_PER_AGENT_PER_DAY", + "label": "per-agent daily", + "user_id": user_id, + "agent_id": agent_id, + "since": (now - timedelta(days=1)).isoformat(), + }, + ] + ) + + for window in windows: + limit = self._parse_int_env(window["env"], 0) + if limit <= 0: + continue + + count = self.db.count_proposal_commits( + user_id=window["user_id"], + agent_id=window["agent_id"], + since=window["since"], + ) + if count >= limit: + raise PermissionError( + f"Write quota exceeded ({window['label']}): " + f"{count}/{limit} proposals in active window" + ) + # ------------------------------------------------------------------ # Read path # ------------------------------------------------------------------ @@ -659,6 +730,7 @@ def propose_write( ) if not self._is_namespace_allowed(namespace_value, allowed_write_namespaces): raise PermissionError(f"Namespace access denied: {namespace_value}") + self._enforce_write_quotas(user_id=user_id, agent_id=agent_id) self.db.ensure_namespace(user_id=user_id, name=namespace_value) if mode == "staging" and not feature_enabled("ENGRAM_V2_STAGING_WRITES", default=True): @@ -780,6 +852,35 @@ def _apply_direct_write( metadata = dict(metadata) metadata.update(provenance) metadata["allow_sensitive"] = True + namespace_value = self._normalize_namespace(metadata.get("namespace")) + metadata["namespace"] = namespace_value + + source_event_id = str(provenance.get("source_event_id") or "").strip() + source_app = provenance.get("source_app") or source_app + if source_event_id: + existing = self.db.get_memory_by_source_event( + user_id=user_id, + source_event_id=source_event_id, + namespace=namespace_value, + source_app=source_app, + ) + if existing: + existing_text = str(existing.get("memory") or "").strip() + proposed_text = str(content or "").strip() + if existing_text != proposed_text: + raise ValueError( + f"source_event_id={source_event_id} already exists with different content" + ) + return { + "mode": "direct", + "result": { + "results": [{"id": existing.get("id"), "status": "EXISTING"}], + "count": 1, + "idempotent": True, + }, + "created_ids": [], + } + sharing_scope = str(metadata.get("sharing_scope", "global")).lower() result = self.memory.add( messages=content, @@ -792,30 +893,29 @@ def _apply_direct_write( source_app=source_app, ) - for item in result.get("results", []): - memory_id = item.get("id") - if not memory_id: - continue - created = self.memory.db.get_memory(memory_id) + # Pre-compute provenance fields to patch onto created memories in one update. + patch_fields = { + "confidentiality_scope": metadata.get("confidentiality_scope", "work"), + "source_type": provenance.get("source_type"), + "source_app": provenance.get("source_app"), + "source_event_id": provenance.get("source_event_id"), + "status": "active", + "namespace": self._normalize_namespace(metadata.get("namespace")), + } + + created_ids: List[str] = [] + result_items = result.get("results", []) + memory_ids = [item.get("id") for item in result_items if item.get("id")] + + # Batch-fetch all created memories in one query instead of N queries. + created_map = self.memory.db.get_memories_bulk(memory_ids) if memory_ids else {} + + for memory_id in memory_ids: + created = created_map.get(memory_id) if not created: continue - created["confidentiality_scope"] = metadata.get("confidentiality_scope", "work") - created["source_type"] = provenance.get("source_type") - created["source_app"] = provenance.get("source_app") - created["source_event_id"] = provenance.get("source_event_id") - created["status"] = "active" - created["namespace"] = self._normalize_namespace(metadata.get("namespace")) - self.memory.db.update_memory( - memory_id, - { - "confidentiality_scope": created["confidentiality_scope"], - "source_type": created["source_type"], - "source_app": created["source_app"], - "source_event_id": created["source_event_id"], - "status": created["status"], - "namespace": created["namespace"], - }, - ) + created_ids.append(memory_id) + self.memory.db.update_memory(memory_id, patch_fields) self.episodic_store.ingest_memory_as_view( user_id=user_id, agent_id=agent_id, @@ -833,6 +933,7 @@ def _apply_direct_write( return { "mode": "direct", "result": result, + "created_ids": created_ids, } def list_pending_commits( @@ -912,8 +1013,14 @@ def approve_commit( source_app=patch.get("source_app"), ) applied.append(outcome) - for row in outcome.get("result", {}).get("results", []): - memory_id = row.get("id") + created_ids = outcome.get("created_ids") + if created_ids is None: + created_ids = [ + row.get("id") + for row in outcome.get("result", {}).get("results", []) + if row.get("id") + ] + for memory_id in created_ids: if memory_id: created_memory_ids.append(memory_id) elif target == "memory_item" and op == "UPDATE": @@ -1139,7 +1246,7 @@ def run_sleep_cycle( require_for_agent=bool(agent_id), required_capabilities=["run_sleep_cycle"], ) - target_date = date_str or (datetime.utcnow() - timedelta(days=1)).date().isoformat() + target_date = date_str or (datetime.now(timezone.utc) - timedelta(days=1)).date().isoformat() users = [user_id] if user_id else self.db.list_user_ids() if not users: users = ["default"] @@ -1159,8 +1266,11 @@ def run_sleep_cycle( "scenes_considered": 0, "decay": {"decayed": 0, "forgotten": 0, "promoted": 0}, } - memories = self.db.get_all_memories(user_id=uid) - day_memories = [m for m in memories if str(m.get("created_at", "")).startswith(target_date)] + day_memories = self.db.get_all_memories( + user_id=uid, + created_after=day_start, + created_before=day_end, + ) # Ensure CAST views/scenes are available for the day. for memory in day_memories: diff --git a/engram/core/profile.py b/engram/core/profile.py index 45c64f7..2336e73 100644 --- a/engram/core/profile.py +++ b/engram/core/profile.py @@ -58,15 +58,7 @@ class ProfileUpdate: ) -def _cosine_similarity(a: List[float], b: List[float]) -> float: - if not a or not b or len(a) != len(b): - return 0.0 - dot = sum(x * y for x, y in zip(a, b)) - norm_a = sum(x * x for x in a) ** 0.5 - norm_b = sum(x * x for x in b) ** 0.5 - if norm_a == 0 or norm_b == 0: - return 0.0 - return dot / (norm_a * norm_b) +from engram.utils.math import cosine_similarity as _cosine_similarity class ProfileProcessor: @@ -260,22 +252,14 @@ def apply_update( def _find_profile(self, name: str, user_id: str) -> Optional[Dict[str, Any]]: """Find a profile by name or alias, with fuzzy matching.""" - # Exact match first + # Fast path: exact or alias match (uses indexed SQL query). profile = self.db.get_profile_by_name(name, user_id=user_id) if profile: return profile - # Check all profiles for partial match - all_profiles = self.db.get_all_profiles(user_id=user_id) - name_lower = name.lower() - for p in all_profiles: - p_name = p["name"].lower() - aliases = [a.lower() for a in p.get("aliases", [])] - # Substring match (e.g. "John" matches "John Smith") - if name_lower in p_name or p_name in name_lower: - return p - if any(name_lower in a or a in name_lower for a in aliases): - return p + # Slow path: substring match on name (e.g. "John" matches "John Smith"). + if hasattr(self.db, "find_profile_by_substring"): + return self.db.find_profile_by_substring(name, user_id=user_id) return None @@ -408,6 +392,9 @@ def search_profiles( if not all_profiles: return [] + query_lower = query.lower() + query_words = query_lower.split() + if self.embedder: query_embedding = self.embedder.embed(query, memory_action="search") scored = [] @@ -417,9 +404,8 @@ def search_profiles( sim = _cosine_similarity(query_embedding, p_emb) scored.append((p, sim)) else: - # Keyword fallback text = f"{p.get('name', '')} {' '.join(p.get('facts', []))} {' '.join(p.get('preferences', []))}".lower() - kw_score = sum(1 for w in query.lower().split() if w in text) * 0.1 + kw_score = sum(1 for w in query_words if w in text) * 0.1 if kw_score > 0: scored.append((p, kw_score)) scored.sort(key=lambda x: x[1], reverse=True) @@ -429,13 +415,12 @@ def search_profiles( results.append(p) return results else: - # Keyword-only search - query_lower = query.lower() scored = [] for p in all_profiles: text = f"{p.get('name', '')} {' '.join(p.get('facts', []))} {' '.join(p.get('preferences', []))}".lower() - score = sum(1 for w in query_lower.split() if w in text) - if score > 0 or query_lower in p.get("name", "").lower(): - scored.append((p, score + (1 if query_lower in p.get("name", "").lower() else 0))) + name_match = query_lower in p.get("name", "").lower() + score = sum(1 for w in query_words if w in text) + if score > 0 or name_match: + scored.append((p, score + (1 if name_match else 0))) scored.sort(key=lambda x: x[1], reverse=True) return [p for p, _ in scored[:limit]] diff --git a/engram/core/scene.py b/engram/core/scene.py index 97567cd..b98beb3 100644 --- a/engram/core/scene.py +++ b/engram/core/scene.py @@ -15,7 +15,7 @@ import re import uuid from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) @@ -38,16 +38,7 @@ class SceneDetectionResult: ) -def _cosine_similarity(a: List[float], b: List[float]) -> float: - """Compute cosine similarity between two vectors.""" - if not a or not b or len(a) != len(b): - return 0.0 - dot = sum(x * y for x, y in zip(a, b)) - norm_a = sum(x * x for x in a) ** 0.5 - norm_b = sum(x * x for x in b) ** 0.5 - if norm_a == 0 or norm_b == 0: - return 0.0 - return dot / (norm_a * norm_b) +from engram.utils.math import cosine_similarity as _cosine_similarity def _detect_location(content: str) -> Optional[str]: @@ -116,13 +107,14 @@ def detect_boundary( # 3. Topic shift (cosine similarity) scene_embedding = current_scene.get("embedding") + topic_sim: Optional[float] = None if embedding and scene_embedding: - sim = _cosine_similarity(embedding, scene_embedding) - if sim < self.topic_threshold: + topic_sim = _cosine_similarity(embedding, scene_embedding) + if topic_sim < self.topic_threshold: return SceneDetectionResult( is_new_scene=True, reason="topic_shift", - topic_similarity=sim, + topic_similarity=topic_sim, ) # 4. Location change @@ -142,11 +134,7 @@ def detect_boundary( return SceneDetectionResult( is_new_scene=False, detected_location=detected_location, - topic_similarity=( - _cosine_similarity(embedding, scene_embedding) - if embedding and scene_embedding - else None - ), + topic_similarity=topic_sim, ) # ------------------------------------------------------------------ @@ -211,14 +199,16 @@ def add_memory_to_scene( if namespace: updates["namespace"] = namespace - # Running average of embeddings + # Running average of embeddings (incremental centroid). if embedding and scene.get("embedding"): old_emb = scene["embedding"] - n = max(position, 1) - updates["embedding"] = [ - (old_emb[i] * n + embedding[i]) / (n + 1) - for i in range(len(embedding)) - ] + if len(old_emb) == len(embedding): + n = max(position, 1) + inv = 1.0 / (n + 1) + updates["embedding"] = [ + old_emb[i] * (n * inv) + embedding[i] * inv + for i in range(len(embedding)) + ] self.db.update_scene(scene_id, updates) self.db.add_scene_memory(scene_id, memory_id, position=position) @@ -235,7 +225,7 @@ def close_scene(self, scene_id: str, timestamp: Optional[str] = None) -> None: updates: Dict[str, Any] = {} if not scene.get("end_time"): - updates["end_time"] = timestamp or datetime.utcnow().isoformat() + updates["end_time"] = timestamp or datetime.now(timezone.utc).isoformat() # Generate summary if self.use_llm_summarization and self.llm: @@ -262,7 +252,11 @@ def auto_close_stale(self, user_id: str) -> List[str]: try: last_dt = datetime.fromisoformat(end_time) - if datetime.utcnow() - last_dt > timedelta(minutes=self.auto_close_minutes): + now = datetime.now(timezone.utc) + # Make last_dt offset-aware if naive. + if last_dt.tzinfo is None: + last_dt = last_dt.replace(tzinfo=timezone.utc) + if now - last_dt > timedelta(minutes=self.auto_close_minutes): self.close_scene(open_scene["id"]) return [open_scene["id"]] except (ValueError, TypeError): @@ -311,17 +305,20 @@ def search_scenes( limit: int = 10, ) -> List[Dict[str, Any]]: """Search scenes by matching query against summaries and topics.""" - all_scenes = self.db.get_scenes(user_id=user_id, limit=limit * 5) + # Fetch a bounded candidate set (3x limit is sufficient for re-ranking). + candidate_limit = min(limit * 3, 150) + all_scenes = self.db.get_scenes(user_id=user_id, limit=candidate_limit) if not all_scenes: return [] + query_lower = query.lower() + query_words = query_lower.split() + if not self.embedder: - # Fallback: keyword match - query_lower = query.lower() scored = [] for s in all_scenes: text = f"{s.get('title', '')} {s.get('summary', '')} {s.get('topic', '')}".lower() - score = sum(1 for w in query_lower.split() if w in text) + score = sum(1 for w in query_words if w in text) if score > 0: scored.append((s, score)) scored.sort(key=lambda x: x[1], reverse=True) @@ -335,9 +332,8 @@ def search_scenes( sim = _cosine_similarity(query_embedding, scene_emb) scored.append((s, sim)) else: - # Fallback to text match text = f"{s.get('title', '')} {s.get('summary', '')} {s.get('topic', '')}".lower() - keyword_score = sum(1 for w in query.lower().split() if w in text) * 0.1 + keyword_score = sum(1 for w in query_words if w in text) * 0.1 if keyword_score > 0: scored.append((s, keyword_score)) diff --git a/engram/db/sqlite.py b/engram/db/sqlite.py index dbb1da8..eba54d3 100644 --- a/engram/db/sqlite.py +++ b/engram/db/sqlite.py @@ -1,11 +1,47 @@ import json +import logging import os import sqlite3 +import threading import uuid from contextlib import contextmanager -from datetime import datetime +from datetime import datetime, timedelta, timezone from typing import Any, Dict, Iterable, List, Optional +logger = logging.getLogger(__name__) + +# Phase 5: Allowed column names for dynamic UPDATE queries to prevent SQL injection. +VALID_MEMORY_COLUMNS = frozenset({ + "memory", "metadata", "categories", "embedding", "strength", + "layer", "tombstone", "updated_at", "related_memories", "source_memories", + "confidentiality_scope", "source_type", "source_app", "source_event_id", + "decay_lambda", "status", "importance", "sensitivity", "namespace", + "access_count", "last_accessed", "immutable", "expiration_date", + "scene_id", "user_id", "agent_id", "run_id", "app_id", +}) + +VALID_SCENE_COLUMNS = frozenset({ + "title", "summary", "topic", "location", "participants", "memory_ids", + "start_time", "end_time", "embedding", "strength", "access_count", + "tombstone", "layer", "scene_strength", "topic_embedding_ref", "namespace", +}) + +VALID_PROFILE_COLUMNS = frozenset({ + "name", "profile_type", "narrative", "facts", "preferences", + "relationships", "sentiment", "theory_of_mind", "aliases", + "embedding", "strength", "updated_at", "role_bias", "profile_summary", +}) + + +def _utcnow() -> datetime: + """Return current UTC datetime (timezone-aware).""" + return datetime.now(timezone.utc) + + +def _utcnow_iso() -> str: + """Return current UTC time as ISO string.""" + return _utcnow().isoformat() + class SQLiteManager: def __init__(self, db_path: str): @@ -13,8 +49,32 @@ def __init__(self, db_path: str): db_dir = os.path.dirname(db_path) if db_dir: os.makedirs(db_dir, exist_ok=True) + # Phase 1: Persistent connection with WAL mode. + self._conn = sqlite3.connect(db_path, check_same_thread=False) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA busy_timeout=5000") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.execute("PRAGMA cache_size=-8000") # 8MB cache + self._conn.execute("PRAGMA temp_store=MEMORY") + self._conn.row_factory = sqlite3.Row + # Reentrant lock is required because some read helpers compose other DB + # helpers (e.g., get_proposal_commit -> get_proposal_changes). + self._lock = threading.RLock() self._init_db() + def close(self) -> None: + """Close the persistent connection for clean shutdown.""" + with self._lock: + if self._conn: + try: + self._conn.close() + except Exception: + pass + self._conn = None # type: ignore[assignment] + + def __repr__(self) -> str: + return f"SQLiteManager(db_path={self.db_path!r})" + def _init_db(self) -> None: with self._get_connection() as conn: conn.executescript( @@ -167,13 +227,14 @@ def _init_db(self) -> None: @contextmanager def _get_connection(self): - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row - try: - yield conn - conn.commit() - finally: - conn.close() + """Yield the persistent connection under the thread lock.""" + with self._lock: + try: + yield self._conn + self._conn.commit() + except Exception: + self._conn.rollback() + raise def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: """Create and migrate Engram v2 schema in-place (idempotent).""" @@ -495,6 +556,10 @@ def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: (version,), ) + # Phase 3: Skip column migrations + backfills if already complete. + if self._is_migration_applied(conn, "v2_columns_complete"): + return + # v2 columns on existing canonical tables. self._migrate_add_column_conn(conn, "memories", "confidentiality_scope", "TEXT DEFAULT 'work'") self._migrate_add_column_conn(conn, "memories", "source_type", "TEXT") @@ -525,7 +590,38 @@ def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: self._migrate_add_column_conn(conn, "handoff_sessions", "namespace", "TEXT DEFAULT 'default'") self._migrate_add_column_conn(conn, "handoff_sessions", "confidentiality_scope", "TEXT DEFAULT 'work'") + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_handoff_sessions_recent + ON handoff_sessions(user_id, last_checkpoint_at DESC, updated_at DESC, created_at DESC) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_handoff_lanes_user_recent + ON handoff_lanes(user_id, last_checkpoint_at DESC, created_at DESC) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_handoff_sessions_repo_recent + ON handoff_sessions(user_id, repo_id, last_checkpoint_at DESC, created_at DESC) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_handoff_lanes_repo_recent + ON handoff_lanes(user_id, repo_id, last_checkpoint_at DESC, created_at DESC) + """ + ) + conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_subscribers_expires ON memory_subscribers(expires_at)") + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_memories_user_source_event + ON memories(user_id, source_event_id, namespace, created_at DESC) + """ + ) # Backfills. conn.execute( @@ -633,6 +729,11 @@ def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: self._seed_default_namespaces(conn) self._seed_invariants(conn) + # Phase 3: Mark column migrations + backfills as complete. + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_columns_complete')" + ) + def _seed_default_namespaces(self, conn: sqlite3.Connection) -> None: users = conn.execute( """ @@ -640,7 +741,7 @@ def _seed_default_namespaces(self, conn: sqlite3.Connection) -> None: WHERE user_id IS NOT NULL AND user_id != '' """ ).fetchall() - now = datetime.utcnow().isoformat() + now = _utcnow_iso() for row in users: user_id = row["user_id"] conn.execute( @@ -749,6 +850,13 @@ def _is_migration_applied(self, conn: sqlite3.Connection, version: str) -> bool: ).fetchone() return row is not None + # Phase 5: Allowed table names for ALTER TABLE to prevent SQL injection. + _ALLOWED_TABLES = frozenset({ + "memories", "scenes", "profiles", "sessions", "memory_subscribers", + "handoff_sessions", "handoff_lanes", "handoff_checkpoints", + "proposal_commits", "categories", "views", + }) + def _migrate_add_column_conn( self, conn: sqlite3.Connection, @@ -757,6 +865,11 @@ def _migrate_add_column_conn( col_type: str, ) -> None: """Add a column using an existing connection, if missing.""" + if table not in self._ALLOWED_TABLES: + raise ValueError(f"Invalid table for migration: {table!r}") + # Validate column name: must be alphanumeric/underscore only. + if not column.replace("_", "").isalnum(): + raise ValueError(f"Invalid column name: {column!r}") try: conn.execute(f"ALTER TABLE {table} ADD COLUMN {column} {col_type}") except sqlite3.OperationalError: @@ -764,7 +877,7 @@ def _migrate_add_column_conn( def add_memory(self, memory_data: Dict[str, Any]) -> str: memory_id = memory_data.get("id", str(uuid.uuid4())) - now = datetime.utcnow().isoformat() + now = _utcnow_iso() metadata = memory_data.get("metadata", {}) or {} source_app = memory_data.get("source_app") or memory_data.get("app_id") or metadata.get("source_app") @@ -821,7 +934,17 @@ def add_memory(self, memory_data: Dict[str, Any]) -> str: (memory_id,), ) - self._log_event(memory_id, "ADD", new_value=memory_data.get("memory")) + # Log within the same transaction — atomic with the insert. + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (memory_id, "ADD", None, memory_data.get("memory"), None, None, None, None), + ) + return memory_id def get_memory(self, memory_id: str, include_tombstoned: bool = False) -> Optional[Dict[str, Any]]: @@ -836,6 +959,41 @@ def get_memory(self, memory_id: str, include_tombstoned: bool = False) -> Option return self._row_to_dict(row) return None + def get_memory_by_source_event( + self, + *, + user_id: str, + source_event_id: str, + namespace: Optional[str] = None, + source_app: Optional[str] = None, + include_tombstoned: bool = False, + ) -> Optional[Dict[str, Any]]: + normalized_event = str(source_event_id or "").strip() + if not normalized_event: + return None + query = """ + SELECT * + FROM memories + WHERE user_id = ? + AND source_event_id = ? + """ + params: List[Any] = [user_id, normalized_event] + if namespace: + query += " AND namespace = ?" + params.append(namespace) + if source_app: + query += " AND source_app = ?" + params.append(source_app) + if not include_tombstoned: + query += " AND tombstone = 0" + query += " ORDER BY created_at DESC LIMIT 1" + + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._row_to_dict(row) + return None + def get_all_memories( self, *, @@ -847,6 +1005,8 @@ def get_all_memories( namespace: Optional[str] = None, min_strength: float = 0.0, include_tombstoned: bool = False, + created_after: Optional[str] = None, + created_before: Optional[str] = None, ) -> List[Dict[str, Any]]: query = "SELECT * FROM memories WHERE strength >= ?" params: List[Any] = [min_strength] @@ -871,6 +1031,12 @@ def get_all_memories( if namespace: query += " AND namespace = ?" params.append(namespace) + if created_after: + query += " AND created_at >= ?" + params.append(created_after) + if created_before: + query += " AND created_at <= ?" + params.append(created_before) query += " ORDER BY strength DESC" @@ -879,38 +1045,53 @@ def get_all_memories( return [self._row_to_dict(row) for row in rows] def update_memory(self, memory_id: str, updates: Dict[str, Any]) -> bool: - old_memory = self.get_memory(memory_id, include_tombstoned=True) - if not old_memory: - return False - set_clauses = [] params: List[Any] = [] for key, value in updates.items(): + if key not in VALID_MEMORY_COLUMNS: + raise ValueError(f"Invalid memory column: {key!r}") if key in {"metadata", "categories", "embedding", "related_memories", "source_memories"}: value = json.dumps(value) set_clauses.append(f"{key} = ?") params.append(value) set_clauses.append("updated_at = ?") - params.append(datetime.utcnow().isoformat()) + params.append(_utcnow_iso()) params.append(memory_id) with self._get_connection() as conn: + # Read old values and update in a single transaction. + old_row = conn.execute( + "SELECT memory, strength, layer FROM memories WHERE id = ?", + (memory_id,), + ).fetchone() + if not old_row: + return False + conn.execute( f"UPDATE memories SET {', '.join(set_clauses)} WHERE id = ?", params, ) - self._log_event( - memory_id, - "UPDATE", - old_value=old_memory.get("memory"), - new_value=updates.get("memory"), - old_strength=old_memory.get("strength"), - new_strength=updates.get("strength"), - old_layer=old_memory.get("layer"), - new_layer=updates.get("layer"), - ) + # Log within the same transaction. + conn.execute( + """ + INSERT INTO memory_history ( + memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + memory_id, + "UPDATE", + old_row["memory"], + updates.get("memory"), + old_row["strength"], + updates.get("strength"), + old_row["layer"], + updates.get("layer"), + ), + ) return True def delete_memory(self, memory_id: str, use_tombstone: bool = True) -> bool: @@ -922,7 +1103,7 @@ def delete_memory(self, memory_id: str, use_tombstone: bool = True) -> bool: return True def increment_access(self, memory_id: str) -> None: - now = datetime.utcnow().isoformat() + now = _utcnow_iso() with self._get_connection() as conn: conn.execute( """ @@ -933,11 +1114,60 @@ def increment_access(self, memory_id: str) -> None: (now, memory_id), ) - def _row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + # Phase 2: Batch operations to eliminate N+1 queries in search. + + def get_memories_bulk(self, memory_ids: List[str], include_tombstoned: bool = False) -> Dict[str, Dict[str, Any]]: + """Fetch multiple memories by ID in a single query. Returns {id: memory_dict}.""" + if not memory_ids: + return {} + with self._get_connection() as conn: + placeholders = ",".join("?" for _ in memory_ids) + query = f"SELECT * FROM memories WHERE id IN ({placeholders})" + if not include_tombstoned: + query += " AND tombstone = 0" + rows = conn.execute(query, memory_ids).fetchall() + return {row["id"]: self._row_to_dict(row) for row in rows} + + def increment_access_bulk(self, memory_ids: List[str]) -> None: + """Increment access count for multiple memories in a single transaction.""" + if not memory_ids: + return + now = _utcnow_iso() + with self._get_connection() as conn: + placeholders = ",".join("?" for _ in memory_ids) + conn.execute( + f""" + UPDATE memories + SET access_count = access_count + 1, last_accessed = ? + WHERE id IN ({placeholders}) + """, + [now] + list(memory_ids), + ) + + def update_strength_bulk(self, updates: Dict[str, float]) -> None: + """Batch-update strength for multiple memories. updates = {memory_id: new_strength}.""" + if not updates: + return + now = _utcnow_iso() + with self._get_connection() as conn: + conn.executemany( + "UPDATE memories SET strength = ?, updated_at = ? WHERE id = ?", + [(strength, now, memory_id) for memory_id, strength in updates.items()], + ) + + _MEMORY_JSON_FIELDS = ("metadata", "categories", "related_memories", "source_memories") + + def _row_to_dict(self, row: sqlite3.Row, *, skip_embedding: bool = False) -> Dict[str, Any]: data = dict(row) - for key in ["metadata", "categories", "embedding", "related_memories", "source_memories"]: + for key in self._MEMORY_JSON_FIELDS: if key in data and data[key]: data[key] = json.loads(data[key]) + # Embedding is the largest JSON field (~30-50KB for 3072-dim vectors). + # Skip deserialization when the caller doesn't need it. + if skip_embedding: + data.pop("embedding", None) + elif "embedding" in data and data["embedding"]: + data["embedding"] = json.loads(data["embedding"]) data["immutable"] = bool(data.get("immutable", 0)) data["tombstone"] = bool(data.get("tombstone", 0)) return data @@ -1131,6 +1361,8 @@ def update_scene(self, scene_id: str, updates: Dict[str, Any]) -> bool: set_clauses = [] params: List[Any] = [] for key, value in updates.items(): + if key not in VALID_SCENE_COLUMNS: + raise ValueError(f"Invalid scene column: {key!r}") if key in {"participants", "memory_ids", "embedding"}: value = json.dumps(value) set_clauses.append(f"{key} = ?") @@ -1230,7 +1462,7 @@ def _scene_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: def add_profile(self, profile_data: Dict[str, Any]) -> str: profile_id = profile_data.get("id", str(uuid.uuid4())) - now = datetime.utcnow().isoformat() + now = _utcnow_iso() with self._get_connection() as conn: conn.execute( """ @@ -1276,12 +1508,14 @@ def update_profile(self, profile_id: str, updates: Dict[str, Any]) -> bool: set_clauses = [] params: List[Any] = [] for key, value in updates.items(): + if key not in VALID_PROFILE_COLUMNS: + raise ValueError(f"Invalid profile column: {key!r}") if key in {"facts", "preferences", "relationships", "aliases", "theory_of_mind", "embedding"}: value = json.dumps(value) set_clauses.append(f"{key} = ?") params.append(value) set_clauses.append("updated_at = ?") - params.append(datetime.utcnow().isoformat()) + params.append(_utcnow_iso()) params.append(profile_id) with self._get_connection() as conn: conn.execute( @@ -1302,14 +1536,45 @@ def get_all_profiles(self, user_id: Optional[str] = None) -> List[Dict[str, Any] return [self._profile_row_to_dict(row) for row in rows] def get_profile_by_name(self, name: str, user_id: Optional[str] = None) -> Optional[Dict[str, Any]]: - """Find a profile by exact name or alias match.""" - profiles = self.get_all_profiles(user_id=user_id) - name_lower = name.lower() - for p in profiles: - if p["name"].lower() == name_lower: - return p - if name_lower in [a.lower() for a in p.get("aliases", [])]: - return p + """Find a profile by exact name match, then fall back to alias scan.""" + # Fast path: exact name match via indexed column. + query = "SELECT * FROM profiles WHERE lower(name) = ?" + params: List[Any] = [name.lower()] + if user_id: + query += " AND user_id = ?" + params.append(user_id) + query += " LIMIT 1" + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._profile_row_to_dict(row) + # Slow path: alias scan (aliases stored as JSON, can't index). + alias_query = "SELECT * FROM profiles WHERE aliases LIKE ?" + alias_params: List[Any] = [f'%"{name}"%'] + if user_id: + alias_query += " AND user_id = ?" + alias_params.append(user_id) + alias_query += " LIMIT 1" + row = conn.execute(alias_query, alias_params).fetchone() + if row: + result = self._profile_row_to_dict(row) + # Verify case-insensitive alias match. + if name.lower() in [a.lower() for a in result.get("aliases", [])]: + return result + return None + + def find_profile_by_substring(self, name: str, user_id: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Find a profile where the name contains the query as a substring (case-insensitive).""" + query = "SELECT * FROM profiles WHERE lower(name) LIKE ?" + params: List[Any] = [f"%{name.lower()}%"] + if user_id: + query += " AND user_id = ?" + params.append(user_id) + query += " ORDER BY strength DESC LIMIT 1" + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._profile_row_to_dict(row) return None def add_profile_memory(self, profile_id: str, memory_id: str, role: str = "mentioned") -> None: @@ -1414,7 +1679,7 @@ def revoke_session(self, session_id: str) -> bool: with self._get_connection() as conn: conn.execute( "UPDATE sessions SET revoked_at = ? WHERE id = ?", - (datetime.utcnow().isoformat(), session_id), + (_utcnow_iso(), session_id), ) return True @@ -1440,8 +1705,8 @@ def add_proposal_commit(self, commit_data: Dict[str, Any], changes: Optional[Lis json.dumps(commit_data.get("checks", {})), json.dumps(commit_data.get("preview", {})), json.dumps(commit_data.get("provenance", {})), - commit_data.get("created_at", datetime.utcnow().isoformat()), - commit_data.get("updated_at", datetime.utcnow().isoformat()), + commit_data.get("created_at", _utcnow_iso()), + commit_data.get("updated_at", _utcnow_iso()), ), ) for change in changes or []: @@ -1458,7 +1723,7 @@ def add_proposal_commit(self, commit_data: Dict[str, Any], changes: Optional[Lis change.get("target", "memory_item"), change.get("target_id"), json.dumps(change.get("patch", {})), - change.get("created_at", datetime.utcnow().isoformat()), + change.get("created_at", _utcnow_iso()), ), ) return commit_id @@ -1505,6 +1770,28 @@ def list_proposal_commits( commits.append(data) return commits + def count_proposal_commits( + self, + *, + user_id: str, + agent_id: Optional[str] = None, + since: Optional[str] = None, + ) -> int: + query = "SELECT COUNT(1) AS cnt FROM proposal_commits WHERE user_id = ?" + params: List[Any] = [user_id] + if agent_id: + query += " AND agent_id = ?" + params.append(agent_id) + if since: + query += " AND created_at >= ?" + params.append(since) + + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if not row: + return 0 + return int(row["cnt"] or 0) + def get_proposal_changes(self, commit_id: str) -> List[Dict[str, Any]]: with self._get_connection() as conn: rows = conn.execute( @@ -1525,7 +1812,7 @@ def update_proposal_commit(self, commit_id: str, updates: Dict[str, Any]) -> boo set_clauses.append(f"{key} = ?") params.append(value) set_clauses.append("updated_at = ?") - params.append(datetime.utcnow().isoformat()) + params.append(_utcnow_iso()) params.append(commit_id) with self._get_connection() as conn: conn.execute( @@ -1547,7 +1834,7 @@ def transition_proposal_commit_status( return False set_clauses = ["status = ?", "updated_at = ?"] - params: List[Any] = [str(to_status or "").upper(), datetime.utcnow().isoformat()] + params: List[Any] = [str(to_status or "").upper(), _utcnow_iso()] for key, value in (updates or {}).items(): if key in {"checks", "preview", "provenance"}: value = json.dumps(value) @@ -1590,7 +1877,7 @@ def add_conflict_stash(self, stash_data: Dict[str, Any]) -> str: json.dumps(stash_data.get("proposed", {})), stash_data.get("resolution", "UNRESOLVED"), stash_data.get("source_commit_id"), - stash_data.get("created_at", datetime.utcnow().isoformat()), + stash_data.get("created_at", _utcnow_iso()), stash_data.get("resolved_at"), ), ) @@ -1643,7 +1930,7 @@ def resolve_conflict_stash(self, stash_id: str, resolution: str) -> bool: SET resolution = ?, resolved_at = ? WHERE id = ? """, - (resolution, datetime.utcnow().isoformat(), stash_id), + (resolution, _utcnow_iso(), stash_id), ) return True @@ -1666,7 +1953,7 @@ def add_view(self, view_data: Dict[str, Any]) -> str: view_id, view_data.get("user_id"), view_data.get("agent_id"), - view_data.get("timestamp", datetime.utcnow().isoformat()), + view_data.get("timestamp", _utcnow_iso()), view_data.get("place_type"), view_data.get("place_value"), view_data.get("topic_label"), @@ -1675,7 +1962,7 @@ def add_view(self, view_data: Dict[str, Any]) -> str: view_data.get("raw_text"), json.dumps(view_data.get("signals", {})), view_data.get("scene_id"), - view_data.get("created_at", datetime.utcnow().isoformat()), + view_data.get("created_at", _utcnow_iso()), ), ) return view_id @@ -1747,7 +2034,7 @@ def adjust_memory_refcount(self, memory_id: str, strong_delta: int = 0, weak_del strong_delta, weak_delta, weak_delta, - datetime.utcnow().isoformat(), + _utcnow_iso(), memory_id, ), ) @@ -1760,25 +2047,18 @@ def add_memory_subscriber( ref_type: str = "weak", ttl_hours: Optional[int] = None, ) -> None: - now = datetime.utcnow().isoformat() + now_dt = _utcnow() + now = now_dt.isoformat() expires_at = None if ttl_hours is not None: try: ttl_value = int(ttl_hours) except Exception: ttl_value = 0 - if ttl_value > 0: - expires_at = datetime.utcfromtimestamp( - datetime.utcnow().timestamp() + ttl_value * 3600 - ).isoformat() - elif ttl_value < 0: - expires_at = datetime.utcfromtimestamp( - datetime.utcnow().timestamp() + ttl_value * 3600 - ).isoformat() + if ttl_value != 0: + expires_at = (now_dt + timedelta(hours=ttl_value)).isoformat() elif ref_type == "weak": - expires_at = datetime.utcfromtimestamp( - datetime.utcnow().timestamp() + 14 * 24 * 3600 - ).isoformat() + expires_at = (now_dt + timedelta(days=14)).isoformat() with self._get_connection() as conn: existing = conn.execute( @@ -1843,7 +2123,7 @@ def list_memory_subscribers(self, memory_id: str) -> List[str]: return [f"{row['subscriber']}:{row['ref_type']}" for row in rows] def cleanup_stale_memory_subscribers(self, now_iso: Optional[str] = None) -> int: - now_iso = now_iso or datetime.utcnow().isoformat() + now_iso = now_iso or _utcnow_iso() with self._get_connection() as conn: rows = conn.execute( """ @@ -1906,7 +2186,7 @@ def upsert_daily_digest(self, user_id: str, digest_date: str, payload: Dict[str, user_id, digest_date, json.dumps(payload), - datetime.utcnow().isoformat(), + _utcnow_iso(), ), ) return digest_id @@ -1967,7 +2247,7 @@ def _compute_trust_score( approved_dt = datetime.fromisoformat(last_approved_at) days_since = max( 0.0, - (datetime.utcnow() - approved_dt).total_seconds() / 86400.0, + (_utcnow() - approved_dt).total_seconds() / 86400.0, ) recency_score = max(0.0, 1.0 - min(days_since, 30.0) / 30.0) except Exception: @@ -2018,7 +2298,7 @@ def _upsert_agent_trust_row( last_proposed_at, last_approved_at, trust_score, - datetime.utcnow().isoformat(), + _utcnow_iso(), ), ) return self.get_agent_trust(user_id=user_id, agent_id=agent_id) @@ -2027,7 +2307,7 @@ def record_agent_proposal(self, user_id: str, agent_id: Optional[str], status: s if not user_id or not agent_id: return {} current = self.get_agent_trust(user_id=user_id, agent_id=agent_id) - now_iso = datetime.utcnow().isoformat() + now_iso = _utcnow_iso() auto_stashed = int(current.get("auto_stashed_proposals", 0)) if (status or "").upper() == "AUTO_STASHED": auto_stashed += 1 @@ -2051,7 +2331,7 @@ def record_agent_commit_outcome(self, user_id: str, agent_id: Optional[str], out rejected = int(current.get("rejected_proposals", 0)) auto_stashed = int(current.get("auto_stashed_proposals", 0)) last_approved_at = current.get("last_approved_at") - now_iso = datetime.utcnow().isoformat() + now_iso = _utcnow_iso() if outcome_upper == "APPROVED": approved += 1 last_approved_at = now_iso @@ -2091,7 +2371,7 @@ def ensure_namespace(self, user_id: str, name: str, description: Optional[str] = SET description = COALESCE(?, description), updated_at = ? WHERE id = ? """, - (description, datetime.utcnow().isoformat(), namespace_id), + (description, _utcnow_iso(), namespace_id), ) return namespace_id namespace_id = str(uuid.uuid4()) @@ -2105,8 +2385,8 @@ def ensure_namespace(self, user_id: str, name: str, description: Optional[str] = user_id, ns_name, description, - datetime.utcnow().isoformat(), - datetime.utcnow().isoformat(), + _utcnow_iso(), + _utcnow_iso(), ), ) return namespace_id @@ -2149,7 +2429,7 @@ def grant_namespace_permission( user_id, agent_id, capability, - datetime.utcnow().isoformat(), + _utcnow_iso(), expires_at, ), ) @@ -2181,7 +2461,7 @@ def list_namespace_permissions( params.append(namespace) if not include_expired: query += " AND (p.expires_at IS NULL OR p.expires_at > ?)" - params.append(datetime.utcnow().isoformat()) + params.append(_utcnow_iso()) query += " ORDER BY p.granted_at DESC" with self._get_connection() as conn: rows = conn.execute(query, params).fetchall() @@ -2202,7 +2482,7 @@ def get_agent_allowed_namespaces(self, user_id: str, agent_id: Optional[str], ca AND p.capability IN (?, '*') AND (p.expires_at IS NULL OR p.expires_at > ?) """, - (user_id, agent_id, capability, datetime.utcnow().isoformat()), + (user_id, agent_id, capability, _utcnow_iso()), ).fetchall() namespaces = [str(row["namespace_name"]) for row in rows if row["namespace_name"]] if "default" not in namespaces: @@ -2238,7 +2518,7 @@ def upsert_agent_policy( for namespace in (allowed_namespaces or []) if str(namespace).strip() }) - now_iso = datetime.utcnow().isoformat() + now_iso = _utcnow_iso() with self._get_connection() as conn: conn.execute( """ @@ -2520,7 +2800,7 @@ def get_decay_log_entries(self, limit: int = 20) -> List[Dict[str, Any]]: def add_handoff_session(self, data: Dict[str, Any]) -> str: session_id = data.get("id", str(uuid.uuid4())) - now = datetime.utcnow().isoformat() + now = _utcnow_iso() with self._get_connection() as conn: conn.execute( """ @@ -2657,7 +2937,7 @@ def update_handoff_session(self, session_id: str, updates: Dict[str, Any]) -> bo if not set_clauses: return False set_clauses.append("updated_at = ?") - params.append(datetime.utcnow().isoformat()) + params.append(_utcnow_iso()) params.append(session_id) with self._get_connection() as conn: cursor = conn.execute( @@ -2723,7 +3003,7 @@ def get_handoff_session_memories(self, session_id: str) -> List[Dict[str, Any]]: def add_handoff_lane(self, data: Dict[str, Any]) -> str: lane_id = data.get("id", str(uuid.uuid4())) - now = datetime.utcnow().isoformat() + now = _utcnow_iso() with self._get_connection() as conn: conn.execute( """ @@ -2807,7 +3087,7 @@ def update_handoff_lane( if not set_clauses: return False set_clauses.append("updated_at = ?") - params.append(datetime.utcnow().isoformat()) + params.append(_utcnow_iso()) query = f"UPDATE handoff_lanes SET {', '.join(set_clauses)} WHERE id = ?" params.append(lane_id) if expected_version is not None: @@ -2858,7 +3138,7 @@ def prune_handoff_lanes(self, user_id: str, max_lanes: int) -> int: def add_handoff_checkpoint(self, data: Dict[str, Any]) -> str: checkpoint_id = data.get("id", str(uuid.uuid4())) - now = datetime.utcnow().isoformat() + now = _utcnow_iso() with self._get_connection() as conn: conn.execute( """ @@ -3004,7 +3284,7 @@ def get_handoff_checkpoint_scenes(self, checkpoint_id: str) -> List[Dict[str, An def add_handoff_lane_conflict(self, data: Dict[str, Any]) -> str: conflict_id = data.get("id", str(uuid.uuid4())) - now = datetime.utcnow().isoformat() + now = _utcnow_iso() with self._get_connection() as conn: conn.execute( """ diff --git a/engram/embeddings/nvidia.py b/engram/embeddings/nvidia.py index 64ca9f4..524dd56 100644 --- a/engram/embeddings/nvidia.py +++ b/engram/embeddings/nvidia.py @@ -14,7 +14,7 @@ def __init__(self, config: Optional[dict] = None): except Exception as exc: raise ImportError("openai package is required for NvidiaEmbedder") from exc - api_key = self.config.get("api_key") or "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" + api_key = self.config.get("api_key") if not api_key: raise ValueError( "NVIDIA API key required. Set config['api_key'] or NVIDIA_API_KEY env var." diff --git a/engram/llms/nvidia.py b/engram/llms/nvidia.py index f657573..6b5cb2b 100644 --- a/engram/llms/nvidia.py +++ b/engram/llms/nvidia.py @@ -14,7 +14,7 @@ def __init__(self, config: Optional[dict] = None): except Exception as exc: raise ImportError("openai package is required for NvidiaLLM") from exc - api_key = self.config.get("api_key") or "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" + api_key = self.config.get("api_key") if not api_key: raise ValueError( "NVIDIA API key required. Set config['api_key'] or NVIDIA_API_KEY env var." diff --git a/engram/mcp_server.py b/engram/mcp_server.py index 02860e7..f37d6c4 100644 --- a/engram/mcp_server.py +++ b/engram/mcp_server.py @@ -4,16 +4,26 @@ This server exposes engram's memory capabilities as MCP tools that Claude Code can use. """ +import atexit import json +import logging import os +import signal import sys -from typing import Any, Dict, List, Optional +import threading +import time +from typing import Any, Callable, Dict, List, Optional from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import Tool, TextContent from engram.memory.main import Memory +from engram.core.handoff_backend import ( + HandoffBackendError, + classify_handoff_error, + create_handoff_backend, +) from engram.configs.base import ( MemoryConfig, VectorStoreConfig, @@ -22,6 +32,8 @@ FadeMemConfig, ) +logger = logging.getLogger(__name__) + def _get_embedding_dims_for_model(model: str, provider: str) -> int: """Get the embedding dimensions for a given model.""" @@ -149,6 +161,11 @@ def get_memory_instance() -> Memory: # Global memory instance (lazy initialized) _memory: Optional[Memory] = None +_handoff_backend = None +_lifecycle_lock = threading.Lock() +_lifecycle_state: Dict[str, Dict[str, Any]] = {} +_idle_pause_seconds = max(1, int(os.environ.get("ENGRAM_MCP_IDLE_PAUSE_SECONDS", "300"))) +_shutdown_hooks_registered = False def get_memory() -> Memory: @@ -159,13 +176,146 @@ def get_memory() -> Memory: return _memory +def _strict_handoff_enabled(memory: Memory) -> bool: + cfg = getattr(memory, "handoff_config", None) + return bool(getattr(cfg, "strict_handoff_auth", True)) + + +def get_handoff_backend(memory: Memory): + """Get or create the configured handoff backend.""" + global _handoff_backend + if _handoff_backend is None: + _handoff_backend = create_handoff_backend(memory) + return _handoff_backend + + +def _handoff_key(*, user_id: str, agent_id: str, namespace: str, repo_id: Optional[str], repo_path: Optional[str]) -> str: + scoped_repo = str(repo_id or repo_path or "").strip() or "unknown-repo" + return f"{user_id}::{agent_id}::{namespace}::{scoped_repo}" + + +def _merge_handoff_context(existing: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]: + merged = dict(existing) + for key, value in update.items(): + if value is not None: + merged[key] = value + return merged + + +def _record_handoff_context(context: Dict[str, Any]) -> None: + user_id = context.get("user_id", "default") + agent_id = context.get("agent_id", "claude-code") + namespace = context.get("namespace", "default") + repo_path = context.get("repo_path") + key = _handoff_key( + user_id=user_id, + agent_id=agent_id, + namespace=namespace, + repo_id=context.get("repo_id"), + repo_path=repo_path, + ) + alt_key = _handoff_key( + user_id=user_id, + agent_id=agent_id, + namespace=namespace, + repo_id=None, + repo_path=repo_path, + ) + with _lifecycle_lock: + now_ts = time.time() + existing = _lifecycle_state.get(key, {}) + if not existing and alt_key in _lifecycle_state: + existing = _lifecycle_state.pop(alt_key) + merged = _merge_handoff_context(existing, context) + merged["last_activity_ts"] = now_ts + _lifecycle_state[key] = merged + + +def _emit_lifecycle_checkpoint(memory: Memory, context: Dict[str, Any], *, event_type: str, task_summary: Optional[str]) -> Dict[str, Any]: + backend = get_handoff_backend(memory) + payload = { + "status": "paused" if event_type in {"agent_pause", "agent_end"} else "active", + "task_summary": task_summary or context.get("objective") or f"{event_type} checkpoint", + "decisions_made": context.get("decisions_made", []), + "files_touched": context.get("files_touched", []), + "todos_remaining": context.get("todos_remaining", []), + "blockers": context.get("blockers", []), + "key_commands": context.get("key_commands", []), + "test_results": context.get("test_results", []), + "context_snapshot": context.get("context_snapshot"), + } + return backend.auto_checkpoint( + user_id=context["user_id"], + agent_id=context["agent_id"], + namespace=context.get("namespace", "default"), + repo_path=context.get("repo_path") or os.getcwd(), + branch=context.get("branch"), + lane_id=context.get("lane_id"), + lane_type=context.get("lane_type", "general"), + objective=context.get("objective") or payload["task_summary"], + agent_role=context.get("agent_role"), + confidentiality_scope=context.get("confidentiality_scope", "work"), + payload=payload, + event_type=event_type, + ) + + +def _flush_agent_end_checkpoints() -> None: + """Best-effort final checkpoints on process shutdown.""" + try: + memory = get_memory() + except Exception: + return + with _lifecycle_lock: + contexts = list(_lifecycle_state.values()) + for context in contexts: + try: + _emit_lifecycle_checkpoint( + memory, + context, + event_type="agent_end", + task_summary=context.get("task_summary") or "Agent shutdown", + ) + except Exception as exc: # pragma: no cover - best effort shutdown path + logger.warning("Agent end checkpoint failed: %s", exc) + + +def _register_shutdown_hooks() -> None: + global _shutdown_hooks_registered + if _shutdown_hooks_registered: + return + atexit.register(_flush_agent_end_checkpoints) + + def _signal_handler(signum, _frame): # pragma: no cover - signal path + try: + _flush_agent_end_checkpoints() + finally: + raise SystemExit(0) + + for sig_name in ("SIGTERM", "SIGINT"): + sig_value = getattr(signal, sig_name, None) + if sig_value is not None: + try: + signal.signal(sig_value, _signal_handler) + except Exception: + logger.debug("Skipping signal hook registration for %s", sig_name) + + _shutdown_hooks_registered = True + + # Create the MCP server server = Server("engram-memory") +# Cached tool list — schemas are static, no need to rebuild on every call. +_tools_cache: Optional[List[Tool]] = None + @server.list_tools() async def list_tools() -> List[Tool]: """List available engram tools.""" + global _tools_cache + if _tools_cache is not None: + return list(_tools_cache) tools = [ Tool( name="add_memory", @@ -830,7 +980,10 @@ async def list_tools() -> List[Tool]: }, "statuses": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string", + "enum": ["active", "paused", "completed", "abandoned"], + }, "description": "Optional status list filter (defaults to active/paused)" }, } @@ -869,7 +1022,10 @@ async def list_tools() -> List[Tool]: }, "statuses": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "string", + "enum": ["active", "paused", "completed", "abandoned"], + }, "description": "Optional status list filter." }, } @@ -881,7 +1037,471 @@ async def list_tools() -> List[Tool]: # Keep handoff tools at the front so cross-agent continuity remains available. priority = {"save_session_digest": 0, "get_last_session": 1, "list_sessions": 2} tools.sort(key=lambda tool: priority.get(tool.name, 1000)) - return tools + _tools_cache = tools + return list(tools) + + +# Phase 6: Tool handler registry for cleaner dispatch. +_TOOL_HANDLERS: Dict[str, Callable] = {} + + +def _preview(value: Any, limit: int = 1200) -> str: + """Truncate a JSON-serialized value for checkpoint snapshots.""" + try: + text = json.dumps(value, default=str) + except Exception: + text = str(value) + if len(text) > limit: + return text[:limit] + "...(truncated)" + return text + + +def _make_session_token(memory: "Memory", *, user_id: str, agent_id: Optional[str], capabilities: List[str], namespaces: Optional[List[str]] = None) -> str: + """Create a scoped session token.""" + session = memory.create_session( + user_id=user_id, + agent_id=agent_id, + allowed_confidentiality_scopes=["work", "personal", "finance", "health", "private"], + capabilities=capabilities, + namespaces=namespaces, + ttl_minutes=24 * 60, + ) + return session["token"] + + +def _tool_handler(name: str): + """Decorator to register a tool handler function.""" + def decorator(fn): + _TOOL_HANDLERS[name] = fn + return fn + return decorator + + +@_tool_handler("get_memory") +def _handle_get_memory(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + memory_id = arguments.get("memory_id", "") + result = memory.get(memory_id) + if result: + return { + "id": result["id"], + "memory": result["memory"], + "layer": result.get("layer", "sml"), + "strength": round(result.get("strength", 1.0), 3), + "categories": result.get("categories", []), + "created_at": result.get("created_at"), + "access_count": result.get("access_count", 0), + } + return {"error": "Memory not found"} + + +@_tool_handler("update_memory") +def _handle_update_memory(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + memory_id = arguments.get("memory_id", "") + content = arguments.get("content", "") + return memory.update(memory_id, content) + + +@_tool_handler("delete_memory") +def _handle_delete_memory(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + memory_id = arguments.get("memory_id", "") + return memory.delete(memory_id) + + +@_tool_handler("get_memory_stats") +def _handle_get_memory_stats(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id") + agent_id = arguments.get("agent_id") + return memory.get_stats(user_id=user_id, agent_id=agent_id) + + +@_tool_handler("apply_memory_decay") +def _handle_apply_memory_decay(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id") + agent_id = arguments.get("agent_id") + scope = {"user_id": user_id, "agent_id": agent_id} if user_id or agent_id else None + return memory.apply_decay(scope=scope) + + +@_tool_handler("engram_context") +def _handle_engram_context(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + limit = arguments.get("limit", 15) + all_result = memory.get_all(user_id=user_id, limit=limit * 3) + all_memories = all_result.get("results", []) + layer_order = {"lml": 0, "sml": 1} + all_memories.sort(key=lambda m: ( + layer_order.get(m.get("layer", "sml"), 1), + -float(m.get("strength", 1.0)) + )) + digest = [ + { + "id": m["id"], + "memory": m.get("memory", ""), + "layer": m.get("layer", "sml"), + "strength": round(float(m.get("strength", 1.0)), 3), + "categories": m.get("categories", []), + } + for m in all_memories[:limit] + ] + return {"digest": digest, "total_in_store": len(all_memories), "returned": len(digest)} + + +@_tool_handler("get_profile") +def _handle_get_profile(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + profile_id = arguments.get("profile_id", "") + profile = memory.get_profile(profile_id) + if profile: + profile.pop("embedding", None) + return profile + return {"error": "Profile not found"} + + +@_tool_handler("list_profiles") +def _handle_list_profiles(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + profiles = memory.get_all_profiles(user_id=user_id) + return { + "profiles": [ + { + "id": p["id"], + "name": p.get("name"), + "profile_type": p.get("profile_type"), + "narrative": p.get("narrative"), + "fact_count": len(p.get("facts", [])), + "preference_count": len(p.get("preferences", [])), + } + for p in profiles + ], + "total": len(profiles), + } + + +@_tool_handler("search_profiles") +def _handle_search_profiles(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + query = arguments.get("query", "") + user_id = arguments.get("user_id", "default") + limit = arguments.get("limit", 10) + profiles = memory.search_profiles(query=query, user_id=user_id, limit=limit) + return { + "profiles": [ + { + "id": p["id"], + "name": p.get("name"), + "profile_type": p.get("profile_type"), + "narrative": p.get("narrative"), + "facts": p.get("facts", [])[:5], + "search_score": p.get("search_score"), + } + for p in profiles + ], + "total": len(profiles), + } + + +@_tool_handler("add_memory") +@_tool_handler("propose_write") +def _handle_add_memory(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + content = arguments.get("content", "") + user_id = arguments.get("user_id", "default") + agent_id = arguments.get("agent_id") + namespace = arguments.get("namespace", "default") + token = _session_token( + user_id=user_id, + agent_id=agent_id, + capabilities=["propose_write"], + namespaces=[namespace], + ) + return memory.propose_write( + content=content, + user_id=user_id, + agent_id=agent_id, + categories=arguments.get("categories"), + metadata=arguments.get("metadata"), + scope=arguments.get("scope", "work"), + namespace=namespace, + mode=arguments.get("mode", "staging"), + infer=False, + token=token, + source_app="mcp", + source_type="mcp", + source_event_id=arguments.get("source_event_id"), + ) + + +@_tool_handler("search_memory") +def _handle_search_memory(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + query = arguments.get("query", "") + user_id = arguments.get("user_id", "default") + agent_id = arguments.get("agent_id") + limit = arguments.get("limit", 10) + categories = arguments.get("categories") + if agent_id: + token = _session_token(user_id=user_id, agent_id=agent_id, capabilities=["search"]) + result = memory.search_with_context( + query=query, user_id=user_id, agent_id=agent_id, token=token, limit=limit, categories=categories, + ) + else: + result = memory.search( + query=query, user_id=user_id, agent_id=agent_id, limit=limit, categories=categories, + agent_category=arguments.get("agent_category"), + connector_ids=arguments.get("connector_ids"), + scope_filter=arguments.get("scope_filter"), + ) + if "results" in result: + result["results"] = [ + { + "id": r.get("id"), + "memory": r.get("memory", r.get("details", "")), + "score": round(r.get("composite_score", r.get("score", 0)), 3), + "layer": r.get("layer", "sml"), + "categories": r.get("categories", []), + "scope": r.get("scope"), + "agent_category": r.get("agent_category"), + "connector_id": r.get("connector_id"), + } + for r in result["results"] + ] + return result + + +@_tool_handler("get_all_memories") +def _handle_get_all_memories(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + result = memory.get_all( + user_id=arguments.get("user_id", "default"), + agent_id=arguments.get("agent_id"), + limit=arguments.get("limit", 50), + layer=arguments.get("layer"), + ) + if "results" in result: + result["results"] = [ + { + "id": r["id"], + "memory": r["memory"], + "layer": r.get("layer", "sml"), + "strength": round(r.get("strength", 1.0), 3), + "categories": r.get("categories", []), + } + for r in result["results"] + ] + return result + + +@_tool_handler("remember") +def _handle_remember(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + namespace = arguments.get("namespace", "default") + token = _session_token( + user_id="default", + agent_id="claude-code", + capabilities=["propose_write"], + namespaces=[namespace], + ) + return memory.propose_write( + content=arguments.get("content", ""), + user_id="default", + agent_id="claude-code", + categories=arguments.get("categories"), + scope=arguments.get("scope", "work"), + namespace=namespace, + mode=arguments.get("mode", "staging"), + source_app="claude-code", + source_type="mcp", + infer=False, + token=token, + ) + + +@_tool_handler("list_pending_commits") +def _handle_list_pending_commits(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + agent_id = arguments.get("agent_id", "claude-code") + token = _session_token(user_id=user_id, agent_id=agent_id, capabilities=["review_commits"]) + return memory.list_pending_commits( + user_id=user_id, agent_id=agent_id, token=token, + status=arguments.get("status"), limit=arguments.get("limit", 100), + ) + + +@_tool_handler("resolve_conflict") +def _handle_resolve_conflict(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + agent_id = arguments.get("agent_id", "claude-code") + token = _session_token( + user_id=arguments.get("user_id", "default"), + agent_id=agent_id, + capabilities=["resolve_conflicts"], + ) + return memory.resolve_conflict( + stash_id=arguments.get("stash_id", ""), + resolution=arguments.get("resolution", "UNRESOLVED"), + token=token, agent_id=agent_id, + ) + + +@_tool_handler("declare_namespace") +def _handle_declare_namespace(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + caller_agent_id = arguments.get("agent_id", "claude-code") + namespace = arguments.get("namespace", "default") + token = _session_token( + user_id=user_id, agent_id=caller_agent_id, + capabilities=["manage_namespaces"], namespaces=[namespace], + ) + return memory.declare_namespace( + user_id=user_id, namespace=namespace, + description=arguments.get("description"), token=token, agent_id=caller_agent_id, + ) + + +@_tool_handler("grant_namespace_permission") +def _handle_grant_namespace_permission(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + requester_agent_id = arguments.get("requester_agent_id", arguments.get("agent_id", "claude-code")) + namespace = arguments.get("namespace", "default") + token = _session_token( + user_id=user_id, agent_id=requester_agent_id, + capabilities=["manage_namespaces"], namespaces=[namespace], + ) + return memory.grant_namespace_permission( + user_id=user_id, namespace=namespace, + agent_id=arguments.get("agent_id", "claude-code"), + capability=arguments.get("capability", "read"), + expires_at=arguments.get("expires_at"), + token=token, requester_agent_id=requester_agent_id, + ) + + +@_tool_handler("upsert_agent_policy") +def _handle_upsert_agent_policy(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + requester_agent_id = arguments.get("requester_agent_id", arguments.get("agent_id", "claude-code")) + token = _session_token(user_id=user_id, agent_id=requester_agent_id, capabilities=["manage_namespaces"]) + return memory.upsert_agent_policy( + user_id=user_id, + agent_id=arguments.get("agent_id", "claude-code"), + allowed_confidentiality_scopes=arguments.get("allowed_confidentiality_scopes"), + allowed_capabilities=arguments.get("allowed_capabilities"), + allowed_namespaces=arguments.get("allowed_namespaces"), + token=token, requester_agent_id=requester_agent_id, + ) + + +@_tool_handler("list_agent_policies") +def _handle_list_agent_policies(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + requester_agent_id = arguments.get("requester_agent_id", arguments.get("agent_id", "claude-code")) + token = _session_token(user_id=user_id, agent_id=requester_agent_id, capabilities=["manage_namespaces"]) + lookup_agent_id = arguments.get("agent_id") + if lookup_agent_id: + return memory.get_agent_policy( + user_id=user_id, agent_id=lookup_agent_id, + include_wildcard=arguments.get("include_wildcard", True), + token=token, requester_agent_id=requester_agent_id, + ) + return memory.list_agent_policies( + user_id=user_id, token=token, requester_agent_id=requester_agent_id, + ) + + +@_tool_handler("delete_agent_policy") +def _handle_delete_agent_policy(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + requester_agent_id = arguments.get("requester_agent_id", arguments.get("agent_id", "claude-code")) + token = _session_token(user_id=user_id, agent_id=requester_agent_id, capabilities=["manage_namespaces"]) + return memory.delete_agent_policy( + user_id=user_id, agent_id=arguments.get("agent_id", "claude-code"), + token=token, requester_agent_id=requester_agent_id, + ) + + +@_tool_handler("get_agent_trust") +def _handle_get_agent_trust(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + requester_agent_id = arguments.get("requester_agent_id", arguments.get("agent_id", "claude-code")) + token = _session_token(user_id=user_id, agent_id=requester_agent_id, capabilities=["read_trust"]) + return memory.get_agent_trust( + user_id=user_id, agent_id=arguments.get("agent_id", "claude-code"), + token=token, requester_agent_id=requester_agent_id, + ) + + +@_tool_handler("run_sleep_cycle") +def _handle_run_sleep_cycle(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + caller_agent_id = arguments.get("agent_id", "claude-code") + token = _session_token(user_id=user_id, agent_id=caller_agent_id, capabilities=["run_sleep_cycle"]) + return memory.run_sleep_cycle( + user_id=arguments.get("user_id"), + date_str=arguments.get("date"), + apply_decay=arguments.get("apply_decay", True), + cleanup_stale_refs=arguments.get("cleanup_stale_refs", True), + token=token, agent_id=caller_agent_id, + ) + + +@_tool_handler("get_scene") +def _handle_get_scene(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + agent_id = arguments.get("agent_id", "claude-code") + token = _session_token(user_id=user_id, agent_id=agent_id, capabilities=["read_scene"]) + scene = memory.kernel.get_scene( + scene_id=arguments.get("scene_id", ""), + user_id=user_id, agent_id=agent_id, token=token, + ) + return scene if scene else {"error": "Scene not found"} + + +@_tool_handler("list_scenes") +def _handle_list_scenes(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + scenes = memory.get_scenes( + user_id=arguments.get("user_id", "default"), + topic=arguments.get("topic"), + start_after=arguments.get("start_after"), + start_before=arguments.get("start_before"), + limit=arguments.get("limit", 20), + ) + return { + "scenes": [ + { + "id": s["id"], + "title": s.get("title"), + "topic": s.get("topic"), + "summary": s.get("summary"), + "start_time": s.get("start_time"), + "end_time": s.get("end_time"), + "memory_count": len(s.get("memory_ids", [])), + } + for s in scenes + ], + "total": len(scenes), + } + + +@_tool_handler("search_scenes") +def _handle_search_scenes(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + user_id = arguments.get("user_id", "default") + agent_id = arguments.get("agent_id", "claude-code") + token = _session_token(user_id=user_id, agent_id=agent_id, capabilities=["read_scene"]) + payload = memory.kernel.search_scenes( + query=arguments.get("query", ""), + user_id=user_id, agent_id=agent_id, token=token, + limit=arguments.get("limit", 10), + ) + scenes = payload.get("scenes", []) + return { + "scenes": [ + { + "id": s.get("id"), + "title": s.get("title"), + "summary": s.get("summary", s.get("details")), + "topic": s.get("topic"), + "start_time": s.get("start_time", s.get("time")), + "search_score": s.get("search_score"), + "memory_count": len(s.get("memory_ids", [])), + "masked": bool(s.get("masked", False)), + } + for s in scenes + ], + "total": len(scenes), + } @server.call_tool() @@ -891,31 +1511,13 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: memory = get_memory() result: Any = None - def _session_token( - *, - user_id: str, - agent_id: Optional[str], - capabilities: List[str], - namespaces: Optional[List[str]] = None, - ) -> str: - session = memory.create_session( - user_id=user_id, - agent_id=agent_id, - allowed_confidentiality_scopes=["work", "personal", "finance", "health", "private"], - capabilities=capabilities, - namespaces=namespaces, - ttl_minutes=24 * 60, - ) - return session["token"] + def _session_token(*, user_id: str, agent_id: Optional[str], capabilities: List[str], namespaces: Optional[List[str]] = None) -> str: + return _make_session_token(memory, user_id=user_id, agent_id=agent_id, capabilities=capabilities, namespaces=namespaces) - def _preview(value: Any, limit: int = 1200) -> str: - try: - text = json.dumps(value, default=str) - except Exception: - text = str(value) - if len(text) > limit: - return text[:limit] + "...(truncated)" - return text + def _handoff_error_payload(exc: Exception) -> Dict[str, str]: + if isinstance(exc, HandoffBackendError): + return exc.to_dict() + return classify_handoff_error(exc).to_dict() auto_handoff_enabled = bool( getattr(memory, "handoff_processor", None) @@ -924,8 +1526,9 @@ def _preview(value: Any, limit: int = 1200) -> str: ) auto_handoff_skip_tools = {"save_session_digest", "get_last_session", "list_sessions"} auto_handoff_context: Dict[str, Any] = {} - auto_handoff_token: Optional[str] = None + auto_handoff_meta: Dict[str, Any] = {} auto_resume_packet: Optional[Dict[str, Any]] = None + handoff_backend = None if auto_handoff_enabled and name not in auto_handoff_skip_tools: caller_agent_id = ( @@ -956,518 +1559,91 @@ def _preview(value: Any, limit: int = 1200) -> str: "confidentiality_scope": arguments.get("confidentiality_scope", "work"), } try: - auto_handoff_token = _session_token( - user_id=user_id, - agent_id=caller_agent_id, - capabilities=["read_handoff", "write_handoff"], - namespaces=[namespace], - ) - except Exception: - auto_handoff_token = None - try: - auto_resume_packet = memory.auto_resume_context( + handoff_backend = get_handoff_backend(memory) + except Exception as backend_exc: + auto_handoff_meta["error"] = _handoff_error_payload(backend_exc) + handoff_backend = None + + if handoff_backend is not None: + auto_handoff_key = _handoff_key( user_id=user_id, agent_id=caller_agent_id, - repo_path=repo_path, - branch=arguments.get("branch"), - lane_type=arguments.get("lane_type", "general"), - objective=objective, - agent_role=arguments.get("agent_role"), namespace=namespace, - token=auto_handoff_token, - requester_agent_id=caller_agent_id, - auto_create=True, - ) - except Exception: - auto_resume_packet = None - - if name in {"add_memory", "propose_write"}: - content = arguments.get("content", "") - user_id = arguments.get("user_id", "default") - agent_id = arguments.get("agent_id") - categories = arguments.get("categories") - metadata = arguments.get("metadata") - scope = arguments.get("scope", "work") - namespace = arguments.get("namespace", "default") - mode = arguments.get("mode", "staging") - source_event_id = arguments.get("source_event_id") - token = _session_token( - user_id=user_id, - agent_id=agent_id, - capabilities=["propose_write"], - namespaces=[namespace], - ) - - result = memory.propose_write( - content=content, - user_id=user_id, - agent_id=agent_id, - categories=categories, - metadata=metadata, - scope=scope, - namespace=namespace, - mode=mode, - infer=False, - token=token, - source_app="mcp", - source_type="mcp", - source_event_id=source_event_id, - ) - - elif name == "search_memory": - query = arguments.get("query", "") - user_id = arguments.get("user_id", "default") - agent_id = arguments.get("agent_id") - limit = arguments.get("limit", 10) - categories = arguments.get("categories") - agent_category = arguments.get("agent_category") - connector_ids = arguments.get("connector_ids") - scope_filter = arguments.get("scope_filter") - if agent_id: - token = _session_token( - user_id=user_id, - agent_id=agent_id, - capabilities=["search"], - ) - result = memory.search_with_context( - query=query, - user_id=user_id, - agent_id=agent_id, - token=token, - limit=limit, - categories=categories, + repo_id=None, + repo_path=repo_path, ) - else: - result = memory.search( - query=query, - user_id=user_id, - agent_id=agent_id, - limit=limit, - categories=categories, - agent_category=agent_category, - connector_ids=connector_ids, - scope_filter=scope_filter, + now_ts = time.time() + with _lifecycle_lock: + previous_context = dict(_lifecycle_state.get(auto_handoff_key, {})) + last_activity_ts = float(previous_context.get("last_activity_ts", now_ts)) + idle_for_seconds = max(0.0, now_ts - last_activity_ts) + if ( + previous_context + and idle_for_seconds >= _idle_pause_seconds + and "agent_pause" in getattr(memory.handoff_config, "auto_checkpoint_events", []) + ): + try: + pause_result = _emit_lifecycle_checkpoint( + memory, + previous_context, + event_type="agent_pause", + task_summary=previous_context.get("task_summary") or f"Idle pause before {name}", + ) + auto_handoff_meta["pause"] = pause_result + except Exception as pause_exc: + auto_handoff_meta["pause"] = {"error": _handoff_error_payload(pause_exc)} + + try: + auto_resume_packet = handoff_backend.auto_resume_context( + user_id=user_id, + agent_id=caller_agent_id, + namespace=namespace, + repo_path=repo_path, + branch=arguments.get("branch"), + lane_type=arguments.get("lane_type", "general"), + objective=objective, + agent_role=arguments.get("agent_role"), + ) + if auto_resume_packet: + auto_handoff_context["lane_id"] = auto_resume_packet.get("lane_id") or auto_handoff_context["lane_id"] + auto_handoff_context["repo_id"] = auto_resume_packet.get("repo_id") + auto_handoff_context["task_summary"] = auto_resume_packet.get("task_summary") + except Exception as resume_exc: + auto_handoff_meta["error"] = _handoff_error_payload(resume_exc) + auto_resume_packet = None + + if handoff_backend is None and _strict_handoff_enabled(memory): + auto_handoff_meta.setdefault( + "error", + {"code": "hosted_backend_unavailable", "message": "Handoff backend is unavailable"}, ) - # Simplify results for readability - if "results" in result: - result["results"] = [ - { - "id": r.get("id"), - "memory": r.get("memory", r.get("details", "")), - "score": round(r.get("composite_score", r.get("score", 0)), 3), - "layer": r.get("layer", "sml"), - "categories": r.get("categories", []), - "scope": r.get("scope"), - "agent_category": r.get("agent_category"), - "connector_id": r.get("connector_id"), - } - for r in result["results"] - ] - elif name == "get_all_memories": - user_id = arguments.get("user_id", "default") - agent_id = arguments.get("agent_id") - limit = arguments.get("limit", 50) - layer = arguments.get("layer") - - result = memory.get_all( - user_id=user_id, - agent_id=agent_id, - limit=limit, - layer=layer, - ) - # Simplify results - if "results" in result: - result["results"] = [ - { - "id": r["id"], - "memory": r["memory"], - "layer": r.get("layer", "sml"), - "strength": round(r.get("strength", 1.0), 3), - "categories": r.get("categories", []), - } - for r in result["results"] - ] - - elif name == "get_memory": - memory_id = arguments.get("memory_id", "") - result = memory.get(memory_id) - if result: - result = { - "id": result["id"], - "memory": result["memory"], - "layer": result.get("layer", "sml"), - "strength": round(result.get("strength", 1.0), 3), - "categories": result.get("categories", []), - "created_at": result.get("created_at"), - "access_count": result.get("access_count", 0), + if handoff_backend is not None and auto_resume_packet is None and "error" not in auto_handoff_meta: + auto_handoff_meta["error"] = { + "code": "lane_resolution_failed", + "message": "Unable to build resume context", } - else: - result = {"error": "Memory not found"} - - elif name == "update_memory": - memory_id = arguments.get("memory_id", "") - content = arguments.get("content", "") - result = memory.update(memory_id, content) - - elif name == "delete_memory": - memory_id = arguments.get("memory_id", "") - result = memory.delete(memory_id) - - elif name == "get_memory_stats": - user_id = arguments.get("user_id") - agent_id = arguments.get("agent_id") - result = memory.get_stats(user_id=user_id, agent_id=agent_id) - - elif name == "apply_memory_decay": - user_id = arguments.get("user_id") - agent_id = arguments.get("agent_id") - scope = {"user_id": user_id, "agent_id": agent_id} if user_id or agent_id else None - result = memory.apply_decay(scope=scope) - - elif name == "engram_context": - user_id = arguments.get("user_id", "default") - limit = arguments.get("limit", 15) - all_result = memory.get_all(user_id=user_id, limit=limit * 3) - all_memories = all_result.get("results", []) - # Sort: LML first, then by strength descending - layer_order = {"lml": 0, "sml": 1} - all_memories.sort(key=lambda m: ( - layer_order.get(m.get("layer", "sml"), 1), - -float(m.get("strength", 1.0)) - )) - digest = [ - { - "id": m["id"], - "memory": m.get("memory", ""), - "layer": m.get("layer", "sml"), - "strength": round(float(m.get("strength", 1.0)), 3), - "categories": m.get("categories", []), - } - for m in all_memories[:limit] - ] - result = {"digest": digest, "total_in_store": len(all_memories), "returned": len(digest)} - - elif name == "remember": - content = arguments.get("content", "") - categories = arguments.get("categories") - token = _session_token( - user_id="default", - agent_id="claude-code", - capabilities=["propose_write"], - namespaces=[arguments.get("namespace", "default")], - ) - result = memory.propose_write( - content=content, - user_id="default", - agent_id="claude-code", - categories=categories, - scope=arguments.get("scope", "work"), - namespace=arguments.get("namespace", "default"), - mode=arguments.get("mode", "staging"), - source_app="claude-code", - source_type="mcp", - infer=False, - token=token, - ) - - elif name == "list_pending_commits": - user_id = arguments.get("user_id", "default") - agent_id = arguments.get("agent_id", "claude-code") - token = _session_token( - user_id=user_id, - agent_id=agent_id, - capabilities=["review_commits"], - ) - result = memory.list_pending_commits( - user_id=user_id, - agent_id=agent_id, - token=token, - status=arguments.get("status"), - limit=arguments.get("limit", 100), - ) - - elif name == "resolve_conflict": - agent_id = arguments.get("agent_id", "claude-code") - # Conflict ownership is resolved from stash; session user can stay default. - token = _session_token( - user_id=arguments.get("user_id", "default"), - agent_id=agent_id, - capabilities=["resolve_conflicts"], - ) - result = memory.resolve_conflict( - stash_id=arguments.get("stash_id", ""), - resolution=arguments.get("resolution", "UNRESOLVED"), - token=token, - agent_id=agent_id, - ) - - elif name == "declare_namespace": - user_id = arguments.get("user_id", "default") - caller_agent_id = arguments.get("agent_id", "claude-code") - token = _session_token( - user_id=user_id, - agent_id=caller_agent_id, - capabilities=["manage_namespaces"], - namespaces=[arguments.get("namespace", "default")], - ) - result = memory.declare_namespace( - user_id=user_id, - namespace=arguments.get("namespace", "default"), - description=arguments.get("description"), - token=token, - agent_id=caller_agent_id, - ) - - elif name == "grant_namespace_permission": - user_id = arguments.get("user_id", "default") - requester_agent_id = arguments.get("requester_agent_id", arguments.get("agent_id", "claude-code")) - token = _session_token( - user_id=user_id, - agent_id=requester_agent_id, - capabilities=["manage_namespaces"], - namespaces=[arguments.get("namespace", "default")], - ) - result = memory.grant_namespace_permission( - user_id=user_id, - namespace=arguments.get("namespace", "default"), - agent_id=arguments.get("agent_id", "claude-code"), - capability=arguments.get("capability", "read"), - expires_at=arguments.get("expires_at"), - token=token, - requester_agent_id=requester_agent_id, - ) - - elif name == "upsert_agent_policy": - user_id = arguments.get("user_id", "default") - requester_agent_id = arguments.get("requester_agent_id", arguments.get("agent_id", "claude-code")) - token = _session_token( - user_id=user_id, - agent_id=requester_agent_id, - capabilities=["manage_namespaces"], - ) - result = memory.upsert_agent_policy( - user_id=user_id, - agent_id=arguments.get("agent_id", "claude-code"), - allowed_confidentiality_scopes=arguments.get("allowed_confidentiality_scopes"), - allowed_capabilities=arguments.get("allowed_capabilities"), - allowed_namespaces=arguments.get("allowed_namespaces"), - token=token, - requester_agent_id=requester_agent_id, - ) - - elif name == "list_agent_policies": - user_id = arguments.get("user_id", "default") - requester_agent_id = arguments.get("requester_agent_id", arguments.get("agent_id", "claude-code")) - token = _session_token( - user_id=user_id, - agent_id=requester_agent_id, - capabilities=["manage_namespaces"], - ) - lookup_agent_id = arguments.get("agent_id") - if lookup_agent_id: - result = memory.get_agent_policy( - user_id=user_id, - agent_id=lookup_agent_id, - include_wildcard=arguments.get("include_wildcard", True), - token=token, - requester_agent_id=requester_agent_id, - ) - else: - result = memory.list_agent_policies( - user_id=user_id, - token=token, - requester_agent_id=requester_agent_id, - ) - - elif name == "delete_agent_policy": - user_id = arguments.get("user_id", "default") - requester_agent_id = arguments.get("requester_agent_id", arguments.get("agent_id", "claude-code")) - token = _session_token( - user_id=user_id, - agent_id=requester_agent_id, - capabilities=["manage_namespaces"], - ) - result = memory.delete_agent_policy( - user_id=user_id, - agent_id=arguments.get("agent_id", "claude-code"), - token=token, - requester_agent_id=requester_agent_id, - ) - - elif name == "get_agent_trust": - user_id = arguments.get("user_id", "default") - requester_agent_id = arguments.get("requester_agent_id", arguments.get("agent_id", "claude-code")) - token = _session_token( - user_id=user_id, - agent_id=requester_agent_id, - capabilities=["read_trust"], - ) - result = memory.get_agent_trust( - user_id=user_id, - agent_id=arguments.get("agent_id", "claude-code"), - token=token, - requester_agent_id=requester_agent_id, - ) - - elif name == "run_sleep_cycle": - user_id = arguments.get("user_id", "default") - caller_agent_id = arguments.get("agent_id", "claude-code") - token = _session_token( - user_id=user_id, - agent_id=caller_agent_id, - capabilities=["run_sleep_cycle"], - ) - result = memory.run_sleep_cycle( - user_id=arguments.get("user_id"), - date_str=arguments.get("date"), - apply_decay=arguments.get("apply_decay", True), - cleanup_stale_refs=arguments.get("cleanup_stale_refs", True), - token=token, - agent_id=caller_agent_id, - ) - - # ---- Scene tools ---- - elif name == "get_scene": - scene_id = arguments.get("scene_id", "") - user_id = arguments.get("user_id", "default") - agent_id = arguments.get("agent_id", "claude-code") - token = _session_token( - user_id=user_id, - agent_id=agent_id, - capabilities=["read_scene"], - ) - scene = memory.kernel.get_scene( - scene_id=scene_id, - user_id=user_id, - agent_id=agent_id, - token=token, - ) - if scene: - result = scene - else: - result = {"error": "Scene not found"} - - elif name == "list_scenes": - user_id = arguments.get("user_id", "default") - scenes = memory.get_scenes( - user_id=user_id, - topic=arguments.get("topic"), - start_after=arguments.get("start_after"), - start_before=arguments.get("start_before"), - limit=arguments.get("limit", 20), - ) - result = { - "scenes": [ - { - "id": s["id"], - "title": s.get("title"), - "topic": s.get("topic"), - "summary": s.get("summary"), - "start_time": s.get("start_time"), - "end_time": s.get("end_time"), - "memory_count": len(s.get("memory_ids", [])), - } - for s in scenes - ], - "total": len(scenes), - } - - elif name == "search_scenes": - query = arguments.get("query", "") - user_id = arguments.get("user_id", "default") - agent_id = arguments.get("agent_id", "claude-code") - limit = arguments.get("limit", 10) - token = _session_token( - user_id=user_id, - agent_id=agent_id, - capabilities=["read_scene"], - ) - payload = memory.kernel.search_scenes( - query=query, - user_id=user_id, - agent_id=agent_id, - token=token, - limit=limit, - ) - scenes = payload.get("scenes", []) - result = { - "scenes": [ - { - "id": s.get("id"), - "title": s.get("title"), - "summary": s.get("summary", s.get("details")), - "topic": s.get("topic"), - "start_time": s.get("start_time", s.get("time")), - "search_score": s.get("search_score"), - "memory_count": len(s.get("memory_ids", [])), - "masked": bool(s.get("masked", False)), - } - for s in scenes - ], - "total": len(scenes), - } + elif auto_handoff_enabled: + try: + handoff_backend = get_handoff_backend(memory) + except Exception: + handoff_backend = None - # ---- Profile tools ---- - elif name == "get_profile": - profile_id = arguments.get("profile_id", "") - profile = memory.get_profile(profile_id) - if profile: - profile.pop("embedding", None) - result = profile - else: - result = {"error": "Profile not found"} + if auto_handoff_context: + _record_handoff_context(auto_handoff_context) - elif name == "list_profiles": - user_id = arguments.get("user_id", "default") - profiles = memory.get_all_profiles(user_id=user_id) - result = { - "profiles": [ - { - "id": p["id"], - "name": p.get("name"), - "profile_type": p.get("profile_type"), - "narrative": p.get("narrative"), - "fact_count": len(p.get("facts", [])), - "preference_count": len(p.get("preferences", [])), - } - for p in profiles - ], - "total": len(profiles), - } + # Tool dispatch: registry handles all tools except handoff tools + # (which need access to the local handoff_backend variable). + handler = _TOOL_HANDLERS.get(name) + if handler: + result = handler(memory, arguments, _session_token, _preview) - elif name == "search_profiles": - query = arguments.get("query", "") - user_id = arguments.get("user_id", "default") - limit = arguments.get("limit", 10) - profiles = memory.search_profiles(query=query, user_id=user_id, limit=limit) - result = { - "profiles": [ - { - "id": p["id"], - "name": p.get("name"), - "profile_type": p.get("profile_type"), - "narrative": p.get("narrative"), - "facts": p.get("facts", [])[:5], - "search_score": p.get("search_score"), - } - for p in profiles - ], - "total": len(profiles), - } - - # ---- Handoff tools ---- + # ---- Handoff tools (need local handoff_backend) ---- elif name == "save_session_digest": user_id = arguments.get("user_id", "default") agent_id = arguments.get("agent_id", "claude-code") requester_agent_id = arguments.get("requester_agent_id", agent_id) namespace = arguments.get("namespace", "default") - token = _session_token( - user_id=user_id, - agent_id=requester_agent_id, - capabilities=["write_handoff"], - namespaces=[namespace], - ) task_summary = str(arguments.get("task_summary", "")).strip() if not task_summary: result = {"error": "task_summary is required"} @@ -1492,13 +1668,18 @@ def _preview(value: Any, limit: int = 1200) -> str: "started_at": arguments.get("started_at"), "ended_at": arguments.get("ended_at"), } - result = memory.save_session_digest( - user_id, - agent_id, - digest, - token=token, - requester_agent_id=requester_agent_id, - ) + try: + handoff_backend = handoff_backend or get_handoff_backend(memory) + result = handoff_backend.save_session_digest( + user_id=user_id, + agent_id=agent_id, + requester_agent_id=requester_agent_id, + namespace=namespace, + digest=digest, + ) + except Exception as handoff_exc: + error_payload = _handoff_error_payload(handoff_exc) + result = {"error": error_payload["message"], "_handoff": {"error": error_payload}} elif name == "get_last_session": user_id = arguments.get("user_id", "default") @@ -1508,25 +1689,21 @@ def _preview(value: Any, limit: int = 1200) -> str: arguments.get("agent_id", "claude-code"), ) namespace = arguments.get("namespace", "default") - token = _session_token( - user_id=user_id, - agent_id=requester_agent_id, - capabilities=["read_handoff"], - namespaces=[namespace], - ) repo = arguments.get("repo") - session = memory.get_last_session( - user_id, - agent_id=agent_id, - repo=repo, - statuses=arguments.get("statuses"), - token=token, - requester_agent_id=requester_agent_id, - ) - if session: - result = session - else: - result = {"error": "No sessions found"} + try: + handoff_backend = handoff_backend or get_handoff_backend(memory) + session = handoff_backend.get_last_session( + user_id=user_id, + agent_id=agent_id, + requester_agent_id=requester_agent_id, + namespace=namespace, + repo=repo, + statuses=arguments.get("statuses"), + ) + result = session if session else {"error": "No sessions found"} + except Exception as handoff_exc: + error_payload = _handoff_error_payload(handoff_exc) + result = {"error": error_payload["message"], "_handoff": {"error": error_payload}} elif name == "list_sessions": user_id = arguments.get("user_id", "default") @@ -1535,39 +1712,38 @@ def _preview(value: Any, limit: int = 1200) -> str: arguments.get("agent_id", "claude-code"), ) namespace = arguments.get("namespace", "default") - token = _session_token( - user_id=user_id, - agent_id=requester_agent_id, - capabilities=["read_handoff"], - namespaces=[namespace], - ) - sessions = memory.list_sessions( - user_id=user_id, - agent_id=arguments.get("agent_id"), - repo=arguments.get("repo"), - status=arguments.get("status"), - statuses=arguments.get("statuses"), - limit=arguments.get("limit", 20), - token=token, - requester_agent_id=requester_agent_id, - ) - result = { - "sessions": [ - { - "id": s["id"], - "agent_id": s.get("agent_id"), - "repo": s.get("repo"), - "repo_id": s.get("repo_id"), - "lane_id": s.get("lane_id"), - "status": s.get("status"), - "task_summary": s.get("task_summary", "")[:200], - "last_checkpoint_at": s.get("last_checkpoint_at"), - "updated_at": s.get("updated_at"), - } - for s in sessions - ], - "total": len(sessions), - } + try: + handoff_backend = handoff_backend or get_handoff_backend(memory) + sessions = handoff_backend.list_sessions( + user_id=user_id, + agent_id=arguments.get("agent_id"), + requester_agent_id=requester_agent_id, + namespace=namespace, + repo=arguments.get("repo"), + status=arguments.get("status"), + statuses=arguments.get("statuses"), + limit=arguments.get("limit", 20), + ) + result = { + "sessions": [ + { + "id": s["id"], + "agent_id": s.get("agent_id"), + "repo": s.get("repo"), + "repo_id": s.get("repo_id"), + "lane_id": s.get("lane_id"), + "status": s.get("status"), + "task_summary": s.get("task_summary", "")[:200], + "last_checkpoint_at": s.get("last_checkpoint_at"), + "updated_at": s.get("updated_at"), + } + for s in sessions + ], + "total": len(sessions), + } + except Exception as handoff_exc: + error_payload = _handoff_error_payload(handoff_exc) + result = {"error": error_payload["message"], "_handoff": {"error": error_payload}} else: result = {"error": f"Unknown tool: {name}"} @@ -1602,31 +1778,43 @@ def _preview(value: Any, limit: int = 1200) -> str: ), } try: - checkpoint_result = memory.auto_checkpoint( + handoff_backend = handoff_backend or get_handoff_backend(memory) + checkpoint_result = handoff_backend.auto_checkpoint( user_id=auto_handoff_context["user_id"], agent_id=auto_handoff_context["agent_id"], - payload=checkpoint_payload, - event_type="tool_complete", + namespace=auto_handoff_context["namespace"], repo_path=auto_handoff_context["repo_path"], branch=auto_handoff_context["branch"], - lane_id=auto_handoff_context["lane_id"], + lane_id=auto_handoff_context.get("lane_id"), lane_type=auto_handoff_context["lane_type"], objective=auto_handoff_context["objective"], agent_role=auto_handoff_context["agent_role"], - namespace=auto_handoff_context["namespace"], confidentiality_scope=auto_handoff_context["confidentiality_scope"], - token=auto_handoff_token, - requester_agent_id=auto_handoff_context["agent_id"], + payload=checkpoint_payload, + event_type="tool_complete", ) + if isinstance(checkpoint_result, dict) and checkpoint_result.get("lane_id"): + auto_handoff_context["lane_id"] = checkpoint_result["lane_id"] + auto_handoff_context["task_summary"] = checkpoint_payload["task_summary"] + auto_handoff_context["context_snapshot"] = checkpoint_payload["context_snapshot"] + _record_handoff_context(auto_handoff_context) except Exception as checkpoint_exc: - checkpoint_result = {"error": str(checkpoint_exc)} + checkpoint_result = {"error": _handoff_error_payload(checkpoint_exc)} if isinstance(result, dict): handoff_meta: Dict[str, Any] = {"checkpoint": checkpoint_result} + if auto_handoff_meta: + handoff_meta.update(auto_handoff_meta) if auto_resume_packet: handoff_meta["resume"] = auto_resume_packet result["_handoff"] = handoff_meta + if isinstance(result, dict) and auto_handoff_meta and "_handoff" not in result: + handoff_meta = dict(auto_handoff_meta) + if auto_resume_packet: + handoff_meta["resume"] = auto_resume_packet + result["_handoff"] = handoff_meta + return [TextContent(type="text", text=json.dumps(result, indent=2, default=str))] except Exception as e: @@ -1643,6 +1831,7 @@ async def main(): def run(): """Entry point for the MCP server.""" import asyncio + _register_shutdown_hooks() asyncio.run(main()) diff --git a/engram/memory/client.py b/engram/memory/client.py index 1162a71..4904e58 100644 --- a/engram/memory/client.py +++ b/engram/memory/client.py @@ -271,6 +271,52 @@ def list_handoff_lanes( params["statuses"] = statuses return self._request("GET", "/v1/handoff/lanes", params=params) + def save_session_digest(self, **kwargs) -> Dict[str, Any]: + payload = dict(kwargs) + return self._request("POST", "/v1/handoff/sessions/digest", json_body=payload) + + def get_last_session( + self, + *, + user_id: str, + agent_id: Optional[str] = None, + requester_agent_id: Optional[str] = None, + repo: Optional[str] = None, + statuses: Optional[List[str]] = None, + ) -> Dict[str, Any]: + params: Dict[str, Any] = { + "user_id": user_id, + "agent_id": agent_id, + "requester_agent_id": requester_agent_id, + "repo": repo, + } + if statuses: + params["statuses"] = statuses + return self._request("GET", "/v1/handoff/sessions/last", params=params) + + def list_sessions( + self, + *, + user_id: str, + agent_id: Optional[str] = None, + requester_agent_id: Optional[str] = None, + repo: Optional[str] = None, + status: Optional[str] = None, + statuses: Optional[List[str]] = None, + limit: int = 20, + ) -> Dict[str, Any]: + params: Dict[str, Any] = { + "user_id": user_id, + "agent_id": agent_id, + "requester_agent_id": requester_agent_id, + "repo": repo, + "status": status, + "limit": limit, + } + if statuses: + params["statuses"] = statuses + return self._request("GET", "/v1/handoff/sessions", params=params) + def get_agent_trust(self, *, user_id: str, agent_id: str) -> Dict[str, Any]: return self._request("GET", "/v1/trust", params={"user_id": user_id, "agent_id": agent_id}) diff --git a/engram/memory/main.py b/engram/memory/main.py index 0cbc586..2f5b1c2 100644 --- a/engram/memory/main.py +++ b/engram/memory/main.py @@ -4,7 +4,7 @@ import logging import os import uuid -from datetime import datetime, date +from datetime import datetime, date, timezone from enum import Enum from typing import Any, Dict, List, Optional, Union @@ -183,6 +183,9 @@ def __init__(self, config: Optional[MemoryConfig] = None): config={ "auto_enrich": self.handoff_config.auto_enrich, "max_sessions": self.handoff_config.max_sessions_per_user, + "handoff_backend": self.handoff_config.handoff_backend, + "strict_handoff_auth": self.handoff_config.strict_handoff_auth, + "allow_auto_trusted_bootstrap": self.handoff_config.allow_auto_trusted_bootstrap, "auto_session_bus": self.handoff_config.auto_session_bus, "lane_inactivity_minutes": self.handoff_config.lane_inactivity_minutes, "max_lanes_per_user": self.handoff_config.max_lanes_per_user, @@ -197,6 +200,9 @@ def __init__(self, config: Optional[MemoryConfig] = None): # v2 Personal Memory Kernel orchestration layer. self.kernel = PersonalMemoryKernel(self) + def __repr__(self) -> str: + return f"Memory(db={self.db!r}, echo={self.echo_config.enable_echo}, scenes={self.scene_config.enable_scenes})" + @classmethod def from_config(cls, config_dict: Dict[str, Any]): return cls(MemoryConfig(**config_dict)) @@ -261,324 +267,380 @@ def add( results: List[Dict[str, Any]] = [] for mem in memories_to_add: - content = mem.get("content", "").strip() - if not content: - continue + result = self._process_single_memory( + mem=mem, + processed_metadata=processed_metadata, + effective_filters=effective_filters, + categories=categories, + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + app_id=app_id, + agent_category=agent_category, + connector_id=connector_id, + scope=scope, + source_app=source_app, + immutable=immutable, + expiration_date=expiration_date, + initial_layer=initial_layer, + initial_strength=initial_strength, + echo_depth=echo_depth, + ) + if result is not None: + results.append(result) - mem_categories = normalize_categories(categories or mem.get("categories")) - mem_metadata = dict(processed_metadata) - mem_metadata.update(mem.get("metadata", {})) - if app_id: - mem_metadata["app_id"] = app_id - - role = mem_metadata.get("role", "user") - explicit_intent = detect_explicit_intent(content) if role == "user" else None - explicit_action = explicit_intent.action if explicit_intent else None - explicit_remember = bool(mem_metadata.get("explicit_remember")) or explicit_action == "remember" - explicit_forget = bool(mem_metadata.get("explicit_forget")) or explicit_action == "forget" - - if explicit_forget: - query = explicit_intent.content if explicit_intent else "" - forget_filters = {"user_id": user_id} if user_id else dict(effective_filters) - forget_result = self._forget_by_query(query, forget_filters) - results.append( - { - "event": "FORGET", - "query": query, - "deleted_count": forget_result.get("deleted_count", 0), - "deleted_ids": forget_result.get("deleted_ids", []), - } - ) - continue + # Persist categories after batch + if self.category_processor: + self._persist_categories() - if explicit_remember and explicit_intent and explicit_intent.content: - content = explicit_intent.content - - blocked = detect_sensitive_categories(content) - allow_sensitive = bool(mem_metadata.get("allow_sensitive")) - if blocked and not allow_sensitive: - results.append( - { - "event": "BLOCKED", - "reason": "sensitive", - "blocked_categories": blocked, - "memory": content, - } - ) - continue + return {"results": results} - if not explicit_remember and is_ephemeral(content): - results.append( - { - "event": "SKIP", - "reason": "ephemeral", - "memory": content, - } - ) - continue + def _resolve_memory_metadata( + self, + *, + content: str, + mem_metadata: Dict[str, Any], + explicit_remember: bool, + agent_id: Optional[str], + run_id: Optional[str], + app_id: Optional[str], + effective_filters: Dict[str, Any], + agent_category: Optional[str], + connector_id: Optional[str], + scope: Optional[str], + source_app: Optional[str], + ) -> tuple: + """Resolve store identifiers, scope, and metadata for a single memory.""" + store_agent_id = agent_id + store_run_id = run_id + store_app_id = app_id + store_filters = dict(effective_filters) + if "user_id" in store_filters or "agent_id" in store_filters: + store_filters.pop("run_id", None) + + if explicit_remember: + store_agent_id = None + store_run_id = None + store_app_id = None + store_filters.pop("agent_id", None) + store_filters.pop("run_id", None) + store_filters.pop("app_id", None) + mem_metadata.pop("agent_id", None) + mem_metadata.pop("run_id", None) + mem_metadata.pop("app_id", None) + mem_metadata["policy_scope"] = "user" + else: + mem_metadata["policy_scope"] = "agent" - store_agent_id = agent_id - store_run_id = run_id - store_app_id = app_id - store_filters = dict(effective_filters) - if "user_id" in store_filters or "agent_id" in store_filters: - store_filters.pop("run_id", None) - - if explicit_remember: - store_agent_id = None - store_run_id = None - store_app_id = None - store_filters.pop("agent_id", None) - store_filters.pop("run_id", None) - store_filters.pop("app_id", None) - mem_metadata.pop("agent_id", None) - mem_metadata.pop("run_id", None) - mem_metadata.pop("app_id", None) - mem_metadata["policy_scope"] = "user" - else: - mem_metadata["policy_scope"] = "agent" + mem_metadata["policy_explicit"] = explicit_remember + resolved_agent_category = self._normalize_agent_category( + agent_category or mem_metadata.get("agent_category") + ) + resolved_connector_id = self._normalize_connector_id( + connector_id or mem_metadata.get("connector_id") + ) + resolved_scope = self._infer_scope( + scope=scope or mem_metadata.get("scope"), + connector_id=resolved_connector_id, + agent_category=resolved_agent_category, + policy_explicit=explicit_remember, + agent_id=store_agent_id, + ) + mem_metadata["scope"] = resolved_scope + if resolved_agent_category: + mem_metadata["agent_category"] = resolved_agent_category + if resolved_connector_id: + mem_metadata["connector_id"] = resolved_connector_id + if source_app or mem_metadata.get("source_app"): + mem_metadata["source_app"] = source_app or mem_metadata.get("source_app") - mem_metadata["policy_explicit"] = explicit_remember - resolved_agent_category = self._normalize_agent_category( - agent_category or mem_metadata.get("agent_category") - ) - resolved_connector_id = self._normalize_connector_id( - connector_id or mem_metadata.get("connector_id") - ) - resolved_scope = self._infer_scope( - scope=scope or mem_metadata.get("scope"), - connector_id=resolved_connector_id, - agent_category=resolved_agent_category, - policy_explicit=explicit_remember, - agent_id=store_agent_id, - ) - mem_metadata["scope"] = resolved_scope - if resolved_agent_category: - mem_metadata["agent_category"] = resolved_agent_category - if resolved_connector_id: - mem_metadata["connector_id"] = resolved_connector_id - if source_app or mem_metadata.get("source_app"): - mem_metadata["source_app"] = source_app or mem_metadata.get("source_app") - high_confidence = explicit_remember or looks_high_confidence(content, mem_metadata) - policy_repeated = False - low_confidence = False - - # CategoryMem: Auto-categorize if not provided - category_match = None - if ( - self.category_processor - and self.category_config.auto_categorize - and not mem_categories - ): - category_match = self.category_processor.detect_category( - content, - metadata=mem_metadata, - use_llm=self.category_config.use_llm_categorization, - ) - mem_categories = [category_match.category_id] - mem_metadata["category_confidence"] = category_match.confidence - mem_metadata["category_auto"] = True - - # EchoMem: Process through multi-modal echo encoding - echo_result = None - effective_strength = initial_strength - if self.echo_processor and self.echo_config.enable_echo: - depth_override = EchoDepth(echo_depth) if echo_depth else None - echo_result = self.echo_processor.process(content, depth=depth_override) - # Apply strength multiplier from echo depth - effective_strength = initial_strength * echo_result.strength_multiplier - # Add echo metadata - mem_metadata.update(echo_result.to_metadata()) - # Auto-categorize if not provided - if not mem_categories and echo_result.category: - mem_categories = [echo_result.category] - - # Choose primary embedding text (optionally question-form for query matching) - primary_text = self._select_primary_text(content, echo_result) - embedding = self.embedder.embed(primary_text, memory_action="add") - - nearest, similarity = self._nearest_memory(embedding, store_filters) - repeated_threshold = max(self.fadem_config.conflict_similarity_threshold - 0.05, 0.7) - if similarity >= repeated_threshold: - policy_repeated = True - high_confidence = True - - if not explicit_remember and not high_confidence: - low_confidence = True - - # Conflict resolution against nearest memory in scope - event = "ADD" - existing = None - if nearest and similarity >= self.fadem_config.conflict_similarity_threshold: - existing = nearest - - if existing and self.fadem_config.enable_forgetting: - resolution = resolve_conflict(existing, content, self.llm, self.config.custom_conflict_prompt) - - if resolution.classification == "CONTRADICTORY": - self._demote_existing(existing, reason="CONTRADICTORY") - event = "UPDATE" - elif resolution.classification == "SUBSUMES": - content = resolution.merged_content or content - self._demote_existing(existing, reason="SUBSUMES") - event = "UPDATE" - elif resolution.classification == "SUBSUMED": - # Boost existing memory and skip new - boosted_strength = min(1.0, float(existing.get("strength", 1.0)) + 0.05) - self.db.update_memory(existing["id"], {"strength": boosted_strength}) - self.db.increment_access(existing["id"]) - results.append( - { - "id": existing["id"], - "memory": existing.get("memory", ""), - "event": "NOOP", - "layer": existing.get("layer", "sml"), - "strength": boosted_strength, - } - ) - continue + return store_agent_id, store_run_id, store_app_id, store_filters - if existing and event == "UPDATE" and resolution.classification == "SUBSUMES": - if self.echo_processor and self.echo_config.enable_echo: - depth_override = None - if echo_depth: - depth_override = EchoDepth(echo_depth) - elif echo_result: - depth_override = echo_result.echo_depth - echo_result = self.echo_processor.process(content, depth=depth_override) - mem_metadata.update(echo_result.to_metadata()) - if not mem_categories and echo_result.category: - mem_categories = [echo_result.category] - - primary_text = self._select_primary_text(content, echo_result) - embedding = self.embedder.embed(primary_text, memory_action="add") - - if policy_repeated: - mem_metadata["policy_repeated"] = True - if low_confidence: - mem_metadata["policy_low_confidence"] = True - - if low_confidence: - effective_strength = min(effective_strength, 0.4) - - layer = initial_layer - if layer == "auto": - layer = "sml" - if low_confidence: - layer = "sml" - - confidentiality_scope = str( - mem_metadata.get("confidentiality_scope") - or mem_metadata.get("privacy_scope") - or "work" - ).lower() - source_type = ( - mem_metadata.get("source_type") - or ("cli" if (source_app or "").lower() == "cli" else "mcp") - ) - source_event_id = mem_metadata.get("source_event_id") - importance = mem_metadata.get("importance", 0.5) - sensitivity = mem_metadata.get("sensitivity", "normal") - namespace_value = str(mem_metadata.get("namespace", "default") or "default").strip() or "default" - - memory_id = str(uuid.uuid4()) - now = datetime.utcnow().isoformat() - memory_data = { - "id": memory_id, + def _encode_memory( + self, + content: str, + echo_depth: Optional[str], + mem_categories: List[str], + mem_metadata: Dict[str, Any], + initial_strength: float, + ) -> tuple: + """Run echo encoding + embedding. Returns (echo_result, effective_strength, mem_categories, embedding).""" + echo_result = None + effective_strength = initial_strength + if self.echo_processor and self.echo_config.enable_echo: + depth_override = EchoDepth(echo_depth) if echo_depth else None + echo_result = self.echo_processor.process(content, depth=depth_override) + effective_strength = initial_strength * echo_result.strength_multiplier + mem_metadata.update(echo_result.to_metadata()) + if not mem_categories and echo_result.category: + mem_categories = [echo_result.category] + + primary_text = self._select_primary_text(content, echo_result) + embedding = self.embedder.embed(primary_text, memory_action="add") + return echo_result, effective_strength, mem_categories, embedding + + def _process_single_memory( + self, + *, + mem: Dict[str, Any], + processed_metadata: Dict[str, Any], + effective_filters: Dict[str, Any], + categories: Optional[List[str]], + user_id: Optional[str], + agent_id: Optional[str], + run_id: Optional[str], + app_id: Optional[str], + agent_category: Optional[str], + connector_id: Optional[str], + scope: Optional[str], + source_app: Optional[str], + immutable: bool, + expiration_date: Optional[str], + initial_layer: str, + initial_strength: float, + echo_depth: Optional[str], + ) -> Optional[Dict[str, Any]]: + """Process and store a single memory item. Returns result dict or None if skipped.""" + content = mem.get("content", "").strip() + if not content: + return None + + mem_categories = normalize_categories(categories or mem.get("categories")) + mem_metadata = dict(processed_metadata) + mem_metadata.update(mem.get("metadata", {})) + if app_id: + mem_metadata["app_id"] = app_id + + role = mem_metadata.get("role", "user") + explicit_intent = detect_explicit_intent(content) if role == "user" else None + explicit_action = explicit_intent.action if explicit_intent else None + explicit_remember = bool(mem_metadata.get("explicit_remember")) or explicit_action == "remember" + explicit_forget = bool(mem_metadata.get("explicit_forget")) or explicit_action == "forget" + + if explicit_forget: + query = explicit_intent.content if explicit_intent else "" + forget_filters = {"user_id": user_id} if user_id else dict(effective_filters) + forget_result = self._forget_by_query(query, forget_filters) + return { + "event": "FORGET", + "query": query, + "deleted_count": forget_result.get("deleted_count", 0), + "deleted_ids": forget_result.get("deleted_ids", []), + } + + if explicit_remember and explicit_intent and explicit_intent.content: + content = explicit_intent.content + + blocked = detect_sensitive_categories(content) + allow_sensitive = bool(mem_metadata.get("allow_sensitive")) + if blocked and not allow_sensitive: + return { + "event": "BLOCKED", + "reason": "sensitive", + "blocked_categories": blocked, "memory": content, - "user_id": user_id, - "agent_id": store_agent_id, - "run_id": store_run_id, - "app_id": store_app_id, - "metadata": mem_metadata, - "categories": mem_categories, - "immutable": immutable, - "expiration_date": expiration_date, - "created_at": now, - "updated_at": now, - "layer": layer, - "strength": effective_strength, - "access_count": 0, - "last_accessed": now, - "embedding": embedding, - "confidentiality_scope": confidentiality_scope, - "source_type": source_type, - "source_app": source_app or mem_metadata.get("source_app"), - "source_event_id": source_event_id, - "decay_lambda": self.fadem_config.sml_decay_rate, - "status": "active", - "importance": importance, - "sensitivity": sensitivity, - "namespace": namespace_value, } - vectors, payloads, vector_ids = self._build_index_vectors( - memory_id=memory_id, - content=content, - primary_text=self._select_primary_text(content, echo_result), - embedding=embedding, - echo_result=echo_result, + if not explicit_remember and is_ephemeral(content): + return { + "event": "SKIP", + "reason": "ephemeral", + "memory": content, + } + + # Resolve store identifiers and scope metadata. + store_agent_id, store_run_id, store_app_id, store_filters = self._resolve_memory_metadata( + content=content, + mem_metadata=mem_metadata, + explicit_remember=explicit_remember, + agent_id=agent_id, + run_id=run_id, + app_id=app_id, + effective_filters=effective_filters, + agent_category=agent_category, + connector_id=connector_id, + scope=scope, + source_app=source_app, + ) + + high_confidence = explicit_remember or looks_high_confidence(content, mem_metadata) + policy_repeated = False + low_confidence = False + + # CategoryMem: Auto-categorize if not provided + if ( + self.category_processor + and self.category_config.auto_categorize + and not mem_categories + ): + category_match = self.category_processor.detect_category( + content, metadata=mem_metadata, - categories=mem_categories, - user_id=user_id, - agent_id=store_agent_id, - run_id=store_run_id, - app_id=store_app_id, + use_llm=self.category_config.use_llm_categorization, ) + mem_categories = [category_match.category_id] + mem_metadata["category_confidence"] = category_match.confidence + mem_metadata["category_auto"] = True - self.db.add_memory(memory_data) - self.vector_store.insert(vectors=vectors, payloads=payloads, ids=vector_ids) + # Encode memory (echo + embedding). + echo_result, effective_strength, mem_categories, embedding = self._encode_memory( + content, echo_depth, mem_categories, mem_metadata, initial_strength, + ) - # CategoryMem: Update category stats - if self.category_processor and mem_categories: - for cat_id in mem_categories: - self.category_processor.update_category_stats( - cat_id, effective_strength, is_addition=True - ) + nearest, similarity = self._nearest_memory(embedding, store_filters) + repeated_threshold = max(self.fadem_config.conflict_similarity_threshold - 0.05, 0.7) + if similarity >= repeated_threshold: + policy_repeated = True + high_confidence = True + + if not explicit_remember and not high_confidence: + low_confidence = True + + # Conflict resolution against nearest memory in scope. + event = "ADD" + existing = None + resolution = None + if nearest and similarity >= self.fadem_config.conflict_similarity_threshold: + existing = nearest + + if existing and self.fadem_config.enable_forgetting: + resolution = resolve_conflict(existing, content, self.llm, self.config.custom_conflict_prompt) + + if resolution.classification == "CONTRADICTORY": + self._demote_existing(existing, reason="CONTRADICTORY") + event = "UPDATE" + elif resolution.classification == "SUBSUMES": + content = resolution.merged_content or content + self._demote_existing(existing, reason="SUBSUMES") + event = "UPDATE" + elif resolution.classification == "SUBSUMED": + boosted_strength = min(1.0, float(existing.get("strength", 1.0)) + 0.05) + self.db.update_memory(existing["id"], {"strength": boosted_strength}) + self.db.increment_access(existing["id"]) + return { + "id": existing["id"], + "memory": existing.get("memory", ""), + "event": "NOOP", + "layer": existing.get("layer", "sml"), + "strength": boosted_strength, + } - # KnowledgeGraph: Extract entities and link memories - if self.knowledge_graph: - self.knowledge_graph.extract_entities( - content=content, - memory_id=memory_id, - use_llm=self.graph_config.use_llm_extraction, - ) - if self.graph_config.auto_link_entities: - self.knowledge_graph.link_by_shared_entities(memory_id) + if existing and event == "UPDATE" and resolution and resolution.classification == "SUBSUMES": + # Re-encode merged content. + echo_result, _, mem_categories, embedding = self._encode_memory( + content, echo_depth, mem_categories, mem_metadata, initial_strength, + ) - # SceneProcessor: Assign memory to a scene - if self.scene_processor: - try: - self._assign_to_scene(memory_id, content, embedding, user_id, now) - except Exception as e: - logger.warning(f"Scene assignment failed for {memory_id}: {e}") + if policy_repeated: + mem_metadata["policy_repeated"] = True + if low_confidence: + mem_metadata["policy_low_confidence"] = True + effective_strength = min(effective_strength, 0.4) + + layer = initial_layer + if layer == "auto": + layer = "sml" + if low_confidence: + layer = "sml" + + confidentiality_scope = str( + mem_metadata.get("confidentiality_scope") + or mem_metadata.get("privacy_scope") + or "work" + ).lower() + source_type = ( + mem_metadata.get("source_type") + or ("cli" if (source_app or "").lower() == "cli" else "mcp") + ) + namespace_value = str(mem_metadata.get("namespace", "default") or "default").strip() or "default" + + memory_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc).isoformat() + memory_data = { + "id": memory_id, + "memory": content, + "user_id": user_id, + "agent_id": store_agent_id, + "run_id": store_run_id, + "app_id": store_app_id, + "metadata": mem_metadata, + "categories": mem_categories, + "immutable": immutable, + "expiration_date": expiration_date, + "created_at": now, + "updated_at": now, + "layer": layer, + "strength": effective_strength, + "access_count": 0, + "last_accessed": now, + "embedding": embedding, + "confidentiality_scope": confidentiality_scope, + "source_type": source_type, + "source_app": source_app or mem_metadata.get("source_app"), + "source_event_id": mem_metadata.get("source_event_id"), + "decay_lambda": self.fadem_config.sml_decay_rate, + "status": "active", + "importance": mem_metadata.get("importance", 0.5), + "sensitivity": mem_metadata.get("sensitivity", "normal"), + "namespace": namespace_value, + } - # ProfileProcessor: Update profiles from content - if self.profile_processor: - try: - self._update_profiles(memory_id, content, mem_metadata, user_id) - except Exception as e: - logger.warning(f"Profile update failed for {memory_id}: {e}") + vectors, payloads, vector_ids = self._build_index_vectors( + memory_id=memory_id, + content=content, + primary_text=self._select_primary_text(content, echo_result), + embedding=embedding, + echo_result=echo_result, + metadata=mem_metadata, + categories=mem_categories, + user_id=user_id, + agent_id=store_agent_id, + run_id=store_run_id, + app_id=store_app_id, + ) - results.append( - { - "id": memory_id, - "memory": content, - "event": event, - "layer": layer, - "strength": effective_strength, - "echo_depth": echo_result.echo_depth.value if echo_result else None, - "categories": mem_categories, - "namespace": namespace_value, - "vector_nodes": len(vectors) # Info for user - } + self.db.add_memory(memory_data) + self.vector_store.insert(vectors=vectors, payloads=payloads, ids=vector_ids) + + # Post-store hooks. + if self.category_processor and mem_categories: + for cat_id in mem_categories: + self.category_processor.update_category_stats( + cat_id, effective_strength, is_addition=True + ) + + if self.knowledge_graph: + self.knowledge_graph.extract_entities( + content=content, + memory_id=memory_id, + use_llm=self.graph_config.use_llm_extraction, ) + if self.graph_config.auto_link_entities: + self.knowledge_graph.link_by_shared_entities(memory_id) - # Persist categories after batch - if self.category_processor: - self._persist_categories() + if self.scene_processor: + try: + self._assign_to_scene(memory_id, content, embedding, user_id, now) + except Exception as e: + logger.warning("Scene assignment failed for %s: %s", memory_id, e) - return {"results": results} + if self.profile_processor: + try: + self._update_profiles(memory_id, content, mem_metadata, user_id) + except Exception as e: + logger.warning("Profile update failed for %s: %s", memory_id, e) + + return { + "id": memory_id, + "memory": content, + "event": event, + "layer": layer, + "strength": effective_strength, + "echo_depth": echo_result.echo_depth.value if echo_result else None, + "categories": mem_categories, + "namespace": namespace_value, + "vector_nodes": len(vectors), + } def search( self, @@ -677,10 +739,20 @@ def search( # Record access to category self.category_processor.access_category(query_category_id) + # Phase 2: Bulk-fetch all candidate memories to eliminate N+1 queries. + candidate_ids = [self._resolve_memory_id(vr) for vr in vector_results] + vr_by_id = {self._resolve_memory_id(vr): vr for vr in vector_results} + memories_bulk = self.db.get_memories_bulk(candidate_ids) + results: List[Dict[str, Any]] = [] - for vr in vector_results: - memory_id = self._resolve_memory_id(vr) - memory = self.db.get_memory(memory_id) + access_ids: List[str] = [] + strength_updates: Dict[str, float] = {} + promotion_ids: List[str] = [] + reecho_ids: List[str] = [] + subscriber_ids: List[str] = [] + + for memory_id in candidate_ids: + memory = memories_bulk.get(memory_id) if not memory: continue @@ -709,6 +781,7 @@ def search( ): continue + vr = vr_by_id[memory_id] similarity = float(vr.score) strength = float(memory.get("strength", 1.0)) @@ -742,10 +815,8 @@ def search( memory_categories = set(memory.get("categories", [])) if use_category_boost and self.category_processor and query_category_id: if query_category_id in memory_categories: - # Direct category match category_boost = self.category_config.category_boost_weight elif memory_categories & related_category_ids: - # Related category match category_boost = self.category_config.cross_category_boost combined = combined * (1 + category_boost) @@ -753,7 +824,6 @@ def search( graph_boost = 0.0 if self.knowledge_graph: memory_entities = self.knowledge_graph.memory_entities.get(memory["id"], set()) - # Check if any query terms match entity names for entity_name in memory_entities: if entity_name.lower() in query_lower or any( term in entity_name.lower() for term in query_terms @@ -763,13 +833,13 @@ def search( combined = combined * (1 + graph_boost) if boost_on_access: - self.db.increment_access(memory["id"]) + access_ids.append(memory["id"]) if self.fadem_config.access_strength_boost > 0: boosted_strength = min(1.0, strength + self.fadem_config.access_strength_boost) if boosted_strength != strength: - self.db.update_memory(memory["id"], {"strength": boosted_strength}) + strength_updates[memory["id"]] = boosted_strength strength = boosted_strength - self._check_promotion(memory["id"]) + promotion_ids.append(memory["id"]) # EchoMem: Re-echo on frequent access if ( self.echo_processor @@ -777,9 +847,9 @@ def search( and memory.get("access_count", 0) >= self.echo_config.reecho_threshold and metadata.get("echo_depth") != "deep" ): - self._reecho_memory(memory["id"]) + reecho_ids.append(memory["id"]) if agent_id: - self.db.add_memory_subscriber(memory["id"], f"agent:{agent_id}", ref_type="weak") + subscriber_ids.append(memory["id"]) results.append( { @@ -818,6 +888,19 @@ def search( } ) + # Phase 2: Batch DB writes instead of per-result round-trips. + if access_ids: + self.db.increment_access_bulk(access_ids) + if strength_updates: + self.db.update_strength_bulk(strength_updates) + for mid in promotion_ids: + self._check_promotion(mid) + for mid in reecho_ids: + self._reecho_memory(mid) + if agent_id: + for mid in subscriber_ids: + self.db.add_memory_subscriber(mid, f"agent:{agent_id}", ref_type="weak") + # Persist category access updates if self.category_processor: self._persist_categories() @@ -876,7 +959,7 @@ def _reecho_memory(self, memory_id: str) -> None: self.db.log_event(memory_id, "REECHO", old_strength=memory.get("strength"), new_strength=new_strength) self._update_vectors_for_memory(memory_id, metadata) except Exception as e: - logger.warning(f"Re-echo failed for memory {memory_id}: {e}") + logger.warning("Re-echo failed for memory %s: %s", memory_id, e) def get(self, memory_id: str) -> Optional[Dict[str, Any]]: memory = self.db.get_memory(memory_id) @@ -1084,7 +1167,7 @@ def apply_decay(self, scope: Dict[str, Any] = None) -> Dict[str, Any]: new_strength = calculate_decayed_strength( current_strength=memory.get("strength", 1.0), - last_accessed=memory.get("last_accessed", datetime.utcnow().isoformat()), + last_accessed=memory.get("last_accessed", datetime.now(timezone.utc).isoformat()), access_count=memory.get("access_count", 0), layer=memory.get("layer", "sml"), config=self.fadem_config, @@ -1511,7 +1594,7 @@ def _extract_memories( extracted = [m for m in extracted if excludes.lower() not in m.get("content", "").lower()] return extracted except Exception as exc: - logger.warning(f"Failed to parse extraction response: {exc}") + logger.warning("Failed to parse extraction response: %s", exc) # Fallback: add last user message last_user = next((m for m in reversed(messages) if m.get("role") == "user"), None) if last_user: @@ -1813,7 +1896,7 @@ def _demote_existing(self, memory: Dict[str, Any], reason: str) -> None: metadata = dict(memory.get("metadata", {})) metadata["superseded"] = True metadata["superseded_reason"] = reason - metadata["superseded_at"] = datetime.utcnow().isoformat() + metadata["superseded_at"] = datetime.now(timezone.utc).isoformat() self.db.update_memory( memory["id"], diff --git a/engram/retrieval/dual_search.py b/engram/retrieval/dual_search.py index 01ae38f..b8d4da1 100644 --- a/engram/retrieval/dual_search.py +++ b/engram/retrieval/dual_search.py @@ -2,6 +2,7 @@ from __future__ import annotations +import os from typing import Any, Dict, Iterable, List, Optional, Set from engram.core.policy import enforce_scope_on_results @@ -16,6 +17,14 @@ def __init__(self, *, memory, episodic_store, ref_manager): self.episodic_store = episodic_store self.ref_manager = ref_manager + @staticmethod + def _parse_float_env(name: str, default: float, *, minimum: float = 0.0, maximum: float = 1.0) -> float: + try: + value = float(os.environ.get(name, default)) + except Exception: + value = float(default) + return min(maximum, max(minimum, value)) + def search( self, *, @@ -42,8 +51,14 @@ def search( limit=max(limit, 5), ) visible_scenes = self._filter_scenes_by_namespace(episodic_scenes, allowed_namespaces) - - promoted = intersection_promote(semantic_results, visible_scenes) + boost_weight = self._parse_float_env("ENGRAM_V2_DUAL_INTERSECTION_BOOST_WEIGHT", 0.22) + boost_cap = self._parse_float_env("ENGRAM_V2_DUAL_INTERSECTION_BOOST_CAP", 0.35) + promoted = intersection_promote( + semantic_results, + visible_scenes, + boost_weight=boost_weight, + max_boost=boost_cap, + ) for item in promoted: if "confidentiality_scope" not in item: row = self.memory.db.get_memory(item.get("id")) @@ -70,10 +85,24 @@ def search( visible_ids = [r.get("id") for r in final_results if r.get("id") and not r.get("masked")] self.ref_manager.record_retrieval_refs(visible_ids, agent_id=agent_id, strong=False) + promoted_intersections = sum(1 for item in promoted if item.get("episodic_match")) + boosted_items = sum(1 for item in promoted if float(item.get("intersection_boost", 0.0)) > 0.0) + return { "results": final_results, "count": len(final_results), "context_packet": context_packet, + "retrieval_trace": { + "ranking_version": "dual_intersection_v2", + "strategy": "semantic_plus_episodic_intersection", + "semantic_candidates": len(semantic_results), + "scene_candidates": len(visible_scenes), + "intersection_candidates": int(promoted_intersections), + "boosted_candidates": int(boosted_items), + "boost_weight": float(boost_weight), + "boost_cap": float(boost_cap), + "masked_count": int(masked_count), + }, "scene_hits": [ { "scene_id": s.get("id"), diff --git a/engram/retrieval/reranker.py b/engram/retrieval/reranker.py index b20023d..5e6d575 100644 --- a/engram/retrieval/reranker.py +++ b/engram/retrieval/reranker.py @@ -2,36 +2,75 @@ from __future__ import annotations -from typing import Dict, List, Set +from typing import Any, Dict, List, Set, Tuple + + +def _coerce_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except Exception: + return float(default) + + +def _build_episodic_signal( + episodic_scene_results: List[Dict[str, Any]], +) -> Tuple[Dict[str, float], Dict[str, int]]: + signal_by_memory: Dict[str, float] = {} + scene_count_by_memory: Dict[str, int] = {} + for rank, scene in enumerate(episodic_scene_results): + memory_ids = [str(mid) for mid in (scene.get("memory_ids") or []) if str(mid).strip()] + if not memory_ids: + continue + rank_weight = 1.0 / (1.0 + float(rank)) + scene_score = _coerce_float(scene.get("search_score"), 0.0) + scene_weight = max(0.15, min(1.0, scene_score)) + contribution = rank_weight * scene_weight + for memory_id in memory_ids: + signal_by_memory[memory_id] = signal_by_memory.get(memory_id, 0.0) + contribution + scene_count_by_memory[memory_id] = scene_count_by_memory.get(memory_id, 0) + 1 + + for memory_id, signal in list(signal_by_memory.items()): + signal_by_memory[memory_id] = min(1.0, signal) + return signal_by_memory, scene_count_by_memory def intersection_promote( - semantic_results: List[Dict], - episodic_scene_results: List[Dict], -) -> List[Dict]: + semantic_results: List[Dict[str, Any]], + episodic_scene_results: List[Dict[str, Any]], + *, + boost_weight: float = 0.22, + max_boost: float = 0.35, +) -> List[Dict[str, Any]]: """Promote semantic results that also appear in episodic scenes. - Relative order among promoted items follows original semantic ranking. + Uses deterministic boost calibration: + - Episodic signal is derived from scene rank + scene score. + - Final score = base_score * (1 + intersection_boost). + - Stable tie-breakers preserve semantic order. """ - episodic_memory_ids: Set[str] = set() - for scene in episodic_scene_results: - for mid in scene.get("memory_ids", []) or []: - episodic_memory_ids.add(str(mid)) - - if not episodic_memory_ids: - return semantic_results - - promoted: List[Dict] = [] - others: List[Dict] = [] - for item in semantic_results: - mid = str(item.get("id")) - if mid in episodic_memory_ids: - enriched = dict(item) - enriched["episodic_match"] = True - promoted.append(enriched) - else: - enriched = dict(item) - enriched["episodic_match"] = False - others.append(enriched) - - return promoted + others + weight = min(1.0, max(0.0, _coerce_float(boost_weight, 0.22))) + cap = min(1.0, max(0.0, _coerce_float(max_boost, 0.35))) + signal_by_memory, scene_count_by_memory = _build_episodic_signal(episodic_scene_results) + episodic_memory_ids: Set[str] = set(signal_by_memory.keys()) + + ranked: List[Tuple[float, float, int, Dict[str, Any]]] = [] + for semantic_rank, item in enumerate(semantic_results): + enriched = dict(item) + memory_id = str(item.get("id")) + base_score = _coerce_float(item.get("composite_score"), _coerce_float(item.get("score"), 0.0)) + episodic_signal = signal_by_memory.get(memory_id, 0.0) + intersection_boost = min(cap, episodic_signal * weight) + final_score = base_score * (1.0 + intersection_boost) + + enriched["episodic_match"] = memory_id in episodic_memory_ids + enriched["episodic_scene_count"] = int(scene_count_by_memory.get(memory_id, 0)) + enriched["episodic_signal"] = round(float(episodic_signal), 6) + enriched["intersection_boost"] = round(float(intersection_boost), 6) + enriched["base_composite_score"] = float(base_score) + enriched["composite_score"] = float(final_score) + + # Tie-breaking preserves semantic ranking deterministically. + ranked.append((float(final_score), float(base_score), -semantic_rank, enriched)) + + ranked.sort(key=lambda row: (row[0], row[1], row[2]), reverse=True) + return [row[3] for row in ranked] diff --git a/engram/utils/math.py b/engram/utils/math.py new file mode 100644 index 0000000..018965c --- /dev/null +++ b/engram/utils/math.py @@ -0,0 +1,29 @@ +"""Shared math utilities for engram.""" + +from typing import List + +try: + import numpy as np + + def cosine_similarity(a: List[float], b: List[float]) -> float: + """Compute cosine similarity between two vectors using NumPy.""" + if not a or not b or len(a) != len(b): + return 0.0 + arr_a = np.asarray(a, dtype=np.float64) + arr_b = np.asarray(b, dtype=np.float64) + dot = np.dot(arr_a, arr_b) + denom = np.sqrt(np.dot(arr_a, arr_a) * np.dot(arr_b, arr_b)) + return float(dot / denom) if denom else 0.0 + +except ImportError: + + def cosine_similarity(a: List[float], b: List[float]) -> float: # type: ignore[misc] + """Compute cosine similarity between two vectors (pure-Python fallback).""" + if not a or not b or len(a) != len(b): + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) diff --git a/launch-article.md b/launch-article.md new file mode 100644 index 0000000..765bb1a --- /dev/null +++ b/launch-article.md @@ -0,0 +1,222 @@ +# Every AI Agent You Use Has Amnesia. I Spent Months Fixing It. + +I hit my breaking point on a Tuesday. + +I was three hours into a coding session with Claude. We'd made real progress. Refactored the auth system, decided on JWT with short-lived tokens, mapped out the middleware chain. Good stuff. + +Then my terminal crashed. + +New session. "Hi, how can I help you today?" + +Three hours of shared context. Gone. Like it never happened. + +I sat there staring at the screen thinking — we put AI in everything. Code editors. Chat apps. Email. Planning tools. And not a single one of them remembers what happened yesterday. + +That was the moment I stopped being annoyed and started being obsessed. + +--- + +I started digging. Surely someone had solved this. + +Turns out, yeah, people have tried. The standard approach is simple: store everything the user says, embed it into vectors, retrieve with similarity search. + +I tried a few. They work. Sort of. + +But something kept nagging me. Three things, specifically. And the more I thought about them, the more I realized they weren't edge cases. They were fundamental design flaws. + +--- + +**The first thing: nobody forgets.** + +I was using one of these memory layers for about two months. Worked great at first. My AI remembered my preferences, my stack, my decisions. + +Then it started getting weird. I'd ask about my current auth approach and it would pull up a decision I made in week one. Before I'd changed my mind. Before I'd learned better. That old decision was sitting right next to the current one. Same priority. Same weight. + +My context window was filling up with ghosts. Stale facts haunting my retrieval results. + +And I thought — that's not how my brain works. I don't remember what I had for lunch three Tuesdays ago. That memory decayed. Naturally. Because it wasn't important enough to keep. + +I went looking for research and found a paper called FadeMem. Bio-inspired forgetting for AI agents. The core idea: the Ebbinghaus forgetting curve isn't a flaw in human cognition. It's a feature. Important stuff gets reinforced through repeated access. Unimportant stuff fades. The result is a memory system that's always current, always relevant, and about 45% smaller than one that hoards everything. + +I read that paper three times. Then I started building. + +--- + +**The second thing: agents write whatever they want.** + +I was testing a setup where my coding agent could save notes to memory. Useful in theory. In practice, it was writing garbage. + +Half-formed thoughts. Duplicate facts phrased slightly differently. One time it contradicted something I'd explicitly told it a week earlier. And I didn't notice for days because the writes just went straight in. No review. No staging. No nothing. + +Would you give a new intern root access to your production database on day one? + +Then why does every memory system give every AI agent full write permissions from the start? + +I kept thinking about how real teams handle this. New hire submits a PR. Someone reviews it. Over time they earn trust. Eventually they get merge rights. There's a progression. + +Nobody was building memory like that. + +--- + +**The third thing: memory was just "find similar text."** + +I was trying to recall a specific debugging session. I knew roughly when it happened. I knew what we were working on. I could almost see the conversation in my head. + +But when I searched, I got scattered facts. A JWT token reference from one session. A middleware mention from another. An auth decision from a third. All "similar" to my query. None of them the actual session I was thinking of. + +Because I wasn't looking for similar text. I was looking for an episode. A scene. Time, place, what happened, what we decided. + +That's how human memory works. You don't remember Tuesday as a bag of keywords. You remember the morning meeting. The long debugging session after lunch. The architecture argument at 4pm. Scenes. + +I found another paper. CAST — Contextual Associative Scene Theory. It explains exactly this. Humans organize memory into episodes defined by shifts in time, place, and topic. When you recall something, you're pulling an entire scene, not running a text search. + +Nobody was building AI memory this way either. + +--- + +So I built it. + +I spent months on it. It's called Engram. And it works differently from anything else I've seen. + +--- + +Here's the core idea. Memory should work like memory actually works. Biologically. + +**New memories start weak.** When your AI stores something, it goes into short-term storage. Strength of 1.0. A proposal, not a permanent record. + +**Repeated access makes them stronger.** Ask about your TypeScript preference three times across different sessions? That memory gets promoted to long-term storage. It earned its place. + +**Unused memories fade.** That random note from two months ago that nobody's referenced since? Its strength decays. Gradually. Following the same curve Ebbinghaus mapped in 1885. Eventually it's gone. No manual cleanup. No "memory management." It just happens. + +The result surprised me. 45% less storage. And retrieval got better, not worse. Because when you search, you're not wading through ghosts anymore. Everything that surfaces is current and relevant. + +There's a catch though. What if Agent A stopped using a memory but Agent B still relies on it? Should it decay? + +No. So I built reference-aware decay. The system tracks who's using what. A memory stays alive as long as any agent references it. Even if the original writer forgot about it. + +--- + +Writes were the next problem to solve. + +In Engram, every write is a proposal. It lands in staging. Not in your canonical memory. Staging. + +The system runs checks. Does this contradict something already stored? Is it a duplicate? Is it from a trusted agent or a new one? + +Contradictions go to a conflict stash. You decide which version wins. + +New agents start with low trust. Everything they write waits for your approval. As you approve good writes, their trust score climbs. Eventually they earn auto-merge. Just like that new hire earning commit access. + +Your memory. Your rules. Always. + +--- + +Episodic memory was the hardest part to build. And the most satisfying. + +Engram watches the conversation flow and detects scene boundaries. Long pause? New scene. Topic shifted from frontend to deployment? New scene. Different repo? New scene. + +Each scene captures when it happened, where (which project, which repo), who was involved, what was discussed, and what decisions came out of it. Plus links to the semantic facts that were extracted. + +So when you ask "what did we decide in that auth session?" Engram doesn't fumble through scattered vector matches. It pulls the scene. The whole episode. Timeline, participants, synopsis, decisions. + +It's the difference between searching your email for "auth" and actually remembering the meeting where you made the call. + +--- + +Then I thought — why does each memory get one shot at being found? + +Standard approach: embed the text, store the vector, pray the query matches. + +I built something called EchoMem. Every memory gets encoded five ways: + +The raw text. A paraphrase. Keywords extracted from it. Implications (what does this fact suggest?). And a question form (what question would this be the answer to?). + +Five retrieval paths instead of one. Five chances to match. + +The question encoding turned out to be weirdly powerful. When an agent asks "what stack should I recommend?" it directly matches against memories stored as "What language does the user prefer?" Much stronger signal than fuzzy cosine similarity between unrelated phrasings. + +--- + +And retrieval itself is dual-path. Semantic search and episodic search run in parallel. + +If a fact shows up in both? Its confidence score gets boosted. Intersection promotion. + +This kills the most annoying failure mode in AI memory: semantically similar but contextually wrong results. "You mentioned JWT tokens" — yes, three months ago, in a different project, before I changed my mind. The episodic layer catches that. The semantic layer alone never would. + +--- + +There's one more thing that I think matters. + +When an agent queries outside its scope, it doesn't get nothing. It gets structure without details. + +```json +{ + "type": "private_event", + "time": "2026-02-10T17:00:00Z", + "importance": "high", + "details": "[REDACTED]" +} +``` + +Your scheduling agent knows you're busy. It doesn't need to know why. Your coding agent knows a decision was made. It doesn't need to see the financial discussion behind it. + +I call it "all but mask." Need-to-know, enforced at the memory layer. + +--- + +The part I didn't expect to build: cross-agent handoff. + +I was using Claude Code for a task, got halfway through, then switched to Cursor the next day. Fresh start. No memory of what Claude Code had done. + +So I built a handoff bus. When an agent pauses work, it saves a session digest. What was the task. What decisions were made. What files were touched. What's left to do. + +Next agent picks up. Calls `get_last_session`. Gets the full context. Continues from where the last agent stopped. + +No re-explanation. No copying context between tools. Your agents work like a relay team. + +--- + +All of this runs locally. `127.0.0.1:8100`. Your data never leaves your machine unless you want it to. + +Three commands to set up: + +``` +pip install engram-memory +export GEMINI_API_KEY="your-key" +engram install +``` + +Restart your agent. Done. + +`engram install` auto-configures Claude Code, Cursor, Codex. One command. All your agents. Same memory kernel underneath. + +Want fully offline? Use Ollama. No API keys. No cloud. Nothing leaves your laptop. + +Open source. MIT licensed. + +--- + +I've been running this for my own workflow for a while now. The difference is hard to overstate. + +Monday I make a decision in Claude Code. Tuesday, Cursor knows about it. Not because I told it. Because the memory is shared. + +Last week's debugging session? I can pull the whole episode. Not scattered facts. The scene. What we tried, what failed, what worked, what we decided. + +And the context window isn't full of ghosts from three months ago. Decay took care of those. What surfaces is current and relevant. + +It sounds small. It's not. The compound effect of never re-explaining yourself changes how you work with AI. + +--- + +I built this because I needed it. I'm sharing it because I think everyone does. + +Models will keep getting better. They'll get faster, cheaper, smarter. But without memory, every session still starts from zero. The smartest model in the world is useless if it can't remember what you told it yesterday. + +Memory is the missing infrastructure layer. Not a feature. Infrastructure. + +``` +pip install engram-memory +``` + +[GitHub](https://github.com/Ashish-dwi99/Engram) + +Your agents forget everything between sessions. Engram fixes that. diff --git a/pitch-deck.md b/pitch-deck.md new file mode 100644 index 0000000..5daac49 --- /dev/null +++ b/pitch-deck.md @@ -0,0 +1,427 @@ +# Engram Pitch Deck v2 + +> Story-first, light-theme, Excalidraw-style deck spec aligned with the current Engram landing page visual language. +> Use this as your source of truth for Keynote, Google Slides, Figma, or Excalidraw export. + +--- + +## 0) Visual System (Match Landing Page) + +### Theme Tokens + +- Canvas: `#F6F6F6` +- Card surface: `#FFFFFF` +- Border: `rgba(0,0,0,0.08)` +- Headline text: `#111111` +- Body text: `#525252` +- Label text: `#9CA3AF` +- Accent (sparingly): `#6366F1` at 10-20% opacity +- Divider grid: `rgba(0,0,0,0.05)` + +### Typography + +- Headlines: `Space Grotesk` (semibold) +- Body and labels: `Manrope` +- Micro-labels: uppercase, tracking `0.28em` to `0.35em` + +### Excalidraw Style Rules + +- Use hand-drawn rectangles, rounded corners, and arrows. +- Stroke width: 1.5-2 px, dark gray (`#1F2937`), roughness medium. +- Keep icon style monoline, rounded caps. +- Add subtle paper-grid background on every slide. +- Avoid heavy gradients; if needed, use very soft radial highlights. + +### Icon Language + +Use simple line icons (Lucide/Tabler style), monochrome: + +- Memory: `brain`, `database` +- Trust/Safety: `shield-check`, `lock`, `alert-triangle` +- Interop: `plug`, `network` +- Retrieval: `search`, `layers`, `clock` +- Motion/flow: `arrow-right`, `git-branch` + +--- + +## 1) Cover Slide + +### On-Slide Content + +**PERSONAL MEMORY KERNEL FOR AI AGENTS** + +# One memory store. +# Every agent, personalized. + +Engram makes AI remember your context across tools while keeping memory user-owned. + +`pip install engram-memory` + +### Excalidraw Composition + +- Center sketch: one "Memory Vault" box in the middle. +- Around it, small agent cards: Claude Code, Cursor, Codex, Custom Agent. +- Hand-drawn arrows from all cards to the vault. +- Tiny lock icon on the vault. + +### Speaker Notes + +"Today, every agent forgets you. Engram changes that with one user-owned memory kernel any agent can plug into. Same user context, across tools, under your control." + +--- + +## 2) The Context Tax (Problem) + +### On-Slide Content + +## Every agent starts from zero. + +- You repeat preferences across sessions. +- Decisions made yesterday are lost today. +- Work quality drops from context resets. + +**The hidden tax:** re-explaining what your systems should already know. + +### Excalidraw Composition + +- Left-to-right comic strip, 3 panels: + - "Tell agent your setup" + - "Agent forgets" + - "Tell it again" +- Loop arrow from panel 3 back to panel 1. +- Add clock icon + "time lost" note. + +### Speaker Notes + +"The biggest AI productivity drain is not model quality. It is memory reset. Teams keep paying a context tax in every interaction." + +--- + +## 3) Why This Gets Worse + +### On-Slide Content + +## More agents -> more memory silos. + +- Workflows are becoming multi-agent by default. +- Each tool builds isolated context. +- User identity fragments across vendors. + +**Without a memory layer, AI stacks become context-fragmented systems.** + +### Excalidraw Composition + +- One user avatar at center top. +- Five agent boxes below, each connected to separate mini-databases. +- Red cross-lines between databases to show no interoperability. + +### Speaker Notes + +"As agent count increases, fragmentation compounds. You do not have one assistant with weak memory. You have many assistants with disconnected memory." + +--- + +## 4) The Insight + Vision + +### On-Slide Content + +## Memory should be infrastructure, not a feature. + +What users need: + +- One portable memory layer +- Works across any agent runtime +- Local-first by default +- User approval on writes + +**Engram = Personal Memory Kernel (PMK)** + +### Excalidraw Composition + +- Split slide: + - Left: "Current" (silos) + - Right: "PMK" (one shared kernel) +- Use plug icons on right side to show easy integration. + +### Speaker Notes + +"We are not building another assistant. We are building the memory substrate that personalizes every assistant." + +--- + +## 5) Product Reveal + +### On-Slide Content + +## Engram in 3 commands + +```bash +pip install engram-memory +export GEMINI_API_KEY="your-key" +engram install +``` + +Then restart your agent. + +- Persistent memory across sessions +- Scoped retrieval +- Staged writes and approval + +### Excalidraw Composition + +- Terminal card on left with the 3 commands. +- Right side: before/after cards: + - Before: "stateless assistant" + - After: "context-aware assistant" +- Arrow between before and after. + +### Speaker Notes + +"This is designed for zero-friction adoption. Install, configure once, and your existing agent stack becomes memory-enabled." + +--- + +## 6) Trust and Safety by Design + +### On-Slide Content + +## Agents are untrusted writers. + +Write pipeline: + +1. Propose -> staging +2. Verify -> invariants, conflicts, risk +3. Approve/reject -> user or policy +4. Promote -> canonical memory + +**All-but-mask:** out-of-scope data returns structure only, details redacted. + +### Excalidraw Composition + +- Horizontal 4-step pipeline with boxes and arrows. +- Small side stash box labeled "Conflict Stash". +- Shield icon above the pipeline. +- Masked response bubble: + - `type` + - `time` + - `importance` + - `details: [REDACTED]` + +### Speaker Notes + +"Most memory systems optimize for write convenience. We optimize for trust. Engram treats every agent as untrusted until proven reliable." + +--- + +## 7) Retrieval Quality: Dual Memory Engine + +### On-Slide Content + +## Better recall with fewer hallucinated joins. + +Engram retrieves in parallel: + +- Semantic memory (facts, entities, preferences) +- Episodic memory (CAST scenes: time/place/topic) + +Intersection promotion boosts results appearing in both. + +**Output:** token-bounded context packet with citations. + +### Excalidraw Composition + +- Two circles (semantic, episodic) with overlap region highlighted. +- Arrow from overlap to a "Context Packet" card. +- Magnifier icon + timeline icon. + +### Speaker Notes + +"Dual retrieval reduces the classic failure mode: semantically similar, temporally wrong answers." + +--- + +## 8) The Bio-Inspired Core + +### On-Slide Content + +## Not just vector search. + +- **FadeMem:** decay and consolidation +- **EchoMem:** multi-path encoding +- **CAST:** episodic scene memory + +Result: higher signal density, lower storage bloat, better long-horizon recall. + +### Excalidraw Composition + +- Three stacked cards with icons: + - Brain + clock (FadeMem) + - Spark/echo waves (EchoMem) + - Film frames/timeline (CAST) +- Curved arrows showing loop: write -> recall -> reinforce/decay. + +### Speaker Notes + +"Engram uses memory dynamics inspired by cognitive science: reinforcement for important memory, decay for stale memory, and episodic grouping for narrative recall." + +--- + +## 9) Why We Win (Positioning) + +### On-Slide Content + +## Category wedge: user-owned memory + +Most alternatives optimize for hosted convenience. +Engram optimizes for user control + interoperability. + +### Comparison Snapshot + +| Capability | Typical memory SaaS | Engram | +|:--|:--|:--| +| Data location | Vendor cloud first | Local-first | +| Write control | Direct writes | Staged + verification | +| Cross-agent portability | Partial | Native goal | +| Episodic memory | Limited | CAST scenes | +| Scope masking | Basic ACL | Structural redaction | + +### Excalidraw Composition + +- Matrix table inside hand-drawn frame. +- Add check icons in Engram column. +- Add a bold outline around Engram column. + +### Speaker Notes + +"Our differentiation is architectural, not cosmetic. User-owned memory is the default, not an enterprise add-on." + +--- + +## 10) Adoption and GTM + +### On-Slide Content + +## Open source adoption, productized monetization. + +Top-of-funnel: + +- `pip install` developer adoption +- MCP-native integrations +- Docs + demos + benchmark narratives + +Monetization path: + +- Managed cloud for teams +- Enterprise controls and support +- Usage-based billing tied to memory operations + +### Excalidraw Composition + +- Flywheel sketch: + - OSS adoption -> integrations -> community trust -> enterprise pull -> managed revenue -> faster product +- Use circular arrows and small icon nodes. + +### Speaker Notes + +"The open core gives distribution. Managed and enterprise offerings give durable revenue without breaking the user-owned thesis." + +--- + +## 11) Roadmap (12-Month) + +### On-Slide Content + +## Execution roadmap + +- **Q1:** Benchmark publication + retrieval quality report +- **Q2:** Deeper graph/entity memory and tooling +- **Q3:** Team memory and governance controls +- **Q4:** Managed cloud GA + migration tooling + +### Excalidraw Composition + +- Horizontal timeline with 4 milestones. +- Milestone cards include one icon each. +- Add "today" marker near Q1. + +### Speaker Notes + +"Roadmap sequencing is deliberate: prove quality, expand capability, then scale distribution and revenue." + +--- + +## 12) The Ask + +### On-Slide Content + +## We are building the memory substrate for the agent era. + +**Raising:** `[amount]` at `[stage]` + +Use of funds: + +- 60% product + infrastructure +- 25% GTM + developer growth +- 15% operations + compliance + +**CTA:** If you believe memory should be user-owned and portable, we should work together. + +Contact: + +- GitHub: `Ashish-dwi99/Engram` +- PyPI: `engram-memory` +- Email: `[you@email.com]` + +### Excalidraw Composition + +- Single centered ask card with strong border. +- Background doodles: plug, shield, brain, arrow-up. + +### Speaker Notes + +"Models will commoditize. Memory will differentiate. We are building the independent memory layer that every agent stack will need." + +--- + +## Appendix A: Slide Build Checklist (Fast) + +- Create one master with grid overlay and light paper texture. +- Keep max 1 core message per slide. +- Keep max 3 visual objects per quadrant. +- Use one accent color only for emphasis. +- Keep icon stroke consistent across all slides. + +--- + +## Appendix B: Objection Handling + +### "Could larger vendors just add this?" + +They can add memory features. Harder to add user-owned portability as their default architecture and business model. + +### "Why not just store everything forever?" + +Because unbounded memory degrades retrieval quality. Intelligent decay improves precision and cost. + +### "How do you handle privacy?" + +Local-first by default, staged writes, scoped retrieval, and structural masking for out-of-scope data. + +### "What is the moat?" + +Product moat: integrations + workflow depth. +Architecture moat: trust-aware memory operations and episodic + semantic retrieval. +Data moat: user-retained memory continuity across toolchains. + +--- + +## Timing Guide (7-8 minutes) + +| Segment | Time | +|:--|:--| +| Slides 1-3 (Problem) | 2:00 | +| Slides 4-6 (Solution + Trust) | 2:15 | +| Slides 7-9 (Tech + Positioning) | 1:45 | +| Slides 10-12 (Business + Ask) | 1:30 | + +Target total: ~7:30 + optional 2-minute product demo. diff --git a/pyproject.toml b/pyproject.toml index c45344b..3970d90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "engram-memory" -version = "0.4.1" +version = "0.5.0b1" description = "Memory layer for AI agents — biologically-inspired forgetting, multi-agent trust, and plug-and-play integrations" readme = "README.md" requires-python = ">=3.9" @@ -54,6 +54,9 @@ all = [ async = [ "aiosqlite>=0.19.0", ] +docs = [ + "reportlab>=4.0.0", +] dev = [ "pytest>=7.0.0", "pytest-asyncio>=0.21.0", diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..0322f10 --- /dev/null +++ b/scripts/__init__.py @@ -0,0 +1 @@ +"""Utility scripts for repository tooling.""" diff --git a/scripts/build_doc_book.py b/scripts/build_doc_book.py new file mode 100644 index 0000000..e06acd9 --- /dev/null +++ b/scripts/build_doc_book.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +"""Build a single documentation book PDF from per-file doc PDFs + manifest.""" + +from __future__ import annotations + +import argparse +import json +import math +import shutil +import subprocess +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, List + + +def parse_args() -> argparse.Namespace: + repo_root = Path(__file__).resolve().parents[1] + default_manifest = repo_root / "docs" / "pdf" / "manifest.json" + default_index = repo_root / "docs" / "pdf" / "BOOK_INDEX.pdf" + default_book = repo_root / "docs" / "pdf" / "BOOK.pdf" + + parser = argparse.ArgumentParser(description="Build single BOOK.pdf from doc manifest + per-file PDFs") + parser.add_argument("--manifest", default=str(default_manifest), help="Path to docgen manifest.json") + parser.add_argument("--index-output", default=str(default_index), help="Path for generated index/cover PDF") + parser.add_argument("--book-output", default=str(default_book), help="Path for final merged book PDF") + parser.add_argument("--title", default="Engram Deep Documentation Book", help="Book title on cover") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + + manifest_path = Path(args.manifest).resolve() + index_output = Path(args.index_output).resolve() + book_output = Path(args.book_output).resolve() + + if not manifest_path.exists(): + raise SystemExit(f"manifest not found: {manifest_path}") + + if shutil.which("pdfunite") is None: + raise SystemExit("pdfunite not found in PATH") + + data = json.loads(manifest_path.read_text(encoding="utf-8")) + items = sorted(data.get("items", []), key=lambda item: item["source_path"]) + if not items: + raise SystemExit("manifest contains no items") + + manifest_dir = manifest_path.parent + missing = [item["output_pdf"] for item in items if not (manifest_dir / item["output_pdf"]).exists()] + if missing: + raise SystemExit(f"missing file PDFs ({len(missing)}), first: {missing[0]}") + + cover_pages = 1 + entries_per_page = _entries_per_index_page() + index_pages = max(1, math.ceil(len(items) / entries_per_page)) + prefix_pages = cover_pages + index_pages + + current_page = prefix_pages + 1 + for item in items: + item["book_start_page"] = current_page + current_page += int(item.get("page_count", 0) or 0) + + _render_cover_and_index( + output_path=index_output, + title=args.title, + commit_hash=data.get("commit_hash", "unknown"), + generated_at=data.get("generated_at", _utc_now()), + items=items, + entries_per_page=entries_per_page, + total_pages=current_page - 1, + ) + + merge_inputs = [str(index_output)] + [str((manifest_dir / item["output_pdf"]).resolve()) for item in items] + book_output.parent.mkdir(parents=True, exist_ok=True) + subprocess.run(["pdfunite", *merge_inputs, str(book_output)], check=True) + + print(f"[book] items={len(items)}") + print(f"[book] index_pdf={index_output}") + print(f"[book] book_pdf={book_output}") + print(f"[book] total_pages={current_page - 1}") + return 0 + + +def _entries_per_index_page() -> int: + from reportlab.lib.pagesizes import LETTER # type: ignore + + _, height = LETTER + top = height - 72 + bottom = 56 + header_space = 48 + row_height = 14 + return max(1, int((top - bottom - header_space) // row_height)) + + +def _render_cover_and_index( + *, + output_path: Path, + title: str, + commit_hash: str, + generated_at: str, + items: List[Dict], + entries_per_page: int, + total_pages: int, +) -> None: + try: + from reportlab.lib.pagesizes import LETTER + from reportlab.pdfgen import canvas + except Exception as exc: # pragma: no cover + raise SystemExit("reportlab is required. Install with: pip install -e '.[docs]'") from exc + + output_path.parent.mkdir(parents=True, exist_ok=True) + c = canvas.Canvas(str(output_path), pagesize=LETTER) + width, height = LETTER + + # Cover page + c.setFont("Helvetica-Bold", 22) + c.drawString(56, height - 96, title) + c.setFont("Helvetica", 12) + c.drawString(56, height - 132, f"Generated: {generated_at}") + c.drawString(56, height - 150, f"Commit: {commit_hash}") + c.drawString(56, height - 168, f"Files documented: {len(items)}") + c.drawString(56, height - 186, f"Estimated total pages: {total_pages}") + c.setFont("Helvetica", 10) + c.drawString(56, height - 224, "This book combines all deep per-file documentation into one PDF.") + c.drawString(56, height - 240, "The next pages contain an index with starting page numbers for each source file.") + _draw_footer(c, 1) + c.showPage() + + # Index pages start at logical page 2 + logical_page = 2 + chunks = [items[i : i + entries_per_page] for i in range(0, len(items), entries_per_page)] + if not chunks: + chunks = [[]] + + for chunk_idx, chunk in enumerate(chunks, start=1): + y = height - 72 + c.setFont("Helvetica-Bold", 16) + c.drawString(56, y, f"Index ({chunk_idx}/{len(chunks)})") + y -= 28 + + c.setFont("Helvetica-Bold", 10) + c.drawString(56, y, "Source file") + c.drawRightString(width - 120, y, "Start page") + c.drawRightString(width - 56, y, "Pages") + y -= 8 + c.line(56, y, width - 56, y) + y -= 14 + + c.setFont("Helvetica", 9) + for item in chunk: + source = _ellipsize(item["source_path"], 88) + c.drawString(56, y, source) + c.drawRightString(width - 120, y, str(item.get("book_start_page", ""))) + c.drawRightString(width - 56, y, str(item.get("page_count", ""))) + y -= 14 + + _draw_footer(c, logical_page) + logical_page += 1 + c.showPage() + + c.save() + + +def _draw_footer(c, page: int) -> None: + c.setFont("Helvetica", 8) + c.drawRightString(560, 28, f"Page {page}") + + +def _ellipsize(text: str, max_chars: int) -> str: + if len(text) <= max_chars: + return text + return text[: max_chars - 3] + "..." + + +def _utc_now() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/docgen/__init__.py b/scripts/docgen/__init__.py new file mode 100644 index 0000000..5f36fec --- /dev/null +++ b/scripts/docgen/__init__.py @@ -0,0 +1,18 @@ +"""Deterministic deep documentation generation utilities.""" + +from .analyze import ( + analyze_non_python_file, + analyze_python_file, + build_doc_payload, + collect_target_files, +) +from .render_pdf import render_file_pdf, render_index_pdf + +__all__ = [ + "collect_target_files", + "analyze_python_file", + "analyze_non_python_file", + "build_doc_payload", + "render_file_pdf", + "render_index_pdf", +] diff --git a/scripts/docgen/analyze.py b/scripts/docgen/analyze.py new file mode 100644 index 0000000..18ed3d2 --- /dev/null +++ b/scripts/docgen/analyze.py @@ -0,0 +1,1094 @@ +"""Static analyzers used to produce deep deterministic file documentation.""" + +from __future__ import annotations + +import ast +import json +import re +import subprocess +from collections import Counter, defaultdict +from html.parser import HTMLParser +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +try: + import tomllib # Python 3.11+ +except ModuleNotFoundError: # pragma: no cover - depends on runtime Python version + tomllib = None + +try: + import yaml +except Exception: # pragma: no cover - optional dependency + yaml = None + + +ROOT_FILES = {"pyproject.toml", "Dockerfile", "docker-compose.yml"} +BRANCH_NODES = [ + ast.If, + ast.For, + ast.AsyncFor, + ast.While, + ast.Try, + ast.With, + ast.AsyncWith, + ast.BoolOp, + ast.IfExp, + ast.comprehension, +] +if hasattr(ast, "Match"): # Python 3.10+ + BRANCH_NODES.append(ast.Match) +BRANCH_NODES = tuple(BRANCH_NODES) + + +def collect_target_files( + repo_root: str | Path, + exclude_tests: bool = True, + include_non_python: bool = True, +) -> List[str]: + """Collect tracked files in scope for deterministic PDF generation.""" + root = Path(repo_root).resolve() + tracked = _git_ls_files(root) + + selected: List[str] = [] + for rel in tracked: + if rel.endswith(".md"): + continue + if exclude_tests and rel.startswith("tests/"): + continue + + in_scope = ( + rel.startswith("engram/") + or rel.startswith("plugins/engram-memory/") + or rel in ROOT_FILES + ) + if not in_scope: + continue + + if not include_non_python and not rel.endswith(".py"): + continue + + selected.append(rel) + + return sorted(selected) + + +def analyze_python_file(path: str | Path) -> Dict[str, Any]: + """Analyze a Python file with AST and return a deterministic metadata payload.""" + file_path = Path(path) + source = _read_text(file_path) + tree = ast.parse(source) + lines = source.splitlines() + + imports = _extract_imports(tree) + constants = _extract_constants(tree) + top_level_functions = _extract_functions(tree.body) + classes = _extract_classes(tree.body) + top_symbols = {item["name"] for item in top_level_functions} | {item["name"] for item in classes} + + parent_map = _build_parent_map(tree) + raises = _extract_raises(tree, parent_map) + call_map = _build_call_map(tree.body, top_symbols) + side_effect_hints = _side_effect_hints(imports, tree) + + branch_count = sum(1 for node in ast.walk(tree) if isinstance(node, BRANCH_NODES)) + comment_count = sum(1 for line in lines if line.strip().startswith("#")) + non_empty = sum(1 for line in lines if line.strip()) + + module_docstring = ast.get_docstring(tree) + complexity = { + "line_count": len(lines), + "non_empty_lines": non_empty, + "comment_lines": comment_count, + "branch_nodes": branch_count, + "cyclomatic_estimate": 1 + branch_count, + } + + dependencies = sorted({item["module"] for item in imports if item["module"]}) + + return { + "file_type": "python", + "path": str(file_path), + "line_count": len(lines), + "module_docstring": module_docstring or "", + "imports": imports, + "constants": constants, + "functions": top_level_functions, + "classes": classes, + "raises": raises, + "side_effect_hints": side_effect_hints, + "call_map": call_map, + "complexity": complexity, + "dependencies": dependencies, + } + + +def analyze_non_python_file(path: str | Path) -> Dict[str, Any]: + """Analyze supported non-Python files used by runtime and integration layers.""" + file_path = Path(path) + text = _read_text(file_path) + line_count = len(text.splitlines()) + lowered = file_path.name.lower() + + parser_errors: List[str] = [] + structure: List[str] = [] + runtime_implications: List[str] = [] + integrations: List[str] = [] + instructions: List[Dict[str, Any]] = [] + + if lowered == "dockerfile": + format_name = "dockerfile" + instructions = _parse_dockerfile(text) + structure = [ + f"{item['instruction']} ({file_path.name}:{item['line']})" + for item in instructions[:80] + ] + runtime_implications.extend(_docker_runtime_implications(instructions)) + integrations.extend(_docker_integrations(instructions)) + + elif file_path.suffix == ".json": + format_name = "json" + try: + parsed = json.loads(text) + structure = _flatten_keys(parsed) + runtime_implications.extend(_config_runtime_implications(structure)) + integrations.extend(_integration_hints_from_keys(structure)) + except Exception as exc: + parser_errors.append(f"JSON parse error: {exc}") + + elif file_path.suffix == ".toml": + format_name = "toml" + try: + parsed = _load_toml(text) + structure = _flatten_keys(parsed) + runtime_implications.extend(_config_runtime_implications(structure)) + integrations.extend(_integration_hints_from_keys(structure)) + except Exception as exc: + parser_errors.append(f"TOML parse error: {exc}") + + elif file_path.suffix in {".yml", ".yaml"}: + format_name = "yaml" + try: + parsed = _load_yaml(text) + structure = _flatten_keys(parsed) + runtime_implications.extend(_config_runtime_implications(structure)) + integrations.extend(_integration_hints_from_keys(structure)) + except Exception as exc: + parser_errors.append(f"YAML parse error: {exc}") + + elif file_path.suffix == ".html": + format_name = "html" + html_info = _analyze_html(text) + structure = html_info["structure"] + integrations = html_info["integrations"] + runtime_implications = html_info["runtime_implications"] + + else: + format_name = "text" + structure = _line_headings(text) + + return { + "file_type": "non_python", + "format": format_name, + "path": str(file_path), + "line_count": line_count, + "structure": structure, + "instructions": instructions, + "runtime_implications": _stable_unique(runtime_implications), + "integrations": _stable_unique(integrations), + "parser_errors": parser_errors, + } + + +def build_doc_payload(path: str | Path, analysis: Dict[str, Any]) -> Dict[str, Any]: + """Build deep sectioned documentation payload in required section order.""" + rel_path = str(path) + file_type = analysis["file_type"] + line_count = analysis.get("line_count", 0) + + sections: List[Dict[str, Any]] = [] + + sections.append( + { + "title": "Role in repository", + "paragraphs": _role_in_repository(rel_path, analysis), + "code_blocks": [], + } + ) + + file_map_lines, metrics_block = _file_map_and_metrics(rel_path, analysis) + sections.append( + { + "title": "File map and metrics", + "paragraphs": file_map_lines, + "code_blocks": [metrics_block] if metrics_block else [], + } + ) + + interface_paragraphs, interface_blocks = _public_interfaces(rel_path, analysis) + sections.append( + { + "title": "Public interfaces and key symbols", + "paragraphs": interface_paragraphs, + "code_blocks": interface_blocks, + } + ) + + sections.append( + { + "title": "Execution/data flow walkthrough", + "paragraphs": _execution_walkthrough(rel_path, analysis), + "code_blocks": [], + } + ) + + sections.append( + { + "title": "Error handling and edge cases", + "paragraphs": _error_and_edge_cases(rel_path, analysis), + "code_blocks": [], + } + ) + + sections.append( + { + "title": "Integration and dependencies", + "paragraphs": _integration_and_dependencies(rel_path, analysis), + "code_blocks": [], + } + ) + + sections.append( + { + "title": "Safe modification guide", + "paragraphs": _safe_modification(rel_path, analysis), + "code_blocks": [], + } + ) + + sections.append( + { + "title": "Reading order for large files", + "paragraphs": _reading_order(rel_path, analysis), + "code_blocks": [], + } + ) + + return { + "file_path": rel_path, + "line_count": line_count, + "file_type": file_type, + "doc_depth": "deep", + "method": "deterministic_static", + "sections": sections, + } + + +def _git_ls_files(repo_root: Path) -> List[str]: + output = subprocess.check_output(["git", "-C", str(repo_root), "ls-files"], text=True) + return [line.strip() for line in output.splitlines() if line.strip()] + + +def _read_text(path: Path) -> str: + try: + return path.read_text(encoding="utf-8") + except UnicodeDecodeError: + return path.read_text(encoding="latin-1") + + +def _extract_imports(tree: ast.AST) -> List[Dict[str, Any]]: + imports: List[Dict[str, Any]] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + imports.append( + { + "module": alias.name, + "name": alias.asname or alias.name, + "line": node.lineno, + "kind": "import", + } + ) + elif isinstance(node, ast.ImportFrom): + module = node.module or "" + for alias in node.names: + imports.append( + { + "module": module, + "name": alias.asname or alias.name, + "line": node.lineno, + "kind": "from", + } + ) + imports.sort(key=lambda item: (item["line"], item["module"], item["name"])) + return imports + + +def _extract_constants(tree: ast.Module) -> List[Dict[str, Any]]: + constants: List[Dict[str, Any]] = [] + for node in tree.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id.isupper(): + constants.append({"name": target.id, "line": node.lineno}) + elif isinstance(node, ast.AnnAssign): + if isinstance(node.target, ast.Name) and node.target.id.isupper(): + constants.append({"name": node.target.id, "line": node.lineno}) + return sorted(constants, key=lambda item: (item["line"], item["name"])) + + +def _extract_functions(nodes: Iterable[ast.stmt]) -> List[Dict[str, Any]]: + funcs: List[Dict[str, Any]] = [] + for node in nodes: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + funcs.append(_function_metadata(node)) + return funcs + + +def _extract_classes(nodes: Iterable[ast.stmt]) -> List[Dict[str, Any]]: + classes: List[Dict[str, Any]] = [] + for node in nodes: + if not isinstance(node, ast.ClassDef): + continue + + methods: List[Dict[str, Any]] = [] + for child in node.body: + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + methods.append(_function_metadata(child, class_name=node.name)) + + bases = [_safe_unparse(base) for base in node.bases] + decorators = [_safe_unparse(dec) for dec in node.decorator_list] + classes.append( + { + "name": node.name, + "line": node.lineno, + "end_line": getattr(node, "end_lineno", node.lineno), + "bases": bases, + "decorators": decorators, + "methods": methods, + } + ) + return classes + + +def _function_metadata( + node: ast.FunctionDef | ast.AsyncFunctionDef, + class_name: Optional[str] = None, +) -> Dict[str, Any]: + decorators = [_safe_unparse(dec) for dec in node.decorator_list] + signature = _format_signature(node) + return { + "name": node.name, + "qualified_name": f"{class_name}.{node.name}" if class_name else node.name, + "line": node.lineno, + "end_line": getattr(node, "end_lineno", node.lineno), + "is_async": isinstance(node, ast.AsyncFunctionDef), + "decorators": decorators, + "signature": signature, + } + + +def _format_signature(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str: + prefix = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def" + args_text = _safe_unparse(node.args) + ret_text = f" -> {_safe_unparse(node.returns)}" if node.returns is not None else "" + return f"{prefix} {node.name}({args_text}){ret_text}" + + +def _build_parent_map(tree: ast.AST) -> Dict[ast.AST, ast.AST]: + parent_map: Dict[ast.AST, ast.AST] = {} + for parent in ast.walk(tree): + for child in ast.iter_child_nodes(parent): + parent_map[child] = parent + return parent_map + + +def _extract_raises(tree: ast.AST, parent_map: Dict[ast.AST, ast.AST]) -> List[Dict[str, Any]]: + raises: List[Dict[str, Any]] = [] + for node in ast.walk(tree): + if not isinstance(node, ast.Raise): + continue + + exc = _safe_unparse(node.exc) if node.exc else "re-raise" + context = _enclosing_symbol(node, parent_map) + raises.append( + { + "line": node.lineno, + "exception": exc, + "context": context, + } + ) + raises.sort(key=lambda item: (item["line"], item["exception"])) + return raises + + +def _enclosing_symbol(node: ast.AST, parent_map: Dict[ast.AST, ast.AST]) -> str: + current = node + fn_name: Optional[str] = None + class_name: Optional[str] = None + + while current in parent_map: + current = parent_map[current] + if fn_name is None and isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef)): + fn_name = current.name + elif class_name is None and isinstance(current, ast.ClassDef): + class_name = current.name + + if class_name and fn_name: + return f"{class_name}.{fn_name}" + if fn_name: + return fn_name + if class_name: + return class_name + return "module" + + +def _build_call_map(nodes: Iterable[ast.stmt], top_symbols: set[str]) -> Dict[str, List[str]]: + call_map: Dict[str, List[str]] = {} + + for node in nodes: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + called = _called_top_symbols(node, top_symbols) + call_map[node.name] = called + elif isinstance(node, ast.ClassDef): + for child in node.body: + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + key = f"{node.name}.{child.name}" + call_map[key] = _called_top_symbols(child, top_symbols) + + return call_map + + +def _called_top_symbols(node: ast.AST, top_symbols: set[str]) -> List[str]: + called: List[str] = [] + seen: set[str] = set() + for item in ast.walk(node): + if not isinstance(item, ast.Call): + continue + raw = _call_name(item) + if not raw: + continue + first = raw.split(".")[0] + if first in top_symbols and first not in seen: + seen.add(first) + called.append(first) + return called + + +def _call_name(node: ast.Call) -> Optional[str]: + if isinstance(node.func, ast.Name): + return node.func.id + if isinstance(node.func, ast.Attribute): + chain: List[str] = [] + current: ast.AST = node.func + while isinstance(current, ast.Attribute): + chain.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + chain.append(current.id) + chain.reverse() + return ".".join(chain) + return None + + +def _side_effect_hints(imports: List[Dict[str, Any]], tree: ast.AST) -> Dict[str, List[str]]: + modules = [item["module"].lower() for item in imports if item["module"]] + calls = [(_call_name(node) or "").lower() for node in ast.walk(tree) if isinstance(node, ast.Call)] + tokens = modules + calls + + buckets = { + "database": ["sqlite", "qdrant", "database", "db", "cursor", "execute", "commit", "rollback"], + "network": ["http", "request", "socket", "openai", "gemini", "ollama", "client", "api"], + "filesystem": ["open", "path", "mkdir", "write", "unlink", "rename", "shutil", "glob"], + "logging": ["logging", "logger", "print"], + "subprocess": ["subprocess", "popen", "check_output", "os.system", "run"], + "environment": ["getenv", "environ", "os.environ"], + } + + hints: Dict[str, List[str]] = {} + for bucket, markers in buckets.items(): + matched = sorted( + { + token + for token in tokens + for marker in markers + if marker in token + } + ) + if matched: + hints[bucket] = matched + + return hints + + +def _load_toml(text: str) -> Any: + if tomllib is not None: + return tomllib.loads(text) + + try: + import tomli # type: ignore + + return tomli.loads(text) + except Exception as exc: # pragma: no cover - dependent on environment + raise RuntimeError("tomllib/tomli unavailable for TOML parsing") from exc + + +def _load_yaml(text: str) -> Any: + if yaml is not None: + return yaml.safe_load(text) + + # Fallback: parse top-level keys only if PyYAML is unavailable. + parsed: Dict[str, Any] = {} + for line in text.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if ":" in stripped and not stripped.startswith("-"): + key = stripped.split(":", 1)[0].strip() + parsed[key] = "" + return parsed + + +def _flatten_keys(data: Any, prefix: str = "") -> List[str]: + keys: List[str] = [] + + if isinstance(data, dict): + for key in sorted(data.keys(), key=lambda item: str(item)): + key_str = str(key) + current = f"{prefix}.{key_str}" if prefix else key_str + keys.append(current) + keys.extend(_flatten_keys(data[key], current)) + elif isinstance(data, list): + current = f"{prefix}[]" if prefix else "[]" + keys.append(current) + if data: + keys.extend(_flatten_keys(data[0], current)) + + return keys[:400] + + +def _config_runtime_implications(keys: List[str]) -> List[str]: + implications: List[str] = [] + lowered = [key.lower() for key in keys] + + if any("dependencies" in key for key in lowered): + implications.append("Dependency declarations drive installation/runtime compatibility.") + if any("scripts" in key for key in lowered): + implications.append("Script entries define developer and release workflows.") + if any("environment" in key or "env" in key for key in lowered): + implications.append("Environment keys influence runtime behavior across environments.") + if any("port" in key for key in lowered): + implications.append("Port settings determine network exposure and service wiring.") + if any("api" in key or "key" in key or "token" in key for key in lowered): + implications.append("Credential-related keys require secret management and redaction discipline.") + if any("database" in key or "sqlite" in key or "qdrant" in key for key in lowered): + implications.append("Storage-related keys alter persistence topology and migration expectations.") + + return implications + + +def _integration_hints_from_keys(keys: List[str]) -> List[str]: + integrations: List[str] = [] + lowered = [key.lower() for key in keys] + + for provider in ["openai", "gemini", "ollama", "qdrant", "fastapi", "docker", "mcp"]: + if any(provider in key for key in lowered): + integrations.append(f"Contains configuration surface for {provider} integration.") + + return integrations + + +def _parse_dockerfile(text: str) -> List[Dict[str, Any]]: + instructions: List[Dict[str, Any]] = [] + + for lineno, line in enumerate(text.splitlines(), start=1): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + + match = re.match(r"^([A-Za-z]+)\s+(.*)$", stripped) + if match: + instruction = match.group(1).upper() + value = match.group(2).strip() + else: + instruction = "RAW" + value = stripped + + instructions.append( + { + "line": lineno, + "instruction": instruction, + "value": value, + } + ) + + return instructions + + +def _docker_runtime_implications(instructions: List[Dict[str, Any]]) -> List[str]: + implications: List[str] = [] + names = {item["instruction"] for item in instructions} + + if "FROM" in names: + implications.append("Base image selection constrains OS packages and runtime security posture.") + if "RUN" in names: + implications.append("RUN layers define build-time dependencies and affect image reproducibility.") + if "EXPOSE" in names: + implications.append("EXPOSE signals service ports expected by runtime orchestration.") + if "CMD" in names or "ENTRYPOINT" in names: + implications.append("Process startup instructions control container lifecycle and health behavior.") + if "ENV" in names: + implications.append("ENV instructions create default environment values for downstream execution.") + + return implications + + +def _docker_integrations(instructions: List[Dict[str, Any]]) -> List[str]: + integrations: List[str] = [] + values = "\n".join(item["value"] for item in instructions).lower() + + for token in ["python", "uvicorn", "fastapi", "qdrant", "sqlite", "engram"]: + if token in values: + integrations.append(f"Docker build/runtime references {token} components.") + + return integrations + + +class _HTMLShapeParser(HTMLParser): + def __init__(self) -> None: + super().__init__() + self.tag_counter: Counter[str] = Counter() + self.ids: set[str] = set() + self.classes: set[str] = set() + self.scripts: List[str] = [] + self.links: List[str] = [] + + def handle_starttag(self, tag: str, attrs: List[tuple[str, Optional[str]]]) -> None: + self.tag_counter[tag] += 1 + attr_map = {key: value for key, value in attrs} + + if attr_map.get("id"): + self.ids.add(attr_map["id"] or "") + if attr_map.get("class"): + class_value = attr_map["class"] or "" + for item in class_value.split(): + self.classes.add(item) + if tag == "script": + src = attr_map.get("src") or "inline-script" + self.scripts.append(src) + if tag in {"a", "link"}: + href = attr_map.get("href") + if href: + self.links.append(href) + + +def _analyze_html(text: str) -> Dict[str, Any]: + parser = _HTMLShapeParser() + parser.feed(text) + + structure = [f"<{tag}> count={count}" for tag, count in parser.tag_counter.most_common(20)] + if parser.ids: + structure.append(f"IDs: {', '.join(sorted(parser.ids)[:25])}") + if parser.classes: + structure.append(f"Classes: {', '.join(sorted(parser.classes)[:25])}") + + integrations: List[str] = [] + for src in parser.scripts[:50]: + integrations.append(f"Script dependency: {src}") + for href in parser.links[:50]: + integrations.append(f"Hyperlink/resource reference: {href}") + + runtime_implications = [ + "HTML structure defines frontend entry points and developer observability surfaces.", + ] + if parser.scripts: + runtime_implications.append("Script tags can introduce runtime dependencies and browser execution order coupling.") + if parser.links: + runtime_implications.append("External links/resources may fail if deployment paths or hosts change.") + + return { + "structure": structure, + "integrations": integrations, + "runtime_implications": runtime_implications, + } + + +def _line_headings(text: str) -> List[str]: + headings: List[str] = [] + for lineno, line in enumerate(text.splitlines(), start=1): + stripped = line.strip() + if not stripped: + continue + if stripped.startswith("[") and stripped.endswith("]"): + headings.append(f"Section {stripped} (line {lineno})") + elif ":" in stripped and not stripped.startswith("#"): + key = stripped.split(":", 1)[0].strip() + headings.append(f"Key {key} (line {lineno})") + if len(headings) >= 120: + break + return headings + + +def _role_in_repository(rel_path: str, analysis: Dict[str, Any]) -> List[str]: + role = _role_from_path(rel_path) + file_type = analysis["file_type"] + paragraphs = [ + f"`{rel_path}` is part of the {role}. This file is analyzed as `{file_type}` content.", + ] + + if file_type == "python": + funcs = len(analysis.get("functions", [])) + classes = len(analysis.get("classes", [])) + paragraphs.append( + f"Symbol density: {classes} class(es), {funcs} top-level function(s), " + f"{analysis.get('line_count', 0)} total lines.") + else: + fmt = analysis.get("format", "text") + paragraphs.append( + f"Configuration/asset format `{fmt}` controls runtime behavior for the repository's integration and deployment surfaces.") + + return paragraphs + + +def _role_from_path(rel_path: str) -> str: + mapping = { + "engram/api/": "API service layer", + "engram/core/": "core memory kernel", + "engram/db/": "database persistence layer", + "engram/embeddings/": "embedding provider integration layer", + "engram/llms/": "LLM provider abstraction layer", + "engram/memory/": "public memory orchestration layer", + "engram/retrieval/": "retrieval and ranking layer", + "engram/vector_stores/": "vector-store backend layer", + "engram/utils/": "shared utility layer", + "engram/integrations/": "external tool integration layer", + "plugins/engram-memory/": "agent plugin integration package", + } + for prefix, role in mapping.items(): + if rel_path.startswith(prefix): + return role + if rel_path in ROOT_FILES: + return "root runtime/build configuration" + return "repository support layer" + + +def _file_map_and_metrics(rel_path: str, analysis: Dict[str, Any]) -> tuple[List[str], str]: + paragraphs: List[str] = [] + blocks: List[str] = [] + line_count = analysis.get("line_count", 0) + + paragraphs.append(f"The file has {line_count} lines and belongs to `{analysis['file_type']}` analysis path.") + + if analysis["file_type"] == "python": + complexity = analysis.get("complexity", {}) + blocks.extend( + [ + f"line_count: {complexity.get('line_count', line_count)}", + f"non_empty_lines: {complexity.get('non_empty_lines', 0)}", + f"comment_lines: {complexity.get('comment_lines', 0)}", + f"branch_nodes: {complexity.get('branch_nodes', 0)}", + f"cyclomatic_estimate: {complexity.get('cyclomatic_estimate', 1)}", + ] + ) + + if analysis.get("module_docstring"): + paragraphs.append( + f"Module docstring present near `{rel_path}:1`, indicating explicit file intent and usage guidance.") + else: + paragraphs.append("No module docstring detected; intent must be inferred from symbols and call graph.") + else: + fmt = analysis.get("format", "text") + structure_len = len(analysis.get("structure", [])) + paragraphs.append( + f"Detected `{fmt}` structure with {structure_len} key/instruction/tag entries captured for documentation.") + if analysis.get("parser_errors"): + paragraphs.append( + f"Parser reported {len(analysis['parser_errors'])} issue(s); documentation falls back to best-effort structural extraction.") + + return paragraphs, "\n".join(blocks) + + +def _public_interfaces(rel_path: str, analysis: Dict[str, Any]) -> tuple[List[str], List[str]]: + paragraphs: List[str] = [] + blocks: List[str] = [] + + if analysis["file_type"] == "python": + constants = analysis.get("constants", []) + classes = analysis.get("classes", []) + functions = analysis.get("functions", []) + + if constants: + constants_text = ", ".join( + f"`{item['name']}` ({rel_path}:{item['line']})" for item in constants[:40] + ) + paragraphs.append(f"Top-level constants: {constants_text}.") + + if classes: + paragraphs.append( + f"Class interfaces appear at: " + + ", ".join(f"`{item['name']}` ({rel_path}:{item['line']})" for item in classes[:30]) + + "." + ) + class_sigs: List[str] = [] + for cls in classes[:20]: + base_text = f"({', '.join(cls['bases'])})" if cls.get("bases") else "" + class_sigs.append(f"class {cls['name']}{base_text} # {rel_path}:{cls['line']}") + for method in cls.get("methods", [])[:20]: + class_sigs.append(f" {method['signature']} # {rel_path}:{method['line']}") + blocks.append("\n".join(class_sigs)) + + if functions: + paragraphs.append( + f"Top-level callable interfaces: " + + ", ".join(f"`{item['name']}` ({rel_path}:{item['line']})" for item in functions[:40]) + + "." + ) + blocks.append( + "\n".join( + f"{item['signature']} # {rel_path}:{item['line']}" for item in functions[:80] + ) + ) + + if not (constants or classes or functions): + paragraphs.append("No top-level public symbols detected (likely package marker or data-only module).") + + else: + structure = analysis.get("structure", []) + if structure: + preview = ", ".join(f"`{item}`" for item in structure[:30]) + paragraphs.append( + f"Primary structural interfaces include: {preview}." + ) + else: + paragraphs.append("No structured interface extracted; file may be minimal or free-form.") + + instructions = analysis.get("instructions", []) + if instructions: + blocks.append( + "\n".join( + f"{item['instruction']} {item['value']} # {rel_path}:{item['line']}" + for item in instructions[:80] + ) + ) + + return paragraphs, [block for block in blocks if block.strip()] + + +def _execution_walkthrough(rel_path: str, analysis: Dict[str, Any]) -> List[str]: + paragraphs: List[str] = [] + + if analysis["file_type"] == "python": + call_map = analysis.get("call_map", {}) + if not call_map: + paragraphs.append("No internal top-level call chaining detected; behavior is either declarative or externally invoked.") + else: + for symbol, callees in list(call_map.items())[:80]: + if callees: + paragraphs.append( + f"Execution path: `{symbol}` in `{rel_path}` invokes {', '.join(f'`{c}`' for c in callees)}.") + else: + paragraphs.append( + f"Execution path: `{symbol}` in `{rel_path}` has no detected calls to same-file top-level symbols.") + + async_symbols = [ + item["qualified_name"] + for item in analysis.get("functions", []) + if item.get("is_async") + ] + for cls in analysis.get("classes", []): + async_symbols.extend( + method["qualified_name"] + for method in cls.get("methods", []) + if method.get("is_async") + ) + if async_symbols: + paragraphs.append( + "Async execution surfaces detected: " + ", ".join(f"`{name}`" for name in async_symbols[:40]) + "." + ) + + else: + fmt = analysis.get("format", "text") + if fmt == "dockerfile": + paragraphs.append( + "Dockerfile flow executes top-to-bottom as image layers; changing earlier instructions invalidates downstream cache layers.") + paragraphs.append( + f"Instruction order is captured with line references in `{rel_path}` and should be reviewed sequentially during modifications.") + elif fmt in {"json", "toml", "yaml"}: + paragraphs.append( + "Configuration flow is key-driven: loaders parse hierarchy first, then runtime components consume specific sections.") + paragraphs.append( + f"Use the extracted hierarchy from `{rel_path}` to locate producer/consumer contracts before editing values.") + elif fmt == "html": + paragraphs.append( + "HTML flow is browser-driven: DOM structure defines render order while script/link tags define runtime side effects.") + paragraphs.append( + f"Tag and resource extraction from `{rel_path}` highlights where integration contracts attach.") + else: + paragraphs.append("Execution flow is not directly inferable from this file type; treat it as passive input to other systems.") + + return paragraphs + + +def _error_and_edge_cases(rel_path: str, analysis: Dict[str, Any]) -> List[str]: + paragraphs: List[str] = [] + + if analysis["file_type"] == "python": + raises = analysis.get("raises", []) + if raises: + paragraphs.append( + f"Detected {len(raises)} explicit `raise` statement(s); key sites include: " + + ", ".join( + f"`{item['exception']}` in `{item['context']}` ({rel_path}:{item['line']})" + for item in raises[:30] + ) + + "." + ) + else: + paragraphs.append("No explicit `raise` statements detected; failures may surface through dependency calls or return-value signaling.") + + complexity = analysis.get("complexity", {}) + if complexity.get("branch_nodes", 0) > 80: + paragraphs.append( + "High branch density suggests multiple conditional paths; validate edge-case behavior when changing conditionals.") + if complexity.get("cyclomatic_estimate", 1) > 120: + paragraphs.append( + "Very high cyclomatic estimate indicates elevated regression risk; target incremental edits with focused tests.") + + else: + parser_errors = analysis.get("parser_errors", []) + if parser_errors: + paragraphs.append("Parsing issues detected: " + "; ".join(parser_errors) + ".") + else: + paragraphs.append("No parser-level structural errors detected.") + + fmt = analysis.get("format") + if fmt in {"json", "toml", "yaml"}: + paragraphs.append( + "Edge cases: invalid syntax, missing required keys, and value-type drift can break boot/runtime configuration loading.") + elif fmt == "dockerfile": + paragraphs.append( + "Edge cases: cache invalidation, missing build context files, and incompatible base images can break builds.") + elif fmt == "html": + paragraphs.append( + "Edge cases: missing script resources and DOM id/class mismatches can break client-side behavior.") + + return paragraphs + + +def _integration_and_dependencies(rel_path: str, analysis: Dict[str, Any]) -> List[str]: + paragraphs: List[str] = [] + + if analysis["file_type"] == "python": + deps = analysis.get("dependencies", []) + if deps: + paragraphs.append( + "Direct imports indicate dependency surface: " + + ", ".join(f"`{dep}`" for dep in deps[:60]) + + "." + ) + + hints = analysis.get("side_effect_hints", {}) + if hints: + for bucket, tokens in hints.items(): + paragraphs.append( + f"{bucket.capitalize()} side-effect signals from `{rel_path}`: " + + ", ".join(f"`{token}`" for token in tokens[:20]) + + "." + ) + else: + paragraphs.append("No strong side-effect markers detected from imports/call signatures.") + + else: + integrations = analysis.get("integrations", []) + implications = analysis.get("runtime_implications", []) + if integrations: + paragraphs.append( + "Integration touchpoints: " + ", ".join(f"{item}" for item in integrations[:30]) + ) + if implications: + paragraphs.append( + "Runtime implications: " + " ".join(implications[:10]) + ) + if not integrations and not implications: + paragraphs.append("No major integration signals extracted from this file.") + + return paragraphs + + +def _safe_modification(rel_path: str, analysis: Dict[str, Any]) -> List[str]: + base = [ + f"Start by checking dependent call sites and imports for `{rel_path}` before renaming symbols or keys.", + "Preserve existing function/class signatures unless all callers are updated in the same change set.", + "Apply narrow edits first, then run targeted tests for affected modules before broad refactors.", + ] + + if analysis["file_type"] == "python": + if analysis.get("side_effect_hints", {}).get("database"): + base.append("Database interaction signals are present; validate migrations and transaction boundaries after edits.") + if analysis.get("side_effect_hints", {}).get("network"): + base.append("Network/API interaction signals are present; test failure and timeout handling paths explicitly.") + else: + fmt = analysis.get("format") + if fmt in {"json", "toml", "yaml"}: + base.append("Keep key names stable where they are consumed by code paths or deployment tooling.") + if fmt == "dockerfile": + base.append("Modify Dockerfile instructions with cache and layer ordering in mind to avoid accidental build regressions.") + if fmt == "html": + base.append("Coordinate HTML id/class changes with JavaScript and CSS selectors to avoid runtime UI regressions.") + + return base + + +def _reading_order(rel_path: str, analysis: Dict[str, Any]) -> List[str]: + line_count = analysis.get("line_count", 0) + + if line_count <= 200: + return [ + f"`{rel_path}` is compact ({line_count} lines). Read top-to-bottom in one pass, then revisit integration points.", + ] + + if analysis["file_type"] == "python": + steps: List[str] = [ + f"Pass 1: scan imports and constants near the top of `{rel_path}` to understand dependencies and global controls.", + ] + + classes = analysis.get("classes", []) + functions = analysis.get("functions", []) + if classes: + steps.append( + "Pass 2: read class definitions in line order: " + + ", ".join(f"`{item['name']}` ({rel_path}:{item['line']})" for item in classes[:20]) + + "." + ) + if functions: + steps.append( + "Pass 3: review top-level functions in line order: " + + ", ".join(f"`{item['name']}` ({rel_path}:{item['line']})" for item in functions[:25]) + + "." + ) + + steps.append("Final pass: inspect raise sites and side-effect hints to map operational risk points.") + return steps + + return [ + f"Read `{rel_path}` in declaration order, then validate extracted hierarchy/instructions against runtime consumers.", + "Prioritize sections that define dependencies, service commands, credentials, ports, and integration URLs.", + ] + + +def _safe_unparse(node: Optional[ast.AST]) -> str: + if node is None: + return "" + try: + return ast.unparse(node) + except Exception: + return "" + + +def _stable_unique(items: Iterable[str]) -> List[str]: + seen: set[str] = set() + out: List[str] = [] + for item in items: + if item not in seen: + seen.add(item) + out.append(item) + return out diff --git a/scripts/docgen/render_pdf.py b/scripts/docgen/render_pdf.py new file mode 100644 index 0000000..84e1b04 --- /dev/null +++ b/scripts/docgen/render_pdf.py @@ -0,0 +1,230 @@ +"""PDF rendering helpers for deterministic deep documentation.""" + +from __future__ import annotations + +import html +from pathlib import Path +from typing import Any, Dict, List + + +def _load_reportlab() -> Dict[str, Any]: + try: + from reportlab.lib.pagesizes import LETTER + from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet + from reportlab.lib.units import inch + from reportlab.pdfgen import canvas + from reportlab.platypus import PageBreak, Paragraph, Preformatted, SimpleDocTemplate, Spacer + except Exception as exc: # pragma: no cover - import error path + raise RuntimeError( + "reportlab is required for PDF rendering. Install with: pip install -e '.[docs]'" + ) from exc + + return { + "LETTER": LETTER, + "ParagraphStyle": ParagraphStyle, + "getSampleStyleSheet": getSampleStyleSheet, + "inch": inch, + "canvas": canvas, + "PageBreak": PageBreak, + "Paragraph": Paragraph, + "Preformatted": Preformatted, + "SimpleDocTemplate": SimpleDocTemplate, + "Spacer": Spacer, + } + + +def render_file_pdf(payload: Dict[str, Any], output_pdf: str | Path) -> int: + """Render one deep file-guide PDF and return the resulting page count.""" + rl = _load_reportlab() + styles = _styles(rl) + + out_path = Path(output_pdf) + out_path.parent.mkdir(parents=True, exist_ok=True) + + doc = rl["SimpleDocTemplate"]( + str(out_path), + pagesize=rl["LETTER"], + leftMargin=0.75 * rl["inch"], + rightMargin=0.75 * rl["inch"], + topMargin=0.75 * rl["inch"], + bottomMargin=0.75 * rl["inch"], + title=f"Deep File Guide: {payload.get('file_path', '')}", + author="Engram deterministic docgen", + ) + + story: List[Any] = [] + + story.append(rl["Paragraph"]("Deep File Guide", styles["title"])) + story.append(rl["Spacer"](1, 0.15 * rl["inch"])) + story.append(rl["Paragraph"](f"File: {_escape(payload.get('file_path', ''))}", styles["meta"])) + story.append(rl["Paragraph"](f"Generated: {_escape(payload.get('generated_at', ''))}", styles["meta"])) + story.append(rl["Paragraph"](f"Commit: {_escape(payload.get('commit_hash', ''))}", styles["meta"])) + story.append(rl["Paragraph"](f"Method: {_escape(payload.get('method', ''))}", styles["meta"])) + story.append(rl["Paragraph"](f"Depth: {_escape(payload.get('doc_depth', ''))}", styles["meta"])) + story.append(rl["Paragraph"](f"Line count: {_escape(str(payload.get('line_count', '')))}", styles["meta"])) + story.append(rl["PageBreak"]()) + + for section in payload.get("sections", []): + title = section.get("title", "Section") + story.append(rl["Paragraph"](_escape(title), styles["h2"])) + story.append(rl["Spacer"](1, 0.06 * rl["inch"])) + + for paragraph in section.get("paragraphs", []): + story.append(rl["Paragraph"](_paragraph_text(paragraph), styles["body"])) + story.append(rl["Spacer"](1, 0.04 * rl["inch"])) + + for block in section.get("code_blocks", []): + if not str(block).strip(): + continue + story.append(rl["Preformatted"](str(block), styles["mono"])) + story.append(rl["Spacer"](1, 0.08 * rl["inch"])) + + story.append(rl["Spacer"](1, 0.1 * rl["inch"])) + + page_count = _build_with_numbered_canvas(doc, story, rl) + return page_count + + +def render_index_pdf(index_payload: Dict[str, Any], output_pdf: str | Path) -> int: + """Render the global index PDF and return page count.""" + rl = _load_reportlab() + styles = _styles(rl) + + out_path = Path(output_pdf) + out_path.parent.mkdir(parents=True, exist_ok=True) + + doc = rl["SimpleDocTemplate"]( + str(out_path), + pagesize=rl["LETTER"], + leftMargin=0.75 * rl["inch"], + rightMargin=0.75 * rl["inch"], + topMargin=0.75 * rl["inch"], + bottomMargin=0.75 * rl["inch"], + title="Deep Documentation Index", + author="Engram deterministic docgen", + ) + + story: List[Any] = [] + story.append(rl["Paragraph"]("Deep Documentation Index", styles["title"])) + story.append(rl["Spacer"](1, 0.15 * rl["inch"])) + story.append(rl["Paragraph"](f"Generated: {_escape(index_payload.get('generated_at', ''))}", styles["meta"])) + story.append(rl["Paragraph"](f"Commit: {_escape(index_payload.get('commit_hash', ''))}", styles["meta"])) + story.append(rl["Paragraph"](f"Total files: {_escape(str(index_payload.get('total_files', 0)))}", styles["meta"])) + story.append(rl["Spacer"](1, 0.1 * rl["inch"])) + + story.append(rl["Paragraph"]("What to Read First", styles["h2"])) + for line in index_payload.get("reading_guide", []): + story.append(rl["Paragraph"](_paragraph_text(line), styles["body"])) + story.append(rl["Spacer"](1, 0.03 * rl["inch"])) + + story.append(rl["PageBreak"]()) + + groups: Dict[str, List[Dict[str, Any]]] = index_payload.get("groups", {}) + for group_name in sorted(groups): + story.append(rl["Paragraph"](_escape(group_name), styles["h2"])) + story.append(rl["Spacer"](1, 0.05 * rl["inch"])) + for item in groups[group_name]: + line = ( + f"Source: {_escape(item.get('source_path', ''))}" + f"
PDF: {_escape(item.get('output_pdf', ''))}" + f"
Lines: {_escape(str(item.get('line_count', '')))}, " + f"Pages: {_escape(str(item.get('page_count', '')))}" + ) + story.append(rl["Paragraph"](line, styles["body"])) + story.append(rl["Spacer"](1, 0.05 * rl["inch"])) + story.append(rl["Spacer"](1, 0.08 * rl["inch"])) + + page_count = _build_with_numbered_canvas(doc, story, rl) + return page_count + + +def _styles(rl: Dict[str, Any]) -> Dict[str, Any]: + style_sheet = rl["getSampleStyleSheet"]() + paragraph_style = rl["ParagraphStyle"] + + return { + "title": paragraph_style( + "DocgenTitle", + parent=style_sheet["Heading1"], + fontSize=20, + leading=24, + spaceAfter=10, + ), + "h2": paragraph_style( + "DocgenHeading2", + parent=style_sheet["Heading2"], + fontSize=13, + leading=16, + spaceBefore=4, + spaceAfter=4, + ), + "meta": paragraph_style( + "DocgenMeta", + parent=style_sheet["BodyText"], + fontSize=10, + leading=12, + ), + "body": paragraph_style( + "DocgenBody", + parent=style_sheet["BodyText"], + fontSize=10, + leading=13, + ), + "mono": paragraph_style( + "DocgenMono", + parent=style_sheet["Code"], + fontName="Courier", + fontSize=8, + leading=10, + leftIndent=8, + ), + } + + +def _build_with_numbered_canvas(doc: Any, story: List[Any], rl: Dict[str, Any]) -> int: + canvas_module = rl["canvas"] + + class NumberedCanvas(canvas_module.Canvas): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._saved_page_states: List[Dict[str, Any]] = [] + self.page_count: int = 0 + + def showPage(self) -> None: # noqa: N802 - reportlab API + self._saved_page_states.append(dict(self.__dict__)) + self._startPage() + + def save(self) -> None: # noqa: N802 - reportlab API + if not self._saved_page_states or self._saved_page_states[-1].get("_pageNumber") != self._pageNumber: + self._saved_page_states.append(dict(self.__dict__)) + + page_count = len(self._saved_page_states) + for state in self._saved_page_states: + self.__dict__.update(state) + self._draw_page_number(page_count) + super().showPage() + self.page_count = page_count + super().save() + + def _draw_page_number(self, page_count: int) -> None: + self.setFont("Helvetica", 8) + self.drawRightString(7.5 * rl["inch"], 0.45 * rl["inch"], f"Page {self._pageNumber} of {page_count}") + + holder: Dict[str, Any] = {} + + def canvas_factory(*args: Any, **kwargs: Any) -> NumberedCanvas: + canvas_obj = NumberedCanvas(*args, **kwargs) + holder["canvas"] = canvas_obj + return canvas_obj + + doc.build(story, canvasmaker=canvas_factory) + return holder["canvas"].page_count if "canvas" in holder else 0 + + +def _escape(value: Any) -> str: + return html.escape(str(value), quote=True) + + +def _paragraph_text(text: Any) -> str: + value = _escape(text) + return value.replace("\n", "
") diff --git a/scripts/generate_deep_docs.py b/scripts/generate_deep_docs.py new file mode 100644 index 0000000..a770380 --- /dev/null +++ b/scripts/generate_deep_docs.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 +"""Generate deep deterministic per-file PDF documentation for Engram.""" + +from __future__ import annotations + +import argparse +import concurrent.futures +import hashlib +import json +import os +import subprocess +from collections import defaultdict +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Tuple + +try: # Script execution path: /.../scripts on sys.path + from docgen.analyze import ( + analyze_non_python_file, + analyze_python_file, + build_doc_payload, + collect_target_files, + ) + from docgen.render_pdf import render_file_pdf, render_index_pdf +except ModuleNotFoundError: # Module import path: scripts.generate_deep_docs + from scripts.docgen.analyze import ( + analyze_non_python_file, + analyze_python_file, + build_doc_payload, + collect_target_files, + ) + from scripts.docgen.render_pdf import render_file_pdf, render_index_pdf + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate deep per-file PDF documentation.") + parser.add_argument( + "--repo-root", + default=str(Path(__file__).resolve().parents[1]), + help="Repository root path.", + ) + parser.add_argument( + "--output-dir", + default=str(Path(__file__).resolve().parents[1] / "docs" / "pdf"), + help="Output folder for generated PDFs and manifest.", + ) + parser.add_argument( + "--exclude-tests", + action=argparse.BooleanOptionalAction, + default=True, + help="Exclude files under tests/ from documentation scope.", + ) + parser.add_argument( + "--include-non-python", + action=argparse.BooleanOptionalAction, + default=True, + help="Include supported non-Python files (JSON/TOML/YAML/HTML/Dockerfile).", + ) + parser.add_argument( + "--force", + action="store_true", + help="Regenerate all files even when hash is unchanged.", + ) + parser.add_argument( + "--changed-only", + action="store_true", + help="Alias for incremental mode: regenerate only changed/added files.", + ) + parser.add_argument( + "--max-workers", + type=int, + default=4, + help="Parallel workers for PDF generation.", + ) + return parser.parse_args() + + +def main(argv: List[str] | None = None) -> int: + args = parse_args() if argv is None else _parse_from_list(argv) + + repo_root = Path(args.repo_root).resolve() + output_dir = Path(args.output_dir).resolve() + files_dir = output_dir / "files" + manifest_path = output_dir / "manifest.json" + index_path = output_dir / "INDEX.pdf" + + output_dir.mkdir(parents=True, exist_ok=True) + files_dir.mkdir(parents=True, exist_ok=True) + + selected = collect_target_files( + repo_root=repo_root, + exclude_tests=args.exclude_tests, + include_non_python=args.include_non_python, + ) + + print(f"[docgen] inventory ({len(selected)} files):") + for rel in selected: + print(f"[docgen] - {rel}") + + prev_manifest = _load_manifest(manifest_path) + prev_index = {item["source_path"]: item for item in prev_manifest.get("items", [])} + + commit_hash = _get_commit_hash(repo_root) + run_ts = _utc_now() + + jobs: List[Dict[str, Any]] = [] + skipped_entries: Dict[str, Dict[str, Any]] = {} + + for rel in selected: + src = repo_root / rel + sha = _sha256_file(src) + output_rel = f"files/{_sanitize_path(rel)}" + output_abs = output_dir / output_rel + line_count = _line_count(src) + + previous = prev_index.get(rel) + unchanged = ( + previous is not None + and previous.get("source_sha256") == sha + and output_abs.exists() + ) + + if args.force: + should_generate = True + else: + should_generate = not unchanged + + if args.changed_only and not args.force: + should_generate = not unchanged + + if should_generate: + jobs.append( + { + "source_path": rel, + "source_abs": src, + "source_sha256": sha, + "output_pdf": output_rel, + "output_abs": output_abs, + "line_count": line_count, + "generated_at": run_ts, + } + ) + else: + reused = dict(previous) + reused["line_count"] = line_count + skipped_entries[rel] = reused + + print(f"[docgen] generate={len(jobs)} skip={len(skipped_entries)}") + + generated_entries: Dict[str, Dict[str, Any]] = {} + + if jobs: + with concurrent.futures.ThreadPoolExecutor(max_workers=max(args.max_workers, 1)) as executor: + future_map = { + executor.submit(_generate_one, repo_root, commit_hash, job): job["source_path"] + for job in jobs + } + for future in concurrent.futures.as_completed(future_map): + source_path = future_map[future] + result = future.result() + generated_entries[source_path] = result + print( + f"[docgen] rendered {source_path} -> {result['output_pdf']} " + f"({result['page_count']} pages)" + ) + + all_entries: List[Dict[str, Any]] = [] + for rel in selected: + if rel in generated_entries: + all_entries.append(generated_entries[rel]) + elif rel in skipped_entries: + all_entries.append(skipped_entries[rel]) + + all_entries.sort(key=lambda item: item["source_path"]) + + manifest = { + "generated_at": run_ts, + "repo_root": str(repo_root), + "commit_hash": commit_hash, + "doc_depth": "deep", + "method": "deterministic_static", + "file_count": len(all_entries), + "items": all_entries, + } + + index_payload = { + "generated_at": run_ts, + "commit_hash": commit_hash, + "total_files": len(all_entries), + "reading_guide": _reading_guide(all_entries), + "groups": _group_for_index(all_entries), + } + index_pages = render_index_pdf(index_payload, index_path) + + manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=False), encoding="utf-8") + + print(f"[docgen] wrote manifest: {manifest_path}") + print(f"[docgen] wrote index: {index_path} ({index_pages} pages)") + return 0 + + +def _parse_from_list(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate deep per-file PDF documentation.") + parser.add_argument("--repo-root", default=str(Path(__file__).resolve().parents[1])) + parser.add_argument("--output-dir", default=str(Path(__file__).resolve().parents[1] / "docs" / "pdf")) + parser.add_argument("--exclude-tests", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--include-non-python", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--force", action="store_true") + parser.add_argument("--changed-only", action="store_true") + parser.add_argument("--max-workers", type=int, default=4) + return parser.parse_args(argv) + + +def _generate_one(repo_root: Path, commit_hash: str, job: Dict[str, Any]) -> Dict[str, Any]: + source_path = job["source_path"] + source_abs = job["source_abs"] + + if source_path.endswith(".py"): + analysis = analyze_python_file(source_abs) + else: + analysis = analyze_non_python_file(source_abs) + + payload = build_doc_payload(source_path, analysis) + payload["generated_at"] = job["generated_at"] + payload["commit_hash"] = commit_hash + + page_count = render_file_pdf(payload, job["output_abs"]) + + return { + "source_path": source_path, + "source_sha256": job["source_sha256"], + "output_pdf": job["output_pdf"], + "line_count": analysis.get("line_count", job["line_count"]), + "page_count": page_count, + "generated_at": job["generated_at"], + "doc_depth": "deep", + "method": "deterministic_static", + } + + +def _group_for_index(items: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: + groups: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + + for item in items: + path = item["source_path"] + parts = path.split("/") + + if path.startswith("engram/") and len(parts) >= 2: + group = f"engram/{parts[1]}" + elif path.startswith("plugins/engram-memory/") and len(parts) >= 3: + group = f"plugins/engram-memory/{parts[2]}" + else: + group = "root" + + groups[group].append(item) + + for value in groups.values(): + value.sort(key=lambda item: item["source_path"]) + + return dict(sorted(groups.items(), key=lambda item: item[0])) + + +def _reading_guide(items: List[Dict[str, Any]]) -> List[str]: + paths = {item["source_path"] for item in items} + guide = [ + "Start with `engram/memory/main.py` to understand the top-level orchestration flow.", + "Read `engram/db/sqlite.py` next to map persistence behavior and storage contracts.", + "Then inspect `engram/mcp_server.py` for tool/API exposure and request handling.", + "Use module-group sections in this index to drill into subsystems after the core pass.", + ] + + dynamic_hints: List[str] = [] + for candidate in [ + "engram/memory/main.py", + "engram/db/sqlite.py", + "engram/mcp_server.py", + "engram/core/kernel.py", + "engram/api/app.py", + ]: + if candidate in paths: + dynamic_hints.append(f"Priority module present: `{candidate}`.") + + return guide + dynamic_hints + + +def _load_manifest(path: Path) -> Dict[str, Any]: + if not path.exists(): + return {} + try: + return json.loads(path.read_text(encoding="utf-8")) + except Exception: + return {} + + +def _sha256_file(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as handle: + while True: + chunk = handle.read(65536) + if not chunk: + break + digest.update(chunk) + return digest.hexdigest() + + +def _line_count(path: Path) -> int: + with path.open("r", encoding="utf-8", errors="replace") as handle: + return sum(1 for _ in handle) + + +def _sanitize_path(rel_path: str) -> str: + return rel_path.replace("/", "__") + ".pdf" + + +def _utc_now() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + + +def _get_commit_hash(repo_root: Path) -> str: + try: + return ( + subprocess.check_output(["git", "-C", str(repo_root), "rev-parse", "HEAD"], text=True) + .strip() + ) + except Exception: + return "unknown" + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/test_agent_policies.py b/tests/test_agent_policies.py new file mode 100644 index 0000000..ed156c0 --- /dev/null +++ b/tests/test_agent_policies.py @@ -0,0 +1,72 @@ +"""Tests for agent policy enforcement and session clamping.""" + +from __future__ import annotations + +import pytest + +from engram import Engram + + +@pytest.fixture +def memory(): + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def test_exact_agent_policy_clamps_session_grants(memory): + memory.upsert_agent_policy( + user_id="u-policy", + agent_id="planner", + allowed_confidentiality_scopes=["work", "personal"], + allowed_capabilities=["search", "propose_write"], + allowed_namespaces=["default", "workbench"], + ) + + session = memory.create_session( + user_id="u-policy", + agent_id="planner", + allowed_confidentiality_scopes=["work", "finance"], + capabilities=["search", "review_commits"], + namespaces=["default", "private-lab"], + ) + + assert set(session["allowed_confidentiality_scopes"]) == {"work"} + assert set(session["capabilities"]) == {"search"} + assert set(session["namespaces"]) == {"default"} + + +def test_wildcard_policy_applies_when_exact_missing(memory): + memory.upsert_agent_policy( + user_id="u-policy-wild", + agent_id="*", + allowed_confidentiality_scopes=["personal"], + allowed_capabilities=["search"], + allowed_namespaces=["default"], + ) + + session = memory.create_session( + user_id="u-policy-wild", + agent_id="new-agent", + allowed_confidentiality_scopes=["personal", "work"], + capabilities=["search", "propose_write"], + namespaces=["default", "secret"], + ) + + assert set(session["allowed_confidentiality_scopes"]) == {"personal"} + assert set(session["capabilities"]) == {"search"} + assert set(session["namespaces"]) == {"default"} + + +def test_require_agent_policy_blocks_unknown_agent(monkeypatch, memory): + monkeypatch.setenv("ENGRAM_V2_REQUIRE_AGENT_POLICY", "true") + with pytest.raises(PermissionError, match="No agent policy configured"): + memory.create_session( + user_id="u-policy-strict", + agent_id="unregistered-agent", + ) + + +def test_require_agent_policy_does_not_block_local_user_session(monkeypatch, memory): + monkeypatch.setenv("ENGRAM_V2_REQUIRE_AGENT_POLICY", "true") + session = memory.create_session(user_id="u-policy-strict", agent_id=None) + assert session["token"] diff --git a/tests/test_api_v2.py b/tests/test_api_v2.py new file mode 100644 index 0000000..665ec16 --- /dev/null +++ b/tests/test_api_v2.py @@ -0,0 +1,105 @@ +"""Tests for Engram v2 session + token-gated search behavior.""" + +import pytest + +from engram import Engram + + +@pytest.fixture +def memory(): + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def test_session_create_and_token_auth(memory): + session = memory.create_session( + user_id="u1", + agent_id="agent-a", + allowed_confidentiality_scopes=["work"], + capabilities=["search", "propose_write"], + ) + assert session["token"] + assert session["session_id"] + + +def test_agent_search_requires_token(memory): + with pytest.raises(PermissionError): + memory.search_with_context( + query="typescript", + user_id="u1", + agent_id="agent-a", + token=None, + ) + + +def test_search_returns_context_packet(memory): + session = memory.create_session(user_id="u1", agent_id="agent-a") + + staged = memory.propose_write( + content="User prefers TypeScript for backend services", + user_id="u1", + agent_id="agent-a", + token=session["token"], + mode="staging", + infer=False, + ) + assert staged["status"] in {"PENDING", "AUTO_STASHED"} + memory.approve_commit(staged["commit_id"]) + + payload = memory.search_with_context( + query="What backend language does the user prefer?", + user_id="u1", + agent_id="agent-a", + token=session["token"], + limit=5, + ) + + assert "results" in payload + assert "context_packet" in payload + assert payload["context_packet"]["snippets"] + assert "retrieval_trace" in payload + assert payload["retrieval_trace"]["strategy"] == "semantic_plus_episodic_intersection" + + +def test_non_agent_search_without_token_is_not_forced_masked(memory): + memory.add(messages="User prefers Vim keybindings", user_id="u-local", infer=False) + payload = memory.search_with_context( + query="keybindings", + user_id="u-local", + agent_id=None, + token=None, + limit=5, + ) + assert payload["results"] + assert not all(item.get("masked") for item in payload["results"]) + + +def test_capability_restrictions_are_enforced(memory): + search_only = memory.create_session( + user_id="u-cap", + agent_id="agent-cap", + capabilities=["search"], + ) + with pytest.raises(PermissionError): + memory.propose_write( + content="Agent should not be able to write with search-only token", + user_id="u-cap", + agent_id="agent-cap", + token=search_only["token"], + mode="staging", + infer=False, + ) + + write_only = memory.create_session( + user_id="u-cap", + agent_id="agent-cap", + capabilities=["propose_write"], + ) + with pytest.raises(PermissionError): + memory.search_with_context( + query="anything", + user_id="u-cap", + agent_id="agent-cap", + token=write_only["token"], + limit=3, + ) diff --git a/tests/test_backward_compat.py b/tests/test_backward_compat.py new file mode 100644 index 0000000..547adc2 --- /dev/null +++ b/tests/test_backward_compat.py @@ -0,0 +1,162 @@ +"""Tests for backward compatibility — existing Memory operations still work.""" + +import os +import tempfile +import uuid + +import pytest + +from engram.db.sqlite import SQLiteManager + + +@pytest.fixture +def db(): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + mgr = SQLiteManager(path) + yield mgr + os.unlink(path) + + +class TestMemoryBackwardCompat: + """Existing memory CRUD operations should work unchanged.""" + + def test_add_memory(self, db): + mem_id = db.add_memory({ + "memory": "The user prefers dark mode", + "user_id": "default", + }) + assert mem_id + mem = db.get_memory(mem_id) + assert mem is not None + assert mem["memory"] == "The user prefers dark mode" + + def test_update_memory(self, db): + mem_id = db.add_memory({ + "memory": "old content", + "user_id": "default", + }) + success = db.update_memory(mem_id, {"memory": "new content"}) + assert success + mem = db.get_memory(mem_id) + assert mem["memory"] == "new content" + + def test_delete_memory(self, db): + mem_id = db.add_memory({ + "memory": "to delete", + "user_id": "default", + }) + db.delete_memory(mem_id) + mem = db.get_memory(mem_id) + assert mem is None # Tombstoned + + def test_get_all_memories(self, db): + db.add_memory({"memory": "mem1", "user_id": "u1"}) + db.add_memory({"memory": "mem2", "user_id": "u1"}) + db.add_memory({"memory": "mem3", "user_id": "u2"}) + + u1_mems = db.get_all_memories(user_id="u1") + assert len(u1_mems) == 2 + + all_mems = db.get_all_memories() + assert len(all_mems) == 3 + + def test_increment_access(self, db): + mem_id = db.add_memory({"memory": "test", "user_id": "default"}) + db.increment_access(mem_id) + mem = db.get_memory(mem_id) + assert mem["access_count"] == 1 + + def test_categories_still_work(self, db): + cat_id = db.save_category({ + "id": "cat1", + "name": "preferences", + "description": "User preferences", + }) + assert cat_id == "cat1" + cat = db.get_category("cat1") + assert cat["name"] == "preferences" + + def test_history(self, db): + mem_id = db.add_memory({"memory": "test", "user_id": "default"}) + history = db.get_history(mem_id) + assert len(history) >= 1 + assert history[0]["event"] == "ADD" + + def test_decay_log(self, db): + db.log_decay(5, 2, 1) + # Should not raise + + def test_memory_with_scene_id(self, db): + """Memories can now have a scene_id column.""" + mem_id = db.add_memory({"memory": "test", "user_id": "default"}) + db.update_memory(mem_id, {"scene_id": "scene-123"}) + mem = db.get_memory(mem_id) + assert mem.get("scene_id") == "scene-123" + + +class TestNewTablesCoexist: + """New tables should not interfere with existing operations.""" + + def test_scenes_empty_by_default(self, db): + scenes = db.get_scenes() + assert scenes == [] + + def test_profiles_empty_by_default(self, db): + profiles = db.get_all_profiles() + assert profiles == [] + + def test_scene_crud(self, db): + scene_id = db.add_scene({ + "user_id": "u1", + "title": "Test Scene", + "topic": "testing", + "start_time": "2024-01-01T00:00:00", + }) + scene = db.get_scene(scene_id) + assert scene["title"] == "Test Scene" + + db.update_scene(scene_id, {"title": "Updated Scene"}) + scene = db.get_scene(scene_id) + assert scene["title"] == "Updated Scene" + + def test_profile_crud(self, db): + profile_id = db.add_profile({ + "user_id": "u1", + "name": "Alice", + "profile_type": "contact", + "facts": ["Works at Google"], + }) + profile = db.get_profile(profile_id) + assert profile["name"] == "Alice" + assert "Works at Google" in profile["facts"] + + db.update_profile(profile_id, {"facts": ["Works at Google", "Likes Python"]}) + profile = db.get_profile(profile_id) + assert len(profile["facts"]) == 2 + + def test_scene_memory_junction(self, db): + mem_id = db.add_memory({"memory": "linked mem", "user_id": "u1"}) + scene_id = db.add_scene({ + "user_id": "u1", + "title": "Scene", + "start_time": "2024-01-01T00:00:00", + }) + db.add_scene_memory(scene_id, mem_id, position=0) + + scene_mems = db.get_scene_memories(scene_id) + assert len(scene_mems) == 1 + assert scene_mems[0]["id"] == mem_id + + def test_profile_memory_junction(self, db): + mem_id = db.add_memory({"memory": "about alice", "user_id": "u1"}) + profile_id = db.add_profile({ + "user_id": "u1", + "name": "Alice", + "profile_type": "contact", + }) + db.add_profile_memory(profile_id, mem_id, role="mentioned") + + profile_mems = db.get_profile_memories(profile_id) + assert len(profile_mems) == 1 + assert profile_mems[0]["id"] == mem_id diff --git a/tests/test_cosine_similarity.py b/tests/test_cosine_similarity.py new file mode 100644 index 0000000..276a3d0 --- /dev/null +++ b/tests/test_cosine_similarity.py @@ -0,0 +1,54 @@ +"""Tests for engram.utils.math cosine_similarity.""" + +import pytest +from engram.utils.math import cosine_similarity + + +def test_identical_vectors(): + assert cosine_similarity([1, 0, 0], [1, 0, 0]) == pytest.approx(1.0) + + +def test_orthogonal_vectors(): + assert cosine_similarity([1, 0], [0, 1]) == pytest.approx(0.0) + + +def test_opposite_vectors(): + assert cosine_similarity([1, 0], [-1, 0]) == pytest.approx(-1.0) + + +def test_empty_vectors(): + assert cosine_similarity([], []) == 0.0 + + +def test_mismatched_lengths(): + assert cosine_similarity([1, 2], [1, 2, 3]) == 0.0 + + +def test_zero_vector(): + assert cosine_similarity([0, 0, 0], [1, 2, 3]) == 0.0 + + +def test_none_input(): + assert cosine_similarity(None, [1, 2]) == 0.0 # type: ignore[arg-type] + assert cosine_similarity([1, 2], None) == 0.0 # type: ignore[arg-type] + + +def test_high_dimensional(): + """Test with 3072-dimensional vectors (typical embedding size).""" + import random + random.seed(42) + a = [random.gauss(0, 1) for _ in range(3072)] + b = list(a) # identical copy + assert cosine_similarity(a, b) == pytest.approx(1.0, abs=1e-6) + + +def test_known_similarity(): + a = [1.0, 2.0, 3.0] + b = [4.0, 5.0, 6.0] + # Known cosine similarity + import math + dot = 1*4 + 2*5 + 3*6 # 32 + norm_a = math.sqrt(1 + 4 + 9) # sqrt(14) + norm_b = math.sqrt(16 + 25 + 36) # sqrt(77) + expected = dot / (norm_a * norm_b) + assert cosine_similarity(a, b) == pytest.approx(expected, abs=1e-6) diff --git a/tests/test_docgen_analyze.py b/tests/test_docgen_analyze.py new file mode 100644 index 0000000..4c50bf7 --- /dev/null +++ b/tests/test_docgen_analyze.py @@ -0,0 +1,147 @@ +"""Unit tests for deterministic docgen analyzers.""" + +from __future__ import annotations + +import subprocess +from pathlib import Path + +from scripts.docgen.analyze import ( + analyze_non_python_file, + analyze_python_file, + build_doc_payload, + collect_target_files, +) + + +def _init_repo(tmp_path: Path) -> None: + subprocess.run(["git", "init"], cwd=tmp_path, check=True, stdout=subprocess.DEVNULL) + subprocess.run(["git", "config", "user.email", "test@example.com"], cwd=tmp_path, check=True) + subprocess.run(["git", "config", "user.name", "Test User"], cwd=tmp_path, check=True) + + +def _git_add_all(tmp_path: Path) -> None: + subprocess.run(["git", "add", "-A"], cwd=tmp_path, check=True) + + +def test_collect_target_files_filters_scope(tmp_path: Path) -> None: + _init_repo(tmp_path) + + files = { + "engram/core/a.py": "print('a')\n", + "engram/core/ignore.md": "# ignored\n", + "plugins/engram-memory/hooks/prompt_context.py": "x = 1\n", + "plugins/engram-memory/README.md": "ignored\n", + "tests/test_ignore.py": "def test_x():\n assert True\n", + "pyproject.toml": "[project]\nname='x'\n", + "Dockerfile": "FROM python:3.11\n", + "docker-compose.yml": "services:\n app:\n image: x\n", + "misc.py": "print('not in scope')\n", + } + + for rel, content in files.items(): + target = tmp_path / rel + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content, encoding="utf-8") + + _git_add_all(tmp_path) + + selected = collect_target_files(tmp_path, exclude_tests=True, include_non_python=True) + assert "engram/core/a.py" in selected + assert "plugins/engram-memory/hooks/prompt_context.py" in selected + assert "pyproject.toml" in selected + assert "Dockerfile" in selected + assert "docker-compose.yml" in selected + + assert "tests/test_ignore.py" not in selected + assert "engram/core/ignore.md" not in selected + assert "plugins/engram-memory/README.md" not in selected + assert "misc.py" not in selected + + py_only = collect_target_files(tmp_path, exclude_tests=True, include_non_python=False) + assert all(item.endswith(".py") for item in py_only) + + +def test_analyze_python_file_extracts_core_metadata(tmp_path: Path) -> None: + path = tmp_path / "sample.py" + path.write_text( + """ +import os +import logging + +APP_NAME = "engram" + + +def helper(value: int) -> int: + return value + 1 + + +class Sample: + def run(self, value: int) -> int: + if value < 0: + raise ValueError("bad") + return helper(value) + + +async def top() -> int: + return helper(1) +""".strip() + + "\n", + encoding="utf-8", + ) + + analysis = analyze_python_file(path) + + assert analysis["file_type"] == "python" + assert analysis["line_count"] > 0 + assert any(item["name"] == "APP_NAME" for item in analysis["constants"]) + assert any(item["name"] == "Sample" for item in analysis["classes"]) + assert any(item["name"] == "top" and item["is_async"] for item in analysis["functions"]) + assert any(item["exception"] == "ValueError('bad')" or "ValueError" in item["exception"] for item in analysis["raises"]) + assert "helper" in analysis["call_map"].get("Sample.run", []) + + payload = build_doc_payload("engram/core/sample.py", analysis) + section_titles = [section["title"] for section in payload["sections"]] + assert section_titles == [ + "Role in repository", + "File map and metrics", + "Public interfaces and key symbols", + "Execution/data flow walkthrough", + "Error handling and edge cases", + "Integration and dependencies", + "Safe modification guide", + "Reading order for large files", + ] + + +def test_analyze_non_python_file_variants(tmp_path: Path) -> None: + json_path = tmp_path / "sample.json" + json_path.write_text('{"services": {"api": {"port": 8100}}, "token": "x"}\n', encoding="utf-8") + + toml_path = tmp_path / "sample.toml" + toml_path.write_text("[project]\nname='engram'\n[tool.test]\nkey='value'\n", encoding="utf-8") + + docker_path = tmp_path / "Dockerfile" + docker_path.write_text("FROM python:3.11\nRUN pip install engram\nCMD ['python']\n", encoding="utf-8") + + html_path = tmp_path / "page.html" + html_path.write_text( + "
", + encoding="utf-8", + ) + + json_analysis = analyze_non_python_file(json_path) + assert json_analysis["format"] == "json" + assert any("services" in key for key in json_analysis["structure"]) + + toml_analysis = analyze_non_python_file(toml_path) + assert toml_analysis["format"] == "toml" + assert any("project" in key for key in toml_analysis["structure"]) + + docker_analysis = analyze_non_python_file(docker_path) + assert docker_analysis["format"] == "dockerfile" + assert any(item["instruction"] == "FROM" for item in docker_analysis["instructions"]) + + html_analysis = analyze_non_python_file(html_path) + assert html_analysis["format"] == "html" + assert any("
" in item for item in html_analysis["structure"]) + assert any("app.js" in item for item in html_analysis["integrations"]) diff --git a/tests/test_docgen_pipeline.py b/tests/test_docgen_pipeline.py new file mode 100644 index 0000000..b160445 --- /dev/null +++ b/tests/test_docgen_pipeline.py @@ -0,0 +1,192 @@ +"""Integration and renderer tests for deterministic deep docgen pipeline.""" + +from __future__ import annotations + +import importlib.util +import json +import os +import subprocess +import time +from pathlib import Path + +import pytest + +from scripts.generate_deep_docs import main +from scripts.docgen.render_pdf import render_file_pdf, render_index_pdf + + +def _init_repo(tmp_path: Path) -> None: + subprocess.run(["git", "init"], cwd=tmp_path, check=True, stdout=subprocess.DEVNULL) + subprocess.run(["git", "config", "user.email", "test@example.com"], cwd=tmp_path, check=True) + subprocess.run(["git", "config", "user.name", "Test User"], cwd=tmp_path, check=True) + + +def _write(tmp_path: Path, rel: str, content: str) -> None: + target = tmp_path / rel + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content, encoding="utf-8") + + +def _git_add_all(tmp_path: Path) -> None: + subprocess.run(["git", "add", "-A"], cwd=tmp_path, check=True) + + +def _has_reportlab() -> bool: + return importlib.util.find_spec("reportlab") is not None + + +@pytest.mark.skipif(not _has_reportlab(), reason="reportlab is required for PDF generation tests") +def test_renderers_smoke(tmp_path: Path) -> None: + file_pdf = tmp_path / "file.pdf" + index_pdf = tmp_path / "index.pdf" + + payload = { + "file_path": "engram/core/demo.py", + "generated_at": "2026-02-11T00:00:00+00:00", + "commit_hash": "abc123", + "method": "deterministic_static", + "doc_depth": "deep", + "line_count": 10, + "sections": [ + { + "title": "Role in repository", + "paragraphs": ["Demo paragraph."], + "code_blocks": ["def demo():\n return 1"], + } + ], + } + pages = render_file_pdf(payload, file_pdf) + assert pages >= 1 + assert file_pdf.exists() + assert file_pdf.stat().st_size > 0 + + index_payload = { + "generated_at": "2026-02-11T00:00:00+00:00", + "commit_hash": "abc123", + "total_files": 1, + "reading_guide": ["Read core files first."], + "groups": { + "engram/core": [ + { + "source_path": "engram/core/demo.py", + "output_pdf": "files/engram__core__demo.py.pdf", + "line_count": 10, + "page_count": pages, + } + ] + }, + } + index_pages = render_index_pdf(index_payload, index_pdf) + assert index_pages >= 1 + assert index_pdf.exists() + assert index_pdf.stat().st_size > 0 + + +@pytest.mark.skipif(not _has_reportlab(), reason="reportlab is required for PDF generation tests") +def test_generator_end_to_end_and_incremental(tmp_path: Path) -> None: + _init_repo(tmp_path) + + _write(tmp_path, "engram/core/a.py", "def alpha():\n return 1\n") + _write(tmp_path, "engram/memory/b.py", "class Beta:\n pass\n") + _write(tmp_path, "plugins/engram-memory/hooks/prompt_context.py", "HOOK = True\n") + _write(tmp_path, "plugins/engram-memory/hooks/hooks.json", '{"hooks": ["x"]}\n') + _write(tmp_path, "pyproject.toml", "[project]\nname='tmp'\n") + _write(tmp_path, "Dockerfile", "FROM python:3.11\nCMD ['python']\n") + _write(tmp_path, "docker-compose.yml", "services:\n app:\n image: test\n") + _write(tmp_path, "tests/test_ignore.py", "def test_ignore():\n assert True\n") + _write(tmp_path, "engram/core/ignore.md", "# ignored\n") + + _git_add_all(tmp_path) + + output_dir = tmp_path / "docs" / "pdf" + + exit_code = main( + [ + "--repo-root", + str(tmp_path), + "--output-dir", + str(output_dir), + "--exclude-tests", + "--include-non-python", + "--max-workers", + "2", + ] + ) + assert exit_code == 0 + + manifest_path = output_dir / "manifest.json" + index_path = output_dir / "INDEX.pdf" + assert manifest_path.exists() + assert index_path.exists() + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + assert manifest["file_count"] == 7 + + required_item_keys = { + "source_path", + "source_sha256", + "output_pdf", + "line_count", + "page_count", + "generated_at", + "doc_depth", + "method", + } + + items = manifest["items"] + for item in items: + assert required_item_keys.issubset(item) + target_pdf = output_dir / item["output_pdf"] + assert target_pdf.exists() + assert target_pdf.stat().st_size > 0 + + target_item = next(item for item in items if item["source_path"] == "engram/core/a.py") + untouched_item = next(item for item in items if item["source_path"] == "engram/memory/b.py") + + target_pdf = output_dir / target_item["output_pdf"] + untouched_pdf = output_dir / untouched_item["output_pdf"] + initial_target_mtime = target_pdf.stat().st_mtime + initial_untouched_mtime = untouched_pdf.stat().st_mtime + + time.sleep(1.1) + exit_code_2 = main( + [ + "--repo-root", + str(tmp_path), + "--output-dir", + str(output_dir), + "--exclude-tests", + "--include-non-python", + "--changed-only", + "--max-workers", + "2", + ] + ) + assert exit_code_2 == 0 + assert target_pdf.stat().st_mtime == initial_target_mtime + assert untouched_pdf.stat().st_mtime == initial_untouched_mtime + + time.sleep(1.1) + _write(tmp_path, "engram/core/a.py", "def alpha():\n return 2\n") + + exit_code_3 = main( + [ + "--repo-root", + str(tmp_path), + "--output-dir", + str(output_dir), + "--exclude-tests", + "--include-non-python", + "--changed-only", + "--max-workers", + "2", + ] + ) + assert exit_code_3 == 0 + + assert target_pdf.stat().st_mtime > initial_target_mtime + assert untouched_pdf.stat().st_mtime == initial_untouched_mtime + + manifest_after = json.loads(manifest_path.read_text(encoding="utf-8")) + assert manifest_after["file_count"] == 7 + assert any(item["source_path"] == "engram/core/a.py" for item in manifest_after["items"]) diff --git a/tests/test_dual_retrieval.py b/tests/test_dual_retrieval.py new file mode 100644 index 0000000..551d1ed --- /dev/null +++ b/tests/test_dual_retrieval.py @@ -0,0 +1,101 @@ +"""Tests for dual retrieval and intersection promotion.""" + +import pytest + +from engram import Engram +from engram.retrieval.reranker import intersection_promote + + +@pytest.fixture +def memory(): + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def _stage_and_approve(memory, *, content, token, user_id="u-dual", agent_id="reader"): + proposal = memory.propose_write( + content=content, + user_id=user_id, + agent_id=agent_id, + token=token, + mode="staging", + infer=False, + scope="work", + ) + memory.approve_commit(proposal["commit_id"]) + + +def test_dual_retrieval_intersection_promotion(memory): + session = memory.create_session( + user_id="u-dual", + agent_id="reader", + allowed_confidentiality_scopes=["work"], + ) + + _stage_and_approve( + memory, + token=session["token"], + content="On 8 May at dance studio, Gina's team performed Finding Freedom and won first place.", + ) + _stage_and_approve( + memory, + token=session["token"], + content="Gina likes contemporary dance in general.", + ) + + payload = memory.search_with_context( + query="What piece did Gina perform when her team won first place?", + user_id="u-dual", + agent_id="reader", + token=session["token"], + limit=5, + ) + + assert payload["results"] + assert payload["scene_hits"] + assert payload["results"][0].get("episodic_match") is True + + context_packet = payload["context_packet"] + assert context_packet["snippets"] + # Scene citations are emitted in context packet. + assert context_packet["snippets"][0]["citations"]["scene_ids"] + + trace = payload["retrieval_trace"] + assert trace["strategy"] == "semantic_plus_episodic_intersection" + assert trace["intersection_candidates"] >= 1 + assert "boost_weight" in trace + assert "boost_cap" in trace + + scores = [float(item.get("composite_score", 0.0)) for item in payload["results"]] + assert scores == sorted(scores, reverse=True) + + first = payload["results"][0] + assert "base_composite_score" in first + assert "intersection_boost" in first + assert float(first["composite_score"]) >= float(first["base_composite_score"]) + + +def test_intersection_promote_applies_calibrated_boost_deterministically(): + semantic_results = [ + {"id": "m1", "composite_score": 0.75, "score": 0.75}, + {"id": "m2", "composite_score": 0.80, "score": 0.80}, + {"id": "m3", "composite_score": 0.65, "score": 0.65}, + ] + episodic_scenes = [ + {"id": "s1", "memory_ids": ["m1"], "search_score": 0.95}, + ] + + ranked = intersection_promote( + semantic_results, + episodic_scenes, + boost_weight=0.35, + max_boost=0.35, + ) + by_id = {item["id"]: item for item in ranked} + + assert ranked[0]["id"] == "m1" + assert by_id["m1"]["episodic_match"] is True + assert float(by_id["m1"]["intersection_boost"]) > 0.0 + assert float(by_id["m1"]["composite_score"]) > float(by_id["m1"]["base_composite_score"]) + assert by_id["m2"]["episodic_match"] is False + assert float(by_id["m2"]["intersection_boost"]) == 0.0 diff --git a/tests/test_handoff.py b/tests/test_handoff.py new file mode 100644 index 0000000..6a679cb --- /dev/null +++ b/tests/test_handoff.py @@ -0,0 +1,496 @@ +"""Tests for cross-agent handoff session bus and legacy compatibility APIs.""" + +from __future__ import annotations + +import os + +import pytest + +from engram import Engram + + +@pytest.fixture +def memory(): + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def _token(memory, *, user_id: str, agent_id: str, capabilities): + if {"read_handoff", "write_handoff"} & set(capabilities): + memory.db.upsert_agent_policy( + user_id=user_id, + agent_id=agent_id, + allowed_confidentiality_scopes=["work"], + allowed_capabilities=["read_handoff", "write_handoff"], + allowed_namespaces=["default"], + ) + session = memory.create_session( + user_id=user_id, + agent_id=agent_id, + capabilities=list(capabilities), + ) + return session["token"] + + +def test_save_and_get_last_session_roundtrip(memory, tmp_path): + repo_path = str(tmp_path) + alias_repo_path = os.path.join(repo_path, ".", "") + token = _token( + memory, + user_id="u-handoff-1", + agent_id="claude-code", + capabilities=["read_handoff", "write_handoff"], + ) + + saved = memory.save_session_digest( + "u-handoff-1", + "claude-code", + { + "task_summary": "Implement auto lane routing", + "repo": repo_path, + "status": "paused", + "decisions_made": ["Use git fingerprint repo_id"], + "files_touched": ["engram/core/handoff_bus.py"], + "todos_remaining": ["Add API tests"], + "blockers": ["Need auth token propagation"], + "key_commands": ["pytest tests/test_handoff.py -q"], + "test_results": ["unit tests pending"], + "context_snapshot": "Bus is mostly wired.", + }, + token=token, + requester_agent_id="claude-code", + ) + assert saved["id"] + assert saved["repo_id"] + + resumed = memory.get_last_session( + "u-handoff-1", + agent_id="claude-code", + repo=alias_repo_path, + token=token, + requester_agent_id="claude-code", + ) + assert resumed is not None + assert resumed["task_summary"] == "Implement auto lane routing" + assert resumed["from_agent"] == "claude-code" + assert resumed["repo_id"] == saved["repo_id"] + assert resumed["blockers"] == ["Need auth token propagation"] + + +def test_hard_prune_keeps_latest_session_order(memory, tmp_path): + repo_path = str(tmp_path) + memory.handoff_processor.session_bus.max_sessions_per_user = 2 + token = _token( + memory, + user_id="u-handoff-2", + agent_id="claude-code", + capabilities=["read_handoff", "write_handoff"], + ) + + for idx in range(4): + memory.save_session_digest( + "u-handoff-2", + "claude-code", + { + "task_summary": f"session-{idx}", + "repo": repo_path, + "status": "paused", + }, + token=token, + requester_agent_id="claude-code", + ) + + sessions = memory.list_sessions( + "u-handoff-2", + repo=repo_path, + limit=10, + token=token, + requester_agent_id="claude-code", + ) + assert len(sessions) == 2 + + resumed = memory.get_last_session( + "u-handoff-2", + repo=repo_path, + token=token, + requester_agent_id="claude-code", + ) + assert resumed["task_summary"] == "session-3" + + +def test_auto_resume_cross_agent_lane_continuity(memory, tmp_path): + repo_path = str(tmp_path) + claude_token = _token( + memory, + user_id="u-handoff-3", + agent_id="claude-code", + capabilities=["read_handoff", "write_handoff"], + ) + + initial_resume = memory.auto_resume_context( + user_id="u-handoff-3", + agent_id="claude-code", + repo_path=repo_path, + objective="Build handoff APIs", + token=claude_token, + requester_agent_id="claude-code", + ) + lane_id = initial_resume["lane_id"] + assert initial_resume["created_new_lane"] is True + + checkpoint = memory.auto_checkpoint( + user_id="u-handoff-3", + agent_id="claude-code", + repo_path=repo_path, + lane_id=lane_id, + payload={ + "task_summary": "Added handoff resume endpoint", + "files_touched": ["engram/api/app.py"], + "todos_remaining": ["Add lane listing endpoint tests"], + }, + token=claude_token, + requester_agent_id="claude-code", + ) + assert checkpoint["checkpoint_id"] + + codex_token = _token( + memory, + user_id="u-handoff-3", + agent_id="codex", + capabilities=["read_handoff", "write_handoff"], + ) + codex_resume = memory.auto_resume_context( + user_id="u-handoff-3", + agent_id="codex", + repo_path=repo_path, + objective="Continue previous work", + token=codex_token, + requester_agent_id="codex", + ) + assert codex_resume["lane_id"] == lane_id + assert codex_resume["task_summary"] == "Added handoff resume endpoint" + assert "Add lane listing endpoint tests" in codex_resume["next_actions"] + + +def test_stale_expected_version_logs_conflict(memory, tmp_path): + repo_path = str(tmp_path) + token = _token( + memory, + user_id="u-handoff-4", + agent_id="frontend", + capabilities=["read_handoff", "write_handoff"], + ) + resume = memory.auto_resume_context( + user_id="u-handoff-4", + agent_id="frontend", + repo_path=repo_path, + objective="Polish UI", + token=token, + requester_agent_id="frontend", + ) + lane_id = resume["lane_id"] + + first = memory.auto_checkpoint( + user_id="u-handoff-4", + agent_id="frontend", + lane_id=lane_id, + repo_path=repo_path, + payload={"task_summary": "Drafted UI wireframes"}, + expected_version=0, + token=token, + requester_agent_id="frontend", + ) + assert first["checkpoint_id"] + + second = memory.auto_checkpoint( + user_id="u-handoff-4", + agent_id="frontend", + lane_id=lane_id, + repo_path=repo_path, + payload={"task_summary": "Updated component hierarchy"}, + expected_version=0, # stale on purpose + token=token, + requester_agent_id="frontend", + ) + assert second["checkpoint_id"] + assert len(second.get("conflicts", [])) == 1 + + conflicts = memory.db.list_handoff_lane_conflicts(lane_id) + assert conflicts + assert "task_summary" in set(conflicts[0].get("conflict_fields", [])) + + +def test_handoff_capabilities_are_enforced(memory, tmp_path): + repo_path = str(tmp_path) + with pytest.raises(PermissionError): + memory.save_session_digest( + "u-handoff-5", + "backend", + {"task_summary": "No token should fail", "repo": repo_path}, + ) + + weak_token = _token( + memory, + user_id="u-handoff-5", + agent_id="backend", + capabilities=["search"], + ) + with pytest.raises(PermissionError): + memory.save_session_digest( + "u-handoff-5", + "backend", + {"task_summary": "Wrong capability", "repo": repo_path}, + token=weak_token, + requester_agent_id="backend", + ) + + strong_token = _token( + memory, + user_id="u-handoff-5", + agent_id="backend", + capabilities=["read_handoff", "write_handoff"], + ) + saved = memory.save_session_digest( + "u-handoff-5", + "backend", + {"task_summary": "Capability granted", "repo": repo_path}, + token=strong_token, + requester_agent_id="backend", + ) + assert saved["id"] + + +def test_get_last_session_falls_back_to_lane_checkpoint(memory, tmp_path): + repo_path = str(tmp_path) + token = _token( + memory, + user_id="u-handoff-6", + agent_id="claude-code", + capabilities=["read_handoff", "write_handoff"], + ) + + resume = memory.auto_resume_context( + user_id="u-handoff-6", + agent_id="claude-code", + repo_path=repo_path, + objective="Implement lane fallback", + token=token, + requester_agent_id="claude-code", + ) + lane_id = resume["lane_id"] + + memory.auto_checkpoint( + user_id="u-handoff-6", + agent_id="claude-code", + repo_path=repo_path, + lane_id=lane_id, + payload={ + "task_summary": "Checkpoint without legacy digest session", + "files_touched": ["engram/core/handoff_bus.py"], + "todos_remaining": ["Verify get_last_session fallback"], + }, + token=token, + requester_agent_id="claude-code", + ) + + # No save_session_digest call happened, so legacy session rows are absent. + # get_last_session must still return continuity from lane/checkpoint state. + resumed = memory.get_last_session( + "u-handoff-6", + repo=repo_path, + token=token, + requester_agent_id="claude-code", + ) + assert resumed is not None + assert resumed["lane_id"] == lane_id + assert resumed["task_summary"] == "Checkpoint without legacy digest session" + assert "Verify get_last_session fallback" in resumed["todos_remaining"] + + listed = memory.list_sessions( + "u-handoff-6", + repo=repo_path, + token=token, + requester_agent_id="claude-code", + ) + assert listed + assert listed[0]["lane_id"] == lane_id + assert listed[0]["task_summary"] == "Checkpoint without legacy digest session" + + +def test_get_last_session_prefers_active_lane_over_completed_digest(memory, tmp_path): + repo_path = str(tmp_path) + token = _token( + memory, + user_id="u-handoff-7", + agent_id="claude-code", + capabilities=["read_handoff", "write_handoff"], + ) + + memory.save_session_digest( + "u-handoff-7", + "claude-code", + { + "task_summary": "Historical completed digest", + "repo": repo_path, + "status": "completed", + }, + token=token, + requester_agent_id="claude-code", + ) + + resume = memory.auto_resume_context( + user_id="u-handoff-7", + agent_id="claude-code", + repo_path=repo_path, + objective="Continue active work", + token=token, + requester_agent_id="claude-code", + ) + lane_id = resume["lane_id"] + memory.auto_checkpoint( + user_id="u-handoff-7", + agent_id="claude-code", + repo_path=repo_path, + lane_id=lane_id, + payload={ + "task_summary": "Live lane checkpoint", + "todos_remaining": ["finish active task"], + }, + token=token, + requester_agent_id="claude-code", + ) + + resumed = memory.get_last_session( + "u-handoff-7", + repo=repo_path, + token=token, + requester_agent_id="claude-code", + ) + assert resumed is not None + assert resumed["task_summary"] == "Live lane checkpoint" + assert resumed["lane_id"] == lane_id + assert resumed["status"] == "active" + + +def test_get_last_session_respects_explicit_status_filter(memory, tmp_path): + repo_path = str(tmp_path) + token = _token( + memory, + user_id="u-handoff-8", + agent_id="claude-code", + capabilities=["read_handoff", "write_handoff"], + ) + + memory.save_session_digest( + "u-handoff-8", + "claude-code", + { + "task_summary": "Completed digest only", + "repo": repo_path, + "status": "completed", + }, + token=token, + requester_agent_id="claude-code", + ) + + resume = memory.auto_resume_context( + user_id="u-handoff-8", + agent_id="claude-code", + repo_path=repo_path, + objective="Live branch", + token=token, + requester_agent_id="claude-code", + ) + lane_id = resume["lane_id"] + memory.auto_checkpoint( + user_id="u-handoff-8", + agent_id="claude-code", + repo_path=repo_path, + lane_id=lane_id, + payload={"task_summary": "Lane is active"}, + token=token, + requester_agent_id="claude-code", + ) + + only_active = memory.get_last_session( + "u-handoff-8", + repo=repo_path, + statuses=["active"], + token=token, + requester_agent_id="claude-code", + ) + assert only_active is not None + assert only_active["task_summary"] == "Lane is active" + assert only_active["status"] == "active" + + only_completed = memory.get_last_session( + "u-handoff-8", + repo=repo_path, + statuses=["completed"], + token=token, + requester_agent_id="claude-code", + ) + assert only_completed is not None + assert only_completed["task_summary"] == "Completed digest only" + assert only_completed["status"] == "completed" + + +def test_get_last_session_status_filter_is_case_insensitive(memory, tmp_path): + repo_path = str(tmp_path) + token = _token( + memory, + user_id="u-handoff-9", + agent_id="claude-code", + capabilities=["read_handoff", "write_handoff"], + ) + memory.save_session_digest( + "u-handoff-9", + "claude-code", + { + "task_summary": "Paused digest for case-insensitive status test", + "repo": repo_path, + "status": "paused", + }, + token=token, + requester_agent_id="claude-code", + ) + + resumed = memory.get_last_session( + "u-handoff-9", + repo=repo_path, + statuses=["PAUSED"], + token=token, + requester_agent_id="claude-code", + ) + assert resumed is not None + assert resumed["status"] == "paused" + + +def test_get_last_session_rejects_invalid_status_filter(memory, tmp_path): + repo_path = str(tmp_path) + token = _token( + memory, + user_id="u-handoff-10", + agent_id="claude-code", + capabilities=["read_handoff", "write_handoff"], + ) + memory.save_session_digest( + "u-handoff-10", + "claude-code", + { + "task_summary": "Digest for invalid status filter test", + "repo": repo_path, + "status": "paused", + }, + token=token, + requester_agent_id="claude-code", + ) + + with pytest.raises(ValueError, match="Invalid handoff statuses"): + memory.get_last_session( + "u-handoff-10", + repo=repo_path, + statuses=["running"], + token=token, + requester_agent_id="claude-code", + ) diff --git a/tests/test_handoff_api_compat.py b/tests/test_handoff_api_compat.py new file mode 100644 index 0000000..af7f399 --- /dev/null +++ b/tests/test_handoff_api_compat.py @@ -0,0 +1,131 @@ +"""Compatibility API tests for legacy handoff session routes.""" + +from __future__ import annotations + +import importlib + +import pytest + +pytest.importorskip("fastapi") +pytest.importorskip("httpx") + +from fastapi.testclient import TestClient + +from engram import Engram + +api_app_module = importlib.import_module("engram.api.app") + + +@pytest.fixture +def client(): + eng = Engram(in_memory=True, provider="mock") + api_app_module._memory = eng._memory + with TestClient(api_app_module.app) as test_client: + yield test_client + api_app_module._memory = None + + +def _session_token(user_id: str, agent_id: str) -> str: + kernel = api_app_module.get_kernel() + kernel.db.upsert_agent_policy( + user_id=user_id, + agent_id=agent_id, + allowed_confidentiality_scopes=["work"], + allowed_capabilities=["read_handoff", "write_handoff"], + allowed_namespaces=["default"], + ) + session = kernel.create_session( + user_id=user_id, + agent_id=agent_id, + capabilities=["read_handoff", "write_handoff"], + namespaces=["default"], + ) + return session["token"] + + +def test_handoff_session_compat_routes_round_trip(client): + user_id = "u-handoff-api-compat" + agent_id = "codex" + token = _session_token(user_id=user_id, agent_id=agent_id) + headers = {"Authorization": f"Bearer {token}"} + + digest = client.post( + "/v1/handoff/sessions/digest", + headers=headers, + json={ + "user_id": user_id, + "agent_id": agent_id, + "requester_agent_id": agent_id, + "task_summary": "Harden handoff compatibility routes", + "repo": "/tmp/engram-repo", + "status": "paused", + "files_touched": ["engram/api/app.py"], + "todos_remaining": ["Validate old MCP clients"], + }, + ) + assert digest.status_code == 200 + digest_payload = digest.json() + assert digest_payload.get("id") + assert digest_payload.get("task_summary") == "Harden handoff compatibility routes" + + last = client.get( + "/v1/handoff/sessions/last", + headers=headers, + params={ + "user_id": user_id, + "agent_id": agent_id, + "requester_agent_id": agent_id, + "repo": "/tmp/engram-repo", + }, + ) + assert last.status_code == 200 + last_payload = last.json() + assert last_payload.get("task_summary") == "Harden handoff compatibility routes" + assert last_payload.get("from_agent") == agent_id + + listed = client.get( + "/v1/handoff/sessions", + headers=headers, + params={ + "user_id": user_id, + "agent_id": agent_id, + "requester_agent_id": agent_id, + "repo": "/tmp/engram-repo", + "limit": 10, + }, + ) + assert listed.status_code == 200 + listed_payload = listed.json() + assert listed_payload.get("count", 0) >= 1 + assert listed_payload["sessions"][0]["task_summary"] + + +def test_handoff_routes_reject_invalid_status_values(client): + user_id = "u-handoff-api-compat-invalid" + agent_id = "codex" + token = _session_token(user_id=user_id, agent_id=agent_id) + headers = {"Authorization": f"Bearer {token}"} + + invalid_last = client.get( + "/v1/handoff/sessions/last", + headers=headers, + params={ + "user_id": user_id, + "agent_id": agent_id, + "requester_agent_id": agent_id, + "statuses": "running", + }, + ) + assert invalid_last.status_code == 422 + + invalid_list = client.get( + "/v1/handoff/sessions", + headers=headers, + params={ + "user_id": user_id, + "agent_id": agent_id, + "requester_agent_id": agent_id, + "status": "running", + }, + ) + assert invalid_list.status_code == 422 diff --git a/tests/test_handoff_hosted_backend.py b/tests/test_handoff_hosted_backend.py new file mode 100644 index 0000000..d1e7d7f --- /dev/null +++ b/tests/test_handoff_hosted_backend.py @@ -0,0 +1,50 @@ +"""Hosted handoff backend routing tests.""" + +from __future__ import annotations + +import asyncio + +import pytest + +pytest.importorskip("mcp") + +from engram import Engram +from engram.core.handoff_backend import HostedHandoffBackend, create_handoff_backend +import engram.mcp_server as mcp_server + + +@pytest.fixture(autouse=True) +def reset_mcp_state(): + mcp_server._handoff_backend = None + mcp_server._lifecycle_state.clear() + yield + mcp_server._handoff_backend = None + mcp_server._lifecycle_state.clear() + + +def test_backend_prefers_hosted_when_api_url_is_set(monkeypatch): + eng = Engram(in_memory=True, provider="mock") + monkeypatch.setenv("ENGRAM_API_URL", "http://127.0.0.1:8100") + backend = create_handoff_backend(eng._memory) + assert isinstance(backend, HostedHandoffBackend) + + +def test_get_last_session_reports_hosted_backend_unavailable(monkeypatch): + eng = Engram(in_memory=True, provider="mock") + + monkeypatch.setenv("ENGRAM_API_URL", "http://127.0.0.1:1") + monkeypatch.setattr(mcp_server, "get_memory", lambda: eng._memory) + + output = asyncio.run( + mcp_server.call_tool( + "get_last_session", + { + "user_id": "u-hosted-err", + "agent_id": "codex", + "requester_agent_id": "codex", + "repo": "/tmp/repo", + }, + ) + ) + payload_text = output[0].text + assert "hosted_backend_unavailable" in payload_text diff --git a/tests/test_mcp_handoff_lifecycle.py b/tests/test_mcp_handoff_lifecycle.py new file mode 100644 index 0000000..79ebe61 --- /dev/null +++ b/tests/test_mcp_handoff_lifecycle.py @@ -0,0 +1,123 @@ +"""MCP lifecycle tests for automatic handoff continuity.""" + +from __future__ import annotations + +import asyncio +import time + +import pytest + +pytest.importorskip("mcp") + +from engram import Engram +import engram.mcp_server as mcp_server + + +class _FakeHandoffBackend: + def __init__(self): + self.resume_calls = [] + self.checkpoint_calls = [] + + def auto_resume_context(self, **kwargs): + self.resume_calls.append(kwargs) + return { + "lane_id": "lane-1", + "repo_id": "repo-1", + "task_summary": "Resume packet", + } + + def auto_checkpoint(self, **kwargs): + self.checkpoint_calls.append(kwargs) + return { + "lane_id": kwargs.get("lane_id") or "lane-1", + "checkpoint_id": f"cp-{len(self.checkpoint_calls)}", + "status": kwargs.get("payload", {}).get("status", "active"), + "version": len(self.checkpoint_calls), + } + + def save_session_digest(self, **kwargs): # pragma: no cover - interface completeness + return {"id": "session-1", **kwargs} + + def get_last_session(self, **kwargs): # pragma: no cover - interface completeness + return {"id": "session-1", **kwargs} + + def list_sessions(self, **kwargs): # pragma: no cover - interface completeness + return [] + + +@pytest.fixture(autouse=True) +def reset_state(): + mcp_server._lifecycle_state.clear() + mcp_server._handoff_backend = None + yield + mcp_server._lifecycle_state.clear() + mcp_server._handoff_backend = None + + +def test_auto_resume_and_tool_complete_checkpoint(monkeypatch): + eng = Engram(in_memory=True, provider="mock") + backend = _FakeHandoffBackend() + monkeypatch.setattr(mcp_server, "get_memory", lambda: eng._memory) + monkeypatch.setattr(mcp_server, "get_handoff_backend", lambda _memory: backend) + + output = asyncio.run( + mcp_server.call_tool( + "search_memory", + { + "query": "continuity", + "user_id": "u-mcp-life-1", + "requester_agent_id": "codex", + "repo_path": "/tmp/repo", + }, + ) + ) + assert output + assert backend.resume_calls + assert any(call.get("event_type") == "tool_complete" for call in backend.checkpoint_calls) + + +def test_idle_pause_checkpoint_and_shutdown_end_checkpoint(monkeypatch): + eng = Engram(in_memory=True, provider="mock") + backend = _FakeHandoffBackend() + monkeypatch.setattr(mcp_server, "get_memory", lambda: eng._memory) + monkeypatch.setattr(mcp_server, "get_handoff_backend", lambda _memory: backend) + monkeypatch.setattr(mcp_server, "_idle_pause_seconds", 1) + + key = mcp_server._handoff_key( + user_id="u-mcp-life-2", + agent_id="codex", + namespace="default", + repo_id=None, + repo_path="/tmp/repo", + ) + mcp_server._lifecycle_state[key] = { + "user_id": "u-mcp-life-2", + "agent_id": "codex", + "namespace": "default", + "repo_path": "/tmp/repo", + "lane_id": "lane-stale", + "lane_type": "general", + "objective": "Resume previous work", + "confidentiality_scope": "work", + "last_activity_ts": time.time() - 120, + } + + asyncio.run( + mcp_server.call_tool( + "search_memory", + { + "query": "resume", + "user_id": "u-mcp-life-2", + "requester_agent_id": "codex", + "repo_path": "/tmp/repo", + }, + ) + ) + + events = [call.get("event_type") for call in backend.checkpoint_calls] + assert "agent_pause" in events + assert "tool_complete" in events + + mcp_server._flush_agent_end_checkpoints() + events = [call.get("event_type") for call in backend.checkpoint_calls] + assert "agent_end" in events diff --git a/tests/test_mcp_tool_dispatch.py b/tests/test_mcp_tool_dispatch.py new file mode 100644 index 0000000..a6825ed --- /dev/null +++ b/tests/test_mcp_tool_dispatch.py @@ -0,0 +1,39 @@ +"""Tests for MCP tool handler registry (Phase 6).""" + +import pytest +from engram.mcp_server import _TOOL_HANDLERS + + +class TestToolHandlerRegistry: + def test_registry_is_populated(self): + """The tool handler registry should have entries from decorated handlers.""" + assert len(_TOOL_HANDLERS) > 0 + + def test_known_handlers_registered(self): + """Verify specific handlers are in the registry.""" + expected = { + "get_memory", + "update_memory", + "delete_memory", + "get_memory_stats", + "apply_memory_decay", + "engram_context", + "get_profile", + "list_profiles", + "search_profiles", + } + for name in expected: + assert name in _TOOL_HANDLERS, f"Handler '{name}' not found in registry" + + def test_handlers_are_callable(self): + """All registered handlers must be callable.""" + for name, handler in _TOOL_HANDLERS.items(): + assert callable(handler), f"Handler '{name}' is not callable" + + def test_handler_signature(self): + """Handlers should accept (memory, arguments, _session_token, _preview).""" + import inspect + for name, handler in _TOOL_HANDLERS.items(): + sig = inspect.signature(handler) + params = list(sig.parameters.keys()) + assert len(params) == 4, f"Handler '{name}' should have 4 params, got {len(params)}: {params}" diff --git a/tests/test_memory_client_v2.py b/tests/test_memory_client_v2.py new file mode 100644 index 0000000..1168740 --- /dev/null +++ b/tests/test_memory_client_v2.py @@ -0,0 +1,71 @@ +"""Tests for MemoryClient v2 policy-management helpers.""" + +from __future__ import annotations + +import pytest + +pytest.importorskip("requests") + +from engram.memory.client import MemoryClient + + +def test_memory_client_agent_policy_methods(monkeypatch): + calls = [] + + def fake_request(self, method, path, *, params=None, json_body=None, extra_headers=None): + calls.append( + { + "method": method, + "path": path, + "params": params, + "json_body": json_body, + "extra_headers": extra_headers, + } + ) + return {"ok": True} + + monkeypatch.setattr(MemoryClient, "_request", fake_request) + + client = MemoryClient(host="http://localhost:8100") + + client.upsert_agent_policy( + user_id="u-client", + agent_id="planner", + allowed_confidentiality_scopes=["work", "personal"], + allowed_capabilities=["search"], + allowed_namespaces=["default", "workbench"], + ) + assert calls[-1]["method"] == "POST" + assert calls[-1]["path"] == "/v1/agent-policies" + assert calls[-1]["json_body"]["agent_id"] == "planner" + + client.list_agent_policies(user_id="u-client") + assert calls[-1]["method"] == "GET" + assert calls[-1]["path"] == "/v1/agent-policies" + assert calls[-1]["params"] == {"user_id": "u-client"} + + client.get_agent_policy(user_id="u-client", agent_id="planner", include_wildcard=False) + assert calls[-1]["method"] == "GET" + assert calls[-1]["path"] == "/v1/agent-policies" + assert calls[-1]["params"]["agent_id"] == "planner" + assert calls[-1]["params"]["include_wildcard"] == "false" + + client.delete_agent_policy(user_id="u-client", agent_id="planner") + assert calls[-1]["method"] == "DELETE" + assert calls[-1]["path"] == "/v1/agent-policies" + assert calls[-1]["params"] == {"user_id": "u-client", "agent_id": "planner"} + + client.handoff_resume(user_id="u-client", agent_id="planner", repo_path="/tmp/repo") + assert calls[-1]["method"] == "POST" + assert calls[-1]["path"] == "/v1/handoff/resume" + assert calls[-1]["json_body"]["agent_id"] == "planner" + + client.handoff_checkpoint(user_id="u-client", agent_id="planner", task_summary="Continue lane") + assert calls[-1]["method"] == "POST" + assert calls[-1]["path"] == "/v1/handoff/checkpoint" + assert calls[-1]["json_body"]["task_summary"] == "Continue lane" + + client.list_handoff_lanes(user_id="u-client", limit=5) + assert calls[-1]["method"] == "GET" + assert calls[-1]["path"] == "/v1/handoff/lanes" + assert calls[-1]["params"]["limit"] == 5 diff --git a/tests/test_migration.py b/tests/test_migration.py new file mode 100644 index 0000000..4b64816 --- /dev/null +++ b/tests/test_migration.py @@ -0,0 +1,120 @@ +"""Tests for schema migration idempotency.""" + +import os +import sqlite3 +import tempfile + +import pytest + +from engram.db.sqlite import SQLiteManager + + +@pytest.fixture +def db_path(): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + yield path + os.unlink(path) + + +class TestMigrationIdempotency: + def test_double_init(self, db_path): + """Tables should be created with IF NOT EXISTS — double init is safe.""" + mgr1 = SQLiteManager(db_path) + mgr2 = SQLiteManager(db_path) # Should not raise + + # Both should work + scenes = mgr2.get_scenes(user_id="test") + assert scenes == [] + + def test_tables_exist(self, db_path): + """All expected tables should be created.""" + mgr = SQLiteManager(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cursor.fetchall()} + conn.close() + + expected = { + "memories", + "memory_history", + "decay_log", + "categories", + "scenes", + "scene_memories", + "profiles", + "profile_memories", + "handoff_sessions", + "handoff_session_memories", + "handoff_lanes", + "handoff_checkpoints", + "handoff_checkpoint_memories", + "handoff_checkpoint_scenes", + "handoff_lane_conflicts", + } + assert expected.issubset(tables), f"Missing tables: {expected - tables}" + + def test_scene_id_column_migration(self, db_path): + """scene_id column should be added to memories table.""" + mgr = SQLiteManager(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.execute("PRAGMA table_info(memories)") + columns = {row[1] for row in cursor.fetchall()} + conn.close() + + assert "scene_id" in columns + + def test_existing_data_untouched(self, db_path): + """Adding new tables should not affect existing data.""" + # Create with old schema (just memories) + conn = sqlite3.connect(db_path) + conn.execute(""" + CREATE TABLE memories ( + id TEXT PRIMARY KEY, + memory TEXT NOT NULL, + user_id TEXT, + agent_id TEXT, + run_id TEXT, + app_id TEXT, + metadata TEXT DEFAULT '{}', + categories TEXT DEFAULT '[]', + immutable INTEGER DEFAULT 0, + expiration_date TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP, + layer TEXT DEFAULT 'sml', + strength REAL DEFAULT 1.0, + access_count INTEGER DEFAULT 0, + last_accessed TEXT DEFAULT CURRENT_TIMESTAMP, + embedding TEXT, + related_memories TEXT DEFAULT '[]', + source_memories TEXT DEFAULT '[]', + tombstone INTEGER DEFAULT 0 + ) + """) + conn.execute( + "INSERT INTO memories (id, memory, user_id) VALUES ('test1', 'hello world', 'u1')" + ) + conn.commit() + conn.close() + + # Now init with SQLiteManager (should add new tables without touching existing data) + mgr = SQLiteManager(db_path) + mem = mgr.get_memory("test1") + assert mem is not None + assert mem["memory"] == "hello world" + + def test_handoff_recent_indexes_exist(self, db_path): + SQLiteManager(db_path) + conn = sqlite3.connect(db_path) + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='index'") + indexes = {row[0] for row in cursor.fetchall()} + conn.close() + + expected = { + "idx_handoff_sessions_recent", + "idx_handoff_sessions_repo_recent", + "idx_handoff_lanes_user_recent", + "idx_handoff_lanes_repo_recent", + } + assert expected.issubset(indexes), f"Missing indexes: {expected - indexes}" diff --git a/tests/test_namespace_system.py b/tests/test_namespace_system.py new file mode 100644 index 0000000..a4a72ea --- /dev/null +++ b/tests/test_namespace_system.py @@ -0,0 +1,76 @@ +"""Tests for namespace-aware access controls.""" + +from __future__ import annotations + +import pytest + +from engram import Engram + + +@pytest.fixture +def memory(): + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def _stage_and_approve(memory, *, content, user_id, agent_id, token, namespace): + proposal = memory.propose_write( + content=content, + user_id=user_id, + agent_id=agent_id, + token=token, + mode="staging", + infer=False, + scope="work", + namespace=namespace, + ) + memory.approve_commit(proposal["commit_id"]) + return proposal + + +def test_namespace_masking_for_reader(memory): + memory.declare_namespace(user_id="u-ns", namespace="workbench") + memory.declare_namespace(user_id="u-ns", namespace="private-lab") + + writer = memory.create_session( + user_id="u-ns", + agent_id="writer", + namespaces=["workbench", "private-lab"], + ) + _stage_and_approve( + memory, + content="Workbench note about architecture", + user_id="u-ns", + agent_id="writer", + token=writer["token"], + namespace="workbench", + ) + _stage_and_approve( + memory, + content="Private-lab secret about salaries", + user_id="u-ns", + agent_id="writer", + token=writer["token"], + namespace="private-lab", + ) + + reader = memory.create_session( + user_id="u-ns", + agent_id="reader", + namespaces=["workbench"], + ) + payload = memory.search_with_context( + query="architecture and salaries", + user_id="u-ns", + agent_id="reader", + token=reader["token"], + limit=10, + ) + + assert payload["results"] + masked = [item for item in payload["results"] if item.get("masked")] + visible = [item for item in payload["results"] if not item.get("masked")] + assert masked + assert visible + assert all(item.get("details") == "[REDACTED]" for item in masked) + assert all(item.get("namespace", "workbench") == "workbench" for item in visible) diff --git a/tests/test_policy_masking.py b/tests/test_policy_masking.py new file mode 100644 index 0000000..b7b889c --- /dev/null +++ b/tests/test_policy_masking.py @@ -0,0 +1,72 @@ +"""Tests for confidentiality scope masking behavior.""" + +import pytest + +from engram import Engram + + +@pytest.fixture +def memory(): + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def _stage_and_approve(memory, *, content, user_id, agent_id, token, scope): + proposal = memory.propose_write( + content=content, + user_id=user_id, + agent_id=agent_id, + token=token, + mode="staging", + infer=False, + scope=scope, + ) + memory.approve_commit(proposal["commit_id"]) + + +def test_out_of_scope_results_are_masked(memory): + # Writer agent can write both work and finance memories. + writer = memory.create_session( + user_id="u-mask", + agent_id="writer", + allowed_confidentiality_scopes=["work", "finance"], + ) + _stage_and_approve( + memory, + content="Work plan: migrate engram API endpoints", + user_id="u-mask", + agent_id="writer", + token=writer["token"], + scope="work", + ) + _stage_and_approve( + memory, + content="Finance update: salary is 200k", + user_id="u-mask", + agent_id="writer", + token=writer["token"], + scope="finance", + ) + + # Reader session can only read work scope. + reader = memory.create_session( + user_id="u-mask", + agent_id="reader", + allowed_confidentiality_scopes=["work"], + ) + + payload = memory.search_with_context( + query="salary and finance update", + user_id="u-mask", + agent_id="reader", + token=reader["token"], + limit=10, + ) + + assert payload["results"] + assert any(item.get("masked") for item in payload["results"]) + masked_items = [item for item in payload["results"] if item.get("masked")] + assert all(item.get("details") == "[REDACTED]" for item in masked_items) + # Ensure secret value is not leaked in masked payload. + for item in masked_items: + assert "200k" not in str(item) diff --git a/tests/test_profile.py b/tests/test_profile.py new file mode 100644 index 0000000..46c81c5 --- /dev/null +++ b/tests/test_profile.py @@ -0,0 +1,205 @@ +"""Tests for ProfileProcessor — extraction, updates, narrative, self-profile.""" + +import os +import tempfile +import uuid + +import pytest + +from engram.core.profile import ProfileProcessor, ProfileUpdate, _SELF_PATTERNS, _PERSON_PATTERN +from engram.db.sqlite import SQLiteManager + + +@pytest.fixture +def db(): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + mgr = SQLiteManager(path) + yield mgr + os.unlink(path) + + +@pytest.fixture +def processor(db): + return ProfileProcessor( + db=db, + embedder=None, + llm=None, + config={ + "auto_detect_profiles": True, + "use_llm_extraction": False, # No LLM in tests + "narrative_regenerate_threshold": 10, + "self_profile_auto_create": True, + "max_facts_per_profile": 100, + }, + ) + + +class TestSelfPatterns: + def test_i_prefer(self): + assert any(p.search("I prefer dark mode") for p in _SELF_PATTERNS) + + def test_my_name(self): + assert any(p.search("my name is John") for p in _SELF_PATTERNS) + + def test_im_a(self): + assert any(p.search("I'm a software engineer") for p in _SELF_PATTERNS) + + def test_no_match(self): + assert not any(p.search("The sky is blue") for p in _SELF_PATTERNS) + + +class TestPersonPattern: + def test_full_name(self): + matches = _PERSON_PATTERN.findall("John Smith works here") + assert "John Smith" in matches + + def test_no_match(self): + matches = _PERSON_PATTERN.findall("hello world") + assert len(matches) == 0 + + +class TestExtraction: + def test_self_preference(self, processor): + updates = processor.extract_profile_mentions( + "I prefer using Python for data analysis", + user_id="u1", + ) + assert len(updates) >= 1 + self_update = next((u for u in updates if u.profile_type == "self"), None) + assert self_update is not None + assert len(self_update.new_preferences) > 0 + + def test_person_mention(self, processor): + updates = processor.extract_profile_mentions( + "Had a meeting with John Smith about the project", + user_id="u1", + ) + person_updates = [u for u in updates if u.profile_type == "contact"] + assert len(person_updates) >= 1 + assert any("John Smith" in u.profile_name for u in person_updates) + + def test_no_mentions(self, processor): + updates = processor.extract_profile_mentions( + "the sky is blue today", + user_id="u1", + ) + assert len(updates) == 0 + + +class TestSelfProfile: + def test_ensure_self_profile(self, processor, db): + profile = processor.ensure_self_profile("u1") + assert profile["name"] == "self" + assert profile["profile_type"] == "self" + + # Second call returns same profile + profile2 = processor.ensure_self_profile("u1") + assert profile2["id"] == profile["id"] + + def test_auto_create_on_self_ref(self, processor, db): + mem_id = str(uuid.uuid4()) + db.add_memory({"id": mem_id, "memory": "I prefer dark mode", "user_id": "u1"}) + + updates = processor.extract_profile_mentions("I prefer dark mode", user_id="u1") + for u in updates: + processor.apply_update(u, mem_id, "u1") + + self_profile = db.get_profile_by_name("self", user_id="u1") + assert self_profile is not None + assert len(self_profile["preferences"]) > 0 + + +class TestProfileLifecycle: + def test_create_contact(self, processor, db): + mem_id = str(uuid.uuid4()) + db.add_memory({"id": mem_id, "memory": "Met with Alice Johnson", "user_id": "u1"}) + + update = ProfileUpdate( + profile_name="Alice Johnson", + profile_type="contact", + new_facts=["Met for lunch"], + ) + profile_id = processor.apply_update(update, mem_id, "u1") + assert profile_id + + profile = db.get_profile(profile_id) + assert profile["name"] == "Alice Johnson" + assert "Met for lunch" in profile["facts"] + + def test_merge_facts(self, processor, db): + mem1 = str(uuid.uuid4()) + mem2 = str(uuid.uuid4()) + db.add_memory({"id": mem1, "memory": "fact 1", "user_id": "u1"}) + db.add_memory({"id": mem2, "memory": "fact 2", "user_id": "u1"}) + + update1 = ProfileUpdate( + profile_name="Bob Wilson", + profile_type="contact", + new_facts=["Works at Google"], + ) + pid = processor.apply_update(update1, mem1, "u1") + + update2 = ProfileUpdate( + profile_name="Bob Wilson", + profile_type="contact", + new_facts=["Likes Python", "Works at Google"], # duplicate + ) + pid2 = processor.apply_update(update2, mem2, "u1") + assert pid == pid2 + + profile = db.get_profile(pid) + assert "Works at Google" in profile["facts"] + assert "Likes Python" in profile["facts"] + # No duplicate + assert profile["facts"].count("Works at Google") == 1 + + def test_max_facts(self, processor, db): + processor.max_facts = 3 + mem_id = str(uuid.uuid4()) + db.add_memory({"id": mem_id, "memory": "test", "user_id": "u1"}) + + update = ProfileUpdate( + profile_name="Test Person", + profile_type="contact", + new_facts=["fact1", "fact2", "fact3", "fact4", "fact5"], + ) + pid = processor.apply_update(update, mem_id, "u1") + profile = db.get_profile(pid) + assert len(profile["facts"]) <= 3 + + +class TestProfileSearch: + def test_keyword_search(self, processor, db): + db.add_profile({ + "id": str(uuid.uuid4()), + "user_id": "u1", + "name": "Alice Johnson", + "profile_type": "contact", + "facts": ["Software engineer", "Works at Google"], + }) + + results = processor.search_profiles("Alice", user_id="u1") + assert len(results) >= 1 + assert any("Alice" in r["name"] for r in results) + + def test_no_results(self, processor, db): + results = processor.search_profiles("nonexistent", user_id="u1") + assert len(results) == 0 + + +class TestProfileMemories: + def test_link_memory(self, processor, db): + mem_id = str(uuid.uuid4()) + db.add_memory({"id": mem_id, "memory": "About Alice", "user_id": "u1"}) + + update = ProfileUpdate( + profile_name="Alice Test", + profile_type="contact", + new_facts=["Test fact"], + ) + pid = processor.apply_update(update, mem_id, "u1") + + linked = db.get_profile_memories(pid) + assert len(linked) == 1 + assert linked[0]["id"] == mem_id diff --git a/tests/test_refaware_decay.py b/tests/test_refaware_decay.py new file mode 100644 index 0000000..bd72fa6 --- /dev/null +++ b/tests/test_refaware_decay.py @@ -0,0 +1,45 @@ +"""Tests for reference-aware FadeMem behavior.""" + +from datetime import datetime, timedelta + +import pytest + +from engram import Engram + + +@pytest.fixture +def memory(monkeypatch): + monkeypatch.setenv("ENGRAM_V2_REF_AWARE_DECAY", "true") + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def _create_memory(memory, user_id: str, content: str) -> str: + added = memory.add(messages=content, user_id=user_id, infer=False) + return added["results"][0]["id"] + + +def test_strong_ref_pauses_decay(memory): + memory_id = _create_memory(memory, "u-decay-strong", "critical memory") + stale_time = (datetime.utcnow() - timedelta(days=90)).isoformat() + memory.db.update_memory(memory_id, {"strength": 0.01, "last_accessed": stale_time}) + + memory.db.add_memory_subscriber(memory_id, "agent:planner", ref_type="strong") + result = memory.apply_decay(scope={"user_id": "u-decay-strong"}) + + # Memory should not be forgotten due to strong reference. + assert memory.db.get_memory(memory_id) is not None + assert result["forgotten"] == 0 + + +def test_weak_ref_dampens_forgetting(memory): + memory_id = _create_memory(memory, "u-decay-weak", "semi-important memory") + stale_time = (datetime.utcnow() - timedelta(days=30)).isoformat() + memory.db.update_memory(memory_id, {"strength": 0.11, "last_accessed": stale_time}) + + memory.db.add_memory_subscriber(memory_id, "agent:researcher", ref_type="weak") + memory.apply_decay(scope={"user_id": "u-decay-weak"}) + + mem = memory.db.get_memory(memory_id) + assert mem is not None + assert mem["strength"] > 0.0 diff --git a/tests/test_scene.py b/tests/test_scene.py new file mode 100644 index 0000000..009f129 --- /dev/null +++ b/tests/test_scene.py @@ -0,0 +1,217 @@ +"""Tests for SceneProcessor — boundary detection, creation, closing, summarization.""" + +import os +import tempfile +import uuid +from datetime import datetime, timedelta + +import pytest + +from engram.core.scene import SceneProcessor, SceneDetectionResult, _detect_location, _cosine_similarity +from engram.db.sqlite import SQLiteManager + + +@pytest.fixture +def db(): + """Create a temporary SQLite database.""" + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + mgr = SQLiteManager(path) + yield mgr + os.unlink(path) + + +@pytest.fixture +def processor(db): + return SceneProcessor( + db=db, + embedder=None, + llm=None, + config={ + "scene_time_gap_minutes": 30, + "scene_topic_threshold": 0.55, + "auto_close_inactive_minutes": 120, + "max_scene_memories": 5, + "use_llm_summarization": False, # No LLM in tests + }, + ) + + +class TestCosineSimililarity: + def test_identical(self): + v = [1.0, 0.0, 0.5] + assert abs(_cosine_similarity(v, v) - 1.0) < 1e-6 + + def test_orthogonal(self): + a = [1.0, 0.0] + b = [0.0, 1.0] + assert abs(_cosine_similarity(a, b)) < 1e-6 + + def test_empty(self): + assert _cosine_similarity([], []) == 0.0 + + def test_different_lengths(self): + assert _cosine_similarity([1.0], [1.0, 2.0]) == 0.0 + + +class TestLocationDetection: + def test_at_location(self): + loc = _detect_location("Meeting at Starbucks") + assert loc is not None + assert "Starbucks" in loc + + def test_in_location(self): + assert _detect_location("Currently in New York") == "New York" + + def test_no_location(self): + assert _detect_location("just a random sentence") is None + + +class TestBoundaryDetection: + def test_no_current_scene(self, processor): + result = processor.detect_boundary("hello", datetime.utcnow().isoformat(), None) + assert result.is_new_scene is True + assert result.reason == "no_scene" + + def test_time_gap(self, processor): + now = datetime.utcnow() + old_time = (now - timedelta(minutes=60)).isoformat() + scene = {"start_time": old_time, "end_time": old_time, "memory_ids": ["a"]} + result = processor.detect_boundary("hi", now.isoformat(), scene) + assert result.is_new_scene is True + assert result.reason == "time_gap" + + def test_no_gap(self, processor): + now = datetime.utcnow() + recent = (now - timedelta(minutes=5)).isoformat() + scene = { + "start_time": recent, + "end_time": recent, + "memory_ids": ["a"], + "location": None, + "embedding": None, + } + result = processor.detect_boundary("hi", now.isoformat(), scene) + assert result.is_new_scene is False + + def test_max_memories(self, processor): + now = datetime.utcnow() + recent = (now - timedelta(minutes=1)).isoformat() + scene = { + "start_time": recent, + "end_time": recent, + "memory_ids": ["a", "b", "c", "d", "e"], # max is 5 + "location": None, + "embedding": None, + } + result = processor.detect_boundary("hi", now.isoformat(), scene) + assert result.is_new_scene is True + assert result.reason == "max_memories" + + def test_topic_shift(self, processor): + now = datetime.utcnow() + recent = (now - timedelta(minutes=1)).isoformat() + # Orthogonal embeddings = similarity 0 + scene_emb = [1.0, 0.0, 0.0] + mem_emb = [0.0, 1.0, 0.0] + scene = { + "start_time": recent, + "end_time": recent, + "memory_ids": ["a"], + "location": None, + "embedding": scene_emb, + } + result = processor.detect_boundary("hi", now.isoformat(), scene, embedding=mem_emb) + assert result.is_new_scene is True + assert result.reason == "topic_shift" + + def test_location_change(self, processor): + now = datetime.utcnow() + recent = (now - timedelta(minutes=1)).isoformat() + scene = { + "start_time": recent, + "end_time": recent, + "memory_ids": ["a"], + "location": "Office", + "embedding": None, + } + result = processor.detect_boundary("Meeting at Starbucks today", now.isoformat(), scene) + assert result.is_new_scene is True + assert result.reason == "location_change" + + +class TestSceneLifecycle: + def test_create_scene(self, processor, db): + mem_id = str(uuid.uuid4()) + now = datetime.utcnow().isoformat() + # Add a memory first + db.add_memory({"id": mem_id, "memory": "test", "user_id": "u1"}) + + scene = processor.create_scene( + first_memory_id=mem_id, + user_id="u1", + timestamp=now, + topic="Test topic", + location="Office", + ) + assert scene["id"] + assert scene["topic"] == "Test topic" + assert scene["memory_ids"] == [mem_id] + + # Verify in DB + fetched = db.get_scene(scene["id"]) + assert fetched is not None + assert fetched["user_id"] == "u1" + + def test_add_memory_to_scene(self, processor, db): + mem1 = str(uuid.uuid4()) + mem2 = str(uuid.uuid4()) + now = datetime.utcnow().isoformat() + db.add_memory({"id": mem1, "memory": "first", "user_id": "u1"}) + db.add_memory({"id": mem2, "memory": "second", "user_id": "u1"}) + + scene = processor.create_scene(mem1, "u1", now, topic="topic") + processor.add_memory_to_scene(scene["id"], mem2, timestamp=now) + + fetched = db.get_scene(scene["id"]) + assert mem2 in fetched["memory_ids"] + + def test_close_scene(self, processor, db): + mem_id = str(uuid.uuid4()) + now = datetime.utcnow().isoformat() + db.add_memory({"id": mem_id, "memory": "test", "user_id": "u1"}) + + scene = processor.create_scene(mem_id, "u1", now, topic="topic") + assert db.get_open_scene("u1") is not None + + processor.close_scene(scene["id"]) + fetched = db.get_scene(scene["id"]) + assert fetched["end_time"] is not None + + def test_get_open_scene(self, processor, db): + mem_id = str(uuid.uuid4()) + now = datetime.utcnow().isoformat() + db.add_memory({"id": mem_id, "memory": "test", "user_id": "u1"}) + + processor.create_scene(mem_id, "u1", now, topic="t1") + open_scene = db.get_open_scene("u1") + assert open_scene is not None + + +class TestSceneSearch: + def test_keyword_search(self, processor, db): + mem_id = str(uuid.uuid4()) + now = datetime.utcnow().isoformat() + db.add_memory({"id": mem_id, "memory": "test", "user_id": "u1"}) + + processor.create_scene(mem_id, "u1", now, topic="python debugging session") + processor.close_scene( + db.get_open_scene("u1")["id"], + timestamp=now, + ) + # Update summary manually since no LLM + scenes = db.get_scenes(user_id="u1") + db.update_scene(scenes[0]["id"], {"summary": "Debugging Python code"}) + + results = processor.search_scenes("python", user_id="u1") + assert len(results) >= 1 diff --git a/tests/test_security_handoff_strict.py b/tests/test_security_handoff_strict.py new file mode 100644 index 0000000..f4baaad --- /dev/null +++ b/tests/test_security_handoff_strict.py @@ -0,0 +1,33 @@ +"""Strict handoff security defaults tests.""" + +from __future__ import annotations + +import pytest + +from engram import Engram + + +@pytest.fixture +def memory(): + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def test_strict_default_denies_implicit_trusted_bootstrap(memory): + with pytest.raises(PermissionError): + memory.create_session( + user_id="u-strict-1", + agent_id="codex", + capabilities=["read_handoff", "write_handoff"], + ) + + +def test_opt_in_bootstrap_allows_trusted_agent(memory): + memory.handoff_config.allow_auto_trusted_bootstrap = True + session = memory.create_session( + user_id="u-strict-2", + agent_id="codex", + capabilities=["read_handoff", "write_handoff"], + ) + assert session.get("token") + assert {"read_handoff", "write_handoff"}.issubset(set(session.get("capabilities", []))) diff --git a/tests/test_security_sessions.py b/tests/test_security_sessions.py new file mode 100644 index 0000000..61ce5b6 --- /dev/null +++ b/tests/test_security_sessions.py @@ -0,0 +1,242 @@ +"""Security-focused API tests for session issuance and token enforcement.""" + +from __future__ import annotations + +import importlib + +import pytest + +pytest.importorskip("fastapi") +pytest.importorskip("httpx") + +from fastapi.testclient import TestClient + +api_app_module = importlib.import_module("engram.api.app") +import engram.api.auth as auth_module +from engram import Engram + + +@pytest.fixture +def client(): + eng = Engram(in_memory=True, provider="mock") + api_app_module._memory = eng._memory + with TestClient(api_app_module.app) as test_client: + yield test_client + api_app_module._memory = None + + +def test_session_creation_requires_admin_key_when_configured(client, monkeypatch): + monkeypatch.setenv("ENGRAM_ADMIN_KEY", "super-secret") + + denied = client.post( + "/v1/sessions", + json={"user_id": "u-admin", "agent_id": "agent-admin"}, + ) + assert denied.status_code == 403 + + allowed = client.post( + "/v1/sessions", + headers={"X-Engram-Admin-Key": "super-secret"}, + json={"user_id": "u-admin", "agent_id": "agent-admin"}, + ) + assert allowed.status_code == 200 + body = allowed.json() + assert body.get("token") + assert body.get("session_id") + + +def test_untrusted_client_must_send_bearer_token(client, monkeypatch): + # Simulate a non-local caller regardless of testclient host. + monkeypatch.setattr(auth_module, "is_trusted_local_client", lambda request: False) + + memory = api_app_module.get_memory() + memory.add(messages="Searchable memory for token test", user_id="u-token", infer=False) + + denied = client.post( + "/v1/search", + json={"query": "searchable", "user_id": "u-token", "agent_id": "agent-token"}, + ) + assert denied.status_code == 401 + + session = api_app_module.get_kernel().create_session( + user_id="u-token", + agent_id="agent-token", + capabilities=["search"], + ) + allowed = client.post( + "/v1/search", + headers={"Authorization": f"Bearer {session['token']}"}, + json={"query": "searchable", "user_id": "u-token", "agent_id": "agent-token"}, + ) + assert allowed.status_code == 200 + payload = allowed.json() + assert payload["count"] >= 1 + + +def test_session_creation_returns_403_when_policy_required_and_missing(client, monkeypatch): + monkeypatch.setenv("ENGRAM_V2_REQUIRE_AGENT_POLICY", "true") + + denied = client.post( + "/v1/sessions", + json={"user_id": "u-policy", "agent_id": "agent-missing"}, + ) + assert denied.status_code == 403 + assert "policy" in denied.json().get("detail", "").lower() + + +def test_handoff_session_creation_denies_untrusted_agent(client): + denied = client.post( + "/v1/sessions", + json={ + "user_id": "u-handoff-policy", + "agent_id": "rogue-agent", + "capabilities": ["read_handoff"], + }, + ) + assert denied.status_code == 403 + assert "handoff" in denied.json().get("detail", "").lower() + + trusted_without_policy = client.post( + "/v1/sessions", + json={ + "user_id": "u-handoff-policy", + "agent_id": "codex", + "capabilities": ["read_handoff", "write_handoff"], + }, + ) + assert trusted_without_policy.status_code == 403 + + policy = client.post( + "/v1/agent-policies", + json={ + "user_id": "u-handoff-policy", + "agent_id": "codex", + "allowed_confidentiality_scopes": ["work"], + "allowed_capabilities": ["read_handoff", "write_handoff"], + "allowed_namespaces": ["default"], + }, + ) + assert policy.status_code == 200 + + allowed = client.post( + "/v1/sessions", + json={ + "user_id": "u-handoff-policy", + "agent_id": "codex", + "capabilities": ["read_handoff", "write_handoff"], + }, + ) + assert allowed.status_code == 200 + payload = allowed.json() + assert {"read_handoff", "write_handoff"}.issubset(set(payload.get("capabilities", []))) + + +def test_agent_policy_api_round_trip_and_session_clamping(client): + upsert = client.post( + "/v1/agent-policies", + json={ + "user_id": "u-policy-api", + "agent_id": "planner", + "allowed_confidentiality_scopes": ["work"], + "allowed_capabilities": ["search"], + "allowed_namespaces": ["default"], + }, + ) + assert upsert.status_code == 200 + upsert_body = upsert.json() + assert upsert_body["user_id"] == "u-policy-api" + assert upsert_body["agent_id"] == "planner" + + session = client.post( + "/v1/sessions", + json={ + "user_id": "u-policy-api", + "agent_id": "planner", + "allowed_confidentiality_scopes": ["work", "finance"], + "capabilities": ["search", "review_commits"], + "namespaces": ["default", "private-lab"], + }, + ) + assert session.status_code == 200 + session_body = session.json() + assert set(session_body["allowed_confidentiality_scopes"]) == {"work"} + assert set(session_body["capabilities"]) == {"search"} + assert set(session_body["namespaces"]) == {"default"} + + get_one = client.get( + "/v1/agent-policies", + params={"user_id": "u-policy-api", "agent_id": "planner"}, + ) + assert get_one.status_code == 200 + payload = get_one.json() + assert payload["policy"]["agent_id"] == "planner" + + delete = client.delete( + "/v1/agent-policies", + params={"user_id": "u-policy-api", "agent_id": "planner"}, + ) + assert delete.status_code == 200 + assert delete.json()["deleted"] is True + + +def test_handoff_endpoints_require_token_and_capabilities(client, monkeypatch): + monkeypatch.setattr(auth_module, "is_trusted_local_client", lambda request: False) + + denied = client.post( + "/v1/handoff/resume", + json={"user_id": "u-handoff-api", "agent_id": "claude-code", "repo_path": "/tmp/repo"}, + ) + assert denied.status_code == 401 + + api_app_module.get_kernel().db.upsert_agent_policy( + user_id="u-handoff-api", + agent_id="claude-code", + allowed_confidentiality_scopes=["work"], + allowed_capabilities=["read_handoff", "write_handoff"], + allowed_namespaces=["default"], + ) + session = api_app_module.get_kernel().create_session( + user_id="u-handoff-api", + agent_id="claude-code", + capabilities=["read_handoff", "write_handoff"], + ) + headers = {"Authorization": f"Bearer {session['token']}"} + + resumed = client.post( + "/v1/handoff/resume", + headers=headers, + json={ + "user_id": "u-handoff-api", + "agent_id": "claude-code", + "repo_path": "/tmp/repo", + "objective": "Continue backend work", + "requester_agent_id": "claude-code", + }, + ) + assert resumed.status_code == 200 + lane_id = resumed.json().get("lane_id") + assert lane_id + + checkpoint = client.post( + "/v1/handoff/checkpoint", + headers=headers, + json={ + "user_id": "u-handoff-api", + "agent_id": "claude-code", + "lane_id": lane_id, + "repo_path": "/tmp/repo", + "task_summary": "Implemented API endpoint", + "event_type": "tool_complete", + "requester_agent_id": "claude-code", + }, + ) + assert checkpoint.status_code == 200 + assert checkpoint.json().get("checkpoint_id") + + lanes = client.get( + "/v1/handoff/lanes", + headers=headers, + params={"user_id": "u-handoff-api", "requester_agent_id": "claude-code"}, + ) + assert lanes.status_code == 200 + assert lanes.json().get("count", 0) >= 1 diff --git a/tests/test_sleep_cycle.py b/tests/test_sleep_cycle.py new file mode 100644 index 0000000..befaefd --- /dev/null +++ b/tests/test_sleep_cycle.py @@ -0,0 +1,47 @@ +"""Tests for sleep-cycle maintenance flow.""" + +from __future__ import annotations + +from datetime import datetime + +import pytest + +from engram import Engram + + +@pytest.fixture +def memory(): + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def test_sleep_cycle_generates_digest_promotes_and_cleans_refs(memory): + today = datetime.utcnow().date().isoformat() + add = memory.add( + messages="Important retention candidate", + user_id="u-sleep", + metadata={"importance": 0.95, "namespace": "default"}, + infer=False, + ) + memory_id = add["results"][0]["id"] + # Force stale weak ref so GC has work. + memory.db.add_memory_subscriber(memory_id, "agent:stale", ref_type="weak", ttl_hours=-1) + + run = memory.run_sleep_cycle( + user_id="u-sleep", + date_str=today, + apply_decay=False, + cleanup_stale_refs=True, + ) + + assert run["users"]["u-sleep"]["promoted"] >= 1 + assert run["stale_refs_removed"] >= 1 + + updated = memory.get(memory_id) + assert updated is not None + assert updated.get("layer") == "lml" + + digest = memory.get_daily_digest(user_id="u-sleep", date_str=today) + assert digest["date"] == today + assert "top_conflicts" in digest + assert "top_proposed_consolidations" in digest diff --git a/tests/test_sqlite_connection_pool.py b/tests/test_sqlite_connection_pool.py new file mode 100644 index 0000000..6915082 --- /dev/null +++ b/tests/test_sqlite_connection_pool.py @@ -0,0 +1,174 @@ +"""Tests for SQLite connection pooling, batch ops, and type safety (Phases 1, 2, 5).""" + +import os +import tempfile +import threading + +import pytest + +from engram.db.sqlite import SQLiteManager, VALID_MEMORY_COLUMNS, VALID_SCENE_COLUMNS, _utcnow_iso + + +@pytest.fixture +def db_manager(): + """Create a temporary SQLiteManager for testing.""" + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + mgr = SQLiteManager(path) + yield mgr + mgr.close() + os.unlink(path) + + +def _add_test_memory(mgr, memory_id="test-1", content="Hello world", user_id="user1"): + now = _utcnow_iso() + mgr.add_memory({ + "id": memory_id, + "memory": content, + "user_id": user_id, + "created_at": now, + "updated_at": now, + "layer": "sml", + "strength": 1.0, + }) + return memory_id + + +class TestConnectionPool: + def test_persistent_connection_wal_mode(self, db_manager): + """Verify WAL mode is enabled on the persistent connection.""" + with db_manager._get_connection() as conn: + mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + assert mode == "wal" + + def test_connection_reuse(self, db_manager): + """Same connection object is yielded on successive calls.""" + with db_manager._get_connection() as conn1: + pass + with db_manager._get_connection() as conn2: + pass + assert conn1 is conn2 + + def test_thread_safety(self, db_manager): + """Concurrent threads can safely access the DB.""" + _add_test_memory(db_manager, "thread-1") + results = [] + + def reader(): + mem = db_manager.get_memory("thread-1") + results.append(mem is not None) + + threads = [threading.Thread(target=reader) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert all(results) + + def test_close(self, db_manager): + """close() shuts down cleanly.""" + _add_test_memory(db_manager, "close-test") + db_manager.close() + # Connection is None after close. + assert db_manager._conn is None + + def test_repr(self, db_manager): + assert "SQLiteManager" in repr(db_manager) + assert db_manager.db_path in repr(db_manager) + + +class TestBatchOperations: + def test_get_memories_bulk(self, db_manager): + _add_test_memory(db_manager, "bulk-1", "Memory 1") + _add_test_memory(db_manager, "bulk-2", "Memory 2") + _add_test_memory(db_manager, "bulk-3", "Memory 3") + + result = db_manager.get_memories_bulk(["bulk-1", "bulk-3"]) + assert len(result) == 2 + assert "bulk-1" in result + assert "bulk-3" in result + assert result["bulk-1"]["memory"] == "Memory 1" + + def test_get_memories_bulk_empty(self, db_manager): + assert db_manager.get_memories_bulk([]) == {} + + def test_get_memories_bulk_missing(self, db_manager): + result = db_manager.get_memories_bulk(["nonexistent"]) + assert len(result) == 0 + + def test_increment_access_bulk(self, db_manager): + _add_test_memory(db_manager, "inc-1") + _add_test_memory(db_manager, "inc-2") + + db_manager.increment_access_bulk(["inc-1", "inc-2"]) + + mem1 = db_manager.get_memory("inc-1") + mem2 = db_manager.get_memory("inc-2") + assert mem1["access_count"] == 1 + assert mem2["access_count"] == 1 + + def test_increment_access_bulk_empty(self, db_manager): + db_manager.increment_access_bulk([]) # Should not raise. + + def test_update_strength_bulk(self, db_manager): + _add_test_memory(db_manager, "str-1") + _add_test_memory(db_manager, "str-2") + + db_manager.update_strength_bulk({"str-1": 0.8, "str-2": 0.6}) + + mem1 = db_manager.get_memory("str-1") + mem2 = db_manager.get_memory("str-2") + assert abs(mem1["strength"] - 0.8) < 0.01 + assert abs(mem2["strength"] - 0.6) < 0.01 + + +class TestTypeSafety: + def test_update_memory_rejects_invalid_column(self, db_manager): + _add_test_memory(db_manager, "safe-1") + with pytest.raises(ValueError, match="Invalid memory column"): + db_manager.update_memory("safe-1", {"robert_tables; DROP TABLE memories--": "hacked"}) + + def test_update_memory_valid_columns(self, db_manager): + _add_test_memory(db_manager, "safe-2") + assert db_manager.update_memory("safe-2", {"strength": 0.5}) + mem = db_manager.get_memory("safe-2") + assert abs(mem["strength"] - 0.5) < 0.01 + + def test_update_scene_rejects_invalid_column(self, db_manager): + scene_id = db_manager.add_scene({"id": "scene-1", "user_id": "u1", "start_time": _utcnow_iso()}) + with pytest.raises(ValueError, match="Invalid scene column"): + db_manager.update_scene(scene_id, {"evil_column": "hack"}) + + def test_update_profile_rejects_invalid_column(self, db_manager): + pid = db_manager.add_profile({"id": "prof-1", "user_id": "u1", "name": "Test"}) + with pytest.raises(ValueError, match="Invalid profile column"): + db_manager.update_profile(pid, {"evil_column": "hack"}) + + def test_migrate_add_column_rejects_invalid_table(self, db_manager): + with db_manager._get_connection() as conn: + with pytest.raises(ValueError, match="Invalid table"): + db_manager._migrate_add_column_conn(conn, "evil_table", "col", "TEXT") + + def test_migrate_add_column_rejects_invalid_column_name(self, db_manager): + with db_manager._get_connection() as conn: + with pytest.raises(ValueError, match="Invalid column name"): + db_manager._migrate_add_column_conn(conn, "memories", "evil;drop", "TEXT") + + def test_valid_columns_frozensets_not_empty(self): + assert len(VALID_MEMORY_COLUMNS) > 10 + assert len(VALID_SCENE_COLUMNS) > 5 + + +class TestMigrationIdempotency: + def test_v2_columns_complete_marker(self, db_manager): + """After init, the v2_columns_complete migration should be applied.""" + with db_manager._get_connection() as conn: + assert db_manager._is_migration_applied(conn, "v2_columns_complete") + + def test_reinit_skips_backfills(self, db_manager): + """Re-running _init_db should be fast because migrations are skipped.""" + # Just verify it doesn't error on second run. + db_manager._init_db() + with db_manager._get_connection() as conn: + assert db_manager._is_migration_applied(conn, "v2_columns_complete") diff --git a/tests/test_staging.py b/tests/test_staging.py new file mode 100644 index 0000000..1ff0973 --- /dev/null +++ b/tests/test_staging.py @@ -0,0 +1,299 @@ +"""Tests for v2 staged writes, approval/rejection, and conflict stash.""" + +import pytest + +from engram import Engram + + +@pytest.fixture +def memory(): + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def test_staging_commit_lifecycle(memory): + session = memory.create_session(user_id="u-staging", agent_id="planner") + + proposal = memory.propose_write( + content="Project codename is Atlas", + user_id="u-staging", + agent_id="planner", + token=session["token"], + mode="staging", + infer=False, + ) + assert proposal["commit_id"] + assert proposal["status"] in {"PENDING", "AUTO_STASHED"} + + pending = memory.list_pending_commits(user_id="u-staging", status="PENDING") + assert pending["count"] >= 1 + + approved = memory.approve_commit(proposal["commit_id"]) + assert approved["status"] == "APPROVED" + + results = memory.search( + query="Atlas codename", + user_id="u-staging", + agent_id="planner", + limit=5, + ) + assert results["results"] + + +def test_reject_commit(memory): + session = memory.create_session(user_id="u-reject", agent_id="planner") + + proposal = memory.propose_write( + content="Temporary wrong statement", + user_id="u-reject", + agent_id="planner", + token=session["token"], + mode="staging", + infer=False, + ) + + rejected = memory.reject_commit(proposal["commit_id"], reason="Incorrect") + assert rejected["status"] == "REJECTED" + + +def test_invariant_conflict_creates_stash(memory): + session = memory.create_session(user_id="u-inv", agent_id="planner") + + initial = memory.propose_write( + content="my name is Alice", + user_id="u-inv", + agent_id="planner", + token=session["token"], + mode="staging", + infer=False, + ) + memory.approve_commit(initial["commit_id"]) + + conflicting = memory.propose_write( + content="my name is Bob", + user_id="u-inv", + agent_id="planner", + token=session["token"], + mode="staging", + infer=False, + ) + + assert conflicting["status"] == "AUTO_STASHED" + stash_items = memory.db.list_conflict_stash(user_id="u-inv", resolution="UNRESOLVED", limit=20) + assert stash_items + assert stash_items[0]["conflict_key"] == "identity.name" + + +def test_commit_listing_requires_review_capability_for_agents(memory): + writer = memory.create_session( + user_id="u-review", + agent_id="planner", + capabilities=["propose_write"], + ) + proposal = memory.propose_write( + content="Pending proposal for review authorization check", + user_id="u-review", + agent_id="planner", + token=writer["token"], + mode="staging", + infer=False, + ) + assert proposal["commit_id"] + + with pytest.raises(PermissionError): + memory.list_pending_commits(user_id="u-review", agent_id="planner", token=None, limit=10) + with pytest.raises(PermissionError): + memory.list_pending_commits(user_id="u-review", agent_id="planner", token=writer["token"], limit=10) + + reviewer = memory.create_session( + user_id="u-review", + agent_id="planner", + capabilities=["review_commits"], + ) + listed = memory.list_pending_commits( + user_id="u-review", + agent_id="planner", + token=reviewer["token"], + limit=10, + ) + assert listed["count"] >= 1 + + +def test_approve_commit_is_idempotent(memory): + session = memory.create_session(user_id="u-idempotent", agent_id="planner") + proposal = memory.propose_write( + content="Idempotent approval memory", + user_id="u-idempotent", + agent_id="planner", + token=session["token"], + mode="staging", + infer=False, + ) + + first = memory.approve_commit(proposal["commit_id"]) + assert first["status"] == "APPROVED" + + count_after_first = len(memory.db.get_all_memories(user_id="u-idempotent")) + second = memory.approve_commit(proposal["commit_id"]) + count_after_second = len(memory.db.get_all_memories(user_id="u-idempotent")) + + assert second["status"] == "APPROVED" + assert second["applied"] == [] + assert count_after_second == count_after_first + + +def test_direct_write_is_idempotent_by_source_event_id(memory): + session = memory.create_session(user_id="u-source-event-direct", agent_id="planner") + + first = memory.propose_write( + content="Source event idempotency payload", + user_id="u-source-event-direct", + agent_id="planner", + token=session["token"], + mode="direct", + trusted_direct=True, + infer=False, + source_event_id="evt-direct-1", + source_app="pytest", + ) + second = memory.propose_write( + content="Source event idempotency payload", + user_id="u-source-event-direct", + agent_id="planner", + token=session["token"], + mode="direct", + trusted_direct=True, + infer=False, + source_event_id="evt-direct-1", + source_app="pytest", + ) + + assert first["mode"] == "direct" + assert second["mode"] == "direct" + assert second["result"]["idempotent"] is True + assert len(memory.db.get_all_memories(user_id="u-source-event-direct")) == 1 + + +def test_approved_retries_do_not_duplicate_memory_when_source_event_matches(memory): + session = memory.create_session(user_id="u-source-event-staging", agent_id="planner") + + first = memory.propose_write( + content="Retry-safe staged payload", + user_id="u-source-event-staging", + agent_id="planner", + token=session["token"], + mode="staging", + infer=False, + source_event_id="evt-staging-1", + ) + assert memory.approve_commit(first["commit_id"])["status"] == "APPROVED" + + second = memory.propose_write( + content="Retry-safe staged payload", + user_id="u-source-event-staging", + agent_id="planner", + token=session["token"], + mode="staging", + infer=False, + source_event_id="evt-staging-1", + ) + approved = memory.approve_commit(second["commit_id"]) + + assert approved["status"] == "APPROVED" + assert len(memory.db.get_all_memories(user_id="u-source-event-staging")) == 1 + + +def test_failed_commit_apply_rolls_back_added_memories(monkeypatch, memory): + commit = memory.kernel.staging_store.create_commit( + user_id="u-atomic", + agent_id="planner", + scope="work", + checks={"invariants_ok": True, "conflicts": [], "risk_score": 0.2}, + preview={}, + provenance={"source_type": "test"}, + changes=[ + { + "op": "ADD", + "target": "memory_item", + "patch": {"content": "First staged memory", "metadata": {"namespace": "default"}}, + }, + { + "op": "ADD", + "target": "memory_item", + "patch": {"content": "Second staged memory", "metadata": {"namespace": "default"}}, + }, + ], + ) + + original_apply = memory.kernel._apply_direct_write + call_count = {"n": 0} + + def flaky_apply(**kwargs): + call_count["n"] += 1 + if call_count["n"] == 2: + raise RuntimeError("forced apply failure") + return original_apply(**kwargs) + + monkeypatch.setattr(memory.kernel, "_apply_direct_write", flaky_apply) + + outcome = memory.approve_commit(commit["id"]) + assert outcome["error"] == "Commit apply failed" + assert outcome["rolled_back"] >= 1 + + all_memories = memory.db.get_all_memories(user_id="u-atomic") + assert all_memories == [] + + stored_commit = memory.kernel.staging_store.get_commit(commit["id"]) + assert stored_commit["status"] == "PENDING" + assert "apply_error" in stored_commit["checks"] + + +def test_write_quota_per_agent_blocks_excess_proposals(monkeypatch, memory): + monkeypatch.setenv("ENGRAM_V2_POLICY_WRITE_QUOTA_PER_AGENT_PER_HOUR", "1") + session = memory.create_session(user_id="u-quota-agent", agent_id="planner") + + first = memory.propose_write( + content="first quota proposal", + user_id="u-quota-agent", + agent_id="planner", + token=session["token"], + mode="staging", + infer=False, + ) + assert first["commit_id"] + + with pytest.raises(PermissionError, match="per-agent hourly"): + memory.propose_write( + content="second quota proposal", + user_id="u-quota-agent", + agent_id="planner", + token=session["token"], + mode="staging", + infer=False, + ) + + +def test_write_quota_per_user_applies_across_agents(monkeypatch, memory): + monkeypatch.setenv("ENGRAM_V2_POLICY_WRITE_QUOTA_PER_USER_PER_HOUR", "1") + planner = memory.create_session(user_id="u-quota-user", agent_id="planner") + codex = memory.create_session(user_id="u-quota-user", agent_id="codex") + + first = memory.propose_write( + content="first user quota proposal", + user_id="u-quota-user", + agent_id="planner", + token=planner["token"], + mode="staging", + infer=False, + ) + assert first["commit_id"] + + with pytest.raises(PermissionError, match="per-user hourly"): + memory.propose_write( + content="second user quota proposal", + user_id="u-quota-user", + agent_id="codex", + token=codex["token"], + mode="staging", + infer=False, + ) diff --git a/tests/test_trust_scores.py b/tests/test_trust_scores.py new file mode 100644 index 0000000..69bba22 --- /dev/null +++ b/tests/test_trust_scores.py @@ -0,0 +1,131 @@ +"""Tests for agent trust scoring and auto-merge behavior.""" + +from __future__ import annotations + +import pytest + +from engram import Engram + + +@pytest.fixture +def memory(): + eng = Engram(in_memory=True, provider="mock") + return eng._memory + + +def test_agent_trust_updates_on_approve_and_reject(memory): + session = memory.create_session( + user_id="u-trust", + agent_id="writer", + capabilities=["propose_write", "review_commits"], + namespaces=["default"], + ) + + first = memory.propose_write( + content="First trust candidate memory", + user_id="u-trust", + agent_id="writer", + token=session["token"], + mode="staging", + namespace="default", + infer=False, + ) + memory.approve_commit(first["commit_id"]) + + trust_after_approve = memory.get_agent_trust(user_id="u-trust", agent_id="writer") + assert trust_after_approve["total_proposals"] == 1 + assert trust_after_approve["approved_proposals"] == 1 + assert float(trust_after_approve["trust_score"]) > 0.9 + + second = memory.propose_write( + content="Second trust candidate memory", + user_id="u-trust", + agent_id="writer", + token=session["token"], + mode="staging", + namespace="default", + infer=False, + ) + memory.reject_commit(second["commit_id"], reason="not needed") + + trust_after_reject = memory.get_agent_trust(user_id="u-trust", agent_id="writer") + assert trust_after_reject["total_proposals"] == 2 + assert trust_after_reject["approved_proposals"] == 1 + assert trust_after_reject["rejected_proposals"] == 1 + assert 0.4 <= float(trust_after_reject["trust_score"]) < 1.0 + + +def test_high_trust_agent_can_auto_merge(monkeypatch, memory): + monkeypatch.setenv("ENGRAM_V2_TRUST_AUTOMERGE", "true") + monkeypatch.setenv("ENGRAM_V2_AUTO_MERGE_TRUST_THRESHOLD", "0.6") + monkeypatch.setenv("ENGRAM_V2_AUTO_MERGE_MIN_TOTAL", "1") + monkeypatch.setenv("ENGRAM_V2_AUTO_MERGE_MIN_APPROVED", "1") + monkeypatch.setenv("ENGRAM_V2_AUTO_MERGE_MAX_REJECT_RATE", "1.0") + + session = memory.create_session( + user_id="u-automerge", + agent_id="planner", + capabilities=["propose_write", "review_commits"], + namespaces=["default"], + ) + + baseline = memory.propose_write( + content="Baseline memory to build trust", + user_id="u-automerge", + agent_id="planner", + token=session["token"], + mode="staging", + namespace="default", + infer=False, + ) + memory.approve_commit(baseline["commit_id"]) + + auto = memory.propose_write( + content="This write should auto-merge", + user_id="u-automerge", + agent_id="planner", + token=session["token"], + mode="staging", + namespace="default", + infer=False, + ) + assert auto["status"] == "APPROVED" + assert auto.get("auto_merged") is True + + +def test_auto_merge_guardrails_block_low_evidence(monkeypatch, memory): + monkeypatch.setenv("ENGRAM_V2_TRUST_AUTOMERGE", "true") + monkeypatch.setenv("ENGRAM_V2_AUTO_MERGE_TRUST_THRESHOLD", "0.5") + monkeypatch.delenv("ENGRAM_V2_AUTO_MERGE_MIN_TOTAL", raising=False) + monkeypatch.delenv("ENGRAM_V2_AUTO_MERGE_MIN_APPROVED", raising=False) + monkeypatch.delenv("ENGRAM_V2_AUTO_MERGE_MAX_REJECT_RATE", raising=False) + + session = memory.create_session( + user_id="u-guard", + agent_id="planner", + capabilities=["propose_write", "review_commits"], + namespaces=["default"], + ) + + baseline = memory.propose_write( + content="Baseline trust seed for guardrails", + user_id="u-guard", + agent_id="planner", + token=session["token"], + mode="staging", + namespace="default", + infer=False, + ) + memory.approve_commit(baseline["commit_id"]) + + guarded = memory.propose_write( + content="Should stay pending because evidence is too low", + user_id="u-guard", + agent_id="planner", + token=session["token"], + mode="staging", + namespace="default", + infer=False, + ) + assert guarded["status"] == "PENDING" + assert guarded.get("auto_merged") is False diff --git a/tests/testlongmemeval_runner.py b/tests/testlongmemeval_runner.py new file mode 100644 index 0000000..1fcb3f1 --- /dev/null +++ b/tests/testlongmemeval_runner.py @@ -0,0 +1,175 @@ +"""Tests for LongMemEval benchmark runner helpers.""" + +import json +from argparse import Namespace + +from engram.benchmarks import longmemeval + + +def test_extract_user_only_text_filters_non_user_roles(): + turns = [ + {"role": "system", "content": "ignore"}, + {"role": "user", "content": "first user fact"}, + {"role": "assistant", "content": "ignore"}, + {"role": "user", "content": "second user fact"}, + ] + text = longmemeval.extract_user_only_text(turns) + assert text == "first user fact\nsecond user fact" + + +def test_parse_session_id_from_result_prefers_metadata(): + row = { + "metadata": {"session_id": "sid_meta"}, + "memory": "Session ID: sid_text", + } + assert longmemeval.parse_session_id_from_result(row) == "sid_meta" + + +def test_compute_session_metrics_hits_expected_scores(): + metrics = longmemeval.compute_session_metrics( + retrieved_session_ids=["s1", "s2", "s3"], + answer_session_ids=["s2", "s3"], + ) + assert metrics["recall_any@1"] == 0.0 + assert metrics["recall_any@3"] == 1.0 + assert metrics["recall_all@1"] == 0.0 + assert metrics["recall_all@3"] == 1.0 + + +def test_build_output_row_excludes_debug_fields_by_default(): + row = longmemeval.build_output_row( + question_id="q1", + hypothesis="answer", + retrieved_session_ids=["s1"], + retrieval_metrics={"recall_any@1": 1.0}, + include_debug_fields=False, + ) + assert row == {"question_id": "q1", "hypothesis": "answer"} + + +def test_build_memory_full_potential_enables_echo_category_graph(tmp_path): + memory = longmemeval.build_memory( + llm_provider="mock", + embedder_provider="simple", + vector_store_provider="memory", + embedding_dims=64, + history_db_path=str(tmp_path / "h.db"), + full_potential=True, + ) + assert memory.config.echo.enable_echo is True + assert memory.config.category.enable_categories is True + assert memory.config.graph.enable_graph is True + + +def test_build_memory_minimal_disables_echo_category_graph(tmp_path): + memory = longmemeval.build_memory( + llm_provider="mock", + embedder_provider="simple", + vector_store_provider="memory", + embedding_dims=64, + history_db_path=str(tmp_path / "h.db"), + full_potential=False, + ) + assert memory.config.echo.enable_echo is False + assert memory.config.category.use_llm_categorization is False + assert memory.config.graph.enable_graph is False + + +class _StubLLM: + def generate(self, prompt: str) -> str: + _ = prompt + return "stub hypothesis" + + +class _StubMemory: + def __init__(self): + self.llm = _StubLLM() + self.deleted = [] + self.added = [] + + def delete_all(self, user_id: str): + self.deleted.append(user_id) + return {"deleted_count": 0} + + def add(self, **kwargs): + self.added.append(kwargs) + return {"id": f"mem_{len(self.added)}"} + + def search_with_context(self, **kwargs): + _ = kwargs + return { + "results": [ + { + "metadata": {"session_id": "session_1"}, + "memory": "Session ID: session_1\nUser Transcript:\nalpha", + }, + { + "metadata": {"session_id": "session_2"}, + "memory": "Session ID: session_2\nUser Transcript:\nbeta", + }, + ] + } + + +def test_run_longmemeval_writes_eval_compatible_jsonl(monkeypatch, tmp_path): + dataset = [ + { + "question_id": "q_001", + "question": "Where did I have dinner last week?", + "answer_session_ids": ["session_1"], + "haystack_session_ids": ["session_1", "session_2"], + "haystack_dates": ["2026-01-01", "2026-01-02"], + "haystack_sessions": [ + [{"role": "user", "content": "I had dinner at Juniper Lane."}], + [{"role": "user", "content": "I bought a notebook."}], + ], + } + ] + dataset_path = tmp_path / "longmemeval_small.json" + dataset_path.write_text(json.dumps(dataset), encoding="utf-8") + + output_path = tmp_path / "hypotheses.jsonl" + retrieval_path = tmp_path / "retrieval.jsonl" + + stub_memory = _StubMemory() + monkeypatch.setattr(longmemeval, "build_memory", lambda **_: stub_memory) + + args = Namespace( + dataset_path=str(dataset_path), + output_jsonl=str(output_path), + retrieval_jsonl=str(retrieval_path), + include_debug_fields=False, + full_potential=True, + user_id="u_test", + start_index=0, + end_index=-1, + max_questions=-1, + skip_abstention=False, + top_k=3, + max_context_chars=2048, + print_every=0, + answer_backend="engram-llm", + hf_model="Qwen/Qwen2.5-1.5B-Instruct", + hf_max_new_tokens=64, + llm_provider="mock", + llm_model=None, + embedder_provider="simple", + embedder_model=None, + vector_store_provider="memory", + embedding_dims=384, + history_db_path=str(tmp_path / "history.db"), + qdrant_path=None, + ) + + summary = longmemeval.run_longmemeval(args) + assert summary["processed"] == 1 + assert len(stub_memory.added) == 2 + assert stub_memory.deleted == ["u_test"] + + rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + assert rows == [{"question_id": "q_001", "hypothesis": "stub hypothesis"}] + + retrieval_rows = [json.loads(line) for line in retrieval_path.read_text(encoding="utf-8").splitlines()] + assert retrieval_rows[0]["question_id"] == "q_001" + assert retrieval_rows[0]["retrieved_session_ids"] == ["session_1", "session_2"] + assert retrieval_rows[0]["metrics"]["recall_any@1"] == 1.0 From 9cce9f6d4f5802e57afd50d075f8e92690afbabd Mon Sep 17 00:00:00 2001 From: Vivek Kumar Date: Wed, 11 Feb 2026 22:30:43 +0530 Subject: [PATCH 2/2] feat: add Active Memory signal bus, sqlite-vec store, and consolidation engine Introduces the Active/Passive memory architecture inspired by conscious/subconscious brain systems. Active Memory is a real-time signal bus (SQLite WAL) where agents post ephemeral signals with TTL tiers (noise/notable/critical/directive). Every MCP response auto-injects the latest signals. A consolidation engine promotes important signals to passive memory. Also adds sqlite-vec as a vector store backend for concurrent multi-agent access without Qdrant single-process file lock. New files: - engram/configs/active.py - config models (TTLTier, SignalType, SignalScope enums) - engram/core/active_memory.py - ActiveMemoryStore (write/read/clear/gc/consolidation) - engram/core/consolidation.py - ConsolidationEngine (active to passive promotion) - engram/vector_stores/sqlite_vec.py - SqliteVecStore (all 11 VectorStoreBase methods) - tests/test_active_memory.py - 24 tests - tests/test_sqlite_vec.py - 23 tests - tests/test_consolidation.py - 8 tests Co-Authored-By: Claude Opus 4.6 --- README.md | 203 +++++++++++----- engram/configs/active.py | 50 ++++ engram/configs/base.py | 3 + engram/core/active_memory.py | 336 ++++++++++++++++++++++++++ engram/core/consolidation.py | 112 +++++++++ engram/mcp_server.py | 193 +++++++++++++++ engram/memory/main.py | 19 ++ engram/utils/factory.py | 4 + engram/vector_stores/sqlite_vec.py | 362 +++++++++++++++++++++++++++++ pyproject.toml | 2 + tests/test_active_memory.py | 201 ++++++++++++++++ tests/test_consolidation.py | 108 +++++++++ tests/test_sqlite_vec.py | 233 +++++++++++++++++++ 13 files changed, 1769 insertions(+), 57 deletions(-) create mode 100644 engram/configs/active.py create mode 100644 engram/core/active_memory.py create mode 100644 engram/core/consolidation.py create mode 100644 engram/vector_stores/sqlite_vec.py create mode 100644 tests/test_active_memory.py create mode 100644 tests/test_consolidation.py create mode 100644 tests/test_sqlite_vec.py diff --git a/README.md b/README.md index 49dcc9f..fc86d7e 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@

Hit a rate limit in Claude Code? Open Codex — it already knows what you were doing.
- One memory kernel. Shared across every agent. Bio-inspired forgetting. Staged writes. Episodic recall. + One memory kernel. Shared across every agent. Active + passive memory. Bio-inspired forgetting. Real-time multi-agent coordination.

@@ -51,11 +51,13 @@ But Engram isn't just a handoff bus. It solves four fundamental problems with ho | Problem | Other Memory Layers | **Engram** | |:--------|:--------------------|:-----------| | **Switching agents = cold start** | Manual copy/paste context | **Handoff bus — session digests, auto-resume** | +| **No real-time coordination** | Polling or nothing | **Active Memory signal bus — agents see each other's state instantly** | | **Nobody forgets** | Store everything forever | **Ebbinghaus decay curve, ~45% less storage** | | **Agents write with no oversight** | Store directly | **Staging + verification + trust scoring** | | **No episodic memory** | Vector search only | **CAST scenes (time/place/topic)** | | Multi-modal encoding | Single embedding | **5 retrieval paths (EchoMem)** | | Cross-agent memory sharing | Per-agent silos | **Scoped retrieval with all-but-mask privacy** | +| Concurrent multi-agent access | Single-process locks | **sqlite-vec WAL mode — multiple agents, one DB** | | Reference-aware decay | No | **If other agents use it, don't delete it** | | Knowledge graph | Sometimes | **Entity extraction + linking** | | MCP + REST | One or the other | **Both, plug-and-play** | @@ -79,9 +81,12 @@ Restart your agent. Done — it now has persistent memory across sessions. # Default runtime (Gemini + local Qdrant + MemoryClient deps) pip install engram-memory -# Full stack extras (MCP server + REST API + async + all providers) +# Full stack extras (MCP server + REST API + async + sqlite-vec + all providers) pip install "engram-memory[all]" +# sqlite-vec for concurrent multi-agent vector search (no server needed) +pip install "engram-memory[sqlite_vec]" + # OpenAI provider add-on pip install "engram-memory[openai]" @@ -130,12 +135,18 @@ docker compose up -d # API at http://localhost:8100 ## Architecture -Engram is a **Personal Memory Kernel** — not just a vector store with an API. It has opinions about how memory should work: +Engram is a **Personal Memory Kernel** — not just a vector store with an API. It models memory the way brains do, with two distinct systems: + +- **Active Memory (conscious):** A real-time signal bus where agents post ephemeral state and events. Every MCP response includes the latest active signals — like how your conscious mind always knows "what's happening right now." Signals auto-expire by TTL tier. Important ones get consolidated into passive memory. +- **Passive Memory (subconscious):** The long-term store — FadeMem decay, EchoMem encoding, CategoryMem organization, CAST scenes. Things the agent "knows" but isn't actively thinking about. + +Engram has five opinions about how memory should work: 1. **Switching agents shouldn't mean starting over.** When an agent pauses — rate limit, crash, tool switch — it saves a session digest. The next agent loads it and continues. Zero re-explanation. -2. **Memory has a lifecycle.** New memories start in short-term (SML), get promoted to long-term (LML) through repeated access, and fade away through Ebbinghaus decay if unused. -3. **Agents are untrusted writers.** Every write is a proposal that lands in staging. Trusted agents can auto-merge; untrusted ones wait for approval. -4. **Scoping is mandatory.** Every memory is scoped by user. Agents see only what they're allowed to — everything else gets the "all but mask" treatment (structure visible, details redacted). +2. **Agents need shared real-time state.** Active Memory lets agents broadcast what they're doing right now — no polling, no coordination protocol. Agent A posts "editing auth.py"; Agent B sees it instantly. +3. **Memory has a lifecycle.** New memories start in short-term (SML), get promoted to long-term (LML) through repeated access, and fade away through Ebbinghaus decay if unused. +4. **Agents are untrusted writers.** Every write is a proposal that lands in staging. Trusted agents can auto-merge; untrusted ones wait for approval. +5. **Scoping is mandatory.** Every memory is scoped by user. Agents see only what they're allowed to — everything else gets the "all but mask" treatment (structure visible, details redacted). ``` ┌─────────────────────────────────────────────────────────────────┐ @@ -151,54 +162,83 @@ Engram is a **Personal Memory Kernel** — not just a vector store with an API. │ Server │ │ API │ └────┬─────┘ └────┬─────┘ └───────────┬──────────┘ - ▼ - ┌────────────────────────────────────┐ - │ Policy Gateway │ - │ Scopes · Masking · Quotas · │ - │ Capability Tokens · Trust Score │ - └────────────────┬───────────────────┘ │ - ┌──────────┴──────────┐ - ▼ ▼ - ┌──────────────────┐ ┌──────────────────┐ - │ Retrieval Engine │ │ Ingestion Pipeline│ - │ ┌─────────────┐ │ │ │ - │ │Semantic │ │ │ Text → Views │ - │ │(hybrid+graph│ │ │ Views → Scenes │ - │ │+categories) │ │ │ Scenes → LML │ - │ ├─────────────┤ │ │ │ - │ │Episodic │ │ └────────┬─────────┘ - │ │(CAST scenes)│ │ │ - │ └─────────────┘ │ ▼ - │ │ ┌──────────────────┐ - │ Intersection │ │Write Verification│ - │ Promotion: │ │ │ - │ match in both → │ │ Invariant checks │ - │ boost score │ │ Conflict → stash │ - └──────────────────┘ │ Trust scoring │ - └────────┬─────────┘ - │ - ┌───────────────────┼───────────────────┐ - ▼ ▼ ▼ - ┌──────────────────┐ ┌──────────────┐ ┌──────────────────┐ - │ Staging (SML) │ │ Long-Term │ │ Indexes │ - │ Proposals+Diffs │ │ Store (LML) │ │ Vector + Graph │ - │ Conflict Stash │ │ Canonical │ │ + Episodic │ - └──────────────────┘ └──────────────┘ └──────────────────┘ - │ │ │ - └───────────────────┼───────────────────┘ - ▼ - ┌──────────────────┐ - │ FadeMem GC │ - │ Ref-aware decay │ - │ If other agents │ - │ use it → keep │ - └──────────────────┘ + ┌────────────────┴────────────────────┐ + │ Policy Gateway │ + │ Scopes · Masking · Quotas · │ + │ Capability Tokens · Trust Score │ + └────────────────┬────────────────────┘ + │ + ┌───────────────────┼───────────────────┐ + ▼ ▼ ▼ +┌──────────┐ ┌──────────────────┐ ┌──────────────────┐ +│ ACTIVE │ │ Retrieval Engine │ │ Ingestion Pipeline│ +│ MEMORY │ │ ┌─────────────┐ │ │ │ +│ │ │ │Semantic │ │ │ Text → Views │ +│ Signal │ │ │(hybrid+graph│ │ │ Views → Scenes │ +│ Bus │ │ │+categories) │ │ │ Scenes → LML │ +│ │ │ ├─────────────┤ │ │ │ +│ state │ │ │Episodic │ │ └────────┬─────────┘ +│ event │ │ │(CAST scenes)│ │ │ +│ directive│ │ └─────────────┘ │ ▼ +│ │ │ Intersection │ ┌──────────────────┐ +│ Auto- │ │ Promotion: │ │Write Verification│ +│ injected │ │ match in both → │ │ Invariant checks │ +│ in every │ │ boost score │ │ Conflict → stash │ +│ response │ └──────────────────┘ │ Trust scoring │ +│ │ └────────┬─────────┘ +│ ┌─────┐ │ │ +│ │ TTL │ │ ┌─────────────────────────┼──────────────┐ +│ └──┬──┘ │ ▼ ▼ ▼ +│ │ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ +│ ▼ │ │ Staging (SML)│ │ Long-Term │ │ Indexes │ +│ Consoli- │ │ Proposals │ │ Store (LML) │ │ Vector+Graph │ +│ dation │ │ Conflict │ │ Canonical │ │ + Episodic │ +│ Engine ──┼──│ Stash │ │ │ │ │ +│ │ └──────────────┘ └──────────────┘ └──────────────┘ +└──────────┘ │ │ │ + └─────────────────┼──────────────────┘ + ▼ + ┌──────────────────┐ + │ FadeMem GC │ + │ Ref-aware decay │ + │ If other agents │ + │ use it → keep │ + └──────────────────┘ ``` ### The Memory Stack -Engram combines five systems, each handling a different aspect of how memory should work: +Engram combines seven systems, each handling a different aspect of how memory should work: + +#### Active Memory — Real-Time Signal Bus + +The "conscious mind" of the system. A shared SQLite bus (WAL mode, concurrent-safe) where agents post ephemeral signals that other agents see immediately. Every MCP tool response auto-injects the latest signals — agents always know what's happening without polling. + +Three signal types, four TTL tiers: + +``` +Signal Types: TTL Tiers: + state → upserts by key noise → 30 minutes + event → always new row notable → 2 hours + directive → permanent rule critical → 24 hours + directive → permanent +``` + +``` +Agent A: signal_write(key="editing", value="auth.py", signal_type="state") +Agent B: signal_write(key="build_status", value="failing", signal_type="event") +User: signal_write(key="use_typescript", value="always", signal_type="directive") + +Any tool call → response includes: + "_active": [ + {"key": "use_typescript", "ttl_tier": "directive", ...}, + {"key": "build_status", "ttl_tier": "notable", ...}, + {"key": "editing", "ttl_tier": "notable", ...} + ] +``` + +**Consolidation Engine:** Important active signals get promoted to passive memory — like how sleep consolidates short-term into long-term memory. Directives become immutable LML memories. Critical and high-read signals get promoted to SML. #### FadeMem — Decay & Consolidation @@ -409,6 +449,11 @@ Once configured, your agent has access to these tools: | Tool | Description | |:-----|:------------| +| **Active Memory** | | +| `signal_write` | Post a signal to the bus (state/event/directive with TTL tiers) | +| `signal_read` | Read active signals, ordered by priority | +| `signal_clear` | Clear signals by key, agent, scope, or type | +| **Passive Memory** | | | `add_memory` | Store a new memory (lands in staging by default) | | `search_memory` | Semantic + keyword + episodic search | | `get_all_memories` | List all stored memories for a user | @@ -423,12 +468,12 @@ Once configured, your agent has access to these tools: | `list_pending_commits` | Inspect staged write queue | | `resolve_conflict` | Resolve invariant conflicts (accept proposed or keep existing) | | `search_scenes` / `get_scene` | Episodic CAST scene retrieval with masking policy | +| **Handoff** | | | `save_session_digest` | Save handoff context when pausing or switching agents | | `get_last_session` | Load session context from the last active agent | | `list_sessions` | Browse handoff history across agents | -Auto-lifecycle behavior is server-driven: when `auto_session_bus` is enabled, -Engram writes handoff checkpoints without explicit user prompts. +Every tool response auto-includes an `_active` field with the latest signals from the Active Memory bus. Auto-lifecycle behavior is server-driven: when `auto_session_bus` is enabled, Engram writes handoff checkpoints without explicit user prompts. --- @@ -519,7 +564,7 @@ curl "http://localhost:8100/v1/handoff/sessions?user_id=u123&repo=/repo&limit=20 - `hosted_backend_unavailable`: verify `ENGRAM_API_URL` and network reachability. - `missing_or_expired_token` / `missing_capability`: ensure the caller has a valid session token with `read_handoff` or `write_handoff`. -- `Storage folder ... qdrant is already accessed`: local Qdrant file mode is single-process; use hosted API mode or a shared Qdrant server for concurrent agents. +- `Storage folder ... qdrant is already accessed`: local Qdrant file mode is single-process. Fix: switch to `sqlite_vec` provider (`pip install "engram-memory[sqlite_vec]"`) which uses WAL mode for concurrent access, or use hosted API mode / a shared Qdrant server. ### Python SDK @@ -570,6 +615,12 @@ memory.fuse(memory_ids) # Combine related memories memory.decay(user_id=None) # Apply forgetting memory.history(memory_id) # Access history +# Active Memory (signal bus) +memory.active.write_signal(key="editing", value="auth.py", signal_type="state") +memory.active.read_signals(user_id="default") +memory.active.clear_signals(key="editing") +memory.consolidate_active() # Promote important signals → passive memory + # Knowledge graph memory.get_related_memories(memory_id) # Graph traversal memory.get_memory_entities(memory_id) # Extracted entities @@ -632,10 +683,16 @@ export ENGRAM_V2_AUTO_MERGE_TRUST_THRESHOLD="0.85" # Trust threshold for auto **Python config:** ```python -from engram.configs.base import MemoryConfig, FadeMemConfig, EchoMemConfig, CategoryMemConfig +from engram.configs.base import MemoryConfig, VectorStoreConfig, FadeMemConfig, EchoMemConfig, CategoryMemConfig +from engram.configs.active import ActiveMemoryConfig config = MemoryConfig( - fadem=FadeMemConfig( + # Use sqlite-vec for concurrent multi-agent access (no server needed) + vector_store=VectorStoreConfig( + provider="sqlite_vec", + config={"path": "~/.engram/vectors.db"}, + ), + engram=FadeMemConfig( enable_forgetting=True, sml_decay_rate=0.15, lml_decay_rate=0.02, @@ -653,6 +710,13 @@ config = MemoryConfig( enable_category_decay=True, max_category_depth=3, ), + # Active Memory signal bus + active=ActiveMemoryConfig( + enabled=True, + db_path="~/.engram/active.db", + default_ttl_tier="notable", + consolidation_enabled=True, + ), ) ``` @@ -663,7 +727,7 @@ config = MemoryConfig( Engram is designed for agent orchestrators. Every memory is scoped by `user_id` and optionally `agent_id`: ```python -# Research agent stores knowledge +# Research agent stores knowledge (passive memory) memory.add("OAuth 2.0 with JWT tokens", user_id="project_123", agent_id="researcher") @@ -676,6 +740,31 @@ memory.add("Security review passed", user_id="project_123", agent_id="reviewer") ``` +**Real-time coordination with Active Memory:** + +```python +# Agent A broadcasts what it's working on +memory.active.write_signal( + key="editing_file", value="src/auth.py", + signal_type="state", source_agent_id="agent-A" +) + +# Agent B posts a build event +memory.active.write_signal( + key="build", value="tests failing: 3 errors in auth module", + signal_type="event", ttl_tier="critical" +) + +# User sets a permanent directive +memory.active.write_signal( + key="style_rule", value="Always use TypeScript for new files", + signal_type="directive" # Never expires, auto-promoted to passive memory +) + +# Any agent reads the bus — or it's auto-injected into every MCP response +signals = memory.active.read_signals(user_id="default") +``` + **Agent trust scoring** determines write permissions: - High-trust agents (>0.85): proposals auto-merge - Medium-trust: queued for daily digest review @@ -696,7 +785,7 @@ Engram is based on: | Multi-hop Reasoning | +12% accuracy | | Retrieval Precision | +8% on LTI-Bench | -Biological inspirations: Ebbinghaus Forgetting Curve → exponential decay, Spaced Repetition → access boosts strength, Sleep Consolidation → SML → LML promotion, Production Effect → echo encoding, Elaborative Encoding → deeper processing = stronger memory. +Biological inspirations: Ebbinghaus Forgetting Curve → exponential decay, Spaced Repetition → access boosts strength, Sleep Consolidation → SML → LML promotion, Working Memory → Active Memory signal bus, Conscious/Subconscious Split → Active vs Passive memory, Production Effect → echo encoding, Elaborative Encoding → deeper processing = stronger memory. --- @@ -883,7 +972,7 @@ MIT License — see [LICENSE](LICENSE) for details. ---

- Switch agents without losing context. Stop re-explaining yourself. + One memory. Every agent. Real-time coordination. Zero cold starts.

GitHub · Issues · diff --git a/engram/configs/active.py b/engram/configs/active.py new file mode 100644 index 0000000..9e0f00a --- /dev/null +++ b/engram/configs/active.py @@ -0,0 +1,50 @@ +"""Configuration models for Active Memory (signal bus) and consolidation.""" + +from enum import Enum +from typing import Dict + +from pydantic import BaseModel, Field + + +class TTLTier(str, Enum): + NOISE = "noise" # 30 min + NOTABLE = "notable" # 2 hours + CRITICAL = "critical" # 24 hours + DIRECTIVE = "directive" # permanent (no expiry) + + +class SignalType(str, Enum): + STATE = "state" # Current status ("agent-X is editing file Y") + EVENT = "event" # One-shot occurrence ("build failed") + DIRECTIVE = "directive" # Permanent user rule ("always use TypeScript") + + +class SignalScope(str, Enum): + GLOBAL = "global" # All agents see it + REPO = "repo" # Only agents in same repo + NAMESPACE = "namespace" # Only agents in same namespace + + +class ActiveMemoryConfig(BaseModel): + """Configuration for the Active Memory signal bus.""" + enabled: bool = True + db_path: str = Field(default="~/.engram/active.db") + default_ttl_tier: str = "notable" + ttl_seconds: Dict[str, int] = Field(default_factory=lambda: { + "noise": 1800, # 30 min + "notable": 7200, # 2 hours + "critical": 86400, # 24 hours + "directive": 0, # permanent + }) + max_signals_per_response: int = 10 + consolidation_enabled: bool = True + consolidation_min_age_seconds: int = 600 + consolidation_min_reads: int = 3 + + +class ConsolidationConfig(BaseModel): + """Configuration for active → passive memory consolidation.""" + promote_critical: bool = True + promote_high_read: bool = True + promote_read_threshold: int = 3 + directive_to_passive: bool = True diff --git a/engram/configs/base.py b/engram/configs/base.py index ce7bba8..02c69f0 100644 --- a/engram/configs/base.py +++ b/engram/configs/base.py @@ -3,6 +3,8 @@ from pydantic import BaseModel, Field +from engram.configs.active import ActiveMemoryConfig + class VectorStoreConfig(BaseModel): provider: str = Field(default="qdrant") @@ -190,3 +192,4 @@ class MemoryConfig(BaseModel): scene: SceneConfig = Field(default_factory=SceneConfig) profile: ProfileConfig = Field(default_factory=ProfileConfig) handoff: HandoffConfig = Field(default_factory=HandoffConfig) + active: ActiveMemoryConfig = Field(default_factory=ActiveMemoryConfig) diff --git a/engram/core/active_memory.py b/engram/core/active_memory.py new file mode 100644 index 0000000..8ff1f03 --- /dev/null +++ b/engram/core/active_memory.py @@ -0,0 +1,336 @@ +""" +Active Memory Store — real-time signal bus for multi-agent coordination. + +Signals are ephemeral messages with TTL tiers that auto-expire. +Uses a separate SQLite database with WAL mode for concurrent access. +""" + +import json +import logging +import os +import sqlite3 +import threading +import uuid +from contextlib import contextmanager +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional + +from engram.configs.active import ActiveMemoryConfig + +logger = logging.getLogger(__name__) + +# TTL tier ordering for display priority (highest first) +_TIER_PRIORITY = {"directive": 0, "critical": 1, "notable": 2, "noise": 3} + + +def _utcnow() -> datetime: + return datetime.now(timezone.utc) + + +def _utcnow_iso() -> str: + return _utcnow().isoformat() + + +class ActiveMemoryStore: + """SQLite-backed signal bus for active memory.""" + + def __init__(self, config: Optional[ActiveMemoryConfig] = None): + self.config = config or ActiveMemoryConfig() + db_path = os.path.expanduser(self.config.db_path) + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + + self._conn = sqlite3.connect(db_path, check_same_thread=False) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA busy_timeout=5000") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.row_factory = sqlite3.Row + self._lock = threading.RLock() + self._init_db() + + def _init_db(self) -> None: + with self._get_connection() as conn: + conn.executescript(""" + CREATE TABLE IF NOT EXISTS signals ( + id TEXT PRIMARY KEY, + signal_type TEXT NOT NULL CHECK (signal_type IN ('state', 'event', 'directive')), + scope TEXT NOT NULL DEFAULT 'global' CHECK (scope IN ('global', 'repo', 'namespace')), + scope_key TEXT, + ttl_tier TEXT NOT NULL DEFAULT 'notable', + key TEXT NOT NULL, + value TEXT NOT NULL, + source_agent_id TEXT, + user_id TEXT DEFAULT 'default', + read_count INTEGER DEFAULT 0, + read_by TEXT DEFAULT '[]', + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + expires_at TEXT, + consolidated INTEGER DEFAULT 0 + ); + CREATE INDEX IF NOT EXISTS idx_signals_scope ON signals(scope, scope_key); + CREATE INDEX IF NOT EXISTS idx_signals_key ON signals(key); + CREATE INDEX IF NOT EXISTS idx_signals_expires ON signals(expires_at); + CREATE INDEX IF NOT EXISTS idx_signals_user ON signals(user_id); + """) + + @contextmanager + def _get_connection(self): + with self._lock: + yield self._conn + + def _compute_expires_at(self, ttl_tier: str) -> Optional[str]: + """Compute expiration timestamp based on TTL tier.""" + ttl_seconds = self.config.ttl_seconds.get(ttl_tier, 0) + if ttl_seconds <= 0: + return None # permanent + expires = _utcnow() + timedelta(seconds=ttl_seconds) + return expires.isoformat() + + def write_signal( + self, + *, + key: str, + value: str, + signal_type: str = "state", + scope: str = "global", + scope_key: Optional[str] = None, + ttl_tier: Optional[str] = None, + source_agent_id: Optional[str] = None, + user_id: str = "default", + ) -> Dict[str, Any]: + """Write a signal to the active memory bus. + + - state: UPSERT by (key, source_agent_id, scope, scope_key) + - event: always INSERT new row + - directive: UPSERT by key, expires_at=NULL (permanent) + """ + ttl_tier = ttl_tier or self.config.default_ttl_tier + if signal_type == "directive": + ttl_tier = "directive" + + expires_at = self._compute_expires_at(ttl_tier) + now = _utcnow_iso() + + with self._get_connection() as conn: + if signal_type == "state": + # Upsert: overwrite existing state signal with same key+agent+scope + existing = conn.execute( + """SELECT id FROM signals + WHERE key = ? AND source_agent_id IS ? AND scope = ? AND scope_key IS ? + AND signal_type = 'state' AND user_id = ?""", + (key, source_agent_id, scope, scope_key, user_id), + ).fetchone() + if existing: + conn.execute( + """UPDATE signals SET value = ?, ttl_tier = ?, expires_at = ?, created_at = ? + WHERE id = ?""", + (value, ttl_tier, expires_at, now, existing["id"]), + ) + conn.commit() + return {"id": existing["id"], "action": "updated", "key": key} + + elif signal_type == "directive": + # Upsert by key (directives are global per key) + existing = conn.execute( + """SELECT id FROM signals + WHERE key = ? AND signal_type = 'directive' AND user_id = ?""", + (key, user_id), + ).fetchone() + if existing: + conn.execute( + """UPDATE signals SET value = ?, source_agent_id = ?, scope = ?, scope_key = ?, + created_at = ?, expires_at = NULL + WHERE id = ?""", + (value, source_agent_id, scope, scope_key, now, existing["id"]), + ) + conn.commit() + return {"id": existing["id"], "action": "updated", "key": key} + + # event signals always INSERT; state/directive fall through if no existing row + signal_id = str(uuid.uuid4()) + conn.execute( + """INSERT INTO signals (id, signal_type, scope, scope_key, ttl_tier, key, value, + source_agent_id, user_id, created_at, expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + (signal_id, signal_type, scope, scope_key, ttl_tier, key, value, + source_agent_id, user_id, now, expires_at), + ) + conn.commit() + return {"id": signal_id, "action": "created", "key": key} + + def read_signals( + self, + *, + scope: Optional[str] = None, + scope_key: Optional[str] = None, + signal_type: Optional[str] = None, + user_id: str = "default", + reader_agent_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> List[Dict[str, Any]]: + """Read active signals, auto-GC expired, increment read counts.""" + self.gc_expired() + + conditions = ["user_id = ?"] + params: List[Any] = [user_id] + + if scope: + conditions.append("scope = ?") + params.append(scope) + if scope_key is not None: + conditions.append("scope_key = ?") + params.append(scope_key) + if signal_type: + conditions.append("signal_type = ?") + params.append(signal_type) + + where = " AND ".join(conditions) + effective_limit = limit or self.config.max_signals_per_response + + with self._get_connection() as conn: + rows = conn.execute( + f"""SELECT * FROM signals WHERE {where} + ORDER BY + CASE ttl_tier + WHEN 'directive' THEN 0 + WHEN 'critical' THEN 1 + WHEN 'notable' THEN 2 + WHEN 'noise' THEN 3 + END, + created_at DESC + LIMIT ?""", + params + [effective_limit], + ).fetchall() + + results = [] + ids_to_update = [] + for row in rows: + signal = dict(row) + # Parse read_by JSON + try: + signal["read_by"] = json.loads(signal.get("read_by", "[]")) + except (json.JSONDecodeError, TypeError): + signal["read_by"] = [] + results.append(signal) + ids_to_update.append(signal["id"]) + + # Track reader + if reader_agent_id and reader_agent_id not in signal["read_by"]: + signal["read_by"].append(reader_agent_id) + + # Batch update read counts + if ids_to_update: + for signal in results: + conn.execute( + "UPDATE signals SET read_count = read_count + 1, read_by = ? WHERE id = ?", + (json.dumps(signal["read_by"]), signal["id"]), + ) + conn.commit() + + return results + + def clear_signals( + self, + *, + key: Optional[str] = None, + scope: Optional[str] = None, + scope_key: Optional[str] = None, + source_agent_id: Optional[str] = None, + signal_type: Optional[str] = None, + user_id: str = "default", + ) -> Dict[str, Any]: + """Clear signals matching the given criteria.""" + conditions = ["user_id = ?"] + params: List[Any] = [user_id] + + if key: + conditions.append("key = ?") + params.append(key) + if scope: + conditions.append("scope = ?") + params.append(scope) + if scope_key is not None: + conditions.append("scope_key = ?") + params.append(scope_key) + if source_agent_id: + conditions.append("source_agent_id = ?") + params.append(source_agent_id) + if signal_type: + conditions.append("signal_type = ?") + params.append(signal_type) + + where = " AND ".join(conditions) + + with self._get_connection() as conn: + cursor = conn.execute(f"DELETE FROM signals WHERE {where}", params) + conn.commit() + return {"deleted": cursor.rowcount} + + def gc_expired(self) -> int: + """Garbage collect expired signals (except directives which never expire).""" + now = _utcnow_iso() + with self._get_connection() as conn: + cursor = conn.execute( + "DELETE FROM signals WHERE expires_at IS NOT NULL AND expires_at < ? AND signal_type != 'directive'", + (now,), + ) + conn.commit() + return cursor.rowcount + + def get_consolidation_candidates( + self, + *, + min_age_seconds: Optional[int] = None, + min_reads: Optional[int] = None, + ) -> List[Dict[str, Any]]: + """Get signals eligible for consolidation to passive memory.""" + min_age = min_age_seconds if min_age_seconds is not None else self.config.consolidation_min_age_seconds + min_reads_val = min_reads if min_reads is not None else self.config.consolidation_min_reads + cutoff = (_utcnow() - timedelta(seconds=min_age)).isoformat() + + with self._get_connection() as conn: + rows = conn.execute( + """SELECT * FROM signals + WHERE consolidated = 0 + AND created_at < ? + AND ( + signal_type = 'directive' + OR ttl_tier = 'critical' + OR read_count >= ? + ) + ORDER BY created_at ASC""", + (cutoff, min_reads_val), + ).fetchall() + + results = [] + for row in rows: + signal = dict(row) + try: + signal["read_by"] = json.loads(signal.get("read_by", "[]")) + except (json.JSONDecodeError, TypeError): + signal["read_by"] = [] + results.append(signal) + return results + + def mark_consolidated(self, signal_ids: List[str]) -> None: + """Mark signals as consolidated (promoted to passive memory).""" + if not signal_ids: + return + with self._get_connection() as conn: + placeholders = ",".join("?" for _ in signal_ids) + conn.execute( + f"UPDATE signals SET consolidated = 1 WHERE id IN ({placeholders})", + signal_ids, + ) + conn.commit() + + def close(self) -> None: + """Close the database connection.""" + with self._lock: + if self._conn: + try: + self._conn.close() + except Exception: + pass + self._conn = None # type: ignore[assignment] diff --git a/engram/core/consolidation.py b/engram/core/consolidation.py new file mode 100644 index 0000000..67fbf3d --- /dev/null +++ b/engram/core/consolidation.py @@ -0,0 +1,112 @@ +""" +Consolidation Engine — promotes important active signals to passive memory. + +Mirrors how the brain consolidates short-term memory into long-term during rest: +- Directives are always promoted (permanent rules) +- Critical-tier signals are promoted (high importance) +- High-read signals are promoted (frequently accessed = important) +""" + +import logging +from typing import Any, Dict, TYPE_CHECKING + +from engram.configs.active import ActiveMemoryConfig, ConsolidationConfig +from engram.core.active_memory import ActiveMemoryStore + +if TYPE_CHECKING: + from engram.memory.main import Memory + +logger = logging.getLogger(__name__) + + +class ConsolidationEngine: + """Promotes qualifying active signals into passive (Engram) memory.""" + + def __init__( + self, + active_store: ActiveMemoryStore, + memory: "Memory", + config: ActiveMemoryConfig, + ): + self.active = active_store + self.memory = memory + self.config = config + self.consolidation = ConsolidationConfig() + + def run_cycle(self) -> Dict[str, Any]: + """Run one consolidation cycle. Returns promotion stats.""" + candidates = self.active.get_consolidation_candidates( + min_age_seconds=self.config.consolidation_min_age_seconds, + min_reads=self.config.consolidation_min_reads, + ) + + promoted = [] + skipped = 0 + errors = 0 + + for signal in candidates: + if not self._should_promote(signal): + skipped += 1 + continue + try: + self._promote_to_passive(signal) + promoted.append(signal["id"]) + except Exception: + logger.exception("Failed to promote signal %s", signal["id"]) + errors += 1 + + if promoted: + self.active.mark_consolidated(promoted) + + return { + "promoted": len(promoted), + "checked": len(candidates), + "skipped": skipped, + "errors": errors, + } + + def _should_promote(self, signal: Dict[str, Any]) -> bool: + """Determine if a signal qualifies for promotion to passive memory.""" + signal_type = signal.get("signal_type", "") + ttl_tier = signal.get("ttl_tier", "") + read_count = signal.get("read_count", 0) + + # Directives always promote + if signal_type == "directive" and self.consolidation.directive_to_passive: + return True + + # Critical tier promotes + if ttl_tier == "critical" and self.consolidation.promote_critical: + return True + + # High-read signals promote + if ( + self.consolidation.promote_high_read + and read_count >= self.consolidation.promote_read_threshold + ): + return True + + return False + + def _promote_to_passive(self, signal: Dict[str, Any]) -> None: + """Add a signal's content to passive memory via Memory.add().""" + signal_type = signal.get("signal_type", "event") + user_id = signal.get("user_id", "default") + key = signal.get("key", "") + value = signal.get("value", "") + + # Build content string + content = f"[{key}] {value}" if key else value + + self.memory.add( + messages=content, + user_id=user_id, + metadata={ + "source": "active_signal", + "signal_key": key, + "signal_type": signal_type, + }, + immutable=(signal_type == "directive"), + initial_layer="lml" if signal_type == "directive" else "sml", + infer=False, + ) diff --git a/engram/mcp_server.py b/engram/mcp_server.py index f37d6c4..c5bdcfa 100644 --- a/engram/mcp_server.py +++ b/engram/mcp_server.py @@ -1031,6 +1031,122 @@ async def list_tools() -> List[Tool]: } } ), + # ---- Active Memory (signal bus) tools ---- + Tool( + name="signal_write", + description="Post a signal to the active memory bus. State signals upsert by key+agent; events always create new; directives are permanent.", + inputSchema={ + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "Signal key (e.g., 'editing_file', 'build_status', 'use_typescript')" + }, + "value": { + "type": "string", + "description": "Signal value/content" + }, + "signal_type": { + "type": "string", + "enum": ["state", "event", "directive"], + "description": "Signal type: state (current status, upserts), event (one-shot), directive (permanent rule). Default: state" + }, + "scope": { + "type": "string", + "enum": ["global", "repo", "namespace"], + "description": "Visibility scope. Default: global" + }, + "scope_key": { + "type": "string", + "description": "Scope qualifier (e.g., repo path for scope=repo)" + }, + "ttl_tier": { + "type": "string", + "enum": ["noise", "notable", "critical", "directive"], + "description": "TTL tier: noise (30m), notable (2h), critical (24h), directive (permanent). Default: notable" + }, + "agent_id": { + "type": "string", + "description": "Source agent identifier" + }, + "user_id": { + "type": "string", + "description": "User identifier (default: 'default')" + }, + }, + "required": ["key", "value"], + } + ), + Tool( + name="signal_read", + description="Read active signals from the memory bus. Returns signals ordered by priority (directive > critical > notable > noise).", + inputSchema={ + "type": "object", + "properties": { + "scope": { + "type": "string", + "enum": ["global", "repo", "namespace"], + "description": "Filter by visibility scope" + }, + "scope_key": { + "type": "string", + "description": "Filter by scope qualifier" + }, + "signal_type": { + "type": "string", + "enum": ["state", "event", "directive"], + "description": "Filter by signal type" + }, + "agent_id": { + "type": "string", + "description": "Reader agent identifier (tracked for read_by)" + }, + "user_id": { + "type": "string", + "description": "User identifier (default: 'default')" + }, + "limit": { + "type": "integer", + "description": "Maximum signals to return" + }, + }, + } + ), + Tool( + name="signal_clear", + description="Clear signals from the active memory bus matching the given criteria.", + inputSchema={ + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "Clear signals with this key" + }, + "scope": { + "type": "string", + "enum": ["global", "repo", "namespace"], + "description": "Clear signals with this scope" + }, + "scope_key": { + "type": "string", + "description": "Clear signals with this scope key" + }, + "agent_id": { + "type": "string", + "description": "Clear signals from this agent" + }, + "signal_type": { + "type": "string", + "enum": ["state", "event", "directive"], + "description": "Clear signals of this type" + }, + "user_id": { + "type": "string", + "description": "User identifier (default: 'default')" + }, + }, + } + ), ] # Some MCP clients cap/trim tool manifests per chat. @@ -1504,6 +1620,68 @@ def _handle_search_scenes(memory: "Memory", arguments: Dict[str, Any], _session_ } +# ---- Active Memory (signal bus) helpers and handlers ---- + +_active_store = None + + +def _get_active_store(memory: Memory): + """Lazy-initialize the global active memory store.""" + global _active_store + if _active_store is None: + if memory.config.active.enabled: + from engram.core.active_memory import ActiveMemoryStore + _active_store = ActiveMemoryStore(memory.config.active) + return _active_store + + +@_tool_handler("signal_write") +def _handle_signal_write(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + active = _get_active_store(memory) + if not active: + return {"error": "Active memory is disabled"} + return active.write_signal( + key=arguments["key"], + value=arguments["value"], + signal_type=arguments.get("signal_type", "state"), + scope=arguments.get("scope", "global"), + scope_key=arguments.get("scope_key"), + ttl_tier=arguments.get("ttl_tier", "notable"), + source_agent_id=arguments.get("agent_id"), + user_id=arguments.get("user_id", "default"), + ) + + +@_tool_handler("signal_read") +def _handle_signal_read(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + active = _get_active_store(memory) + if not active: + return {"error": "Active memory is disabled"} + return active.read_signals( + scope=arguments.get("scope"), + scope_key=arguments.get("scope_key"), + signal_type=arguments.get("signal_type"), + user_id=arguments.get("user_id", "default"), + reader_agent_id=arguments.get("agent_id"), + limit=arguments.get("limit"), + ) + + +@_tool_handler("signal_clear") +def _handle_signal_clear(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + active = _get_active_store(memory) + if not active: + return {"error": "Active memory is disabled"} + return active.clear_signals( + key=arguments.get("key"), + scope=arguments.get("scope"), + scope_key=arguments.get("scope_key"), + source_agent_id=arguments.get("agent_id"), + signal_type=arguments.get("signal_type"), + user_id=arguments.get("user_id", "default"), + ) + + @server.call_tool() async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: """Handle tool calls.""" @@ -1815,6 +1993,21 @@ def _handoff_error_payload(exc: Exception) -> Dict[str, str]: handoff_meta["resume"] = auto_resume_packet result["_handoff"] = handoff_meta + # Active memory auto-injection: attach latest signals to every response + if isinstance(result, dict): + active_store = _get_active_store(memory) + if active_store: + try: + signals = active_store.read_signals( + user_id=arguments.get("user_id", "default"), + reader_agent_id=arguments.get("agent_id"), + limit=memory.config.active.max_signals_per_response, + ) + if signals: + result["_active"] = signals + except Exception: + pass # Never break tool responses for active memory errors + return [TextContent(type="text", text=json.dumps(result, indent=2, default=str))] except Exception as e: diff --git a/engram/memory/main.py b/engram/memory/main.py index 2f5b1c2..de1f80f 100644 --- a/engram/memory/main.py +++ b/engram/memory/main.py @@ -200,6 +200,25 @@ def __init__(self, config: Optional[MemoryConfig] = None): # v2 Personal Memory Kernel orchestration layer. self.kernel = PersonalMemoryKernel(self) + # Active memory store (lazy initialized) + self._active_store = None + + @property + def active(self): + """Lazy-initialized Active Memory store for signal bus.""" + if self._active_store is None and self.config.active.enabled: + from engram.core.active_memory import ActiveMemoryStore + self._active_store = ActiveMemoryStore(self.config.active) + return self._active_store + + def consolidate_active(self) -> Dict[str, Any]: + """Run one consolidation cycle: promote important active signals to passive memory.""" + if not self.active: + return {"skipped": True, "reason": "active memory disabled"} + from engram.core.consolidation import ConsolidationEngine + engine = ConsolidationEngine(self.active, self, self.config.active) + return engine.run_cycle() + def __repr__(self) -> str: return f"Memory(db={self.db!r}, echo={self.echo_config.enable_echo}, scenes={self.scene_config.enable_scenes})" diff --git a/engram/utils/factory.py b/engram/utils/factory.py index 8e9d57a..2a1df22 100644 --- a/engram/utils/factory.py +++ b/engram/utils/factory.py @@ -63,4 +63,8 @@ def create(cls, provider: str, config: Dict[str, Any]): from engram.vector_stores.memory import InMemoryVectorStore return InMemoryVectorStore(config) + if provider == "sqlite_vec": + from engram.vector_stores.sqlite_vec import SqliteVecStore + + return SqliteVecStore(config) raise ValueError(f"Unsupported vector store provider: {provider}") diff --git a/engram/vector_stores/sqlite_vec.py b/engram/vector_stores/sqlite_vec.py new file mode 100644 index 0000000..71fed43 --- /dev/null +++ b/engram/vector_stores/sqlite_vec.py @@ -0,0 +1,362 @@ +""" +sqlite-vec vector store implementation. + +Uses sqlite-vec extension for vector similarity search with cosine distance. +Enables concurrent multi-agent access from a single SQLite database (unlike +Qdrant local which locks the directory). +""" + +from __future__ import annotations + +import json +import logging +import os +import sqlite3 +import struct +import threading +import uuid +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from engram.memory.utils import matches_filters +from engram.vector_stores.base import VectorStoreBase + +logger = logging.getLogger(__name__) + + +@dataclass +class MemoryResult: + id: str + score: float = 0.0 + payload: Dict[str, Any] = None + + +def _serialize_float32(vector: List[float]) -> bytes: + """Serialize a float vector to bytes for sqlite-vec.""" + return struct.pack(f"{len(vector)}f", *vector) + + +def _deserialize_float32(data: bytes, dims: int) -> List[float]: + """Deserialize bytes back to a float vector.""" + return list(struct.unpack(f"{dims}f", data)) + + +class SqliteVecStore(VectorStoreBase): + """Vector store backed by sqlite-vec extension.""" + + def __init__(self, config: Optional[Dict[str, Any]] = None): + config = config or {} + self.config = config + self.collection_name = config.get("collection_name", "fadem_memories") + self.vector_size = ( + config.get("embedding_model_dims") + or config.get("vector_size") + or config.get("embedding_dims") + or 1536 + ) + db_path = config.get( + "path", + os.path.join(os.path.expanduser("~"), ".engram", "sqlite_vec.db"), + ) + db_dir = os.path.dirname(db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + + self._conn = sqlite3.connect(db_path, check_same_thread=False) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA busy_timeout=5000") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.row_factory = sqlite3.Row + self._lock = threading.RLock() + + # Load sqlite-vec extension + self._conn.enable_load_extension(True) + import sqlite_vec + sqlite_vec.load(self._conn) + self._conn.enable_load_extension(False) + + self._ensure_collection(self.collection_name, self.vector_size) + + def _vec_table(self, name: str) -> str: + return f"vec_{name}" + + def _payload_table(self, name: str) -> str: + return f"payload_{name}" + + def _ensure_collection(self, name: str, vector_size: int) -> None: + """Create vec0 virtual table and payload table if they don't exist.""" + vec_table = self._vec_table(name) + payload_table = self._payload_table(name) + + with self._lock: + # Check if collection already exists + existing = self._conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (payload_table,), + ).fetchone() + + if not existing: + self._conn.execute( + f"CREATE VIRTUAL TABLE IF NOT EXISTS [{vec_table}] " + f"USING vec0(embedding float[{vector_size}] distance_metric=cosine)" + ) + self._conn.execute( + f"""CREATE TABLE IF NOT EXISTS [{payload_table}] ( + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + uuid TEXT UNIQUE NOT NULL, + payload TEXT DEFAULT '{{}}' + )""" + ) + self._conn.execute( + f"CREATE INDEX IF NOT EXISTS [idx_{name}_uuid] ON [{payload_table}](uuid)" + ) + self._conn.commit() + + def create_col(self, name: str, vector_size: int, distance: str = "cosine") -> None: + self._ensure_collection(name, vector_size) + + def insert( + self, + vectors: List[List[float]], + payloads: Optional[List[Dict[str, Any]]] = None, + ids: Optional[List[str]] = None, + ) -> None: + payloads = payloads or [{} for _ in vectors] + if len(payloads) != len(vectors): + raise ValueError("payloads length must match vectors length") + if ids is not None and len(ids) != len(vectors): + raise ValueError("ids length must match vectors length") + ids = ids or [str(uuid.uuid4()) for _ in vectors] + + vec_table = self._vec_table(self.collection_name) + payload_table = self._payload_table(self.collection_name) + + with self._lock: + for vector_id, vector, payload in zip(ids, vectors, payloads): + # Check if uuid already exists (upsert) + existing = self._conn.execute( + f"SELECT rowid FROM [{payload_table}] WHERE uuid = ?", + (vector_id,), + ).fetchone() + + if existing: + rowid = existing["rowid"] + self._conn.execute( + f"UPDATE [{payload_table}] SET payload = ? WHERE rowid = ?", + (json.dumps(payload, default=str), rowid), + ) + self._conn.execute( + f"UPDATE [{vec_table}] SET embedding = ? WHERE rowid = ?", + (_serialize_float32(vector), rowid), + ) + else: + cursor = self._conn.execute( + f"INSERT INTO [{payload_table}] (uuid, payload) VALUES (?, ?)", + (vector_id, json.dumps(payload, default=str)), + ) + rowid = cursor.lastrowid + self._conn.execute( + f"INSERT INTO [{vec_table}] (rowid, embedding) VALUES (?, ?)", + (rowid, _serialize_float32(vector)), + ) + self._conn.commit() + + def search( + self, + query: Optional[str], + vectors: List[float], + limit: int = 5, + filters: Optional[Dict[str, Any]] = None, + ) -> List[MemoryResult]: + vec_table = self._vec_table(self.collection_name) + payload_table = self._payload_table(self.collection_name) + + # Over-fetch when filters are present to compensate for post-filtering + fetch_limit = limit * 3 if filters else limit + + with self._lock: + # Check if collection has any rows first + count = self._conn.execute( + f"SELECT COUNT(*) as cnt FROM [{payload_table}]" + ).fetchone() + if not count or count["cnt"] == 0: + return [] + + # sqlite-vec requires `k = ?` in WHERE clause for KNN queries + rows = self._conn.execute( + f"""SELECT v.rowid, v.distance + FROM [{vec_table}] v + WHERE v.embedding MATCH ? AND k = ?""", + (_serialize_float32(vectors), fetch_limit), + ).fetchall() + + # Join with payload table in a second step + results_raw = [] + for row in rows: + p = self._conn.execute( + f"SELECT uuid, payload FROM [{payload_table}] WHERE rowid = ?", + (row["rowid"],), + ).fetchone() + if p: + results_raw.append({ + "distance": row["distance"], + "uuid": p["uuid"], + "payload": p["payload"], + }) + + results = [] + for item in results_raw: + payload = {} + try: + payload = json.loads(item["payload"]) if item["payload"] else {} + except (json.JSONDecodeError, TypeError): + pass + + if filters and not matches_filters(payload, filters): + continue + + # sqlite-vec cosine distance is 0..2 (0=identical). + # Convert to similarity score: 1 - (distance / 2) + distance = float(item["distance"]) + score = 1.0 - (distance / 2.0) + + results.append(MemoryResult( + id=item["uuid"], + score=score, + payload=payload, + )) + + return results[:limit] + + def delete(self, vector_id: str) -> None: + payload_table = self._payload_table(self.collection_name) + vec_table = self._vec_table(self.collection_name) + + with self._lock: + row = self._conn.execute( + f"SELECT rowid FROM [{payload_table}] WHERE uuid = ?", + (vector_id,), + ).fetchone() + if row: + rowid = row["rowid"] + self._conn.execute( + f"DELETE FROM [{vec_table}] WHERE rowid = ?", (rowid,) + ) + self._conn.execute( + f"DELETE FROM [{payload_table}] WHERE rowid = ?", (rowid,) + ) + self._conn.commit() + + def update( + self, + vector_id: str, + vector: Optional[List[float]] = None, + payload: Optional[Dict[str, Any]] = None, + ) -> None: + payload_table = self._payload_table(self.collection_name) + vec_table = self._vec_table(self.collection_name) + + with self._lock: + row = self._conn.execute( + f"SELECT rowid FROM [{payload_table}] WHERE uuid = ?", + (vector_id,), + ).fetchone() + if not row: + return + rowid = row["rowid"] + + if vector is not None: + self._conn.execute( + f"UPDATE [{vec_table}] SET embedding = ? WHERE rowid = ?", + (_serialize_float32(vector), rowid), + ) + if payload is not None: + self._conn.execute( + f"UPDATE [{payload_table}] SET payload = ? WHERE rowid = ?", + (json.dumps(payload, default=str), rowid), + ) + self._conn.commit() + + def get(self, vector_id: str) -> Optional[MemoryResult]: + payload_table = self._payload_table(self.collection_name) + + with self._lock: + row = self._conn.execute( + f"SELECT uuid, payload FROM [{payload_table}] WHERE uuid = ?", + (vector_id,), + ).fetchone() + + if not row: + return None + + payload = {} + try: + payload = json.loads(row["payload"]) if row["payload"] else {} + except (json.JSONDecodeError, TypeError): + pass + + return MemoryResult(id=row["uuid"], score=0.0, payload=payload) + + def list_cols(self) -> List[str]: + with self._lock: + rows = self._conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'payload_%'", + ).fetchall() + return [row["name"].replace("payload_", "", 1) for row in rows] + + def delete_col(self) -> None: + vec_table = self._vec_table(self.collection_name) + payload_table = self._payload_table(self.collection_name) + + with self._lock: + self._conn.execute(f"DROP TABLE IF EXISTS [{vec_table}]") + self._conn.execute(f"DROP TABLE IF EXISTS [{payload_table}]") + self._conn.commit() + + def col_info(self) -> Dict[str, Any]: + payload_table = self._payload_table(self.collection_name) + + with self._lock: + row = self._conn.execute( + f"SELECT COUNT(*) as cnt FROM [{payload_table}]", + ).fetchone() + + count = row["cnt"] if row else 0 + return { + "name": self.collection_name, + "points": count, + "vector_size": self.vector_size, + } + + def list( + self, + filters: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None, + ) -> List[MemoryResult]: + payload_table = self._payload_table(self.collection_name) + effective_limit = limit or 100 + + with self._lock: + rows = self._conn.execute( + f"SELECT uuid, payload FROM [{payload_table}] LIMIT ?", + (effective_limit * 3 if filters else effective_limit,), + ).fetchall() + + results = [] + for row in rows: + payload = {} + try: + payload = json.loads(row["payload"]) if row["payload"] else {} + except (json.JSONDecodeError, TypeError): + pass + + if filters and not matches_filters(payload, filters): + continue + + results.append(MemoryResult(id=row["uuid"], score=0.0, payload=payload)) + + return results[:effective_limit] + + def reset(self) -> None: + self.delete_col() + self._ensure_collection(self.collection_name, self.vector_size) diff --git a/pyproject.toml b/pyproject.toml index 3970d90..244f44e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ gemini = ["google-generativeai>=0.3.0"] openai = ["openai>=1.0.0"] ollama = ["ollama>=0.4.0"] qdrant = ["qdrant-client>=1.7.0"] +sqlite_vec = ["sqlite-vec>=0.1.1"] mcp = [ "mcp>=1.0.0", ] @@ -50,6 +51,7 @@ all = [ "fastapi>=0.109.0", "uvicorn[standard]>=0.27.0", "aiosqlite>=0.19.0", + "sqlite-vec>=0.1.1", ] async = [ "aiosqlite>=0.19.0", diff --git a/tests/test_active_memory.py b/tests/test_active_memory.py new file mode 100644 index 0000000..5be9107 --- /dev/null +++ b/tests/test_active_memory.py @@ -0,0 +1,201 @@ +"""Tests for Active Memory Store — signal bus with TTL tiers.""" + +import os +import tempfile +import time + +import pytest + +from engram.configs.active import ActiveMemoryConfig +from engram.core.active_memory import ActiveMemoryStore + + +@pytest.fixture +def store(tmp_path): + """Create an ActiveMemoryStore with a temporary database.""" + config = ActiveMemoryConfig( + db_path=str(tmp_path / "active_test.db"), + ttl_seconds={ + "noise": 1, # 1 second for fast test + "notable": 7200, + "critical": 86400, + "directive": 0, + }, + ) + s = ActiveMemoryStore(config) + yield s + s.close() + + +class TestWriteSignal: + def test_write_event_creates_new(self, store): + r1 = store.write_signal(key="build", value="failed", signal_type="event") + r2 = store.write_signal(key="build", value="passed", signal_type="event") + assert r1["action"] == "created" + assert r2["action"] == "created" + assert r1["id"] != r2["id"] + + def test_write_state_upserts(self, store): + r1 = store.write_signal(key="editing", value="file_a.py", signal_type="state", source_agent_id="agent-1") + r2 = store.write_signal(key="editing", value="file_b.py", signal_type="state", source_agent_id="agent-1") + assert r1["action"] == "created" + assert r2["action"] == "updated" + assert r1["id"] == r2["id"] + + signals = store.read_signals(user_id="default") + assert len(signals) == 1 + assert signals[0]["value"] == "file_b.py" + + def test_write_state_different_agents_no_upsert(self, store): + r1 = store.write_signal(key="editing", value="file_a.py", signal_type="state", source_agent_id="agent-1") + r2 = store.write_signal(key="editing", value="file_b.py", signal_type="state", source_agent_id="agent-2") + assert r1["id"] != r2["id"] + + signals = store.read_signals(user_id="default") + assert len(signals) == 2 + + def test_write_directive_upserts_by_key(self, store): + r1 = store.write_signal(key="use_typescript", value="always", signal_type="directive") + r2 = store.write_signal(key="use_typescript", value="always use strict mode", signal_type="directive") + assert r2["action"] == "updated" + assert r1["id"] == r2["id"] + + def test_directive_forces_directive_tier(self, store): + store.write_signal(key="rule1", value="test", signal_type="directive", ttl_tier="noise") + signals = store.read_signals(user_id="default") + assert signals[0]["ttl_tier"] == "directive" + assert signals[0]["expires_at"] is None + + +class TestReadSignals: + def test_read_empty(self, store): + signals = store.read_signals(user_id="default") + assert signals == [] + + def test_read_filters_by_scope(self, store): + store.write_signal(key="a", value="1", scope="global") + store.write_signal(key="b", value="2", scope="repo", scope_key="/path") + signals = store.read_signals(scope="repo", scope_key="/path", user_id="default") + assert len(signals) == 1 + assert signals[0]["key"] == "b" + + def test_read_filters_by_signal_type(self, store): + store.write_signal(key="a", value="1", signal_type="state") + store.write_signal(key="b", value="2", signal_type="event") + signals = store.read_signals(signal_type="event", user_id="default") + assert len(signals) == 1 + assert signals[0]["key"] == "b" + + def test_read_increments_read_count(self, store): + store.write_signal(key="x", value="v") + store.read_signals(user_id="default") + store.read_signals(user_id="default") + store.read_signals(user_id="default") + signals = store.read_signals(user_id="default") + # read_count reflects value at time of SELECT (before this read's increment) + assert signals[0]["read_count"] >= 3 + + def test_read_tracks_reader_agent(self, store): + store.write_signal(key="x", value="v") + store.read_signals(user_id="default", reader_agent_id="agent-A") + signals = store.read_signals(user_id="default", reader_agent_id="agent-B") + assert "agent-A" in signals[0]["read_by"] + assert "agent-B" in signals[0]["read_by"] + + def test_read_priority_order(self, store): + store.write_signal(key="noise_sig", value="1", ttl_tier="noise") + store.write_signal(key="critical_sig", value="2", ttl_tier="critical") + store.write_signal(key="directive_sig", value="3", signal_type="directive") + signals = store.read_signals(user_id="default") + tiers = [s["ttl_tier"] for s in signals] + assert tiers == ["directive", "critical", "noise"] + + def test_read_respects_limit(self, store): + for i in range(5): + store.write_signal(key=f"event_{i}", value=str(i), signal_type="event") + signals = store.read_signals(user_id="default", limit=3) + assert len(signals) == 3 + + def test_read_filters_by_user(self, store): + store.write_signal(key="a", value="1", user_id="alice") + store.write_signal(key="b", value="2", user_id="bob") + signals = store.read_signals(user_id="alice") + assert len(signals) == 1 + assert signals[0]["key"] == "a" + + +class TestTTLExpiry: + def test_noise_expires_quickly(self, store): + store.write_signal(key="temp", value="gone", ttl_tier="noise") + # noise TTL is 1 second in test config + time.sleep(1.5) + signals = store.read_signals(user_id="default") + assert len(signals) == 0 + + def test_directive_never_expires(self, store): + store.write_signal(key="rule", value="permanent", signal_type="directive") + # Even after GC, directive should persist + store.gc_expired() + signals = store.read_signals(user_id="default") + assert len(signals) == 1 + assert signals[0]["signal_type"] == "directive" + + +class TestClearSignals: + def test_clear_by_key(self, store): + store.write_signal(key="a", value="1") + store.write_signal(key="b", value="2") + result = store.clear_signals(key="a", user_id="default") + assert result["deleted"] == 1 + signals = store.read_signals(user_id="default") + assert len(signals) == 1 + assert signals[0]["key"] == "b" + + def test_clear_by_agent(self, store): + store.write_signal(key="x", value="1", source_agent_id="agent-1") + store.write_signal(key="y", value="2", source_agent_id="agent-2") + result = store.clear_signals(source_agent_id="agent-1", user_id="default") + assert result["deleted"] == 1 + + def test_clear_all_for_user(self, store): + store.write_signal(key="a", value="1", user_id="u1") + store.write_signal(key="b", value="2", user_id="u1") + result = store.clear_signals(user_id="u1") + assert result["deleted"] == 2 + + +class TestGC: + def test_gc_removes_expired(self, store): + store.write_signal(key="temp", value="gone", ttl_tier="noise") + time.sleep(1.5) + removed = store.gc_expired() + assert removed >= 1 + + +class TestConsolidationCandidates: + def test_directive_is_candidate(self, store): + store.write_signal(key="rule", value="always use tests", signal_type="directive") + # With min_age_seconds=0 to bypass age requirement + candidates = store.get_consolidation_candidates(min_age_seconds=0, min_reads=100) + assert len(candidates) == 1 + assert candidates[0]["signal_type"] == "directive" + + def test_critical_is_candidate(self, store): + store.write_signal(key="important", value="critical info", ttl_tier="critical") + candidates = store.get_consolidation_candidates(min_age_seconds=0, min_reads=100) + assert len(candidates) == 1 + + def test_high_read_is_candidate(self, store): + store.write_signal(key="popular", value="read a lot") + for _ in range(5): + store.read_signals(user_id="default") + candidates = store.get_consolidation_candidates(min_age_seconds=0, min_reads=3) + assert len(candidates) >= 1 + + def test_mark_consolidated_skips_on_next_run(self, store): + store.write_signal(key="rule", value="test", signal_type="directive") + candidates = store.get_consolidation_candidates(min_age_seconds=0) + assert len(candidates) == 1 + store.mark_consolidated([candidates[0]["id"]]) + candidates2 = store.get_consolidation_candidates(min_age_seconds=0) + assert len(candidates2) == 0 diff --git a/tests/test_consolidation.py b/tests/test_consolidation.py new file mode 100644 index 0000000..3f27d3a --- /dev/null +++ b/tests/test_consolidation.py @@ -0,0 +1,108 @@ +"""Tests for Active → Passive memory consolidation engine.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from engram.configs.active import ActiveMemoryConfig +from engram.core.active_memory import ActiveMemoryStore +from engram.core.consolidation import ConsolidationEngine + + +@pytest.fixture +def active_store(tmp_path): + config = ActiveMemoryConfig( + db_path=str(tmp_path / "consolidation_test.db"), + consolidation_min_age_seconds=0, # no age requirement for tests + consolidation_min_reads=2, + ) + s = ActiveMemoryStore(config) + yield s + s.close() + + +@pytest.fixture +def mock_memory(): + memory = MagicMock() + memory.add.return_value = {"results": [{"id": "mem-1"}]} + return memory + + +@pytest.fixture +def engine(active_store, mock_memory): + config = ActiveMemoryConfig( + consolidation_min_age_seconds=0, + consolidation_min_reads=2, + ) + return ConsolidationEngine(active_store, mock_memory, config) + + +class TestConsolidation: + def test_directive_promoted(self, active_store, engine, mock_memory): + active_store.write_signal(key="rule", value="always use tests", signal_type="directive") + result = engine.run_cycle() + assert result["promoted"] == 1 + mock_memory.add.assert_called_once() + call_kwargs = mock_memory.add.call_args + assert call_kwargs.kwargs["immutable"] is True + assert call_kwargs.kwargs["initial_layer"] == "lml" + + def test_critical_promoted(self, active_store, engine, mock_memory): + active_store.write_signal(key="important", value="critical info", ttl_tier="critical") + result = engine.run_cycle() + assert result["promoted"] == 1 + call_kwargs = mock_memory.add.call_args + assert call_kwargs.kwargs["immutable"] is False + assert call_kwargs.kwargs["initial_layer"] == "sml" + + def test_high_read_promoted(self, active_store, engine, mock_memory): + active_store.write_signal(key="popular", value="frequently read", ttl_tier="notable") + # Read 3 times to exceed threshold of 2 + for _ in range(3): + active_store.read_signals(user_id="default") + result = engine.run_cycle() + assert result["promoted"] >= 1 + + def test_low_read_not_promoted(self, active_store, engine, mock_memory): + active_store.write_signal(key="unpopular", value="barely read", ttl_tier="notable") + # Only 1 read, below threshold of 2 + active_store.read_signals(user_id="default") + result = engine.run_cycle() + assert result["promoted"] == 0 + mock_memory.add.assert_not_called() + + def test_already_consolidated_skipped(self, active_store, engine, mock_memory): + active_store.write_signal(key="rule", value="test", signal_type="directive") + # First cycle promotes + result1 = engine.run_cycle() + assert result1["promoted"] == 1 + mock_memory.add.reset_mock() + # Second cycle should skip (already consolidated) + result2 = engine.run_cycle() + assert result2["promoted"] == 0 + mock_memory.add.assert_not_called() + + def test_content_format(self, active_store, engine, mock_memory): + active_store.write_signal(key="coding_style", value="use type hints", signal_type="directive") + engine.run_cycle() + call_kwargs = mock_memory.add.call_args + assert "[coding_style]" in call_kwargs.kwargs["messages"] + assert "use type hints" in call_kwargs.kwargs["messages"] + + def test_metadata_source(self, active_store, engine, mock_memory): + active_store.write_signal(key="test_key", value="test_val", signal_type="directive") + engine.run_cycle() + call_kwargs = mock_memory.add.call_args + metadata = call_kwargs.kwargs["metadata"] + assert metadata["source"] == "active_signal" + assert metadata["signal_key"] == "test_key" + assert metadata["signal_type"] == "directive" + + def test_run_cycle_stats(self, active_store, engine, mock_memory): + active_store.write_signal(key="rule1", value="a", signal_type="directive") + active_store.write_signal(key="rule2", value="b", signal_type="directive") + active_store.write_signal(key="noise", value="c", ttl_tier="noise") + result = engine.run_cycle() + assert result["promoted"] == 2 + assert result["checked"] == 2 # noise is not a candidate + assert result["errors"] == 0 diff --git a/tests/test_sqlite_vec.py b/tests/test_sqlite_vec.py new file mode 100644 index 0000000..8712a21 --- /dev/null +++ b/tests/test_sqlite_vec.py @@ -0,0 +1,233 @@ +"""Tests for sqlite-vec vector store implementation.""" + +import math +import os + +import pytest + +sqlite_vec = pytest.importorskip("sqlite_vec", reason="sqlite-vec not installed") + +from engram.vector_stores.sqlite_vec import SqliteVecStore, MemoryResult + + +@pytest.fixture +def store(tmp_path): + """Create a SqliteVecStore with a temporary database.""" + config = { + "path": str(tmp_path / "vec_test.db"), + "collection_name": "test_col", + "embedding_model_dims": 4, + } + return SqliteVecStore(config) + + +def _norm(v): + """Normalize a vector to unit length.""" + mag = math.sqrt(sum(x * x for x in v)) + return [x / mag for x in v] if mag > 0 else v + + +class TestInsert: + def test_insert_single(self, store): + store.insert( + vectors=[[1.0, 0.0, 0.0, 0.0]], + payloads=[{"text": "hello"}], + ids=["id-1"], + ) + result = store.get("id-1") + assert result is not None + assert result.id == "id-1" + assert result.payload["text"] == "hello" + + def test_insert_multiple(self, store): + store.insert( + vectors=[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], + payloads=[{"text": "a"}, {"text": "b"}], + ids=["id-1", "id-2"], + ) + assert store.get("id-1") is not None + assert store.get("id-2") is not None + + def test_insert_auto_ids(self, store): + store.insert(vectors=[[1.0, 0.0, 0.0, 0.0]]) + info = store.col_info() + assert info["points"] == 1 + + def test_insert_upsert(self, store): + store.insert( + vectors=[[1.0, 0.0, 0.0, 0.0]], + payloads=[{"text": "old"}], + ids=["id-1"], + ) + store.insert( + vectors=[[0.0, 1.0, 0.0, 0.0]], + payloads=[{"text": "new"}], + ids=["id-1"], + ) + result = store.get("id-1") + assert result.payload["text"] == "new" + assert store.col_info()["points"] == 1 + + def test_insert_mismatched_lengths(self, store): + with pytest.raises(ValueError): + store.insert( + vectors=[[1.0, 0.0, 0.0, 0.0]], + payloads=[{"a": 1}, {"b": 2}], + ) + + +class TestSearch: + def test_cosine_similarity_ordering(self, store): + v1 = _norm([1.0, 0.0, 0.0, 0.0]) + v2 = _norm([0.7, 0.7, 0.0, 0.0]) + v3 = _norm([0.0, 1.0, 0.0, 0.0]) + store.insert( + vectors=[v1, v2, v3], + payloads=[{"label": "exact"}, {"label": "partial"}, {"label": "orthogonal"}], + ids=["a", "b", "c"], + ) + query = _norm([1.0, 0.0, 0.0, 0.0]) + results = store.search(query=None, vectors=query, limit=3) + assert len(results) == 3 + assert results[0].id == "a" # Most similar + assert results[0].score > results[1].score + assert results[1].score > results[2].score + + def test_search_respects_limit(self, store): + for i in range(10): + store.insert( + vectors=[_norm([float(i), 1.0, 0.0, 0.0])], + ids=[f"id-{i}"], + ) + results = store.search(query=None, vectors=_norm([5.0, 1.0, 0.0, 0.0]), limit=3) + assert len(results) == 3 + + def test_search_with_filters(self, store): + store.insert( + vectors=[_norm([1.0, 0.0, 0.0, 0.0]), _norm([1.0, 0.1, 0.0, 0.0])], + payloads=[{"user_id": "alice"}, {"user_id": "bob"}], + ids=["a", "b"], + ) + results = store.search( + query=None, + vectors=_norm([1.0, 0.0, 0.0, 0.0]), + limit=5, + filters={"user_id": "alice"}, + ) + assert len(results) == 1 + assert results[0].payload["user_id"] == "alice" + + def test_search_empty_collection(self, store): + results = store.search(query=None, vectors=[1.0, 0.0, 0.0, 0.0], limit=5) + assert results == [] + + +class TestDelete: + def test_delete_existing(self, store): + store.insert( + vectors=[[1.0, 0.0, 0.0, 0.0]], + ids=["id-1"], + ) + store.delete("id-1") + assert store.get("id-1") is None + assert store.col_info()["points"] == 0 + + def test_delete_nonexistent(self, store): + store.delete("nonexistent") # Should not raise + + +class TestUpdate: + def test_update_payload(self, store): + store.insert( + vectors=[[1.0, 0.0, 0.0, 0.0]], + payloads=[{"text": "old"}], + ids=["id-1"], + ) + store.update("id-1", payload={"text": "new"}) + result = store.get("id-1") + assert result.payload["text"] == "new" + + def test_update_vector(self, store): + v1 = _norm([1.0, 0.0, 0.0, 0.0]) + store.insert(vectors=[v1], ids=["id-1"]) + + v2 = _norm([0.0, 1.0, 0.0, 0.0]) + store.update("id-1", vector=v2) + + # Search for the new vector direction should rank it first + results = store.search(query=None, vectors=v2, limit=1) + assert results[0].id == "id-1" + + def test_update_nonexistent(self, store): + store.update("nonexistent", payload={"x": 1}) # Should not raise + + +class TestGet: + def test_get_existing(self, store): + store.insert( + vectors=[[1.0, 0.0, 0.0, 0.0]], + payloads=[{"key": "value"}], + ids=["id-1"], + ) + result = store.get("id-1") + assert isinstance(result, MemoryResult) + assert result.id == "id-1" + assert result.payload["key"] == "value" + + def test_get_nonexistent(self, store): + assert store.get("nonexistent") is None + + +class TestCollectionOps: + def test_list_cols(self, store): + cols = store.list_cols() + assert "test_col" in cols + + def test_col_info(self, store): + store.insert(vectors=[[1.0, 0.0, 0.0, 0.0]], ids=["a"]) + store.insert(vectors=[[0.0, 1.0, 0.0, 0.0]], ids=["b"]) + info = store.col_info() + assert info["name"] == "test_col" + assert info["points"] == 2 + assert info["vector_size"] == 4 + + def test_delete_col(self, store): + store.insert(vectors=[[1.0, 0.0, 0.0, 0.0]], ids=["a"]) + store.delete_col() + # After delete_col, list should be empty + cols = store.list_cols() + assert "test_col" not in cols + + def test_reset(self, store): + store.insert(vectors=[[1.0, 0.0, 0.0, 0.0]], ids=["a"]) + store.reset() + assert store.col_info()["points"] == 0 + # But collection exists again + assert "test_col" in store.list_cols() + + +class TestList: + def test_list_all(self, store): + store.insert( + vectors=[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], + payloads=[{"label": "a"}, {"label": "b"}], + ids=["id-1", "id-2"], + ) + results = store.list() + assert len(results) == 2 + + def test_list_with_filters(self, store): + store.insert( + vectors=[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], + payloads=[{"user_id": "alice"}, {"user_id": "bob"}], + ids=["id-1", "id-2"], + ) + results = store.list(filters={"user_id": "alice"}) + assert len(results) == 1 + assert results[0].payload["user_id"] == "alice" + + def test_list_with_limit(self, store): + for i in range(10): + store.insert(vectors=[[float(i), 0.0, 0.0, 0.0]], ids=[f"id-{i}"]) + results = store.list(limit=3) + assert len(results) == 3