Skip to content

Commit

Permalink
refactor: 合并了 importFunc 和 localFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed Nov 26, 2023
1 parent 10f6afb commit 5e9c42b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 48 deletions.
23 changes: 11 additions & 12 deletions message/extract/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type Options struct {
//
// 函数至少需要一个参数,且其第一个参数的类型必须为 string。
// 如果指向的是方法,那么在调用此方法的结构必须有明确类型声明,不能由类型推荐获得。
// 比如,当 f 为 golang.org/x/text/message.Printer.Printf 时:
// 比如,当 p 为 golang.org/x/text/message.Printer.Printf 时:
//
// // 以下无法提取内容
// p := message.NewPrinter();
Expand All @@ -64,9 +64,8 @@ type Options struct {
}

type extractor struct {
log message.LogFunc
funcs []localeFunc
fset *token.FileSet
log message.LogFunc
fset *token.FileSet

mux sync.Mutex
msg []message.Message
Expand All @@ -77,20 +76,20 @@ func Extract(ctx context.Context, o *Options) (*message.Language, error) {
// NOTE: 有可能存在将 localeutil.Phrase 二次封装的情况,
// 为了尽可能多地找到本地化字符串,所以采用用户指定函数的方法。

// 获取所有需要分析的源码目录
dirs, err := getDir(o.Root, o.Recursive, o.SkipSubModule)
if err != nil {
return nil, err
}

ex := &extractor{
log: o.Log,
funcs: split(o.Funcs...),
fset: token.NewFileSet(),
log: o.Log,
fset: token.NewFileSet(),

msg: make([]message.Message, 0, 100),
}

if err := ex.scanDirs(ctx, dirs); err != nil {
if err := ex.scanDirs(ctx, dirs, o.Funcs); err != nil {
return nil, err
}

Expand All @@ -99,7 +98,7 @@ func Extract(ctx context.Context, o *Options) (*message.Language, error) {
return &message.Language{ID: o.Language, Messages: ex.msg}, nil
}

func (ex *extractor) scanDirs(ctx context.Context, dirs []string) error {
func (ex *extractor) scanDirs(ctx context.Context, dirs, funcs []string) error {
wg := &sync.WaitGroup{}
for _, dir := range dirs {
entries, err := os.ReadDir(dir)
Expand Down Expand Up @@ -127,7 +126,7 @@ func (ex *extractor) scanDirs(ctx context.Context, dirs []string) error {
return
}

ex.inspectFile(p, f)
ex.inspectFile(p, f, funcs)
}(filepath.Join(dir, e.Name()))
}
}
Expand All @@ -145,7 +144,7 @@ func logErr(err error, log message.LogFunc) {
log(localeutil.StringPhrase(err.Error()))
}

func (ex *extractor) inspectFile(p string, f *ast.File) {
func (ex *extractor) inspectFile(p string, f *ast.File, funcs []string) {
const notFound = localeutil.StringPhrase("go.mod not found")

modPath, err := source.ModPath(p)
Expand All @@ -158,7 +157,7 @@ func (ex *extractor) inspectFile(p string, f *ast.File) {
return
}

mods := filterImportFuncs(modPath, f.Imports, ex.funcs)
mods := filterImportFuncs(modPath, f.Imports, funcs)
ast.Inspect(f, func(n ast.Node) bool {
switch expr := n.(type) {
case *ast.TypeSpec, *ast.ImportSpec:
Expand Down
53 changes: 25 additions & 28 deletions message/extract/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,49 +9,27 @@ import (
"strings"
)

// 表示本地化的函数
type localeFunc struct {
path string // 函数的完整导入路径
structure string // 类型名,可能为空
name string // 函数名
}

// 表示由 import 转换后的函数名
type importFunc struct {
modName string // import 中的别名
structName string // 类型名,可能为空
name string // 函数名
}

func split(funcs ...string) []localeFunc {
ret := make([]localeFunc, 0, len(funcs))
for _, f := range funcs {
base := path.Base(f)
dir := path.Dir(f)
switch strs := strings.Split(base, "."); len(strs) {
case 2:
ret = append(ret, localeFunc{path: path.Join(dir, strs[0]), name: strs[1]})
case 3:
ret = append(ret, localeFunc{path: path.Join(dir, strs[0]), structure: strs[1], name: strs[2]})
default:
panic(fmt.Sprintf("%s 格式无效", f))
}
}
return ret
}
func filterImportFuncs(fileModPath string, imports []*ast.ImportSpec, funcList []string) []importFunc {
funcs := split(funcList...)

func filterImportFuncs(fileModPath string, imports []*ast.ImportSpec, funcs []localeFunc) []importFunc {
mods := make([]importFunc, 0, len(funcs))

for _, f := range funcs {
if fileModPath == f.path {
mods = append(mods, importFunc{name: f.name, structName: f.structure})
if fileModPath == f.modName {
mods = append(mods, importFunc{name: f.name, structName: f.structName})
continue
}

for _, ip := range imports {
modPath := strings.Trim(ip.Path.Value, "\"")
if f.path != modPath {
if f.modName != modPath {
continue
}

Expand All @@ -62,9 +40,28 @@ func filterImportFuncs(fileModPath string, imports []*ast.ImportSpec, funcs []lo
modName = path.Base(modPath)
}

mods = append(mods, importFunc{modName: modName, name: f.name, structName: f.structure})
mods = append(mods, importFunc{modName: modName, name: f.name, structName: f.structName})
}
}

return mods
}

// 返回从 [Options.Funcs] 中分析而来的中间数据
// 此时返回元素中的 modName 表示的是完整的模块导出山路径。
func split(funcs ...string) []importFunc {
ret := make([]importFunc, 0, len(funcs))
for _, f := range funcs {
base := path.Base(f)
dir := path.Dir(f)
switch strs := strings.Split(base, "."); len(strs) {
case 2:
ret = append(ret, importFunc{modName: path.Join(dir, strs[0]), name: strs[1]})
case 3:
ret = append(ret, importFunc{modName: path.Join(dir, strs[0]), structName: strs[1], name: strs[2]})
default:
panic(fmt.Sprintf("%s 格式无效", f))
}
}
return ret
}
16 changes: 8 additions & 8 deletions message/extract/func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ func TestSplit(t *testing.T) {
a := assert.New(t, false)

fns := split("github.com/issue9/localeutil.Phrase", "github.com/issue9/localeutil.Error", "github.com/issue9/localeutil.Struct.Printf")
a.Equal(fns, []localeFunc{
{path: "github.com/issue9/localeutil", name: "Phrase"},
{path: "github.com/issue9/localeutil", name: "Error"},
{path: "github.com/issue9/localeutil", name: "Printf", structure: "Struct"},
a.Equal(fns, []importFunc{
{modName: "github.com/issue9/localeutil", name: "Phrase"},
{modName: "github.com/issue9/localeutil", name: "Error"},
{modName: "github.com/issue9/localeutil", name: "Printf", structName: "Struct"},
})

a.PanicString(func() {
Expand All @@ -31,7 +31,7 @@ func TestFilterImportFuncs(t *testing.T) {
f, err := parser.ParseFile(token.NewFileSet(), "./testdata/testdata.go", nil, parser.AllErrors)
a.NotError(err).NotNil(f)

fns := split("github.com/issue9/localeutil.Phrase", "github.com/issue9/localeutil.Error")
fns := []string{"github.com/issue9/localeutil.Phrase", "github.com/issue9/localeutil.Error"}
mods := filterImportFuncs("", f.Imports, fns)
a.Equal(mods, []importFunc{
{modName: "localeutil", name: "Phrase"},
Expand All @@ -46,7 +46,7 @@ func TestFilterImportFuncs(t *testing.T) {
f, err := parser.ParseFile(token.NewFileSet(), "./testdata/testdata.go", nil, parser.AllErrors)
a.NotError(err).NotNil(f)

fns := split("github.com/issue9/localeutil.Phrase")
fns := []string{"github.com/issue9/localeutil.Phrase"}
mods := filterImportFuncs("", f.Imports, fns)
a.Equal(mods, []importFunc{
{modName: "localeutil", name: "Phrase"},
Expand All @@ -59,7 +59,7 @@ func TestFilterImportFuncs(t *testing.T) {
f, err := parser.ParseFile(token.NewFileSet(), "./testdata/struct.go", nil, parser.AllErrors)
a.NotError(err).NotNil(f)

fns := split("golang.org/x/text/message.Printer.Printf")
fns := []string{"golang.org/x/text/message.Printer.Printf"}
mods := filterImportFuncs("", f.Imports, fns)
a.Equal(mods, []importFunc{
{modName: "message", name: "Printf", structName: "Printer"},
Expand All @@ -72,7 +72,7 @@ func TestFilterImportFuncs(t *testing.T) {
f, err := parser.ParseFile(token.NewFileSet(), "./testdata/struct.go", nil, parser.AllErrors)
a.NotError(err).NotNil(f)

fns := split("golang.org/x/text/message.Printer.Printf")
fns := []string{"golang.org/x/text/message.Printer.Printf"}
mods := filterImportFuncs("golang.org/x/text/message", f.Imports, fns)
a.Equal(mods, []importFunc{
{name: "Printf", structName: "Printer"},
Expand Down

0 comments on commit 5e9c42b

Please sign in to comment.