Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions llms/ollama/internal/ollamaclient/types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ollamaclient

import (
"encoding/json"
"fmt"
"os"
"time"
Expand Down Expand Up @@ -40,10 +41,65 @@ type GenerateRequest struct {

type ImageData []byte

// Tool represents a tool available for the model to call.
type Tool struct {
Type string `json:"type"`
Function ToolFunction `json:"function"`
}

// ToolFunction describes a function that a tool can call.
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters any `json:"parameters,omitempty"`
}

// ToolCall represents a tool call returned by the model.
type ToolCall struct {
Function ToolCallFunction `json:"function"`
}

// ToolCallFunction holds the function name and arguments of a tool call.
type ToolCallFunction struct {
Name string `json:"name"`
Arguments string `json:"-"` // custom unmarshal: Ollama sends object, we store as JSON string
}

// UnmarshalJSON handles Ollama's format where arguments is an object, not a string.
func (f *ToolCallFunction) UnmarshalJSON(data []byte) error {
var raw struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
f.Name = raw.Name
// Ollama returns arguments as a JSON object; store it as a string
f.Arguments = string(raw.Arguments)
return nil
}

// MarshalJSON writes the function call back in Ollama's expected format.
func (f ToolCallFunction) MarshalJSON() ([]byte, error) {
args := json.RawMessage(f.Arguments)
if len(args) == 0 {
args = json.RawMessage("{}")
}
return json.Marshal(struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}{
Name: f.Name,
Arguments: args,
})
}

type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
Role string `json:"role"` // one of ["system", "user", "assistant", "tool"]
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}

type ChatRequest struct {
Expand All @@ -52,6 +108,7 @@ type ChatRequest struct {
Stream bool `json:"stream,omitempty"`
Format string `json:"format"`
KeepAlive string `json:"keep_alive,omitempty"`
Tools []Tool `json:"tools,omitempty"`

Options Options `json:"options"`
}
Expand Down
44 changes: 44 additions & 0 deletions llms/ollama/ollama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,50 @@ func TestCreateEmbedding(t *testing.T) {
}
}

func TestToolCall(t *testing.T) {
ctx := context.Background()

llm := newTestClient(t)

content := []llms.MessageContent{
{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{llms.TextContent{Text: "What is the weather in San Francisco?"}},
},
}

tools := []llms.Tool{
{
Type: "function",
Function: &llms.FunctionDefinition{
Name: "get_weather",
Description: "Get the current weather for a location",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{
"type": "string",
"description": "The city name",
},
},
"required": []string{"location"},
},
},
},
}

rsp, err := llm.GenerateContent(ctx, content, llms.WithTools(tools))
require.NoError(t, err)

require.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
require.NotEmpty(t, c1.ToolCalls, "expected tool calls in response")

tc := c1.ToolCalls[0]
assert.Equal(t, "get_weather", tc.FunctionCall.Name)
assert.Contains(t, tc.FunctionCall.Arguments, "San Francisco")
}

func TestWithPullTimeout(t *testing.T) {
ctx := context.Background()

Expand Down
136 changes: 97 additions & 39 deletions llms/ollama/ollamallm.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,34 +108,9 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
// We have to convert it to a format Ollama undestands: ChatRequest, which
// has a sequence of Message, each of which has a role and content - single
// text + potential images.
chatMsgs := make([]*ollamaclient.Message, 0, len(messages))
for _, mc := range messages {
msg := &ollamaclient.Message{Role: typeToRole(mc.Role)}

// Look at all the parts in mc; expect to find a single Text part and
// any number of binary parts.
var text string
foundText := false
var images []ollamaclient.ImageData

for _, p := range mc.Parts {
switch pt := p.(type) {
case llms.TextContent:
if foundText {
return nil, errors.New("expecting a single Text content")
}
foundText = true
text = pt.Text
case llms.BinaryContent:
images = append(images, ollamaclient.ImageData(pt.Data))
default:
return nil, errors.New("only support Text and BinaryContent parts right now")
}
}

msg.Content = text
msg.Images = images
chatMsgs = append(chatMsgs, msg)
chatMsgs, err := makeOllamaMessages(messages)
if err != nil {
return nil, err
}

format := o.options.format
Expand All @@ -155,12 +130,16 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
}
}
}
// Ollama doesn't support streaming with tools
stream := opts.StreamingFunc != nil && len(opts.Tools) == 0

