Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 92 additions & 20 deletions python/packages/azure-ai/agent_framework_azure_ai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

import json
import sys
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from contextlib import suppress
from typing import Any, ClassVar, Generic, Literal, TypedDict, TypeVar, cast

from agent_framework import (
Expand Down Expand Up @@ -218,6 +220,10 @@ class MyOptions(ChatOptions, total=False):
self._is_application_endpoint = "/applications/" in project_client._config.endpoint # type: ignore
# Track whether we should close client connection
self._should_close_client = should_close_client
# Track creation-time agent configuration for runtime mismatch warnings.
self.warn_runtime_tools_and_structure_changed = False
self._created_agent_tool_names: set[str] = set()
self._created_agent_structured_output_signature: str | None = None

async def configure_azure_monitor(
self,
Expand Down Expand Up @@ -341,18 +347,18 @@ async def _get_agent_reference_or_create(
"Agent name is required. Provide 'agent_name' when initializing AzureAIClient "
"or 'name' when initializing Agent."
)
# If the agent exists and we do not want to track agent configuration, return early
if self.agent_version is not None and not self.warn_runtime_tools_and_structure_changed:
return {"name": self.agent_name, "version": self.agent_version, "type": "agent_reference"}

# If no agent_version is provided, either use latest version or create a new agent:
if self.agent_version is None:
# Try to use latest version if requested and agent exists
if self.use_latest_version:
try:
with suppress(ResourceNotFoundError):
existing_agent = await self.project_client.agents.get(self.agent_name)
self.agent_version = existing_agent.versions.latest.version
return {"name": self.agent_name, "version": self.agent_version, "type": "agent_reference"}
except ResourceNotFoundError:
# Agent doesn't exist, fall through to creation logic
pass

if "model" not in run_options or not run_options["model"]:
raise ServiceInitializationError(
Expand Down Expand Up @@ -395,6 +401,26 @@ async def _get_agent_reference_or_create(
)

self.agent_version = created_agent.version
self.warn_runtime_tools_and_structure_changed = True
self._created_agent_tool_names = self._extract_tool_names(run_options.get("tools"))
self._created_agent_structured_output_signature = self._get_structured_output_signature(chat_options)
else:
runtime_tools = run_options.get("tools")
tools_changed = False
if runtime_tools is not None:
tools_changed = self._extract_tool_names(runtime_tools) != self._created_agent_tool_names

runtime_structured_output = self._get_structured_output_signature(chat_options)
structured_output_changed = (
runtime_structured_output is not None
and runtime_structured_output != self._created_agent_structured_output_signature
)

if tools_changed or structured_output_changed:
logger.warning(
"AzureAIClient does not support runtime tools or structured_output overrides after agent creation. "
"Use AzureOpenAIResponsesClient instead."
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the issue here that runtime tools or structured output is not supported so we are trying to catch the case where they happen to match what the agent is configured with and allow that to pass?

One issue with this approach is that if the agent changes on the server the client code stops working even though it didn't change. Would it be just better to ignore the runtime options and just log a warning?


return {"name": self.agent_name, "version": self.agent_version, "type": "agent_reference"}

Expand All @@ -403,6 +429,66 @@ async def _close_client_if_needed(self) -> None:
if self._should_close_client:
await self.project_client.close()

def _extract_tool_names(self, tools: Any) -> set[str]:
"""Extract comparable tool names from runtime tool payloads."""
if not isinstance(tools, Sequence) or isinstance(tools, str | bytes):
return set()
return {self._get_tool_name(tool) for tool in tools}

def _get_tool_name(self, tool: Any) -> str:
"""Get a stable name for a tool for runtime comparison."""
if isinstance(tool, FunctionTool):
return tool.name
if isinstance(tool, Mapping):
tool_type = tool.get("type")
if tool_type == "function":
if isinstance(function_data := tool.get("function"), Mapping) and function_data.get("name"):
return str(function_data["name"])
if tool.get("name"):
return str(tool["name"])
if tool.get("name"):
return str(tool["name"])
if tool.get("server_label"):
return f"mcp:{tool['server_label']}"
if tool_type:
return str(tool_type)
if getattr(tool, "name", None):
return str(tool.name)
if getattr(tool, "server_label", None):
return f"mcp:{tool.server_label}"
if getattr(tool, "type", None):
return str(tool.type)
return type(tool).__name__

def _get_structured_output_signature(self, chat_options: Mapping[str, Any] | None) -> str | None:
"""Build a stable signature for structured_output/response_format values."""
if not chat_options:
return None
response_format = chat_options.get("response_format")
if response_format is None:
return None
if isinstance(response_format, type):
return f"{response_format.__module__}.{response_format.__qualname__}"
if isinstance(response_format, Mapping):
return json.dumps(response_format, sort_keys=True, default=str)
return str(response_format)

def _remove_agent_level_run_options(self, run_options: dict[str, Any]) -> None:
"""Remove request-level options that Azure AI only supports at agent creation time."""
agent_level_option_to_run_keys = {
"model_id": ("model",),
"tools": ("tools",),
"response_format": ("response_format", "text", "text_format"),
"rai_config": ("rai_config",),
"temperature": ("temperature",),
"top_p": ("top_p",),
"reasoning": ("reasoning",),
}

for run_keys in agent_level_option_to_run_keys.values():
for run_key in run_keys:
run_options.pop(run_key, None)

@override
async def _prepare_options(
self,
Expand All @@ -427,22 +513,8 @@ async def _prepare_options(
agent_reference = await self._get_agent_reference_or_create(run_options, instructions, options)
run_options["extra_body"] = {"agent": agent_reference}

# Remove properties that are not supported on request level
# but were configured on agent level
exclude = [
"model",
"tools",
"response_format",
"rai_config",
"temperature",
"top_p",
"text",
"text_format",
"reasoning",
]

for property in exclude:
run_options.pop(property, None)
# Remove only keys that map to this client's declared options TypedDict.
self._remove_agent_level_run_options(run_options)

return run_options

Expand Down
95 changes: 95 additions & 0 deletions python/packages/azure-ai/tests/test_azure_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def create_test_azure_ai_client(
client.conversation_id = conversation_id
client._is_application_endpoint = False # type: ignore
client._should_close_client = should_close_client # type: ignore
client.warn_runtime_tools_and_structure_changed = False # type: ignore
client._created_agent_tool_names = set() # type: ignore
client._created_agent_structured_output_signature = None # type: ignore
client.additional_properties = {}
client.middleware = None

Expand Down Expand Up @@ -773,6 +776,31 @@ async def test_agent_creation_with_tools(
assert call_args[1]["definition"].tools == test_tools


async def test_runtime_tools_override_logs_warning(
mock_project_client: MagicMock,
) -> None:
"""Test warning is logged when runtime tools differ from creation-time tools."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent")

mock_agent = MagicMock()
mock_agent.name = "test-agent"
mock_agent.version = "1.0"
mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent)

await client._get_agent_reference_or_create(
{"model": "test-model", "tools": [{"type": "function", "name": "tool_one"}]},
None,
)

with patch("agent_framework_azure_ai._client.logger.warning") as mock_warning:
await client._get_agent_reference_or_create(
{"model": "test-model", "tools": [{"type": "function", "name": "tool_two"}]},
None,
)
mock_warning.assert_called_once()
assert "Use AzureOpenAIResponsesClient instead." in mock_warning.call_args[0][0]


async def test_use_latest_version_existing_agent(
mock_project_client: MagicMock,
) -> None:
Expand Down Expand Up @@ -872,6 +900,13 @@ class ResponseFormatModel(BaseModel):
model_config = ConfigDict(extra="forbid")


class AlternateResponseFormatModel(BaseModel):
"""Alternate model for structured output warning checks."""

summary: str
confidence: float


async def test_agent_creation_with_response_format(
mock_project_client: MagicMock,
) -> None:
Expand Down Expand Up @@ -964,6 +999,33 @@ async def test_agent_creation_with_mapping_response_format(
assert format_config.strict is True


async def test_runtime_structured_output_override_logs_warning(
mock_project_client: MagicMock,
) -> None:
"""Test warning is logged when runtime structured_output differs from creation-time configuration."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent")

mock_agent = MagicMock()
mock_agent.name = "test-agent"
mock_agent.version = "1.0"
mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent)

await client._get_agent_reference_or_create(
{"model": "test-model"},
None,
{"response_format": ResponseFormatModel},
)

with patch("agent_framework_azure_ai._client.logger.warning") as mock_warning:
await client._get_agent_reference_or_create(
{"model": "test-model"},
None,
{"response_format": AlternateResponseFormatModel},
)
mock_warning.assert_called_once()
assert "Use AzureOpenAIResponsesClient instead." in mock_warning.call_args[0][0]


async def test_prepare_options_excludes_response_format(
mock_project_client: MagicMock,
) -> None:
Expand Down Expand Up @@ -1001,6 +1063,39 @@ async def test_prepare_options_excludes_response_format(
assert run_options["extra_body"]["agent"]["name"] == "test-agent"


async def test_prepare_options_keeps_values_for_unsupported_option_keys(
mock_project_client: MagicMock,
) -> None:
"""Test that run_options removal only applies to known AzureAI agent-level option mappings."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0")
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]

with (
patch(
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
return_value={
"model": "test-model",
"tools": [{"type": "function", "name": "weather"}],
"text": {"format": {"type": "json_schema", "name": "schema"}},
"text_format": ResponseFormatModel,
"custom_option": "keep-me",
},
),
patch.object(
client,
"_get_agent_reference_or_create",
return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"},
),
):
run_options = await client._prepare_options(messages, {})

assert "model" not in run_options
assert "tools" not in run_options
assert "text" not in run_options
assert "text_format" not in run_options
assert run_options["custom_option"] == "keep-me"


def test_get_conversation_id_with_store_true_and_conversation_id() -> None:
"""Test _get_conversation_id returns conversation ID when store is True and conversation exists."""
client = create_test_azure_ai_client(MagicMock())
Expand Down