diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a85f1e..de64065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # CHANGELOG +## 1.0.0 + +* Refactored the codebase to allow using the code generation functionality as a library. + ## 0.1.0 (2022-10-06) * Initial version. diff --git a/generate.go b/gen/generate.go similarity index 64% rename from generate.go rename to gen/generate.go index 97a9d01..de1d473 100644 --- a/generate.go +++ b/gen/generate.go @@ -1,8 +1,10 @@ -package main +// Package gen generates Go code from an XGBoost model. +package gen import ( "fmt" "go/format" + "os" ) type treeFunction struct { @@ -10,7 +12,7 @@ type treeFunction struct { Name string } -func codegen( +func generateSource( packageName, funcName string, trees []*node, @@ -65,3 +67,36 @@ func codegenTree(r *renderer, tree *node, level int) (string, error) { return r.executeDecisionNode(tree, level, left, right) } + +// GenerateFile generates a .go file containing a function that implements the XGB model. +func GenerateFile( + inputJSON string, + packageName, + funcName, + outputFile string, +) error { + x, err := readModel(inputJSON) + if err != nil { + return err + } + + trees, err := readTrees(x) + if err != nil { + return err + } + + r, err := newRenderer() + if err != nil { + return err + } + + code, err := generateSource(packageName, funcName, trees, r) + if err != nil { + return err + } + + if err := os.WriteFile(outputFile, []byte(code), 0o644); err != nil { + return fmt.Errorf("error writing file: %w", err) + } + return nil +} diff --git a/parse.go b/gen/parse.go similarity index 99% rename from parse.go rename to gen/parse.go index 7958a37..5e0a001 100644 --- a/parse.go +++ b/gen/parse.go @@ -1,4 +1,4 @@ -package main +package gen import ( "encoding/json" diff --git a/parse_test.go b/gen/parse_test.go similarity index 99% rename from parse_test.go rename to gen/parse_test.go index c1bac04..6fc7196 100644 --- a/parse_test.go +++ b/gen/parse_test.go @@ -1,4 +1,4 @@ -package main +package gen import ( "path/filepath" diff --git a/templates.go b/gen/templates.go similarity index 99% rename from templates.go rename to gen/templates.go index 0223f0b..422e599 100644 --- a/templates.go +++ b/gen/templates.go @@ -1,4 +1,4 @@ -package main +package gen import ( "bytes" diff --git a/templates/decision_node.tmpl b/gen/templates/decision_node.tmpl similarity index 100% rename from templates/decision_node.tmpl rename to gen/templates/decision_node.tmpl diff --git a/templates/root.tmpl b/gen/templates/root.tmpl similarity index 100% rename from templates/root.tmpl rename to gen/templates/root.tmpl diff --git a/templates/terminal_node.tmpl b/gen/templates/terminal_node.tmpl similarity index 100% rename from templates/terminal_node.tmpl rename to gen/templates/terminal_node.tmpl diff --git a/testdata/breast-cancer/generate-model.py b/gen/testdata/breast-cancer/generate-model.py similarity index 100% rename from testdata/breast-cancer/generate-model.py rename to gen/testdata/breast-cancer/generate-model.py diff --git a/testdata/breast-cancer/model.json b/gen/testdata/breast-cancer/model.json similarity index 100% rename from testdata/breast-cancer/model.json rename to gen/testdata/breast-cancer/model.json diff --git a/testdata/breast-cancer/preds.csv b/gen/testdata/breast-cancer/preds.csv similarity index 100% rename from testdata/breast-cancer/preds.csv rename to gen/testdata/breast-cancer/preds.csv diff --git a/testdata/breast-cancer/xtest.csv b/gen/testdata/breast-cancer/xtest.csv similarity index 100% rename from testdata/breast-cancer/xtest.csv rename to gen/testdata/breast-cancer/xtest.csv diff --git a/testdata/main.go b/gen/testdata/main.go similarity index 100% rename from testdata/main.go rename to gen/testdata/main.go diff --git a/testdata/small-model/model.json b/gen/testdata/small-model/model.json similarity index 100% rename from testdata/small-model/model.json rename to gen/testdata/small-model/model.json diff --git a/testdata/small-model/preds.csv b/gen/testdata/small-model/preds.csv similarity index 100% rename from testdata/small-model/preds.csv rename to gen/testdata/small-model/preds.csv diff --git a/testdata/small-model/xtest.csv b/gen/testdata/small-model/xtest.csv similarity index 100% rename from testdata/small-model/xtest.csv rename to gen/testdata/small-model/xtest.csv diff --git a/main.go b/main.go index 9e4af70..66ec61b 100644 --- a/main.go +++ b/main.go @@ -1,4 +1,4 @@ -// xgb2code generates code for an XGB model. +// This program runs a command line program that generates Go code from an xgb model in JSON format. package main import ( @@ -6,6 +6,8 @@ import ( "fmt" "log" "os" + + "github.com/maxmind/xgb2code/gen" ) type language string @@ -57,41 +59,8 @@ func main() { os.Exit(1) } - err := run(*inputJSON, *packageName, *funcName, *outputFile) + err := gen.GenerateFile(*inputJSON, *packageName, *funcName, *outputFile) if err != nil { log.Fatal(err) } } - -func run( - inputJSON string, - packageName, - funcName, - outputFile string, -) error { - x, err := readModel(inputJSON) - if err != nil { - return err - } - - trees, err := readTrees(x) - if err != nil { - return err - } - - r, err := newRenderer() - if err != nil { - return err - } - - code, err := codegen(packageName, funcName, trees, r) - if err != nil { - return err - } - - if err := os.WriteFile(outputFile, []byte(code), 0o644); err != nil { - return fmt.Errorf("error writing file: %w", err) - } - - return nil -} diff --git a/main_test.go b/main_test.go index ac20e15..14890b1 100644 --- a/main_test.go +++ b/main_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "testing" + "github.com/maxmind/xgb2code/gen" "github.com/stretchr/testify/require" ) @@ -21,8 +22,7 @@ func TestGenerateAndRunModels(t *testing.T) { for _, test := range tests { t.Run(test.model, func(t *testing.T) { // Generate the code. - - modelDir := filepath.Join("testdata", test.model) + modelDir := filepath.Join("gen", "testdata", test.model) modelFile := filepath.Join(modelDir, "model.json") packageName := "main" @@ -31,13 +31,13 @@ func TestGenerateAndRunModels(t *testing.T) { outputDir := t.TempDir() funcFile := filepath.Join(outputDir, "predict.go") - err := run(modelFile, packageName, functionName, funcFile) + err := gen.GenerateFile(modelFile, packageName, functionName, funcFile) require.NoError(t, err) // Copy the test program and test data into place. files := []string{ - filepath.Join("testdata", "main.go"), + filepath.Join("gen", "testdata", "main.go"), filepath.Join(modelDir, "xtest.csv"), filepath.Join(modelDir, "preds.csv"), }