diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index ef2b127d45..e4152b77d1 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -12,7 +12,13 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload from .._agents import BaseAgent -from .._sessions import AgentSession, BaseContextProvider, BaseHistoryProvider, SessionContext +from .._sessions import ( + AgentSession, + BaseContextProvider, + BaseHistoryProvider, + InMemoryHistoryProvider, + SessionContext, +) from .._types import ( AgentResponse, AgentResponseUpdate, @@ -112,7 +118,17 @@ def __init__( if not any(is_type_compatible(list[Message], input_type) for input_type in start_executor.input_types): raise ValueError("Workflow's start executor cannot handle list[Message]") - super().__init__(id=id, name=name, description=description, context_providers=context_providers, **kwargs) + resolved_context_providers = list(context_providers) if context_providers is not None else [] + if not resolved_context_providers: + resolved_context_providers.append(InMemoryHistoryProvider("memory")) + + super().__init__( + id=id, + name=name, + description=description, + context_providers=resolved_context_providers, + **kwargs, + ) self._workflow: Workflow = workflow self._pending_requests: dict[str, WorkflowEvent[Any]] = {} diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 5adf82dd57..b2fbded39b 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -14,6 +14,7 @@ AgentSession, Content, Executor, + InMemoryHistoryProvider, Message, ResponseStream, SupportsAgentRun, @@ -562,6 +563,35 @@ async def test_empty_session_works_correctly(self) -> None: assert len(capturing_executor.received_messages) == 1 assert capturing_executor.received_messages[0].text == "Just a new message" + async def test_workflow_as_agent_adds_default_history_provider(self) -> None: + """Test that workflow.as_agent() defaults to in-memory history when no providers are configured.""" + capturing_executor = ConversationHistoryCapturingExecutor(id="default_history_provider_test") + workflow = WorkflowBuilder(start_executor=capturing_executor).build() + agent = workflow.as_agent(name="Default History Provider Agent") + session = AgentSession() + + await agent.run("first message", session=session) + await agent.run("second message", session=session) + + assert any(isinstance(provider, InMemoryHistoryProvider) for provider in agent.context_providers) + texts = [message.text for message in capturing_executor.received_messages] + assert "first message" in texts + assert "second message" in texts + + async def test_workflow_agent_keeps_explicit_context_providers(self) -> None: + """Test that WorkflowAgent does not append defaults when context providers are explicitly provided.""" + workflow = WorkflowBuilder( + start_executor=ConversationHistoryCapturingExecutor(id="explicit_provider_test") + ).build() + explicit_provider = InMemoryHistoryProvider("custom-memory") + agent = WorkflowAgent( + workflow=workflow, + name="Explicit Provider Agent", + context_providers=[explicit_provider], + ) + + assert agent.context_providers == [explicit_provider] + async def test_checkpoint_storage_passed_to_workflow(self) -> None: """Test that checkpoint_storage parameter is passed through to the workflow.""" from agent_framework import InMemoryCheckpointStorage