diff --git a/message/extract/extract.go b/message/extract/extract.go index 797d899..ec495f5 100644 --- a/message/extract/extract.go +++ b/message/extract/extract.go @@ -51,7 +51,7 @@ type Options struct { // // 函数至少需要一个参数,且其第一个参数的类型必须为 string。 // 如果指向的是方法,那么在调用此方法的结构必须有明确类型声明,不能由类型推荐获得。 - // 比如,当 f 为 golang.org/x/text/message.Printer.Printf 时: + // 比如,当 p 为 golang.org/x/text/message.Printer.Printf 时: // // // 以下无法提取内容 // p := message.NewPrinter(); @@ -64,9 +64,8 @@ type Options struct { } type extractor struct { - log message.LogFunc - funcs []localeFunc - fset *token.FileSet + log message.LogFunc + fset *token.FileSet mux sync.Mutex msg []message.Message @@ -77,20 +76,20 @@ func Extract(ctx context.Context, o *Options) (*message.Language, error) { // NOTE: 有可能存在将 localeutil.Phrase 二次封装的情况, // 为了尽可能多地找到本地化字符串,所以采用用户指定函数的方法。 + // 获取所有需要分析的源码目录 dirs, err := getDir(o.Root, o.Recursive, o.SkipSubModule) if err != nil { return nil, err } ex := &extractor{ - log: o.Log, - funcs: split(o.Funcs...), - fset: token.NewFileSet(), + log: o.Log, + fset: token.NewFileSet(), msg: make([]message.Message, 0, 100), } - if err := ex.scanDirs(ctx, dirs); err != nil { + if err := ex.scanDirs(ctx, dirs, o.Funcs); err != nil { return nil, err } @@ -99,7 +98,7 @@ func Extract(ctx context.Context, o *Options) (*message.Language, error) { return &message.Language{ID: o.Language, Messages: ex.msg}, nil } -func (ex *extractor) scanDirs(ctx context.Context, dirs []string) error { +func (ex *extractor) scanDirs(ctx context.Context, dirs, funcs []string) error { wg := &sync.WaitGroup{} for _, dir := range dirs { entries, err := os.ReadDir(dir) @@ -127,7 +126,7 @@ func (ex *extractor) scanDirs(ctx context.Context, dirs []string) error { return } - ex.inspectFile(p, f) + ex.inspectFile(p, f, funcs) }(filepath.Join(dir, e.Name())) } } @@ -145,7 +144,7 @@ func logErr(err error, log message.LogFunc) { log(localeutil.StringPhrase(err.Error())) } -func (ex *extractor) inspectFile(p string, f *ast.File) { +func (ex *extractor) inspectFile(p string, f *ast.File, funcs []string) { const notFound = localeutil.StringPhrase("go.mod not found") modPath, err := source.ModPath(p) @@ -158,7 +157,7 @@ func (ex *extractor) inspectFile(p string, f *ast.File) { return } - mods := filterImportFuncs(modPath, f.Imports, ex.funcs) + mods := filterImportFuncs(modPath, f.Imports, funcs) ast.Inspect(f, func(n ast.Node) bool { switch expr := n.(type) { case *ast.TypeSpec, *ast.ImportSpec: diff --git a/message/extract/func.go b/message/extract/func.go index 509be0b..df9af1b 100644 --- a/message/extract/func.go +++ b/message/extract/func.go @@ -9,13 +9,6 @@ import ( "strings" ) -// 表示本地化的函数 -type localeFunc struct { - path string // 函数的完整导入路径 - structure string // 类型名,可能为空 - name string // 函数名 -} - // 表示由 import 转换后的函数名 type importFunc struct { modName string // import 中的别名 @@ -23,35 +16,20 @@ type importFunc struct { name string // 函数名 } -func split(funcs ...string) []localeFunc { - ret := make([]localeFunc, 0, len(funcs)) - for _, f := range funcs { - base := path.Base(f) - dir := path.Dir(f) - switch strs := strings.Split(base, "."); len(strs) { - case 2: - ret = append(ret, localeFunc{path: path.Join(dir, strs[0]), name: strs[1]}) - case 3: - ret = append(ret, localeFunc{path: path.Join(dir, strs[0]), structure: strs[1], name: strs[2]}) - default: - panic(fmt.Sprintf("%s 格式无效", f)) - } - } - return ret -} +func filterImportFuncs(fileModPath string, imports []*ast.ImportSpec, funcList []string) []importFunc { + funcs := split(funcList...) -func filterImportFuncs(fileModPath string, imports []*ast.ImportSpec, funcs []localeFunc) []importFunc { mods := make([]importFunc, 0, len(funcs)) for _, f := range funcs { - if fileModPath == f.path { - mods = append(mods, importFunc{name: f.name, structName: f.structure}) + if fileModPath == f.modName { + mods = append(mods, importFunc{name: f.name, structName: f.structName}) continue } for _, ip := range imports { modPath := strings.Trim(ip.Path.Value, "\"") - if f.path != modPath { + if f.modName != modPath { continue } @@ -62,9 +40,28 @@ func filterImportFuncs(fileModPath string, imports []*ast.ImportSpec, funcs []lo modName = path.Base(modPath) } - mods = append(mods, importFunc{modName: modName, name: f.name, structName: f.structure}) + mods = append(mods, importFunc{modName: modName, name: f.name, structName: f.structName}) } } return mods } + +// 返回从 [Options.Funcs] 中分析而来的中间数据 +// 此时返回元素中的 modName 表示的是完整的模块导出山路径。 +func split(funcs ...string) []importFunc { + ret := make([]importFunc, 0, len(funcs)) + for _, f := range funcs { + base := path.Base(f) + dir := path.Dir(f) + switch strs := strings.Split(base, "."); len(strs) { + case 2: + ret = append(ret, importFunc{modName: path.Join(dir, strs[0]), name: strs[1]}) + case 3: + ret = append(ret, importFunc{modName: path.Join(dir, strs[0]), structName: strs[1], name: strs[2]}) + default: + panic(fmt.Sprintf("%s 格式无效", f)) + } + } + return ret +} diff --git a/message/extract/func_test.go b/message/extract/func_test.go index 9394581..981335e 100644 --- a/message/extract/func_test.go +++ b/message/extract/func_test.go @@ -14,10 +14,10 @@ func TestSplit(t *testing.T) { a := assert.New(t, false) fns := split("github.com/issue9/localeutil.Phrase", "github.com/issue9/localeutil.Error", "github.com/issue9/localeutil.Struct.Printf") - a.Equal(fns, []localeFunc{ - {path: "github.com/issue9/localeutil", name: "Phrase"}, - {path: "github.com/issue9/localeutil", name: "Error"}, - {path: "github.com/issue9/localeutil", name: "Printf", structure: "Struct"}, + a.Equal(fns, []importFunc{ + {modName: "github.com/issue9/localeutil", name: "Phrase"}, + {modName: "github.com/issue9/localeutil", name: "Error"}, + {modName: "github.com/issue9/localeutil", name: "Printf", structName: "Struct"}, }) a.PanicString(func() { @@ -31,7 +31,7 @@ func TestFilterImportFuncs(t *testing.T) { f, err := parser.ParseFile(token.NewFileSet(), "./testdata/testdata.go", nil, parser.AllErrors) a.NotError(err).NotNil(f) - fns := split("github.com/issue9/localeutil.Phrase", "github.com/issue9/localeutil.Error") + fns := []string{"github.com/issue9/localeutil.Phrase", "github.com/issue9/localeutil.Error"} mods := filterImportFuncs("", f.Imports, fns) a.Equal(mods, []importFunc{ {modName: "localeutil", name: "Phrase"}, @@ -46,7 +46,7 @@ func TestFilterImportFuncs(t *testing.T) { f, err := parser.ParseFile(token.NewFileSet(), "./testdata/testdata.go", nil, parser.AllErrors) a.NotError(err).NotNil(f) - fns := split("github.com/issue9/localeutil.Phrase") + fns := []string{"github.com/issue9/localeutil.Phrase"} mods := filterImportFuncs("", f.Imports, fns) a.Equal(mods, []importFunc{ {modName: "localeutil", name: "Phrase"}, @@ -59,7 +59,7 @@ func TestFilterImportFuncs(t *testing.T) { f, err := parser.ParseFile(token.NewFileSet(), "./testdata/struct.go", nil, parser.AllErrors) a.NotError(err).NotNil(f) - fns := split("golang.org/x/text/message.Printer.Printf") + fns := []string{"golang.org/x/text/message.Printer.Printf"} mods := filterImportFuncs("", f.Imports, fns) a.Equal(mods, []importFunc{ {modName: "message", name: "Printf", structName: "Printer"}, @@ -72,7 +72,7 @@ func TestFilterImportFuncs(t *testing.T) { f, err := parser.ParseFile(token.NewFileSet(), "./testdata/struct.go", nil, parser.AllErrors) a.NotError(err).NotNil(f) - fns := split("golang.org/x/text/message.Printer.Printf") + fns := []string{"golang.org/x/text/message.Printer.Printf"} mods := filterImportFuncs("golang.org/x/text/message", f.Imports, fns) a.Equal(mods, []importFunc{ {name: "Printf", structName: "Printer"},