Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions chains/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/httputil"
"github.com/tmc/langchaingo/internal/httprr"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/googleai"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/prompts"
Expand Down Expand Up @@ -65,6 +66,31 @@ func TestLLMChain(t *testing.T) {
require.True(t, strings.Contains(result, "Paris"))
}

type errorLanguageModel struct {
err error
}

func (m *errorLanguageModel) Call(_ context.Context, _ string, _ ...llms.CallOption) (string, error) {
return "", m.err
}

func (m *errorLanguageModel) GenerateContent(_ context.Context, _ []llms.MessageContent, _ ...llms.CallOption) (*llms.ContentResponse, error) {
return nil, m.err
}

func TestLLMChainPropagatesContentFilterError(t *testing.T) {
t.Parallel()

chain := NewLLMChain(
&errorLanguageModel{err: llms.NewError(llms.ErrCodeContentFilter, "bedrock", "blocked")},
prompts.NewPromptTemplate("{{.text}}", []string{"text"}),
)

_, err := chain.Call(context.Background(), map[string]any{"text": "unsafe prompt"})
require.Error(t, err)
require.True(t, llms.IsContentFilterError(err))
}

func TestLLMChainWithChatPromptTemplate(t *testing.T) {
ctx := context.Background()
t.Parallel()
Expand Down
17 changes: 17 additions & 0 deletions chains/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ type chainCallOption struct {
RepetitionPenalty float64
repetitionPenaltySet bool

// Safety configures provider-defined safety controls in an LLM call.
SafetyConfig map[string]any
safetyConfigSet bool

// CallbackHandler is the callback handler for Chain
CallbackHandler callbacks.Handler
}
Expand Down Expand Up @@ -146,6 +150,16 @@ func WithRepetitionPenalty(repetitionPenalty float64) ChainCallOption {
}
}

// WithSafetyConfig configures provider-defined safety controls for the LLM call.
func WithSafetyConfig(config map[string]any) ChainCallOption {
return func(o *chainCallOption) {
if config != nil {
o.SafetyConfig = config
o.safetyConfigSet = true
}
}
}

