-
Notifications
You must be signed in to change notification settings - Fork 76
Expand file tree
/
Copy pathfunction_calling.py
More file actions
146 lines (112 loc) · 4.74 KB
/
function_calling.py
File metadata and controls
146 lines (112 loc) · 4.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import json
from typing import Literal, Sequence
from absl import app, flags
from pydantic import BaseModel, Field
from xai_sdk import Client
from xai_sdk.chat import system, tool, tool_result, user
STREAM = flags.DEFINE_bool("stream", False, "Whether streaming is enabled.")
def function_calling(client: Client) -> None:
"""Multi-turn chat with function calling."""
def get_weather(city: str, units: Literal["C", "F"]) -> str:
temperature = 20 if units == "C" else 68
return f"The weather in {city} is sunny at a temperature of {temperature} {units}."
chat = client.chat.create(
model="grok-3",
messages=[system("You are a helpful assistant that can answer questions and help with tasks.")],
tools=[
tool(
name="get_weather",
description="Get the weather for a given city.",
parameters={
"type": "object",
"properties": {
"city": {"type": "string", "description": "The name of the city to get the weather for."},
"units": {
"type": "string",
"description": "The units to use for the temperature.",
"enum": ["C", "F"],
},
},
"required": ["city", "units"],
},
),
],
)
while True:
user_input = input("You: ")
if user_input.lower() == "exit":
break
chat.append(user(user_input))
response = chat.sample()
chat.append(response)
if response.content:
print("Grok: ", end="", flush=True)
print(response.content, end="", flush=True)
if response.tool_calls:
for tool_call in response.tool_calls:
tool_args = json.loads(tool_call.function.arguments)
result = get_weather(tool_args["city"], tool_args["units"])
chat.append(tool_result(result))
response = chat.sample()
print()
print("Grok: ", end="", flush=True)
print(response.content, end="", flush=True)
chat.append(response)
print()
def function_calling_streaming(client: Client) -> None:
"""Multi-turn chat with function calling and streaming."""
# Define the shape of the tool call arguments as a Pydantic model.
class GetWeatherRequest(BaseModel):
city: str = Field(description="The name of the city to get the weather for.")
units: Literal["C", "F"] = Field(description="The units to use for the temperature.")
def get_weather(request: GetWeatherRequest) -> str:
temperature = 20 if request.units == "C" else 68
return f"The weather in {request.city} is sunny at a temperature of {temperature} {request.units}."
conversation = client.chat.create(
model="grok-3",
messages=[system("You are a helpful assistant that can answer questions and help with tasks.")],
tools=[
tool(
name="get_weather",
description="Get the weather for a given city.",
parameters=GetWeatherRequest.model_json_schema(),
)
],
)
while True:
user_input = input("You: ")
if user_input.lower() == "exit":
break
conversation.append(user(user_input))
stream = conversation.stream()
print("Grok: ", end="", flush=True)
last_response = None
for response, chunk in stream:
print(chunk.content, end="", flush=True)
last_response = response
assert last_response is not None
conversation.append(last_response)
if last_response.tool_calls:
for tool_call in last_response.tool_calls:
# Validate the tool call arguments as a Pydantic model and get proper type checking.
request = GetWeatherRequest.model_validate_json(tool_call.function.arguments)
result = get_weather(request)
conversation.append(tool_result(result))
stream = conversation.stream()
last_response = None
for response, chunk in stream:
print(chunk.content, end="", flush=True)
last_response = response
assert last_response is not None
conversation.append(last_response)
print()
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Unexpected command line arguments.")
client = Client()
if STREAM.value:
function_calling_streaming(client)
else:
function_calling(client)
if __name__ == "__main__":
app.run(main)