From a029b705cd412dc0b171b5a8be3a3f21a6868090 Mon Sep 17 00:00:00 2001 From: Nikola Jokic Date: Thu, 21 Dec 2023 15:35:36 +0100 Subject: [PATCH] Fix proxy issue in new listener client (#3181) --- cmd/ghalistener/config/config.go | 13 +- cmd/ghalistener/config/config_client_test.go | 161 +++++++++++++++++++ 2 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 cmd/ghalistener/config/config_client_test.go diff --git a/cmd/ghalistener/config/config.go b/cmd/ghalistener/config/config.go index 127a57e755..d27d6af994 100644 --- a/cmd/ghalistener/config/config.go +++ b/cmd/ghalistener/config/config.go @@ -4,6 +4,8 @@ import ( "crypto/x509" "encoding/json" "fmt" + "net/http" + "net/url" "os" "github.com/actions/actions-runner-controller/build" @@ -101,7 +103,7 @@ func (c *Config) Logger() (logr.Logger, error) { return logger, nil } -func (c *Config) ActionsClient(logger logr.Logger) (*actions.Client, error) { +func (c *Config) ActionsClient(logger logr.Logger, clientOptions ...actions.ClientOption) (*actions.Client, error) { var creds actions.ActionsAuth switch c.Token { case "": @@ -114,9 +116,9 @@ func (c *Config) ActionsClient(logger logr.Logger) (*actions.Client, error) { creds.Token = c.Token } - options := []actions.ClientOption{ + options := append([]actions.ClientOption{ actions.WithLogger(logger), - } + }, clientOptions...) if c.ServerRootCA != "" { systemPool, err := x509.SystemCertPool() @@ -132,6 +134,11 @@ func (c *Config) ActionsClient(logger logr.Logger) (*actions.Client, error) { options = append(options, actions.WithRootCAs(pool)) } + proxyFunc := httpproxy.FromEnvironment().ProxyFunc() + options = append(options, actions.WithProxy(func(req *http.Request) (*url.URL, error) { + return proxyFunc(req.URL) + })) + client, err := actions.NewClient(c.ConfigureUrl, &creds, options...) if err != nil { return nil, fmt.Errorf("failed to create actions client: %w", err) diff --git a/cmd/ghalistener/config/config_client_test.go b/cmd/ghalistener/config/config_client_test.go new file mode 100644 index 0000000000..29a10b181b --- /dev/null +++ b/cmd/ghalistener/config/config_client_test.go @@ -0,0 +1,161 @@ +package config_test + +import ( + "context" + "crypto/tls" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/actions/actions-runner-controller/cmd/ghalistener/config" + "github.com/actions/actions-runner-controller/github/actions" + "github.com/actions/actions-runner-controller/github/actions/testserver" + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCustomerServerRootCA(t *testing.T) { + ctx := context.Background() + certsFolder := filepath.Join( + "../../../", + "github", + "actions", + "testdata", + ) + certPath := filepath.Join(certsFolder, "server.crt") + keyPath := filepath.Join(certsFolder, "server.key") + + serverCalledSuccessfully := false + + server := testserver.NewUnstarted(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + serverCalledSuccessfully = true + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"count": 0}`)) + })) + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} + server.StartTLS() + + var certsString string + rootCA, err := os.ReadFile(filepath.Join(certsFolder, "rootCA.crt")) + require.NoError(t, err) + certsString = string(rootCA) + + intermediate, err := os.ReadFile(filepath.Join(certsFolder, "intermediate.pem")) + require.NoError(t, err) + certsString = certsString + string(intermediate) + + config := config.Config{ + ConfigureUrl: server.ConfigURLForOrg("myorg"), + ServerRootCA: certsString, + Token: "token", + } + + client, err := config.ActionsClient(logr.Discard()) + require.NoError(t, err) + _, err = client.GetRunnerScaleSet(ctx, 1, "test") + require.NoError(t, err) + assert.True(t, serverCalledSuccessfully) +} + +func TestProxySettings(t *testing.T) { + t.Run("http", func(t *testing.T) { + wentThroughProxy := false + + proxy := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + wentThroughProxy = true + })) + t.Cleanup(func() { + proxy.Close() + }) + + prevProxy := os.Getenv("http_proxy") + os.Setenv("http_proxy", proxy.URL) + defer os.Setenv("http_proxy", prevProxy) + + config := config.Config{ + ConfigureUrl: "https://github.com/org/repo", + Token: "token", + } + + client, err := config.ActionsClient(logr.Discard()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + require.NoError(t, err) + _, err = client.Do(req) + require.NoError(t, err) + + assert.True(t, wentThroughProxy) + }) + + t.Run("https", func(t *testing.T) { + wentThroughProxy := false + + proxy := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + wentThroughProxy = true + })) + t.Cleanup(func() { + proxy.Close() + }) + + prevProxy := os.Getenv("https_proxy") + os.Setenv("https_proxy", proxy.URL) + defer os.Setenv("https_proxy", prevProxy) + + config := config.Config{ + ConfigureUrl: "https://github.com/org/repo", + Token: "token", + } + + client, err := config.ActionsClient(logr.Discard(), actions.WithRetryMax(0)) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + + _, err = client.Do(req) + // proxy doesn't support https + assert.Error(t, err) + assert.True(t, wentThroughProxy) + }) + + t.Run("no_proxy", func(t *testing.T) { + wentThroughProxy := false + + proxy := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + wentThroughProxy = true + })) + t.Cleanup(func() { + proxy.Close() + }) + + prevProxy := os.Getenv("http_proxy") + os.Setenv("http_proxy", proxy.URL) + defer os.Setenv("http_proxy", prevProxy) + + prevNoProxy := os.Getenv("no_proxy") + os.Setenv("no_proxy", "example.com") + defer os.Setenv("no_proxy", prevNoProxy) + + config := config.Config{ + ConfigureUrl: "https://github.com/org/repo", + Token: "token", + } + + client, err := config.ActionsClient(logr.Discard()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + require.NoError(t, err) + + _, err = client.Do(req) + require.NoError(t, err) + assert.False(t, wentThroughProxy) + }) +}