diff --git a/README.md b/README.md index fc86d7e..9306cf2 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,9 @@ But Engram isn't just a handoff bus. It solves four fundamental problems with ho | **Nobody forgets** | Store everything forever | **Ebbinghaus decay curve, ~45% less storage** | | **Agents write with no oversight** | Store directly | **Staging + verification + trust scoring** | | **No episodic memory** | Vector search only | **CAST scenes (time/place/topic)** | +| **No consolidation** | Store everything as-is | **CLS Distillation — replay-driven fact extraction** | +| **Single decay rate** | One exponential curve | **Multi-trace Benna-Fusi model (fast/mid/slow)** | +| **No intent routing** | Same search for all queries | **Episodic vs semantic query classification** | | Multi-modal encoding | Single embedding | **5 retrieval paths (EchoMem)** | | Cross-agent memory sharing | Per-agent silos | **Scoped retrieval with all-but-mask privacy** | | Concurrent multi-agent access | Single-process locks | **sqlite-vec WAL mode — multiple agents, one DB** | @@ -90,6 +93,9 @@ pip install "engram-memory[sqlite_vec]" # OpenAI provider add-on pip install "engram-memory[openai]" +# NVIDIA provider add-on (Llama 3.1, nv-embed-v1, etc.) +pip install "engram-memory[nvidia]" + # Ollama provider add-on pip install "engram-memory[ollama]" ``` @@ -144,7 +150,7 @@ Engram has five opinions about how memory should work: 1. **Switching agents shouldn't mean starting over.** When an agent pauses — rate limit, crash, tool switch — it saves a session digest. The next agent loads it and continues. Zero re-explanation. 2. **Agents need shared real-time state.** Active Memory lets agents broadcast what they're doing right now — no polling, no coordination protocol. Agent A posts "editing auth.py"; Agent B sees it instantly. -3. **Memory has a lifecycle.** New memories start in short-term (SML), get promoted to long-term (LML) through repeated access, and fade away through Ebbinghaus decay if unused. +3. **Memory has a lifecycle.** New memories start in short-term (SML), get promoted to long-term (LML) through repeated access, and fade away through Ebbinghaus decay if unused. Sleep cycles distill episodic conversations into durable semantic facts (CLS consolidation), cascade strength traces from fast to slow, and prune redundant or contradictory memories. 4. **Agents are untrusted writers.** Every write is a proposal that lands in staging. Trusted agents can auto-merge; untrusted ones wait for approval. 5. **Scoping is mandatory.** Every memory is scoped by user. Agents see only what they're allowed to — everything else gets the "all but mask" treatment (structure visible, details redacted). @@ -209,7 +215,7 @@ Engram has five opinions about how memory should work: ### The Memory Stack -Engram combines seven systems, each handling a different aspect of how memory should work: +Engram combines multiple systems, each handling a different aspect of how memory should work: #### Active Memory — Real-Time Signal Bus @@ -289,6 +295,48 @@ Scene: "Engram v2 architecture session" Memories: [mem_1, mem_2] ← semantic facts extracted ``` +#### CLS Distillation Memory — Bio-Inspired Consolidation (v1.4) + +Inspired by Complementary Learning Systems (CLS) theory — how the hippocampus and neocortex work together in the brain. Engram v1.4 adds five mechanisms that make memory smarter over time: + +**1. Episodic/Semantic Memory Types** +Conversations are stored as `episodic` memories. During sleep cycles, a replay-driven distiller extracts durable facts into `semantic` memories — just like how your brain consolidates experiences into knowledge overnight. + +**2. Replay-Driven Distillation** +The `ReplayDistiller` samples recent episodic memories, groups them by scene/time, and uses the LLM to extract reusable semantic facts. Every distilled fact links back to its source episodes (provenance tracking). + +**3. Multi-Mechanism Forgetting** +Beyond simple exponential decay, Engram now has three advanced forgetting mechanisms: +- **Interference Pruning** — contradictory memories are detected and the weaker one is demoted +- **Redundancy Collapse** — near-duplicate memories are auto-fused +- **Homeostatic Normalization** — memory budgets per namespace prevent unbounded growth + +**4. Multi-Timescale Strength Traces (Benna-Fusi Model)** +Each memory has three strength traces instead of one scalar: +``` +s_fast (decay: 0.20/day) — recent access, volatile +s_mid (decay: 0.05/day) — medium-term consolidation +s_slow (decay: 0.005/day) — durable long-term knowledge +``` +New memories start in `s_fast`. Sleep cycles cascade strength: `fast → mid → slow`. Important facts become nearly permanent. + +**5. Intent-Aware Retrieval Routing** +Queries are classified as episodic ("when did we discuss..."), semantic ("what is the deployment process?"), or mixed. Matching memory types get a retrieval boost — the right type of answer for the right type of question. + +``` +┌──────────────────────────────────────────────────────────────┐ +│ Sleep Cycle (v1.4) │ +│ │ +│ 1. Standard FadeMem decay (SML/LML) │ +│ 2. Multi-trace decay (fast/mid/slow independently) │ +│ 3. Interference pruning (contradict → demote weaker) │ +│ 4. Redundancy collapse (near-dupes → fuse) │ +│ 5. Homeostatic normalization (budget enforcement) │ +│ 6. Replay distillation (episodic → semantic facts) │ +│ 7. Trace cascade (fast → mid → slow consolidation) │ +└──────────────────────────────────────────────────────────────┘ +``` + #### Handoff Bus — Cross-Agent Continuity Engram now defaults to a zero-intervention continuity model: MCP adapters automatically request resume context before tool execution and auto-write checkpoints on lifecycle events (`tool_complete`, `agent_pause`, `agent_end`). The legacy tools (`save_session_digest`, `get_last_session`, `list_sessions`) remain available for compatibility. @@ -785,7 +833,7 @@ Engram is based on: | Multi-hop Reasoning | +12% accuracy | | Retrieval Precision | +8% on LTI-Bench | -Biological inspirations: Ebbinghaus Forgetting Curve → exponential decay, Spaced Repetition → access boosts strength, Sleep Consolidation → SML → LML promotion, Working Memory → Active Memory signal bus, Conscious/Subconscious Split → Active vs Passive memory, Production Effect → echo encoding, Elaborative Encoding → deeper processing = stronger memory. +Biological inspirations: Ebbinghaus Forgetting Curve → exponential decay, Spaced Repetition → access boosts strength, Sleep Consolidation → SML → LML promotion + CLS replay distillation, Benna-Fusi Model → multi-timescale strength traces (fast/mid/slow), Complementary Learning Systems → episodic-to-semantic consolidation, Working Memory → Active Memory signal bus, Conscious/Subconscious Split → Active vs Passive memory, Production Effect → echo encoding, Elaborative Encoding → deeper processing = stronger memory. --- diff --git a/engram/api/app.py b/engram/api/app.py index 3f258d2..2d01586 100644 --- a/engram/api/app.py +++ b/engram/api/app.py @@ -85,9 +85,16 @@ class DecayResponse(BaseModel): redoc_url="/redoc", ) +_cors_origins_raw = os.environ.get("ENGRAM_CORS_ORIGINS", "") +_cors_origins = ( + [o.strip() for o in _cors_origins_raw.split(",") if o.strip()] + if _cors_origins_raw + else ["http://localhost:3000", "http://127.0.0.1:3000"] +) + app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=_cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -95,12 +102,15 @@ class DecayResponse(BaseModel): add_metrics_routes(app) _memory: Optional[Memory] = None +_memory_lock = threading.Lock() def get_memory() -> Memory: global _memory if _memory is None: - _memory = Memory() + with _memory_lock: + if _memory is None: + _memory = Memory() return _memory @@ -403,7 +413,7 @@ async def search_memories(request: SearchRequestV2, http_request: Request): raise require_session_error(exc) except Exception as exc: logger.exception("Error searching memories") - raise HTTPException(status_code=500, detail=str(exc)) + raise HTTPException(status_code=500, detail="Internal server error") @app.get("/v1/scenes") @@ -494,7 +504,7 @@ async def add_memory(request: AddMemoryRequestV2, http_request: Request): raise require_session_error(exc) except Exception as exc: logger.exception("Error creating proposal/direct memory") - raise HTTPException(status_code=500, detail=str(exc)) + raise HTTPException(status_code=500, detail="Internal server error") @app.get("/v1/staging/commits") @@ -779,7 +789,9 @@ async def get_memory_by_id(memory_id: str): @app.put("/v1/memories/{memory_id}", response_model=Dict[str, Any]) @app.put("/v1/memories/{memory_id}/", response_model=Dict[str, Any]) -async def update_memory(memory_id: str, request: Dict[str, Any]): +async def update_memory(memory_id: str, request: Dict[str, Any], http_request: Request): + token = get_token_from_request(http_request) + require_token_for_untrusted_request(http_request, token) memory = get_memory() result = memory.update(memory_id, request) return result @@ -787,7 +799,9 @@ async def update_memory(memory_id: str, request: Dict[str, Any]): @app.delete("/v1/memories/{memory_id}") @app.delete("/v1/memories/{memory_id}/") -async def delete_memory(memory_id: str): +async def delete_memory(memory_id: str, http_request: Request): + token = get_token_from_request(http_request) + require_token_for_untrusted_request(http_request, token) memory = get_memory() memory.delete(memory_id) return {"status": "deleted", "id": memory_id} @@ -796,14 +810,18 @@ async def delete_memory(memory_id: str): @app.delete("/v1/memories", response_model=Dict[str, Any]) @app.delete("/v1/memories/", response_model=Dict[str, Any]) async def delete_memories( + http_request: Request, user_id: Optional[str] = Query(default=None), agent_id: Optional[str] = Query(default=None), run_id: Optional[str] = Query(default=None), app_id: Optional[str] = Query(default=None), + dry_run: bool = Query(default=False, description="Preview what would be deleted without actually deleting"), ): + token = get_token_from_request(http_request) + require_token_for_untrusted_request(http_request, token) memory = get_memory() try: - return memory.delete_all(user_id=user_id, agent_id=agent_id, run_id=run_id, app_id=app_id) + return memory.delete_all(user_id=user_id, agent_id=agent_id, run_id=run_id, app_id=app_id, dry_run=dry_run) except FadeMemValidationError as exc: raise HTTPException(status_code=400, detail=exc.message) diff --git a/engram/api/schemas.py b/engram/api/schemas.py index e06ebb0..13deee9 100644 --- a/engram/api/schemas.py +++ b/engram/api/schemas.py @@ -93,7 +93,7 @@ class HandoffSessionDigestRequest(BaseModel): class SearchRequestV2(BaseModel): - query: str + query: str = Field(min_length=1, max_length=10000) user_id: str = Field(default="default") agent_id: Optional[str] = Field(default=None) limit: int = Field(default=10, ge=1, le=100) @@ -101,7 +101,7 @@ class SearchRequestV2(BaseModel): class AddMemoryRequestV2(BaseModel): - content: Optional[str] = Field(default=None) + content: Optional[str] = Field(default=None, max_length=100000) messages: Optional[Union[str, List[Dict[str, Any]]]] = Field(default=None) user_id: str = Field(default="default") agent_id: Optional[str] = Field(default=None) @@ -109,7 +109,7 @@ class AddMemoryRequestV2(BaseModel): categories: Optional[List[str]] = Field(default=None) scope: Optional[str] = Field(default="work") namespace: Optional[str] = Field(default="default") - mode: str = Field(default="staging", description="staging|direct") + mode: Literal["staging", "direct"] = Field(default="staging", description="staging|direct") infer: bool = Field(default=False) source_app: Optional[str] = Field(default=None) source_type: str = Field(default="rest") @@ -117,7 +117,7 @@ class AddMemoryRequestV2(BaseModel): class SceneSearchRequest(BaseModel): - query: str + query: str = Field(min_length=1, max_length=10000) user_id: str = Field(default="default") agent_id: Optional[str] = Field(default=None) limit: int = Field(default=10, ge=1, le=100) diff --git a/engram/configs/active.py b/engram/configs/active.py index 9e0f00a..40642ea 100644 --- a/engram/configs/active.py +++ b/engram/configs/active.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Dict -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator class TTLTier(str, Enum): @@ -25,6 +25,14 @@ class SignalScope(str, Enum): NAMESPACE = "namespace" # Only agents in same namespace +class ConsolidationConfig(BaseModel): + """Configuration for active → passive memory consolidation.""" + promote_critical: bool = True + promote_high_read: bool = True + promote_read_threshold: int = 3 + directive_to_passive: bool = True + + class ActiveMemoryConfig(BaseModel): """Configuration for the Active Memory signal bus.""" enabled: bool = True @@ -40,11 +48,18 @@ class ActiveMemoryConfig(BaseModel): consolidation_enabled: bool = True consolidation_min_age_seconds: int = 600 consolidation_min_reads: int = 3 + consolidation: ConsolidationConfig = Field(default_factory=ConsolidationConfig) + @field_validator("default_ttl_tier") + @classmethod + def _valid_ttl_tier(cls, v: str) -> str: + allowed = {t.value for t in TTLTier} + v = str(v).strip().lower() + if v not in allowed: + return TTLTier.NOTABLE.value + return v -class ConsolidationConfig(BaseModel): - """Configuration for active → passive memory consolidation.""" - promote_critical: bool = True - promote_high_read: bool = True - promote_read_threshold: int = 3 - directive_to_passive: bool = True + @field_validator("max_signals_per_response") + @classmethod + def _clamp_max_signals(cls, v: int) -> int: + return min(100, max(1, int(v))) diff --git a/engram/configs/base.py b/engram/configs/base.py index 02c69f0..56e5729 100644 --- a/engram/configs/base.py +++ b/engram/configs/base.py @@ -1,11 +1,16 @@ import os from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from engram.configs.active import ActiveMemoryConfig +_VALID_VECTOR_PROVIDERS = {"qdrant", "memory", "sqlite_vec"} +_VALID_LLM_PROVIDERS = {"gemini", "openai", "nvidia", "ollama", "mock"} +_VALID_EMBEDDER_PROVIDERS = {"gemini", "openai", "nvidia", "ollama", "simple"} + + class VectorStoreConfig(BaseModel): provider: str = Field(default="qdrant") config: Dict[str, Any] = Field( @@ -15,21 +20,45 @@ class VectorStoreConfig(BaseModel): } ) + @field_validator("provider") + @classmethod + def _valid_provider(cls, v: str) -> str: + v = str(v).strip().lower() + if v not in _VALID_VECTOR_PROVIDERS: + raise ValueError(f"Unknown vector store provider '{v}'. Valid: {sorted(_VALID_VECTOR_PROVIDERS)}") + return v + class LLMConfig(BaseModel): - provider: str = Field(default="gemini") + provider: str = Field(default="nvidia") config: Dict[str, Any] = Field( default_factory=lambda: { - "model": "gemini-2.0-flash", - "temperature": 0.1, + "model": "meta/llama-3.1-8b-instruct", + "temperature": 0.2, "max_tokens": 1024, } ) + @field_validator("provider") + @classmethod + def _valid_provider(cls, v: str) -> str: + v = str(v).strip().lower() + if v not in _VALID_LLM_PROVIDERS: + raise ValueError(f"Unknown LLM provider '{v}'. Valid: {sorted(_VALID_LLM_PROVIDERS)}") + return v + class EmbedderConfig(BaseModel): - provider: str = Field(default="gemini") - config: Dict[str, Any] = Field(default_factory=lambda: {"model": "gemini-embedding-001"}) + provider: str = Field(default="nvidia") + config: Dict[str, Any] = Field(default_factory=lambda: {"model": "nvidia/nv-embed-v1"}) + + @field_validator("provider") + @classmethod + def _valid_provider(cls, v: str) -> str: + v = str(v).strip().lower() + if v not in _VALID_EMBEDDER_PROVIDERS: + raise ValueError(f"Unknown embedder provider '{v}'. Valid: {sorted(_VALID_EMBEDDER_PROVIDERS)}") + return v class GraphStoreConfig(BaseModel): @@ -60,6 +89,20 @@ class EchoMemConfig(BaseModel): # Use question_form embedding for primary vector (better query matching) use_question_embedding: bool = True + @field_validator("default_depth") + @classmethod + def _valid_depth(cls, v: str) -> str: + allowed = {"shallow", "medium", "deep"} + v = str(v).strip().lower() + if v not in allowed: + return "medium" + return v + + @field_validator("shallow_multiplier", "medium_multiplier", "deep_multiplier") + @classmethod + def _positive_multiplier(cls, v: float) -> float: + return max(0.1, float(v)) + class CategoryMemConfig(BaseModel): """ @@ -94,6 +137,19 @@ class CategoryMemConfig(BaseModel): max_category_depth: int = 3 # Maximum nesting depth auto_create_subcategories: bool = True # Allow dynamic subcategory creation + @field_validator( + "category_decay_rate", "weak_category_threshold", + "category_boost_weight", "cross_category_boost", + ) + @classmethod + def _clamp_unit_float(cls, v: float) -> float: + return min(1.0, max(0.0, float(v))) + + @field_validator("max_category_depth") + @classmethod + def _clamp_depth(cls, v: int) -> int: + return min(10, max(1, int(v))) + class SceneConfig(BaseModel): """Configuration for episodic scene grouping.""" @@ -153,6 +209,74 @@ class ScopeConfig(BaseModel): global_weight: float = 0.92 +class DistillationConfig(BaseModel): + """Configuration for CLS Distillation Memory (hippocampus-neocortex consolidation).""" + + # Gap 1: Episodic/Semantic separation + enable_memory_types: bool = True + default_memory_type: str = "semantic" + + # Gap 2: Replay distillation + enable_distillation: bool = True + distillation_batch_size: int = 20 + distillation_min_episodes: int = 5 + distillation_scene_grouping: bool = True + distillation_time_window_hours: int = 24 + max_semantic_per_batch: int = 5 + + # Gap 3: Advanced forgetting + enable_interference_pruning: bool = True + enable_redundancy_collapse: bool = True + enable_homeostasis: bool = True + homeostasis_budget_per_namespace: int = 5000 + homeostasis_pressure_factor: float = 0.1 + redundancy_collapse_threshold: float = 0.92 + + # Gap 4: Multi-trace strength + enable_multi_trace: bool = True + s_fast_weight: float = 0.2 + s_mid_weight: float = 0.3 + s_slow_weight: float = 0.5 + s_fast_decay_rate: float = 0.20 + s_mid_decay_rate: float = 0.05 + s_slow_decay_rate: float = 0.005 + cascade_fast_to_mid: float = 0.1 + cascade_mid_to_slow: float = 0.05 + + # Gap 5: Intent routing + enable_intent_routing: bool = True + episodic_boost: float = 0.15 + semantic_boost: float = 0.15 + intersection_boost: float = 0.1 + + @field_validator("default_memory_type") + @classmethod + def _valid_memory_type(cls, v: str) -> str: + allowed = {"episodic", "semantic"} + v = str(v).strip().lower() + if v not in allowed: + return "semantic" + return v + + @field_validator( + "homeostasis_pressure_factor", "redundancy_collapse_threshold", + "s_fast_weight", "s_mid_weight", "s_slow_weight", + "s_fast_decay_rate", "s_mid_decay_rate", "s_slow_decay_rate", + "cascade_fast_to_mid", "cascade_mid_to_slow", + "episodic_boost", "semantic_boost", "intersection_boost", + ) + @classmethod + def _clamp_unit_float(cls, v: float) -> float: + return min(1.0, max(0.0, float(v))) + + @field_validator("homeostasis_budget_per_namespace", "distillation_batch_size", + "distillation_min_episodes", "distillation_time_window_hours", + "max_semantic_per_batch") + @classmethod + def _positive_int(cls, v: int) -> int: + return max(1, int(v)) + + class FadeMemConfig(BaseModel): enable_forgetting: bool = True sml_decay_rate: float = 0.15 @@ -167,6 +291,21 @@ class FadeMemConfig(BaseModel): enable_fusion: bool = True use_tombstone_deletion: bool = True + @field_validator( + "sml_decay_rate", "lml_decay_rate", "access_dampening_factor", + "promotion_strength_threshold", "forgetting_threshold", + "access_strength_boost", "conflict_similarity_threshold", + "fusion_similarity_threshold", + ) + @classmethod + def _clamp_unit_float(cls, v: float) -> float: + return min(1.0, max(0.0, float(v))) + + @field_validator("promotion_access_threshold") + @classmethod + def _positive_int(cls, v: int) -> int: + return max(1, int(v)) + class MemoryConfig(BaseModel): vector_store: VectorStoreConfig = Field(default_factory=VectorStoreConfig) @@ -177,8 +316,8 @@ class MemoryConfig(BaseModel): default_factory=lambda: os.path.join(os.path.expanduser("~"), ".engram", "history.db") ) collection_name: str = "fadem_memories" - embedding_model_dims: int = 3072 # gemini-embedding-001 default dimensions - version: str = "v1.3" # Updated for CategoryMem + embedding_model_dims: int = 4096 # nvidia/nv-embed-v1 default dimensions + version: str = "v1.4" # Updated for CLS Distillation Memory custom_fact_extraction_prompt: Optional[str] = None custom_conflict_prompt: Optional[str] = None custom_fusion_prompt: Optional[str] = None @@ -193,3 +332,12 @@ class MemoryConfig(BaseModel): profile: ProfileConfig = Field(default_factory=ProfileConfig) handoff: HandoffConfig = Field(default_factory=HandoffConfig) active: ActiveMemoryConfig = Field(default_factory=ActiveMemoryConfig) + distillation: DistillationConfig = Field(default_factory=DistillationConfig) + + @field_validator("embedding_model_dims") + @classmethod + def _valid_dims(cls, v: int) -> int: + v = int(v) + if v < 1 or v > 65536: + raise ValueError(f"embedding_model_dims must be 1-65536, got {v}") + return v diff --git a/engram/core/active_memory.py b/engram/core/active_memory.py index 8ff1f03..a0b109c 100644 --- a/engram/core/active_memory.py +++ b/engram/core/active_memory.py @@ -170,8 +170,6 @@ def read_signals( limit: Optional[int] = None, ) -> List[Dict[str, Any]]: """Read active signals, auto-GC expired, increment read counts.""" - self.gc_expired() - conditions = ["user_id = ?"] params: List[Any] = [user_id] @@ -189,6 +187,13 @@ def read_signals( effective_limit = limit or self.config.max_signals_per_response with self._get_connection() as conn: + # GC expired signals atomically within the same connection context + now = _utcnow_iso() + conn.execute( + "DELETE FROM signals WHERE expires_at IS NOT NULL AND expires_at < ? AND signal_type != 'directive'", + (now,), + ) + rows = conn.execute( f"""SELECT * FROM signals WHERE {where} ORDER BY @@ -226,7 +231,72 @@ def read_signals( "UPDATE signals SET read_count = read_count + 1, read_by = ? WHERE id = ?", (json.dumps(signal["read_by"]), signal["id"]), ) - conn.commit() + conn.commit() + + return results + + def peek_signals( + self, + *, + scope: Optional[str] = None, + scope_key: Optional[str] = None, + signal_type: Optional[str] = None, + user_id: str = "default", + limit: Optional[int] = None, + ) -> List[Dict[str, Any]]: + """Read active signals without incrementing read_count or tracking readers. + + Identical to ``read_signals`` but read-only — no writes are issued, + making it safe for automatic injection paths that should not inflate + read counts. + """ + conditions = ["user_id = ?"] + params: List[Any] = [user_id] + + if scope: + conditions.append("scope = ?") + params.append(scope) + if scope_key is not None: + conditions.append("scope_key = ?") + params.append(scope_key) + if signal_type: + conditions.append("signal_type = ?") + params.append(signal_type) + + where = " AND ".join(conditions) + effective_limit = limit or self.config.max_signals_per_response + + with self._get_connection() as conn: + # GC expired signals atomically within the same connection context + now = _utcnow_iso() + conn.execute( + "DELETE FROM signals WHERE expires_at IS NOT NULL AND expires_at < ? AND signal_type != 'directive'", + (now,), + ) + conn.commit() + + rows = conn.execute( + f"""SELECT * FROM signals WHERE {where} + ORDER BY + CASE ttl_tier + WHEN 'directive' THEN 0 + WHEN 'critical' THEN 1 + WHEN 'notable' THEN 2 + WHEN 'noise' THEN 3 + END, + created_at DESC + LIMIT ?""", + params + [effective_limit], + ).fetchall() + + results = [] + for row in rows: + signal = dict(row) + try: + signal["read_by"] = json.loads(signal.get("read_by", "[]")) + except (json.JSONDecodeError, TypeError): + signal["read_by"] = [] + results.append(signal) return results diff --git a/engram/core/category.py b/engram/core/category.py index 67785b6..a6e0b36 100644 --- a/engram/core/category.py +++ b/engram/core/category.py @@ -20,9 +20,10 @@ import json import logging +import math import uuid from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from typing import Any, Dict, List, Optional, Set, Tuple @@ -59,7 +60,7 @@ class Category: total_strength: float = 0.0 # Sum of all memory strengths access_count: int = 0 last_accessed: Optional[str] = None - created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat()) + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) # Semantic representation embedding: Optional[List[float]] = None # Category's semantic vector @@ -109,7 +110,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "Category": total_strength=data.get("total_strength", 0.0), access_count=data.get("access_count", 0), last_accessed=data.get("last_accessed"), - created_at=data.get("created_at", datetime.utcnow().isoformat()), + created_at=data.get("created_at", datetime.now(timezone.utc).isoformat()), embedding=data.get("embedding"), keywords=data.get("keywords", []), summary=data.get("summary"), @@ -380,7 +381,10 @@ def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: if norm1 == 0 or norm2 == 0: return 0.0 - return dot_product / (norm1 * norm2) + result = dot_product / (norm1 * norm2) + if math.isnan(result) or math.isinf(result): + return 0.0 + return result def _llm_detect_category( self, @@ -511,7 +515,7 @@ def access_category(self, category_id: str): cat = self.categories[category_id] cat.access_count += 1 - cat.last_accessed = datetime.utcnow().isoformat() + cat.last_accessed = datetime.now(timezone.utc).isoformat() # Strengthen category on access (bio-inspired) cat.strength = min(1.0, cat.strength + 0.02) @@ -541,7 +545,7 @@ def generate_summary(self, category_id: str, memories: List[Dict[str, Any]]) -> try: summary = self.llm.generate(prompt) cat.summary = summary.strip() - cat.summary_updated_at = datetime.utcnow().isoformat() + cat.summary_updated_at = datetime.now(timezone.utc).isoformat() return cat.summary except Exception as e: logger.warning(f"Summary generation failed for {category_id}: {e}") @@ -575,12 +579,15 @@ def apply_category_decay(self, decay_rate: float = 0.05) -> Dict[str, Any]: if cat.last_accessed: try: last_access = datetime.fromisoformat(cat.last_accessed) - days_since = (datetime.utcnow() - last_access).days + if last_access.tzinfo is None: + last_access = last_access.replace(tzinfo=timezone.utc) + days_since = (datetime.now(timezone.utc) - last_access).days decay_amount = decay_rate * (days_since / 7) # Weekly decay cat.strength = max(0.1, cat.strength - decay_amount) decayed += 1 - except Exception: - pass + except (ValueError, TypeError) as e: + logger.debug(f"Category decay calculation failed for {cat.id}: {e}") + continue # Track weak categories for potential merging if cat.strength < 0.3 and cat.memory_count < 3: @@ -623,8 +630,9 @@ def _find_merge_target(self, weak_cat: Category) -> Optional[Category]: # Check keyword overlap if weak_cat.keywords and cat.keywords: overlap = len(set(weak_cat.keywords) & set(cat.keywords)) - if overlap >= 2: - if not best_target or overlap / len(weak_cat.keywords) > best_similarity: + kw_count = len(weak_cat.keywords) + if overlap >= 2 and kw_count > 0: + if not best_target or overlap / kw_count > best_similarity: best_target = cat return best_target @@ -660,7 +668,14 @@ def _merge_categories(self, source_id: str, target_id: str): logger.info(f"Merged category {source_id} into {target_id}") def find_related_categories(self, category_id: str, limit: int = 3) -> List[str]: - """Find categories related to the given one.""" + """Find categories related to the given one. + + Note: This is O(N * D) where N is the number of categories and D is the + embedding dimensionality, because it computes cosine similarity against + every category. For very large category counts this could become a + bottleneck; consider caching related_ids or using an approximate + nearest-neighbor index if N grows large. + """ if category_id not in self.categories: return [] diff --git a/engram/core/conflict.py b/engram/core/conflict.py index 63b5661..a59ab03 100644 --- a/engram/core/conflict.py +++ b/engram/core/conflict.py @@ -1,9 +1,12 @@ import json +import logging from dataclasses import dataclass from typing import Any, Dict, Optional from engram.utils.prompts import CONFLICT_RESOLUTION_PROMPT +logger = logging.getLogger(__name__) + @dataclass class ConflictResolution: @@ -26,13 +29,26 @@ def resolve_conflict(existing_memory: Dict[str, Any], new_content: str, llm, cus try: response = llm.generate(prompt) data = json.loads(response.strip()) + try: + confidence = float(data.get("confidence", 0.5)) + except (ValueError, TypeError): + confidence = 0.5 return ConflictResolution( classification=data.get("classification", "COMPATIBLE"), - confidence=float(data.get("confidence", 0.5)), + confidence=min(1.0, max(0.0, confidence)), merged_content=data.get("merged_content"), explanation=data.get("explanation", ""), ) - except Exception: + except (json.JSONDecodeError, ValueError, TypeError) as e: + logger.warning("Conflict resolution parsing failed: %s", e) + return ConflictResolution( + classification="COMPATIBLE", + confidence=0.5, + merged_content=None, + explanation="Failed to parse LLM response", + ) + except Exception as e: + logger.warning("Conflict resolution failed: %s", e) return ConflictResolution( classification="COMPATIBLE", confidence=0.5, diff --git a/engram/core/consolidation.py b/engram/core/consolidation.py index 67fbf3d..dfaedaa 100644 --- a/engram/core/consolidation.py +++ b/engram/core/consolidation.py @@ -10,7 +10,7 @@ import logging from typing import Any, Dict, TYPE_CHECKING -from engram.configs.active import ActiveMemoryConfig, ConsolidationConfig +from engram.configs.active import ActiveMemoryConfig from engram.core.active_memory import ActiveMemoryStore if TYPE_CHECKING: @@ -31,7 +31,7 @@ def __init__( self.active = active_store self.memory = memory self.config = config - self.consolidation = ConsolidationConfig() + self.consolidation = config.consolidation def run_cycle(self) -> Dict[str, Any]: """Run one consolidation cycle. Returns promotion stats.""" diff --git a/engram/core/decay.py b/engram/core/decay.py index a026e5b..5cf2926 100644 --- a/engram/core/decay.py +++ b/engram/core/decay.py @@ -1,5 +1,5 @@ import math -from datetime import datetime +from datetime import datetime, timezone from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -15,8 +15,13 @@ def calculate_decayed_strength( ) -> float: if isinstance(last_accessed, str): last_accessed = datetime.fromisoformat(last_accessed) + if last_accessed.tzinfo is None: + last_accessed = last_accessed.replace(tzinfo=timezone.utc) - time_elapsed_days = (datetime.utcnow() - last_accessed).total_seconds() / 86400.0 + if math.isnan(current_strength): + return 0.0 + + time_elapsed_days = (datetime.now(timezone.utc) - last_accessed).total_seconds() / 86400.0 decay_rate = config.sml_decay_rate if layer == "sml" else config.lml_decay_rate access_dampening = 1 + config.access_dampening_factor * math.log1p(access_count) new_strength = current_strength * math.exp(-decay_rate * time_elapsed_days / access_dampening) @@ -24,6 +29,8 @@ def calculate_decayed_strength( def should_forget(strength: float, config: "FadeMemConfig") -> bool: + if math.isnan(strength): + return True return strength < config.forgetting_threshold diff --git a/engram/core/distillation.py b/engram/core/distillation.py new file mode 100644 index 0000000..c99ed92 --- /dev/null +++ b/engram/core/distillation.py @@ -0,0 +1,232 @@ +"""Replay-driven semantic distillation (CLS consolidation). + +During sleep cycles, the ReplayDistiller samples recent episodic memories, +groups them by scene or time window, and uses an LLM to extract durable +semantic facts. This models the hippocampus-to-neocortex transfer in +Complementary Learning Systems theory. +""" + +from __future__ import annotations + +import json +import logging +import uuid +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from engram.memory.utils import strip_code_fences +from engram.utils.prompts import DISTILLATION_PROMPT + +if TYPE_CHECKING: + from engram.configs.base import DistillationConfig + from engram.db.sqlite import SQLiteManager + from engram.llms.base import BaseLLM + +logger = logging.getLogger(__name__) + + +class ReplayDistiller: + """Extracts semantic knowledge from episodic memory batches.""" + + def __init__( + self, + db: "SQLiteManager", + llm: "BaseLLM", + config: "DistillationConfig", + ): + self.db = db + self.llm = llm + self.config = config + + def run( + self, + user_id: str, + date_str: Optional[str] = None, + memory_add_fn: Optional[Any] = None, + ) -> Dict[str, Any]: + """Run one distillation cycle for a user. + + Args: + user_id: The user whose episodic memories to distill. + date_str: Target date (defaults to yesterday). + memory_add_fn: Callable to add a memory (typically Memory.add). + Required for actual distillation; if None, dry-run only. + + Returns: + Stats dict with episodes_sampled, semantic_created, etc. + """ + if not self.config.enable_distillation: + return {"skipped": True, "reason": "distillation disabled"} + + target_date = date_str or ( + datetime.now(timezone.utc) - timedelta(days=1) + ).date().isoformat() + + window_hours = self.config.distillation_time_window_hours + created_after = f"{target_date}T00:00:00" + created_before = f"{target_date}T23:59:59.999999" + + # Sample recent episodic memories + episodes = self.db.get_episodic_memories( + user_id, + created_after=created_after, + created_before=created_before, + limit=self.config.distillation_batch_size * 5, + ) + + if len(episodes) < self.config.distillation_min_episodes: + return { + "skipped": True, + "reason": "insufficient episodes", + "episodes_found": len(episodes), + "min_required": self.config.distillation_min_episodes, + } + + # Group into batches + batches = self._group_episodes(episodes) + + total_created = 0 + total_dedup = 0 + total_errors = 0 + + for batch in batches: + try: + created, dedup = self._distill_batch( + user_id=user_id, + batch=batch, + memory_add_fn=memory_add_fn, + ) + total_created += created + total_dedup += dedup + except Exception as e: + logger.warning("Distillation batch failed: %s", e) + total_errors += 1 + + # Log the run + run_id = self.db.log_distillation_run( + user_id=user_id, + episodes_sampled=len(episodes), + semantic_created=total_created, + semantic_deduplicated=total_dedup, + errors=total_errors, + ) + + return { + "run_id": run_id, + "episodes_sampled": len(episodes), + "batches_processed": len(batches), + "semantic_created": total_created, + "semantic_deduplicated": total_dedup, + "errors": total_errors, + } + + def _group_episodes( + self, episodes: List[Dict[str, Any]] + ) -> List[List[Dict[str, Any]]]: + """Group episodes by scene_id or into time-window chunks.""" + if self.config.distillation_scene_grouping: + # Group by scene_id first + scene_groups: Dict[Optional[str], List[Dict[str, Any]]] = {} + for ep in episodes: + scene_id = ep.get("scene_id") + scene_groups.setdefault(scene_id, []).append(ep) + + batches = [] + for scene_id, group in scene_groups.items(): + # Split large scene groups into sub-batches + batch_size = self.config.distillation_batch_size + for i in range(0, len(group), batch_size): + batches.append(group[i : i + batch_size]) + return batches + + # Fallback: chunk by batch_size + batch_size = self.config.distillation_batch_size + return [ + episodes[i : i + batch_size] + for i in range(0, len(episodes), batch_size) + ] + + def _distill_batch( + self, + user_id: str, + batch: List[Dict[str, Any]], + memory_add_fn: Optional[Any], + ) -> tuple: + """Distill a single batch of episodes. Returns (created, deduplicated).""" + # Build the episodes text for the prompt + episode_texts = [] + episode_ids = [] + for ep in batch: + ep_id = ep.get("id", "unknown") + episode_ids.append(ep_id) + content = ep.get("memory", "") + created_at = ep.get("created_at", "") + episode_texts.append(f"[{ep_id}] ({created_at}): {content}") + + episodes_str = "\n".join(episode_texts) + prompt = DISTILLATION_PROMPT.format( + episodes=episodes_str, + max_facts=self.config.max_semantic_per_batch, + ) + + raw_response = self.llm.generate(prompt) + cleaned = strip_code_fences(raw_response) + + try: + parsed = json.loads(cleaned) + except (json.JSONDecodeError, TypeError): + logger.warning("Distillation LLM returned invalid JSON: %.200s", raw_response) + return (0, 0) + + facts = parsed.get("semantic_facts", []) + if not isinstance(facts, list): + return (0, 0) + + created = 0 + deduplicated = 0 + + for fact in facts[: self.config.max_semantic_per_batch]: + content = fact.get("content", "").strip() + if not content: + continue + + importance = fact.get("importance", "medium") + source_eps = fact.get("source_episodes", episode_ids) + + if memory_add_fn is not None: + result = memory_add_fn( + content, + user_id=user_id, + infer=False, + initial_layer="lml", + initial_strength=0.8, + metadata={ + "is_distilled": True, + "distillation_source_count": len(source_eps), + "importance": importance, + "memory_type": "semantic", + }, + ) + + # Check if it was deduplicated (NOOP/SUBSUMED) + results = result.get("results", []) + if results: + first = results[0] + event = first.get("event", "ADD") + if event in ("NOOP", "SUBSUMED"): + deduplicated += 1 + else: + created += 1 + # Record provenance + semantic_id = first.get("id") + if semantic_id: + try: + self.db.add_distillation_provenance( + semantic_memory_id=semantic_id, + episodic_memory_ids=source_eps, + run_id=str(uuid.uuid4()), + ) + except Exception as e: + logger.warning("Failed to record provenance: %s", e) + + return (created, deduplicated) diff --git a/engram/core/echo.py b/engram/core/echo.py index d29902e..08c96bd 100644 --- a/engram/core/echo.py +++ b/engram/core/echo.py @@ -324,7 +324,7 @@ def _medium_echo(self, content: str) -> EchoResult: echo_depth=EchoDepth.MEDIUM, strength_multiplier=self.STRENGTH_MULTIPLIERS[EchoDepth.MEDIUM], ) - except Exception as e: + except (json.JSONDecodeError, ValueError, KeyError, AttributeError) as e: logger.warning(f"Medium echo failed, falling back to shallow: {e}") return self._shallow_echo(content) @@ -357,7 +357,7 @@ def _deep_echo(self, content: str) -> EchoResult: echo_depth=EchoDepth.DEEP, strength_multiplier=self.STRENGTH_MULTIPLIERS[EchoDepth.DEEP], ) - except Exception as e: + except (json.JSONDecodeError, ValueError, KeyError, AttributeError) as e: logger.warning(f"Deep echo failed, falling back to medium: {e}") return self._medium_echo(content) diff --git a/engram/core/forgetting.py b/engram/core/forgetting.py new file mode 100644 index 0000000..ae9e7b0 --- /dev/null +++ b/engram/core/forgetting.py @@ -0,0 +1,296 @@ +"""Advanced forgetting mechanisms for CLS Distillation Memory. + +Three biologically-inspired forgetting mechanisms beyond simple exponential decay: +1. InterferencePruner — contradictory memories demote each other +2. RedundancyCollapser — near-duplicate memories auto-fuse +3. HomeostaticNormalizer — memory budget enforcement with pressure-based decay +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + from engram.configs.base import DistillationConfig, FadeMemConfig + from engram.db.sqlite import SQLiteManager + +logger = logging.getLogger(__name__) + + +class InterferencePruner: + """Demote contradictory memories discovered during decay cycles. + + For memories above a minimum strength, finds nearest neighbors and + checks for contradiction. If contradictory, the weaker memory gets demoted. + """ + + def __init__( + self, + db: "SQLiteManager", + config: "DistillationConfig", + fadem_config: "FadeMemConfig", + resolve_conflict_fn=None, + search_fn=None, + llm=None, + ): + self.db = db + self.config = config + self.fadem_config = fadem_config + self.resolve_conflict_fn = resolve_conflict_fn + self.search_fn = search_fn + self.llm = llm + + def run( + self, + memories: List[Dict[str, Any]], + user_id: Optional[str] = None, + ) -> Dict[str, int]: + """Check memories for interference and demote contradictions. + + Returns {"checked": N, "demoted": N}. + """ + if not self.config.enable_interference_pruning: + return {"checked": 0, "demoted": 0} + + if not self.resolve_conflict_fn or not self.search_fn: + return {"checked": 0, "demoted": 0} + + checked = 0 + demoted = 0 + min_strength = 0.2 + + for memory in memories: + if memory.get("immutable"): + continue + strength = float(memory.get("strength", 0.0)) + if strength < min_strength: + continue + + embedding = memory.get("embedding") + if not embedding: + continue + + checked += 1 + + # Find nearest neighbor + try: + filters = {"user_id": user_id} if user_id else {} + neighbors = self.search_fn( + query="", + vectors=embedding, + limit=2, + filters=filters, + ) + # Skip self + neighbors = [n for n in neighbors if n.id != memory["id"]] + if not neighbors: + continue + + nearest = neighbors[0] + similarity = float(nearest.score) + + if similarity < self.fadem_config.conflict_similarity_threshold: + continue + + # Fetch the neighbor memory from DB + neighbor_mem = self.db.get_memory(nearest.id) + if not neighbor_mem: + continue + + # Check for contradiction + resolution = self.resolve_conflict_fn( + neighbor_mem, memory.get("memory", ""), self.llm + ) + + if resolution and resolution.classification == "CONTRADICTORY": + # Demote the weaker one + mem_strength = float(memory.get("strength", 0.0)) + neighbor_strength = float(neighbor_mem.get("strength", 0.0)) + + if mem_strength <= neighbor_strength: + target_id = memory["id"] + old_strength = mem_strength + else: + target_id = neighbor_mem["id"] + old_strength = neighbor_strength + + new_strength = old_strength * 0.3 + self.db.update_memory(target_id, {"strength": new_strength}) + self.db.log_event( + target_id, + "INTERFERENCE_DEMOTE", + old_strength=old_strength, + new_strength=new_strength, + ) + demoted += 1 + + except Exception as e: + logger.debug("Interference check failed for %s: %s", memory.get("id"), e) + + return {"checked": checked, "demoted": demoted} + + +class RedundancyCollapser: + """Auto-fuse near-duplicate memories to reduce bloat. + + During decay cycles, finds clusters of highly similar memories + and fuses them using the existing fusion pipeline. + """ + + def __init__( + self, + db: "SQLiteManager", + config: "DistillationConfig", + fuse_fn=None, + search_fn=None, + ): + self.db = db + self.config = config + self.fuse_fn = fuse_fn + self.search_fn = search_fn + + def run( + self, + memories: List[Dict[str, Any]], + user_id: Optional[str] = None, + ) -> Dict[str, int]: + """Find and fuse redundant memory groups. + + Returns {"groups_fused": N, "memories_fused": N}. + """ + if not self.config.enable_redundancy_collapse: + return {"groups_fused": 0, "memories_fused": 0} + + if not self.fuse_fn or not self.search_fn: + return {"groups_fused": 0, "memories_fused": 0} + + threshold = self.config.redundancy_collapse_threshold + groups_fused = 0 + memories_fused = 0 + already_fused = set() + + for memory in memories: + mid = memory.get("id") + if mid in already_fused: + continue + if memory.get("immutable"): + continue + + embedding = memory.get("embedding") + if not embedding: + continue + + try: + filters = {"user_id": user_id} if user_id else {} + neighbors = self.search_fn( + query="", + vectors=embedding, + limit=5, + filters=filters, + ) + # Find highly similar memories + group_ids = [mid] + for n in neighbors: + if n.id == mid or n.id in already_fused: + continue + n_mem = self.db.get_memory(n.id) + if not n_mem or n_mem.get("immutable"): + continue + if float(n.score) >= threshold: + group_ids.append(n.id) + + if len(group_ids) >= 2: + result = self.fuse_fn(group_ids, user_id=user_id) + if result and not result.get("error"): + already_fused.update(group_ids) + groups_fused += 1 + memories_fused += len(group_ids) + + except Exception as e: + logger.debug("Redundancy collapse failed for %s: %s", mid, e) + + return {"groups_fused": groups_fused, "memories_fused": memories_fused} + + +class HomeostaticNormalizer: + """Enforce memory budgets per namespace with pressure-based decay. + + When a namespace exceeds its budget, applies extra decay pressure + to the weakest memories proportional to the excess ratio. + """ + + def __init__( + self, + db: "SQLiteManager", + config: "DistillationConfig", + fadem_config: "FadeMemConfig", + delete_fn=None, + ): + self.db = db + self.config = config + self.fadem_config = fadem_config + self.delete_fn = delete_fn + + def run( + self, + user_id: str, + ) -> Dict[str, Any]: + """Apply homeostatic pressure to namespaces over budget. + + Returns {"namespaces_over_budget": N, "pressured": N, "forgotten": N}. + """ + if not self.config.enable_homeostasis: + return {"namespaces_over_budget": 0, "pressured": 0, "forgotten": 0} + + counts = self.db.get_memory_count_by_namespace(user_id) + budget = self.config.homeostasis_budget_per_namespace + pressure_factor = self.config.homeostasis_pressure_factor + + namespaces_over = 0 + total_pressured = 0 + total_forgotten = 0 + + for namespace, count in counts.items(): + if count <= budget: + continue + + namespaces_over += 1 + excess_ratio = (count - budget) / budget + + # Fetch weakest memories in this namespace + weak_memories = self.db.get_all_memories( + user_id=user_id, + namespace=namespace, + min_strength=0.0, + limit=count, + ) + + # Sort by strength ascending (weakest first) + weak_memories.sort(key=lambda m: float(m.get("strength", 0.0))) + + for memory in weak_memories: + if memory.get("immutable"): + continue + + strength = float(memory.get("strength", 0.0)) + # Apply extra decay proportional to excess + pressure = strength * pressure_factor * excess_ratio + new_strength = max(0.0, strength - pressure) + + if new_strength < self.fadem_config.forgetting_threshold: + if self.delete_fn: + try: + self.delete_fn(memory["id"]) + total_forgotten += 1 + except Exception as e: + logger.debug("Homeostasis delete failed for %s: %s", memory["id"], e) + else: + self.db.update_memory(memory["id"], {"strength": new_strength}) + total_pressured += 1 + + return { + "namespaces_over_budget": namespaces_over, + "pressured": total_pressured, + "forgotten": total_forgotten, + } diff --git a/engram/core/fusion.py b/engram/core/fusion.py index 63b8614..305ea43 100644 --- a/engram/core/fusion.py +++ b/engram/core/fusion.py @@ -1,9 +1,12 @@ import json +import logging from dataclasses import dataclass from typing import Any, Dict, List, Optional from engram.utils.prompts import FUSION_PROMPT +logger = logging.getLogger(__name__) + @dataclass class FusedMemory: @@ -14,7 +17,10 @@ class FusedMemory: layer: str = "lml" -def fuse_memories(memories: List[Dict[str, Any]], llm, custom_prompt: Optional[str] = None) -> FusedMemory: +def fuse_memories(memories: List[Dict[str, Any]], llm, custom_prompt: Optional[str] = None) -> Optional[FusedMemory]: + if not memories: + return None + memories_text = "\n\n".join( [ f"Memory {i + 1} (strength={m.get('strength', 1.0):.2f}, accessed={m.get('access_count', 0)}x, created_at={m.get('created_at', '')}):\n{m.get('memory', '')}" @@ -28,11 +34,27 @@ def fuse_memories(memories: List[Dict[str, Any]], llm, custom_prompt: Optional[s response = llm.generate(prompt) data = json.loads(response.strip()) fused_content = data.get("consolidated_memory", "") - except Exception: + except (json.JSONDecodeError, ValueError, TypeError) as e: + logger.warning("Fusion LLM parsing failed: %s", e) + fused_content = " | ".join([m.get("memory", "") for m in memories]) + except Exception as e: + logger.warning("Fusion LLM call failed: %s", e) fused_content = " | ".join([m.get("memory", "") for m in memories]) - avg_strength = sum(m.get("strength", 1.0) for m in memories) / len(memories) - total_access = sum(m.get("access_count", 0) for m in memories) + def _safe_float(val, default: float) -> float: + try: + return float(val) + except (ValueError, TypeError): + return default + + def _safe_int(val, default: int) -> int: + try: + return int(val) + except (ValueError, TypeError): + return default + + avg_strength = sum(_safe_float(m.get("strength", 1.0), 1.0) for m in memories) / len(memories) + total_access = sum(_safe_int(m.get("access_count", 0), 0) for m in memories) return FusedMemory( content=fused_content, diff --git a/engram/core/graph.py b/engram/core/graph.py index ebaad05..23e18e3 100644 --- a/engram/core/graph.py +++ b/engram/core/graph.py @@ -9,6 +9,7 @@ import re import json import logging +from collections import deque from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set, Tuple from enum import Enum @@ -333,10 +334,11 @@ def get_related_memories( """ visited = {memory_id} results = [] - queue = [(memory_id, 0, [])] + # Use deque for O(1) popleft instead of list.pop(0) which is O(n). + queue = deque([(memory_id, 0, [])]) while queue: - current_id, depth, path = queue.pop(0) + current_id, depth, path = queue.popleft() if depth >= max_depth: continue diff --git a/engram/core/handoff_bus.py b/engram/core/handoff_bus.py index 46f8759..62bc907 100644 --- a/engram/core/handoff_bus.py +++ b/engram/core/handoff_bus.py @@ -78,6 +78,8 @@ def __init__( fallback=["active", "paused"], ) self.lane_inactivity_minutes = int(cfg.get("lane_inactivity_minutes", 240)) + # Cache bootstrapped policies per-instance to avoid a DB query on every checkpoint/resume. + self._bootstrapped_policies: set = set() self.auto_trusted_agents = { str(agent).strip().lower() for agent in cfg.get( @@ -568,9 +570,6 @@ def list_sessions( # Internal helpers # ------------------------------------------------------------------ - # Cache bootstrapped policies to avoid a DB query on every checkpoint/resume. - _bootstrapped_policies: set = set() - def _bootstrap_auto_trusted_policy(self, *, user_id: str, agent_id: Optional[str], namespace: str) -> None: if not self.allow_auto_trusted_bootstrap: return diff --git a/engram/core/intent.py b/engram/core/intent.py new file mode 100644 index 0000000..10a266d --- /dev/null +++ b/engram/core/intent.py @@ -0,0 +1,93 @@ +"""Query intent classifier for retrieval routing. + +Regex-based classifier (zero LLM cost, sub-millisecond) that determines +whether a query targets episodic memories (conversations, events), +semantic memories (facts, preferences), or is ambiguous (mixed). +""" + +from __future__ import annotations + +import re +from enum import Enum +from typing import List, Tuple + + +class QueryIntent(str, Enum): + EPISODIC = "episodic" # "when did", "last time", "what happened", "ago" + SEMANTIC = "semantic" # "what is", "prefer", "tell me about", "favorite" + MIXED = "mixed" # ambiguous or both signals + + +# Patterns that signal episodic (event/time-based) queries +_EPISODIC_PATTERNS: List[Tuple[re.Pattern, float]] = [ + (re.compile(r"\bwhen did\b", re.I), 1.0), + (re.compile(r"\blast time\b", re.I), 1.0), + (re.compile(r"\bwhat happened\b", re.I), 1.0), + (re.compile(r"\bdo you remember\b", re.I), 0.8), + (re.compile(r"\brecall\b", re.I), 0.6), + (re.compile(r"\b\d+\s*(days?|weeks?|months?|hours?)\s+ago\b", re.I), 1.0), + (re.compile(r"\byesterday\b", re.I), 0.9), + (re.compile(r"\blast (week|month|year|session|conversation)\b", re.I), 1.0), + (re.compile(r"\bwe (discussed|talked|mentioned|said)\b", re.I), 0.9), + (re.compile(r"\bi (said|told|mentioned|asked)\b", re.I), 0.8), + (re.compile(r"\bwhat did (i|we|you)\b", re.I), 0.9), + (re.compile(r"\bhistory of\b", re.I), 0.7), + (re.compile(r"\btimeline\b", re.I), 0.7), + (re.compile(r"\bsequence of events\b", re.I), 1.0), + (re.compile(r"\bfirst time\b", re.I), 0.8), + (re.compile(r"\bhow many times\b", re.I), 0.7), +] + +# Patterns that signal semantic (fact/knowledge-based) queries +_SEMANTIC_PATTERNS: List[Tuple[re.Pattern, float]] = [ + (re.compile(r"\bwhat is\b", re.I), 0.8), + (re.compile(r"\bwhat are\b", re.I), 0.7), + (re.compile(r"\bwhat'?s my\b", re.I), 0.9), + (re.compile(r"\bprefer\b", re.I), 0.9), + (re.compile(r"\bfavorite\b", re.I), 0.9), + (re.compile(r"\btell me about\b", re.I), 0.7), + (re.compile(r"\bwho is\b", re.I), 0.7), + (re.compile(r"\bexplain\b", re.I), 0.6), + (re.compile(r"\bdescribe\b", re.I), 0.6), + (re.compile(r"\bhow (do|does|to)\b", re.I), 0.7), + (re.compile(r"\bprocess for\b", re.I), 0.8), + (re.compile(r"\bsteps to\b", re.I), 0.7), + (re.compile(r"\bprocedure\b", re.I), 0.7), + (re.compile(r"\bworkflow\b", re.I), 0.7), + (re.compile(r"\bdefault\b", re.I), 0.5), + (re.compile(r"\busually\b", re.I), 0.6), + (re.compile(r"\balways\b", re.I), 0.5), + (re.compile(r"\bnever\b", re.I), 0.5), +] + + +def classify_intent(query: str) -> QueryIntent: + """Classify a search query as episodic, semantic, or mixed. + + Returns QueryIntent enum based on regex pattern matching. + Zero LLM cost, sub-millisecond execution. + """ + if not query or not query.strip(): + return QueryIntent.MIXED + + episodic_score = 0.0 + semantic_score = 0.0 + + for pattern, weight in _EPISODIC_PATTERNS: + if pattern.search(query): + episodic_score += weight + + for pattern, weight in _SEMANTIC_PATTERNS: + if pattern.search(query): + semantic_score += weight + + if episodic_score == 0.0 and semantic_score == 0.0: + return QueryIntent.MIXED + + # Require clear dominance (>1.5x) to declare a specific intent + if episodic_score > semantic_score * 1.5: + return QueryIntent.EPISODIC + if semantic_score > episodic_score * 1.5: + return QueryIntent.SEMANTIC + + return QueryIntent.MIXED diff --git a/engram/core/kernel.py b/engram/core/kernel.py index b2991b8..73b8cc7 100644 --- a/engram/core/kernel.py +++ b/engram/core/kernel.py @@ -1312,6 +1312,52 @@ def run_sleep_cycle( if apply_decay: user_stats["decay"] = self.memory.apply_decay(scope={"user_id": uid}) + + # CLS Distillation: replay distillation + trace cascade during sleep + distillation_config = getattr(self.memory.config, "distillation", None) + if distillation_config: + # Gap 2: Replay distillation + if distillation_config.enable_distillation: + try: + from engram.core.distillation import ReplayDistiller + distiller = ReplayDistiller( + db=self.db, + llm=self.memory.llm, + config=distillation_config, + ) + user_stats["distillation"] = distiller.run( + user_id=uid, + date_str=target_date, + memory_add_fn=self.memory.add, + ) + except Exception as e: + user_stats["distillation"] = {"error": str(e)} + + # Gap 4: Cascade traces (deep sleep) + if distillation_config.enable_multi_trace: + try: + from engram.core.traces import cascade_traces, compute_effective_strength + traced_memories = self.db.get_all_memories( + user_id=uid, + ) + cascade_count = 0 + for mem in traced_memories: + if mem.get("s_fast") is None: + continue + s_f, s_m, s_s = cascade_traces( + s_fast=float(mem.get("s_fast", 0.0)), + s_mid=float(mem.get("s_mid", 0.0)), + s_slow=float(mem.get("s_slow", 0.0)), + config=distillation_config, + deep_sleep=True, + ) + eff = compute_effective_strength(s_f, s_m, s_s, distillation_config) + self.db.update_multi_trace(mem["id"], s_f, s_m, s_s, eff) + cascade_count += 1 + user_stats["trace_cascades"] = cascade_count + except Exception as e: + user_stats["trace_cascades"] = {"error": str(e)} + summary["users"][uid] = user_stats if cleanup_stale_refs: diff --git a/engram/core/provenance.py b/engram/core/provenance.py index 920e99b..2607f2d 100644 --- a/engram/core/provenance.py +++ b/engram/core/provenance.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass, asdict -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, Optional @@ -19,7 +19,7 @@ class Provenance: def to_dict(self) -> Dict[str, Any]: data = asdict(self) if not data["created_at"]: - data["created_at"] = datetime.utcnow().isoformat() + data["created_at"] = datetime.now(timezone.utc).isoformat() return data diff --git a/engram/core/traces.py b/engram/core/traces.py new file mode 100644 index 0000000..bc6287e --- /dev/null +++ b/engram/core/traces.py @@ -0,0 +1,112 @@ +"""Benna-Fusi inspired multi-timescale strength traces. + +Each memory has three traces (fast, mid, slow) that decay at different rates +and cascade information from fast → mid → slow during sleep cycles. +This mimics how synaptic plasticity operates at multiple timescales in biological memory. +""" + +from __future__ import annotations + +import math +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Tuple + +if TYPE_CHECKING: + from engram.configs.base import DistillationConfig + + +def initialize_traces( + strength: float, is_new: bool = True +) -> Tuple[float, float, float]: + """Initialize (s_fast, s_mid, s_slow) for a memory. + + New memories: all strength in fast trace. + Migrated memories: spread across fast and mid. + """ + strength = max(0.0, min(1.0, float(strength))) + if is_new: + return (strength, 0.0, 0.0) + return (strength, strength * 0.5, 0.0) + + +def compute_effective_strength( + s_fast: float, s_mid: float, s_slow: float, config: "DistillationConfig" +) -> float: + """Weighted combination of three traces into a single effective strength.""" + effective = ( + config.s_fast_weight * s_fast + + config.s_mid_weight * s_mid + + config.s_slow_weight * s_slow + ) + return max(0.0, min(1.0, effective)) + + +def decay_traces( + s_fast: float, + s_mid: float, + s_slow: float, + last_accessed: datetime, + access_count: int, + config: "DistillationConfig", +) -> Tuple[float, float, float]: + """Decay each trace independently at its own rate. + + Access count provides dampening (more accessed = slower decay), + mirroring the access-dampened decay in FadeMem. + """ + if isinstance(last_accessed, str): + last_accessed = datetime.fromisoformat(last_accessed) + if last_accessed.tzinfo is None: + last_accessed = last_accessed.replace(tzinfo=timezone.utc) + + elapsed_days = (datetime.now(timezone.utc) - last_accessed).total_seconds() / 86400.0 + dampening = 1.0 + 0.5 * math.log1p(access_count) + + new_fast = s_fast * math.exp(-config.s_fast_decay_rate * elapsed_days / dampening) + new_mid = s_mid * math.exp(-config.s_mid_decay_rate * elapsed_days / dampening) + new_slow = s_slow * math.exp(-config.s_slow_decay_rate * elapsed_days / dampening) + + return ( + max(0.0, min(1.0, new_fast)), + max(0.0, min(1.0, new_mid)), + max(0.0, min(1.0, new_slow)), + ) + + +def cascade_traces( + s_fast: float, + s_mid: float, + s_slow: float, + config: "DistillationConfig", + deep_sleep: bool = False, +) -> Tuple[float, float, float]: + """Transfer strength from faster traces to slower traces. + + Normal: fast → mid transfer only. + Deep sleep: fast → mid AND mid → slow transfer. + """ + fast_to_mid = s_fast * config.cascade_fast_to_mid + new_fast = s_fast - fast_to_mid + new_mid = s_mid + fast_to_mid + + if deep_sleep: + mid_to_slow = new_mid * config.cascade_mid_to_slow + new_mid = new_mid - mid_to_slow + new_slow = s_slow + mid_to_slow + else: + new_slow = s_slow + + return ( + max(0.0, min(1.0, new_fast)), + max(0.0, min(1.0, new_mid)), + max(0.0, min(1.0, new_slow)), + ) + + +def boost_fast_trace(s_fast: float, boost: float) -> float: + """On access, only the fast trace gets boosted (not mid/slow). + + This models how recent retrieval strengthens short-term plasticity + without directly affecting consolidated long-term traces. + """ + return max(0.0, min(1.0, s_fast + boost)) diff --git a/engram/db/async_sqlite.py b/engram/db/async_sqlite.py index 5303b44..c792ca5 100644 --- a/engram/db/async_sqlite.py +++ b/engram/db/async_sqlite.py @@ -8,7 +8,7 @@ import json import os from contextlib import asynccontextmanager -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, List, Optional try: @@ -59,7 +59,10 @@ async def initialize(self) -> None: embedding TEXT, related_memories TEXT DEFAULT '[]', source_memories TEXT DEFAULT '[]', - tombstone INTEGER DEFAULT 0 + tombstone INTEGER DEFAULT 0, + namespace TEXT DEFAULT 'default', + confidentiality_scope TEXT DEFAULT 'work', + importance REAL DEFAULT 0.5 ); CREATE INDEX IF NOT EXISTS idx_user_layer ON memories(user_id, layer); @@ -121,9 +124,11 @@ async def initialize(self) -> None: @asynccontextmanager async def _get_connection(self): - """Get a database connection.""" + """Get a database connection with WAL mode and busy timeout.""" conn = await aiosqlite.connect(self.db_path) conn.row_factory = aiosqlite.Row + await conn.execute("PRAGMA journal_mode=WAL") + await conn.execute("PRAGMA busy_timeout=5000") try: yield conn finally: @@ -149,7 +154,7 @@ async def add_memory( await self.initialize() async with self._get_connection() as conn: - now = datetime.utcnow().isoformat() + now = datetime.now(timezone.utc).isoformat() await conn.execute( """ INSERT INTO memories ( @@ -245,7 +250,7 @@ async def update_memory( await self.initialize() updates = ["updated_at = ?"] - params = [datetime.utcnow().isoformat()] + params = [datetime.now(timezone.utc).isoformat()] if content is not None: updates.append("memory = ?") @@ -296,7 +301,7 @@ async def increment_access(self, memory_id: str) -> int: await self.initialize() async with self._get_connection() as conn: - now = datetime.utcnow().isoformat() + now = datetime.now(timezone.utc).isoformat() await conn.execute( """ UPDATE memories @@ -437,3 +442,7 @@ def _row_to_dict(self, row) -> Dict[str, Any]: result["embedding"] = None return result + + async def close(self) -> None: + """No-op for per-call connection model. Exists for interface compatibility.""" + pass diff --git a/engram/db/sqlite.py b/engram/db/sqlite.py index eba54d3..035ee98 100644 --- a/engram/db/sqlite.py +++ b/engram/db/sqlite.py @@ -18,6 +18,7 @@ "decay_lambda", "status", "importance", "sensitivity", "namespace", "access_count", "last_accessed", "immutable", "expiration_date", "scene_id", "user_id", "agent_id", "run_id", "app_id", + "memory_type", "s_fast", "s_mid", "s_slow", }) VALID_SCENE_COLUMNS = frozenset({ @@ -32,6 +33,24 @@ "embedding", "strength", "updated_at", "role_bias", "profile_summary", }) +VALID_PROPOSAL_COMMIT_COLUMNS = frozenset({ + "status", "checks", "preview", "provenance", "updated_at", +}) + +VALID_HANDOFF_SESSION_COLUMNS = frozenset({ + "status", "task_summary", "decisions_made", "files_touched", + "todos_remaining", "context_snapshot", "linked_memory_ids", + "linked_scene_ids", "ended_at", "updated_at", "repo_id", + "blockers", "key_commands", "test_results", "lane_id", + "last_checkpoint_at", "namespace", "confidentiality_scope", +}) + +VALID_HANDOFF_LANE_COLUMNS = frozenset({ + "status", "objective", "current_state", "namespace", + "confidentiality_scope", "last_checkpoint_at", "version", + "branch", "lane_type", "repo_id", "repo_path", "updated_at", +}) + def _utcnow() -> datetime: """Return current UTC datetime (timezone-aware).""" @@ -53,7 +72,7 @@ def __init__(self, db_path: str): self._conn = sqlite3.connect(db_path, check_same_thread=False) self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA busy_timeout=5000") - self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.execute("PRAGMA synchronous=FULL") self._conn.execute("PRAGMA cache_size=-8000") # 8MB cache self._conn.execute("PRAGMA temp_store=MEMORY") self._conn.row_factory = sqlite3.Row @@ -546,6 +565,29 @@ def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: ); CREATE INDEX IF NOT EXISTS idx_handoff_conflicts_lane ON handoff_lane_conflicts(lane_id, created_at DESC); """, + "v2_013": """ + CREATE TABLE IF NOT EXISTS distillation_provenance ( + id TEXT PRIMARY KEY, + semantic_memory_id TEXT NOT NULL, + episodic_memory_id TEXT NOT NULL, + distillation_run_id TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS idx_distill_prov_semantic ON distillation_provenance(semantic_memory_id); + CREATE INDEX IF NOT EXISTS idx_distill_prov_episodic ON distillation_provenance(episodic_memory_id); + CREATE INDEX IF NOT EXISTS idx_distill_prov_run ON distillation_provenance(distillation_run_id); + + CREATE TABLE IF NOT EXISTS distillation_log ( + id TEXT PRIMARY KEY, + run_at TEXT DEFAULT CURRENT_TIMESTAMP, + user_id TEXT, + episodes_sampled INTEGER DEFAULT 0, + semantic_created INTEGER DEFAULT 0, + semantic_deduplicated INTEGER DEFAULT 0, + errors INTEGER DEFAULT 0 + ); + CREATE INDEX IF NOT EXISTS idx_distill_log_user ON distillation_log(user_id, run_at DESC); + """, } for version, ddl in migrations.items(): @@ -734,6 +776,32 @@ def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_columns_complete')" ) + # CLS Distillation Memory columns (idempotent). + self._ensure_cls_columns(conn) + + def _ensure_cls_columns(self, conn: sqlite3.Connection) -> None: + """Add CLS Distillation Memory columns to memories table (idempotent).""" + if self._is_migration_applied(conn, "v2_cls_columns_complete"): + return + + self._migrate_add_column_conn(conn, "memories", "memory_type", "TEXT DEFAULT 'semantic'") + self._migrate_add_column_conn(conn, "memories", "s_fast", "REAL DEFAULT NULL") + self._migrate_add_column_conn(conn, "memories", "s_mid", "REAL DEFAULT NULL") + self._migrate_add_column_conn(conn, "memories", "s_slow", "REAL DEFAULT NULL") + + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_memories_memory_type ON memories(memory_type, user_id)" + ) + + # Backfill: set memory_type to 'semantic' for existing memories. + conn.execute( + "UPDATE memories SET memory_type = 'semantic' WHERE memory_type IS NULL" + ) + + conn.execute( + "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_cls_columns_complete')" + ) + def _seed_default_namespaces(self, conn: sqlite3.Connection) -> None: users = conn.execute( """ @@ -890,8 +958,9 @@ def add_memory(self, memory_data: Dict[str, Any]) -> str: created_at, updated_at, layer, strength, access_count, last_accessed, embedding, related_memories, source_memories, tombstone, confidentiality_scope, namespace, source_type, source_app, source_event_id, decay_lambda, - status, importance, sensitivity - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + status, importance, sensitivity, + memory_type, s_fast, s_mid, s_slow + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( memory_id, @@ -923,6 +992,10 @@ def add_memory(self, memory_data: Dict[str, Any]) -> str: memory_data.get("status", "active"), memory_data.get("importance", metadata.get("importance", 0.5)), memory_data.get("sensitivity", metadata.get("sensitivity", "normal")), + memory_data.get("memory_type", "semantic"), + memory_data.get("s_fast"), + memory_data.get("s_mid"), + memory_data.get("s_slow"), ), ) @@ -1007,6 +1080,7 @@ def get_all_memories( include_tombstoned: bool = False, created_after: Optional[str] = None, created_before: Optional[str] = None, + limit: Optional[int] = None, ) -> List[Dict[str, Any]]: query = "SELECT * FROM memories WHERE strength >= ?" params: List[Any] = [min_strength] @@ -1040,6 +1114,11 @@ def get_all_memories( query += " ORDER BY strength DESC" + # Apply SQL-level LIMIT to avoid fetching unbounded rows into memory. + if limit is not None and limit > 0: + query += " LIMIT ?" + params.append(limit) + with self._get_connection() as conn: rows = conn.execute(query, params).fetchall() return [self._row_to_dict(row) for row in rows] @@ -1216,9 +1295,129 @@ def log_decay(self, decayed: int, forgotten: int, promoted: int, storage_before_ ) def purge_tombstoned(self) -> int: + """Permanently delete all tombstoned memories. This is IRREVERSIBLE.""" with self._get_connection() as conn: - cursor = conn.execute("DELETE FROM memories WHERE tombstone = 1") - return cursor.rowcount + # Log what will be purged before deletion for audit trail. + rows = conn.execute( + "SELECT id, user_id, memory FROM memories WHERE tombstone = 1" + ).fetchall() + count = len(rows) + if count > 0: + ids = [row["id"] for row in rows] + logger.warning( + "purge_tombstoned: permanently deleting %d memories: %s", + count, + ids, + ) + for row in rows: + conn.execute( + """INSERT INTO memory_history (memory_id, event, old_value, new_value, + old_strength, new_strength, old_layer, new_layer) + VALUES (?, ?, ?, NULL, NULL, NULL, NULL, NULL)""", + (row["id"], "PURGE", row["memory"]), + ) + conn.execute("DELETE FROM memories WHERE tombstone = 1") + return count + + # CLS Distillation Memory helpers + + def get_episodic_memories( + self, + user_id: str, + *, + scene_id: Optional[str] = None, + created_after: Optional[str] = None, + created_before: Optional[str] = None, + limit: int = 100, + namespace: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Fetch episodic-type memories for a user, optionally filtered by scene/time.""" + query = "SELECT * FROM memories WHERE user_id = ? AND memory_type = 'episodic' AND tombstone = 0" + params: List[Any] = [user_id] + if scene_id: + query += " AND scene_id = ?" + params.append(scene_id) + if created_after: + query += " AND created_at >= ?" + params.append(created_after) + if created_before: + query += " AND created_at <= ?" + params.append(created_before) + if namespace: + query += " AND namespace = ?" + params.append(namespace) + query += " ORDER BY created_at DESC LIMIT ?" + params.append(limit) + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._row_to_dict(row) for row in rows] + + def add_distillation_provenance( + self, + semantic_memory_id: str, + episodic_memory_ids: List[str], + run_id: str, + ) -> None: + """Record which episodic memories contributed to a distilled semantic memory.""" + with self._get_connection() as conn: + for ep_id in episodic_memory_ids: + conn.execute( + """ + INSERT INTO distillation_provenance (id, semantic_memory_id, episodic_memory_id, distillation_run_id) + VALUES (?, ?, ?, ?) + """, + (str(uuid.uuid4()), semantic_memory_id, ep_id, run_id), + ) + + def log_distillation_run( + self, + user_id: str, + episodes_sampled: int, + semantic_created: int, + semantic_deduplicated: int = 0, + errors: int = 0, + ) -> str: + """Log a distillation run and return the run ID.""" + run_id = str(uuid.uuid4()) + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO distillation_log (id, user_id, episodes_sampled, semantic_created, semantic_deduplicated, errors) + VALUES (?, ?, ?, ?, ?, ?) + """, + (run_id, user_id, episodes_sampled, semantic_created, semantic_deduplicated, errors), + ) + return run_id + + def get_memory_count_by_namespace(self, user_id: str) -> Dict[str, int]: + """Return {namespace: count} for active memories of a user.""" + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT COALESCE(namespace, 'default') AS ns, COUNT(*) AS cnt + FROM memories + WHERE user_id = ? AND tombstone = 0 + GROUP BY ns + """, + (user_id,), + ).fetchall() + return {row["ns"]: row["cnt"] for row in rows} + + def update_multi_trace( + self, + memory_id: str, + s_fast: float, + s_mid: float, + s_slow: float, + effective_strength: float, + ) -> bool: + """Update multi-trace columns and effective strength for a memory.""" + return self.update_memory(memory_id, { + "s_fast": s_fast, + "s_mid": s_mid, + "s_slow": s_slow, + "strength": effective_strength, + }) # CategoryMem methods def save_category(self, category_data: Dict[str, Any]) -> str: @@ -1285,12 +1484,48 @@ def delete_category(self, category_id: str) -> bool: return True def save_all_categories(self, categories: List[Dict[str, Any]]) -> int: - """Save multiple categories (batch operation).""" - count = 0 + """Save multiple categories in a single transaction for performance.""" + if not categories: + return 0 + rows = [] for cat in categories: - self.save_category(cat) - count += 1 - return count + cat_id = cat.get("id") + if not cat_id: + continue + rows.append(( + cat_id, + cat.get("name", ""), + cat.get("description", ""), + cat.get("category_type", "dynamic"), + cat.get("parent_id"), + json.dumps(cat.get("children_ids", [])), + cat.get("memory_count", 0), + cat.get("total_strength", 0.0), + cat.get("access_count", 0), + cat.get("last_accessed"), + cat.get("created_at"), + json.dumps(cat.get("embedding")) if cat.get("embedding") else None, + json.dumps(cat.get("keywords", [])), + cat.get("summary"), + cat.get("summary_updated_at"), + json.dumps(cat.get("related_ids", [])), + cat.get("strength", 1.0), + )) + if not rows: + return 0 + with self._get_connection() as conn: + conn.executemany( + """ + INSERT OR REPLACE INTO categories ( + id, name, description, category_type, parent_id, + children_ids, memory_count, total_strength, access_count, + last_accessed, created_at, embedding, keywords, + summary, summary_updated_at, related_ids, strength + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + return len(rows) def _category_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: """Convert a category row to dict.""" @@ -1807,10 +2042,14 @@ def update_proposal_commit(self, commit_id: str, updates: Dict[str, Any]) -> boo set_clauses = [] params: List[Any] = [] for key, value in updates.items(): + if key not in VALID_PROPOSAL_COMMIT_COLUMNS: + continue if key in {"checks", "preview", "provenance"}: value = json.dumps(value) set_clauses.append(f"{key} = ?") params.append(value) + if not set_clauses: + return False set_clauses.append("updated_at = ?") params.append(_utcnow_iso()) params.append(commit_id) @@ -1836,6 +2075,8 @@ def transition_proposal_commit_status( set_clauses = ["status = ?", "updated_at = ?"] params: List[Any] = [str(to_status or "").upper(), _utcnow_iso()] for key, value in (updates or {}).items(): + if key not in VALID_PROPOSAL_COMMIT_COLUMNS: + continue if key in {"checks", "preview", "provenance"}: value = json.dumps(value) set_clauses.append(f"{key} = ?") @@ -2086,12 +2327,28 @@ def add_memory_subscriber( expires_at, ), ) - if existing is None: - self.adjust_memory_refcount( - memory_id, - strong_delta=1 if ref_type == "strong" else 0, - weak_delta=1 if ref_type == "weak" else 0, - ) + if existing is None: + strong_delta = 1 if ref_type == "strong" else 0 + weak_delta = 1 if ref_type == "weak" else 0 + conn.execute( + """ + INSERT INTO memory_refcounts (memory_id, strong_count, weak_count) + VALUES (?, 0, 0) + ON CONFLICT(memory_id) DO NOTHING + """, + (memory_id,), + ) + conn.execute( + """ + UPDATE memory_refcounts + SET + strong_count = CASE WHEN strong_count + ? < 0 THEN 0 ELSE strong_count + ? END, + weak_count = CASE WHEN weak_count + ? < 0 THEN 0 ELSE weak_count + ? END, + updated_at = ? + WHERE memory_id = ? + """, + (strong_delta, strong_delta, weak_delta, weak_delta, _utcnow_iso(), memory_id), + ) def remove_memory_subscriber(self, memory_id: str, subscriber: str, ref_type: str = "weak") -> None: with self._get_connection() as conn: @@ -2102,12 +2359,20 @@ def remove_memory_subscriber(self, memory_id: str, subscriber: str, ref_type: st """, (memory_id, subscriber, ref_type), ) - if cursor.rowcount > 0: - self.adjust_memory_refcount( - memory_id, - strong_delta=-1 if ref_type == "strong" else 0, - weak_delta=-1 if ref_type == "weak" else 0, - ) + if cursor.rowcount > 0: + strong_delta = -1 if ref_type == "strong" else 0 + weak_delta = -1 if ref_type == "weak" else 0 + conn.execute( + """ + UPDATE memory_refcounts + SET + strong_count = CASE WHEN strong_count + ? < 0 THEN 0 ELSE strong_count + ? END, + weak_count = CASE WHEN weak_count + ? < 0 THEN 0 ELSE weak_count + ? END, + updated_at = ? + WHERE memory_id = ? + """, + (strong_delta, strong_delta, weak_delta, weak_delta, _utcnow_iso(), memory_id), + ) def list_memory_subscribers(self, memory_id: str) -> List[str]: with self._get_connection() as conn: @@ -2930,6 +3195,8 @@ def update_handoff_session(self, session_id: str, updates: Dict[str, Any]) -> bo "linked_scene_ids", } for key, value in updates.items(): + if key not in VALID_HANDOFF_SESSION_COLUMNS: + continue if key in json_fields: value = json.dumps(value) set_clauses.append(f"{key} = ?") @@ -3080,6 +3347,8 @@ def update_handoff_lane( set_clauses = [] params: List[Any] = [] for key, value in updates.items(): + if key not in VALID_HANDOFF_LANE_COLUMNS: + continue if key == "current_state" and not isinstance(value, str): value = json.dumps(value) set_clauses.append(f"{key} = ?") diff --git a/engram/embeddings/gemini.py b/engram/embeddings/gemini.py index af10c2d..9944498 100644 --- a/engram/embeddings/gemini.py +++ b/engram/embeddings/gemini.py @@ -1,8 +1,11 @@ +import logging import os from typing import List, Optional from engram.embeddings.base import BaseEmbedder +logger = logging.getLogger(__name__) + class GeminiEmbedder(BaseEmbedder): def __init__(self, config: Optional[dict] = None): @@ -23,7 +26,7 @@ def __init__(self, config: Optional[dict] = None): genai.configure(api_key=self.api_key) self._client_type = "generativeai" self._genai = genai - except Exception: + except ImportError: try: from google import genai @@ -35,27 +38,37 @@ def __init__(self, config: Optional[dict] = None): ) from exc def embed(self, text: str, memory_action: Optional[str] = None) -> List[float]: - if self._client_type == "generativeai": - response = self._genai.embed_content( - model=self.model, - content=text, - ) - embedding = response.get("embedding") if isinstance(response, dict) else getattr(response, "embedding", None) - return embedding or [] + try: + if self._client_type == "generativeai": + response = self._genai.embed_content( + model=self.model, + content=text, + ) + embedding = response.get("embedding") if isinstance(response, dict) else getattr(response, "embedding", None) + if not embedding: + raise RuntimeError(f"Gemini embedding returned empty result (model={self.model})") + return embedding - if self._client_type == "genai": - response = self._client.models.embed_content( - model=self.model, - contents=text, - ) - return _extract_embedding_from_response(response) + if self._client_type == "genai": + response = self._client.models.embed_content( + model=self.model, + contents=text, + ) + return _extract_embedding_from_response(response) - return [] + raise RuntimeError("Gemini embedder not initialized") + except RuntimeError: + raise + except Exception as exc: + logger.error("Gemini embedding failed (model=%s): %s", self.model, exc) + raise RuntimeError( + f"Gemini embedding failed (model={self.model}): {exc}" + ) from exc def _extract_embedding_from_response(response) -> List[float]: if response is None: - return [] + raise RuntimeError("Gemini embedding response was None") embedding = getattr(response, "embedding", None) if embedding: return embedding @@ -65,4 +78,4 @@ def _extract_embedding_from_response(response) -> List[float]: vector = getattr(first, "values", None) or getattr(first, "embedding", None) if vector: return vector - return [] + raise RuntimeError("Gemini embedding response contained no embedding data") diff --git a/engram/embeddings/nvidia.py b/engram/embeddings/nvidia.py index 524dd56..270cad6 100644 --- a/engram/embeddings/nvidia.py +++ b/engram/embeddings/nvidia.py @@ -1,11 +1,14 @@ +import logging import os from typing import List, Optional from engram.embeddings.base import BaseEmbedder +logger = logging.getLogger(__name__) + class NvidiaEmbedder(BaseEmbedder): - """Embedding provider for NVIDIA API (OpenAI-compatible). Default model: nv-embedqa-e5-v5.""" + """Embedding provider for NVIDIA API (OpenAI-compatible). Default model: nv-embed-v1.""" def __init__(self, config: Optional[dict] = None): super().__init__(config) @@ -14,27 +17,39 @@ def __init__(self, config: Optional[dict] = None): except Exception as exc: raise ImportError("openai package is required for NvidiaEmbedder") from exc - api_key = self.config.get("api_key") + api_key = ( + self.config.get("api_key") + or os.getenv("NVIDIA_EMBEDDING_API_KEY") + or os.getenv("NVIDIA_API_KEY") + ) if not api_key: raise ValueError( - "NVIDIA API key required. Set config['api_key'] or NVIDIA_API_KEY env var." + "NVIDIA API key required. Set config['api_key'], " + "NVIDIA_EMBEDDING_API_KEY, or NVIDIA_API_KEY env var." ) base_url = self.config.get("base_url", "https://integrate.api.nvidia.com/v1") - self.client = OpenAI(base_url=base_url, api_key=api_key) - self.model = self.config.get("model", "nvidia/nv-embedqa-e5-v5") + timeout = self.config.get("timeout", 60) + self.client = OpenAI(base_url=base_url, api_key=api_key, timeout=timeout) + self.model = self.config.get("model", "nvidia/nv-embed-v1") def embed(self, text: str, memory_action: Optional[str] = None) -> List[float]: - # NVIDIA embedding models distinguish between passage and query input types - if memory_action in ("search", "forget"): - input_type = "query" - else: - input_type = "passage" - - response = self.client.embeddings.create( - input=[text], - model=self.model, - encoding_format="float", - extra_body={"input_type": input_type, "truncate": "NONE"}, - ) - return response.data[0].embedding + try: + extra_body = {} + # nv-embed-v1 does not use input_type; older E5 models do + if "e5" in self.model or "embedqa" in self.model: + input_type = "query" if memory_action in ("search", "forget") else "passage" + extra_body = {"input_type": input_type, "truncate": "NONE"} + + response = self.client.embeddings.create( + input=[text], + model=self.model, + encoding_format="float", + **({"extra_body": extra_body} if extra_body else {}), + ) + return response.data[0].embedding + except Exception as exc: + logger.error("NVIDIA embedding failed (model=%s): %s", self.model, exc) + raise RuntimeError( + f"NVIDIA embedding failed (model={self.model}): {exc}" + ) from exc diff --git a/engram/embeddings/openai.py b/engram/embeddings/openai.py index 40383c9..1400bfa 100644 --- a/engram/embeddings/openai.py +++ b/engram/embeddings/openai.py @@ -1,7 +1,10 @@ +import logging from typing import List, Optional from engram.embeddings.base import BaseEmbedder +logger = logging.getLogger(__name__) + class OpenAIEmbedder(BaseEmbedder): def __init__(self, config: Optional[dict] = None): @@ -10,9 +13,16 @@ def __init__(self, config: Optional[dict] = None): from openai import OpenAI except Exception as exc: raise ImportError("openai package is required for OpenAIEmbedder") from exc - self.client = OpenAI() + timeout = self.config.get("timeout", 60) + self.client = OpenAI(timeout=timeout) self.model = self.config.get("model", "text-embedding-3-small") def embed(self, text: str, memory_action: Optional[str] = None) -> List[float]: - response = self.client.embeddings.create(model=self.model, input=text) - return response.data[0].embedding + try: + response = self.client.embeddings.create(model=self.model, input=text) + return response.data[0].embedding + except Exception as exc: + logger.error("OpenAI embedding failed (model=%s): %s", self.model, exc) + raise RuntimeError( + f"OpenAI embedding failed (model={self.model}): {exc}" + ) from exc diff --git a/engram/llms/gemini.py b/engram/llms/gemini.py index 969f442..5a67fd2 100644 --- a/engram/llms/gemini.py +++ b/engram/llms/gemini.py @@ -1,8 +1,11 @@ +import logging import os from typing import Optional from engram.llms.base import BaseLLM +logger = logging.getLogger(__name__) + class GeminiLLM(BaseLLM): def __init__(self, config: Optional[dict] = None): @@ -26,7 +29,7 @@ def __init__(self, config: Optional[dict] = None): self._client_type = "generativeai" self._genai = genai self._model = genai.GenerativeModel(self.model) - except Exception: + except ImportError: try: from google import genai @@ -38,28 +41,36 @@ def __init__(self, config: Optional[dict] = None): ) from exc def generate(self, prompt: str) -> str: - if self._client_type == "generativeai": - response = self._model.generate_content( - prompt, - generation_config={ - "temperature": self.temperature, - "max_output_tokens": self.max_tokens, - }, - ) - return getattr(response, "text", "") or "" + try: + if self._client_type == "generativeai": + response = self._model.generate_content( + prompt, + generation_config={ + "temperature": self.temperature, + "max_output_tokens": self.max_tokens, + }, + ) + return getattr(response, "text", "") or "" - if self._client_type == "genai": - response = self._client.models.generate_content( - model=self.model, - contents=prompt, - config={ - "temperature": self.temperature, - "max_output_tokens": self.max_tokens, - }, - ) - return _extract_text_from_response(response) + if self._client_type == "genai": + response = self._client.models.generate_content( + model=self.model, + contents=prompt, + config={ + "temperature": self.temperature, + "max_output_tokens": self.max_tokens, + }, + ) + return _extract_text_from_response(response) - return "" + raise RuntimeError("Gemini LLM client not initialized") + except RuntimeError: + raise + except Exception as exc: + logger.error("Gemini LLM generate failed (model=%s): %s", self.model, exc) + raise RuntimeError( + f"Gemini LLM generation failed (model={self.model}): {exc}" + ) from exc def _extract_text_from_response(response) -> str: diff --git a/engram/llms/nvidia.py b/engram/llms/nvidia.py index 6b5cb2b..d91812c 100644 --- a/engram/llms/nvidia.py +++ b/engram/llms/nvidia.py @@ -1,11 +1,14 @@ +import logging import os from typing import Optional from engram.llms.base import BaseLLM +logger = logging.getLogger(__name__) + class NvidiaLLM(BaseLLM): - """LLM provider for NVIDIA API (OpenAI-compatible). Default model: Kimi K2.5.""" + """LLM provider for NVIDIA API (OpenAI-compatible). Default model: Llama 3.1 8B Instruct.""" def __init__(self, config: Optional[dict] = None): super().__init__(config) @@ -14,34 +17,46 @@ def __init__(self, config: Optional[dict] = None): except Exception as exc: raise ImportError("openai package is required for NvidiaLLM") from exc - api_key = self.config.get("api_key") + api_key = ( + self.config.get("api_key") + or os.getenv("LLAMA_API_KEY") + or os.getenv("NVIDIA_API_KEY") + ) if not api_key: raise ValueError( - "NVIDIA API key required. Set config['api_key'] or NVIDIA_API_KEY env var." + "NVIDIA API key required. Set config['api_key'], " + "LLAMA_API_KEY, or NVIDIA_API_KEY env var." ) base_url = self.config.get("base_url", "https://integrate.api.nvidia.com/v1") - self.client = OpenAI(base_url=base_url, api_key=api_key) - self.model = self.config.get("model", "moonshotai/kimi-k2.5") - self.temperature = self.config.get("temperature", 1.0) - self.max_tokens = self.config.get("max_tokens", 16384) + timeout = self.config.get("timeout", 60) + self.client = OpenAI(base_url=base_url, api_key=api_key, timeout=timeout) + self.model = self.config.get("model", "meta/llama-3.1-8b-instruct") + self.temperature = self.config.get("temperature", 0.2) + self.max_tokens = self.config.get("max_tokens", 1024) self.top_p = self.config.get("top_p", 0.7) self.enable_thinking = self.config.get("enable_thinking", False) def generate(self, prompt: str) -> str: - extra_kwargs = {} - if self.enable_thinking: - extra_kwargs["extra_body"] = { - "chat_template_kwargs": {"enable_thinking": True} - } - - response = self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - temperature=self.temperature, - top_p=self.top_p, - max_tokens=self.max_tokens, - stream=False, - **extra_kwargs, - ) - return response.choices[0].message.content + try: + extra_kwargs = {} + if self.enable_thinking: + extra_kwargs["extra_body"] = { + "chat_template_kwargs": {"enable_thinking": True} + } + + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + temperature=self.temperature, + top_p=self.top_p, + max_tokens=self.max_tokens, + stream=False, + **extra_kwargs, + ) + return response.choices[0].message.content or "" + except Exception as exc: + logger.error("NVIDIA LLM generate failed (model=%s): %s", self.model, exc) + raise RuntimeError( + f"NVIDIA LLM generation failed (model={self.model}): {exc}" + ) from exc diff --git a/engram/llms/openai.py b/engram/llms/openai.py index 752c5bd..510cbc3 100644 --- a/engram/llms/openai.py +++ b/engram/llms/openai.py @@ -1,7 +1,10 @@ +import logging from typing import Optional from engram.llms.base import BaseLLM +logger = logging.getLogger(__name__) + class OpenAILLM(BaseLLM): def __init__(self, config: Optional[dict] = None): @@ -10,16 +13,23 @@ def __init__(self, config: Optional[dict] = None): from openai import OpenAI except Exception as exc: raise ImportError("openai package is required for OpenAILLM") from exc - self.client = OpenAI() + timeout = self.config.get("timeout", 60) + self.client = OpenAI(timeout=timeout) self.model = self.config.get("model", "gpt-4o-mini") self.temperature = self.config.get("temperature", 0.1) self.max_tokens = self.config.get("max_tokens", 1000) def generate(self, prompt: str) -> str: - response = self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - return response.choices[0].message.content + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + return response.choices[0].message.content or "" + except Exception as exc: + logger.error("OpenAI LLM generate failed (model=%s): %s", self.model, exc) + raise RuntimeError( + f"OpenAI LLM generation failed (model={self.model}): {exc}" + ) from exc diff --git a/engram/mcp_server.py b/engram/mcp_server.py index c5bdcfa..4b83c85 100644 --- a/engram/mcp_server.py +++ b/engram/mcp_server.py @@ -166,13 +166,16 @@ def get_memory_instance() -> Memory: _lifecycle_state: Dict[str, Dict[str, Any]] = {} _idle_pause_seconds = max(1, int(os.environ.get("ENGRAM_MCP_IDLE_PAUSE_SECONDS", "300"))) _shutdown_hooks_registered = False +_shutdown_requested = False def get_memory() -> Memory: """Get or create the global memory instance.""" global _memory if _memory is None: - _memory = get_memory_instance() + with _lifecycle_lock: + if _memory is None: + _memory = get_memory_instance() return _memory @@ -185,7 +188,9 @@ def get_handoff_backend(memory: Memory): """Get or create the configured handoff backend.""" global _handoff_backend if _handoff_backend is None: - _handoff_backend = create_handoff_backend(memory) + with _lifecycle_lock: + if _handoff_backend is None: + _handoff_backend = create_handoff_backend(memory) return _handoff_backend @@ -202,6 +207,30 @@ def _merge_handoff_context(existing: Dict[str, Any], update: Dict[str, Any]) -> return merged +_LIFECYCLE_MAX_ENTRIES = 500 +_LIFECYCLE_MAX_AGE_SECONDS = 86400 # 24 hours + + +def _gc_lifecycle_state_locked() -> None: + """Evict stale runtime entries from _lifecycle_state. Called under _lifecycle_lock. + + NOTE: This only cleans ephemeral in-process handoff context, NOT persistent + memory data. Actual memories are safely stored in SQLite and vector stores. + """ + now = time.time() + expired = [k for k, v in _lifecycle_state.items() + if now - v.get("last_activity_ts", 0) > _LIFECYCLE_MAX_AGE_SECONDS] + for k in expired: + del _lifecycle_state[k] + if len(_lifecycle_state) > _LIFECYCLE_MAX_ENTRIES: + sorted_keys = sorted( + _lifecycle_state, + key=lambda k: _lifecycle_state[k].get("last_activity_ts", 0), + ) + for k in sorted_keys[:len(_lifecycle_state) - _LIFECYCLE_MAX_ENTRIES]: + del _lifecycle_state[k] + + def _record_handoff_context(context: Dict[str, Any]) -> None: user_id = context.get("user_id", "default") agent_id = context.get("agent_id", "claude-code") @@ -229,6 +258,9 @@ def _record_handoff_context(context: Dict[str, Any]) -> None: merged = _merge_handoff_context(existing, context) merged["last_activity_ts"] = now_ts _lifecycle_state[key] = merged + # Periodic cleanup to prevent unbounded growth + if len(_lifecycle_state) > _LIFECYCLE_MAX_ENTRIES: + _gc_lifecycle_state_locked() def _emit_lifecycle_checkpoint(memory: Memory, context: Dict[str, Any], *, event_type: str, task_summary: Optional[str]) -> Dict[str, Any]: @@ -266,8 +298,14 @@ def _flush_agent_end_checkpoints() -> None: memory = get_memory() except Exception: return - with _lifecycle_lock: + # Use non-blocking acquire so we never deadlock if a signal fired while + # the lock was already held on the same thread (atexit runs in-process). + acquired = _lifecycle_lock.acquire(blocking=False) + try: contexts = list(_lifecycle_state.values()) + finally: + if acquired: + _lifecycle_lock.release() for context in contexts: try: _emit_lifecycle_checkpoint( @@ -287,10 +325,11 @@ def _register_shutdown_hooks() -> None: atexit.register(_flush_agent_end_checkpoints) def _signal_handler(signum, _frame): # pragma: no cover - signal path - try: - _flush_agent_end_checkpoints() - finally: - raise SystemExit(0) + # Set a flag instead of acquiring _lifecycle_lock directly to avoid + # deadlock if the signal fires while _lifecycle_lock is already held. + global _shutdown_requested + _shutdown_requested = True + raise SystemExit(0) for sig_name in ("SIGTERM", "SIGINT"): sig_value = getattr(signal, sig_name, None) @@ -1241,7 +1280,10 @@ def _handle_apply_memory_decay(memory: "Memory", arguments: Dict[str, Any], _ses @_tool_handler("engram_context") def _handle_engram_context(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: user_id = arguments.get("user_id", "default") - limit = arguments.get("limit", 15) + try: + limit = max(1, min(100, int(arguments.get("limit", 15)))) + except (ValueError, TypeError): + limit = 15 all_result = memory.get_all(user_id=user_id, limit=limit * 3) all_memories = all_result.get("results", []) layer_order = {"lml": 0, "sml": 1} @@ -1296,7 +1338,10 @@ def _handle_list_profiles(memory: "Memory", arguments: Dict[str, Any], _session_ def _handle_search_profiles(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: query = arguments.get("query", "") user_id = arguments.get("user_id", "default") - limit = arguments.get("limit", 10) + try: + limit = max(1, min(100, int(arguments.get("limit", 10)))) + except (ValueError, TypeError): + limit = 10 profiles = memory.search_profiles(query=query, user_id=user_id, limit=limit) return { "profiles": [ @@ -1349,7 +1394,10 @@ def _handle_search_memory(memory: "Memory", arguments: Dict[str, Any], _session_ query = arguments.get("query", "") user_id = arguments.get("user_id", "default") agent_id = arguments.get("agent_id") - limit = arguments.get("limit", 10) + try: + limit = max(1, min(1000, int(arguments.get("limit", 10)))) + except (ValueError, TypeError): + limit = 10 categories = arguments.get("categories") if agent_id: token = _session_token(user_id=user_id, agent_id=agent_id, capabilities=["search"]) @@ -1382,10 +1430,14 @@ def _handle_search_memory(memory: "Memory", arguments: Dict[str, Any], _session_ @_tool_handler("get_all_memories") def _handle_get_all_memories(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + try: + limit = max(1, min(1000, int(arguments.get("limit", 50)))) + except (ValueError, TypeError): + limit = 50 result = memory.get_all( user_id=arguments.get("user_id", "default"), agent_id=arguments.get("agent_id"), - limit=arguments.get("limit", 50), + limit=limit, layer=arguments.get("layer"), ) if "results" in result: @@ -1433,7 +1485,8 @@ def _handle_list_pending_commits(memory: "Memory", arguments: Dict[str, Any], _s token = _session_token(user_id=user_id, agent_id=agent_id, capabilities=["review_commits"]) return memory.list_pending_commits( user_id=user_id, agent_id=agent_id, token=token, - status=arguments.get("status"), limit=arguments.get("limit", 100), + status=arguments.get("status"), + limit=max(1, min(1000, int(arguments.get("limit", 100)))) if arguments.get("limit") is not None else 100, ) @@ -1567,12 +1620,16 @@ def _handle_get_scene(memory: "Memory", arguments: Dict[str, Any], _session_toke @_tool_handler("list_scenes") def _handle_list_scenes(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + try: + scene_limit = max(1, min(200, int(arguments.get("limit", 20)))) + except (ValueError, TypeError): + scene_limit = 20 scenes = memory.get_scenes( user_id=arguments.get("user_id", "default"), topic=arguments.get("topic"), start_after=arguments.get("start_after"), start_before=arguments.get("start_before"), - limit=arguments.get("limit", 20), + limit=scene_limit, ) return { "scenes": [ @@ -1596,10 +1653,14 @@ def _handle_search_scenes(memory: "Memory", arguments: Dict[str, Any], _session_ user_id = arguments.get("user_id", "default") agent_id = arguments.get("agent_id", "claude-code") token = _session_token(user_id=user_id, agent_id=agent_id, capabilities=["read_scene"]) + try: + scene_search_limit = max(1, min(100, int(arguments.get("limit", 10)))) + except (ValueError, TypeError): + scene_search_limit = 10 payload = memory.kernel.search_scenes( query=arguments.get("query", ""), user_id=user_id, agent_id=agent_id, token=token, - limit=arguments.get("limit", 10), + limit=scene_search_limit, ) scenes = payload.get("scenes", []) return { @@ -1623,30 +1684,53 @@ def _handle_search_scenes(memory: "Memory", arguments: Dict[str, Any], _session_ # ---- Active Memory (signal bus) helpers and handlers ---- _active_store = None +_active_store_lock = threading.Lock() def _get_active_store(memory: Memory): """Lazy-initialize the global active memory store.""" global _active_store if _active_store is None: - if memory.config.active.enabled: - from engram.core.active_memory import ActiveMemoryStore - _active_store = ActiveMemoryStore(memory.config.active) + with _active_store_lock: + if _active_store is None: + if memory.config.active.enabled: + from engram.core.active_memory import ActiveMemoryStore + _active_store = ActiveMemoryStore(memory.config.active) return _active_store @_tool_handler("signal_write") def _handle_signal_write(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + # Validate required fields + key = arguments.get("key") + if key is None or not isinstance(key, str) or not key.strip(): + return {"error": "signal_write requires 'key' parameter"} + value = arguments.get("value") + if value is None or not isinstance(value, str): + return {"error": "signal_write requires 'value' parameter (string)"} + + # Validate signal_type enum + _VALID_SIGNAL_TYPES = {"state", "event", "directive"} + signal_type = arguments.get("signal_type", "state") + if signal_type not in _VALID_SIGNAL_TYPES: + return {"error": f"signal_write 'signal_type' must be one of {sorted(_VALID_SIGNAL_TYPES)}, got '{signal_type}'"} + + # Validate ttl_tier enum + _VALID_TTL_TIERS = {"noise", "notable", "critical", "directive"} + ttl_tier = arguments.get("ttl_tier", "notable") + if ttl_tier not in _VALID_TTL_TIERS: + return {"error": f"signal_write 'ttl_tier' must be one of {sorted(_VALID_TTL_TIERS)}, got '{ttl_tier}'"} + active = _get_active_store(memory) if not active: return {"error": "Active memory is disabled"} return active.write_signal( - key=arguments["key"], - value=arguments["value"], - signal_type=arguments.get("signal_type", "state"), + key=key, + value=value, + signal_type=signal_type, scope=arguments.get("scope", "global"), scope_key=arguments.get("scope_key"), - ttl_tier=arguments.get("ttl_tier", "notable"), + ttl_tier=ttl_tier, source_agent_id=arguments.get("agent_id"), user_id=arguments.get("user_id", "default"), ) @@ -1657,18 +1741,32 @@ def _handle_signal_read(memory: "Memory", arguments: Dict[str, Any], _session_to active = _get_active_store(memory) if not active: return {"error": "Active memory is disabled"} + raw_limit = arguments.get("limit") + if raw_limit is not None: + try: + limit = max(1, min(1000, int(raw_limit))) + except (ValueError, TypeError): + limit = None + else: + limit = None return active.read_signals( scope=arguments.get("scope"), scope_key=arguments.get("scope_key"), signal_type=arguments.get("signal_type"), user_id=arguments.get("user_id", "default"), reader_agent_id=arguments.get("agent_id"), - limit=arguments.get("limit"), + limit=limit, ) @_tool_handler("signal_clear") def _handle_signal_clear(memory: "Memory", arguments: Dict[str, Any], _session_token, _preview) -> Any: + # Validate signal_type enum if provided + _VALID_SIGNAL_TYPES = {"state", "event", "directive"} + signal_type = arguments.get("signal_type") + if signal_type is not None and signal_type not in _VALID_SIGNAL_TYPES: + return {"error": f"signal_clear 'signal_type' must be one of {sorted(_VALID_SIGNAL_TYPES)}, got '{signal_type}'"} + active = _get_active_store(memory) if not active: return {"error": "Active memory is disabled"} @@ -1677,7 +1775,7 @@ def _handle_signal_clear(memory: "Memory", arguments: Dict[str, Any], _session_t scope=arguments.get("scope"), scope_key=arguments.get("scope_key"), source_agent_id=arguments.get("agent_id"), - signal_type=arguments.get("signal_type"), + signal_type=signal_type, user_id=arguments.get("user_id", "default"), ) @@ -1900,7 +1998,7 @@ def _handoff_error_payload(exc: Exception) -> Dict[str, str]: repo=arguments.get("repo"), status=arguments.get("status"), statuses=arguments.get("statuses"), - limit=arguments.get("limit", 20), + limit=max(1, min(200, int(arguments.get("limit", 20)))) if arguments.get("limit") is not None else 20, ) result = { "sessions": [ @@ -1993,25 +2091,29 @@ def _handoff_error_payload(exc: Exception) -> Dict[str, str]: handoff_meta["resume"] = auto_resume_packet result["_handoff"] = handoff_meta - # Active memory auto-injection: attach latest signals to every response + # Active memory auto-injection: attach latest signals to every response. + # Use peek_signals (read-only) to avoid inflating read_count on every + # tool call; only explicit signal_read calls should bump the counter. if isinstance(result, dict): active_store = _get_active_store(memory) if active_store: try: - signals = active_store.read_signals( + signals = active_store.peek_signals( user_id=arguments.get("user_id", "default"), - reader_agent_id=arguments.get("agent_id"), limit=memory.config.active.max_signals_per_response, ) if signals: result["_active"] = signals - except Exception: - pass # Never break tool responses for active memory errors + except Exception as active_err: + logger.debug("Active memory injection failed: %s", active_err) return [TextContent(type="text", text=json.dumps(result, indent=2, default=str))] except Exception as e: - error_result = {"error": str(e)} + logger.exception("MCP tool '%s' failed", name) + # Sanitize error — only expose the exception class name + message, not internals + error_msg = f"{type(e).__name__}: {e}" + error_result = {"error": error_msg} return [TextContent(type="text", text=json.dumps(error_result, indent=2))] diff --git a/engram/memory/async_memory.py b/engram/memory/async_memory.py index bdaab8b..403b616 100644 --- a/engram/memory/async_memory.py +++ b/engram/memory/async_memory.py @@ -104,6 +104,8 @@ async def close(self) -> None: """Close all connections.""" if self._vector_store: await self._vector_store.close() + if self._db: + await self._db.close() @classmethod async def from_config(cls, config_dict: Dict[str, Any]) -> "AsyncMemory": diff --git a/engram/memory/client.py b/engram/memory/client.py index 4904e58..24507d1 100644 --- a/engram/memory/client.py +++ b/engram/memory/client.py @@ -112,7 +112,7 @@ def get(self, memory_id: str, **kwargs) -> Dict[str, Any]: def get_all(self, **kwargs) -> Dict[str, Any]: return self._request("GET", "/v1/memories/", params=kwargs) - def update(self, memory_id: str, data: str = None, metadata: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]: + def update(self, memory_id: str, data: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]: payload: Dict[str, Any] = {} if data is not None: payload["data"] = data diff --git a/engram/memory/episodic_store.py b/engram/memory/episodic_store.py index e9c88e7..ec341ef 100644 --- a/engram/memory/episodic_store.py +++ b/engram/memory/episodic_store.py @@ -3,7 +3,7 @@ from __future__ import annotations import re -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional, Tuple @@ -39,7 +39,7 @@ def ingest_memory_as_view( timestamp: Optional[str] = None, ) -> Dict[str, Any]: metadata = metadata or {} - timestamp = timestamp or datetime.utcnow().isoformat() + timestamp = timestamp or datetime.now(timezone.utc).isoformat() namespace = str(metadata.get("namespace", "default") or "default").strip() or "default" place_type, place_value = self._extract_place(metadata) topic_label = self._extract_topic(content) diff --git a/engram/memory/main.py b/engram/memory/main.py index de1f80f..11405cb 100644 --- a/engram/memory/main.py +++ b/engram/memory/main.py @@ -17,9 +17,19 @@ ) from engram.core.decay import calculate_decayed_strength, should_forget, should_promote from engram.core.conflict import resolve_conflict +from engram.core.distillation import ReplayDistiller from engram.core.echo import EchoProcessor, EchoDepth, EchoResult +from engram.core.forgetting import HomeostaticNormalizer, InterferencePruner, RedundancyCollapser from engram.core.fusion import fuse_memories +from engram.core.intent import QueryIntent, classify_intent from engram.core.retrieval import composite_score, tokenize, HybridSearcher +from engram.core.traces import ( + boost_fast_trace, + cascade_traces, + compute_effective_strength, + decay_traces, + initialize_traces, +) from engram.core.category import CategoryProcessor, CategoryMatch from engram.core.graph import KnowledgeGraph from engram.core.scene import SceneProcessor @@ -94,6 +104,7 @@ def __init__(self, config: Optional[MemoryConfig] = None): self.fadem_config = self.config.engram self.echo_config = self.config.echo self.scope_config = getattr(self.config, "scope", None) + self.distillation_config = getattr(self.config, "distillation", None) # Initialize EchoMem processor if self.echo_config.enable_echo: @@ -219,6 +230,16 @@ def consolidate_active(self) -> Dict[str, Any]: engine = ConsolidationEngine(self.active, self, self.config.active) return engine.run_cycle() + def close(self) -> None: + """Release all resources held by the Memory instance.""" + if hasattr(self, '_active_store') and self._active_store is not None: + self._active_store.close() + self._active_store = None + if hasattr(self, 'vector_store') and self.vector_store is not None: + self.vector_store.close() + if hasattr(self, 'db') and self.db is not None: + self.db.close() + def __repr__(self) -> str: return f"Memory(db={self.db!r}, echo={self.echo_config.enable_echo}, scenes={self.scene_config.enable_scenes})" @@ -229,22 +250,22 @@ def from_config(cls, config_dict: Dict[str, Any]): def add( self, messages: Union[str, List[Dict[str, str]]], - user_id: str = None, - agent_id: str = None, - run_id: str = None, - app_id: str = None, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + app_id: Optional[str] = None, metadata: Dict[str, Any] = None, filters: Dict[str, Any] = None, categories: List[str] = None, immutable: bool = False, - expiration_date: str = None, + expiration_date: Optional[str] = None, infer: bool = True, - prompt: str = None, - includes: str = None, - excludes: str = None, + prompt: Optional[str] = None, + includes: Optional[str] = None, + excludes: Optional[str] = None, initial_layer: str = "auto", initial_strength: float = 1.0, - echo_depth: str = None, # EchoMem: override echo depth (shallow/medium/deep) + echo_depth: Optional[str] = None, # EchoMem: override echo depth (shallow/medium/deep) agent_category: Optional[str] = None, connector_id: Optional[str] = None, scope: Optional[str] = None, @@ -573,6 +594,16 @@ def _process_single_memory( ) namespace_value = str(mem_metadata.get("namespace", "default") or "default").strip() or "default" + # Gap 1: Classify memory type (episodic vs semantic) + memory_type = self._classify_memory_type(mem_metadata, role) + + # Gap 4: Initialize multi-trace strength + s_fast_val = None + s_mid_val = None + s_slow_val = None + if self.distillation_config and self.distillation_config.enable_multi_trace: + s_fast_val, s_mid_val, s_slow_val = initialize_traces(effective_strength, is_new=True) + memory_id = str(uuid.uuid4()) now = datetime.now(timezone.utc).isoformat() memory_data = { @@ -602,6 +633,10 @@ def _process_single_memory( "importance": mem_metadata.get("importance", 0.5), "sensitivity": mem_metadata.get("sensitivity", "normal"), "namespace": namespace_value, + "memory_type": memory_type, + "s_fast": s_fast_val, + "s_mid": s_mid_val, + "s_slow": s_slow_val, } vectors, payloads, vector_ids = self._build_index_vectors( @@ -619,7 +654,23 @@ def _process_single_memory( ) self.db.add_memory(memory_data) - self.vector_store.insert(vectors=vectors, payloads=payloads, ids=vector_ids) + if vectors: + try: + self.vector_store.insert(vectors=vectors, payloads=payloads, ids=vector_ids) + except Exception as e: + # Vector insert failed — roll back the DB record to prevent desync. + logger.error( + "Vector insert failed for memory %s, rolling back DB record: %s", + memory_id, e, + ) + try: + self.db.delete_memory(memory_id, use_tombstone=False) + except Exception as rollback_err: + logger.critical( + "CRITICAL: DB rollback also failed for memory %s — manual cleanup required: %s", + memory_id, rollback_err, + ) + raise # Post-store hooks. if self.category_processor and mem_categories: @@ -659,15 +710,16 @@ def _process_single_memory( "categories": mem_categories, "namespace": namespace_value, "vector_nodes": len(vectors), + "memory_type": memory_type, } def search( self, query: str, - user_id: str = None, - agent_id: str = None, - run_id: str = None, - app_id: str = None, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + app_id: Optional[str] = None, filters: Dict[str, Any] = None, categories: List[str] = None, agent_category: Optional[str] = None, @@ -683,6 +735,9 @@ def search( use_category_boost: bool = True, # CategoryMem: boost by category relevance **kwargs: Any, ) -> Dict[str, Any]: + if not query or not query.strip(): + return {"results": [], "context_packet": None} + _, effective_filters = build_filters_and_metadata( user_id=user_id, agent_id=agent_id, @@ -708,6 +763,15 @@ def search( if scope_value } + # Gap 5: Classify query intent for routing + query_intent = None + if ( + self.distillation_config + and self.distillation_config.enable_intent_routing + and self.distillation_config.enable_memory_types + ): + query_intent = classify_intent(query) + query_embedding = self.embedder.embed(query, memory_action="search") vector_results = self.vector_store.search( query=query, @@ -770,14 +834,16 @@ def search( reecho_ids: List[str] = [] subscriber_ids: List[str] = [] + # Pre-create HybridSearcher outside the loop to avoid re-allocation per result. + hybrid_searcher = HybridSearcher(alpha=hybrid_alpha) if keyword_search else None + for memory_id in candidate_ids: memory = memories_bulk.get(memory_id) if not memory: continue - # Skip expired memories + # Skip expired memories (cleanup happens in apply_decay, not during search) if self._is_expired(memory): - self.delete(memory["id"]) continue if memory.get("strength", 1.0) < min_strength: @@ -806,8 +872,7 @@ def search( # Hybrid search: combine semantic and keyword scores keyword_score = 0.0 - if keyword_search: - hybrid_searcher = HybridSearcher(alpha=hybrid_alpha) + if hybrid_searcher: scores = hybrid_searcher.score_memory( query_terms=query_terms, semantic_similarity=similarity, @@ -839,6 +904,19 @@ def search( category_boost = self.category_config.cross_category_boost combined = combined * (1 + category_boost) + # Gap 5: Intent-based retrieval routing boost + intent_boost = 0.0 + mem_type = memory.get("memory_type", "semantic") + if query_intent and self.distillation_config: + dc = self.distillation_config + if query_intent == QueryIntent.EPISODIC and mem_type == "episodic": + intent_boost = dc.episodic_boost + elif query_intent == QueryIntent.SEMANTIC and mem_type == "semantic": + intent_boost = dc.semantic_boost + elif query_intent == QueryIntent.MIXED: + intent_boost = dc.intersection_boost + combined = combined * (1 + intent_boost) + # KnowledgeGraph: Boost for memories sharing entities with query terms graph_boost = 0.0 if self.knowledge_graph: @@ -904,6 +982,9 @@ def search( "echo_boost": echo_boost, "category_boost": category_boost, "graph_boost": graph_boost, + "intent_boost": intent_boost, + "memory_type": mem_type, + "query_intent": query_intent.value if query_intent else None, } ) @@ -986,19 +1067,25 @@ def get(self, memory_id: str) -> Optional[Dict[str, Any]]: self.db.increment_access(memory_id) return memory + # Hard cap to prevent unbounded result sets even if callers pass a huge limit. + _GET_ALL_MAX_LIMIT = 10_000 + def get_all( self, - user_id: str = None, - agent_id: str = None, - run_id: str = None, - app_id: str = None, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + app_id: Optional[str] = None, filters: Dict[str, Any] = None, categories: List[str] = None, limit: int = 100, - layer: str = None, + layer: Optional[str] = None, min_strength: float = 0.0, **kwargs: Any, ) -> Dict[str, Any]: + # Clamp limit to a sensible maximum to avoid unbounded result sets. + limit = max(1, min(limit, self._GET_ALL_MAX_LIMIT)) + _, effective_filters = build_filters_and_metadata( user_id=user_id, agent_id=agent_id, @@ -1015,6 +1102,7 @@ def get_all( app_id=app_id, layer=layer, min_strength=min_strength, + limit=limit, ) if categories: @@ -1096,7 +1184,14 @@ def update(self, memory_id: str, data: Union[str, Dict[str, Any]]) -> Dict[str, run_id=memory.get("run_id"), app_id=memory.get("app_id"), ) - self.vector_store.insert(vectors=vectors, payloads=payloads, ids=vector_ids) + try: + self.vector_store.insert(vectors=vectors, payloads=payloads, ids=vector_ids) + except Exception as e: + logger.error( + "Vector re-insert failed during update for memory %s: %s. " + "DB was updated but vector index is stale — will be rebuilt on next update.", + memory_id, e, + ) else: success = self.db.update_memory( memory_id, @@ -1105,22 +1200,31 @@ def update(self, memory_id: str, data: Union[str, Dict[str, Any]]) -> Dict[str, if success: payload_updates = dict(metadata) payload_updates["categories"] = categories - self._update_vectors_for_memory(memory_id, payload_updates) + try: + self._update_vectors_for_memory(memory_id, payload_updates) + except Exception as e: + logger.error( + "Vector payload update failed for memory %s: %s. " + "DB is authoritative — vector metadata may be stale.", + memory_id, e, + ) return {"id": memory_id, "memory": content, "event": "UPDATE" if success else "ERROR"} def delete(self, memory_id: str) -> Dict[str, Any]: + logger.info("Deleting memory %s (tombstone=%s)", memory_id, self.fadem_config.use_tombstone_deletion) self.db.delete_memory(memory_id, use_tombstone=self.fadem_config.use_tombstone_deletion) self._delete_vectors_for_memory(memory_id) return {"id": memory_id, "deleted": True} def delete_all( self, - user_id: str = None, - agent_id: str = None, - run_id: str = None, - app_id: str = None, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + app_id: Optional[str] = None, filters: Dict[str, Any] = None, + dry_run: bool = False, **kwargs: Any, ) -> Dict[str, Any]: if not any([user_id, agent_id, run_id, app_id, filters]): @@ -1132,6 +1236,13 @@ def delete_all( if filters: memories = [m for m in memories if matches_filters({**m, **m.get("metadata", {})}, filters)] + if dry_run: + return {"deleted_count": 0, "would_delete": len(memories), "dry_run": True} + + logger.warning( + "delete_all: deleting %d memories (user_id=%s, agent_id=%s, filters=%s)", + len(memories), user_id, agent_id, filters, + ) count = 0 for memory in memories: self.delete(memory["id"]) @@ -1142,7 +1253,9 @@ def history(self, memory_id: str) -> List[Dict[str, Any]]: return self.db.get_history(memory_id) def reset(self) -> None: + """Delete ALL memories including tombstoned. This is IRREVERSIBLE.""" memories = self.db.get_all_memories(include_tombstoned=True) + logger.warning("reset: permanently deleting ALL %d memories", len(memories)) for mem in memories: self.delete(mem["id"]) if hasattr(self.vector_store, "reset"): @@ -1184,14 +1297,32 @@ def apply_decay(self, scope: Dict[str, Any] = None) -> Dict[str, Any]: metrics.record_ref_protected_skip(1) continue - new_strength = calculate_decayed_strength( - current_strength=memory.get("strength", 1.0), - last_accessed=memory.get("last_accessed", datetime.now(timezone.utc).isoformat()), - access_count=memory.get("access_count", 0), - layer=memory.get("layer", "sml"), - config=self.fadem_config, + # Gap 4: Multi-trace decay (if enabled and traces are initialized) + use_multi_trace = ( + self.distillation_config + and self.distillation_config.enable_multi_trace + and memory.get("s_fast") is not None ) + if use_multi_trace: + s_f, s_m, s_s = decay_traces( + s_fast=float(memory.get("s_fast", 0.0)), + s_mid=float(memory.get("s_mid", 0.0)), + s_slow=float(memory.get("s_slow", 0.0)), + last_accessed=memory.get("last_accessed", datetime.now(timezone.utc).isoformat()), + access_count=memory.get("access_count", 0), + config=self.distillation_config, + ) + new_strength = compute_effective_strength(s_f, s_m, s_s, self.distillation_config) + else: + new_strength = calculate_decayed_strength( + current_strength=memory.get("strength", 1.0), + last_accessed=memory.get("last_accessed", datetime.now(timezone.utc).isoformat()), + access_count=memory.get("access_count", 0), + layer=memory.get("layer", "sml"), + config=self.fadem_config, + ) + if ref_aware and int(ref_state.get("weak", 0)) > 0: weak = min(int(ref_state.get("weak", 0)), 10) dampening = 1.0 + weak * 0.15 @@ -1209,7 +1340,10 @@ def apply_decay(self, scope: Dict[str, Any] = None) -> Dict[str, Any]: continue if new_strength != memory.get("strength"): - self.db.update_memory(memory["id"], {"strength": new_strength}) + if use_multi_trace: + self.db.update_multi_trace(memory["id"], s_f, s_m, s_s, new_strength) + else: + self.db.update_memory(memory["id"], {"strength": new_strength}) self.db.log_event(memory["id"], "DECAY", old_strength=memory.get("strength"), new_strength=new_strength) decayed += 1 @@ -1226,15 +1360,55 @@ def apply_decay(self, scope: Dict[str, Any] = None) -> Dict[str, Any]: if self.fadem_config.use_tombstone_deletion: self.db.purge_tombstoned() + # Gap 3: Advanced forgetting mechanisms + interference_stats = {"checked": 0, "demoted": 0} + redundancy_stats = {"groups_fused": 0, "memories_fused": 0} + homeostasis_stats = {"namespaces_over_budget": 0, "pressured": 0, "forgotten": 0} + + if self.distillation_config: + user_id = scope.get("user_id") if scope else None + + if self.distillation_config.enable_interference_pruning: + pruner = InterferencePruner( + db=self.db, + config=self.distillation_config, + fadem_config=self.fadem_config, + resolve_conflict_fn=resolve_conflict, + search_fn=self.vector_store.search, + llm=self.llm, + ) + interference_stats = pruner.run(memories, user_id=user_id) + + if self.distillation_config.enable_redundancy_collapse: + collapser = RedundancyCollapser( + db=self.db, + config=self.distillation_config, + fuse_fn=self.fuse_memories, + search_fn=self.vector_store.search, + ) + redundancy_stats = collapser.run(memories, user_id=user_id) + + if self.distillation_config.enable_homeostasis and user_id: + normalizer = HomeostaticNormalizer( + db=self.db, + config=self.distillation_config, + fadem_config=self.fadem_config, + delete_fn=self.delete, + ) + homeostasis_stats = normalizer.run(user_id) + self.db.log_decay(decayed, forgotten, promoted) return { "decayed": decayed, "forgotten": forgotten, "promoted": promoted, "stale_refs_removed": stale_refs_removed, + "interference": interference_stats, + "redundancy": redundancy_stats, + "homeostasis": homeostasis_stats, } - def fuse_memories(self, memory_ids: List[str], user_id: str = None) -> Dict[str, Any]: + def fuse_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> Dict[str, Any]: memories = [self.db.get_memory(mid) for mid in memory_ids] memories = [m for m in memories if m] if len(memories) < 2: @@ -1258,7 +1432,7 @@ def fuse_memories(self, memory_ids: List[str], user_id: str = None) -> Dict[str, fused_id = result.get("results", [{}])[0].get("id") if result.get("results") else None return {"fused_id": fused_id, "source_ids": memory_ids, "fused_memory": fused.content} - def get_stats(self, user_id: str = None, agent_id: str = None) -> Dict[str, Any]: + def get_stats(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> Dict[str, Any]: memories = self.db.get_all_memories(user_id=user_id, agent_id=agent_id) sml_count = sum(1 for m in memories if m.get("layer") == "sml") lml_count = sum(1 for m in memories if m.get("layer") == "lml") @@ -1613,11 +1787,7 @@ def _extract_memories( extracted = [m for m in extracted if excludes.lower() not in m.get("content", "").lower()] return extracted except Exception as exc: - logger.warning("Failed to parse extraction response: %s", exc) - # Fallback: add last user message - last_user = next((m for m in reversed(messages) if m.get("role") == "user"), None) - if last_user: - return [{"content": last_user.get("content", "") }] + logger.warning("Memory extraction failed (LLM or JSON error): %s", exc) return [] def _should_use_agent_memory_extraction(self, messages: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool: @@ -1625,6 +1795,33 @@ def _should_use_agent_memory_extraction(self, messages: List[Dict[str, Any]], me has_assistant_messages = any(msg.get("role") == "assistant" for msg in messages) return has_agent_id and has_assistant_messages + def _classify_memory_type(self, metadata: Dict[str, Any], role: str) -> str: + """Classify a memory as 'episodic' or 'semantic' (Gap 1). + + When enable_memory_types is False, everything stays 'semantic' (backward compat). + """ + if not self.distillation_config or not self.distillation_config.enable_memory_types: + return self.distillation_config.default_memory_type if self.distillation_config else "semantic" + + # Explicit override from metadata + explicit = metadata.get("memory_type") + if explicit in ("episodic", "semantic"): + return explicit + + # Distilled content is always semantic + if metadata.get("is_distilled"): + return "semantic" + + # Conversation messages (user/assistant) are episodic + if role in ("user", "assistant"): + return "episodic" + + # Active memory signals are semantic + if metadata.get("source_type") == "active_signal": + return "semantic" + + return "semantic" + def _select_primary_text(self, content: str, echo_result: Optional[EchoResult]) -> str: if self.echo_config.use_question_embedding and echo_result and echo_result.question_form: return echo_result.question_form @@ -1849,26 +2046,43 @@ def add_node( return vectors, payloads, vector_ids def _delete_vectors_for_memory(self, memory_id: str) -> None: - vectors = self.vector_store.list(filters={"memory_id": memory_id}) - if not vectors: - self.vector_store.delete(memory_id) - return - for vec in vectors: - self.vector_store.delete(vec.id) + try: + vectors = self.vector_store.list(filters={"memory_id": memory_id}) + if not vectors: + self.vector_store.delete(memory_id) + return + for vec in vectors: + self.vector_store.delete(vec.id) + except Exception as e: + logger.error( + "Failed to delete vectors for memory %s: %s. " + "Orphaned vector entries may exist.", + memory_id, e, + ) def _update_vectors_for_memory(self, memory_id: str, payload_updates: Dict[str, Any]) -> None: - vectors = self.vector_store.list(filters={"memory_id": memory_id}) + try: + vectors = self.vector_store.list(filters={"memory_id": memory_id}) + except Exception as e: + logger.error("Failed to list vectors for memory %s: %s", memory_id, e) + return if not vectors: - existing = self.vector_store.get(memory_id) - if existing: - payload = existing.payload or {} - payload.update(payload_updates) - self.vector_store.update(memory_id, payload=payload) + try: + existing = self.vector_store.get(memory_id) + if existing: + payload = existing.payload or {} + payload.update(payload_updates) + self.vector_store.update(memory_id, payload=payload) + except Exception as e: + logger.error("Failed to update vector payload for memory %s: %s", memory_id, e) return for vec in vectors: payload = vec.payload or {} payload.update(payload_updates) - self.vector_store.update(vec.id, payload=payload) + try: + self.vector_store.update(vec.id, payload=payload) + except Exception as e: + logger.error("Failed to update vector %s for memory %s: %s", vec.id, memory_id, e) def _nearest_memory(self, embedding: List[float], filters: Dict[str, Any]) -> tuple[Optional[Dict[str, Any]], float]: results = self.vector_store.search(query=None, vectors=embedding, limit=1, filters=filters) diff --git a/engram/memory/staging_store.py b/engram/memory/staging_store.py index 45d3b64..3504062 100644 --- a/engram/memory/staging_store.py +++ b/engram/memory/staging_store.py @@ -3,7 +3,7 @@ from __future__ import annotations import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, List, Optional @@ -24,7 +24,7 @@ def create_commit( status: str = "PENDING", ) -> Dict[str, Any]: commit_id = str(uuid.uuid4()) - created_at = datetime.utcnow().isoformat() + created_at = datetime.now(timezone.utc).isoformat() payload = { "id": commit_id, "user_id": user_id, diff --git a/engram/observability.py b/engram/observability.py index 272b839..1782e36 100644 --- a/engram/observability.py +++ b/engram/observability.py @@ -24,7 +24,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Callable import threading @@ -60,7 +60,7 @@ def _log(self, level: int, message: str, **kwargs): "structured_data": { **self._context, **kwargs, - "timestamp": datetime.utcnow().isoformat(), + "timestamp": datetime.now(timezone.utc).isoformat(), } } self._logger.log(level, message, extra=extra) @@ -123,7 +123,7 @@ def record(self, latency_ms: float, error: bool = False): self.max_latency_ms = max(self.max_latency_ms, latency_ms) if error: self.errors += 1 - self.last_operation = datetime.utcnow().isoformat() + self.last_operation = datetime.now(timezone.utc).isoformat() @property def avg_latency_ms(self) -> float: @@ -183,7 +183,7 @@ def __init__(self): self._lock = threading.Lock() self._operations: Dict[str, OperationMetrics] = defaultdict(OperationMetrics) self._memory = MemoryMetrics() - self._start_time = datetime.utcnow() + self._start_time = datetime.now(timezone.utc) self._custom_gauges: Dict[str, float] = {} def record_operation( @@ -266,7 +266,7 @@ def set_gauge(self, name: str, value: float): def get_summary(self) -> Dict[str, Any]: """Get a summary of all metrics.""" with self._lock: - uptime = (datetime.utcnow() - self._start_time).total_seconds() + uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds() return { "uptime_seconds": round(uptime, 2), "operations": { diff --git a/engram/retrieval/dual_search.py b/engram/retrieval/dual_search.py index b8d4da1..5d08236 100644 --- a/engram/retrieval/dual_search.py +++ b/engram/retrieval/dual_search.py @@ -36,6 +36,9 @@ def search( allowed_confidentiality_scopes: Optional[Iterable[str]] = None, allowed_namespaces: Optional[Iterable[str]] = None, ) -> Dict[str, Any]: + # Materialize to avoid consuming a generator/iterator twice + if allowed_namespaces is not None and not isinstance(allowed_namespaces, (list, tuple, set, frozenset)): + allowed_namespaces = list(allowed_namespaces) semantic_payload = self.memory.search( query=query, user_id=user_id, diff --git a/engram/simple.py b/engram/simple.py index b1be84e..9a48c92 100644 --- a/engram/simple.py +++ b/engram/simple.py @@ -237,11 +237,14 @@ def search( ) if isinstance(result, dict) and "results" in result: results = result["results"] - for entry in results: - if "content" not in entry and "memory" in entry: - entry["content"] = entry.get("memory") - return results - return result + elif isinstance(result, list): + results = result + else: + return [] + for entry in results: + if isinstance(entry, dict) and "content" not in entry and "memory" in entry: + entry["content"] = entry.get("memory") + return results def get(self, memory_id: str) -> Optional[Dict[str, Any]]: """Get a specific memory by ID. diff --git a/engram/utils/prompts.py b/engram/utils/prompts.py index c620189..99e28d4 100644 --- a/engram/utils/prompts.py +++ b/engram/utils/prompts.py @@ -181,3 +181,35 @@ - discarded_as_redundant lists information dropped because it was repetitive - confidence reflects how well the memories merged (lower if they seem unrelated) """ + +DISTILLATION_PROMPT = """You are a memory consolidation system. Extract reusable semantic knowledge from a batch of episodic memories (conversations/events). + +EPISODIC MEMORIES: +{episodes} + +Your task is to identify durable FACTS, PREFERENCES, PATTERNS, or PROCEDURES that can be distilled from these episodic memories into long-term semantic knowledge. + +Respond ONLY with valid JSON in this exact format: +{{ + "semantic_facts": [ + {{ + "content": "The specific fact, preference, or pattern to remember", + "importance": "high|medium|low", + "source_episodes": ["episode_id_1", "episode_id_2"], + "reasoning": "Brief explanation of why this is a durable fact" + }} + ], + "skipped_as_temporary": ["Brief description of info that was too transient to distill"] +}} + +Rules: +- Extract ONLY durable facts supported by the episodic evidence +- Maximum {max_facts} facts per batch +- Each fact should be a standalone, self-contained statement +- Use third person ("User prefers..." not "I prefer...") +- Do NOT extract temporary/one-time information +- Do NOT invent information not present in the episodes +- source_episodes should reference the IDs of the episodes that support each fact +- importance: high = likely frequently relevant, medium = useful context, low = niche +- If nothing durable can be extracted, return empty semantic_facts array +""" diff --git a/engram/vector_stores/async_qdrant.py b/engram/vector_stores/async_qdrant.py index b300b7e..36587af 100644 --- a/engram/vector_stores/async_qdrant.py +++ b/engram/vector_stores/async_qdrant.py @@ -5,10 +5,12 @@ from __future__ import annotations +import asyncio import uuid -from dataclasses import dataclass from typing import Any, Dict, List, Optional +from engram.vector_stores.base import MemoryResult + try: from qdrant_client import AsyncQdrantClient from qdrant_client.models import Distance, PointStruct, VectorParams, Filter, FieldCondition, MatchValue @@ -17,14 +19,6 @@ HAS_QDRANT = False -@dataclass -class MemoryResult: - """Result from a vector search.""" - id: str - score: float = 0.0 - payload: Dict[str, Any] = None - - class AsyncQdrantVectorStore: """Async Qdrant vector store for Engram memories.""" @@ -153,20 +147,21 @@ async def search( if conditions: qdrant_filter = Filter(must=conditions) - results = await self.client.search( + response = await self.client.query_points( collection_name=self.collection_name, - query_vector=query_vector, + query=query_vector, limit=limit, query_filter=qdrant_filter, + with_payload=True, ) return [ MemoryResult( id=str(r.id), - score=r.score, + score=float(r.score or 0.0), payload=r.payload or {}, ) - for r in results + for r in response.points ] async def delete(self, ids: List[str]) -> None: diff --git a/engram/vector_stores/base.py b/engram/vector_stores/base.py index ddc0782..f688c51 100644 --- a/engram/vector_stores/base.py +++ b/engram/vector_stores/base.py @@ -1,7 +1,16 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import Any, Dict, List, Optional +@dataclass +class MemoryResult: + """Standard result type returned by all vector store implementations.""" + id: str + score: float = 0.0 + payload: Dict[str, Any] = field(default_factory=dict) + + class VectorStoreBase(ABC): @abstractmethod def create_col(self, name: str, vector_size: int, distance: str = "cosine") -> None: @@ -46,3 +55,7 @@ def list(self, filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = @abstractmethod def reset(self) -> None: pass + + def close(self) -> None: + """Release resources. Override in subclasses that hold connections.""" + pass diff --git a/engram/vector_stores/memory.py b/engram/vector_stores/memory.py index c2518a3..129fbc8 100644 --- a/engram/vector_stores/memory.py +++ b/engram/vector_stores/memory.py @@ -1,19 +1,12 @@ from __future__ import annotations import math +import threading import uuid -from dataclasses import dataclass from typing import Any, Dict, List, Optional from engram.memory.utils import matches_filters -from engram.vector_stores.base import VectorStoreBase - - -@dataclass -class MemoryResult: - id: str - score: float = 0.0 - payload: Dict[str, Any] = None +from engram.vector_stores.base import MemoryResult, VectorStoreBase class InMemoryVectorStore(VectorStoreBase): @@ -22,6 +15,7 @@ def __init__(self, config: Optional[Dict[str, Any]] = None): self.collection_name = self.config.get("collection_name", "fadem_memories") self.vector_size = self.config.get("embedding_model_dims") self._store: Dict[str, Dict[str, Any]] = {} + self._lock = threading.RLock() def create_col(self, name: str, vector_size: int, distance: str = "cosine") -> None: self.collection_name = name @@ -34,8 +28,9 @@ def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict[str, A if ids is not None and len(ids) != len(vectors): raise ValueError("ids length must match vectors length") ids = ids or [str(uuid.uuid4()) for _ in vectors] - for vector_id, vector, payload in zip(ids, vectors, payloads): - self._store[vector_id] = {"vector": vector, "payload": payload} + with self._lock: + for vector_id, vector, payload in zip(ids, vectors, payloads): + self._store[vector_id] = {"vector": vector, "payload": payload} def _cosine_similarity(self, a: List[float], b: List[float]) -> float: if not a or not b: @@ -49,7 +44,9 @@ def _cosine_similarity(self, a: List[float], b: List[float]) -> float: def search(self, query: Optional[str], vectors: List[float], limit: int = 5, filters: Optional[Dict[str, Any]] = None) -> List[MemoryResult]: results: List[MemoryResult] = [] - for vector_id, record in self._store.items(): + with self._lock: + snapshot = list(self._store.items()) + for vector_id, record in snapshot: payload = record.get("payload", {}) if filters and not matches_filters(payload, filters): continue @@ -60,35 +57,41 @@ def search(self, query: Optional[str], vectors: List[float], limit: int = 5, fil return results[:limit] def delete(self, vector_id: str) -> None: - if vector_id in self._store: - del self._store[vector_id] + with self._lock: + if vector_id in self._store: + del self._store[vector_id] def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None) -> None: - if vector_id not in self._store: - return - if vector is not None: - self._store[vector_id]["vector"] = vector - if payload is not None: - self._store[vector_id]["payload"] = payload + with self._lock: + if vector_id not in self._store: + return + if vector is not None: + self._store[vector_id]["vector"] = vector + if payload is not None: + self._store[vector_id]["payload"] = payload def get(self, vector_id: str) -> Optional[MemoryResult]: - record = self._store.get(vector_id) - if not record: - return None - return MemoryResult(id=vector_id, score=0.0, payload=record.get("payload", {})) + with self._lock: + record = self._store.get(vector_id) + if not record: + return None + return MemoryResult(id=vector_id, score=0.0, payload=record.get("payload", {})) def list_cols(self) -> List[str]: return [self.collection_name] def delete_col(self) -> None: - self._store = {} + with self._lock: + self._store = {} def col_info(self) -> Dict[str, Any]: return {"name": self.collection_name, "size": len(self._store), "vector_size": self.vector_size} def list(self, filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = None) -> List[MemoryResult]: results: List[MemoryResult] = [] - for vector_id, record in self._store.items(): + with self._lock: + snapshot = list(self._store.items()) + for vector_id, record in snapshot: payload = record.get("payload", {}) if filters and not matches_filters(payload, filters): continue @@ -98,4 +101,5 @@ def list(self, filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = return results def reset(self) -> None: - self._store = {} + with self._lock: + self._store = {} diff --git a/engram/vector_stores/qdrant.py b/engram/vector_stores/qdrant.py index 6bab31f..fd56dda 100644 --- a/engram/vector_stores/qdrant.py +++ b/engram/vector_stores/qdrant.py @@ -1,17 +1,9 @@ from __future__ import annotations -from dataclasses import dataclass import uuid from typing import Any, Dict, List, Optional -from engram.vector_stores.base import VectorStoreBase - - -@dataclass -class MemoryResult: - id: str - score: float = 0.0 - payload: Dict[str, Any] = None +from engram.vector_stores.base import MemoryResult, VectorStoreBase class QdrantVectorStore(VectorStoreBase): @@ -101,7 +93,10 @@ def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: if vector is not None: from qdrant_client.models import PointStruct - payload = payload or {} + # Preserve existing payload when only updating the vector + if payload is None: + existing = self.get(vector_id) + payload = existing.payload if existing else {} point = PointStruct(id=vector_id, vector=vector, payload=payload) self.client.upsert(collection_name=self.collection_name, points=[point]) return @@ -141,6 +136,14 @@ def reset(self) -> None: self.delete_col() self.create_col(self.collection_name, self.vector_size, self.distance) + def close(self) -> None: + """Close the Qdrant client connection.""" + if hasattr(self, 'client') and self.client is not None: + try: + self.client.close() + except Exception: + pass + def _create_client(config: Dict[str, Any]): from qdrant_client import QdrantClient diff --git a/engram/vector_stores/sqlite_vec.py b/engram/vector_stores/sqlite_vec.py index 71fed43..7f6f73e 100644 --- a/engram/vector_stores/sqlite_vec.py +++ b/engram/vector_stores/sqlite_vec.py @@ -15,22 +15,14 @@ import struct import threading import uuid -from dataclasses import dataclass from typing import Any, Dict, List, Optional from engram.memory.utils import matches_filters -from engram.vector_stores.base import VectorStoreBase +from engram.vector_stores.base import MemoryResult, VectorStoreBase logger = logging.getLogger(__name__) -@dataclass -class MemoryResult: - id: str - score: float = 0.0 - payload: Dict[str, Any] = None - - def _serialize_float32(vector: List[float]) -> bytes: """Serialize a float vector to bytes for sqlite-vec.""" return struct.pack(f"{len(vector)}f", *vector) @@ -65,9 +57,10 @@ def __init__(self, config: Optional[Dict[str, Any]] = None): self._conn = sqlite3.connect(db_path, check_same_thread=False) self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA busy_timeout=5000") - self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.execute("PRAGMA synchronous=FULL") self._conn.row_factory = sqlite3.Row self._lock = threading.RLock() + self._closed = False # Load sqlite-vec extension self._conn.enable_load_extension(True) @@ -113,6 +106,7 @@ def _ensure_collection(self, name: str, vector_size: int) -> None: self._conn.commit() def create_col(self, name: str, vector_size: int, distance: str = "cosine") -> None: + self._check_open() self._ensure_collection(name, vector_size) def insert( @@ -121,6 +115,7 @@ def insert( payloads: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None, ) -> None: + self._check_open() payloads = payloads or [{} for _ in vectors] if len(payloads) != len(vectors): raise ValueError("payloads length must match vectors length") @@ -128,6 +123,10 @@ def insert( raise ValueError("ids length must match vectors length") ids = ids or [str(uuid.uuid4()) for _ in vectors] + for vector in vectors: + if len(vector) != self.vector_size: + raise ValueError(f"Vector has {len(vector)} dimensions, expected {self.vector_size}") + vec_table = self._vec_table(self.collection_name) payload_table = self._payload_table(self.collection_name) @@ -168,6 +167,7 @@ def search( limit: int = 5, filters: Optional[Dict[str, Any]] = None, ) -> List[MemoryResult]: + self._check_open() vec_table = self._vec_table(self.collection_name) payload_table = self._payload_table(self.collection_name) @@ -229,6 +229,7 @@ def search( return results[:limit] def delete(self, vector_id: str) -> None: + self._check_open() payload_table = self._payload_table(self.collection_name) vec_table = self._vec_table(self.collection_name) @@ -253,6 +254,7 @@ def update( vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None, ) -> None: + self._check_open() payload_table = self._payload_table(self.collection_name) vec_table = self._vec_table(self.collection_name) @@ -278,6 +280,7 @@ def update( self._conn.commit() def get(self, vector_id: str) -> Optional[MemoryResult]: + self._check_open() payload_table = self._payload_table(self.collection_name) with self._lock: @@ -298,6 +301,7 @@ def get(self, vector_id: str) -> Optional[MemoryResult]: return MemoryResult(id=row["uuid"], score=0.0, payload=payload) def list_cols(self) -> List[str]: + self._check_open() with self._lock: rows = self._conn.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'payload_%'", @@ -305,6 +309,7 @@ def list_cols(self) -> List[str]: return [row["name"].replace("payload_", "", 1) for row in rows] def delete_col(self) -> None: + self._check_open() vec_table = self._vec_table(self.collection_name) payload_table = self._payload_table(self.collection_name) @@ -314,6 +319,7 @@ def delete_col(self) -> None: self._conn.commit() def col_info(self) -> Dict[str, Any]: + self._check_open() payload_table = self._payload_table(self.collection_name) with self._lock: @@ -333,6 +339,7 @@ def list( filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, ) -> List[MemoryResult]: + self._check_open() payload_table = self._payload_table(self.collection_name) effective_limit = limit or 100 @@ -358,5 +365,26 @@ def list( return results[:effective_limit] def reset(self) -> None: + self._check_open() self.delete_col() self._ensure_collection(self.collection_name, self.vector_size) + + def _check_open(self) -> None: + """Raise if the store has been closed.""" + if self._closed: + raise RuntimeError("SqliteVecStore is closed") + + def close(self) -> None: + """Close the SQLite connection.""" + with self._lock: + self._closed = True + if self._conn: + try: + self._conn.execute("PRAGMA wal_checkpoint(RESTART)") + except Exception: + pass + try: + self._conn.close() + except Exception: + pass + self._conn = None # type: ignore[assignment] diff --git a/tests/test_distillation.py b/tests/test_distillation.py new file mode 100644 index 0000000..3311464 --- /dev/null +++ b/tests/test_distillation.py @@ -0,0 +1,192 @@ +"""Tests for engram.core.distillation — Replay-driven semantic distillation.""" + +import json +import os +import tempfile +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from engram.configs.base import DistillationConfig +from engram.core.distillation import ReplayDistiller +from engram.db.sqlite import SQLiteManager + + +@pytest.fixture +def tmp_db(): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + db = SQLiteManager(path) + yield db + db.close() + os.unlink(path) + + +@pytest.fixture +def config(): + return DistillationConfig( + enable_distillation=True, + distillation_batch_size=10, + distillation_min_episodes=2, + max_semantic_per_batch=3, + ) + + +@pytest.fixture +def mock_llm(): + llm = MagicMock() + llm.generate.return_value = json.dumps({ + "semantic_facts": [ + { + "content": "User prefers TypeScript for frontend development", + "importance": "high", + "source_episodes": ["ep1", "ep2"], + "reasoning": "Mentioned multiple times", + }, + { + "content": "User deploys on Fridays using CI/CD pipeline", + "importance": "medium", + "source_episodes": ["ep3"], + "reasoning": "Consistent pattern", + }, + ], + "skipped_as_temporary": ["one-time error discussion"], + }) + return llm + + +def _add_episodic_memory(db, user_id, content, created_at=None): + """Helper to add an episodic memory directly to the DB.""" + now = created_at or datetime.now(timezone.utc).isoformat() + db.add_memory({ + "memory": content, + "user_id": user_id, + "memory_type": "episodic", + "created_at": now, + "updated_at": now, + "layer": "sml", + "strength": 1.0, + }) + + +class TestReplayDistiller: + def test_disabled_returns_skipped(self, tmp_db, mock_llm): + config = DistillationConfig(enable_distillation=False) + distiller = ReplayDistiller(tmp_db, mock_llm, config) + result = distiller.run("user1") + assert result["skipped"] is True + assert result["reason"] == "distillation disabled" + + def test_insufficient_episodes(self, tmp_db, mock_llm, config): + # Add only 1 episode (min is 2) + yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date().isoformat() + _add_episodic_memory(tmp_db, "user1", "Hello", f"{yesterday}T12:00:00") + + distiller = ReplayDistiller(tmp_db, mock_llm, config) + result = distiller.run("user1", date_str=yesterday) + assert result["skipped"] is True + assert result["reason"] == "insufficient episodes" + + def test_successful_distillation(self, tmp_db, mock_llm, config): + yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date().isoformat() + for i in range(5): + _add_episodic_memory( + tmp_db, "user1", f"Episode content {i}", + f"{yesterday}T{10+i}:00:00", + ) + + add_fn = MagicMock() + add_fn.return_value = { + "results": [{"id": f"sem_{i}", "event": "ADD"}] + } + + distiller = ReplayDistiller(tmp_db, mock_llm, config) + result = distiller.run("user1", date_str=yesterday, memory_add_fn=add_fn) + + assert result.get("skipped") is not True + assert result["episodes_sampled"] == 5 + assert result["semantic_created"] == 2 + assert "run_id" in result + # LLM was called + assert mock_llm.generate.called + + def test_dedup_detection(self, tmp_db, mock_llm, config): + yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date().isoformat() + for i in range(3): + _add_episodic_memory( + tmp_db, "user1", f"Episode {i}", + f"{yesterday}T{10+i}:00:00", + ) + + add_fn = MagicMock() + # First fact gets added, second gets deduplicated + add_fn.side_effect = [ + {"results": [{"id": "sem_1", "event": "ADD"}]}, + {"results": [{"id": "existing", "event": "NOOP"}]}, + ] + + distiller = ReplayDistiller(tmp_db, mock_llm, config) + result = distiller.run("user1", date_str=yesterday, memory_add_fn=add_fn) + + assert result["semantic_created"] == 1 + assert result["semantic_deduplicated"] == 1 + + def test_invalid_llm_response(self, tmp_db, config): + yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date().isoformat() + for i in range(3): + _add_episodic_memory( + tmp_db, "user1", f"Episode {i}", + f"{yesterday}T{10+i}:00:00", + ) + + bad_llm = MagicMock() + bad_llm.generate.return_value = "not valid json at all" + + distiller = ReplayDistiller(tmp_db, bad_llm, config) + result = distiller.run("user1", date_str=yesterday, memory_add_fn=MagicMock()) + + # Should not crash, just produce 0 facts + assert result.get("semantic_created", 0) == 0 + + +class TestDistillationProvenance: + def test_provenance_recorded(self, tmp_db, mock_llm, config): + yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date().isoformat() + for i in range(3): + _add_episodic_memory( + tmp_db, "user1", f"Episode {i}", + f"{yesterday}T{10+i}:00:00", + ) + + add_fn = MagicMock() + add_fn.return_value = {"results": [{"id": "sem_123", "event": "ADD"}]} + + distiller = ReplayDistiller(tmp_db, mock_llm, config) + distiller.run("user1", date_str=yesterday, memory_add_fn=add_fn) + + # Check provenance was recorded + with tmp_db._get_connection() as conn: + rows = conn.execute("SELECT * FROM distillation_provenance").fetchall() + assert len(rows) > 0 + + +class TestDistillationLog: + def test_log_recorded(self, tmp_db, mock_llm, config): + yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).date().isoformat() + for i in range(3): + _add_episodic_memory( + tmp_db, "user1", f"Episode {i}", + f"{yesterday}T{10+i}:00:00", + ) + + add_fn = MagicMock() + add_fn.return_value = {"results": [{"id": "sem_1", "event": "ADD"}]} + + distiller = ReplayDistiller(tmp_db, mock_llm, config) + result = distiller.run("user1", date_str=yesterday, memory_add_fn=add_fn) + + with tmp_db._get_connection() as conn: + rows = conn.execute("SELECT * FROM distillation_log WHERE user_id = 'user1'").fetchall() + assert len(rows) == 1 + assert rows[0]["episodes_sampled"] == 3 diff --git a/tests/test_forgetting.py b/tests/test_forgetting.py new file mode 100644 index 0000000..2b19bce --- /dev/null +++ b/tests/test_forgetting.py @@ -0,0 +1,154 @@ +"""Tests for engram.core.forgetting — Advanced forgetting mechanisms.""" + +import os +import tempfile +from unittest.mock import MagicMock, PropertyMock + +import pytest + +from engram.configs.base import DistillationConfig, FadeMemConfig +from engram.core.forgetting import ( + HomeostaticNormalizer, + InterferencePruner, + RedundancyCollapser, +) +from engram.db.sqlite import SQLiteManager + + +@pytest.fixture +def tmp_db(): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + db = SQLiteManager(path) + yield db + db.close() + os.unlink(path) + + +@pytest.fixture +def fadem_config(): + return FadeMemConfig( + conflict_similarity_threshold=0.85, + forgetting_threshold=0.1, + ) + + +def _make_memory(mid, content, strength=0.5, embedding=None, immutable=False): + return { + "id": mid, + "memory": content, + "strength": strength, + "embedding": embedding or [0.1, 0.2], + "immutable": immutable, + } + + +class TestInterferencePruner: + def test_disabled(self, tmp_db, fadem_config): + config = DistillationConfig(enable_interference_pruning=False) + pruner = InterferencePruner(tmp_db, config, fadem_config) + result = pruner.run([_make_memory("m1", "test")]) + assert result == {"checked": 0, "demoted": 0} + + def test_no_search_fn(self, tmp_db, fadem_config): + config = DistillationConfig(enable_interference_pruning=True) + pruner = InterferencePruner(tmp_db, config, fadem_config) + result = pruner.run([_make_memory("m1", "test")]) + assert result == {"checked": 0, "demoted": 0} + + def test_skips_immutable(self, tmp_db, fadem_config): + config = DistillationConfig(enable_interference_pruning=True) + mock_search = MagicMock(return_value=[]) + mock_resolve = MagicMock() + pruner = InterferencePruner( + tmp_db, config, fadem_config, + resolve_conflict_fn=mock_resolve, + search_fn=mock_search, + ) + memories = [_make_memory("m1", "test", immutable=True)] + result = pruner.run(memories) + assert result["checked"] == 0 + + def test_skips_low_strength(self, tmp_db, fadem_config): + config = DistillationConfig(enable_interference_pruning=True) + mock_search = MagicMock(return_value=[]) + mock_resolve = MagicMock() + pruner = InterferencePruner( + tmp_db, config, fadem_config, + resolve_conflict_fn=mock_resolve, + search_fn=mock_search, + ) + memories = [_make_memory("m1", "test", strength=0.1)] + result = pruner.run(memories) + assert result["checked"] == 0 + + +class TestRedundancyCollapser: + def test_disabled(self, tmp_db): + config = DistillationConfig(enable_redundancy_collapse=False) + collapser = RedundancyCollapser(tmp_db, config) + result = collapser.run([_make_memory("m1", "test")]) + assert result == {"groups_fused": 0, "memories_fused": 0} + + def test_no_fuse_fn(self, tmp_db): + config = DistillationConfig(enable_redundancy_collapse=True) + collapser = RedundancyCollapser(tmp_db, config) + result = collapser.run([_make_memory("m1", "test")]) + assert result == {"groups_fused": 0, "memories_fused": 0} + + def test_skips_immutable(self, tmp_db): + config = DistillationConfig(enable_redundancy_collapse=True) + mock_search = MagicMock(return_value=[]) + mock_fuse = MagicMock() + collapser = RedundancyCollapser(tmp_db, config, fuse_fn=mock_fuse, search_fn=mock_search) + memories = [_make_memory("m1", "test", immutable=True)] + result = collapser.run(memories) + assert result["groups_fused"] == 0 + + +class TestHomeostaticNormalizer: + def test_disabled(self, tmp_db, fadem_config): + config = DistillationConfig(enable_homeostasis=False) + normalizer = HomeostaticNormalizer(tmp_db, config, fadem_config) + result = normalizer.run("user1") + assert result == {"namespaces_over_budget": 0, "pressured": 0, "forgotten": 0} + + def test_under_budget_no_action(self, tmp_db, fadem_config): + config = DistillationConfig( + enable_homeostasis=True, + homeostasis_budget_per_namespace=5000, + ) + # Add a few memories — well under budget + for i in range(3): + tmp_db.add_memory({ + "memory": f"Memory {i}", + "user_id": "user1", + "namespace": "default", + "strength": 0.5, + }) + + normalizer = HomeostaticNormalizer(tmp_db, config, fadem_config) + result = normalizer.run("user1") + assert result["namespaces_over_budget"] == 0 + + def test_over_budget_applies_pressure(self, tmp_db, fadem_config): + config = DistillationConfig( + enable_homeostasis=True, + homeostasis_budget_per_namespace=5, # Very low budget + homeostasis_pressure_factor=0.5, + ) + # Add 10 memories, which exceeds budget of 5 + for i in range(10): + tmp_db.add_memory({ + "memory": f"Memory {i}", + "user_id": "user1", + "namespace": "default", + "strength": 0.3, + }) + + mock_delete = MagicMock() + normalizer = HomeostaticNormalizer(tmp_db, config, fadem_config, delete_fn=mock_delete) + result = normalizer.run("user1") + assert result["namespaces_over_budget"] == 1 + # Some memories should have been pressured or deleted + assert result["pressured"] + result["forgotten"] > 0 diff --git a/tests/test_intent.py b/tests/test_intent.py new file mode 100644 index 0000000..31c7ef5 --- /dev/null +++ b/tests/test_intent.py @@ -0,0 +1,90 @@ +"""Tests for engram.core.intent — Query intent classifier.""" + +import pytest + +from engram.core.intent import QueryIntent, classify_intent + + +class TestEpisodicQueries: + def test_when_did(self): + assert classify_intent("When did we discuss the API?") == QueryIntent.EPISODIC + + def test_last_time(self): + assert classify_intent("What was the last time we talked about deployment?") == QueryIntent.EPISODIC + + def test_what_happened(self): + assert classify_intent("What happened in our meeting yesterday?") == QueryIntent.EPISODIC + + def test_ago(self): + assert classify_intent("What did I say 3 days ago?") == QueryIntent.EPISODIC + + def test_last_week(self): + assert classify_intent("What was discussed last week?") == QueryIntent.EPISODIC + + def test_we_discussed(self): + assert classify_intent("We discussed the bug fix for the login page") == QueryIntent.EPISODIC + + def test_i_told(self): + assert classify_intent("I told you about my new project") == QueryIntent.EPISODIC + + def test_what_did_i(self): + assert classify_intent("What did I mention about Python?") == QueryIntent.EPISODIC + + +class TestSemanticQueries: + def test_what_is(self): + assert classify_intent("What is the deployment process?") == QueryIntent.SEMANTIC + + def test_prefer(self): + assert classify_intent("What language do I prefer for backend?") == QueryIntent.SEMANTIC + + def test_favorite(self): + assert classify_intent("What's my favorite color?") == QueryIntent.SEMANTIC + + def test_how_to(self): + assert classify_intent("How to set up the development environment?") == QueryIntent.SEMANTIC + + def test_whats_my(self): + assert classify_intent("What's my email address?") == QueryIntent.SEMANTIC + + def test_procedure(self): + assert classify_intent("What's the procedure for code review?") == QueryIntent.SEMANTIC + + def test_workflow(self): + assert classify_intent("Tell me about the CI/CD workflow") == QueryIntent.SEMANTIC + + +class TestMixedQueries: + def test_empty_query(self): + assert classify_intent("") == QueryIntent.MIXED + + def test_whitespace(self): + assert classify_intent(" ") == QueryIntent.MIXED + + def test_ambiguous(self): + assert classify_intent("Python") == QueryIntent.MIXED + + def test_no_signals(self): + assert classify_intent("project update") == QueryIntent.MIXED + + def test_both_signals(self): + # "what is" (semantic) + "last time" (episodic) — may be mixed + result = classify_intent("What is something we said last time?") + assert result in (QueryIntent.MIXED, QueryIntent.EPISODIC, QueryIntent.SEMANTIC) + + +class TestEdgeCases: + def test_none_like(self): + assert classify_intent("") == QueryIntent.MIXED + + def test_case_insensitive(self): + assert classify_intent("WHEN DID we talk?") == QueryIntent.EPISODIC + assert classify_intent("WHAT IS my name?") == QueryIntent.SEMANTIC + + def test_single_word_no_crash(self): + classify_intent("hello") + + def test_very_long_query(self): + long_query = "when did " * 100 + result = classify_intent(long_query) + assert result == QueryIntent.EPISODIC diff --git a/tests/test_memory_types.py b/tests/test_memory_types.py new file mode 100644 index 0000000..6f4d0c3 --- /dev/null +++ b/tests/test_memory_types.py @@ -0,0 +1,201 @@ +"""Tests for CLS memory type classification and backward compatibility.""" + +import os +import tempfile + +import pytest + +from engram.configs.base import DistillationConfig, MemoryConfig +from engram.core.intent import QueryIntent, classify_intent +from engram.core.traces import compute_effective_strength, initialize_traces +from engram.db.sqlite import SQLiteManager + + +@pytest.fixture +def tmp_db(): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + db = SQLiteManager(path) + yield db + db.close() + os.unlink(path) + + +class TestDistillationConfig: + def test_defaults_all_enabled(self): + config = DistillationConfig() + assert config.enable_memory_types is True + assert config.enable_distillation is True + assert config.enable_interference_pruning is True + assert config.enable_redundancy_collapse is True + assert config.enable_homeostasis is True + assert config.enable_multi_trace is True + assert config.enable_intent_routing is True + + def test_memory_config_has_distillation(self): + mc = MemoryConfig() + assert hasattr(mc, "distillation") + assert isinstance(mc.distillation, DistillationConfig) + + def test_version_updated(self): + mc = MemoryConfig() + assert mc.version == "v1.4" + + def test_default_config_has_cls_enabled(self): + """Verify a default MemoryConfig has CLS features enabled.""" + mc = MemoryConfig() + assert mc.distillation.enable_memory_types is True + assert mc.distillation.default_memory_type == "semantic" + + def test_custom_config(self): + config = DistillationConfig( + enable_memory_types=True, + enable_multi_trace=True, + enable_intent_routing=True, + ) + assert config.enable_memory_types is True + assert config.enable_multi_trace is True + assert config.enable_intent_routing is True + + +class TestDBMemoryTypeColumn: + def test_memory_type_column_exists(self, tmp_db): + """Verify memory_type column was added by migration.""" + with tmp_db._get_connection() as conn: + row = conn.execute( + "PRAGMA table_info(memories)" + ).fetchall() + columns = {r["name"] for r in row} + assert "memory_type" in columns + assert "s_fast" in columns + assert "s_mid" in columns + assert "s_slow" in columns + + def test_add_memory_with_type(self, tmp_db): + mid = tmp_db.add_memory({ + "memory": "Test episodic memory", + "user_id": "user1", + "memory_type": "episodic", + }) + mem = tmp_db.get_memory(mid) + assert mem["memory_type"] == "episodic" + + def test_add_memory_default_type(self, tmp_db): + mid = tmp_db.add_memory({ + "memory": "Test default memory", + "user_id": "user1", + }) + mem = tmp_db.get_memory(mid) + assert mem["memory_type"] == "semantic" + + def test_add_memory_with_traces(self, tmp_db): + mid = tmp_db.add_memory({ + "memory": "Traced memory", + "user_id": "user1", + "s_fast": 0.8, + "s_mid": 0.0, + "s_slow": 0.0, + }) + mem = tmp_db.get_memory(mid) + assert mem["s_fast"] == 0.8 + assert mem["s_mid"] == 0.0 + assert mem["s_slow"] == 0.0 + + def test_update_multi_trace(self, tmp_db): + mid = tmp_db.add_memory({ + "memory": "Trace update test", + "user_id": "user1", + "strength": 1.0, + "s_fast": 1.0, + "s_mid": 0.0, + "s_slow": 0.0, + }) + tmp_db.update_multi_trace(mid, 0.5, 0.3, 0.1, 0.35) + mem = tmp_db.get_memory(mid) + assert mem["s_fast"] == 0.5 + assert mem["s_mid"] == 0.3 + assert mem["s_slow"] == 0.1 + assert mem["strength"] == pytest.approx(0.35) + + +class TestDBEpisodicMemories: + def test_get_episodic_memories(self, tmp_db): + tmp_db.add_memory({ + "memory": "Episodic 1", + "user_id": "user1", + "memory_type": "episodic", + }) + tmp_db.add_memory({ + "memory": "Semantic 1", + "user_id": "user1", + "memory_type": "semantic", + }) + eps = tmp_db.get_episodic_memories("user1") + assert len(eps) == 1 + assert eps[0]["memory"] == "Episodic 1" + + def test_get_episodic_empty(self, tmp_db): + eps = tmp_db.get_episodic_memories("nonexistent") + assert eps == [] + + +class TestDBDistillationTables: + def test_distillation_log(self, tmp_db): + run_id = tmp_db.log_distillation_run( + "user1", + episodes_sampled=10, + semantic_created=2, + semantic_deduplicated=1, + ) + assert run_id # non-empty string + + def test_distillation_provenance(self, tmp_db): + tmp_db.add_distillation_provenance( + semantic_memory_id="sem_1", + episodic_memory_ids=["ep_1", "ep_2"], + run_id="run_1", + ) + with tmp_db._get_connection() as conn: + rows = conn.execute( + "SELECT * FROM distillation_provenance WHERE semantic_memory_id = 'sem_1'" + ).fetchall() + assert len(rows) == 2 + + def test_memory_count_by_namespace(self, tmp_db): + for i in range(3): + tmp_db.add_memory({ + "memory": f"Mem {i}", + "user_id": "user1", + "namespace": "default", + }) + tmp_db.add_memory({ + "memory": "Work mem", + "user_id": "user1", + "namespace": "work", + }) + counts = tmp_db.get_memory_count_by_namespace("user1") + assert counts.get("default", 0) == 3 + assert counts.get("work", 0) == 1 + + +class TestMultiTraceIntegration: + def test_trace_lifecycle(self): + """Test the full lifecycle: initialize -> compute -> cascade.""" + config = DistillationConfig(enable_multi_trace=True) + + # New memory: all in fast + s_f, s_m, s_s = initialize_traces(0.9, is_new=True) + assert s_f == 0.9 + assert s_m == 0.0 + assert s_s == 0.0 + + eff = compute_effective_strength(s_f, s_m, s_s, config) + assert eff == pytest.approx(0.2 * 0.9) # Only fast has value + + # After deep sleep cascade + from engram.core.traces import cascade_traces + s_f2, s_m2, s_s2 = cascade_traces(s_f, s_m, s_s, config, deep_sleep=True) + assert s_m2 > 0.0 # Some transferred to mid + eff2 = compute_effective_strength(s_f2, s_m2, s_s2, config) + # Total energy should shift across traces but weighted sum may differ + assert eff2 > 0.0 diff --git a/tests/test_sqlite_vec.py b/tests/test_sqlite_vec.py index 8712a21..eb5d24f 100644 --- a/tests/test_sqlite_vec.py +++ b/tests/test_sqlite_vec.py @@ -7,7 +7,8 @@ sqlite_vec = pytest.importorskip("sqlite_vec", reason="sqlite-vec not installed") -from engram.vector_stores.sqlite_vec import SqliteVecStore, MemoryResult +from engram.vector_stores.base import MemoryResult +from engram.vector_stores.sqlite_vec import SqliteVecStore @pytest.fixture diff --git a/tests/test_traces.py b/tests/test_traces.py new file mode 100644 index 0000000..68e9d1c --- /dev/null +++ b/tests/test_traces.py @@ -0,0 +1,160 @@ +"""Tests for engram.core.traces — Benna-Fusi multi-timescale strength model.""" + +import math +from datetime import datetime, timedelta, timezone + +import pytest + +from engram.configs.base import DistillationConfig +from engram.core.traces import ( + boost_fast_trace, + cascade_traces, + compute_effective_strength, + decay_traces, + initialize_traces, +) + + +@pytest.fixture +def config(): + return DistillationConfig(enable_multi_trace=True) + + +class TestInitializeTraces: + def test_new_memory_all_in_fast(self): + s_fast, s_mid, s_slow = initialize_traces(0.8, is_new=True) + assert s_fast == 0.8 + assert s_mid == 0.0 + assert s_slow == 0.0 + + def test_migrated_memory_spread(self): + s_fast, s_mid, s_slow = initialize_traces(0.6, is_new=False) + assert s_fast == 0.6 + assert s_mid == pytest.approx(0.3) + assert s_slow == 0.0 + + def test_zero_strength(self): + s_fast, s_mid, s_slow = initialize_traces(0.0, is_new=True) + assert s_fast == 0.0 + assert s_mid == 0.0 + assert s_slow == 0.0 + + def test_strength_clamped(self): + s_fast, s_mid, s_slow = initialize_traces(1.5, is_new=True) + assert s_fast == 1.0 + assert s_mid == 0.0 + assert s_slow == 0.0 + + def test_negative_clamped(self): + s_fast, _, _ = initialize_traces(-0.5, is_new=True) + assert s_fast == 0.0 + + +class TestComputeEffectiveStrength: + def test_default_weights(self, config): + # 0.2*1.0 + 0.3*0.5 + 0.5*0.0 = 0.35 + eff = compute_effective_strength(1.0, 0.5, 0.0, config) + assert eff == pytest.approx(0.35) + + def test_all_ones(self, config): + # 0.2*1 + 0.3*1 + 0.5*1 = 1.0 + eff = compute_effective_strength(1.0, 1.0, 1.0, config) + assert eff == pytest.approx(1.0) + + def test_all_zeros(self, config): + eff = compute_effective_strength(0.0, 0.0, 0.0, config) + assert eff == 0.0 + + def test_slow_dominates(self, config): + # 0.2*0 + 0.3*0 + 0.5*0.8 = 0.4 + eff = compute_effective_strength(0.0, 0.0, 0.8, config) + assert eff == pytest.approx(0.4) + + def test_clamped_to_unit(self, config): + eff = compute_effective_strength(1.0, 1.0, 1.0, config) + assert eff <= 1.0 + + +class TestDecayTraces: + def test_recent_memory_minimal_decay(self, config): + now = datetime.now(timezone.utc) + s_f, s_m, s_s = decay_traces(1.0, 0.5, 0.2, now, 0, config) + assert s_f == pytest.approx(1.0, abs=0.01) + assert s_m == pytest.approx(0.5, abs=0.01) + assert s_s == pytest.approx(0.2, abs=0.01) + + def test_old_memory_significant_decay(self, config): + old = datetime.now(timezone.utc) - timedelta(days=10) + s_f, s_m, s_s = decay_traces(1.0, 1.0, 1.0, old, 0, config) + # Fast decays fastest + assert s_f < s_m < s_s + assert s_f < 1.0 + + def test_access_count_dampens_decay(self, config): + old = datetime.now(timezone.utc) - timedelta(days=5) + # No accesses + f0, m0, s0 = decay_traces(1.0, 1.0, 1.0, old, 0, config) + # Many accesses + f10, m10, s10 = decay_traces(1.0, 1.0, 1.0, old, 10, config) + # More accesses = less decay + assert f10 > f0 + assert m10 > m0 + + def test_string_last_accessed(self, config): + now = datetime.now(timezone.utc).isoformat() + s_f, s_m, s_s = decay_traces(1.0, 0.5, 0.2, now, 0, config) + assert s_f == pytest.approx(1.0, abs=0.01) + + def test_values_clamped(self, config): + old = datetime.now(timezone.utc) - timedelta(days=100) + s_f, s_m, s_s = decay_traces(1.0, 1.0, 1.0, old, 0, config) + assert s_f >= 0.0 + assert s_m >= 0.0 + assert s_s >= 0.0 + + +class TestCascadeTraces: + def test_normal_cascade_fast_to_mid(self, config): + s_f, s_m, s_s = cascade_traces(1.0, 0.0, 0.0, config, deep_sleep=False) + # 10% of fast goes to mid + assert s_f == pytest.approx(0.9) + assert s_m == pytest.approx(0.1) + assert s_s == pytest.approx(0.0) + + def test_deep_sleep_cascade(self, config): + s_f, s_m, s_s = cascade_traces(1.0, 0.5, 0.0, config, deep_sleep=True) + # Fast -> mid: fast loses 0.1 (10%), mid gains 0.1 -> mid = 0.6 + # Mid -> slow: mid loses 0.6*0.05 = 0.03, slow gains 0.03 + assert s_f == pytest.approx(0.9) + assert s_m == pytest.approx(0.57) + assert s_s == pytest.approx(0.03) + + def test_no_cascade_from_zero(self, config): + s_f, s_m, s_s = cascade_traces(0.0, 0.0, 0.0, config, deep_sleep=True) + assert s_f == 0.0 + assert s_m == 0.0 + assert s_s == 0.0 + + def test_values_clamped(self, config): + s_f, s_m, s_s = cascade_traces(1.0, 1.0, 1.0, config, deep_sleep=True) + assert 0.0 <= s_f <= 1.0 + assert 0.0 <= s_m <= 1.0 + assert 0.0 <= s_s <= 1.0 + + +class TestBoostFastTrace: + def test_basic_boost(self): + result = boost_fast_trace(0.5, 0.1) + assert result == pytest.approx(0.6) + + def test_clamped_at_one(self): + result = boost_fast_trace(0.95, 0.1) + assert result == 1.0 + + def test_zero_boost(self): + result = boost_fast_trace(0.5, 0.0) + assert result == 0.5 + + def test_clamped_at_zero(self): + result = boost_fast_trace(0.0, -0.5) + assert result == 0.0