Skip to content

Commit

Permalink
refactor(middleware/auth): 添加 Auth.GetInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed Mar 22, 2024
1 parent 2919560 commit 9bcdcaa
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 105 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ _testmain.go
.vscode/
.idea/
.zed/
.vs/
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/issue9/assert/v4 v4.1.1
github.com/issue9/cache v0.10.0
github.com/issue9/mux/v7 v7.4.2
github.com/issue9/rands/v3 v3.0.1
github.com/issue9/unique/v2 v2.1.0
github.com/issue9/web v0.88.2
Expand All @@ -24,7 +25,6 @@ require (
github.com/issue9/errwrap v0.3.2 // indirect
github.com/issue9/localeutil v0.26.5 // indirect
github.com/issue9/logs/v7 v7.5.1 // indirect
github.com/issue9/mux/v7 v7.4.2 // indirect
github.com/issue9/query/v3 v3.1.3 // indirect
github.com/issue9/scheduled v0.19.3 // indirect
github.com/issue9/sliceutil v0.16.0 // indirect
Expand Down
27 changes: 27 additions & 0 deletions internal/mauth/mauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: 2024 caixw
//
// SPDX-License-Identifier: MIT

// Package mauth middlewares/auth 的私有函数
package mauth

import "github.com/issue9/web"

type keyType int

const valueKey keyType = 1

const AuthorizationHeader = "Authorization"

// Set 更新 [web.Context] 保存的值
func Set[T any](ctx *web.Context, val T) { ctx.SetVar(valueKey, val) }

// Get 获取当前对话关联的信息
func Get[T any](ctx *web.Context) (val T, found bool) {
if v, found := ctx.GetVar(valueKey); found {
return v.(T), true
}

var zero T
return zero, false
}
30 changes: 30 additions & 0 deletions internal/mauth/mauth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// SPDX-FileCopyrightText: 2024 caixw
//
// SPDX-License-Identifier: MIT

package mauth

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/issue9/assert/v4"
"github.com/issue9/mux/v7/types"
"github.com/issue9/web/server"
)

