Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 115 additions & 1 deletion python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
logger = get_logger()

__all__ = [
"AdditiveToolsList",
"FunctionInvocationConfiguration",
"FunctionInvocationLayer",
"FunctionTool",
Expand Down Expand Up @@ -1285,6 +1286,94 @@ def _get_tool_map(
return tool_list


class AdditiveToolsList:
"""Thread-safe wrapper for appending tools during parallel function execution.

This wrapper ensures that tools can be safely added to the tools list when multiple
function calls execute concurrently via asyncio.gather. Only append and extend
operations are supported to maintain a clear API focused on the dynamic tool loading
use case.

Note:
Only additive operations (append, extend) are supported because the model's context
already contains the existing tools from the conversation. Removing or modifying
tools during execution could result in the model calling tools that no longer exist,
leading to errors. Adding new tools is safe as it expands the model's capabilities
without invalidating its existing context.

Tools can access the tools list via ``kwargs.get("_framework_tools")`` in functions
that accept ``**kwargs``. The framework uses the reserved key ``"_framework_tools"``
to avoid conflicts with user-supplied kwargs.

Uses a threading.Lock for synchronization. While this briefly blocks the event loop,
the blocking time is negligible (sub-millisecond) since list append/extend operations
are extremely fast and few.

Example:
.. code-block:: python

from agent_framework import tool
from typing import Any


@tool(approval_mode="never_require")
def load_tools(category: str, **kwargs: Any) -> str:
# Access via reserved framework key
tools_list = kwargs.get("_framework_tools")
if tools_list and category == "math":
# Thread-safe mutation
tools_list.append(some_tool)
return "Tools loaded"
"""

def __init__(self, wrapped_list: list[Any]) -> None:
"""Initialize the thread-safe tools list wrapper.

Args:
wrapped_list: The underlying list to wrap.
"""
import threading

self._list = wrapped_list
self._lock = threading.Lock()

# Mutation methods - require lock
def append(self, item: Any) -> None:
"""Append item to the tools list (thread-safe)."""
with self._lock:
self._list.append(item)

def extend(self, items: Sequence[Any]) -> None:
"""Extend the tools list with items (thread-safe)."""
with self._lock:
self._list.extend(items)

# Read operations - no lock needed (safe in async)
def __getitem__(self, index: int | slice) -> Any:
"""Get item at index."""
return self._list[index]

def __len__(self) -> int:
"""Get length of the tools list."""
return len(self._list)

def __iter__(self) -> Any:
"""Iterate over the tools list."""
return iter(self._list)

def __contains__(self, item: Any) -> bool:
"""Check if item is in the tools list."""
return item in self._list

def __repr__(self) -> str:
"""Return string representation."""
return f"AdditiveToolsList({self._list!r})"

def __bool__(self) -> bool:
"""Check if list is non-empty."""
return bool(self._list)


async def _try_execute_function_calls(
custom_args: dict[str, Any],
attempt_idx: int,
Expand Down Expand Up @@ -1316,6 +1405,31 @@ async def _try_execute_function_calls(
from ._types import Content

tool_map = _get_tool_map(tools)

# Make tools list available to functions that accept **kwargs
# Use the same list object if tools is already a list so modifications persist
tools_list: list[Any]
if isinstance(tools, list):
# Use the same list object so modifications persist
tools_list = tools
elif isinstance(tools, Sequence) and not isinstance(tools, (str, bytes)):
# Convert other sequences to list
tools_list = list(tools)
elif tools is not None:
tools_list = [tools]
else:
tools_list = []

# Wrap the tools list in a thread-safe wrapper to prevent race conditions
# when multiple function calls execute concurrently via asyncio.gather
additive_tools = AdditiveToolsList(tools_list)

# Use a reserved framework key "_framework_tools" instead of "tools" to prevent
# overwriting user-supplied additional_function_arguments["tools"] values.
# Tools with **kwargs can access this via kwargs.get("_framework_tools").
custom_args_with_tools = dict(custom_args)
custom_args_with_tools["_framework_tools"] = additive_tools

approval_tools = [tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"]
logger.debug(
"_try_execute_function_calls: tool_map keys=%s, approval_tools=%s",
Expand Down Expand Up @@ -1380,7 +1494,7 @@ async def invoke_with_termination_handling(
try:
result = await _auto_invoke_function(
function_call_content=function_call, # type: ignore[arg-type]
custom_args=custom_args,
custom_args=custom_args_with_tools,
tool_map=tool_map,
sequence_index=seq_idx,
request_index=attempt_idx,
Expand Down
15 changes: 4 additions & 11 deletions python/packages/core/agent_framework/azure/_responses_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def __init__(
env_file_encoding: str | None = None,
instruction_role: str | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration
| None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
**kwargs: Any,
) -> None:
"""Initialize an Azure OpenAI Responses client.
Expand Down Expand Up @@ -190,9 +189,7 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False):
deployment_name = str(model_id)

# Project client path: create OpenAI client from an Azure AI Foundry project
if async_client is None and (
project_client is not None or project_endpoint is not None
):
if async_client is None and (project_client is not None or project_endpoint is not None):
async_client = self._create_client_from_project(
project_client=project_client,
project_endpoint=project_endpoint,
Expand Down Expand Up @@ -221,9 +218,7 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False):
and (hostname := urlparse(str(azure_openai_settings["endpoint"])).hostname)
and hostname.endswith(".openai.azure.com")
):
azure_openai_settings["base_url"] = urljoin(
str(azure_openai_settings["endpoint"]), "/openai/v1/"
)
azure_openai_settings["base_url"] = urljoin(str(azure_openai_settings["endpoint"]), "/openai/v1/")

if not azure_openai_settings["responses_deployment_name"]:
raise ServiceInitializationError(
Expand All @@ -236,9 +231,7 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False):
endpoint=azure_openai_settings["endpoint"],
base_url=azure_openai_settings["base_url"],
api_version=azure_openai_settings["api_version"], # type: ignore
api_key=azure_openai_settings["api_key"].get_secret_value()
if azure_openai_settings["api_key"]
else None,
api_key=azure_openai_settings["api_key"].get_secret_value() if azure_openai_settings["api_key"] else None,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
token_endpoint=azure_openai_settings["token_endpoint"],
Expand Down
Loading
Loading