diff --git a/backend/pkg/middleware/auth.go b/backend/pkg/middleware/auth.go index 71ae2cc70..19b9f2f73 100644 --- a/backend/pkg/middleware/auth.go +++ b/backend/pkg/middleware/auth.go @@ -11,14 +11,14 @@ func NewAuthSkipper(auth string) middleware.Skipper { return func(c echo.Context) bool { switch auth { case "oidc": - paths := []string{"/health", "/login", "/config", "/*", "/login/cb", "/login/token", "/v1/update"} + paths := []string{"/health", "/login", "/config", "/*", "/flatcar/*", "/login/cb", "/login/token", "/v1/update"} for _, path := range paths { if c.Path() == path { return true } } case "github": - paths := []string{"/health", "/v1/update", "/login/cb", "/login/webhook"} + paths := []string{"/health", "/v1/update", "/login/cb", "/login/webhook", "/flatcar/*"} for _, path := range paths { if c.Path() == path { return true diff --git a/backend/pkg/server/server.go b/backend/pkg/server/server.go index 7795e0eac..bb7decd52 100644 --- a/backend/pkg/server/server.go +++ b/backend/pkg/server/server.go @@ -36,8 +36,7 @@ var ( logger = util.NewLogger("nebraska") middlewareSkipper = func(c echo.Context) bool { requestPath := c.Path() - - paths := []string{"/health", "/metrics", "/config", "/v1/update", "/*"} + paths := []string{"/health", "/metrics", "/config", "/v1/update", "/flatcar/*", "/*"} for _, path := range paths { if requestPath == path { return true @@ -108,12 +107,18 @@ func New(conf *config.Config, db *db.API) (*echo.Echo, error) { e.Static("/", conf.HTTPStaticDir) + if conf.HostFlatcarPackages && conf.FlatcarPackagesPath != "" { + e.Static("/flatcar/", conf.FlatcarPackagesPath) + } + e.HTTPErrorHandler = func(err error, c echo.Context) { code := http.StatusNotFound if he, ok := err.(*echo.HTTPError); ok { if code == he.Code { fileErr := c.File(path.Join(conf.HTTPStaticDir, "index.html")) - logger.Err(fileErr).Msg("Error serving index.html") + if fileErr != nil { + logger.Err(fileErr).Msg("Error serving index.html") + } return } } diff --git a/backend/test/api/flatcar_package_test.go b/backend/test/api/flatcar_package_test.go new file mode 100644 index 000000000..3bd21bd6e --- /dev/null +++ b/backend/test/api/flatcar_package_test.go @@ -0,0 +1,97 @@ +package api_test + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "os" + "path" + "testing" + + "github.com/google/uuid" + "github.com/kinvolk/nebraska/backend/pkg/config" + "github.com/kinvolk/nebraska/backend/pkg/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHostFlatcarPackage(t *testing.T) { + currentDir, err := os.Getwd() + require.NoError(t, err) + + serverPort := uint(6000) + serverPortStr := fmt.Sprintf(":%d", serverPort) + + conf := &config.Config{ + HostFlatcarPackages: true, + FlatcarPackagesPath: currentDir, + AuthMode: "noop", + ServerPort: serverPort, + } + + db := newDBForTest(t) + + t.Run("file_exists", func(t *testing.T) { + server, err := server.New(conf, db) + require.NotNil(t, server) + require.NoError(t, err) + + //nolint:errcheck + go server.Start(serverPortStr) + + //nolint:errcheck + defer server.Shutdown(context.Background()) + + // create a temp file + fileName := fmt.Sprintf("%s.txt", uuid.NewString()) + file, err := os.Create(path.Join(currentDir, fileName)) + require.NoError(t, err) + + fileString := "This is a test" + _, err = file.WriteString(fileString) + require.NoError(t, err) + err = file.Close() + require.NoError(t, err) + + _, err = waitServerReady(fmt.Sprintf("http://localhost:%d", serverPort)) + require.NoError(t, err) + + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/flatcar/%s", serverPort, fileName)) + assert.NoError(t, err) + assert.NotNil(t, resp) + + bodyBytes, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, fileString, string(bodyBytes)) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // delete the temp file + err = os.Remove(path.Join(currentDir, fileName)) + require.NoError(t, err) + }) + + t.Run("file_not_exists", func(t *testing.T) { + server, err := server.New(conf, db) + require.NotNil(t, server) + require.NoError(t, err) + + fileName := fmt.Sprintf("%s.txt", uuid.NewString()) + + //nolint:errcheck + go server.Start(serverPortStr) + + //nolint:errcheck + defer server.Shutdown(context.Background()) + + _, err = waitServerReady(fmt.Sprintf("http://localhost:%d", serverPort)) + require.NoError(t, err) + + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/flatcar/%s", serverPort, fileName)) + assert.NoError(t, err) + assert.NotNil(t, resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) +} diff --git a/backend/test/api/helper_test.go b/backend/test/api/helper_test.go index 81e79b6a4..23579adc8 100644 --- a/backend/test/api/helper_test.go +++ b/backend/test/api/helper_test.go @@ -3,6 +3,7 @@ package api_test import ( "encoding/json" "encoding/xml" + "errors" "fmt" "io" "io/ioutil" @@ -107,3 +108,33 @@ func httpDo(t *testing.T, url string, method string, payload io.Reader, statusco require.NoError(t, err) } } + +var ErrOutOfRetries = errors.New("test: out of retries") + +func waitServerReady(serverURL string) (bool, error) { + retries := 5 + for i := 0; i < retries; i++ { + if i != 0 { + time.Sleep(100 * time.Millisecond) + } + req, err := http.NewRequest("GET", fmt.Sprintf("%s/health", serverURL), nil) + if err != nil { + continue + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + continue + } + + bodyBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + continue + } + + if (http.StatusOK == resp.StatusCode) && ("OK" == string(bodyBytes)) { + return true, nil + } + } + return false, ErrOutOfRetries +}