diff --git a/docs/rough_edges.md b/docs/rough_edges.md index 98758cd6..e0e6d730 100644 --- a/docs/rough_edges.md +++ b/docs/rough_edges.md @@ -45,3 +45,17 @@ v2. **Workaround**: to advertise no capabilities, set `ServerOptions.Capabilities` or `ClientOptions.Capabilities` to an empty `&ServerCapabilities{}` or `&ClientCapabilities{}` respectively. + +- `CreateMessageResult.Content` is singular `Content`, but the 2025-11-25 spec + allows `content` to be a single block or an array (for parallel tool calls). + We added `CreateMessageResultWithTools` (with `Content []Content`) as a + workaround, matching the TypeScript SDK's approach. In v2, + `CreateMessageResult` should use `[]Content` directly. Similarly, + `SamplingMessage.Content` should become `[]Content` to support sending + multiple tool_result blocks in a single user message. + +- We didn't actually need CallToolParams and CallToolParamsRaw, since even when + we're unmarshalling into a custom Go type (for the mcp.AddTool convenience + wrapper) we need to first unmarshal into a `map[string]any` in order to do + server-side validation of required fields. CallToolParams could have just had + a map[string]any. diff --git a/internal/docs/rough_edges.src.md b/internal/docs/rough_edges.src.md index 4dc32199..42e79f78 100644 --- a/internal/docs/rough_edges.src.md +++ b/internal/docs/rough_edges.src.md @@ -44,3 +44,17 @@ v2. **Workaround**: to advertise no capabilities, set `ServerOptions.Capabilities` or `ClientOptions.Capabilities` to an empty `&ServerCapabilities{}` or `&ClientCapabilities{}` respectively. + +- `CreateMessageResult.Content` is singular `Content`, but the 2025-11-25 spec + allows `content` to be a single block or an array (for parallel tool calls). + We added `CreateMessageResultWithTools` (with `Content []Content`) as a + workaround, matching the TypeScript SDK's approach. In v2, + `CreateMessageResult` should use `[]Content` directly. Similarly, + `SamplingMessage.Content` should become `[]Content` to support sending + multiple tool_result blocks in a single user message. + +- We didn't actually need CallToolParams and CallToolParamsRaw, since even when + we're unmarshalling into a custom Go type (for the mcp.AddTool convenience + wrapper) we need to first unmarshal into a `map[string]any` in order to do + server-side validation of required fields. CallToolParams could have just had + a map[string]any. diff --git a/mcp/client.go b/mcp/client.go index a1670477..25a8c310 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -52,6 +52,9 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { } options = nil // prevent reuse + if opts.CreateMessageHandler != nil && opts.CreateMessageWithToolsHandler != nil { + panic("cannot set both CreateMessageHandler and CreateMessageWithToolsHandler; use CreateMessageWithToolsHandler for tool support, or CreateMessageHandler for basic sampling") + } if opts.Logger == nil { // ensure we have a logger opts.Logger = ensureLogger(nil) } @@ -77,6 +80,19 @@ type ClientOptions struct { // non nil value for [ClientCapabilities.Sampling], that value overrides the // inferred capability. CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) + // CreateMessageWithToolsHandler handles incoming sampling/createMessage + // requests that may involve tool use. It returns + // [CreateMessageWithToolsResult], which supports array content for parallel + // tool calls. + // + // Setting this handler causes the client to advertise the sampling + // capability with tools support (sampling.tools). As with + // [CreateMessageHandler], [ClientOptions.Capabilities].Sampling overrides + // the inferred capability. + // + // It is a panic to set both CreateMessageHandler and + // CreateMessageWithToolsHandler. + CreateMessageWithToolsHandler func(context.Context, *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) // ElicitationHandler handles incoming requests for elicitation/create. // // Setting ElicitationHandler to a non-nil value automatically causes the @@ -109,7 +125,16 @@ type ClientOptions struct { // are set in the Capabilities field, their values override the inferred // value. // - // For example, to to configure elicitation modes: + // For example, to advertise sampling with tools and context support: + // + // Capabilities: &ClientCapabilities{ + // Sampling: &SamplingCapabilities{ + // Tools: &SamplingToolsCapabilities{}, + // Context: &SamplingContextCapabilities{}, + // }, + // } + // + // Or to configure elicitation modes: // // Capabilities: &ClientCapabilities{ // Elicitation: &ElicitationCapabilities{ @@ -119,8 +144,7 @@ type ClientOptions struct { // } // // Conversely, if Capabilities does not set a field (for example, if the - // Elicitation field is nil), the inferred elicitation capability will be - // used. + // Elicitation field is nil), the inferred capability will be used. Capabilities *ClientCapabilities // ElicitationCompleteHandler handles incoming notifications for notifications/elicitation/complete. ElicitationCompleteHandler func(context.Context, *ElicitationCompleteNotificationRequest) @@ -198,11 +222,14 @@ func (c *Client) capabilities(protocolVersion string) *ClientCapabilities { caps.Roots = *caps.RootsV2 } - // Augment with sampling capability if handler is set. - if c.opts.CreateMessageHandler != nil { + // Augment with sampling capability if a handler is set. + if c.opts.CreateMessageHandler != nil || c.opts.CreateMessageWithToolsHandler != nil { if caps.Sampling == nil { caps.Sampling = &SamplingCapabilities{} } + if c.opts.CreateMessageWithToolsHandler != nil && caps.Sampling.Tools == nil { + caps.Sampling.Tools = &SamplingToolsCapabilities{} + } } // Augment with elicitation capability if handler is set. @@ -453,12 +480,23 @@ func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRoots }, nil } -func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { - if c.opts.CreateMessageHandler == nil { - // TODO: wrap or annotate this error? Pick a standard code? - return nil, &jsonrpc.Error{Code: codeUnsupportedMethod, Message: "client does not support CreateMessage"} +func (c *Client) createMessage(ctx context.Context, req *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) { + if c.opts.CreateMessageWithToolsHandler != nil { + return c.opts.CreateMessageWithToolsHandler(ctx, req) + } + if c.opts.CreateMessageHandler != nil { + // Downconvert the request for the basic handler. + baseReq := &CreateMessageRequest{ + Session: req.Session, + Params: req.Params.toBase(), + } + res, err := c.opts.CreateMessageHandler(ctx, baseReq) + if err != nil { + return nil, err + } + return res.toWithTools(), nil } - return c.opts.CreateMessageHandler(ctx, req) + return nil, &jsonrpc.Error{Code: codeUnsupportedMethod, Message: "client does not support CreateMessage"} } // urlElicitationMiddleware returns middleware that automatically handles URL elicitation diff --git a/mcp/content.go b/mcp/content.go index fb1a0d1e..4b911f29 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -14,7 +14,10 @@ import ( ) // A Content is a [TextContent], [ImageContent], [AudioContent], -// [ResourceLink], or [EmbeddedResource]. +// [ResourceLink], [EmbeddedResource], [ToolUseContent], or [ToolResultContent]. +// +// Note: [ToolUseContent] and [ToolResultContent] are only valid in sampling +// message contexts (CreateMessageParams/CreateMessageResult). type Content interface { MarshalJSON() ([]byte, error) fromWire(*wireContent) @@ -183,6 +186,104 @@ func (c *EmbeddedResource) fromWire(wire *wireContent) { c.Annotations = wire.Annotations } +// ToolUseContent represents a request from the assistant to invoke a tool. +// This content type is only valid in sampling messages. +type ToolUseContent struct { + // ID is a unique identifier for this tool use, used to match with ToolResultContent. + ID string + // Name is the name of the tool to invoke. + Name string + // Input contains the tool arguments as a JSON object. + Input map[string]any + Meta Meta +} + +func (c *ToolUseContent) MarshalJSON() ([]byte, error) { + input := c.Input + if input == nil { + input = map[string]any{} + } + wire := struct { + Type string `json:"type"` + ID string `json:"id"` + Name string `json:"name"` + Input map[string]any `json:"input"` + Meta Meta `json:"_meta,omitempty"` + }{ + Type: "tool_use", + ID: c.ID, + Name: c.Name, + Input: input, + Meta: c.Meta, + } + return json.Marshal(wire) +} + +func (c *ToolUseContent) fromWire(wire *wireContent) { + c.ID = wire.ID + c.Name = wire.Name + c.Input = wire.Input + c.Meta = wire.Meta +} + +// ToolResultContent represents the result of a tool invocation. +// This content type is only valid in sampling messages with role "user". +type ToolResultContent struct { + // ToolUseID references the ID from the corresponding ToolUseContent. + ToolUseID string + // Content holds the unstructured result of the tool call. + Content []Content + // StructuredContent holds an optional structured result as a JSON object. + StructuredContent any + // IsError indicates whether the tool call ended in an error. + IsError bool + Meta Meta +} + +func (c *ToolResultContent) MarshalJSON() ([]byte, error) { + // Marshal nested content + var contentWire []*wireContent + for _, content := range c.Content { + data, err := content.MarshalJSON() + if err != nil { + return nil, err + } + var w wireContent + if err := json.Unmarshal(data, &w); err != nil { + return nil, err + } + contentWire = append(contentWire, &w) + } + if contentWire == nil { + contentWire = []*wireContent{} // avoid JSON null + } + + wire := struct { + Type string `json:"type"` + ToolUseID string `json:"toolUseId"` + Content []*wireContent `json:"content"` + StructuredContent any `json:"structuredContent,omitempty"` + IsError bool `json:"isError,omitempty"` + Meta Meta `json:"_meta,omitempty"` + }{ + Type: "tool_result", + ToolUseID: c.ToolUseID, + Content: contentWire, + StructuredContent: c.StructuredContent, + IsError: c.IsError, + Meta: c.Meta, + } + return json.Marshal(wire) +} + +func (c *ToolResultContent) fromWire(wire *wireContent) { + c.ToolUseID = wire.ToolUseID + c.StructuredContent = wire.StructuredContent + c.IsError = wire.IsError + c.Meta = wire.Meta + // Content is handled separately in contentFromWire due to nested content +} + // ResourceContents contains the contents of a specific resource or // sub-resource. type ResourceContents struct { @@ -224,10 +325,9 @@ func (r *ResourceContents) MarshalJSON() ([]byte, error) { // wireContent is the wire format for content. // It represents the protocol types TextContent, ImageContent, AudioContent, -// ResourceLink, and EmbeddedResource. +// ResourceLink, EmbeddedResource, ToolUseContent, and ToolResultContent. // The Type field distinguishes them. In the protocol, each type has a constant // value for the field. -// At most one of Text, Data, Resource, and URI is non-zero. type wireContent struct { Type string `json:"type"` Text string `json:"text,omitempty"` @@ -242,10 +342,40 @@ type wireContent struct { Meta Meta `json:"_meta,omitempty"` Annotations *Annotations `json:"annotations,omitempty"` Icons []Icon `json:"icons,omitempty"` + // Fields for ToolUseContent (type: "tool_use") + ID string `json:"id,omitempty"` + Input map[string]any `json:"input,omitempty"` + // Fields for ToolResultContent (type: "tool_result") + ToolUseID string `json:"toolUseId,omitempty"` + NestedContent []*wireContent `json:"content,omitempty"` // nested content for tool_result + StructuredContent any `json:"structuredContent,omitempty"` + IsError bool `json:"isError,omitempty"` +} + +// unmarshalContent unmarshals JSON that is either a single content object or +// an array of content objects. A single object is wrapped in a one-element slice. +func unmarshalContent(raw json.RawMessage, allow map[string]bool) ([]Content, error) { + if len(raw) == 0 || string(raw) == "null" { + return nil, fmt.Errorf("nil content") + } + // Try array first, then fall back to single object. + var wires []*wireContent + if err := json.Unmarshal(raw, &wires); err == nil { + return contentsFromWire(wires, allow) + } + var wire wireContent + if err := json.Unmarshal(raw, &wire); err != nil { + return nil, err + } + c, err := contentFromWire(&wire, allow) + if err != nil { + return nil, err + } + return []Content{c}, nil } func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, error) { - var blocks []Content + blocks := make([]Content, 0, len(wires)) for _, wire := range wires { block, err := contentFromWire(wire, allow) if err != nil { @@ -284,6 +414,27 @@ func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) v := new(EmbeddedResource) v.fromWire(wire) return v, nil + case "tool_use": + v := new(ToolUseContent) + v.fromWire(wire) + return v, nil + case "tool_result": + v := new(ToolResultContent) + v.fromWire(wire) + // Handle nested content - tool_result content can contain text, image, audio, + // resource_link, and resource (same as CallToolResult.content) + if wire.NestedContent != nil { + toolResultContentAllow := map[string]bool{ + "text": true, "image": true, "audio": true, + "resource_link": true, "resource": true, + } + nestedContent, err := contentsFromWire(wire.NestedContent, toolResultContentAllow) + if err != nil { + return nil, fmt.Errorf("tool_result nested content: %w", err) + } + v.Content = nestedContent + } + return v, nil } - return nil, fmt.Errorf("internal error: unrecognized content type %s", wire.Type) + return nil, fmt.Errorf("unrecognized content type %q", wire.Type) } diff --git a/mcp/protocol.go b/mcp/protocol.go index bea776f9..85f468e2 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -188,7 +188,6 @@ type RootCapabilities struct { // this schema, but this is not a closed set: any client can define its own, // additional capabilities. type ClientCapabilities struct { - // NOTE: any addition to ClientCapabilities must also be reflected in // [ClientCapabilities.clone]. @@ -216,7 +215,12 @@ type ClientCapabilities struct { func (c *ClientCapabilities) clone() *ClientCapabilities { cp := *c cp.RootsV2 = shallowClone(c.RootsV2) - cp.Sampling = shallowClone(c.Sampling) + if c.Sampling != nil { + x := *c.Sampling + x.Tools = shallowClone(c.Sampling.Tools) + x.Context = shallowClone(c.Sampling.Context) + cp.Sampling = &x + } if c.Elicitation != nil { x := *c.Elicitation x.Form = shallowClone(c.Elicitation.Form) @@ -357,6 +361,11 @@ type CreateMessageParams struct { Meta `json:"_meta,omitempty"` // A request to include context from one or more MCP servers (including the // caller), to be attached to the prompt. The client may ignore this request. + // + // The default is "none". Values "thisServer" and + // "allServers" are soft-deprecated. Servers SHOULD only use these values if + // the client declares ClientCapabilities.sampling.context. These values may + // be removed in future spec releases. IncludeContext string `json:"includeContext,omitempty"` // The maximum number of tokens to sample, as requested by the server. The // client may choose to sample fewer tokens than requested. @@ -379,6 +388,105 @@ func (x *CreateMessageParams) isParams() {} func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } +// CreateMessageWithToolsParams is a sampling request that includes tools. +// It extends the basic [CreateMessageParams] fields with tools, tool choice, +// and messages that support array content (for parallel tool calls). +// +// Use with [ServerSession.CreateMessageWithTools]. +type CreateMessageWithToolsParams struct { + Meta `json:"_meta,omitempty"` + IncludeContext string `json:"includeContext,omitempty"` + MaxTokens int64 `json:"maxTokens"` + // Messages supports array content for tool_use and tool_result blocks. + Messages []*SamplingMessageV2 `json:"messages"` + Metadata any `json:"metadata,omitempty"` + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + SystemPrompt string `json:"systemPrompt,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + // Tools is the list of tools available for the model to use. + Tools []*Tool `json:"tools,omitempty"` + // ToolChoice controls how the model should use tools. + ToolChoice *ToolChoice `json:"toolChoice,omitempty"` +} + +func (x *CreateMessageWithToolsParams) isParams() {} +func (x *CreateMessageWithToolsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CreateMessageWithToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// toBase converts to CreateMessageParams by taking the first content block +// from each message. Tools, ToolChoice, and any additional content blocks +// (e.g. parallel tool calls) are dropped. The first block may be a +// ToolUseContent or ToolResultContent, which the basic handler should +// tolerate since SamplingMessage accepts tool content types. +func (p *CreateMessageWithToolsParams) toBase() *CreateMessageParams { + var msgs []*SamplingMessage + for _, m := range p.Messages { + var content Content + if len(m.Content) > 0 { + content = m.Content[0] + } + msgs = append(msgs, &SamplingMessage{Content: content, Role: m.Role}) + } + return &CreateMessageParams{ + Meta: p.Meta, + IncludeContext: p.IncludeContext, + MaxTokens: p.MaxTokens, + Messages: msgs, + Metadata: p.Metadata, + ModelPreferences: p.ModelPreferences, + StopSequences: p.StopSequences, + SystemPrompt: p.SystemPrompt, + Temperature: p.Temperature, + } +} + +// SamplingMessageV2 describes a message issued to or received from an +// LLM API, supporting array content for parallel tool calls. The "V2" refers +// to the 2025-11-25 spec, which changed content from a single block to +// single-or-array. In v2 of the SDK, this will replace [SamplingMessage]. +// +// When marshaling, a single-element Content slice is marshaled as a single +// object for compatibility with pre-2025-11-25 implementations. When +// unmarshaling, a single JSON content object is accepted and wrapped in a +// one-element slice. +type SamplingMessageV2 struct { + Content []Content `json:"content"` + Role Role `json:"role"` +} + +var samplingWithToolsAllow = map[string]bool{ + "text": true, "image": true, "audio": true, + "tool_use": true, "tool_result": true, +} + +// MarshalJSON marshals the message. A single-element Content slice is marshaled +// as a single object for backward compatibility. +func (m *SamplingMessageV2) MarshalJSON() ([]byte, error) { + if len(m.Content) == 1 { + return json.Marshal(&SamplingMessage{Content: m.Content[0], Role: m.Role}) + } + type msg SamplingMessageV2 // avoid recursion + return json.Marshal((*msg)(m)) +} + +func (m *SamplingMessageV2) UnmarshalJSON(data []byte) error { + type msg SamplingMessageV2 // avoid recursion + var wire struct { + msg + Content json.RawMessage `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.msg.Content, err = unmarshalContent(wire.Content, samplingWithToolsAllow); err != nil { + return err + } + *m = SamplingMessageV2(wire.msg) + return nil +} + // The client's response to a sampling/create_message request from the server. // The client should inform the user before returning the sampled message, to // allow them to inspect the response (human in the loop) and decide whether to @@ -392,6 +500,12 @@ type CreateMessageResult struct { Model string `json:"model"` Role Role `json:"role"` // The reason why sampling stopped, if known. + // + // Standard values: + // - "endTurn": natural end of the assistant's turn + // - "stopSequence": a stop sequence was encountered + // - "maxTokens": reached the maximum token limit + // - "toolUse": the model wants to use one or more tools StopReason string `json:"stopReason,omitempty"` } @@ -413,6 +527,84 @@ func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { return nil } +// CreateMessageWithToolsResult is the client's response to a +// sampling/create_message request that included tools. Content is a slice to +// support parallel tool calls (multiple tool_use blocks in one response). +// +// Use [ServerSession.CreateMessageWithTools] to send a sampling request with +// tools and receive this result type. +// +// When unmarshaling, a single JSON content object is accepted and wrapped in a +// one-element slice, for compatibility with clients that return a single block. +type CreateMessageWithToolsResult struct { + Meta `json:"_meta,omitempty"` + Content []Content `json:"content"` + Model string `json:"model"` + Role Role `json:"role"` + // The reason why sampling stopped. + // + // Standard values: "endTurn", "stopSequence", "maxTokens", "toolUse". + StopReason string `json:"stopReason,omitempty"` +} + +// createMessageWithToolsResultAllow lists content types valid in assistant responses. +// tool_result is excluded: it only appears in user messages. +var createMessageWithToolsResultAllow = map[string]bool{ + "text": true, "image": true, "audio": true, + "tool_use": true, +} + +func (*CreateMessageWithToolsResult) isResult() {} + +// MarshalJSON marshals the result. When Content has a single element, it is +// marshaled as a single object for compatibility with pre-2025-11-25 +// implementations that expect a single content block. +func (r *CreateMessageWithToolsResult) MarshalJSON() ([]byte, error) { + if len(r.Content) == 1 { + return json.Marshal(&CreateMessageResult{ + Meta: r.Meta, + Content: r.Content[0], + Model: r.Model, + Role: r.Role, + StopReason: r.StopReason, + }) + } + type result CreateMessageWithToolsResult // avoid recursion + return json.Marshal((*result)(r)) +} + +func (r *CreateMessageWithToolsResult) UnmarshalJSON(data []byte) error { + type result CreateMessageWithToolsResult // avoid recursion + var wire struct { + result + Content json.RawMessage `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.result.Content, err = unmarshalContent(wire.Content, createMessageWithToolsResultAllow); err != nil { + return err + } + *r = CreateMessageWithToolsResult(wire.result) + return nil +} + +// toWithTools converts a CreateMessageResult to CreateMessageWithToolsResult. +func (r *CreateMessageResult) toWithTools() *CreateMessageWithToolsResult { + var content []Content + if r.Content != nil { + content = []Content{r.Content} + } + return &CreateMessageWithToolsResult{ + Meta: r.Meta, + Content: content, + Model: r.Model, + Role: r.Role, + StopReason: r.StopReason, + } +} + type GetPromptParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -982,7 +1174,27 @@ func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t // below directly above ClientCapabilities. // SamplingCapabilities describes the client's support for sampling. -type SamplingCapabilities struct{} +type SamplingCapabilities struct { + // Context indicates the client supports includeContext values other than "none". + Context *SamplingContextCapabilities `json:"context,omitempty"` + // Tools indicates the client supports tools and toolChoice in sampling requests. + Tools *SamplingToolsCapabilities `json:"tools,omitempty"` +} + +// SamplingContextCapabilities indicates the client supports context inclusion. +type SamplingContextCapabilities struct{} + +// SamplingToolsCapabilities indicates the client supports tool use in sampling. +type SamplingToolsCapabilities struct{} + +// ToolChoice controls how the model uses tools during sampling. +type ToolChoice struct { + // Mode controls tool invocation behavior: + // - "auto": Model decides whether to use tools (default) + // - "required": Model must use at least one tool + // - "none": Model must not use any tools + Mode string `json:"mode,omitempty"` +} // ElicitationCapabilities describes the capabilities for elicitation. // @@ -993,14 +1205,15 @@ type ElicitationCapabilities struct { } // FormElicitationCapabilities describes capabilities for form elicitation. -type FormElicitationCapabilities struct { -} +type FormElicitationCapabilities struct{} // URLElicitationCapabilities describes capabilities for url elicitation. -type URLElicitationCapabilities struct { -} +type URLElicitationCapabilities struct{} // Describes a message issued to or received from an LLM API. +// +// For assistant messages, Content may be text, image, audio, or tool_use. +// For user messages, Content may be text, image, audio, or tool_result. type SamplingMessage struct { Content Content `json:"content"` Role Role `json:"role"` @@ -1017,8 +1230,9 @@ func (m *SamplingMessage) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &wire); err != nil { return err } + // Allow text, image, audio, tool_use, and tool_result in sampling messages var err error - if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true, "tool_use": true, "tool_result": true}); err != nil { return err } *m = SamplingMessage(wire.msg) @@ -1297,7 +1511,6 @@ type ToolCapabilities struct { // ServerCapabilities describes capabilities that a server supports. type ServerCapabilities struct { - // NOTE: any addition to ServerCapabilities must also be reflected in // [ServerCapabilities.clone]. diff --git a/mcp/requests.go b/mcp/requests.go index f64d6fb6..42809413 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -24,6 +24,7 @@ type ( type ( CreateMessageRequest = ClientRequest[*CreateMessageParams] + CreateMessageWithToolsRequest = ClientRequest[*CreateMessageWithToolsParams] ElicitRequest = ClientRequest[*ElicitParams] initializedClientRequest = ClientRequest[*InitializedParams] InitializeRequest = ClientRequest[*InitializeParams] diff --git a/mcp/sampling_tools_test.go b/mcp/sampling_tools_test.go new file mode 100644 index 00000000..1afeae25 --- /dev/null +++ b/mcp/sampling_tools_test.go @@ -0,0 +1,1253 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "reflect" + "strings" + "testing" +) + +func TestToolUseContent_MarshalJSON(t *testing.T) { + tests := []struct { + name string + content *ToolUseContent + want map[string]any + }{ + { + name: "basic tool use", + content: &ToolUseContent{ + ID: "tool_123", + Name: "calculator", + Input: map[string]any{ + "operation": "add", + "x": 1.0, + "y": 2.0, + }, + }, + want: map[string]any{ + "type": "tool_use", + "id": "tool_123", + "name": "calculator", + "input": map[string]any{ + "operation": "add", + "x": 1.0, + "y": 2.0, + }, + }, + }, + { + name: "tool use with nil input", + content: &ToolUseContent{ + ID: "tool_456", + Name: "no_args_tool", + Input: nil, + }, + want: map[string]any{ + "type": "tool_use", + "id": "tool_456", + "name": "no_args_tool", + "input": map[string]any{}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := tt.content.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToolResultContent_MarshalJSON(t *testing.T) { + tests := []struct { + name string + content *ToolResultContent + want map[string]any + }{ + { + name: "basic tool result", + content: &ToolResultContent{ + ToolUseID: "tool_123", + Content: []Content{&TextContent{Text: "42"}}, + }, + want: map[string]any{ + "type": "tool_result", + "toolUseId": "tool_123", + "content": []any{ + map[string]any{ + "type": "text", + "text": "42", + }, + }, + }, + }, + { + name: "tool result with error", + content: &ToolResultContent{ + ToolUseID: "tool_456", + Content: []Content{&TextContent{Text: "division by zero"}}, + IsError: true, + }, + want: map[string]any{ + "type": "tool_result", + "toolUseId": "tool_456", + "content": []any{ + map[string]any{ + "type": "text", + "text": "division by zero", + }, + }, + "isError": true, + }, + }, + { + name: "tool result with structured content", + content: &ToolResultContent{ + ToolUseID: "tool_789", + Content: []Content{&TextContent{Text: `{"result": 42}`}}, + StructuredContent: map[string]any{"result": 42.0}, + }, + want: map[string]any{ + "type": "tool_result", + "toolUseId": "tool_789", + "structuredContent": map[string]any{"result": 42.0}, + "content": []any{ + map[string]any{ + "type": "text", + "text": `{"result": 42}`, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := tt.content.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToolUseContent_UnmarshalJSON(t *testing.T) { + jsonData := `{ + "type": "tool_use", + "id": "tool_123", + "name": "calculator", + "input": {"x": 1, "y": 2} + }` + + wire := &wireContent{} + if err := json.Unmarshal([]byte(jsonData), wire); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + content, err := contentFromWire(wire, map[string]bool{"tool_use": true}) + if err != nil { + t.Fatalf("contentFromWire() error = %v", err) + } + + toolUse, ok := content.(*ToolUseContent) + if !ok { + t.Fatalf("expected *ToolUseContent, got %T", content) + } + + if toolUse.ID != "tool_123" { + t.Errorf("ID = %v, want %v", toolUse.ID, "tool_123") + } + if toolUse.Name != "calculator" { + t.Errorf("Name = %v, want %v", toolUse.Name, "calculator") + } + if toolUse.Input["x"] != 1.0 || toolUse.Input["y"] != 2.0 { + t.Errorf("Input = %v, want map with x=1, y=2", toolUse.Input) + } +} + +func TestToolResultContent_UnmarshalJSON(t *testing.T) { + jsonData := `{ + "type": "tool_result", + "toolUseId": "tool_123", + "content": [{"type": "text", "text": "42"}], + "isError": false + }` + + wire := &wireContent{} + if err := json.Unmarshal([]byte(jsonData), wire); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + content, err := contentFromWire(wire, map[string]bool{"tool_result": true}) + if err != nil { + t.Fatalf("contentFromWire() error = %v", err) + } + + toolResult, ok := content.(*ToolResultContent) + if !ok { + t.Fatalf("expected *ToolResultContent, got %T", content) + } + + if toolResult.ToolUseID != "tool_123" { + t.Errorf("ToolUseID = %v, want %v", toolResult.ToolUseID, "tool_123") + } + if toolResult.IsError { + t.Errorf("IsError = %v, want false", toolResult.IsError) + } + if len(toolResult.Content) != 1 { + t.Fatalf("len(Content) = %v, want 1", len(toolResult.Content)) + } + textContent, ok := toolResult.Content[0].(*TextContent) + if !ok { + t.Fatalf("expected *TextContent, got %T", toolResult.Content[0]) + } + if textContent.Text != "42" { + t.Errorf("Text = %v, want %v", textContent.Text, "42") + } +} + +func TestCreateMessageWithToolsResult_ToolUseContent(t *testing.T) { + // Test that CreateMessageWithToolsResult can unmarshal tool_use content + jsonData := `{ + "content": {"type": "tool_use", "id": "tool_1", "name": "calculator", "input": {"x": 1}}, + "model": "test-model", + "role": "assistant", + "stopReason": "toolUse" + }` + + var result CreateMessageWithToolsResult + if err := json.Unmarshal([]byte(jsonData), &result); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if result.Model != "test-model" { + t.Errorf("Model = %v, want %v", result.Model, "test-model") + } + if result.StopReason != "toolUse" { + t.Errorf("StopReason = %v, want %v", result.StopReason, "toolUse") + } + + if len(result.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(result.Content)) + } + toolUse, ok := result.Content[0].(*ToolUseContent) + if !ok { + t.Fatalf("Content[0] expected *ToolUseContent, got %T", result.Content[0]) + } + if toolUse.ID != "tool_1" { + t.Errorf("Content.ID = %v, want %v", toolUse.ID, "tool_1") + } + if toolUse.Name != "calculator" { + t.Errorf("Content.Name = %v, want %v", toolUse.Name, "calculator") + } +} + +func TestSamplingMessage_ToolUseContent(t *testing.T) { + // Test that SamplingMessage can unmarshal tool_use content (assistant role) + jsonData := `{ + "content": {"type": "tool_use", "id": "tool_1", "name": "calc", "input": {}}, + "role": "assistant" + }` + + var msg SamplingMessage + if err := json.Unmarshal([]byte(jsonData), &msg); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if msg.Role != "assistant" { + t.Errorf("Role = %v, want %v", msg.Role, "assistant") + } + + toolUse, ok := msg.Content.(*ToolUseContent) + if !ok { + t.Fatalf("Content expected *ToolUseContent, got %T", msg.Content) + } + if toolUse.ID != "tool_1" { + t.Errorf("Content.ID = %v, want %v", toolUse.ID, "tool_1") + } +} + +func TestSamplingMessage_ToolResultContent(t *testing.T) { + // Test that SamplingMessage can unmarshal tool_result content (user role) + jsonData := `{ + "content": {"type": "tool_result", "toolUseId": "tool_1", "content": [{"type": "text", "text": "42"}]}, + "role": "user" + }` + + var msg SamplingMessage + if err := json.Unmarshal([]byte(jsonData), &msg); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if msg.Role != "user" { + t.Errorf("Role = %v, want %v", msg.Role, "user") + } + + toolResult, ok := msg.Content.(*ToolResultContent) + if !ok { + t.Fatalf("Content expected *ToolResultContent, got %T", msg.Content) + } + if toolResult.ToolUseID != "tool_1" { + t.Errorf("Content.ToolUseID = %v, want %v", toolResult.ToolUseID, "tool_1") + } + if len(toolResult.Content) != 1 { + t.Fatalf("len(Content.Content) = %v, want 1", len(toolResult.Content)) + } +} + +func TestSamplingCapabilities_WithTools(t *testing.T) { + caps := &SamplingCapabilities{ + Tools: &SamplingToolsCapabilities{}, + Context: &SamplingContextCapabilities{}, + } + + data, err := json.Marshal(caps) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var caps2 SamplingCapabilities + if err := json.Unmarshal(data, &caps2); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if caps2.Tools == nil { + t.Error("Tools capability should not be nil") + } + if caps2.Context == nil { + t.Error("Context capability should not be nil") + } +} + +func TestSamplingCapabilities_Empty(t *testing.T) { + // Test backward compatibility - empty struct should marshal/unmarshal correctly + caps := &SamplingCapabilities{} + + data, err := json.Marshal(caps) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var caps2 SamplingCapabilities + if err := json.Unmarshal(data, &caps2); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if caps2.Tools != nil { + t.Error("Tools capability should be nil for empty capabilities") + } + if caps2.Context != nil { + t.Error("Context capability should be nil for empty capabilities") + } +} + +func TestCreateMessageWithToolsParams(t *testing.T) { + params := &CreateMessageWithToolsParams{ + MaxTokens: 1000, + Messages: []*SamplingMessageV2{ + { + Role: "user", + Content: []Content{&TextContent{Text: "Calculate 1+1"}}, + }, + }, + Tools: []*Tool{ + { + Name: "calculator", + Description: "A calculator tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "x": map[string]any{"type": "number"}, + "y": map[string]any{"type": "number"}, + }, + }, + }, + }, + ToolChoice: &ToolChoice{Mode: "auto"}, + } + + data, err := json.Marshal(params) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var params2 CreateMessageWithToolsParams + if err := json.Unmarshal(data, ¶ms2); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if len(params2.Tools) != 1 { + t.Fatalf("len(Tools) = %v, want 1", len(params2.Tools)) + } + if params2.Tools[0].Name != "calculator" { + t.Errorf("Tools[0].Name = %v, want %v", params2.Tools[0].Name, "calculator") + } + if params2.ToolChoice == nil || params2.ToolChoice.Mode != "auto" { + t.Errorf("ToolChoice.Mode = %v, want %v", params2.ToolChoice, &ToolChoice{Mode: "auto"}) + } +} + +func TestToolChoice_Modes(t *testing.T) { + tests := []struct { + name string + mode string + }{ + {"auto", "auto"}, + {"required", "required"}, + {"none", "none"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &ToolChoice{Mode: tt.mode} + data, err := json.Marshal(tc) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var tc2 ToolChoice + if err := json.Unmarshal(data, &tc2); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if tc2.Mode != tt.mode { + t.Errorf("Mode = %v, want %v", tc2.Mode, tt.mode) + } + }) + } +} + +// Integration tests + +func TestSamplingWithTools_Integration(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + // Track what the client received + var receivedParams *CreateMessageWithToolsParams + + // Client with tools capability, using CreateMessageWithToolsHandler + client := NewClient(testImpl, &ClientOptions{ + CreateMessageWithToolsHandler: func(_ context.Context, req *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) { + receivedParams = req.Params + // Return a tool use response + return &CreateMessageWithToolsResult{ + Model: "test-model", + Role: "assistant", + Content: []Content{&ToolUseContent{ + ID: "tool_call_1", + Name: "calculator", + Input: map[string]any{"x": 1.0, "y": 2.0}, + }}, + StopReason: "toolUse", + }, nil + }, + Capabilities: &ClientCapabilities{ + Sampling: &SamplingCapabilities{Tools: &SamplingToolsCapabilities{}}, + }, + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Server sends CreateMessageWithTools + result, err := ss.CreateMessageWithTools(ctx, &CreateMessageWithToolsParams{ + MaxTokens: 1000, + Messages: []*SamplingMessageV2{ + {Role: "user", Content: []Content{&TextContent{Text: "Calculate 1+2"}}}, + }, + Tools: []*Tool{ + { + Name: "calculator", + Description: "A calculator", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "x": map[string]any{"type": "number"}, + "y": map[string]any{"type": "number"}, + }, + }, + }, + }, + ToolChoice: &ToolChoice{Mode: "auto"}, + }) + if err != nil { + t.Fatalf("CreateMessageWithTools() error = %v", err) + } + + // Verify client received the tools + if receivedParams == nil { + t.Fatal("client did not receive params") + } + if len(receivedParams.Tools) != 1 { + t.Errorf("client received %d tools, want 1", len(receivedParams.Tools)) + } + if receivedParams.Tools[0].Name != "calculator" { + t.Errorf("tool name = %v, want calculator", receivedParams.Tools[0].Name) + } + if receivedParams.ToolChoice == nil || receivedParams.ToolChoice.Mode != "auto" { + t.Errorf("tool choice mode = %v, want auto", receivedParams.ToolChoice) + } + + // Verify server received the tool use response + if result.StopReason != "toolUse" { + t.Errorf("StopReason = %v, want toolUse", result.StopReason) + } + if len(result.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(result.Content)) + } + toolUse, ok := result.Content[0].(*ToolUseContent) + if !ok { + t.Fatalf("Content[0] type = %T, want *ToolUseContent", result.Content[0]) + } + if toolUse.ID != "tool_call_1" { + t.Errorf("ToolUse.ID = %v, want tool_call_1", toolUse.ID) + } + if toolUse.Name != "calculator" { + t.Errorf("ToolUse.Name = %v, want calculator", toolUse.Name) + } +} + +func TestSamplingWithToolResult_Integration(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + // Track messages received by client + var receivedMessages []*SamplingMessage + + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + receivedMessages = req.Params.Messages + return &CreateMessageResult{ + Model: "test-model", + Role: "assistant", + Content: &TextContent{Text: "The result is 3"}, + }, nil + }, + Capabilities: &ClientCapabilities{ + Sampling: &SamplingCapabilities{Tools: &SamplingToolsCapabilities{}}, + }, + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Server sends CreateMessage with tool result in messages + _, err = ss.CreateMessage(ctx, &CreateMessageParams{ + MaxTokens: 1000, + Messages: []*SamplingMessage{ + {Role: "user", Content: &TextContent{Text: "Calculate 1+2"}}, + {Role: "assistant", Content: &ToolUseContent{ + ID: "tool_1", + Name: "calculator", + Input: map[string]any{"x": 1.0, "y": 2.0}, + }}, + {Role: "user", Content: &ToolResultContent{ + ToolUseID: "tool_1", + Content: []Content{&TextContent{Text: "3"}}, + }}, + }, + }) + if err != nil { + t.Fatalf("CreateMessage() error = %v", err) + } + + // Verify client received all messages including tool content + if len(receivedMessages) != 3 { + t.Fatalf("received %d messages, want 3", len(receivedMessages)) + } + + // Check first message is text + if _, ok := receivedMessages[0].Content.(*TextContent); !ok { + t.Errorf("message[0] content type = %T, want *TextContent", receivedMessages[0].Content) + } + + // Check second message is tool use + toolUse, ok := receivedMessages[1].Content.(*ToolUseContent) + if !ok { + t.Fatalf("message[1] content type = %T, want *ToolUseContent", receivedMessages[1].Content) + } + if toolUse.ID != "tool_1" { + t.Errorf("toolUse.ID = %v, want tool_1", toolUse.ID) + } + + // Check third message is tool result + toolResult, ok := receivedMessages[2].Content.(*ToolResultContent) + if !ok { + t.Fatalf("message[2] content type = %T, want *ToolResultContent", receivedMessages[2].Content) + } + if toolResult.ToolUseID != "tool_1" { + t.Errorf("toolResult.ToolUseID = %v, want tool_1", toolResult.ToolUseID) + } + if len(toolResult.Content) != 1 { + t.Fatalf("toolResult.Content len = %d, want 1", len(toolResult.Content)) + } + if tc, ok := toolResult.Content[0].(*TextContent); !ok || tc.Text != "3" { + t.Errorf("toolResult.Content[0] = %v, want TextContent with '3'", toolResult.Content[0]) + } +} + +func TestSamplingToolsCapability_Integration(t *testing.T) { + ctx := context.Background() + + t.Run("client advertises tools capability", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(_ context.Context, _ *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "m", Content: &TextContent{}}, nil + }, + Capabilities: &ClientCapabilities{ + Sampling: &SamplingCapabilities{ + Tools: &SamplingToolsCapabilities{}, + Context: &SamplingContextCapabilities{}, + }, + }, + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Check server sees client capabilities + caps := ss.InitializeParams().Capabilities + if caps.Sampling == nil { + t.Fatal("client should advertise sampling capability") + } + if caps.Sampling.Tools == nil { + t.Error("client should advertise sampling.tools capability") + } + if caps.Sampling.Context == nil { + t.Error("client should advertise sampling.context capability") + } + }) + + t.Run("client without tools capability", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(_ context.Context, _ *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "m", Content: &TextContent{}}, nil + }, + // No Capabilities.Sampling.Tools set + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Check server sees client capabilities + caps := ss.InitializeParams().Capabilities + if caps.Sampling == nil { + t.Fatal("client should advertise sampling capability") + } + if caps.Sampling.Tools != nil { + t.Error("client should NOT advertise sampling.tools capability") + } + if caps.Sampling.Context != nil { + t.Error("client should NOT advertise sampling.context capability") + } + }) + + t.Run("CreateMessageWithToolsHandler infers tools capability", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + client := NewClient(testImpl, &ClientOptions{ + CreateMessageWithToolsHandler: func(_ context.Context, _ *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) { + return &CreateMessageWithToolsResult{Model: "m", Content: []Content{&TextContent{}}}, nil + }, + // No explicit Capabilities set — tools should be inferred. + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + caps := ss.InitializeParams().Capabilities + if caps.Sampling == nil { + t.Fatal("client should advertise sampling capability") + } + if caps.Sampling.Tools == nil { + t.Error("client should infer sampling.tools capability from CreateMessageWithToolsHandler") + } + }) +} + +func TestSamplingToolResultWithError_Integration(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + var receivedMessages []*SamplingMessage + + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + receivedMessages = req.Params.Messages + return &CreateMessageResult{ + Model: "test-model", + Role: "assistant", + Content: &TextContent{Text: "I see the tool failed"}, + }, nil + }, + Capabilities: &ClientCapabilities{ + Sampling: &SamplingCapabilities{Tools: &SamplingToolsCapabilities{}}, + }, + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Server sends CreateMessage with error tool result + _, err = ss.CreateMessage(ctx, &CreateMessageParams{ + MaxTokens: 1000, + Messages: []*SamplingMessage{ + {Role: "user", Content: &ToolResultContent{ + ToolUseID: "tool_1", + Content: []Content{&TextContent{Text: "division by zero"}}, + IsError: true, + }}, + }, + }) + if err != nil { + t.Fatalf("CreateMessage() error = %v", err) + } + + if len(receivedMessages) != 1 { + t.Fatalf("received %d messages, want 1", len(receivedMessages)) + } + + toolResult, ok := receivedMessages[0].Content.(*ToolResultContent) + if !ok { + t.Fatalf("content type = %T, want *ToolResultContent", receivedMessages[0].Content) + } + if !toolResult.IsError { + t.Error("IsError should be true") + } + if toolResult.ToolUseID != "tool_1" { + t.Errorf("ToolUseID = %v, want tool_1", toolResult.ToolUseID) + } +} + +func TestToolResultContent_ImageNestedContent(t *testing.T) { + // Verify non-text nested content in ToolResultContent works. + jsonData := `{ + "type": "tool_result", + "toolUseId": "t1", + "content": [ + {"type": "image", "mimeType": "image/png", "data": "YWJj"} + ] + }` + + wire := &wireContent{} + if err := json.Unmarshal([]byte(jsonData), wire); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + content, err := contentFromWire(wire, map[string]bool{"tool_result": true}) + if err != nil { + t.Fatalf("contentFromWire() error = %v", err) + } + + toolResult, ok := content.(*ToolResultContent) + if !ok { + t.Fatalf("expected *ToolResultContent, got %T", content) + } + if len(toolResult.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(toolResult.Content)) + } + img, ok := toolResult.Content[0].(*ImageContent) + if !ok { + t.Fatalf("nested content type = %T, want *ImageContent", toolResult.Content[0]) + } + if img.MIMEType != "image/png" { + t.Errorf("MIMEType = %v, want image/png", img.MIMEType) + } +} + +func TestToolUseContent_MetaRoundTrip(t *testing.T) { + // Verify Meta round-trips through marshal/unmarshal. + orig := &ToolUseContent{ + ID: "t1", + Name: "calc", + Input: map[string]any{"x": 1.0}, + Meta: Meta{"requestId": "req-123"}, + } + + data, err := orig.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + wire := &wireContent{} + if err := json.Unmarshal(data, wire); err != nil { + t.Fatalf("Unmarshal wire error = %v", err) + } + + content, err := contentFromWire(wire, map[string]bool{"tool_use": true}) + if err != nil { + t.Fatalf("contentFromWire() error = %v", err) + } + + got, ok := content.(*ToolUseContent) + if !ok { + t.Fatalf("type = %T, want *ToolUseContent", content) + } + if got.Meta["requestId"] != "req-123" { + t.Errorf("Meta[requestId] = %v, want req-123", got.Meta["requestId"]) + } +} + +func TestParallelToolCalls_Integration(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + // Client returns parallel tool use results + client := NewClient(testImpl, &ClientOptions{ + CreateMessageWithToolsHandler: func(_ context.Context, req *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) { + return &CreateMessageWithToolsResult{ + Model: "test-model", + Role: "assistant", + Content: []Content{ + &ToolUseContent{ID: "call_1", Name: "weather", Input: map[string]any{"city": "SF"}}, + &ToolUseContent{ID: "call_2", Name: "weather", Input: map[string]any{"city": "NY"}}, + }, + StopReason: "toolUse", + }, nil + }, + Capabilities: &ClientCapabilities{ + Sampling: &SamplingCapabilities{Tools: &SamplingToolsCapabilities{}}, + }, + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + result, err := ss.CreateMessageWithTools(ctx, &CreateMessageWithToolsParams{ + MaxTokens: 1000, + Messages: []*SamplingMessageV2{ + {Role: "user", Content: []Content{&TextContent{Text: "Weather in SF and NY"}}}, + }, + Tools: []*Tool{ + {Name: "weather", InputSchema: map[string]any{"type": "object"}}, + }, + }) + if err != nil { + t.Fatalf("CreateMessageWithTools() error = %v", err) + } + + if len(result.Content) != 2 { + t.Fatalf("len(Content) = %d, want 2", len(result.Content)) + } + for i, c := range result.Content { + tu, ok := c.(*ToolUseContent) + if !ok { + t.Fatalf("Content[%d] type = %T, want *ToolUseContent", i, c) + } + if tu.Name != "weather" { + t.Errorf("Content[%d].Name = %v, want weather", i, tu.Name) + } + } + if result.Content[0].(*ToolUseContent).ID != "call_1" { + t.Errorf("Content[0].ID = %v, want call_1", result.Content[0].(*ToolUseContent).ID) + } + if result.Content[1].(*ToolUseContent).ID != "call_2" { + t.Errorf("Content[1].ID = %v, want call_2", result.Content[1].(*ToolUseContent).ID) + } +} + +func TestCreateMessageWithToolsResult_ArrayRoundTrip(t *testing.T) { + // Marshal multi-content, unmarshal, verify. + orig := &CreateMessageWithToolsResult{ + Model: "test", + Role: "assistant", + Content: []Content{ + &ToolUseContent{ID: "t1", Name: "calc", Input: map[string]any{"x": 1.0}}, + &ToolUseContent{ID: "t2", Name: "search", Input: map[string]any{"q": "hi"}}, + }, + StopReason: "toolUse", + } + + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var got CreateMessageWithToolsResult + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if len(got.Content) != 2 { + t.Fatalf("len(Content) = %d, want 2", len(got.Content)) + } + for i, c := range got.Content { + tu, ok := c.(*ToolUseContent) + if !ok { + t.Fatalf("Content[%d] type = %T, want *ToolUseContent", i, c) + } + origTU := orig.Content[i].(*ToolUseContent) + if tu.ID != origTU.ID || tu.Name != origTU.Name { + t.Errorf("Content[%d] = %+v, want %+v", i, tu, origTU) + } + } +} + +func TestCreateMessageWithToolsResult_SingleContentBackwardCompat(t *testing.T) { + // Single-element Content marshals as object (not array). + result := &CreateMessageWithToolsResult{ + Model: "test", + Role: "assistant", + Content: []Content{&TextContent{Text: "hello"}}, + } + + data, err := json.Marshal(result) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("Unmarshal raw error = %v", err) + } + + content := raw["content"] + for _, b := range content { + if b == ' ' || b == '\t' || b == '\n' || b == '\r' { + continue + } + if b == '[' { + t.Errorf("single-element Content marshaled as array, want object") + } + break + } +} + +func TestNewClient_BothHandlersPanics(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic when both handlers set") + } + msg, ok := r.(string) + if !ok || !strings.Contains(msg, "CreateMessageHandler") { + t.Errorf("unexpected panic: %v", r) + } + }() + NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return nil, nil + }, + CreateMessageWithToolsHandler: func(context.Context, *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) { + return nil, nil + }, + }) +} + +func TestCreateMessage_MultipleContentError(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + // Client returns multiple content blocks via CreateMessageWithToolsHandler + client := NewClient(testImpl, &ClientOptions{ + CreateMessageWithToolsHandler: func(_ context.Context, _ *CreateMessageWithToolsRequest) (*CreateMessageWithToolsResult, error) { + return &CreateMessageWithToolsResult{ + Model: "test", + Role: "assistant", + Content: []Content{ + &TextContent{Text: "a"}, + &TextContent{Text: "b"}, + }, + }, nil + }, + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Server calls CreateMessage (singular), should get error + _, err = ss.CreateMessage(ctx, &CreateMessageParams{ + MaxTokens: 100, + Messages: []*SamplingMessage{{Role: "user", Content: &TextContent{Text: "hi"}}}, + }) + if err == nil { + t.Fatal("expected error for multiple content blocks") + } + if !strings.Contains(err.Error(), "CreateMessageWithTools") { + t.Errorf("error should mention CreateMessageWithTools, got: %v", err) + } +} + +func TestUnmarshalContent_NullJSON(t *testing.T) { + // JSON null should be rejected. + jsonData := `{"content": null, "model": "m", "role": "assistant"}` + var result CreateMessageWithToolsResult + if err := json.Unmarshal([]byte(jsonData), &result); err == nil { + t.Error("expected error for null content") + } +} + +func TestUnmarshalContent_EmptyArray(t *testing.T) { + // Empty array should produce empty (non-nil) slice. + jsonData := `{"content": [], "model": "m", "role": "assistant"}` + var result CreateMessageWithToolsResult + if err := json.Unmarshal([]byte(jsonData), &result); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if result.Content == nil { + t.Error("Content should be non-nil empty slice, got nil") + } + if len(result.Content) != 0 { + t.Errorf("len(Content) = %d, want 0", len(result.Content)) + } +} + +func TestSamplingMessageV2_EmptyContent(t *testing.T) { + msg := &SamplingMessageV2{ + Role: "user", + Content: []Content{}, + } + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + var got SamplingMessageV2 + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(got.Content) != 0 { + t.Errorf("len(Content) = %d, want 0", len(got.Content)) + } +} + +func TestSamplingMessageV2_MixedContent(t *testing.T) { + // Text + tool_use in the same message (valid per spec for assistant). + msg := &SamplingMessageV2{ + Role: "assistant", + Content: []Content{ + &TextContent{Text: "Let me check the weather."}, + &ToolUseContent{ID: "c1", Name: "weather", Input: map[string]any{"city": "SF"}}, + }, + } + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + var got SamplingMessageV2 + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(got.Content) != 2 { + t.Fatalf("len(Content) = %d, want 2", len(got.Content)) + } + if _, ok := got.Content[0].(*TextContent); !ok { + t.Errorf("Content[0] type = %T, want *TextContent", got.Content[0]) + } + if _, ok := got.Content[1].(*ToolUseContent); !ok { + t.Errorf("Content[1] type = %T, want *ToolUseContent", got.Content[1]) + } +} + +func TestCreateMessageWithToolsResult_RejectsToolResult(t *testing.T) { + // tool_result should not be valid in a result (assistant role). + jsonData := `{ + "content": {"type": "tool_result", "toolUseId": "t1", "content": []}, + "model": "m", + "role": "assistant" + }` + var result CreateMessageWithToolsResult + if err := json.Unmarshal([]byte(jsonData), &result); err == nil { + t.Error("expected error for tool_result in CreateMessageWithToolsResult") + } +} + +func TestToBase_Conversion(t *testing.T) { + params := &CreateMessageWithToolsParams{ + MaxTokens: 1000, + Messages: []*SamplingMessageV2{ + {Role: "user", Content: []Content{&TextContent{Text: "hello"}}}, + {Role: "assistant", Content: []Content{ + &ToolUseContent{ID: "c1", Name: "calc", Input: map[string]any{}}, + &ToolUseContent{ID: "c2", Name: "search", Input: map[string]any{}}, + }}, + }, + Tools: []*Tool{{Name: "calc"}}, + ToolChoice: &ToolChoice{Mode: "auto"}, + } + base := params.toBase() + + // Tools and ToolChoice should be gone + if base.MaxTokens != 1000 { + t.Errorf("MaxTokens = %d, want 1000", base.MaxTokens) + } + if len(base.Messages) != 2 { + t.Fatalf("len(Messages) = %d, want 2", len(base.Messages)) + } + // First message: single content preserved + if tc, ok := base.Messages[0].Content.(*TextContent); !ok || tc.Text != "hello" { + t.Errorf("Messages[0].Content = %v, want TextContent{hello}", base.Messages[0].Content) + } + // Second message: only first content block kept + if tu, ok := base.Messages[1].Content.(*ToolUseContent); !ok || tu.ID != "c1" { + t.Errorf("Messages[1].Content = %v, want ToolUseContent{c1}", base.Messages[1].Content) + } +} + +func TestToWithTools_Conversion(t *testing.T) { + result := &CreateMessageResult{ + Model: "test", + Role: "assistant", + Content: &TextContent{Text: "hello"}, + StopReason: "endTurn", + } + wt := result.toWithTools() + if wt.Model != "test" { + t.Errorf("Model = %v, want test", wt.Model) + } + if len(wt.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(wt.Content)) + } + if tc, ok := wt.Content[0].(*TextContent); !ok || tc.Text != "hello" { + t.Errorf("Content[0] = %v, want TextContent{hello}", wt.Content[0]) + } +} + +func TestToWithTools_NilContent(t *testing.T) { + result := &CreateMessageResult{ + Model: "test", + Role: "assistant", + } + wt := result.toWithTools() + if wt.Content != nil { + t.Errorf("Content = %v, want nil", wt.Content) + } +} + +func TestClientCapabilities_CloneSampling(t *testing.T) { + caps := &ClientCapabilities{ + Sampling: &SamplingCapabilities{ + Tools: &SamplingToolsCapabilities{}, + Context: &SamplingContextCapabilities{}, + }, + } + cloned := caps.clone() + + // Verify deep copy — Sampling pointer should differ. + // (Tools and Context are empty structs, so Go may reuse the same address; + // we just check they're non-nil and that mutating Sampling doesn't alias.) + if cloned.Sampling == caps.Sampling { + t.Error("Sampling pointer should differ after clone") + } + if cloned.Sampling.Tools == nil { + t.Error("cloned Sampling.Tools should not be nil") + } + if cloned.Sampling.Context == nil { + t.Error("cloned Sampling.Context should not be nil") + } + // Verify mutation doesn't affect original. + cloned.Sampling.Tools = nil + if caps.Sampling.Tools == nil { + t.Error("modifying cloned Sampling.Tools should not affect original") + } +} diff --git a/mcp/server.go b/mcp/server.go index d68a8c23..2229f427 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1162,6 +1162,10 @@ func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) } // CreateMessage sends a sampling request to the client. +// +// If the client returns multiple content blocks (e.g. parallel tool calls), +// CreateMessage returns an error. Use [ServerSession.CreateMessageWithTools] +// for tool-enabled sampling. func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { if err := ss.checkInitialized(methodCreateMessage); err != nil { return nil, err @@ -1174,7 +1178,44 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag p2.Messages = []*SamplingMessage{} // avoid JSON "null" params = &p2 } - return handleSend[*CreateMessageResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) + res, err := handleSend[*CreateMessageWithToolsResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) + if err != nil { + return nil, err + } + // Downconvert to singular content. + if len(res.Content) > 1 { + return nil, fmt.Errorf("CreateMessage result has %d content blocks; use CreateMessageWithTools for multiple content", len(res.Content)) + } + var content Content + if len(res.Content) > 0 { + content = res.Content[0] + } + return &CreateMessageResult{ + Meta: res.Meta, + Content: content, + Model: res.Model, + Role: res.Role, + StopReason: res.StopReason, + }, nil +} + +// CreateMessageWithTools sends a sampling request with tools to the client, +// returning a [CreateMessageWithToolsResult] that supports array content +// (for parallel tool calls). Use this instead of [ServerSession.CreateMessage] +// when the request includes tools. +func (ss *ServerSession) CreateMessageWithTools(ctx context.Context, params *CreateMessageWithToolsParams) (*CreateMessageWithToolsResult, error) { + if err := ss.checkInitialized(methodCreateMessage); err != nil { + return nil, err + } + if params == nil { + params = &CreateMessageWithToolsParams{Messages: []*SamplingMessageV2{}} + } + if params.Messages == nil { + p2 := *params + p2.Messages = []*SamplingMessageV2{} // avoid JSON "null" + params = &p2 + } + return handleSend[*CreateMessageWithToolsResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) } // Elicit sends an elicitation request to the client asking for user input.