diff --git a/package.go b/package.go index 92b521b..e6b5a2c 100644 --- a/package.go +++ b/package.go @@ -3,6 +3,7 @@ package main import ( _ "code.google.com/p/go.tools/go/gcimporter" "code.google.com/p/go.tools/go/types" + "code.google.com/p/go.tools/importer" "errors" "fmt" "go/ast" @@ -29,6 +30,28 @@ type GenTag struct { Negated bool } +func checkTypes(path string, fset *token.FileSet, files []*ast.File) (*types.Package, error) { + /* + * Use customized types.Config instead of types.Check + * By default, types.Check will use gcimporter, which will only import gc-generated files (*.a) + * The customized types.Config will use importer, instead of gcimporter. importer will import gc-generated files (*.a) and source files (*.go). + */ + typesConfig := types.Config{} + typesConfig.Import = func(imports map[string]*types.Package, path string) (*types.Package, error) { + impConfig := importer.Config{} + impConfig.TypeChecker = typesConfig + impConfig.Build = &build.Default + imp := importer.New(&impConfig) + pkgInfo, err := imp.LoadPackage(path) + if nil != err { + return nil, err + } + return pkgInfo.Pkg, nil + } + + return typesConfig.Check(path, fset, files, nil) +} + // Returns one gen Package per Go package found in current directory func getPackages() (result []*Package) { fset := token.NewFileSet() @@ -41,7 +64,7 @@ func getPackages() (result []*Package) { for name, astPackage := range astPackages { pkg := &Package{Name: name} - typesPkg, err := types.Check(name, fset, getAstFiles(astPackage, rootDir)) + typesPkg, err := checkTypes(name, fset, getAstFiles(astPackage, rootDir)) if err != nil { errs = append(errs, err) } diff --git a/predicates.go b/predicates.go index 5fe9323..7d618d9 100644 --- a/predicates.go +++ b/predicates.go @@ -13,7 +13,25 @@ import ( ) func isComparable(typ types.Type) bool { - return types.Comparable(typ) + // From : https://code.google.com/p/go/source/browse/go/types/predicates.go?repo=tools&name=release-branch.go1.2 + switch t := typ.Underlying().(type) { + case *types.Basic: + // assume invalid types to be comparable + // to avoid follow-up errors + return t.Kind() != types.UntypedNil + case *types.Pointer, *types.Interface, *types.Chan: + return true + case *types.Struct: + for i := 0; i < t.NumFields(); i++ { + if !isComparable(t.Field(i).Type()) { + return false + } + } + return true + case *types.Array: + return isComparable(t.Elem()) + } + return false } func isNumeric(typ types.Type) bool {