From e68a88e14d7fdcafb6d240311b164155c2c5d3c9 Mon Sep 17 00:00:00 2001 From: Serge Smertin <259697+nfx@users.noreply.github.com> Date: Wed, 8 Nov 2023 15:50:20 +0100 Subject: [PATCH] Added `env.UserHomeDir(ctx)` for parallel-friendly tests (#955) ## Changes `os.Getenv(..)` is not friendly with `libs/env`. This PR makes the relevant changes to places where we need to read user home directory. ## Tests Mainly done in https://github.com/databricks/cli/pull/914 --- cmd/auth/env.go | 7 +-- cmd/auth/login.go | 2 +- cmd/auth/profiles.go | 2 +- cmd/root/auth.go | 30 ++++------- libs/databrickscfg/profiles.go | 43 +++++++++------- libs/databrickscfg/profiles_test.go | 51 ++++++++++++++----- .../testdata/sample-home/.databrickscfg | 7 +++ libs/env/context.go | 21 ++++++++ libs/env/context_test.go | 7 +++ libs/env/loader.go | 50 ++++++++++++++++++ libs/env/loader_test.go | 26 ++++++++++ 11 files changed, 190 insertions(+), 56 deletions(-) create mode 100644 libs/databrickscfg/testdata/sample-home/.databrickscfg create mode 100644 libs/env/loader.go create mode 100644 libs/env/loader_test.go diff --git a/cmd/auth/env.go b/cmd/auth/env.go index 241d5f880f..04aef36a88 100644 --- a/cmd/auth/env.go +++ b/cmd/auth/env.go @@ -1,6 +1,7 @@ package auth import ( + "context" "encoding/json" "errors" "fmt" @@ -68,8 +69,8 @@ func resolveSection(cfg *config.Config, iniFile *config.File) (*ini.Section, err return candidates[0], nil } -func loadFromDatabricksCfg(cfg *config.Config) error { - iniFile, err := databrickscfg.Get() +func loadFromDatabricksCfg(ctx context.Context, cfg *config.Config) error { + iniFile, err := databrickscfg.Get(ctx) if errors.Is(err, fs.ErrNotExist) { // it's fine not to have ~/.databrickscfg return nil @@ -110,7 +111,7 @@ func newEnvCommand() *cobra.Command { cfg.Profile = profile } else if cfg.Host == "" { cfg.Profile = "DEFAULT" - } else if err := loadFromDatabricksCfg(cfg); err != nil { + } else if err := loadFromDatabricksCfg(cmd.Context(), cfg); err != nil { return err } // Go SDK is lazy loaded because of Terraform semantics, diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 3a3f3a6dcf..c2b821b68c 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -128,7 +128,7 @@ func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command { func setHost(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error { // If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile. - _, profiles, err := databrickscfg.LoadProfiles(func(p databrickscfg.Profile) bool { + _, profiles, err := databrickscfg.LoadProfiles(ctx, func(p databrickscfg.Profile) bool { return p.Name == profileName }) if err != nil { diff --git a/cmd/auth/profiles.go b/cmd/auth/profiles.go index 97d8eeabcc..51ae9b1850 100644 --- a/cmd/auth/profiles.go +++ b/cmd/auth/profiles.go @@ -95,7 +95,7 @@ func newProfilesCommand() *cobra.Command { cmd.RunE = func(cmd *cobra.Command, args []string) error { var profiles []*profileMetadata - iniFile, err := databrickscfg.Get() + iniFile, err := databrickscfg.Get(cmd.Context()) if os.IsNotExist(err) { // return empty list for non-configured machines iniFile = &config.File{ diff --git a/cmd/root/auth.go b/cmd/root/auth.go index 81c7147923..350cbc65a7 100644 --- a/cmd/root/auth.go +++ b/cmd/root/auth.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "os" "github.com/databricks/cli/bundle" "github.com/databricks/cli/libs/cmdio" @@ -55,7 +54,7 @@ func accountClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt } // Try picking a profile dynamically if the current configuration is not valid. - profile, err := askForAccountProfile(ctx) + profile, err := AskForAccountProfile(ctx) if err != nil { return nil, err } @@ -83,7 +82,7 @@ func MustAccountClient(cmd *cobra.Command, args []string) error { // 1. only admins will have account configured // 2. 99% of admins will have access to just one account // hence, we don't need to create a special "DEFAULT_ACCOUNT" profile yet - _, profiles, err := databrickscfg.LoadProfiles(databrickscfg.MatchAccountProfiles) + _, profiles, err := databrickscfg.LoadProfiles(cmd.Context(), databrickscfg.MatchAccountProfiles) if err != nil { return err } @@ -123,7 +122,7 @@ func workspaceClientOrPrompt(ctx context.Context, cfg *config.Config, allowPromp } // Try picking a profile dynamically if the current configuration is not valid. - profile, err := askForWorkspaceProfile(ctx) + profile, err := AskForWorkspaceProfile(ctx) if err != nil { return nil, err } @@ -173,21 +172,14 @@ func SetWorkspaceClient(ctx context.Context, w *databricks.WorkspaceClient) cont return context.WithValue(ctx, &workspaceClient, w) } -func transformLoadError(path string, err error) error { - if os.IsNotExist(err) { - return fmt.Errorf("no configuration file found at %s; please create one first", path) - } - return err -} - -func askForWorkspaceProfile(ctx context.Context) (string, error) { - path, err := databrickscfg.GetPath() +func AskForWorkspaceProfile(ctx context.Context) (string, error) { + path, err := databrickscfg.GetPath(ctx) if err != nil { return "", fmt.Errorf("cannot determine Databricks config file path: %w", err) } - file, profiles, err := databrickscfg.LoadProfiles(databrickscfg.MatchWorkspaceProfiles) + file, profiles, err := databrickscfg.LoadProfiles(ctx, databrickscfg.MatchWorkspaceProfiles) if err != nil { - return "", transformLoadError(path, err) + return "", err } switch len(profiles) { case 0: @@ -213,14 +205,14 @@ func askForWorkspaceProfile(ctx context.Context) (string, error) { return profiles[i].Name, nil } -func askForAccountProfile(ctx context.Context) (string, error) { - path, err := databrickscfg.GetPath() +func AskForAccountProfile(ctx context.Context) (string, error) { + path, err := databrickscfg.GetPath(ctx) if err != nil { return "", fmt.Errorf("cannot determine Databricks config file path: %w", err) } - file, profiles, err := databrickscfg.LoadProfiles(databrickscfg.MatchAccountProfiles) + file, profiles, err := databrickscfg.LoadProfiles(ctx, databrickscfg.MatchAccountProfiles) if err != nil { - return "", transformLoadError(path, err) + return "", err } switch len(profiles) { case 0: diff --git a/libs/databrickscfg/profiles.go b/libs/databrickscfg/profiles.go index 864000d034..9f31eff626 100644 --- a/libs/databrickscfg/profiles.go +++ b/libs/databrickscfg/profiles.go @@ -1,11 +1,14 @@ package databrickscfg import ( + "context" + "errors" "fmt" - "os" + "io/fs" "path/filepath" "strings" + "github.com/databricks/cli/libs/env" "github.com/databricks/databricks-sdk-go/config" "github.com/spf13/cobra" ) @@ -67,43 +70,45 @@ func MatchAllProfiles(p Profile) bool { } // Get the path to the .databrickscfg file, falling back to the default in the current user's home directory. -func GetPath() (string, error) { - configFile := os.Getenv("DATABRICKS_CONFIG_FILE") +func GetPath(ctx context.Context) (string, error) { + configFile := env.Get(ctx, "DATABRICKS_CONFIG_FILE") if configFile == "" { configFile = "~/.databrickscfg" } if strings.HasPrefix(configFile, "~") { - homedir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("cannot find homedir: %w", err) - } + homedir := env.UserHomeDir(ctx) configFile = filepath.Join(homedir, configFile[1:]) } return configFile, nil } -func Get() (*config.File, error) { - configFile, err := GetPath() +var ErrNoConfiguration = errors.New("no configuration file found") + +func Get(ctx context.Context) (*config.File, error) { + path, err := GetPath(ctx) if err != nil { return nil, fmt.Errorf("cannot determine Databricks config file path: %w", err) } - return config.LoadFile(configFile) + configFile, err := config.LoadFile(path) + if errors.Is(err, fs.ErrNotExist) { + // downstreams depend on ErrNoConfiguration. TODO: expose this error through SDK + return nil, fmt.Errorf("%w at %s; please create one first", ErrNoConfiguration, path) + } else if err != nil { + return nil, err + } + return configFile, nil } -func LoadProfiles(fn ProfileMatchFunction) (file string, profiles Profiles, err error) { - f, err := Get() +func LoadProfiles(ctx context.Context, fn ProfileMatchFunction) (file string, profiles Profiles, err error) { + f, err := Get(ctx) if err != nil { return "", nil, fmt.Errorf("cannot load Databricks config file: %w", err) } - homedir, err := os.UserHomeDir() - if err != nil { - return - } - // Replace homedir with ~ if applicable. // This is to make the output more readable. - file = f.Path() + file = filepath.Clean(f.Path()) + homedir := filepath.Clean(env.UserHomeDir(ctx)) if strings.HasPrefix(file, homedir) { file = "~" + file[len(homedir):] } @@ -130,7 +135,7 @@ func LoadProfiles(fn ProfileMatchFunction) (file string, profiles Profiles, err } func ProfileCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - _, profiles, err := LoadProfiles(MatchAllProfiles) + _, profiles, err := LoadProfiles(cmd.Context(), MatchAllProfiles) if err != nil { return nil, cobra.ShellCompDirectiveError } diff --git a/libs/databrickscfg/profiles_test.go b/libs/databrickscfg/profiles_test.go index b1acdce92c..33a5c9dfd1 100644 --- a/libs/databrickscfg/profiles_test.go +++ b/libs/databrickscfg/profiles_test.go @@ -1,9 +1,11 @@ package databrickscfg import ( - "runtime" + "context" + "path/filepath" "testing" + "github.com/databricks/cli/libs/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -27,27 +29,50 @@ func TestProfilesSearchCaseInsensitive(t *testing.T) { } func TestLoadProfilesReturnsHomedirAsTilde(t *testing.T) { - if runtime.GOOS == "windows" { - t.Setenv("USERPROFILE", "./testdata") - } else { - t.Setenv("HOME", "./testdata") - } - t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg") - file, _, err := LoadProfiles(func(p Profile) bool { return true }) + ctx := context.Background() + ctx = env.WithUserHomeDir(ctx, "testdata") + ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg") + file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true }) require.NoError(t, err) - assert.Equal(t, "~/databrickscfg", file) + require.Equal(t, filepath.Clean("~/databrickscfg"), file) +} + +func TestLoadProfilesReturnsHomedirAsTildeExoticFile(t *testing.T) { + ctx := context.Background() + ctx = env.WithUserHomeDir(ctx, "testdata") + ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "~/databrickscfg") + file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true }) + require.NoError(t, err) + require.Equal(t, filepath.Clean("~/databrickscfg"), file) +} + +func TestLoadProfilesReturnsHomedirAsTildeDefaultFile(t *testing.T) { + ctx := context.Background() + ctx = env.WithUserHomeDir(ctx, "testdata/sample-home") + file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true }) + require.NoError(t, err) + require.Equal(t, filepath.Clean("~/.databrickscfg"), file) +} + +func TestLoadProfilesNoConfiguration(t *testing.T) { + ctx := context.Background() + ctx = env.WithUserHomeDir(ctx, "testdata") + _, _, err := LoadProfiles(ctx, func(p Profile) bool { return true }) + require.ErrorIs(t, err, ErrNoConfiguration) } func TestLoadProfilesMatchWorkspace(t *testing.T) { - t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg") - _, profiles, err := LoadProfiles(MatchWorkspaceProfiles) + ctx := context.Background() + ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg") + _, profiles, err := LoadProfiles(ctx, MatchWorkspaceProfiles) require.NoError(t, err) assert.Equal(t, []string{"DEFAULT", "query", "foo1", "foo2"}, profiles.Names()) } func TestLoadProfilesMatchAccount(t *testing.T) { - t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg") - _, profiles, err := LoadProfiles(MatchAccountProfiles) + ctx := context.Background() + ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg") + _, profiles, err := LoadProfiles(ctx, MatchAccountProfiles) require.NoError(t, err) assert.Equal(t, []string{"acc"}, profiles.Names()) } diff --git a/libs/databrickscfg/testdata/sample-home/.databrickscfg b/libs/databrickscfg/testdata/sample-home/.databrickscfg new file mode 100644 index 0000000000..96c8b7ca16 --- /dev/null +++ b/libs/databrickscfg/testdata/sample-home/.databrickscfg @@ -0,0 +1,7 @@ +[DEFAULT] +host = https://default +token = default + +[acc] +host = https://accounts.cloud.databricks.com +account_id = abc diff --git a/libs/env/context.go b/libs/env/context.go index bbe294d7b4..84518ad789 100644 --- a/libs/env/context.go +++ b/libs/env/context.go @@ -2,7 +2,9 @@ package env import ( "context" + "fmt" "os" + "runtime" "strings" ) @@ -63,6 +65,25 @@ func Set(ctx context.Context, key, value string) context.Context { return setMap(ctx, m) } +func homeEnvVar() string { + if runtime.GOOS == "windows" { + return "USERPROFILE" + } + return "HOME" +} + +func WithUserHomeDir(ctx context.Context, value string) context.Context { + return Set(ctx, homeEnvVar(), value) +} + +func UserHomeDir(ctx context.Context) string { + home := Get(ctx, homeEnvVar()) + if home == "" { + panic(fmt.Errorf("$HOME is not set")) + } + return home +} + // All returns environment variables that are defined in both os.Environ // and this package. `env.Set(ctx, x, y)` will override x from os.Environ. func All(ctx context.Context) map[string]string { diff --git a/libs/env/context_test.go b/libs/env/context_test.go index 39553448cc..5befe4acff 100644 --- a/libs/env/context_test.go +++ b/libs/env/context_test.go @@ -47,3 +47,10 @@ func TestContext(t *testing.T) { assert.Equal(t, "x=y", all["BAR"]) assert.NotEmpty(t, all["PATH"]) } + +func TestHome(t *testing.T) { + ctx := context.Background() + ctx = WithUserHomeDir(ctx, "...") + home := UserHomeDir(ctx) + assert.Equal(t, "...", home) +} diff --git a/libs/env/loader.go b/libs/env/loader.go new file mode 100644 index 0000000000..f441ffa15e --- /dev/null +++ b/libs/env/loader.go @@ -0,0 +1,50 @@ +package env + +import ( + "context" + + "github.com/databricks/databricks-sdk-go/config" +) + +// NewConfigLoader creates Databricks SDK Config loader that is aware of env.Set variables: +// +// ctx = env.Set(ctx, "DATABRICKS_WAREHOUSE_ID", "...") +// +// Usage: +// +// &config.Config{ +// Loaders: []config.Loader{ +// env.NewConfigLoader(ctx), +// config.ConfigAttributes, +// config.ConfigFile, +// }, +// } +func NewConfigLoader(ctx context.Context) *configLoader { + return &configLoader{ + ctx: ctx, + } +} + +type configLoader struct { + ctx context.Context +} + +func (le *configLoader) Name() string { + return "cli-env" +} + +func (le *configLoader) Configure(cfg *config.Config) error { + for _, a := range config.ConfigAttributes { + if !a.IsZero(cfg) { + continue + } + for _, k := range a.EnvVars { + v := Get(le.ctx, k) + if v == "" { + continue + } + a.Set(cfg, v) + } + } + return nil +} diff --git a/libs/env/loader_test.go b/libs/env/loader_test.go new file mode 100644 index 0000000000..2d1fa4002e --- /dev/null +++ b/libs/env/loader_test.go @@ -0,0 +1,26 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" +) + +func TestLoader(t *testing.T) { + ctx := context.Background() + ctx = Set(ctx, "DATABRICKS_WAREHOUSE_ID", "...") + ctx = Set(ctx, "DATABRICKS_CONFIG_PROFILE", "...") + loader := NewConfigLoader(ctx) + + cfg := &config.Config{ + Profile: "abc", + } + err := loader.Configure(cfg) + assert.NoError(t, err) + + assert.Equal(t, "...", cfg.WarehouseID) + assert.Equal(t, "abc", cfg.Profile) + assert.Equal(t, "cli-env", loader.Name()) +}