From c2dd1a0d1392bc38c4444c671a83ec489cd91e8d Mon Sep 17 00:00:00 2001 From: Shadi Romani Date: Mon, 15 Jan 2024 22:31:48 +0000 Subject: [PATCH] Move code generation to a package --- generate.go => gen/generate.go | 38 +++++++++++++++++- parse.go => gen/parse.go | 2 +- parse_test.go => gen/parse_test.go | 2 +- templates.go => gen/templates.go | 2 +- .../templates}/decision_node.tmpl | 0 {templates => gen/templates}/root.tmpl | 0 .../templates}/terminal_node.tmpl | 0 .../testdata}/breast-cancer/generate-model.py | 0 .../testdata}/breast-cancer/model.json | 0 .../testdata}/breast-cancer/preds.csv | 0 .../testdata}/breast-cancer/xtest.csv | 0 {testdata => gen/testdata}/main.go | 0 .../testdata}/small-model/model.json | 0 .../testdata}/small-model/preds.csv | 0 .../testdata}/small-model/xtest.csv | 0 main.go | 39 ++----------------- main_test.go | 7 ++-- 17 files changed, 48 insertions(+), 42 deletions(-) rename generate.go => gen/generate.go (67%) rename parse.go => gen/parse.go (99%) rename parse_test.go => gen/parse_test.go (99%) rename templates.go => gen/templates.go (99%) rename {templates => gen/templates}/decision_node.tmpl (100%) rename {templates => gen/templates}/root.tmpl (100%) rename {templates => gen/templates}/terminal_node.tmpl (100%) rename {testdata => gen/testdata}/breast-cancer/generate-model.py (100%) rename {testdata => gen/testdata}/breast-cancer/model.json (100%) rename {testdata => gen/testdata}/breast-cancer/preds.csv (100%) rename {testdata => gen/testdata}/breast-cancer/xtest.csv (100%) rename {testdata => gen/testdata}/main.go (100%) rename {testdata => gen/testdata}/small-model/model.json (100%) rename {testdata => gen/testdata}/small-model/preds.csv (100%) rename {testdata => gen/testdata}/small-model/xtest.csv (100%) diff --git a/generate.go b/gen/generate.go similarity index 67% rename from generate.go rename to gen/generate.go index b0222bb..7a0ab70 100644 --- a/generate.go +++ b/gen/generate.go @@ -1,8 +1,9 @@ -package main +package gen import ( "fmt" "go/format" + "os" ) type treeFunction struct { @@ -65,3 +66,38 @@ func codegenTree(r *renderer, tree *node, level int) (string, error) { return r.executeDecisionNode(tree, level, left, right) } + + +// GenerateFile generates a .go file containing a function that implements the XGB model. +func GenerateFile( + inputJSON string, + packageName, + funcName, + outputFile string, +) error { + x, err := readModel(inputJSON) + if err != nil { + return err + } + + trees, err := readTrees(x) + if err != nil { + return err + } + + r, err := newRenderer() + if err != nil { + return err + } + + code, err := generateSource(packageName, funcName, trees, r) + if err != nil { + return err + } + + if err := os.WriteFile(outputFile, []byte(code), 0o644); err != nil { + return fmt.Errorf("error writing file: %w", err) + } + + return nil +} \ No newline at end of file diff --git a/parse.go b/gen/parse.go similarity index 99% rename from parse.go rename to gen/parse.go index 7958a37..5e0a001 100644 --- a/parse.go +++ b/gen/parse.go @@ -1,4 +1,4 @@ -package main +package gen import ( "encoding/json" diff --git a/parse_test.go b/gen/parse_test.go similarity index 99% rename from parse_test.go rename to gen/parse_test.go index c1bac04..6fc7196 100644 --- a/parse_test.go +++ b/gen/parse_test.go @@ -1,4 +1,4 @@ -package main +package gen import ( "path/filepath" diff --git a/templates.go b/gen/templates.go similarity index 99% rename from templates.go rename to gen/templates.go index 0223f0b..422e599 100644 --- a/templates.go +++ b/gen/templates.go @@ -1,4 +1,4 @@ -package main +package gen import ( "bytes" diff --git a/templates/decision_node.tmpl b/gen/templates/decision_node.tmpl similarity index 100% rename from templates/decision_node.tmpl rename to gen/templates/decision_node.tmpl diff --git a/templates/root.tmpl b/gen/templates/root.tmpl similarity index 100% rename from templates/root.tmpl rename to gen/templates/root.tmpl diff --git a/templates/terminal_node.tmpl b/gen/templates/terminal_node.tmpl similarity index 100% rename from templates/terminal_node.tmpl rename to gen/templates/terminal_node.tmpl diff --git a/testdata/breast-cancer/generate-model.py b/gen/testdata/breast-cancer/generate-model.py similarity index 100% rename from testdata/breast-cancer/generate-model.py rename to gen/testdata/breast-cancer/generate-model.py diff --git a/testdata/breast-cancer/model.json b/gen/testdata/breast-cancer/model.json similarity index 100% rename from testdata/breast-cancer/model.json rename to gen/testdata/breast-cancer/model.json diff --git a/testdata/breast-cancer/preds.csv b/gen/testdata/breast-cancer/preds.csv similarity index 100% rename from testdata/breast-cancer/preds.csv rename to gen/testdata/breast-cancer/preds.csv diff --git a/testdata/breast-cancer/xtest.csv b/gen/testdata/breast-cancer/xtest.csv similarity index 100% rename from testdata/breast-cancer/xtest.csv rename to gen/testdata/breast-cancer/xtest.csv diff --git a/testdata/main.go b/gen/testdata/main.go similarity index 100% rename from testdata/main.go rename to gen/testdata/main.go diff --git a/testdata/small-model/model.json b/gen/testdata/small-model/model.json similarity index 100% rename from testdata/small-model/model.json rename to gen/testdata/small-model/model.json diff --git a/testdata/small-model/preds.csv b/gen/testdata/small-model/preds.csv similarity index 100% rename from testdata/small-model/preds.csv rename to gen/testdata/small-model/preds.csv diff --git a/testdata/small-model/xtest.csv b/gen/testdata/small-model/xtest.csv similarity index 100% rename from testdata/small-model/xtest.csv rename to gen/testdata/small-model/xtest.csv diff --git a/main.go b/main.go index c96e1ae..cf9bbee 100644 --- a/main.go +++ b/main.go @@ -1,4 +1,4 @@ -// xgb2code generates code for an XGB model. +// This program generates code for an XGB model. package main import ( @@ -6,6 +6,8 @@ import ( "fmt" "log" "os" + + "github.com/maxmind/xgb2code/gen" ) type language string @@ -57,42 +59,9 @@ func main() { os.Exit(1) } - err := GenerateFile(*inputJSON, *packageName, *funcName, *outputFile) + err := gen.GenerateFile(*inputJSON, *packageName, *funcName, *outputFile) if err != nil { log.Fatal(err) } } -// GenerateFile generates a .go file containing a function that implements the XGB model. -func GenerateFile( - inputJSON string, - packageName, - funcName, - outputFile string, -) error { - x, err := readModel(inputJSON) - if err != nil { - return err - } - - trees, err := readTrees(x) - if err != nil { - return err - } - - r, err := newRenderer() - if err != nil { - return err - } - - code, err := generateSource(packageName, funcName, trees, r) - if err != nil { - return err - } - - if err := os.WriteFile(outputFile, []byte(code), 0o644); err != nil { - return fmt.Errorf("error writing file: %w", err) - } - - return nil -} diff --git a/main_test.go b/main_test.go index f19dbe8..bb4b6ff 100644 --- a/main_test.go +++ b/main_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "testing" + "github.com/maxmind/xgb2code/gen" "github.com/stretchr/testify/require" ) @@ -22,7 +23,7 @@ func TestGenerateAndRunModels(t *testing.T) { t.Run(test.model, func(t *testing.T) { // Generate the code. - modelDir := filepath.Join("testdata", test.model) + modelDir := filepath.Join("gen","testdata", test.model) modelFile := filepath.Join(modelDir, "model.json") packageName := "main" @@ -31,13 +32,13 @@ func TestGenerateAndRunModels(t *testing.T) { outputDir := t.TempDir() funcFile := filepath.Join(outputDir, "predict.go") - err := GenerateFile(modelFile, packageName, functionName, funcFile) + err := gen.GenerateFile(modelFile, packageName, functionName, funcFile) require.NoError(t, err) // Copy the test program and test data into place. files := []string{ - filepath.Join("testdata", "main.go"), + filepath.Join("gen","testdata", "main.go"), filepath.Join(modelDir, "xtest.csv"), filepath.Join(modelDir, "preds.csv"), }