diff --git a/llms/googleai/googleai.go b/llms/googleai/googleai.go index 8d07c7f10..38615a575 100644 --- a/llms/googleai/googleai.go +++ b/llms/googleai/googleai.go @@ -110,6 +110,26 @@ func (g *GoogleAI) GenerateContent( model.ResponseMIMEType = ResponseMIMETypeJson } + // Support for ResponseSchema (structured output) + // Accepts *genai.Schema for native Google AI structured output + // Google API requires ResponseMIMEType to be "application/json" when using ResponseSchema + if opts.ResponseSchema != nil { + schema, ok := opts.ResponseSchema.(*genai.Schema) + if !ok { + return nil, fmt.Errorf("ResponseSchema must be *genai.Schema for Google AI, got %T", opts.ResponseSchema) + } + // Validate MIME type compatibility + if model.ResponseMIMEType != "" && model.ResponseMIMEType != ResponseMIMETypeJson { + return nil, fmt.Errorf("ResponseSchema requires ResponseMIMEType to be %q, got %q", + ResponseMIMETypeJson, model.ResponseMIMEType) + } + model.ResponseSchema = schema + // Auto-set ResponseMIMEType to JSON if not already set + if model.ResponseMIMEType == "" { + model.ResponseMIMEType = ResponseMIMETypeJson + } + } + var response *llms.ContentResponse if len(messages) == 1 { diff --git a/llms/googleai/googleai_unit_test.go b/llms/googleai/googleai_unit_test.go index ba2186551..ecc5162e5 100644 --- a/llms/googleai/googleai_unit_test.go +++ b/llms/googleai/googleai_unit_test.go @@ -570,3 +570,161 @@ func TestConvertTools(t *testing.T) { //nolint:funlen // comprehensive test //no assert.Contains(t, customizationsProp.Items.Required, "value") }) } + +// TestSchemaTypeAliases tests the Schema type aliases exported from the package +func TestSchemaTypeAliases(t *testing.T) { + t.Parallel() + + // Test that Schema type alias works correctly + t.Run("Schema type alias", func(t *testing.T) { + schema := &Schema{ + Type: TypeObject, + Description: "A test object", + Properties: map[string]*Schema{ + "name": { + Type: TypeString, + Description: "The name field", + }, + "age": { + Type: TypeInteger, + Description: "The age field", + }, + "score": { + Type: TypeNumber, + Description: "The score field", + }, + "active": { + Type: TypeBoolean, + Description: "The active flag", + }, + "tags": { + Type: TypeArray, + Description: "List of tags", + Items: &Schema{Type: TypeString}, + }, + }, + Required: []string{"name"}, + } + + // Verify the schema is correctly constructed + assert.Equal(t, TypeObject, schema.Type) + assert.Equal(t, "A test object", schema.Description) + assert.Len(t, schema.Properties, 5) + assert.Equal(t, []string{"name"}, schema.Required) + + // Verify nested properties + assert.Equal(t, TypeString, schema.Properties["name"].Type) + assert.Equal(t, TypeInteger, schema.Properties["age"].Type) + assert.Equal(t, TypeNumber, schema.Properties["score"].Type) + assert.Equal(t, TypeBoolean, schema.Properties["active"].Type) + assert.Equal(t, TypeArray, schema.Properties["tags"].Type) + assert.Equal(t, TypeString, schema.Properties["tags"].Items.Type) + }) + + // Test Type constants + t.Run("Type constants", func(t *testing.T) { + // Verify type constants are correctly exported + assert.NotEqual(t, TypeUnspecified, TypeString) + assert.NotEqual(t, TypeString, TypeNumber) + assert.NotEqual(t, TypeNumber, TypeInteger) + assert.NotEqual(t, TypeInteger, TypeBoolean) + assert.NotEqual(t, TypeBoolean, TypeArray) + assert.NotEqual(t, TypeArray, TypeObject) + }) +} + +// TestWithResponseSchemaOption tests the WithResponseSchema CallOption +func TestWithResponseSchemaOption(t *testing.T) { + t.Parallel() + + t.Run("sets ResponseSchema correctly", func(t *testing.T) { + schema := &Schema{ + Type: TypeObject, + Properties: map[string]*Schema{ + "result": {Type: TypeString}, + }, + } + + opts := &llms.CallOptions{} + llms.WithResponseSchema(schema)(opts) + + assert.NotNil(t, opts.ResponseSchema) + assert.Equal(t, schema, opts.ResponseSchema) + }) + + t.Run("nil schema", func(t *testing.T) { + opts := &llms.CallOptions{} + llms.WithResponseSchema(nil)(opts) + + assert.Nil(t, opts.ResponseSchema) + }) + + // Test backward compatibility - existing options should still work + t.Run("backward compatibility with existing options", func(t *testing.T) { + opts := &llms.CallOptions{} + + // Apply multiple options including ResponseSchema + llms.WithModel("gemini-2.0-flash")(opts) + llms.WithTemperature(0.7)(opts) + llms.WithMaxTokens(1000)(opts) + llms.WithJSONMode()(opts) + llms.WithResponseSchema(&Schema{Type: TypeObject})(opts) + + // Verify all options are set correctly + assert.Equal(t, "gemini-2.0-flash", opts.Model) + assert.Equal(t, 0.7, opts.Temperature) + assert.Equal(t, 1000, opts.MaxTokens) + assert.True(t, opts.JSONMode) + assert.NotNil(t, opts.ResponseSchema) + }) + + // Test ResponseSchema with ResponseMIMEType compatibility + t.Run("compatible with application/json MIME type", func(t *testing.T) { + opts := &llms.CallOptions{} + llms.WithResponseMIMEType("application/json")(opts) + llms.WithResponseSchema(&Schema{Type: TypeObject})(opts) + + assert.Equal(t, "application/json", opts.ResponseMIMEType) + assert.NotNil(t, opts.ResponseSchema) + }) + + // Test complex nested schema + t.Run("complex nested schema", func(t *testing.T) { + schema := &Schema{ + Type: TypeObject, + Description: "User profile response", + Properties: map[string]*Schema{ + "user": { + Type: TypeObject, + Properties: map[string]*Schema{ + "id": {Type: TypeInteger}, + "name": {Type: TypeString}, + "email": {Type: TypeString}, + "roles": { + Type: TypeArray, + Items: &Schema{Type: TypeString}, + }, + }, + Required: []string{"id", "name"}, + }, + "metadata": { + Type: TypeObject, + Properties: map[string]*Schema{ + "created_at": {Type: TypeString}, + "updated_at": {Type: TypeString}, + }, + }, + }, + Required: []string{"user"}, + } + + opts := &llms.CallOptions{} + llms.WithResponseSchema(schema)(opts) + + assert.NotNil(t, opts.ResponseSchema) + s := opts.ResponseSchema.(*Schema) + assert.Equal(t, TypeObject, s.Type) + assert.Len(t, s.Properties, 2) + assert.Contains(t, s.Required, "user") + }) +} diff --git a/llms/googleai/schema.go b/llms/googleai/schema.go new file mode 100644 index 000000000..91e5368b1 --- /dev/null +++ b/llms/googleai/schema.go @@ -0,0 +1,34 @@ +package googleai + +import ( + "github.com/google/generative-ai-go/genai" +) + +// Schema type aliases for convenient structured output usage. +// These are re-exported from github.com/google/generative-ai-go/genai. +type ( + // Schema represents the structure of generated content. + // Use this with llms.WithResponseSchema() for structured output. + Schema = genai.Schema + + // Type represents the data type of a Schema. + Type = genai.Type +) + +// Type constants for Schema definition. +const ( + // TypeUnspecified means not specified, should not be used. + TypeUnspecified = genai.TypeUnspecified + // TypeString means string type. + TypeString = genai.TypeString + // TypeNumber means number type (float). + TypeNumber = genai.TypeNumber + // TypeInteger means integer type. + TypeInteger = genai.TypeInteger + // TypeBoolean means boolean type. + TypeBoolean = genai.TypeBoolean + // TypeArray means array type. + TypeArray = genai.TypeArray + // TypeObject means object/struct type. + TypeObject = genai.TypeObject +) diff --git a/llms/options.go b/llms/options.go index 87b35359c..cbdd74fcb 100644 --- a/llms/options.go +++ b/llms/options.go @@ -70,6 +70,17 @@ type CallOptions struct { // application/json: JSON response in the response candidates. ResponseMIMEType string `json:"response_mime_type,omitempty"` + // ResponseSchema specifies the schema for structured output. + // The schema format is provider-specific: + // - For Google AI: use *genai.Schema from github.com/google/generative-ai-go/genai + // (or use the re-exported googleai.Schema type alias) + // - For other providers: check provider documentation + // + // Note: Google API requires ResponseMIMEType to be "application/json" when using + // ResponseSchema. If not set, it will be auto-configured. If set to a different + // value, an error will be returned. + ResponseSchema any `json:"response_schema,omitempty"` + // WebSearchOptions configures web search behavior for models that support it. // Currently supported by OpenAI models like gpt-4o-search-preview. WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` @@ -329,6 +340,20 @@ func WithResponseMIMEType(responseMIMEType string) CallOption { } } +// WithResponseSchema will add an option to set the ResponseSchema for structured output. +// The schema format is provider-specific and must match the provider being used: +// - For Google AI: use *genai.Schema from github.com/google/generative-ai-go/genai +// (or use the re-exported googleai.Schema type alias) +// - For other providers: check provider documentation +// +// Passing an incorrect type will result in an error from the provider. +// When using this option, ResponseMIMEType will typically be auto-set to "application/json". +func WithResponseSchema(schema any) CallOption { + return func(o *CallOptions) { + o.ResponseSchema = schema + } +} + // WithWebSearch enables web search for models that support it. // Use with OpenAI models like gpt-4o-search-preview and gpt-4o-mini-search-preview. // Pass nil for default web search behavior, or provide WebSearchOptions to customize.