Skip to content

Commit

Permalink
AST + runtime memoizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Mzack9999 committed Feb 8, 2024
1 parent 4023fb7 commit 912b6d8
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 2 deletions.
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.21

require (
github.com/Masterminds/semver/v3 v3.2.1
github.com/Mzack9999/gcache v0.0.0-20230410081825-519e28eab057
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2
github.com/charmbracelet/glamour v0.6.0
github.com/denisbrodbeck/machineid v1.0.1
Expand All @@ -26,13 +27,13 @@ require (
go.uber.org/multierr v1.11.0
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
golang.org/x/oauth2 v0.11.0
golang.org/x/sync v0.3.0
golang.org/x/sys v0.16.0
golang.org/x/text v0.14.0
gopkg.in/yaml.v3 v3.0.1
)

require (
github.com/Mzack9999/gcache v0.0.0-20230410081825-519e28eab057 // indirect
github.com/Mzack9999/go-http-digest-auth-client v0.6.1-0.20220414142836-eb8883508809 // indirect
github.com/VividCortex/ewma v1.2.0 // indirect
github.com/akrylysov/pogreb v0.10.1 // indirect
Expand Down Expand Up @@ -114,7 +115,7 @@ require (
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/term v0.16.0
golang.org/x/tools v0.13.0 // indirect
golang.org/x/tools v0.13.0
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.31.0 // indirect
)
15 changes: 15 additions & 0 deletions memoize/cmd/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package main

import (
"log"

"github.com/projectdiscovery/utils/memoize"
)

func main() {
out, err := memoize.File("../tests/test.go", "test")
if err != nil {
panic(err)
}
log.Println(out)
}
52 changes: 52 additions & 0 deletions memoize/gen/memoize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package main

import (
"flag"
"io/fs"
"log"
"os"
"path/filepath"
"strings"

fileutil "github.com/projectdiscovery/utils/file"
"github.com/projectdiscovery/utils/memoize"
)

var (
srcFolder = flag.String("src", "", "source folder")
dstFolder = flag.String("dst", "", "destination foldder")
packageName = flag.String("pkg", "memo", "destination package")
)

func main() {
flag.Parse()

_ = fileutil.CreateFolder(*dstFolder)

err := filepath.WalkDir(*srcFolder, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if ext := filepath.Ext(path); strings.ToLower(ext) != ".go" {
return nil
}

return process(path)
})
if err != nil {
log.Fatal(err)
}
}

func process(path string) error {
filename := filepath.Base(path)
dstFile := filepath.Join(*dstFolder, filename)
out, err := memoize.File(path, *packageName)
if err != nil {
return err
}
return os.WriteFile(dstFile, out, os.ModePerm)
}
245 changes: 245 additions & 0 deletions memoize/memoize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
package memoize

import (
"bytes"
"errors"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
"os"
"strings"
"text/template"

"github.com/Mzack9999/gcache"
"golang.org/x/sync/singleflight"
"golang.org/x/tools/imports"
)

type Memoizer struct {
cache gcache.Cache[string, interface{}]
group singleflight.Group
}

type MemoizeOption func(m *Memoizer) error

func WithMaxSize(size int) MemoizeOption {
return func(m *Memoizer) error {
m.cache = gcache.New[string, interface{}](size).Build()
return nil
}
}

func New(options ...MemoizeOption) (*Memoizer, error) {
m := &Memoizer{}
for _, option := range options {
if err := option(m); err != nil {
return nil, err
}
}

return m, nil
}

func (m *Memoizer) Do(funcHash string, fn func() (interface{}, error)) (interface{}, error, bool) {
if value, err := m.cache.GetIFPresent(funcHash); !errors.Is(err, gcache.KeyNotFoundError) {
return value, err, true
}

value, err, _ := m.group.Do(funcHash, func() (interface{}, error) {
data, err := fn()

if err == nil {
m.cache.Set(funcHash, data)

Check failure on line 55 in memoize/memoize.go

View workflow job for this annotation

GitHub Actions / Lint Test

Error return value of `m.cache.Set` is not checked (errcheck)
}

return data, err
})

return value, err, false
}

func File(sourceFile, packageName string) ([]byte, error) {
data, err := os.ReadFile(sourceFile)
if err != nil {
return nil, err
}

return Src(sourceFile, data, packageName)
}

func Src(sourcePath string, source []byte, packageName string) ([]byte, error) {
var (
fileData FileData
content bytes.Buffer
)

tmpl, err := template.New("package_template").Parse(packageTemplate)
if err != nil {
return nil, err
}

fileData.PackageName = packageName

fset := token.NewFileSet()
node, err := parser.ParseFile(fset, sourcePath, source, parser.ParseComments)
if err != nil {
return nil, err
}

for _, nn := range node.Imports {
var packageImport PackageImport
if nn.Name != nil {
packageImport.Name = nn.Name.Name
}

if nn.Path != nil {
packageImport.Path = nn.Path.Value
}

fileData.Imports = append(fileData.Imports, packageImport)
}

fileData.SourcePackage = node.Name.Name

ast.Inspect(node, func(n ast.Node) bool {
switch nn := n.(type) {
case *ast.FuncDecl:
if !nn.Name.IsExported() {
return false
}
if nn.Doc == nil {
return false
}

var funcDeclaration FunctionDeclaration
funcDeclaration.IsExported = true
funcDeclaration.Name = nn.Name.Name
funcDeclaration.SourcePackage = fileData.SourcePackage
var funcSign strings.Builder
printer.Fprint(&funcSign, fset, nn.Type)
funcDeclaration.Signature = strings.Replace(funcSign.String(), "func", "func "+funcDeclaration.Name, 1)

for _, comment := range nn.Doc.List {
if comment.Text == "// @memo" {
if nn.Type.Params != nil {
for idx, param := range nn.Type.Params.List {
var funcParam FuncValue
funcParam.Index = idx
for _, name := range param.Names {
funcParam.Name = name.String()
}
funcParam.Type = fmt.Sprint(param.Type)
funcDeclaration.Params = append(funcDeclaration.Params, funcParam)
}
}

if nn.Type.Results != nil {
for idx, res := range nn.Type.Results.List {
var result FuncValue
result.Index = idx
for _, name := range res.Names {
result.Name = name.String()
}
result.Type = fmt.Sprint(res.Type)
funcDeclaration.Results = append(funcDeclaration.Results, result)
}
}
}
}
fileData.Functions = append(fileData.Functions, funcDeclaration)
return false
default:
return true
}
})

err = tmpl.Execute(&content, fileData)
if err != nil {
return nil, err
}

out, err := imports.Process(sourcePath, content.Bytes(), nil)
if err != nil {
return nil, err
}

return format.Source(out)
}

type PackageImport struct {
Name string
Path string
}

type FuncValue struct {
Index int
Name string
Type string
}

func (f FuncValue) ResultName() string {
return fmt.Sprintf("result%d", f.Index)
}

type FunctionDeclaration struct {
SourcePackage string
IsExported bool
Name string
Params []FuncValue
Results []FuncValue
Signature string
}

func (f FunctionDeclaration) HasParams() bool {
return len(f.Params) > 0
}

func (f FunctionDeclaration) ParamsNames() string {
var params []string
for _, param := range f.Params {
params = append(params, param.Name)
}
return strings.Join(params, ",")
}

func (f FunctionDeclaration) HasReturn() bool {
return len(f.Results) > 0
}

func (f FunctionDeclaration) WantSyncOnce() bool {
return !f.HasParams()
}

func (f FunctionDeclaration) SyncOnceVarName() string {
return fmt.Sprintf("once%s", f.Name)
}

func (f FunctionDeclaration) WantReturn() bool {
return f.HasReturn()
}

func (f FunctionDeclaration) ResultStructType() string {
return fmt.Sprintf("result%s", f.Name)
}

func (f FunctionDeclaration) ResultStructVarName() string {
return fmt.Sprintf("v%s", f.ResultStructType())
}

func (f FunctionDeclaration) ResultStructFields() string {
var results []string
for _, result := range f.Results {
results = append(results, fmt.Sprintf("%s.%s", f.ResultStructVarName(), result.ResultName()))
}
return strings.Join(results, ",")
}

type FileData struct {
PackageName string
SourcePackage string
Imports []PackageImport
Functions []FunctionDeclaration
}
28 changes: 28 additions & 0 deletions memoize/memoize_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package memoize

import (
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestMemo(t *testing.T) {
testingFunc := func() (interface{}, error) {
time.Sleep(10 * time.Second)
return "b", nil
}

m, err := New(WithMaxSize(5))
require.Nil(t, err)
start := time.Now()
m.Do("test", testingFunc)

Check failure on line 19 in memoize/memoize_test.go

View workflow job for this annotation

GitHub Actions / Lint Test

Error return value of `m.Do` is not checked (errcheck)
m.Do("test", testingFunc)

Check failure on line 20 in memoize/memoize_test.go

View workflow job for this annotation

GitHub Actions / Lint Test

Error return value of `m.Do` is not checked (errcheck)
require.True(t, time.Since(start) < time.Duration(15*time.Second))
}

func TestSrc(t *testing.T) {
out, err := File("tests/test.go", "test")
require.Nil(t, err)
require.True(t, len(out) > 0)
}
6 changes: 6 additions & 0 deletions memoize/package_template.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package memoize

import _ "embed"

//go:embed package_template.tpl
var packageTemplate string
Loading

0 comments on commit 912b6d8

Please sign in to comment.