// WithStopWords is an option for setting the stop words for LLM.Call.
func WithStopWords(stopWords []string) ChainCallOption {
return func(o *chainCallOption) {
Expand Down Expand Up @@ -208,6 +222,9 @@ func GetLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint:
if opts.repetitionPenaltySet {
chainCallOption = append(chainCallOption, llms.WithRepetitionPenalty(opts.RepetitionPenalty))
}
if opts.safetyConfigSet {
chainCallOption = append(chainCallOption, llms.WithSafetyConfig(opts.SafetyConfig))
}
chainCallOption = append(chainCallOption, llms.WithStreamingFunc(opts.StreamingFunc))

return chainCallOption
Expand Down
48 changes: 28 additions & 20 deletions llms/bedrock/bedrockllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,37 +103,45 @@ func processMessages(messages []llms.MessageContent) ([]bedrockclient.Message, e

for _, m := range messages {
for _, part := range m.Parts {
switch part := part.(type) {
var cacheControl *llms.CacheControl
if cached, ok := part.(llms.CachedContent); ok {
cacheControl = cached.CacheControl
part = cached.ContentPart
}

switch p := part.(type) {
case llms.TextContent:
bedrockMsgs = append(bedrockMsgs, bedrockclient.Message{
Role: m.Role,
Content: part.Text,
Type: "text",
Role: m.Role,
Content: p.Text,
Type: "text",
CacheControl: cacheControl,
})
case llms.BinaryContent:
bedrockMsgs = append(bedrockMsgs, bedrockclient.Message{
Role: m.Role,
Content: string(part.Data),
MimeType: part.MIMEType,
Type: "image",
Role: m.Role,
Content: string(p.Data),
MimeType: p.MIMEType,
Type: "image",
CacheControl: cacheControl,
})
case llms.ToolCall:
// Handle tool calls from AI messages
bedrockMsgs = append(bedrockMsgs, bedrockclient.Message{
Role: m.Role,
Content: "", // Content will be empty for tool calls
Type: "tool_call",
ToolCallID: part.ID,
ToolName: part.FunctionCall.Name,
ToolArgs: part.FunctionCall.Arguments,
Role: m.Role,
Content: "", // Content will be empty for tool calls
Type: "tool_call",
ToolCallID: p.ID,
ToolName: p.FunctionCall.Name,
ToolArgs: p.FunctionCall.Arguments,
CacheControl: cacheControl,
})
case llms.ToolCallResponse:
// Handle tool result messages
bedrockMsgs = append(bedrockMsgs, bedrockclient.Message{
Role: m.Role,
Content: part.Content,
Type: "tool_result",
ToolUseID: part.ToolCallID,
Role: m.Role,
Content: p.Content,
Type: "tool_result",
ToolUseID: p.ToolCallID,
CacheControl: cacheControl,
})
default:
return nil, errors.New("unsupported message type")
Expand Down
19 changes: 19 additions & 0 deletions llms/bedrock/caching.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package bedrock

import (
"time"

"github.com/tmc/langchaingo/llms"
)

// EphemeralCache returns an ephemeral cache control for Bedrock prompt
// caching with no ttl field on the wire. Bedrock applies its default
// 5-minute behavior.
func EphemeralCache() *llms.CacheControl {
return &llms.CacheControl{Type: "ephemeral"}
}

// EphemeralCacheOneHour returns an ephemeral cache control with ttl=1h.
func EphemeralCacheOneHour() *llms.CacheControl {
return &llms.CacheControl{Type: "ephemeral", Duration: time.Hour}
}
3 changes: 3 additions & 0 deletions llms/bedrock/internal/bedrockclient/bedrockclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ type Message struct {
ToolArgs string `json:"tool_args,omitempty"`
// Tool result fields
ToolUseID string `json:"tool_use_id,omitempty"`
// CacheControl marks this message as a prompt-cache breakpoint when set.
// Providers that don't support prompt caching silently ignore the field.
CacheControl *llms.CacheControl `json:"-"`
}

func getProvider(modelID string) string {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,10 @@ func TestClient_CreateCompletion(t *testing.T) {
},
StopReason: AnthropicCompletionReasonEndTurn,
Usage: struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}{
InputTokens: 10,
OutputTokens: 5,
Expand Down Expand Up @@ -429,16 +431,20 @@ func TestClient_CreateCompletion_Streaming(t *testing.T) {
StopReason any `json:"stop_reason"`
StopSequence any `json:"stop_sequence"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
} `json:"usage"`
}{
ID: "msg-123",
Type: "message",
Role: "assistant",
Usage: struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}{
InputTokens: 10,
},
Expand Down Expand Up @@ -485,7 +491,9 @@ func TestClient_CreateCompletion_Streaming(t *testing.T) {
StopReason: AnthropicCompletionReasonEndTurn,
},
Usage: struct {
OutputTokens int `json:"output_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}{
OutputTokens: 15,
},
Expand Down
32 changes: 20 additions & 12 deletions llms/bedrock/internal/bedrockclient/bedrockclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ func TestProcessInputMessagesAnthropic(t *testing.T) {
name string
messages []Message
expectedMsgs int
expectedSystem string
expectedSystem interface{}
expectError bool
errorContains string
}{
Expand All @@ -268,7 +268,7 @@ func TestProcessInputMessagesAnthropic(t *testing.T) {
{Role: llms.ChatMessageTypeHuman, Type: "text", Content: "Hello"},
},
expectedMsgs: 1,
expectedSystem: "",
expectedSystem: nil,
},
{
name: "system message extracted",
Expand Down Expand Up @@ -304,7 +304,7 @@ func TestProcessInputMessagesAnthropic(t *testing.T) {
{Role: llms.ChatMessageTypeHuman, Type: "text", Content: "How are you?"},
},
expectedMsgs: 3,
expectedSystem: "",
expectedSystem: nil,
},
{
name: "multiple messages same role chunked together",
Expand All @@ -314,15 +314,15 @@ func TestProcessInputMessagesAnthropic(t *testing.T) {
{Role: llms.ChatMessageTypeAI, Type: "text", Content: "Response"},
},
expectedMsgs: 2,
expectedSystem: "",
expectedSystem: nil,
},
{
name: "function role converted to user",
messages: []Message{
{Role: llms.ChatMessageTypeFunction, Type: "text", Content: "Function call"},
},
expectedMsgs: 1,
expectedSystem: "",
expectedSystem: nil,
},
}

Expand Down Expand Up @@ -673,8 +673,10 @@ func TestAnthropicResponseParsing(t *testing.T) {
StopReason: AnthropicCompletionReasonEndTurn,
StopSequence: "",
Usage: struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}{
InputTokens: 10,
OutputTokens: 15,
Expand Down Expand Up @@ -746,17 +748,21 @@ func TestAnthropicStreamingResponseChunk(t *testing.T) {
StopReason any `json:"stop_reason"`
StopSequence any `json:"stop_sequence"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
} `json:"usage"`
}{
ID: "msg-123",
Type: "message",
Role: "assistant",
Model: "claude-3",
Usage: struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}{
InputTokens: 25,
},
Expand Down Expand Up @@ -792,7 +798,9 @@ func TestAnthropicStreamingResponseChunk(t *testing.T) {
StopReason: AnthropicCompletionReasonEndTurn,
},
Usage: struct {
OutputTokens int `json:"output_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}{
OutputTokens: 12,
},
Expand Down
Loading
Loading