diff --git a/common/Makefile b/common/Makefile index e47a81b89d2..28665ae30ba 100644 --- a/common/Makefile +++ b/common/Makefile @@ -1,4 +1,5 @@ GENERATED_COMMITTED := determined_common/schemas/expconf/_gen.py +GENERATION_INPUTS = ../schemas/gen.py $(shell find ../schemas/expconf -name '*.json') .PHONY: clean clean: @@ -27,11 +28,10 @@ check-gen: force-gen $(GENERATED_COMMITTED) # git reports the files as unchanged after forcibly regenerating the files: test -z "$(shell git status --porcelain $(GENERATED_COMMITTED))" -determined_common/schemas/expconf/_gen.py: ../schemas/gen.py $(shell find ../schemas/expconf -name "*.json") - ../schemas/gen.py \ - --output $@ \ - python \ - $(shell find ../schemas/expconf -name "*.json") +determined_common/schemas/expconf/_gen.py: $(GENERATION_INPUTS) + ../schemas/gen.py python\ + --package expconf \ + --output $@ .PHONY: build build: diff --git a/common/determined_common/schemas/expconf/_gen.py b/common/determined_common/schemas/expconf/_gen.py index 9929d3c920a..2447aead6bf 100644 --- a/common/determined_common/schemas/expconf/_gen.py +++ b/common/determined_common/schemas/expconf/_gen.py @@ -2816,4 +2816,3 @@ """ ), } - diff --git a/master/.golangci.yml b/master/.golangci.yml index ffbfbe76554..145b16d6141 100644 --- a/master/.golangci.yml +++ b/master/.golangci.yml @@ -8,11 +8,10 @@ run: # The exit code when at least one issue was found. issues-exit-code: 1 - # Exclude generated code. skip-files: - - pkg/schemas/schema_gen.go - - pkg/schemas/expconf/schema_gen.go - pkg/schemas/expconf/dummies.go + - pkg/schemas/*/zgen*.go + - pkg/schemas/zgen*.go output: # Linter output format. diff --git a/master/Makefile b/master/Makefile index 88903d7f374..ee93bcc22c6 100644 --- a/master/Makefile +++ b/master/Makefile @@ -1,5 +1,5 @@ -GENERATED_COMMITTED := pkg/schemas/schema_gen.go pkg/schemas/expconf/schema_gen.go -GENERATED := $(GENERATED_COMMITTED) packaging/LICENSE +GENERATED := packaging/LICENSE +GENERATION_INPUTS = ../schemas/gen.py $(shell find ./pkg/schemas/ -name '*.go') $(shell find ../schemas/expconf -name '*.json') export VERSION := $(shell cat ../VERSION) export GO111MODULE := on @@ -12,20 +12,25 @@ clean: ungen .PHONY: ungen ungen: rm -f $(GENERATED) - git checkout -- $(GENERATED_COMMITTED) + rm -f `find ./pkg/schemas/ -name 'zgen_*.go'` build/schema_gen.stamp .PHONY: gen -gen: $(GENERATED) +gen: $(GENERATED) build/schema_gen.stamp .PHONY: force-gen force-gen: - touch ../schemas/gen.py + rm -f build/schema_gen.stamp + +build/schema_gen.stamp: $(GENERATION_INPUTS) + go generate ./pkg/schemas/... + mkdir -p build + touch $@ .PHONY: check-gen -check-gen: force-gen $(GENERATED_COMMITTED) +check-gen: force-gen gen # Checking that committed, generated code is up-to-date by ensuring that # git reports the files as unchanged after forcibly regenerating the files: - test -z "$(shell git status --porcelain $(GENERATED_COMMITTED))" + test -z "$(shell git status --porcelain ./pkg/schemas)" .PHONY: get-deps get-deps: @@ -105,21 +110,3 @@ publish-dev: packaging/LICENSE: $(shell find ../tools/scripts/licenses -type f) ../tools/scripts/gen-attributions.py master $@ - -# Root schemas generated file; this contains all the schemas we have. -pkg/schemas/schema_gen.go: ../schemas/gen.py $(shell find ../schemas -regex ".*/v.*/.*\.json") - ../schemas/gen.py \ - --output $@ \ - --package schemas \ - go \ - $(shell find ../schemas -regex ".*/v.*/.*\.json") - goimports -l -local github.com/determined-ai -w $@ - -# Package-level generated file; this only generates code related to expconf objects. -pkg/schemas/expconf/schema_gen.go: ../schemas/gen.py $(shell find ../schemas/expconf -name "*.json") - ../schemas/gen.py \ - --output $@ \ - --package expconf \ - go \ - $(shell find ../schemas/expconf -name "*.json") - goimports -l -local github.com/determined-ai -w $@ diff --git a/master/pkg/schemas/copy.go b/master/pkg/schemas/copy.go new file mode 100644 index 00000000000..de583c7babb --- /dev/null +++ b/master/pkg/schemas/copy.go @@ -0,0 +1,76 @@ +package schemas + +import ( + "fmt" + "reflect" +) + +// cpy is for deep copying, but it will only work on "nice" objects, which should include our +// schema objects. Useful to other reflect code. +func cpy(v reflect.Value) reflect.Value { + // fmt.Printf("cpy(%T)\n", v.Interface()) + var out reflect.Value + + switch v.Kind() { + case reflect.Ptr: + if v.IsZero() { + return v + } + out = reflect.New(v.Elem().Type()) + out.Elem().Set(cpy(v.Elem())) + + case reflect.Interface: + if v.IsZero() { + return v.Elem() + } + out = cpy(v.Elem()) + + case reflect.Struct: + out = reflect.New(v.Type()).Elem() + // Recurse into each field of the struct. + for i := 0; i < v.NumField(); i++ { + out.Field(i).Set(cpy(v.Field(i))) + } + + case reflect.Map: + typ := reflect.MapOf(v.Type().Key(), v.Type().Elem()) + if v.IsZero() { + // unallocated map + out = reflect.Zero(typ) + } else { + out = reflect.MakeMap(typ) + // Recurse into each key of the map. + for _, key := range v.MapKeys() { + val := v.MapIndex(key) + out.SetMapIndex(key, cpy(val)) + } + } + + case reflect.Slice: + typ := reflect.SliceOf(v.Type().Elem()) + if v.IsZero() { + // unallocated slice + out = reflect.Zero(typ) + } else { + out = reflect.MakeSlice(typ, 0, v.Len()) + // Recurse into each element of the slice. + for i := 0; i < v.Len(); i++ { + val := v.Index(i) + out = reflect.Append(out, cpy(val)) + } + } + + // Assert that none of the "complex" kinds are present. + case reflect.Array, + reflect.Chan, + reflect.Func, + reflect.UnsafePointer: + panic(fmt.Sprintf("unable to cpy %T of kind %v", v.Interface(), v.Kind())) + + default: + // Simple types like string or int can be passed directly. + return v + } + + return out +} diff --git a/master/pkg/schemas/copy_test.go b/master/pkg/schemas/copy_test.go new file mode 100644 index 00000000000..599f21e0da7 --- /dev/null +++ b/master/pkg/schemas/copy_test.go @@ -0,0 +1,82 @@ +package schemas + +import ( + "reflect" + "testing" + + "gotest.tools/assert" +) + +// Copy is the non-reflect version of copy, but mostly the reflect version is called from other +// reflect code, so it's defined here in test code. +func Copy(src interface{}) interface{} { + return cpy(reflect.ValueOf(src)).Interface() +} + +func TestCopyAllocatedSlice(t *testing.T) { + src := []string{} + obj := Copy(src).([]string) + assert.DeepEqual(t, obj, src) +} + +func TestCopyUnallocatedSlice(t *testing.T) { + // Copying an unallocated slice encodes to null. + var src []string + obj := Copy(src).([]string) + assert.DeepEqual(t, obj, src) +} + +func TestCopyAllocatedMap(t *testing.T) { + // Copying an allocated map encodes to []. + src := map[string]string{} + + obj := Copy(src).(map[string]string) + assert.DeepEqual(t, obj, src) +} + +func TestCopyUnallocatedMap(t *testing.T) { + // Copying an unallocated map encodes to null. + var src map[string]string + + obj := Copy(src).(map[string]string) + assert.DeepEqual(t, obj, src) +} + +type A struct { + M map[string]string + S []int + B B +} + +type B struct { + I int + S string + C []C +} + +type C struct { + I int + D map[string]D +} + +type D struct { + I int + S string +} + +func TestCopyNested(t *testing.T) { + src := A{ + M: map[string]string{"eeny": "meeny", "miney": "moe"}, + S: []int{1, 2, 3, 4}, + B: B{ + I: 5, + S: "five", + C: []C{ + {I: 6, D: map[string]D{"one": {I: 1, S: "fish"}, "two": {I: 2, S: "fish"}}}, + {I: 6, D: map[string]D{"red": {I: 3, S: "fish"}, "blue": {I: 4, S: "fish"}}}, + }, + }, + } + obj := Copy(src).(A) + assert.DeepEqual(t, obj, src) +} diff --git a/master/pkg/schemas/defaults.go b/master/pkg/schemas/defaults.go index 7b72eb6e6b9..47a111fbf91 100644 --- a/master/pkg/schemas/defaults.go +++ b/master/pkg/schemas/defaults.go @@ -11,16 +11,10 @@ import ( // random seed based on the wall clock. type RuntimeDefaultable interface { // RuntimeDefaults must apply the runtime-defined default values. - RuntimeDefaults() + RuntimeDefaults() interface{} } -// Defaultable means a struct can have its defaults filled in automatically. -type Defaultable interface { - // DefaultSource must return a parsed json-schema object in which to find defaults. - DefaultSource() interface{} -} - -// FillDefaults will recurse through structs, maps, and slices, setting default values for any +// WithDefaults will recurse through structs, maps, and slices, setting default values for any // struct fields whose struct implements the Defaultable interface. This lets us read default // values out of json-schema automatically. // @@ -28,11 +22,6 @@ type Defaultable interface { // description to experiments with no description. This can be accomplished by implementing // the RuntimeDefaultable interface for that object. See ExperimentConfig for an example. // -// There are some objects which get their defaults from other objects' defaults. This an -// unfortunate detail of our union types which have common members that appear on the root union -// object. That's hard to reason about, and we should avoid doing that in new config objects. But -// those objects implement DefaultSource() to customize that behavior. -// // Example usage: // // config, err := expconf.ParseAnyExperimentConfigYAML(bytes) @@ -41,67 +30,70 @@ type Defaultable interface { // schemas.Merge(&config.CheckpointStorage, cluster_default_checkpoint_storage) // // // Define any remaining undefined values. -// schemas.FillDefaults(&config) +// config = schemas.WithDefaults(&config).(ExperimentConfig) // -func FillDefaults(obj interface{}) { +func WithDefaults(obj interface{}) interface{} { vObj := reflect.ValueOf(obj) - // obj can't be a non-pointer, because it edits in-place. - if vObj.Kind() != reflect.Ptr { - panic("FillDefaults must be called on a pointer") - } - // obj can't be a nil pointer, because FillDefaults(nil) doesn't make any sense. - if vObj.IsZero() { - panic("FillDefaults must be called on a non-nil pointer") - } - // Enter the recursive default filling with no default bytes for the root object (which must - // already exist), and starting with the name of the object type. name := fmt.Sprintf("%T", obj) - vObj.Elem().Set(fillDefaults(vObj.Elem(), nil, name)) + return withDefaults(vObj, nil, name).Interface() } -// fillDefaults is the recursive layer under FillDefaults. fillDefaults will return the original -// input value (not a copy of the original value). -func fillDefaults(obj reflect.Value, defaultBytes []byte, name string) reflect.Value { - switch obj.Kind() { - case reflect.Interface: - if obj.IsZero() { - // This doesn't make any sense; we need a type. - panic("got a nil interface as the obj to FillDefaults into") - } - obj.Set(fillDefaults(obj.Elem(), defaultBytes, name)) +func getDefaultSource(obj reflect.Value) interface{} { + if schema, ok := obj.Interface().(Schema); ok { + return schema.ParsedSchema() + } + return nil +} - case reflect.Ptr: +// withDefaults is the recursive layer under WithDefaults. withDefaults will return a clean copy +// of the original value, with defaults set. +func withDefaults(obj reflect.Value, defaultBytes []byte, name string) reflect.Value { + // fmt.Printf("withDefaults on %v (%T)\n", name, obj.Interface()) + + // Handle pointers first. + if obj.Kind() == reflect.Ptr { if obj.IsZero() { if defaultBytes == nil { // Nil pointer with no defaultBytes means we are done recursing. return obj } // Otherwise, since we have default bytes, allocate the new object. - obj = reflect.New(obj.Type().Elem()) + out := reflect.New(obj.Type().Elem()) // Fill the object with default bytes. - err := json.Unmarshal(defaultBytes, obj.Interface()) + err := json.Unmarshal(defaultBytes, out.Interface()) if err != nil { panic( fmt.Sprintf( - "failed to unmarshal defaultBytes into %T: %v", + "failed to unmarshal defaultBytes into %T: %q: %v", obj.Interface(), string(defaultBytes), + err.Error(), ), ) } // We already consumed defaultBytes, so set it to nil when we recurse. - obj.Elem().Set(fillDefaults(obj.Elem(), nil, name)) - } else { - // Recurse into the element inside the pointer. - obj.Elem().Set(fillDefaults(obj.Elem(), defaultBytes, name)) + return withDefaults(out, nil, name) + } + // Allocate a new pointer and set is avlue + out := reflect.New(obj.Type().Elem()) + out.Elem().Set(withDefaults(obj.Elem(), defaultBytes, name)) + return out + } + + // Next handle interfaces. + if obj.Kind() == reflect.Interface { + if obj.IsZero() { + return cpy(obj) } + return withDefaults(obj.Elem(), defaultBytes, name) + } + + var out reflect.Value + switch obj.Kind() { case reflect.Struct: defaultSource := getDefaultSource(obj) - // Create a clean copy of the object which is settable. This is necessary because if you - // have a required struct (i.e. it appears as a struct rather than a struct pointer on its - // parent object), then obj.Field(i) will not be settable. - newObj := reflect.New(obj.Type()).Elem() + out = reflect.New(obj.Type()).Elem() // Iterate through all the fields of the struct once, applying defaults. for i := 0; i < obj.NumField(); i++ { var fieldDefaultBytes []byte @@ -111,56 +103,59 @@ func fillDefaults(obj reflect.Value, defaultBytes []byte, name string) reflect.V } fieldName := fmt.Sprintf("%v.%v", name, obj.Type().Field(i).Name) // Recurse into the field. - newObj.Field(i).Set(fillDefaults(obj.Field(i), fieldDefaultBytes, fieldName)) + out.Field(i).Set(withDefaults(obj.Field(i), fieldDefaultBytes, fieldName)) } - // Use the new copy instead of the old one. - obj = newObj case reflect.Slice: - for i := 0; i < obj.Len(); i++ { - elemName := fmt.Sprintf("%v.[%v]", name, i) - // Recurse into the elem (there's no per-element defaults yet). - obj.Index(i).Set(fillDefaults(obj.Index(i), nil, elemName)) + if obj.IsZero() { + out = cpy(obj) + } else { + typ := reflect.SliceOf(obj.Type().Elem()) + out = reflect.MakeSlice(typ, 0, obj.Len()) + for i := 0; i < obj.Len(); i++ { + elemName := fmt.Sprintf("%v[%v]", name, i) + // Recurse into the elem (there's no per-element defaults yet). + out = reflect.Append(out, withDefaults(obj.Index(i), nil, elemName)) + } } case reflect.Map: - for _, key := range obj.MapKeys() { + typ := reflect.MapOf(obj.Type().Key(), obj.Type().Elem()) + out = reflect.MakeMap(typ) + iter := obj.MapRange() + for iter.Next() { + key := iter.Key() + val := iter.Value() elemName := fmt.Sprintf("%v.[%v]", name, key.Interface()) - val := obj.MapIndex(key) // Recurse into the elem (there's no per-element defaults yet). - tmp := fillDefaults(val, nil, elemName) - // Update the original value with the defaulted value. - obj.SetMapIndex(key, tmp) + out.SetMapIndex(key, withDefaults(val, nil, elemName)) } // Assert that none of the "complex" kinds are present. case reflect.Array, reflect.Chan, reflect.Func, - reflect.UnsafePointer: + reflect.UnsafePointer, + reflect.Ptr, + reflect.Interface: panic(fmt.Sprintf( - "unable to fillDefaults at %v of type %T, kind %v", name, obj.Interface(), obj.Kind(), + "unable to withDefaults at %v of type %T, kind %v", name, obj.Interface(), obj.Kind(), )) + + default: + out = cpy(obj) } - // AFTER the automatic defaults, we apply any runtime defaults. This way, we've already filled - // any nil pointers with valid objects. - if runtimeDefaultable, ok := obj.Interface().(RuntimeDefaultable); ok { - runtimeDefaultable.RuntimeDefaults() + // Any non-pointer, non-interface type may be RuntimeDefaultable. + if out.IsValid() { + if defaultable, ok := out.Interface().(RuntimeDefaultable); ok { + out = reflect.ValueOf(defaultable.RuntimeDefaults()) + } } - return obj -} + // fmt.Printf("withDefaults on %v (%T) returning %v\n", name, obj.Interface(), obj.Interface()) -// getDefaultSource gets a source of defaults from a Defaultable or Schema interface. -func getDefaultSource(v reflect.Value) interface{} { - if defaultable, ok := v.Interface().(Defaultable); ok { - return defaultable.DefaultSource() - } - if schema, ok := v.Interface().(Schema); ok { - return schema.ParsedSchema() - } - return nil + return out } // jsonNameFromJSONTag is based on encoding/json's parseTag(). diff --git a/master/pkg/schemas/defaults_test.go b/master/pkg/schemas/defaults_test.go index 4312c6c3866..2eeb9f3a78f 100644 --- a/master/pkg/schemas/defaults_test.go +++ b/master/pkg/schemas/defaults_test.go @@ -38,33 +38,33 @@ func (b BindMountV0) CompletenessValidator() *jsonschema.Schema { } func TestFillEmptyDefaults(t *testing.T) { - obj := BindMountV0{} - assertDefaults := func() { - assert.Assert(t, obj.ReadOnly != nil) - assert.Assert(t, *obj.ReadOnly == false) - assert.Assert(t, obj.Propagation != nil) - assert.Assert(t, *obj.Propagation == rprivate) + assertDefaults := func(b BindMountV0) { + assert.Assert(t, b.ReadOnly != nil) + assert.Assert(t, *b.ReadOnly == false) + assert.Assert(t, b.Propagation != nil) + assert.Assert(t, *b.Propagation == rprivate) } - FillDefaults(&obj) - assertDefaults() + obj := BindMountV0{} + out := WithDefaults(obj).(BindMountV0) + assertDefaults(out) // Make sure pointers on the input are ok. objRef := &BindMountV0{} - FillDefaults(&objRef) - assertDefaults() + objRef = WithDefaults(objRef).(*BindMountV0) + assertDefaults(*objRef) // Make sure input interfaces are ok. var iObj interface{} = &BindMountV0{} - FillDefaults(iObj) - assertDefaults() + iObj = WithDefaults(iObj) + assertDefaults(*(iObj.(*BindMountV0))) } func TestNonEmptyDefaults(t *testing.T) { obj := BindMountV0{ReadOnly: ptrs.BoolPtr(true), Propagation: ptrs.StringPtr("asdf")} - FillDefaults(&obj) - assert.Assert(t, *obj.ReadOnly == true) - assert.Assert(t, *obj.Propagation == "asdf") + out := WithDefaults(obj).(BindMountV0) + assert.Assert(t, *out.ReadOnly == true) + assert.Assert(t, *out.Propagation == "asdf") } func TestArrayOfDefautables(t *testing.T) { @@ -73,9 +73,9 @@ func TestArrayOfDefautables(t *testing.T) { obj = append(obj, BindMountV0{}) obj = append(obj, BindMountV0{}) - FillDefaults(&obj) + out := WithDefaults(obj).([]BindMountV0) - for _, b := range obj { + for _, b := range out { assert.Assert(t, b.ReadOnly != nil) assert.Assert(t, *b.ReadOnly == false) assert.Assert(t, b.Propagation != nil) diff --git a/master/pkg/schemas/expconf/schema_gen.go b/master/pkg/schemas/expconf/schema_gen.go deleted file mode 100644 index e436d45871c..00000000000 --- a/master/pkg/schemas/expconf/schema_gen.go +++ /dev/null @@ -1,573 +0,0 @@ -// This is a generated file. Editing it will make you sad. - -package expconf - -import ( - "github.com/santhosh-tekuri/jsonschema/v2" - - "github.com/determined-ai/determined/master/pkg/schemas" -) - -func (b BindMountV0) ParsedSchema() interface{} { - return schemas.ParsedBindMountV0() -} - -func (b BindMountV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/bind-mount.json") -} - -func (b BindMountV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/bind-mount.json") -} - -func (c CheckpointStorageConfigV0) ParsedSchema() interface{} { - return schemas.ParsedCheckpointStorageConfigV0() -} - -func (c CheckpointStorageConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/checkpoint-storage.json") -} - -func (c CheckpointStorageConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/checkpoint-storage.json") -} - -func (g GCSDataLayerConfigV0) ParsedSchema() interface{} { - return schemas.ParsedGCSDataLayerConfigV0() -} - -func (g GCSDataLayerConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/data-layer-gcs.json") -} - -func (g GCSDataLayerConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/data-layer-gcs.json") -} - -func (s S3DataLayerConfigV0) ParsedSchema() interface{} { - return schemas.ParsedS3DataLayerConfigV0() -} - -func (s S3DataLayerConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/data-layer-s3.json") -} - -func (s S3DataLayerConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/data-layer-s3.json") -} - -func (s SharedFSDataLayerConfigV0) ParsedSchema() interface{} { - return schemas.ParsedSharedFSDataLayerConfigV0() -} - -func (s SharedFSDataLayerConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/data-layer-shared-fs.json") -} - -func (s SharedFSDataLayerConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/data-layer-shared-fs.json") -} - -func (d DataLayerConfigV0) ParsedSchema() interface{} { - return schemas.ParsedDataLayerConfigV0() -} - -func (d DataLayerConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/data-layer.json") -} - -func (d DataLayerConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/data-layer.json") -} - -func (e EnvironmentImageMapV0) ParsedSchema() interface{} { - return schemas.ParsedEnvironmentImageMapV0() -} - -func (e EnvironmentImageMapV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/environment-image-map.json") -} - -func (e EnvironmentImageMapV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/environment-image-map.json") -} - -func (e EnvironmentImageV0) ParsedSchema() interface{} { - return schemas.ParsedEnvironmentImageV0() -} - -func (e EnvironmentImageV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/environment-image.json") -} - -func (e EnvironmentImageV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/environment-image.json") -} - -func (e EnvironmentVariablesMapV0) ParsedSchema() interface{} { - return schemas.ParsedEnvironmentVariablesMapV0() -} - -func (e EnvironmentVariablesMapV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/environment-variables-map.json") -} - -func (e EnvironmentVariablesMapV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/environment-variables-map.json") -} - -func (e EnvironmentVariablesV0) ParsedSchema() interface{} { - return schemas.ParsedEnvironmentVariablesV0() -} - -func (e EnvironmentVariablesV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/environment-variables.json") -} - -func (e EnvironmentVariablesV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/environment-variables.json") -} - -func (e EnvironmentConfigV0) ParsedSchema() interface{} { - return schemas.ParsedEnvironmentConfigV0() -} - -func (e EnvironmentConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/environment.json") -} - -func (e EnvironmentConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/environment.json") -} - -func (e ExperimentConfigV0) ParsedSchema() interface{} { - return schemas.ParsedExperimentConfigV0() -} - -func (e ExperimentConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/experiment.json") -} - -func (e ExperimentConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/experiment.json") -} - -func (g GCSConfigV0) ParsedSchema() interface{} { - return schemas.ParsedGCSConfigV0() -} - -func (g GCSConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/gcs.json") -} - -func (g GCSConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/gcs.json") -} - -func (h HDFSConfigV0) ParsedSchema() interface{} { - return schemas.ParsedHDFSConfigV0() -} - -func (h HDFSConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/hdfs.json") -} - -func (h HDFSConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/hdfs.json") -} - -func (c CategoricalHyperparameterV0) ParsedSchema() interface{} { - return schemas.ParsedCategoricalHyperparameterV0() -} - -func (c CategoricalHyperparameterV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/hyperparameter-categorical.json") -} - -func (c CategoricalHyperparameterV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/hyperparameter-categorical.json") -} - -func (c ConstHyperparameterV0) ParsedSchema() interface{} { - return schemas.ParsedConstHyperparameterV0() -} - -func (c ConstHyperparameterV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/hyperparameter-const.json") -} - -func (c ConstHyperparameterV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/hyperparameter-const.json") -} - -func (d DoubleHyperparameterV0) ParsedSchema() interface{} { - return schemas.ParsedDoubleHyperparameterV0() -} - -func (d DoubleHyperparameterV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/hyperparameter-double.json") -} - -func (d DoubleHyperparameterV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/hyperparameter-double.json") -} - -func (i IntHyperparameterV0) ParsedSchema() interface{} { - return schemas.ParsedIntHyperparameterV0() -} - -func (i IntHyperparameterV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/hyperparameter-int.json") -} - -func (i IntHyperparameterV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/hyperparameter-int.json") -} - -func (l LogHyperparameterV0) ParsedSchema() interface{} { - return schemas.ParsedLogHyperparameterV0() -} - -func (l LogHyperparameterV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/hyperparameter-log.json") -} - -func (l LogHyperparameterV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/hyperparameter-log.json") -} - -func (h HyperparameterV0) ParsedSchema() interface{} { - return schemas.ParsedHyperparameterV0() -} - -func (h HyperparameterV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/hyperparameter.json") -} - -func (h HyperparameterV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/hyperparameter.json") -} - -func (h HyperparametersV0) ParsedSchema() interface{} { - return schemas.ParsedHyperparametersV0() -} - -func (h HyperparametersV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/hyperparameters.json") -} - -func (h HyperparametersV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/hyperparameters.json") -} - -func (i InternalConfigV0) ParsedSchema() interface{} { - return schemas.ParsedInternalConfigV0() -} - -func (i InternalConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/internal.json") -} - -func (i InternalConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/internal.json") -} - -func (k KerberosConfigV0) ParsedSchema() interface{} { - return schemas.ParsedKerberosConfigV0() -} - -func (k KerberosConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/kerberos.json") -} - -func (k KerberosConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/kerberos.json") -} - -func (l LengthV0) ParsedSchema() interface{} { - return schemas.ParsedLengthV0() -} - -func (l LengthV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/length.json") -} - -func (l LengthV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/length.json") -} - -func (n NativeConfigV0) ParsedSchema() interface{} { - return schemas.ParsedNativeConfigV0() -} - -func (n NativeConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/native.json") -} - -func (n NativeConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/native.json") -} - -func (o OptimizationsConfigV0) ParsedSchema() interface{} { - return schemas.ParsedOptimizationsConfigV0() -} - -func (o OptimizationsConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/optimizations.json") -} - -func (o OptimizationsConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/optimizations.json") -} - -func (r ReproducibilityConfigV0) ParsedSchema() interface{} { - return schemas.ParsedReproducibilityConfigV0() -} - -func (r ReproducibilityConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/reproducibility.json") -} - -func (r ReproducibilityConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/reproducibility.json") -} - -func (r ResourcesConfigV0) ParsedSchema() interface{} { - return schemas.ParsedResourcesConfigV0() -} - -func (r ResourcesConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/resources.json") -} - -func (r ResourcesConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/resources.json") -} - -func (s S3ConfigV0) ParsedSchema() interface{} { - return schemas.ParsedS3ConfigV0() -} - -func (s S3ConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/s3.json") -} - -func (s S3ConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/s3.json") -} - -func (a AdaptiveASHAConfigV0) ParsedSchema() interface{} { - return schemas.ParsedAdaptiveASHAConfigV0() -} - -func (a AdaptiveASHAConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/searcher-adaptive-asha.json") -} - -func (a AdaptiveASHAConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/searcher-adaptive-asha.json") -} - -func (a AdaptiveSimpleConfigV0) ParsedSchema() interface{} { - return schemas.ParsedAdaptiveSimpleConfigV0() -} - -func (a AdaptiveSimpleConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/searcher-adaptive-simple.json") -} - -func (a AdaptiveSimpleConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/searcher-adaptive-simple.json") -} - -func (a AdaptiveConfigV0) ParsedSchema() interface{} { - return schemas.ParsedAdaptiveConfigV0() -} - -func (a AdaptiveConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/searcher-adaptive.json") -} - -func (a AdaptiveConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/searcher-adaptive.json") -} - -func (a AsyncHalvingConfigV0) ParsedSchema() interface{} { - return schemas.ParsedAsyncHalvingConfigV0() -} - -func (a AsyncHalvingConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/searcher-async-halving.json") -} - -func (a AsyncHalvingConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/searcher-async-halving.json") -} - -func (g GridConfigV0) ParsedSchema() interface{} { - return schemas.ParsedGridConfigV0() -} - -func (g GridConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/searcher-grid.json") -} - -func (g GridConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/searcher-grid.json") -} - -func (p PBTConfigV0) ParsedSchema() interface{} { - return schemas.ParsedPBTConfigV0() -} - -func (p PBTConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/searcher-pbt.json") -} - -func (p PBTConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/searcher-pbt.json") -} - -func (r RandomConfigV0) ParsedSchema() interface{} { - return schemas.ParsedRandomConfigV0() -} - -func (r RandomConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/searcher-random.json") -} - -func (r RandomConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/searcher-random.json") -} - -func (s SingleConfigV0) ParsedSchema() interface{} { - return schemas.ParsedSingleConfigV0() -} - -func (s SingleConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/searcher-single.json") -} - -func (s SingleConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/searcher-single.json") -} - -func (s SyncHalvingConfigV0) ParsedSchema() interface{} { - return schemas.ParsedSyncHalvingConfigV0() -} - -func (s SyncHalvingConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/searcher-sync-halving.json") -} - -func (s SyncHalvingConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/searcher-sync-halving.json") -} - -func (s SearcherConfigV0) ParsedSchema() interface{} { - return schemas.ParsedSearcherConfigV0() -} - -func (s SearcherConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/searcher.json") -} - -func (s SearcherConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/searcher.json") -} - -func (s SecurityConfigV0) ParsedSchema() interface{} { - return schemas.ParsedSecurityConfigV0() -} - -func (s SecurityConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/security.json") -} - -func (s SecurityConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/security.json") -} - -func (s SharedFSConfigV0) ParsedSchema() interface{} { - return schemas.ParsedSharedFSConfigV0() -} - -func (s SharedFSConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/shared-fs.json") -} - -func (s SharedFSConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/shared-fs.json") -} - -func (t TensorboardStorageConfigV0) ParsedSchema() interface{} { - return schemas.ParsedTensorboardStorageConfigV0() -} - -func (t TensorboardStorageConfigV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/tensorboard-storage.json") -} - -func (t TensorboardStorageConfigV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/tensorboard-storage.json") -} - -func (t TestRootV0) ParsedSchema() interface{} { - return schemas.ParsedTestRootV0() -} - -func (t TestRootV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/test-root.json") -} - -func (t TestRootV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/test-root.json") -} - -func (t TestSubV0) ParsedSchema() interface{} { - return schemas.ParsedTestSubV0() -} - -func (t TestSubV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/test-sub.json") -} - -func (t TestSubV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/test-sub.json") -} - -func (t TestUnionAV0) ParsedSchema() interface{} { - return schemas.ParsedTestUnionAV0() -} - -func (t TestUnionAV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/test-union-a.json") -} - -func (t TestUnionAV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/test-union-a.json") -} - -func (t TestUnionBV0) ParsedSchema() interface{} { - return schemas.ParsedTestUnionBV0() -} - -func (t TestUnionBV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/test-union-b.json") -} - -func (t TestUnionBV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/test-union-b.json") -} - -func (t TestUnionV0) ParsedSchema() interface{} { - return schemas.ParsedTestUnionV0() -} - -func (t TestUnionV0) SanityValidator() *jsonschema.Schema { - return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/test-union.json") -} - -func (t TestUnionV0) CompletenessValidator() *jsonschema.Schema { - return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/test-union.json") -} diff --git a/master/pkg/schemas/expconf/schema_test.go b/master/pkg/schemas/expconf/schema_test.go index b41c3a8ed80..0ffb68c2e8f 100644 --- a/master/pkg/schemas/expconf/schema_test.go +++ b/master/pkg/schemas/expconf/schema_test.go @@ -47,7 +47,7 @@ func (tc SchemaTestCase) CheckMatches(t *testing.T) { byts, err := json.Marshal(tc.Case) assert.NilError(t, err) for _, url := range *tc.Matches { - schema := schemas.GetCompletenessValidator(url) + schema := schemas.GetSanityValidator(url) err := schema.Validate(bytes.NewReader(byts)) if err == nil { continue @@ -65,7 +65,7 @@ func (tc SchemaTestCase) CheckErrors(t *testing.T) { byts, err := json.Marshal(tc.Case) assert.NilError(t, err) for url, expectedErrors := range *tc.Errors { - schema := schemas.GetCompletenessValidator(url) + schema := schemas.GetSanityValidator(url) err := schema.Validate(bytes.NewReader(byts)) if err == nil { t.Errorf("expected error matching %v but got none", url) @@ -108,6 +108,8 @@ func objectForURL(url string) interface{} { // case "http://determined.ai/schemas/expconf/v0/hyperparameter.json", // "http://determined.ai/schemas/expconf/v0/hyperparameter-int.json": // return &Hyperparameter{} + + // Test-related structs. case "http://determined.ai/schemas/expconf/v0/test-root.json": return &TestRootV0{} case "http://determined.ai/schemas/expconf/v0/test-union.json", @@ -126,7 +128,6 @@ func clearRuntimeDefaults(obj *interface{}, defaulted interface{}) { // If defaulted is a "*" and obj is not nil, set obj to be "*" too so they match. if s, ok := defaulted.(string); ok && s == "*" { if *obj != nil { - fmt.Fprintf(os.Stderr, "%v matches %v\n", *obj, defaulted) *obj = "*" } } @@ -180,7 +181,7 @@ func (tc SchemaTestCase) CheckDefaulted(t *testing.T) { err = json.Unmarshal(byts, &obj) assert.NilError(t, err) - schemas.FillDefaults(&obj) + obj = schemas.WithDefaults(obj) // Compare json-to-json. defaultedBytes, err := json.Marshal(obj) @@ -222,13 +223,15 @@ func (tc SchemaTestCase) CheckRoundTrip(t *testing.T) { assert.DeepEqual(t, obj, cpy) // Round-trip again after defaults. - schemas.FillDefaults(&obj) + obj = schemas.WithDefaults(obj) + jByts, err = json.Marshal(obj) assert.NilError(t, err) + cpy = objectForURL(url) - schemas.FillDefaults(&cpy) err = json.Unmarshal(jByts, &cpy) assert.NilError(t, err) + assert.DeepEqual(t, obj, cpy) } diff --git a/master/pkg/schemas/expconf/test_types.go b/master/pkg/schemas/expconf/test_types.go index b418af39db1..7cfeb0f32b8 100644 --- a/master/pkg/schemas/expconf/test_types.go +++ b/master/pkg/schemas/expconf/test_types.go @@ -1,18 +1,17 @@ package expconf +// Define types that are only used in testing. + import ( "encoding/json" - // "fmt" - // "time" - // petname "github.com/dustinkirkland/golang-petname" "github.com/pkg/errors" "github.com/determined-ai/determined/master/pkg/ptrs" - "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/union" ) +//go:generate ../gen.sh // TestUnionAV0 is exported. type TestUnionAV0 struct { Type string `json:"type"` @@ -21,6 +20,7 @@ type TestUnionAV0 struct { CommonVal *string `json:"common_val"` } +//go:generate ../gen.sh // TestUnionBV0 is exported. type TestUnionBV0 struct { Type string `json:"type"` @@ -29,14 +29,11 @@ type TestUnionBV0 struct { CommonVal *string `json:"common_val"` } +//go:generate ../gen.sh // TestUnionV0 is exported. type TestUnionV0 struct { A *TestUnionAV0 `union:"type,a" json:"-"` B *TestUnionBV0 `union:"type,b" json:"-"` - - // I think common memebers should not exist, but for now they do and you can handle them with - // the DefaultSource interface. - CommonVal *string `json:"common_val"` } // UnmarshalJSON is exported. @@ -53,17 +50,14 @@ func (t TestUnionV0) MarshalJSON() ([]byte, error) { return union.Marshal(t) } -// DefaultSource implements the Defaultable interface. -func (t TestUnionV0) DefaultSource() interface{} { - return schemas.UnionDefaultSchema(t) -} - +//go:generate ../gen.sh // TestSubV0 is exported. type TestSubV0 struct { // defaultable; pointer. ValY *string `json:"val_y"` } +//go:generate ../gen.sh // TestRootV0 is exported. type TestRootV0 struct { // required; non-pointer. @@ -76,8 +70,9 @@ type TestRootV0 struct { } // RuntimeDefaults implements the RuntimeDefaultable interface. -func (t *TestRootV0) RuntimeDefaults() { +func (t TestRootV0) RuntimeDefaults() interface{} { if t.RuntimeDefaultable == nil { t.RuntimeDefaultable = ptrs.IntPtr(10) } + return t } diff --git a/master/pkg/schemas/expconf/zgen_test_root_v0.go b/master/pkg/schemas/expconf/zgen_test_root_v0.go new file mode 100644 index 00000000000..4d1663cde29 --- /dev/null +++ b/master/pkg/schemas/expconf/zgen_test_root_v0.go @@ -0,0 +1,48 @@ +// Code generated by gen.py. DO NOT EDIT. + +package expconf + +import ( + "github.com/santhosh-tekuri/jsonschema/v2" + + "github.com/determined-ai/determined/master/pkg/schemas" +) + +func (t TestRootV0) GetValX() int { + return t.ValX +} + +func (t TestRootV0) GetSubObj() TestSubV0 { + if t.SubObj == nil { + panic("You must call WithDefaults on TestRootV0 before .GetSubObj") + } + return *t.SubObj +} + +func (t TestRootV0) GetSubUnion() *TestUnionV0 { + return t.SubUnion +} + +func (t TestRootV0) GetRuntimeDefaultable() *int { + return t.RuntimeDefaultable +} + +func (t TestRootV0) WithDefaults() TestRootV0 { + return schemas.WithDefaults(t).(TestRootV0) +} + +func (t TestRootV0) Merge(other TestRootV0) TestRootV0 { + return schemas.Merge(t, other).(TestRootV0) +} + +func (t TestRootV0) ParsedSchema() interface{} { + return schemas.ParsedTestRootV0() +} + +func (t TestRootV0) SanityValidator() *jsonschema.Schema { + return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/test-root.json") +} + +func (t TestRootV0) CompletenessValidator() *jsonschema.Schema { + return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/test-root.json") +} diff --git a/master/pkg/schemas/expconf/zgen_test_sub_v0.go b/master/pkg/schemas/expconf/zgen_test_sub_v0.go new file mode 100644 index 00000000000..c30baf7f602 --- /dev/null +++ b/master/pkg/schemas/expconf/zgen_test_sub_v0.go @@ -0,0 +1,36 @@ +// Code generated by gen.py. DO NOT EDIT. + +package expconf + +import ( + "github.com/santhosh-tekuri/jsonschema/v2" + + "github.com/determined-ai/determined/master/pkg/schemas" +) + +func (t TestSubV0) GetValY() string { + if t.ValY == nil { + panic("You must call WithDefaults on TestSubV0 before .GetValY") + } + return *t.ValY +} + +func (t TestSubV0) WithDefaults() TestSubV0 { + return schemas.WithDefaults(t).(TestSubV0) +} + +func (t TestSubV0) Merge(other TestSubV0) TestSubV0 { + return schemas.Merge(t, other).(TestSubV0) +} + +func (t TestSubV0) ParsedSchema() interface{} { + return schemas.ParsedTestSubV0() +} + +func (t TestSubV0) SanityValidator() *jsonschema.Schema { + return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/test-sub.json") +} + +func (t TestSubV0) CompletenessValidator() *jsonschema.Schema { + return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/test-sub.json") +} diff --git a/master/pkg/schemas/expconf/zgen_test_union_av0.go b/master/pkg/schemas/expconf/zgen_test_union_av0.go new file mode 100644 index 00000000000..ce23573286a --- /dev/null +++ b/master/pkg/schemas/expconf/zgen_test_union_av0.go @@ -0,0 +1,44 @@ +// Code generated by gen.py. DO NOT EDIT. + +package expconf + +import ( + "github.com/santhosh-tekuri/jsonschema/v2" + + "github.com/determined-ai/determined/master/pkg/schemas" +) + +func (t TestUnionAV0) GetType() string { + return t.Type +} + +func (t TestUnionAV0) GetValA() int { + return t.ValA +} + +func (t TestUnionAV0) GetCommonVal() string { + if t.CommonVal == nil { + panic("You must call WithDefaults on TestUnionAV0 before .GetCommonVal") + } + return *t.CommonVal +} + +func (t TestUnionAV0) WithDefaults() TestUnionAV0 { + return schemas.WithDefaults(t).(TestUnionAV0) +} + +func (t TestUnionAV0) Merge(other TestUnionAV0) TestUnionAV0 { + return schemas.Merge(t, other).(TestUnionAV0) +} + +func (t TestUnionAV0) ParsedSchema() interface{} { + return schemas.ParsedTestUnionAV0() +} + +func (t TestUnionAV0) SanityValidator() *jsonschema.Schema { + return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/test-union-a.json") +} + +func (t TestUnionAV0) CompletenessValidator() *jsonschema.Schema { + return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/test-union-a.json") +} diff --git a/master/pkg/schemas/expconf/zgen_test_union_bv0.go b/master/pkg/schemas/expconf/zgen_test_union_bv0.go new file mode 100644 index 00000000000..828337cf076 --- /dev/null +++ b/master/pkg/schemas/expconf/zgen_test_union_bv0.go @@ -0,0 +1,44 @@ +// Code generated by gen.py. DO NOT EDIT. + +package expconf + +import ( + "github.com/santhosh-tekuri/jsonschema/v2" + + "github.com/determined-ai/determined/master/pkg/schemas" +) + +func (t TestUnionBV0) GetType() string { + return t.Type +} + +func (t TestUnionBV0) GetValB() int { + return t.ValB +} + +func (t TestUnionBV0) GetCommonVal() string { + if t.CommonVal == nil { + panic("You must call WithDefaults on TestUnionBV0 before .GetCommonVal") + } + return *t.CommonVal +} + +func (t TestUnionBV0) WithDefaults() TestUnionBV0 { + return schemas.WithDefaults(t).(TestUnionBV0) +} + +func (t TestUnionBV0) Merge(other TestUnionBV0) TestUnionBV0 { + return schemas.Merge(t, other).(TestUnionBV0) +} + +func (t TestUnionBV0) ParsedSchema() interface{} { + return schemas.ParsedTestUnionBV0() +} + +func (t TestUnionBV0) SanityValidator() *jsonschema.Schema { + return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/test-union-b.json") +} + +func (t TestUnionBV0) CompletenessValidator() *jsonschema.Schema { + return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/test-union-b.json") +} diff --git a/master/pkg/schemas/expconf/zgen_test_union_v0.go b/master/pkg/schemas/expconf/zgen_test_union_v0.go new file mode 100644 index 00000000000..55860d7ece8 --- /dev/null +++ b/master/pkg/schemas/expconf/zgen_test_union_v0.go @@ -0,0 +1,59 @@ +// Code generated by gen.py. DO NOT EDIT. + +package expconf + +import ( + "github.com/santhosh-tekuri/jsonschema/v2" + + "github.com/determined-ai/determined/master/pkg/schemas" +) + +func (t TestUnionV0) GetUnionMember() interface{} { + if t.A != nil { + return nil + } + if t.B != nil { + return nil + } + panic("no union member defined") +} + +func (t TestUnionV0) GetCommonVal() string { + if t.A != nil { + return t.A.GetCommonVal() + } + if t.B != nil { + return t.B.GetCommonVal() + } + panic("no union member defined") +} + +func (t TestUnionV0) GetType() string { + if t.A != nil { + return t.A.GetType() + } + if t.B != nil { + return t.B.GetType() + } + panic("no union member defined") +} + +func (t TestUnionV0) WithDefaults() TestUnionV0 { + return schemas.WithDefaults(t).(TestUnionV0) +} + +func (t TestUnionV0) Merge(other TestUnionV0) TestUnionV0 { + return schemas.Merge(t, other).(TestUnionV0) +} + +func (t TestUnionV0) ParsedSchema() interface{} { + return schemas.ParsedTestUnionV0() +} + +func (t TestUnionV0) SanityValidator() *jsonschema.Schema { + return schemas.GetSanityValidator("http://determined.ai/schemas/expconf/v0/test-union.json") +} + +func (t TestUnionV0) CompletenessValidator() *jsonschema.Schema { + return schemas.GetCompletenessValidator("http://determined.ai/schemas/expconf/v0/test-union.json") +} diff --git a/master/pkg/schemas/gen.sh b/master/pkg/schemas/gen.sh new file mode 100755 index 00000000000..8b12ab25bac --- /dev/null +++ b/master/pkg/schemas/gen.sh @@ -0,0 +1,2 @@ +#!/usr/bin/sh +../../../../schemas/gen.py go-struct "$@" diff --git a/master/pkg/schemas/lint.go b/master/pkg/schemas/lint.go deleted file mode 100644 index 8a12a482f8e..00000000000 --- a/master/pkg/schemas/lint.go +++ /dev/null @@ -1,70 +0,0 @@ -package schemas - -// lint.go is full of functions for writing unit tests that ensure certain assumptions about the -// nature of our json-schema values and related go types hold constant. -// -// lint.go is not useful outside of writing tests. - -import ( - "encoding/json" - "reflect" - - "github.com/pkg/errors" -) - -// LintStructDefaults asserts that all fields with json-schema defaults correspond to pointers types -// so that fill-defaults will work. It also asserts that all default bytes are Unmarshalable. -// -// LintStructDefaults does not recurse; you should call for each generated struct. -// -// Since Defaultable and Schema are both defined on raw structs, you must call this on a struct, not -// a struct pointer. -func LintStructDefaults(x interface{}) []error { - t := reflect.TypeOf(x) - if t.Kind() != reflect.Struct { - return []error{errors.Errorf( - "LintStructDefaults can only be called on a struct-like input, not %v", t.Name(), - )} - } - defaultSource := getDefaultSource(reflect.ValueOf(x)) - if defaultSource == nil { - return []error{errors.Errorf( - "LintStructDefaults called on %v which has no default source", t.Name(), - )} - } - var out []error - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - // Is there a default for this field's tag? - fieldDefaultBytes := findDefaultInSchema(defaultSource, field) - if fieldDefaultBytes == nil { - continue - } - // Is this field a pointer type? - if field.Type.Kind() != reflect.Ptr { - out = append(out, errors.Errorf( - "%v.%v has default bytes '%v' but it has non-pointer type '%v'", - t.Name(), - field.Name, - string(fieldDefaultBytes), - field.Type, - )) - } - // Can we unmarshal defaultBtyes into a pointer of the field type? - fieldObj := reflect.New(field.Type).Interface() - err := json.Unmarshal(fieldDefaultBytes, fieldObj) - if err != nil { - out = append(out, - errors.Wrapf( - err, - "failed to unmarshal defaultBytes of '%v' for %v.%v of type '%v'", - string(fieldDefaultBytes), - t.Name(), - field.Name, - field.Type, - ), - ) - } - } - return out -} diff --git a/master/pkg/schemas/lint_test.go b/master/pkg/schemas/lint_test.go deleted file mode 100644 index e9069b9e301..00000000000 --- a/master/pkg/schemas/lint_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package schemas - -import ( - "encoding/json" - "strings" - "testing" - - "gotest.tools/assert" -) - -type BadDefaultStruct struct { - Val string `json:"val"` -} - -func (b BadDefaultStruct) DefaultSource() interface{} { - raw := `{ - "properties": { - "val": { - "default": "val-default" - } - } - }` - var out interface{} - err := json.Unmarshal([]byte(raw), &out) - if err != nil { - panic(err.Error()) - } - return out -} - -func TestLintStructDefaults(t *testing.T) { - var b BadDefaultStruct - errs := LintStructDefaults(b) - assert.Assert(t, len(errs) == 1) - assert.Assert(t, strings.Contains(errs[0].Error(), "non-pointer type")) -} diff --git a/master/pkg/schemas/merge.go b/master/pkg/schemas/merge.go index 10aac8f08cb..4ebe6e9219e 100644 --- a/master/pkg/schemas/merge.go +++ b/master/pkg/schemas/merge.go @@ -7,23 +7,17 @@ import ( // Mergable means an object can have custom behvaiors for schemas.Merge. type Mergable interface { - // Merge takes a non-nil version of itself and merges into itself. - Merge(interface{}) + // Merge should take a struct and return the same struct. + Merge(interface{}) interface{} } -// Merge will recurse through structs, setting empty values in obj with an non-empty values (copy -// semantics). Both obj and src must be the same type of obj, but obj must be a pointer so that it -// is settable. +// Merge will recurse through two objects of the same type and return a merged version +// (a clean copy). // -// The default behavior for merging maps is to merge keys from src to obj, and the default -// behavior for slices is to copy them. This is analgous to how json.Unmarshal treats maps and -// slices. However, the default merging behavior for an object can be overwritten by implementing -// the Mergable interface. An example of this is BindMountsConfig. -// -// Merge is intelligent enough to handle union types automatically. In those cases, Merge is -// recursive as log as the obj getting filled either does not have any of the union types defined -// or if it has the same union type defined as the src. That is, a S3 checkpoint storage object -// will never be used as a src to try to fill a SharedFS checkpoint storage object. +// The default behavior for merging maps is to include keys from both src and obj, while the default +// behavior for slices is to use one or the other. This is analgous to how json.Unmarshal treats +// maps and slices. However, the default merging behavior for an object can be overwritten by +// implementing the Mergable interface. An example of this is BindMountsConfig. // // Example usage: // @@ -32,28 +26,20 @@ type Mergable interface { // var cluster_default_checkpoint_storage expconf.CheckpointStorage = ... // // // Use the cluster checkpoint storage if the user did not specify one. -// schemas.Merge(&config.CheckpointStorage, cluster_default_checkpoint_storage) +// config.CheckpointStorage = schemas.Merge( +// config.CheckpointStorage, cluster_default_checkpoint_storage).(CheckpointStorageConfig +// ) // -func Merge(obj interface{}, src interface{}) { +func Merge(obj interface{}, src interface{}) interface{} { name := fmt.Sprintf("%T", obj) vObj := reflect.ValueOf(obj) vSrc := reflect.ValueOf(src) - // obj should always be a pointer, because Merge(&x, y) will act on x in-place. - if vObj.Kind() != reflect.Ptr { - panic("non-pointer in merge") - } - // obj can't be a nil pointer, because Merge(nil, y) doesn't make any sense. - if vObj.IsZero() { - panic("nil pointer in merge") - } - - // *obj must have the same type as src. - assertTypeMatch(vObj.Elem(), vSrc) + // obj must have the same type as src. + assertTypeMatch(vObj, vSrc) - // Think: *obj = merge(*obj, src) - vObj.Elem().Set(merge(vObj.Elem(), vSrc, name)) + return merge(vObj, vSrc, name).Interface() } func assertTypeMatch(obj reflect.Value, src reflect.Value) { @@ -76,139 +62,97 @@ func merge(obj reflect.Value, src reflect.Value, name string) reflect.Value { // fmt.Printf("merge(%T, %T, %v)\n", obj.Interface(), src.Interface(), name) assertTypeMatch(obj, src) - // If src is nil, return obj unmodified. - if src.Kind() == reflect.Ptr || src.Kind() == reflect.Interface { - if src.IsZero() { + // Always handle pointers first. + if obj.Kind() == reflect.Ptr { + if obj.IsZero() { + return cpy(src) + } else if src.IsZero() { return cpy(obj) } + out := reflect.New(obj.Elem().Type()) + out.Elem().Set(merge(obj.Elem(), src.Elem(), name)) + return out } - // Handle nil pointers by simply copying the src. - if obj.Kind() == reflect.Ptr && obj.IsZero() { - return cpy(src) + // Next handle interfaces. + if obj.Kind() == reflect.Interface { + if obj.IsZero() { + return cpy(src) + } else if src.IsZero() { + return cpy(obj) + } + return merge(obj.Elem(), src.Elem(), name) } - // If the object is Mergable, we only call Merge on it and return it as-is. - if mergeable, ok := obj.Addr().Interface().(Mergable); ok { - mergeable.Merge(src.Interface()) - return obj + // Detect Mergables. + if mergeable, ok := obj.Interface().(Mergable); ok { + return reflect.ValueOf(mergeable.Merge(src.Interface())) } switch obj.Kind() { - case reflect.Ptr: - // We already checked for nil pointers, so just recurse on the Elem of the value. - obj.Elem().Set(merge(obj.Elem(), src.Elem(), name)) - case reflect.Struct: - // Detect what to do with union fields. There are 4 important cases: - // 1. src has a union member, obj does not -> recurse into that field. - // 2. src has a union member, obj has the same one -> recurse into that field. - // 3. src has a union member, obj has the different one -> do not recurse. - // 4. src has no union member -> recursing is a noop and doesn't matter - // Logically, this reduces to: - // - if obj has a union member, src does not have the same one -> don't recurse. - // - else -> recurse - recurseIntoUnion := true - for i := 0; i < src.NumField(); i++ { - structField := src.Type().Field(i) - if _, ok := structField.Tag.Lookup("union"); ok { - if !obj.Field(i).IsZero() && src.Field(i).IsZero() { - recurseIntoUnion = false - break - } - } - } // Recurse into each field of the struct. + out := reflect.New(obj.Type()).Elem() for i := 0; i < src.NumField(); i++ { structField := src.Type().Field(i) - if _, ok := structField.Tag.Lookup("union"); ok && !recurseIntoUnion { - continue - } fieldName := fmt.Sprintf("%v.%v", name, structField.Name) x := merge(obj.Field(i), src.Field(i), fieldName) - obj.Field(i).Set(x) + out.Field(i).Set(x) } + return out case reflect.Map: - // Maps get fused together; all input keys are written into the output map. - for _, key := range src.MapKeys() { - // Ensure key is not already set in obj. - if objVal := obj.MapIndex(key); objVal.IsValid() { - continue + // Handle unallocated maps on either input. + if src.IsZero() { + return cpy(obj) + } else if obj.IsZero() { + return cpy(src) + } + // allocate a new map + typ := reflect.MapOf(obj.Type().Key(), obj.Type().Elem()) + out := reflect.MakeMap(typ) + // Iterate through keys and objects in obj. + iter := obj.MapRange() + for iter.Next() { + key := iter.Key() + objVal := iter.Value() + if srcVal := src.MapIndex(key); srcVal.IsValid() { + // Key present in both maps. + out.SetMapIndex(key, merge(objVal, srcVal, name)) + } else { + // Key is unique to obj. + out.SetMapIndex(key, cpy(objVal)) } - val := src.MapIndex(key) - obj.SetMapIndex(key, cpy(val)) } + // Check for keys only present in src. + iter = src.MapRange() + for iter.Next() { + key := iter.Key() + srcVal := iter.Value() + if objVal := obj.MapIndex(key); !objVal.IsValid() { + // Key is unique to src. + out.SetMapIndex(key, cpy(srcVal)) + } + } + return out case reflect.Slice: // Slices get copied only if the original was a nil pointer, which should always pass // through the cpy() codepath and never through here. + return cpy(obj) // Assert that none of the "complex" kinds are present. case reflect.Array, reflect.Chan, reflect.Func, - reflect.Interface, - reflect.UnsafePointer: + reflect.UnsafePointer, + // We already handled Ptr and Interface. + reflect.Ptr, + reflect.Interface: panic(fmt.Sprintf("unable to fill %T with %T at %v", obj.Interface(), src.Interface(), name)) - // Nothing to do for the simple Kinds like string or int; the only way a simple kind in the - // src can end up being merged into the obj is if it is within a call to cpy(), like after - // allocating a new pointer. This is because we only merge into nil pointers. - } - - return obj -} - -// cpy is for deep copying, but it will only work on "nice" objects, which should include our -// schema objects. -func cpy(v reflect.Value) reflect.Value { - // fmt.Printf("cpy(%T)\n", v.Interface()) - var out reflect.Value - - switch v.Kind() { - case reflect.Ptr: - if v.IsZero() { - return v - } - out = reflect.New(v.Elem().Type()) - out.Elem().Set(cpy(v.Elem())) - - case reflect.Struct: - out = reflect.New(v.Type()).Elem() - // Recurse into each field of the struct. - for i := 0; i < v.NumField(); i++ { - out.Field(i).Set(cpy(v.Field(i))) - } - - case reflect.Map: - out = reflect.New(v.Type()).Elem() - // Recurse into each key of the map. - for _, key := range v.MapKeys() { - val := v.MapIndex(key) - out.SetMapIndex(key, cpy(val)) - } - - case reflect.Slice: - out = reflect.New(v.Type()).Elem() - // Recurse into each element of the slice. - for i := 0; i < v.Len(); i++ { - val := v.Index(i) - out.Set(reflect.Append(out, cpy(val))) - } - - // Assert that none of the "complex" kinds are present. - case reflect.Array, - reflect.Chan, - reflect.Func, - reflect.Interface, - reflect.UnsafePointer: - panic(fmt.Sprintf("unable to cpy %T", v.Interface())) - default: - // Simple types like string or int can be passed directly. - return v + // Simple kinds just get copied. + return cpy(obj) } - - return out } diff --git a/master/pkg/schemas/merge_test.go b/master/pkg/schemas/merge_test.go index 2268c5eab19..780da260824 100644 --- a/master/pkg/schemas/merge_test.go +++ b/master/pkg/schemas/merge_test.go @@ -27,86 +27,11 @@ func TestMerge(t *testing.T) { C: nil, } - Merge(&obj, src) + out := Merge(obj, src).(X) - assert.Assert(t, *obj.A == "obj:x.a") - assert.Assert(t, *obj.B == "src:x.b") - assert.Assert(t, *obj.C == "obj:x.c") -} - -type Y struct { - A *UA `union:"type,ux" json:"-"` - B *UB `union:"type,uy" json:"-"` - C *string -} - -type UA struct { - A *string -} - -type UB struct { - B *string -} - -func TestUnionMerge(t *testing.T) { - // 1. src has a union member, obj does not -> recurse into that field. - obj := Y{ - A: nil, - B: nil, - C: ptrs.StringPtr("obj:c"), - } - - src := Y{ - A: nil, - B: &UB{ - B: ptrs.StringPtr("src:b:b"), - }, - C: ptrs.StringPtr("src:c"), - } - - Merge(&obj, src) - - assert.Assert(t, obj.A == nil) - assert.Assert(t, *obj.B.B == "src:b:b") - assert.Assert(t, *obj.C == "obj:c") - - // 2. src has a union member, obj has the same one -> recurse into that field. - obj = Y{ - A: &UA{}, - B: nil, - C: nil, - } - - src = Y{ - A: &UA{A: ptrs.StringPtr("src:a:a")}, - B: nil, - C: ptrs.StringPtr("src:y.c"), - } - - Merge(&obj, src) - assert.Assert(t, *obj.A.A == "src:a:a") - assert.Assert(t, obj.B == nil) - assert.Assert(t, *obj.C == "src:y.c") - - // 3. src has a union member, obj has the different one -> do not recurse. - obj = Y{ - A: &UA{}, - B: nil, - C: nil, - } - - src = Y{ - A: nil, - B: &UB{ - B: ptrs.StringPtr("src:b:b"), - }, - C: nil, - } - - Merge(&obj, src) - assert.Assert(t, obj.A.A == nil) - assert.Assert(t, obj.B == nil) - assert.Assert(t, obj.C == nil) + assert.Assert(t, *out.A == "obj:x.a") + assert.Assert(t, *out.B == "src:x.b") + assert.Assert(t, *out.C == "obj:x.c") } func TestMapMerge(t *testing.T) { @@ -119,39 +44,39 @@ func TestMapMerge(t *testing.T) { obj := map[string]string{"1": "obj:one", "2": "obj:two"} src := map[string]string{"2": "src:two", "3": "src:three"} - Merge(&obj, src) - assertCorrectMerge(obj) + out := Merge(obj, src).(map[string]string) + assertCorrectMerge(out) } func TestSliceMerge(t *testing.T) { obj := &[]int{0, 1} src := &[]int{2, 3} - Merge(&obj, src) - assert.Assert(t, len(*obj) == 2) - assert.Assert(t, (*obj)[0] == 0) - assert.Assert(t, (*obj)[1] == 1) + out := Merge(obj, src).(*[]int) + assert.Assert(t, len(*out) == 2) + assert.Assert(t, (*out)[0] == 0) + assert.Assert(t, (*out)[1] == 1) obj = nil src = &[]int{2, 3} - Merge(&obj, src) - assert.Assert(t, len(*obj) == 2) - assert.Assert(t, (*obj)[0] == 2) - assert.Assert(t, (*obj)[1] == 3) + out = Merge(obj, src).(*[]int) + assert.Assert(t, len(*out) == 2) + assert.Assert(t, (*out)[0] == 2) + assert.Assert(t, (*out)[1] == 3) } type Z []int -func (z *Z) Merge(src interface{}) { - *z = append(*z, src.(Z)...) +func (z Z) Merge(src interface{}) interface{} { + return append(z, src.(Z)...) } func TestMergable(t *testing.T) { obj := &Z{0, 1} src := &Z{2, 3} - Merge(&obj, src) - assert.Assert(t, len(*obj) == 4) - assert.Assert(t, (*obj)[0] == 0) - assert.Assert(t, (*obj)[1] == 1) - assert.Assert(t, (*obj)[2] == 2) - assert.Assert(t, (*obj)[3] == 3) + out := Merge(obj, src).(*Z) + assert.Assert(t, len(*out) == 4) + assert.Assert(t, (*out)[0] == 0) + assert.Assert(t, (*out)[1] == 1) + assert.Assert(t, (*out)[2] == 2) + assert.Assert(t, (*out)[3] == 3) } diff --git a/master/pkg/schemas/schema.go b/master/pkg/schemas/schema.go index 46285eaefe8..d66d3ef42ef 100644 --- a/master/pkg/schemas/schema.go +++ b/master/pkg/schemas/schema.go @@ -1,5 +1,7 @@ package schemas +//go:generate ../../../schemas/gen.py go-root --output zgen_schemas.go + import ( "bytes" "encoding/json" diff --git a/master/pkg/schemas/unions.go b/master/pkg/schemas/unions.go deleted file mode 100644 index 6e99ee75c4b..00000000000 --- a/master/pkg/schemas/unions.go +++ /dev/null @@ -1,75 +0,0 @@ -package schemas - -import ( - "reflect" -) - -// recursiveElem calls Elem() recursively until a non-pointer, non-interface object is reached. -// If any layer is nil, it returns (nil, false). -func recursiveElem(val reflect.Value) (reflect.Value, bool) { - for val.Kind() == reflect.Ptr || val.Kind() == reflect.Interface { - if val.IsZero() { - return val, false - } - val = val.Elem() - } - return val, true -} - -// UnionDefaultSchema is a helper function for defining DefaultSchema on union-like objects. -// It searches for the non-nil union member and uses that member to define defaults for the common -// fields. In short it turns this: -// -// func (c CheckpointStorageConfigV0) DefaultSource { -// if c != nil { -// if c.SharedFSConfig != nil { -// return c.SharedFSConfig.DefaultSource -// } -// if c.S3Config != nil { -// return c.S3Config.DefaultSource -// } -// if c.GCSConfig != nil { -// return c.GCSConfig.DefaultSource -// } -// if c.HDFSConfig != nil { -// return c.HDFSConfig.DefaultSource -// } -// } -// return nil -// } -// -// Into this: -// -// func (c CheckpointStorageConfigV0) DefaultSource() interface{} { -// return schemas.UnionDefaultSchema(c) -// } -func UnionDefaultSchema(in interface{}) interface{} { - v := reflect.ValueOf(in) - var ok bool - if v, ok = recursiveElem(v); !ok { - return nil - } - // Iterate through all the fields of the struct. - for i := 0; i < v.NumField(); i++ { - fieldType := v.Type().Field(i) - if _, ok := fieldType.Tag.Lookup("union"); !ok { - // This field has no "union" tag and cannot provide defaults. - continue - } - - field := v.Field(i) - - if _, ok := recursiveElem(field); !ok { - // nil pointers cannot provide defaults. - continue - } - - // Get a source of defaults from a Defaultable or Schema interface. - if defaultable, ok := field.Interface().(Defaultable); ok { - return defaultable.DefaultSource() - } else if schema, ok := field.Interface().(Schema); ok { - return schema.ParsedSchema() - } - } - return nil -} diff --git a/master/pkg/schemas/schema_gen.go b/master/pkg/schemas/zgen_schemas.go similarity index 97% rename from master/pkg/schemas/schema_gen.go rename to master/pkg/schemas/zgen_schemas.go index 302359cfc11..11b78f1ae4c 100644 --- a/master/pkg/schemas/schema_gen.go +++ b/master/pkg/schemas/zgen_schemas.go @@ -1,4 +1,4 @@ -// This is a generated file. Editing it will make you sad. +// Code generated by gen.py. DO NOT EDIT. package schemas @@ -2611,60 +2611,113 @@ var ( } } `) - schemaBindMountV0 interface{} - schemaCheckDataLayerCacheV0 interface{} - schemaCheckEpochNotUsedV0 interface{} - schemaCheckGlobalBatchSizeV0 interface{} - schemaCheckGridHyperparameterV0 interface{} - schemaCheckPositiveLengthV0 interface{} - schemaCheckpointStorageConfigV0 interface{} - schemaGCSDataLayerConfigV0 interface{} - schemaS3DataLayerConfigV0 interface{} - schemaSharedFSDataLayerConfigV0 interface{} - schemaDataLayerConfigV0 interface{} - schemaEnvironmentImageMapV0 interface{} - schemaEnvironmentImageV0 interface{} - schemaEnvironmentVariablesMapV0 interface{} - schemaEnvironmentVariablesV0 interface{} - schemaEnvironmentConfigV0 interface{} - schemaExperimentConfigV0 interface{} - schemaGCSConfigV0 interface{} - schemaHDFSConfigV0 interface{} + schemaBindMountV0 interface{} + + schemaCheckDataLayerCacheV0 interface{} + + schemaCheckEpochNotUsedV0 interface{} + + schemaCheckGlobalBatchSizeV0 interface{} + + schemaCheckGridHyperparameterV0 interface{} + + schemaCheckPositiveLengthV0 interface{} + + schemaCheckpointStorageConfigV0 interface{} + + schemaGCSDataLayerConfigV0 interface{} + + schemaS3DataLayerConfigV0 interface{} + + schemaSharedFSDataLayerConfigV0 interface{} + + schemaDataLayerConfigV0 interface{} + + schemaEnvironmentImageMapV0 interface{} + + schemaEnvironmentImageV0 interface{} + + schemaEnvironmentVariablesMapV0 interface{} + + schemaEnvironmentVariablesV0 interface{} + + schemaEnvironmentConfigV0 interface{} + + schemaExperimentConfigV0 interface{} + + schemaGCSConfigV0 interface{} + + schemaHDFSConfigV0 interface{} + schemaCategoricalHyperparameterV0 interface{} - schemaConstHyperparameterV0 interface{} - schemaDoubleHyperparameterV0 interface{} - schemaIntHyperparameterV0 interface{} - schemaLogHyperparameterV0 interface{} - schemaHyperparameterV0 interface{} - schemaHyperparametersV0 interface{} - schemaInternalConfigV0 interface{} - schemaKerberosConfigV0 interface{} - schemaLengthV0 interface{} - schemaNativeConfigV0 interface{} - schemaOptimizationsConfigV0 interface{} - schemaReproducibilityConfigV0 interface{} - schemaResourcesConfigV0 interface{} - schemaS3ConfigV0 interface{} - schemaAdaptiveASHAConfigV0 interface{} - schemaAdaptiveSimpleConfigV0 interface{} - schemaAdaptiveConfigV0 interface{} - schemaAsyncHalvingConfigV0 interface{} - schemaGridConfigV0 interface{} - schemaPBTConfigV0 interface{} - schemaRandomConfigV0 interface{} - schemaSingleConfigV0 interface{} - schemaSyncHalvingConfigV0 interface{} - schemaSearcherConfigV0 interface{} - schemaSecurityConfigV0 interface{} - schemaSharedFSConfigV0 interface{} - schemaTensorboardStorageConfigV0 interface{} - schemaTestRootV0 interface{} - schemaTestSubV0 interface{} - schemaTestUnionAV0 interface{} - schemaTestUnionBV0 interface{} - schemaTestUnionV0 interface{} - cachedSchemaMap map[string]interface{} - cachedSchemaBytesMap map[string][]byte + + schemaConstHyperparameterV0 interface{} + + schemaDoubleHyperparameterV0 interface{} + + schemaIntHyperparameterV0 interface{} + + schemaLogHyperparameterV0 interface{} + + schemaHyperparameterV0 interface{} + + schemaHyperparametersV0 interface{} + + schemaInternalConfigV0 interface{} + + schemaKerberosConfigV0 interface{} + + schemaLengthV0 interface{} + + schemaNativeConfigV0 interface{} + + schemaOptimizationsConfigV0 interface{} + + schemaReproducibilityConfigV0 interface{} + + schemaResourcesConfigV0 interface{} + + schemaS3ConfigV0 interface{} + + schemaAdaptiveASHAConfigV0 interface{} + + schemaAdaptiveSimpleConfigV0 interface{} + + schemaAdaptiveConfigV0 interface{} + + schemaAsyncHalvingConfigV0 interface{} + + schemaGridConfigV0 interface{} + + schemaPBTConfigV0 interface{} + + schemaRandomConfigV0 interface{} + + schemaSingleConfigV0 interface{} + + schemaSyncHalvingConfigV0 interface{} + + schemaSearcherConfigV0 interface{} + + schemaSecurityConfigV0 interface{} + + schemaSharedFSConfigV0 interface{} + + schemaTensorboardStorageConfigV0 interface{} + + schemaTestRootV0 interface{} + + schemaTestSubV0 interface{} + + schemaTestUnionAV0 interface{} + + schemaTestUnionBV0 interface{} + + schemaTestUnionV0 interface{} + + cachedSchemaMap map[string]interface{} + + cachedSchemaBytesMap map[string][]byte ) func ParsedBindMountV0() interface{} { diff --git a/schemas/gen.py b/schemas/gen.py index a3bf30c49be..9ee27fbca15 100755 --- a/schemas/gen.py +++ b/schemas/gen.py @@ -3,8 +3,13 @@ import argparse import json import os +import re import sys -from typing import List, Optional +from typing import List, Optional, Tuple + +HERE = os.path.dirname(__file__) +ALL_PKGS = ["expconf"] +URLBASE = "http://determined.ai/schemas" def camel_to_snake(name: str) -> str: @@ -37,12 +42,21 @@ def version(self) -> str: return os.path.basename(os.path.dirname(self.url)) +def list_files(package: str) -> List[str]: + """List all json schema files in a package (like `expconf`).""" + out = [] + root = os.path.join(HERE, package) + for dirpath, _, files in os.walk(root): + out += [os.path.join(dirpath, f) for f in files if f.endswith(".json")] + return sorted(out) + + def read_schemas(files: List[str]) -> List[Schema]: + """Read all the schemas in a list of files.""" schemas = [] - urlbase = "http://determined.ai/schemas" for file in files: urlend = os.path.relpath(file, os.path.dirname(__file__)) - url = os.path.join(urlbase, urlend) + url = os.path.join(URLBASE, urlend) with open(file) as f: schema = Schema(url, f.read()) schemas.append(schema) @@ -60,13 +74,12 @@ def gen_go_schemas_package(schemas: List[Schema]) -> List[str]: urls, so that schemas of one type are free to reference schemas of another type. """ lines = [] - lines.append("// This is a generated file. Editing it will make you sad.") + lines.append("// Code generated by gen.py. DO NOT EDIT.") lines.append("") lines.append("package schemas") lines.append("") lines.append("import (") lines.append('\t"encoding/json"') - lines.append('\t"github.com/santhosh-tekuri/jsonschema/v2"') lines.append(")") lines.append("") @@ -77,9 +90,12 @@ def gen_go_schemas_package(schemas: List[Schema]) -> List[str]: [f"\ttext{schema.golang_title} = []byte(`{schema.text}`)" for schema in schemas] ) # Cached schema values, initially nil. - lines.extend([f"\tschema{schema.golang_title} interface{{}}" for schema in schemas]) + for schema in schemas: + lines.append(f"\tschema{schema.golang_title} interface{{}}") + lines.append("") # Cached map of urls to schema values, initially nil. lines.append("\tcachedSchemaMap map[string]interface{}") + lines.append("") lines.append("\tcachedSchemaBytesMap map[string][]byte") lines.append(")") lines.append("") @@ -118,49 +134,321 @@ def gen_go_schemas_package(schemas: List[Schema]) -> List[str]: return lines -def gen_go_package(schemas: List[Schema], package: str) -> List[str]: +def next_struct_name(file: str, start: int) -> str: """ - Generate a file at the level of e.g. pkg/schemas/expconf that defines the schemas.Schema - interface and schemas.Defaultable interfcae (if applicable) for all the objects in this package. + Find the name of the next struct definition in a go file starting at a given line. + + This is how we decide which struct to operate on for the //go:generate comments above structs. """ - lines = [] - lines.append("// This is a generated file. Editing it will make you sad.") - lines.append("") - lines.append(f"package {package}") - lines.append("") - lines.append("import (") - lines.append('\t"encoding/json"') - lines.append('\t"github.com/santhosh-tekuri/jsonschema/v2"') - lines.append('\t"github.com/determined-ai/determined/master/pkg/schemas"') - lines.append(")") - lines.append("") + with open(file) as f: + for lineno, line in enumerate(f.readlines()): + if lineno <= start: + continue + match = re.match("type ([\\S]+) struct", line) + if match is not None: + return match[1] + raise AssertionError(f"did not find struct in {file} after line {line}") - # Implement the Schema interface for all objects. - for schema in schemas: - if not schema.python_title.startswith("check_"): - x = schema.golang_title[0].lower() - lines.append("") - lines.append( - f"func ({x} {schema.golang_title}) ParsedSchema() interface{{}} {{" + +# FieldSpec = (field, type, tag) +FieldSpec = Tuple[str, str, str] +# UnionSpec = (field, type) +UnionSpec = Tuple[str, str] + + +def find_struct(file: str, struct_name: str) -> Tuple[List[FieldSpec], List[UnionSpec]]: + """ + Open a file and find a struct definition for a given name. + + This function uses regex to read the golang source code... hacky, but it works. + """ + field_spec = [] # type: List[FieldSpec] + union_spec = [] # type: List[UnionSpec] + with open(file) as f: + state = "pre" + for lineno, line in enumerate(f.readlines()): + if state == "pre": + if line.startswith(f"type {struct_name} struct"): + state = "fields" + elif state == "fields": + if line.strip() == "}": + # No more fields + return field_spec, union_spec + if line.strip() == "": + # No field on this line + continue + if line.startswith("\t//"): + # comment line + continue + + # Union fields. + match = re.match("\t([\\S]+)\\s+([\\S]+)\\s+`union.*", line) + if match is not None: + field, type = match[1], match[2] + union_spec.append((field, type)) + continue + + # Normal fields: capture the field name, the type, and the json tag. + match = re.match('\t([\\S]+)\\s+([\\S]+)\\s+`json:"([^,"]+)', line) + if match is not None: + field, type, tag = match[1], match[2], match[3] + # store the field name and the type + field_spec.append((field, type, tag)) + continue + + raise AssertionError( + f"unsure how to handle line {lineno}: '{line.rstrip()}'" + ) + + # We should have exited when we saw the "}" line. + raise AssertionError( + f"failed to find struct definition for {struct_name} in {file}" + ) + + +def find_schema(package: str, struct: str) -> Schema: + """Locate a json-schema file from a struct name.""" + if re.match(".*V[0-9]+", struct) is None: + raise AssertionError( + f"{struct} is not a valid schema type name; it should end in Vx where x is a digit" + ) + version = struct[-2:].lower() + dir = os.path.join(HERE, package, version) + for file in os.listdir(dir): + if not file.endswith(".json"): + continue + path = os.path.join(dir, file) + urlend = os.path.relpath(path, HERE) + url = os.path.join(URLBASE, urlend) + with open(path) as f: + schema = Schema(url, f.read()) + if schema.golang_title != struct: + continue + return schema + raise AssertionError("failed to find schema") + + +def get_defaulted_type(schema: Schema, tag: str, type: str) -> Tuple[str, str, bool]: + """ + Given the type string for a field of a given tag, determine the type of the after-defaulting + value. This is used by the auto-generated getters, so that parts of the code which consume + experiment configs can use compile-time checks to know which pointer-typed fields values might + be nil and which ones have defaults and will never be nil. + """ + prop = schema.schema["properties"].get(tag, {}) + if prop is True: + prop = {} + default = prop.get("default") + + required = schema.schema.get("required", []) or schema.schema.get( + "eventuallyRequired", [] + ) + + if default is not None: + if not type.startswith("*"): + raise AssertionError( + f"{tag} type ({type}) must be a pointer since it can be defaulted" ) - lines.append(f"\treturn schemas.Parsed{schema.golang_title}()") + if type.startswith("**"): + raise AssertionError(f"{tag} type ({type}) must not be a double pointer") + type = type[1:] + + return type, default, required + + +def go_getters(struct: str, schema: Schema, spec: List[FieldSpec]) -> List[str]: + lines = [] # type: List[str] + + if len(spec) < 1: + return lines + + x = struct[0].lower() + + for field, type, tag in spec: + defaulted_type, default, required = get_defaulted_type(schema, tag, type) + + if default is None: + lines.append(f"func ({x} {struct}) Get{field}() {type} {{") + lines.append(f"\treturn {x}.{field}") lines.append("}") lines.append("") + else: + lines.append(f"func ({x} {struct}) Get{field}() {defaulted_type} {{") + lines.append(f"\tif {x}.{field} == nil {{") lines.append( - f"func ({x} {schema.golang_title}) SanityValidator() *jsonschema.Schema {{" + f'\t\tpanic("You must call WithDefaults on {struct} before .Get{field}")' ) - lines.append(f'\treturn schemas.GetSanityValidator("{schema.url}")') + lines.append("\t}") + lines.append(f"\treturn *{x}.{field}") lines.append("}") lines.append("") - lines.append( - f"func ({x} {schema.golang_title}) CompletenessValidator() *jsonschema.Schema {{" + + return lines + + +def get_union_common_members( + file: str, package: str, union_types: List[str] +) -> List[Tuple[str, str]]: + """ + Look at all of the union members types for a union type and automatically determine which + members are common to all members. + """ + # Find all members and types of all union member types + per_struct_members = [] + for struct in union_types: + schema = find_schema(package, struct) + spec, union = find_struct(file, struct) + if len(union) > 0: + raise AssertionError( + f"detected nested union; {struct} is a union member and also a union itself" ) - lines.append(f'\treturn schemas.GetCompletenessValidator("{schema.url}")') - lines.append("}") + members = {} + for field, type, tag in spec: + type, _, _ = get_defaulted_type(schema, tag, type) + members[field] = type + per_struct_members.append(members) + + # Find common members by name. + common_fields = set(per_struct_members[0].keys()) + for members in per_struct_members[1:]: + common_fields = common_fields.intersection(set(members.keys())) + + # Validate types all match. + for field in common_fields: + field_types = {members[field] for members in per_struct_members} + if len(field_types) != 1: + raise AssertionError( + f".{field} has multiple types ({field_types}) among union members {union_types}" + ) + + # Sort this so the generation is deterministic. + return sorted( + {field: per_struct_members[0][field] for field in common_fields}.items() + ) + + +def go_unions( + struct: str, package: str, file: str, schema: Schema, union_spec: List[UnionSpec] +) -> List[str]: + lines = [] # type: List[str] + if len(union_spec) < 1: + return lines + x = struct[0].lower() + + # Define a GetUnionMember() that returns an interface. + lines.append(f"func ({x} {struct}) GetUnionMember() interface{{}} {{") + for field, _ in union_spec: + lines.append(f"\tif {x}.{field} != nil {{") + lines.append("\t\treturn nil") + lines.append("\t}") + lines.append('\tpanic("no union member defined")') + lines.append("}") + lines.append("") + + union_types = [type.lstrip("*") for _, type in union_spec] + + # Define getters for each of the common members of the union. + common_members = get_union_common_members(file, package, union_types) + for common_field, type in common_members: + lines.append(f"func ({x} {struct}) Get{common_field}() {type} {{") + for field, _ in union_spec: + lines.append(f"\tif {x}.{field} != nil {{") + lines.append(f"\t\treturn {x}.{field}.Get{common_field}()") + lines.append("\t}") + lines.append('\tpanic("no union member defined")') + lines.append("}") + lines.append("") + + return lines + + +def go_helpers(struct: str) -> List[str]: + """ + Define WithDefaults() and Merge(), which are typed wrappers around schemas.WithDefaults() and + schemas.Merge(). + """ + lines = [] + + x = struct[0].lower() + + lines.append(f"func ({x} {struct}) WithDefaults() {struct} {{") + lines.append(f"\treturn schemas.WithDefaults({x}).({struct})") + lines.append("}") + lines.append("") + + lines.append(f"func ({x} {struct}) Merge(other {struct}) {struct} {{") + lines.append(f"\treturn schemas.Merge({x}, other).({struct})") + lines.append("}") return lines +def go_schema_interface(struct: str, url: str) -> List[str]: + """ + Generate the schemas.Schema interface for a particular schema. + + This is used for getting json-schema-based validators from Schema objects, as well as being + used by the reflect code in defaults.go. + """ + lines = [] + + x = struct[0].lower() + + lines.append("") + lines.append(f"func ({x} {struct}) ParsedSchema() interface{{}} {{") + lines.append(f"\treturn schemas.Parsed{struct}()") + lines.append("}") + lines.append("") + lines.append(f"func ({x} {struct}) SanityValidator() *jsonschema.Schema {{") + lines.append(f'\treturn schemas.GetSanityValidator("{url}")') + lines.append("}") + lines.append("") + lines.append(f"func ({x} {struct}) CompletenessValidator() *jsonschema.Schema {{") + lines.append(f'\treturn schemas.GetCompletenessValidator("{url}")') + lines.append("}") + + return lines + + +def gen_go_struct( + package: str, file: str, line: int, imports: List[str] +) -> Tuple[str, List[str]]: + """Used by the //go:generate decorations on structs.""" + struct = next_struct_name(file, line) + field_spec, union_spec = find_struct(file, struct) + + if len(field_spec) and len(union_spec): + raise AssertionError(f"{struct} has both union tags and normal fields") + + schema = find_schema(package, struct) + + lines = [] + lines.append("// Code generated by gen.py. DO NOT EDIT.") + lines.append("") + + lines.append(f"package {package}") + lines.append("") + + lines.append("import (") + lines.append('\t"github.com/santhosh-tekuri/jsonschema/v2"') + + for imp in imports: + lines.append("\t" + imp) + lines.append("") + lines.append('\t"github.com/determined-ai/determined/master/pkg/schemas"') + lines.append(")") + lines.append("") + + lines += go_getters(struct, schema, field_spec) + lines += go_unions(struct, package, file, schema, union_spec) + lines += go_helpers(struct) + lines += go_schema_interface(struct, schema.url) + + filename = "zgen_" + camel_to_snake(struct) + ".go" + + return filename, lines + + def gen_python(schemas: List[Schema]) -> List[str]: lines = [] lines.append("# This is a generated file. Editing it will make you sad.") @@ -177,45 +465,100 @@ def gen_python(schemas: List[Schema]) -> List[str]: return lines -def main( - language: str, package: Optional[str], files: List[str], output: Optional[str] -) -> None: - assert language in ["go", "python"], "language must be 'go' or 'python'" - if language == "go": - assert package is not None, "--package must be provided for the go generator" - else: - assert package is None, "--package must not be provided to the python generator" - assert files, "no input files" - assert output is not None, "missing output file" +def maybe_write_output(lines: List[str], output: Optional[str]) -> None: + """Write lines to output, unless output would be unchanged.""" - schemas = read_schemas(files) + text = "\n".join(lines) + "\n" - if language == "go": - assert package is not None - if package == "schemas": - lines = gen_go_schemas_package(schemas) - else: - lines = gen_go_package(schemas, package) - else: - lines = gen_python(schemas) + if output is None: + # Write to stdout. + sys.stdout.write(text) + return - text = "\n".join([*lines, "\n"]) + if os.path.exists(output): + with open(output, "r") as f: + if f.read() == text: + return - # Write the output file. with open(output, "w") as f: f.write(text) +def python_main(package: str, output: Optional[str]) -> None: + assert package is not None, "--package must be provided" + files = list_files(package) + schemas = read_schemas(files) + + lines = gen_python(schemas) + + maybe_write_output(lines, output) + + +def go_struct_main(package: str, file: str, line: int, imports: Optional[str]) -> None: + assert package is not None, "GOPACKAGE not set" + assert file is not None, "GOFILE not set" + assert line is not None, "GOLINE not set" + + def fmt_import(imp: str) -> str: + """Turn e.g. `k8sV1:k8s.io/api/core/v1` into `k8sV1 "k8s.io/api/core/v1"`.""" + if ":" in imp: + return imp.replace(":", ' "') + '"' + else: + return '"' + imp + '"' + + imports_list = [] + if imports is not None: + imports_list = [fmt_import(i) for i in imports.split(",") if i] + + output, lines = gen_go_struct(package, file, line, imports_list) + + maybe_write_output(lines, output) + + +def go_root_main(output: Optional[str]) -> None: + files = [] + for package in ALL_PKGS: + files += list_files(package) + schemas = read_schemas(files) + + lines = gen_go_schemas_package(schemas) + + maybe_write_output(lines, output) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="generate code with embedded schemas") - parser.add_argument("language", help="go or python") - parser.add_argument("files", nargs="*", help="input files") - parser.add_argument("--output") - parser.add_argument("--package") - args = parser.parse_args() + subparsers = parser.add_subparsers(dest="generator") + + # Python generator. + python_parser = subparsers.add_parser("python") + python_parser.add_argument("--package", required=True) + python_parser.add_argument("--output") + + # Go struct generator, expect environment variables set by go generate. + go_struct_parser = subparsers.add_parser("go-struct") + go_struct_parser.add_argument("--package", default=os.environ.get("GOPACKAGE")) + go_struct_parser.add_argument("--file", default=os.environ.get("GOFILE")) + go_struct_parser.add_argument("--line", default=os.environ.get("GOLINE"), type=int) + go_struct_parser.add_argument("--imports") + + # Go root generator. + go_root_parser = subparsers.add_parser("go-root") + go_root_parser.add_argument("--output") + + args = vars(parser.parse_args()) try: - main(args.language, args.package, args.files, args.output) + assert "generator" in args, "missing generator argument on command line" + generator = args.pop("generator") + if generator == "python": + python_main(**args) + elif generator == "go-struct": + go_struct_main(**args) + elif generator == "go-root": + go_root_main(**args) + else: + raise ValueError(f"unrecognized generator: {generator}") except AssertionError as e: print(e, file=sys.stderr) sys.exit(1) diff --git a/schemas/test_cases/v0/experiment.yaml b/schemas/test_cases/v0/experiment.yaml index bb3a3694b1e..5baff6f3f95 100644 --- a/schemas/test_cases/v0/experiment.yaml +++ b/schemas/test_cases/v0/experiment.yaml @@ -107,9 +107,6 @@ matches: - http://determined.ai/schemas/expconf/v0/experiment.json case: - checkpoint_storage: - type: shared_fs - host_path: /tmp hyperparameters: global_batch_size: 32 searcher: @@ -122,14 +119,7 @@ defaulted: bind_mounts: [] checkpoint_policy: best - checkpoint_storage: - type: shared_fs - host_path: /tmp - storage_path: null - propagation: rprivate - save_experiment_best: 0 - save_trial_best: 1 - save_trial_latest: 1 + checkpoint_storage: null data: {} data_layer: type: shared_fs @@ -197,9 +187,6 @@ case: hyperparameters: global_batch_size: 32 - checkpoint_storage: - type: shared_fs - host_path: /tmp entrypoint: model_def:MyTrial searcher: name: single @@ -279,9 +266,6 @@ categorical_hparam: type: categorical vals: [1, 2, 3, 4] - checkpoint_storage: - type: shared_fs - host_path: /tmp searcher: name: grid metric: loss