func TestGetSet(t *testing.T) {
a := assert.New(t, false)

s, err := server.New("test", "1.0.0", nil)
a.NotError(err).NotNil(s)

ctx := s.NewContext(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/path", nil), types.NewContext())
val, found := Get[int](ctx)
a.False(found).Zero(val)

Set(ctx, 5)
val, found = Get[int](ctx)
a.True(found).Equal(val, 5)
}
28 changes: 26 additions & 2 deletions middlewares/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,36 @@
// Package auth 登录凭证的验证
package auth

import "github.com/issue9/web"
import (
"strings"

"github.com/issue9/web"
)

// Auth 登录凭证的验证接口
type Auth interface {
type Auth[T any] interface {
web.Middleware

// Logout 退出
Logout(*web.Context) error

// GetInfo 获取用户数据
//
// 当验证通过之后,验证接口同时会将用户信息写入到 [web.Context]
// 可通过当前方法获取写入的数据。
GetInfo(*web.Context) (T, bool)
}

// GetToken 获取客户端提交的 token
//
// header 表示报头的名称;
// prefix 表示报头内容的前缀;
func GetToken(ctx *web.Context, prefix, header string) string {
prefixLen := len(prefix)
h := ctx.Request().Header.Get(header)
if len(h) > prefixLen && strings.ToLower(h[:prefixLen]) == prefix {
return h[prefixLen:]
}
ctx.Logs().DEBUG().LocaleString(web.Phrase("the client %s header %s is invalid format", header, h))
return ""
}
28 changes: 8 additions & 20 deletions middlewares/auth/basic/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ package basic
import (
"bytes"
"encoding/base64"
"strings"

"github.com/issue9/web"
)

type keyType int

const valueKey keyType = 1
"github.com/issue9/webuse/v7/internal/mauth"
"github.com/issue9/webuse/v7/middlewares/auth"
)

const prefix = "basic "

Expand Down Expand Up @@ -51,12 +49,12 @@ type basic[T any] struct {
// T 表示验证成功之后,向用户传递的一些额外信息。之后可通过 [GetValue] 获取。
//
// [Basic 验证]: https://datatracker.ietf.org/doc/html/rfc7617
func New[T any](srv web.Server, auth AuthFunc[T], realm string, proxy bool) web.Middleware {
func New[T any](srv web.Server, auth AuthFunc[T], realm string, proxy bool) auth.Auth[T] {
if auth == nil {
panic("auth 参数不能为空")
}

authorization := "Authorization"
authorization := mauth.AuthorizationHeader
authenticate := "WWW-Authenticate"
problemID := web.ProblemUnauthorized
if proxy {
Expand All @@ -79,10 +77,7 @@ func New[T any](srv web.Server, auth AuthFunc[T], realm string, proxy bool) web.

func (b *basic[T]) Middleware(next web.HandlerFunc) web.HandlerFunc {
return func(ctx *web.Context) web.Responser {
h := ctx.Request().Header.Get(b.authorization)
if len(h) > prefixLen && strings.ToLower(h[:prefixLen]) == prefix {
h = h[prefixLen:]
}
h := auth.GetToken(ctx, prefix, b.authorization)

secret, err := base64.StdEncoding.DecodeString(h)
if err != nil {
Expand All @@ -98,8 +93,8 @@ func (b *basic[T]) Middleware(next web.HandlerFunc) web.HandlerFunc {
if !ok {
return b.unauthorization(ctx)
}
ctx.SetVar(valueKey, v)

mauth.Set(ctx, v)
return next(ctx)
}
}
Expand All @@ -111,11 +106,4 @@ func (b *basic[T]) unauthorization(ctx *web.Context) web.Responser {
return ctx.Problem(b.problemID)
}

// GetValue 获取当前对话关联的登录信息
func GetValue[T any](ctx *web.Context) (T, bool) {
if v, found := ctx.GetVar(valueKey); found {
return v.(T), true
}
var vv T
return vv, false
}
func (b *basic[T]) GetInfo(ctx *web.Context) (T, bool) { return mauth.Get[T](ctx) }
13 changes: 7 additions & 6 deletions middlewares/auth/basic/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/issue9/web/server"
"github.com/issue9/web/server/servertest"

"github.com/issue9/webuse/v7/internal/mauth"
"github.com/issue9/webuse/v7/middlewares/auth"
)

Expand All @@ -21,7 +22,7 @@ var (
return username, true
}

_ auth.Auth = &basic[[]byte]{}
_ auth.Auth[[]byte] = &basic[[]byte]{}
)

func TestNew(t *testing.T) {
Expand All @@ -39,7 +40,7 @@ func TestNew(t *testing.T) {

b = New(srv, authFunc, "", false).(*basic[[]byte])

a.Equal(b.authorization, "Authorization").
a.Equal(b.authorization, mauth.AuthorizationHeader).
Equal(b.authenticate, "WWW-Authenticate").
Equal(b.problemID, web.ProblemUnauthorized).
NotNil(b.auth)
Expand All @@ -66,7 +67,7 @@ func TestServeHTTP_ok(t *testing.T) {
r := s.Routers().New("def", nil)
r.Use(b)
r.Get("/path", func(ctx *web.Context) web.Responser {
username, found := GetValue[[]byte](ctx)
username, found := b.GetInfo(ctx)
a.True(found).Equal(string(username), "Aladdin")
return web.Status(http.StatusCreated)
})
Expand All @@ -81,7 +82,7 @@ func TestServeHTTP_ok(t *testing.T) {

// 正确的访问
servertest.Get(a, "http://localhost:8080/path").
Header("Authorization", "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="). // Aladdin, open sesame,来自 https://zh.wikipedia.org/wiki/HTTP基本认证
Header(mauth.AuthorizationHeader, "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="). // Aladdin, open sesame,来自 https://zh.wikipedia.org/wiki/HTTP基本认证
Do(nil).
Status(http.StatusCreated)
}
Expand All @@ -100,7 +101,7 @@ func TestServeHTTP_failed(t *testing.T) {
r := s.Routers().New("def", nil)
r.Use(b)
r.Get("/path", func(ctx *web.Context) web.Responser {
obj, found := GetValue[[]byte](ctx)
obj, found := b.GetInfo(ctx)
a.True(found).Nil(obj)
return nil
})
Expand All @@ -115,7 +116,7 @@ func TestServeHTTP_failed(t *testing.T) {

// 错误的编码
servertest.Get(a, "http://localhost:8080/path").
Header("Authorization", "Basic aaQWxhZGRpbjpvcGVuIHNlc2FtZQ===").
Header(mauth.AuthorizationHeader, "Basic aaQWxhZGRpbjpvcGVuIHNlc2FtZQ===").
Do(nil).
Status(http.StatusUnauthorized)
}
3 changes: 1 addition & 2 deletions middlewares/auth/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ func (j *JWT[T]) VerifiyRefresh(next web.HandlerFunc) web.HandlerFunc {
// Middleware 解码用户的 token 并写入 [web.Context]
func (j *JWT[T]) Middleware(next web.HandlerFunc) web.HandlerFunc { return j.v.Middleware(next) }

// GetValue 返回解码后的 Claims 对象
func (j *JWT[T]) GetValue(ctx *web.Context) (T, bool) { return j.v.GetValue(ctx) }
func (j *JWT[T]) GetInfo(ctx *web.Context) (T, bool) { return j.v.GetInfo(ctx) }

// Render 向客户端输出令牌
//
Expand Down
29 changes: 16 additions & 13 deletions middlewares/auth/jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ import (
xjson "github.com/issue9/web/mimetype/json"
"github.com/issue9/web/server"
"github.com/issue9/web/server/servertest"

"github.com/issue9/webuse/v7/internal/mauth"
"github.com/issue9/webuse/v7/middlewares/auth"
)

var _ web.Middleware = &JWT[*testClaims]{}
var _ auth.Auth[*testClaims] = &JWT[*testClaims]{}

func newJWT(a *assert.Assertion, expired, refresh time.Duration) (web.Server, *JWT[*testClaims]) {
s, err := server.New("test", "1.0.0", &server.Options{
Expand Down Expand Up @@ -125,7 +128,7 @@ func verifierMiddleware(a *assert.Assertion, s web.Server, j *JWT[*testClaims])
r.Post("/refresh", j.VerifiyRefresh(func(ctx *web.Context) web.Responser {
a.TB().Helper()

claims, ok := j.GetValue(ctx)
claims, ok := j.GetInfo(ctx)
if !ok {
return ctx.Problem(web.ProblemUnauthorized)
}
Expand All @@ -136,7 +139,7 @@ func verifierMiddleware(a *assert.Assertion, s web.Server, j *JWT[*testClaims])
r.Get("/info", j.Middleware(func(ctx *web.Context) web.Responser {
a.TB().Helper()

val, found := j.GetValue(ctx)
val, found := j.GetInfo(ctx)
if !found {
return web.Status(http.StatusNotFound)
}
Expand Down Expand Up @@ -171,13 +174,13 @@ func verifierMiddleware(a *assert.Assertion, s web.Server, j *JWT[*testClaims])
NotEmpty(resp.Refresh)

servertest.Get(a, "http://localhost:8080/info").
Header("Authorization", prefix+resp.Access).
Header(mauth.AuthorizationHeader, prefix+resp.Access).
Do(nil).
Status(http.StatusOK)

resp2 := &Response{}
servertest.Post(a, "http://localhost:8080/refresh", nil).
Header("Authorization", prefix+resp.Refresh).
Header(mauth.AuthorizationHeader, prefix+resp.Refresh).
Do(nil).
Status(http.StatusCreated).
BodyFunc(func(a *assert.Assertion, body []byte) {
Expand All @@ -194,24 +197,24 @@ func verifierMiddleware(a *assert.Assertion, s web.Server, j *JWT[*testClaims])

// 旧令牌已经无法访问
servertest.Get(a, "http://localhost:8080/info").
Header("Authorization", prefix+resp.Access).
Header(mauth.AuthorizationHeader, prefix+resp.Access).
Do(nil).
Status(http.StatusUnauthorized)

// 新令牌可以访问
servertest.Get(a, "http://localhost:8080/info").
Header("Authorization", prefix+resp2.Access).
Header(mauth.AuthorizationHeader, prefix+resp2.Access).
Do(nil).
Status(http.StatusOK)

servertest.Delete(a, "http://localhost:8080/login").
Header("Authorization", prefix+resp2.Access).
Header(mauth.AuthorizationHeader, prefix+resp2.Access).
Do(nil).
Status(http.StatusNoContent)

// token 已经在 delete /login 中被弃用
servertest.Get(a, "http://localhost:8080/info").
Header("Authorization", prefix+resp2.Access).
Header(mauth.AuthorizationHeader, prefix+resp2.Access).
Do(nil).
Status(http.StatusUnauthorized)
})
Expand All @@ -232,7 +235,7 @@ func TestVerifier_client(t *testing.T) {
})

r.Get("/info", j.Middleware(func(ctx *web.Context) web.Responser {
val, found := j.GetValue(ctx)
val, found := j.GetInfo(ctx)
if !found {
return web.Status(http.StatusNotFound)
}
Expand Down Expand Up @@ -265,7 +268,7 @@ func TestVerifier_client(t *testing.T) {
header["alg"] = "ES256"
parts[0] = encodeHeader(a, header)
servertest.Get(a, "http://localhost:8080/info").
Header("Authorization", "BEARER "+strings.Join(parts, ".")).
Header(mauth.AuthorizationHeader, "BEARER "+strings.Join(parts, ".")).
Do(nil).
Status(http.StatusUnauthorized)

Expand All @@ -275,7 +278,7 @@ func TestVerifier_client(t *testing.T) {
header["alg"] = "none"
parts[0] = encodeHeader(a, header)
servertest.Get(a, "http://localhost:8080/info").
Header("Authorization", "BEARER "+strings.Join(parts, ".")).
Header(mauth.AuthorizationHeader, "BEARER "+strings.Join(parts, ".")).
Do(nil).
Status(http.StatusUnauthorized)

Expand All @@ -284,7 +287,7 @@ func TestVerifier_client(t *testing.T) {
header["alg"] = "none"
parts[0] = encodeHeader(a, header)
servertest.Get(a, "http://localhost:8080/info").
Header("Authorization", "BEARER "+strings.Join(parts, ".")).
Header(mauth.AuthorizationHeader, "BEARER "+strings.Join(parts, ".")).
Do(nil).
Status(http.StatusUnauthorized)
})
Expand Down
Loading

0 comments on commit 9bcdcaa

Please sign in to comment.