req := &ollamaclient.ChatRequest{
Model: model,
Format: format,
Messages: chatMsgs,
Options: ollamaOptions,
Stream: opts.StreamingFunc != nil,
Stream: stream,
Tools: makeOllamaTools(opts.Tools),
}

keepAlive := o.options.keepAlive
Expand All @@ -183,15 +162,16 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
}
if !req.Stream || response.Done {
resp = response
resp.Message = &ollamaclient.Message{
Role: "assistant",
Content: streamedResponse,
if resp.Message == nil {
resp.Message = &ollamaclient.Message{}
}
resp.Message.Role = "assistant"
resp.Message.Content = streamedResponse
}
return nil
}

err := o.client.GenerateChat(ctx, req, fn)
err = o.client.GenerateChat(ctx, req, fn)
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
Expand Down Expand Up @@ -235,14 +215,28 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
genInfo["ThinkingEnabled"] = true
}

choices := []*llms.ContentChoice{
{
Content: content,
GenerationInfo: genInfo,
},
choice := &llms.ContentChoice{
Content: content,
GenerationInfo: genInfo,
}

// Convert tool calls from the response
if resp.Message != nil {
for _, tc := range resp.Message.ToolCalls {
choice.ToolCalls = append(choice.ToolCalls, llms.ToolCall{
Type: "function",
FunctionCall: &llms.FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
if len(choice.ToolCalls) > 0 {
choice.FuncCall = choice.ToolCalls[0].FunctionCall
}
}

response := &llms.ContentResponse{Choices: choices}
response := &llms.ContentResponse{Choices: []*llms.ContentChoice{choice}}

if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response)
Expand Down Expand Up @@ -332,6 +326,70 @@ func makeOllamaOptionsFromOptions(ollamaOptions ollamaclient.Options, opts llms.
return ollamaOptions
}

// makeOllamaMessages converts llms.MessageContent to ollamaclient.Message,
// handling text, binary, tool call, and tool response parts.
func makeOllamaMessages(messages []llms.MessageContent) ([]*ollamaclient.Message, error) {
msgs := make([]*ollamaclient.Message, 0, len(messages))
for _, mc := range messages {
msg := &ollamaclient.Message{Role: typeToRole(mc.Role)}

switch mc.Role {
case llms.ChatMessageTypeTool:
if len(mc.Parts) != 1 {
return nil, fmt.Errorf("expected exactly one part for tool role, got %d", len(mc.Parts))
}
p, ok := mc.Parts[0].(llms.ToolCallResponse)
if !ok {
return nil, fmt.Errorf("expected ToolCallResponse for tool role, got %T", mc.Parts[0])
}
msg.Content = p.Content

default:
var images []ollamaclient.ImageData
for _, p := range mc.Parts {
switch pt := p.(type) {
case llms.TextContent:
msg.Content = pt.Text
case llms.BinaryContent:
images = append(images, ollamaclient.ImageData(pt.Data))
case llms.ToolCall:
msg.ToolCalls = append(msg.ToolCalls, ollamaclient.ToolCall{
Function: ollamaclient.ToolCallFunction{
Name: pt.FunctionCall.Name,
Arguments: pt.FunctionCall.Arguments,
},
})
default:
return nil, fmt.Errorf("unsupported content part type: %T", p)
}
}
msg.Images = images
}

msgs = append(msgs, msg)
}
return msgs, nil
}

// makeOllamaTools converts llms.Tool to ollamaclient.Tool.
func makeOllamaTools(tools []llms.Tool) []ollamaclient.Tool {
if len(tools) == 0 {
return nil
}
out := make([]ollamaclient.Tool, len(tools))
for i, t := range tools {
out[i] = ollamaclient.Tool{
Type: t.Type,
Function: ollamaclient.ToolFunction{
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: t.Function.Parameters,
},
}
}
return out
}

// pullModelIfNeeded pulls the model if it's not already available.
func (o *LLM) pullModelIfNeeded(ctx context.Context, model string) error {
// Try to use the model first. If it fails with a model not found error,
Expand Down
19 changes: 19 additions & 0 deletions llms/ollama/testdata/TestToolCall.httprr

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading