Skip to content
This repository has been archived by the owner on Jun 20, 2024. It is now read-only.

Commit

Permalink
Replace mux with gin router
Browse files Browse the repository at this point in the history
  • Loading branch information
norling committed Mar 29, 2023
1 parent a5f3579 commit fa64d9d
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 190 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ There is an README file in the [dev_utils](/dev_utils) folder with sections for
| Component | Role |
|---------------|------|
| middleware | Performs access token verification and validation |
| sda | Constructs the main API endpoints fort the NeIC SDA Data Out API. |
| sda | Constructs the main API endpoints for the NeIC SDA Data Out API. |


## Internal Components
Expand All @@ -43,4 +43,4 @@ There is an README file in the [dev_utils](/dev_utils) folder with sections for
| Component | Role |
|---------------|------|
| auth | Auth pkg is used by the middleware to parse OIDC Details and extract GA4GH Visas from a [GA4GH Passport](https://github.com/ga4gh-duri/ga4gh-duri.github.io/blob/master/researcher_ids/ga4gh_passport_v1.md) |
| request | This pkg Stores a HTTP client, so that it doesn't need to be initialised on every request. |
| request | This pkg Stores a HTTP client, so that it doesn't need to be initialised on every request. |
19 changes: 10 additions & 9 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,30 @@ import (
"net/http"
"time"

"github.com/gorilla/mux"
"github.com/gin-gonic/gin"
"github.com/neicnordic/sda-download/api/middleware"
"github.com/neicnordic/sda-download/api/sda"
"github.com/neicnordic/sda-download/internal/config"
log "github.com/sirupsen/logrus"
)

// healthResponse
func healthResponse(w http.ResponseWriter, r *http.Request) {
func healthResponse(c *gin.Context) {
// ok response to health
w.WriteHeader(http.StatusOK)
c.Writer.WriteHeader(http.StatusOK)
}

// Setup configures the web server and registers the routes
func Setup() *http.Server {
// Set up routing
log.Info("(2/5) Registering endpoint handlers")
r := mux.NewRouter().SkipClean(true)

r.Handle("/metadata/datasets", middleware.TokenMiddleware(http.HandlerFunc(sda.Datasets))).Methods("GET")
r.Handle("/metadata/datasets/{dataset:[^\\s/$.?#].[^\\s]+|[A-Za-z0-9-_:.]+}/files", middleware.TokenMiddleware(http.HandlerFunc(sda.Files))).Methods("GET")
r.Handle("/files/{fileid:[A-Za-z0-9-_:.]+}", middleware.TokenMiddleware(http.HandlerFunc(sda.Download))).Methods("GET")
r.HandleFunc("/health", healthResponse).Methods("GET")
router := gin.Default()

router.GET("/metadata/datasets", middleware.TokenMiddleware(), sda.Datasets)
router.GET("/metadata/datasets/:dataset/files", middleware.TokenMiddleware(), sda.Files)
router.GET("/files/:fileid", middleware.TokenMiddleware(), sda.Download)
router.GET("/health", healthResponse)

// Configure TLS settings
log.Info("(3/5) Configuring TLS")
Expand All @@ -45,7 +46,7 @@ func Setup() *http.Server {
log.Info("(4/5) Configuring server")
srv := &http.Server{
Addr: config.Config.App.Host + ":" + fmt.Sprint(config.Config.App.Port),
Handler: r,
Handler: router,
TLSConfig: cfg,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
ReadHeaderTimeout: 20 * time.Second,
Expand Down
64 changes: 27 additions & 37 deletions api/middleware/middleware.go
Original file line number Diff line number Diff line change
@@ -1,46 +1,43 @@
package middleware

import (
"context"
"net/http"

"github.com/gin-gonic/gin"
"github.com/neicnordic/sda-download/internal/config"
"github.com/neicnordic/sda-download/internal/session"
"github.com/neicnordic/sda-download/pkg/auth"
log "github.com/sirupsen/logrus"
)

type stringVariable string

// as specified in docs: https://pkg.go.dev/context#WithValue
var datasetsKey = stringVariable("datasets")
var datasetsKey = "datasets"

// TokenMiddleware performs access token verification and validation
// JWTs are verified and validated by the app, opaque tokens are sent to AAI for verification
// Successful auth results in list of authorised datasets
func TokenMiddleware(nextHandler http.Handler) http.Handler {

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func TokenMiddleware() gin.HandlerFunc {

return func(c *gin.Context) {
// Check if dataset permissions are cached to session
sessionCookie, err := r.Cookie(config.Config.Session.Name)
sessionCookie, err := c.Cookie(config.Config.Session.Name)
if err != nil {
log.Debugf("no session cookie received")
}
var datasets []string
var exists bool
if sessionCookie != nil {
if sessionCookie != "" {
log.Debug("session cookie received")
datasets, exists = session.Get(sessionCookie.Value)
datasets, exists = session.Get(sessionCookie)
}

if !exists {
log.Debug("no session found, create new session")

// Check that a token is provided
token, code, err := auth.GetToken(r.Header.Get("Authorization"))
token, code, err := auth.GetToken(c.Request.Header.Get("Authorization"))
if err != nil {
http.Error(w, err.Error(), code)
c.String(code, err.Error())

return
}
Expand All @@ -49,7 +46,7 @@ func TokenMiddleware(nextHandler http.Handler) http.Handler {
visas, err := auth.GetVisas(auth.Details, token)
if err != nil {
log.Debug("failed to validate token at AAI")
http.Error(w, "bad token", 401)
c.String(http.StatusUnauthorized, "bad token")

return
}
Expand All @@ -64,47 +61,40 @@ func TokenMiddleware(nextHandler http.Handler) http.Handler {
// Start a new session and store datasets under the session key
key := session.NewSessionKey()
session.Set(key, datasets)
sessionCookie := &http.Cookie{
Name: config.Config.Session.Name,
Value: key,
Domain: config.Config.Session.Domain,
Secure: config.Config.Session.Secure,
HttpOnly: config.Config.Session.HTTPOnly,
// time.Duration is stored in nanoseconds, but MaxAge wants seconds
MaxAge: int(config.Config.Session.Expiration) / 1e9,
}
http.SetCookie(w, sessionCookie)
c.SetCookie(config.Config.Session.Name, // name
key, // value
int(config.Config.Session.Expiration)/1e9, // max age
"/", // path
config.Config.Session.Domain, // domain
config.Config.Session.Secure, // secure
config.Config.Session.HTTPOnly, // httpOnly
)
log.Debug("authorization check passed")
}

// Store dataset list to request context, for use in the endpoint handlers
modifiedContext := storeDatasets(r.Context(), datasets)
modifiedRequest := r.WithContext(modifiedContext)
storeDatasets(c, datasets)

// Forward request to the next endpoint handler
nextHandler.ServeHTTP(w, modifiedRequest)
})
c.Next()
}

}

// storeDatasets stores the dataset list to the request context
func storeDatasets(ctx context.Context, datasets []string) context.Context {
func storeDatasets(c *gin.Context, datasets []string) *gin.Context {
log.Debugf("storing %v datasets to request context", datasets)

ctx = context.WithValue(ctx, datasetsKey, datasets)
c.Set(datasetsKey, datasets)

return ctx
return c
}

// GetDatasets extracts the dataset list from the request context
var GetDatasets = func(ctx context.Context) []string {
datasets := ctx.Value(datasetsKey)
if datasets == nil {
log.Debug("request datasets context is empty")
var GetDatasets = func(c *gin.Context) []string {
datasets := c.GetStringSlice(datasetsKey)

return []string{}
}
log.Debugf("returning %v from request context", datasets)

return datasets.([]string)
return datasets
}
89 changes: 32 additions & 57 deletions api/middleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ import (
"reflect"
"testing"

"github.com/gin-gonic/gin"
"github.com/neicnordic/sda-download/internal/config"
"github.com/neicnordic/sda-download/internal/session"
"github.com/neicnordic/sda-download/pkg/auth"
log "github.com/sirupsen/logrus"
)

const token string = "token"

// testEndpoint mimics the endpoint handlers that perform business logic after passing the
// authentication middleware. This handler is generic and can be used for all cases.
func testEndpoint(w http.ResponseWriter, r *http.Request) {}
func testEndpoint(c *gin.Context) {}

func TestTokenMiddleware_Fail_GetToken(t *testing.T) {

Expand All @@ -32,18 +34,19 @@ func TestTokenMiddleware_Fail_GetToken(t *testing.T) {

// Mock request and response holders
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "https://testing.fi", nil)
r := httptest.NewRequest("GET", "/", nil)
_, router := gin.CreateTestContext(w)

// Send a request through the middleware
testHandler := TokenMiddleware(http.HandlerFunc(testEndpoint))
testHandler.ServeHTTP(w, r)
router.GET("/", TokenMiddleware(), testEndpoint)
router.ServeHTTP(w, r)

// Test the outcomes of the handler
response := w.Result()
defer response.Body.Close()
body, _ := io.ReadAll(response.Body)
expectedStatusCode := 401
expectedBody := []byte("access token must be provided\n")
expectedBody := []byte("access token must be provided")

if response.StatusCode != expectedStatusCode {
t.Errorf("TestTokenMiddleware_Fail_GetToken failed, got %d expected %d", response.StatusCode, expectedStatusCode)
Expand Down Expand Up @@ -76,18 +79,19 @@ func TestTokenMiddleware_Fail_GetVisas(t *testing.T) {

// Mock request and response holders
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "https://testing.fi", nil)
r := httptest.NewRequest("GET", "/", nil)
_, router := gin.CreateTestContext(w)

// Send a request through the middleware
testHandler := TokenMiddleware(http.HandlerFunc(testEndpoint))
testHandler.ServeHTTP(w, r)
router.GET("/", TokenMiddleware(), testEndpoint)
router.ServeHTTP(w, r)

// Test the outcomes of the handler
response := w.Result()
defer response.Body.Close()
body, _ := io.ReadAll(response.Body)
expectedStatusCode := 401
expectedBody := []byte("bad token\n")
expectedBody := []byte("bad token")

if response.StatusCode != expectedStatusCode {
t.Errorf("TestTokenMiddleware_Fail_GetVisas failed, got %d expected %d", response.StatusCode, expectedStatusCode)
Expand Down Expand Up @@ -125,11 +129,12 @@ func TestTokenMiddleware_Fail_GetPermissions(t *testing.T) {

// Mock request and response holders
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "https://testing.fi", nil)
r := httptest.NewRequest("GET", "/", nil)
_, router := gin.CreateTestContext(w)

// Send a request through the middleware
testHandler := TokenMiddleware(http.HandlerFunc(testEndpoint))
testHandler.ServeHTTP(w, r)
router.GET("/", TokenMiddleware(), testEndpoint)
router.ServeHTTP(w, r)

// Test the outcomes of the handler
response := w.Result()
Expand Down Expand Up @@ -171,20 +176,21 @@ func TestTokenMiddleware_Success_NoCache(t *testing.T) {

// Mock request and response holders
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "https://testing.fi", nil)
r := httptest.NewRequest("GET", "/", nil)
_, router := gin.CreateTestContext(w)

// Now that we are modifying the request context, we need to place the context test inside the handler
expectedDatasets := []string{"dataset1", "dataset2"}
testEndpointWithContextData := func(w http.ResponseWriter, r *http.Request) {
datasets := r.Context().Value(datasetsKey).([]string)
testEndpointWithContextData := func(c *gin.Context) {
datasets := c.GetStringSlice(datasetsKey)
if !reflect.DeepEqual(datasets, expectedDatasets) {
t.Errorf("TestTokenMiddleware_Success_NoCache failed, got %s expected %s", datasets, expectedDatasets)
}
}

// Send a request through the middleware
testHandler := TokenMiddleware(http.HandlerFunc(testEndpointWithContextData))
testHandler.ServeHTTP(w, r)
router.GET("/", TokenMiddleware(), testEndpointWithContextData)
router.ServeHTTP(w, r)

// Test the outcomes of the handler
response := w.Result()
Expand Down Expand Up @@ -219,31 +225,35 @@ func TestTokenMiddleware_Success_FromCache(t *testing.T) {

// Substitute mock functions
session.Get = func(key string) ([]string, bool) {
log.Warningf("session.Get %v", key)

return []string{"dataset1", "dataset2"}, true
}

config.Config.Session.Name = "sda_session_key"

// Mock request and response holders
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "https://testing.fi", nil)
r := httptest.NewRequest("GET", "/", nil)
_, router := gin.CreateTestContext(w)

r.AddCookie(&http.Cookie{
Name: "sda_session_key",
Value: "key",
})

// Now that we are modifying the request context, we need to place the context test inside the handler
expectedDatasets := []string{"dataset1", "dataset2"}
testEndpointWithContextData := func(w http.ResponseWriter, r *http.Request) {
datasets := r.Context().Value(datasetsKey).([]string)
testEndpointWithContextData := func(c *gin.Context) {
datasets := c.GetStringSlice(datasetsKey)
if !reflect.DeepEqual(datasets, expectedDatasets) {
t.Errorf("TestTokenMiddleware_Success_FromCache failed, got %s expected %s", datasets, expectedDatasets)
}
}

// Send a request through the middleware
testHandler := TokenMiddleware(http.HandlerFunc(testEndpointWithContextData))
testHandler.ServeHTTP(w, r)
router.GET("/", TokenMiddleware(), testEndpointWithContextData)
router.ServeHTTP(w, r)

// Test the outcomes of the handler
response := w.Result()
Expand All @@ -264,38 +274,3 @@ func TestTokenMiddleware_Success_FromCache(t *testing.T) {
session.Get = originalGetCache

}

func TestStoreDatasets(t *testing.T) {

// Get a request context for testing if data is saved
r := httptest.NewRequest("GET", "https://testing.fi", nil)

// Store data to request context
datasets := []string{"dataset1", "dataset2"}
modifiedContext := storeDatasets(r.Context(), datasets)

// Verify that context has new data
storedDatasets := modifiedContext.Value(datasetsKey).([]string)
if !reflect.DeepEqual(datasets, storedDatasets) {
t.Errorf("TestStoreDatasets failed, got %s, expected %s", storedDatasets, datasets)
}

}

func TestGetDatasets(t *testing.T) {

// Get a request context for testing if data is saved
r := httptest.NewRequest("GET", "https://testing.fi", nil)

// Store data to request context
datasets := []string{"dataset1", "dataset2"}
modifiedContext := storeDatasets(r.Context(), datasets)
modifiedRequest := r.WithContext(modifiedContext)

// Verify that context has new data
storedDatasets := GetDatasets(modifiedRequest.Context())
if !reflect.DeepEqual(datasets, storedDatasets) {
t.Errorf("TestStoreDatasets failed, got %s, expected %s", storedDatasets, datasets)
}

}
Loading

0 comments on commit fa64d9d

Please sign in to comment.