Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/docs/openapi3gen.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ type ExportComponentSchemasOptions struct {
ExportGenerics bool
}

type FieldNameGenerator func(field reflect.StructField, defaultName string) string
FieldNameGenerator allows client to set custom name for struct fields in
the generated schema. defaultName is the name, determined by generator's
standard JSON, YAML and Go field name resolution rules. Useful for
processing custom binding tags, such as `form` or `xml`. Function should
always return non-empty string.

type Generator struct {
Types map[reflect.Type]*openapi3.SchemaRef

Expand Down Expand Up @@ -60,6 +67,8 @@ func CreateComponentSchemas(exso ExportComponentSchemasOptions) Option
CreateComponents changes the default behavior to add all schemas as
components Reduces duplicate schemas in routes

func CreateFieldNameGenerator(fngnrt FieldNameGenerator) Option

func CreateTypeNameGenerator(tngnrt TypeNameGenerator) Option

func SchemaCustomizer(sc SchemaCustomizerFn) Option
Expand Down
41 changes: 22 additions & 19 deletions openapi3gen/openapi3gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,19 @@ type ExportComponentSchemasOptions struct {

type TypeNameGenerator func(t reflect.Type) string

// FieldNameGenerator allows client to set custom name for struct fields in the generated schema.
// defaultName is the name, determined by generator's standard JSON, YAML and Go field name resolution rules.
// Useful for processing custom binding tags, such as `form` or `xml`.
// Function should always return non-empty string.
type FieldNameGenerator func(field reflect.StructField, defaultName string) string

type generatorOpt struct {
useAllExportedFields bool
throwErrorOnCycle bool
schemaCustomizer SchemaCustomizerFn
exportComponentSchemas ExportComponentSchemasOptions
typeNameGenerator TypeNameGenerator
fieldNameGenerator FieldNameGenerator
}

// UseAllExportedFields changes the default behavior of only
Expand All @@ -70,6 +77,10 @@ func CreateTypeNameGenerator(tngnrt TypeNameGenerator) Option {
return func(x *generatorOpt) { x.typeNameGenerator = tngnrt }
}

func CreateFieldNameGenerator(fngnrt FieldNameGenerator) Option {
return func(x *generatorOpt) { x.fieldNameGenerator = fngnrt }
}

// ThrowErrorOnCycle changes the default behavior of creating cycle
// refs to instead error if a cycle is detected.
func ThrowErrorOnCycle() Option {
Expand Down Expand Up @@ -342,28 +353,13 @@ func (g *Generator) generateWithoutSaving(parents []*theTypeInfo, t reflect.Type
if !fieldInfo.HasJSONTag && !g.opts.useAllExportedFields {
continue
}

// If asked, try to use yaml tag
fieldName, fType := fieldInfo.JSONName, fieldInfo.Type
ff := getStructField(t, fieldInfo)
if !fieldInfo.HasJSONTag && g.opts.useAllExportedFields {
// Handle anonymous fields/embedded structs
if t.Field(fieldInfo.Index[0]).Anonymous {
ref, err := g.generateSchemaRefFor(parents, fType, fieldName, tag)
if err != nil {
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
ref = g.generateCycleSchemaRef(fType, schema)
} else {
return nil, err
}
}
if ref != nil {
g.SchemaRefs[ref]++
schema.WithPropertyRef(fieldName, ref)
}
} else {
ff := getStructField(t, fieldInfo)
if tag, ok := ff.Tag.Lookup("yaml"); ok && tag != "-" {
fieldName, fType = tag, ff.Type
}
if tag, ok := ff.Tag.Lookup("yaml"); ok && tag != "-" {
fieldName = tag
}
}

Expand All @@ -374,6 +370,13 @@ func (g *Generator) generateWithoutSaving(parents []*theTypeInfo, t reflect.Type
fieldTag = ff.Tag
}

if g.opts.fieldNameGenerator != nil {
fieldName = g.opts.fieldNameGenerator(ff, fieldName)
if fieldName == "" {
return nil, fmt.Errorf("field name can't be empty")
}
}

ref, err := g.generateSchemaRefFor(parents, fType, fieldName, fieldTag)
if err != nil {
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
Expand Down
158 changes: 158 additions & 0 deletions openapi3gen/openapi3gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"encoding/json"
"errors"
"fmt"
"maps"
"reflect"
"slices"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -704,3 +706,159 @@ func TestExportComponentSchemasSkipsAnonymousType(t *testing.T) {
assert.NotEmpty(t, key, "every component schema must have a non-empty key")
}
}

func TestEmbeddedFieldGeneratedOnce(t *testing.T) {
type Embedded struct {
Field string
}
type Container struct {
Embedded
}

calls := 0
g := openapi3gen.NewGenerator(
openapi3gen.UseAllExportedFields(),
openapi3gen.SchemaCustomizer(func(name string, _ reflect.Type, _ reflect.StructTag, _ *openapi3.Schema) error {
if name == "Field" {
calls++
}
return nil
}),
)

schemaRef, err := g.GenerateSchemaRef(reflect.TypeFor[Container]())
require.NoError(t, err)
require.Contains(t, schemaRef.Value.Properties, "Field")
require.Equal(t, 1, calls)
}

func TestFieldNameGenerator(t *testing.T) {
type Embedded struct {
EmbeddedField string
}
type Container struct {
PlainField string
JSONField string `json:"json_name"`
YAMLField string `yaml:"yaml_name"`
TaggedField string `property:"custom_name"`
Embedded
}

tests := []struct {
name string
generator openapi3gen.FieldNameGenerator
wantFields []string
wantDefaults map[string]string
wantGoFields []string
wantFieldTags map[string]string
wantErr bool
}{
{
name: "customizes untagged fields",
generator: func(_ reflect.StructField, defaultName string) string { return strings.ToLower(defaultName) },
wantFields: []string{
"plainfield",
"json_name",
"yaml_name",
"taggedfield",
"embeddedfield",
},
},
{
name: "customizes explicit json and yaml names",
generator: func(_ reflect.StructField, defaultName string) string { return "prefix_" + defaultName },
wantFields: []string{
"prefix_PlainField",
"prefix_json_name",
"prefix_yaml_name",
"prefix_TaggedField",
"prefix_EmbeddedField",
},
},
{
name: "uses custom struct tag",
generator: func(f reflect.StructField, defaultName string) string {
if name := f.Tag.Get("property"); name != "" {
return name
}
return defaultName
},
wantFields: []string{"PlainField", "json_name", "yaml_name", "custom_name", "EmbeddedField"},
wantFieldTags: map[string]string{"TaggedField": "custom_name"},
},
{
name: "receives promoted embedded field",
generator: func(_ reflect.StructField, defaultName string) string { return defaultName },
wantFields: []string{
"PlainField",
"json_name",
"yaml_name",
"TaggedField",
"EmbeddedField",
},
wantGoFields: []string{"EmbeddedField"},
},
{
name: "receives resolved default names",
generator: func(field reflect.StructField, defaultName string) string { return defaultName },
wantFields: []string{
"PlainField",
"json_name",
"yaml_name",
"TaggedField",
"EmbeddedField",
},
wantDefaults: map[string]string{
"JSONField": "json_name",
"PlainField": "PlainField",
"YAMLField": "yaml_name",
"EmbeddedField": "EmbeddedField",
},
},
{
name: "empty field names are rejected",
generator: func(field reflect.StructField, defaultName string) string { return "" },
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotDefaults := make(map[string]string)
gotTags := make(map[string]string)
gotGoFields := make(map[string]bool)

g := openapi3gen.NewGenerator(
openapi3gen.UseAllExportedFields(),
openapi3gen.CreateFieldNameGenerator(func(f reflect.StructField, defaultName string) string {
gotDefaults[f.Name] = defaultName
gotTags[f.Name] = f.Tag.Get("property")
gotGoFields[f.Name] = true
return tt.generator(f, defaultName)
}),
)

schemaRef, err := g.GenerateSchemaRef(reflect.TypeFor[Container]())
if tt.wantErr {
require.Error(t, err)
return
}

require.NoError(t, err)

require.ElementsMatch(t, tt.wantFields, slices.Collect(maps.Keys(schemaRef.Value.Properties)))

for field, want := range tt.wantDefaults {
require.Equal(t, want, gotDefaults[field])
}

for field, want := range tt.wantFieldTags {
require.Equal(t, want, gotTags[field])
}

for _, field := range tt.wantGoFields {
require.True(t, gotGoFields[field], "field name generator was not called for %s", field)
}
})
}
}
Loading