From 8123f6d249e35fe52b74e79b64f95878f59a6f13 Mon Sep 17 00:00:00 2001 From: Nicolas De Loof Date: Mon, 30 Oct 2023 17:04:15 +0100 Subject: [PATCH] detect include cycle using compose-file stored in context Signed-off-by: Nicolas De Loof --- loader/extends.go | 2 +- loader/include.go | 23 ++++++++++++++++------- loader/loader.go | 14 ++++++++------ loader/loader_yaml_test.go | 2 +- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/loader/extends.go b/loader/extends.go index 6fa611a1..15edb5e0 100644 --- a/loader/extends.go +++ b/loader/extends.go @@ -71,7 +71,7 @@ func ApplyExtends(ctx context.Context, dict map[string]any, workingdir string, o ConfigFiles: []types.ConfigFile{ {Filename: local}, }, - }, extendsOpts, ct, nil) + }, extendsOpts, ct) if err != nil { return err } diff --git a/loader/include.go b/loader/include.go index 037a4022..35674250 100644 --- a/loader/include.go +++ b/loader/include.go @@ -23,6 +23,7 @@ import ( "reflect" "strings" + "github.com/compose-spec/compose-go/v2/consts" "github.com/compose-spec/compose-go/v2/dotenv" interp "github.com/compose-spec/compose-go/v2/interpolation" "github.com/compose-spec/compose-go/v2/types" @@ -39,7 +40,7 @@ func loadIncludeConfig(source any) ([]types.IncludeConfig, error) { return requires, err } -func ApplyInclude(ctx context.Context, configDetails types.ConfigDetails, model map[string]any, options *Options, included []string) error { +func ApplyInclude(ctx context.Context, configDetails types.ConfigDetails, model map[string]any, options *Options) error { includeConfig, err := loadIncludeConfig(model["include"]) if err != nil { return err @@ -60,11 +61,9 @@ func ApplyInclude(ctx context.Context, configDetails types.ConfigDetails, model } mainFile := r.Path[0] - for _, f := range included { - if f == mainFile { - included = append(included, mainFile) - return errors.Errorf("include cycle detected:\n%s\n include %s", included[0], strings.Join(included[1:], "\n include ")) - } + err := checkIncludeCycle(ctx, mainFile) + if err != nil { + return err } if r.ProjectDirectory == "" { @@ -91,7 +90,7 @@ func ApplyInclude(ctx context.Context, configDetails types.ConfigDetails, model LookupValue: config.LookupEnv, TypeCastMapping: options.Interpolate.TypeCastMapping, } - imported, err := loadYamlModel(ctx, config, loadOptions, &cycleTracker{}, included) + imported, err := loadYamlModel(ctx, config, loadOptions, &cycleTracker{}) if err != nil { return err } @@ -104,6 +103,16 @@ func ApplyInclude(ctx context.Context, configDetails types.ConfigDetails, model return nil } +func checkIncludeCycle(ctx context.Context, mainFile string) error { + files, _ := ctx.Value(consts.ComposeFileKey{}).([]string) + for _, f := range files { + if f == mainFile { + return errors.Errorf("include cycle detected:\n%s\n include %s", strings.Join(files, "\n include "), mainFile) + } + } + return nil +} + // importResources import into model all resources defined by imported, and report error on conflict func importResources(source map[string]any, target map[string]any) error { if err := importResource(source, target, "services"); err != nil { diff --git a/loader/loader.go b/loader/loader.go index e15d4377..da7acf60 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -274,13 +274,16 @@ func LoadWithContext(ctx context.Context, configDetails types.ConfigDetails, opt return load(ctx, configDetails, opts, nil) } -func loadYamlModel(ctx context.Context, config types.ConfigDetails, opts *Options, ct *cycleTracker, included []string) (map[string]interface{}, error) { +func loadYamlModel(ctx context.Context, config types.ConfigDetails, opts *Options, ct *cycleTracker) (map[string]interface{}, error) { var ( dict = map[string]interface{}{} err error ) + + f, _ := ctx.Value(consts.ComposeFileKey{}).([]string) + ctx = context.WithValue(ctx, consts.ComposeFileKey{}, append(f, config.ConfigFiles[0].Filename)) + for _, file := range config.ConfigFiles { - fctx := context.WithValue(ctx, consts.ComposeFileKey{}, file.Filename) if len(file.Content) == 0 && file.Config == nil { content, err := os.ReadFile(file.Filename) if err != nil { @@ -315,7 +318,7 @@ func loadYamlModel(ctx context.Context, config types.ConfigDetails, opts *Option } if !opts.SkipExtends { - err = ApplyExtends(fctx, cfg, config.WorkingDir, opts, ct, processors...) + err = ApplyExtends(ctx, cfg, config.WorkingDir, opts, ct, processors...) if err != nil { return err } @@ -364,8 +367,7 @@ func loadYamlModel(ctx context.Context, config types.ConfigDetails, opts *Option } if !opts.SkipInclude { - included = append(included, config.ConfigFiles[0].Filename) - err = ApplyInclude(ctx, config, dict, opts, included) + err = ApplyInclude(ctx, config, dict, opts) if err != nil { return nil, err } @@ -400,7 +402,7 @@ func load(ctx context.Context, configDetails types.ConfigDetails, opts *Options, includeRefs := make(map[string][]types.IncludeConfig) - dict, err := loadYamlModel(ctx, configDetails, opts, &cycleTracker{}, nil) + dict, err := loadYamlModel(ctx, configDetails, opts, &cycleTracker{}) if err != nil { return nil, err } diff --git a/loader/loader_yaml_test.go b/loader/loader_yaml_test.go index c02129bb..f33e99b4 100644 --- a/loader/loader_yaml_test.go +++ b/loader/loader_yaml_test.go @@ -43,7 +43,7 @@ services: image: bar command: echo world init: false -`)}}}, &Options{}, &cycleTracker{}, nil) +`)}}}, &Options{}, &cycleTracker{}) assert.NilError(t, err) assert.DeepEqual(t, model, map[string]interface{}{ "services": map[string]interface{}{