From fed6ee27c12b6f55247d2c3ae2c3f7e153f25308 Mon Sep 17 00:00:00 2001 From: yuanzhao <2206582181@qq.com> Date: Thu, 6 Jun 2024 15:50:42 +0800 Subject: [PATCH] =?UTF-8?q?add:=20=E6=B7=BB=E5=8A=A0=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E8=BE=85=E5=8A=A9=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/func.go | 1 + app/http/context.go | 52 ++++++++++++++++++++++++------ bootstrap/services/http_server.go | 2 +- bootstrap/services/redis_server.go | 15 +++++---- 4 files changed, 53 insertions(+), 17 deletions(-) diff --git a/app/func.go b/app/func.go index 922640f..aa1a9c8 100644 --- a/app/func.go +++ b/app/func.go @@ -115,6 +115,7 @@ func StringToHump(s string) string { return string(data[:]) } +// GetRoot 获取项目根目录,在一些单元测试中可以便捷获取 func GetRoot() string { return app.GetBean("config").(app.GetRoot).GetRoot() } diff --git a/app/http/context.go b/app/http/context.go index 9f50106..ea341dd 100644 --- a/app/http/context.go +++ b/app/http/context.go @@ -6,8 +6,10 @@ import ( "github.com/sirupsen/logrus" "net/http" "strings" + "time" ) +const UserKey = "user" const UserIdKey = "user_id" // UserModel 不能赋值指针 @@ -28,7 +30,7 @@ type Ctx struct { UserInfo interface{} } -func (receiver Ctx) Success(data interface{}) { +func (receiver *Ctx) Success(data interface{}) { receiver.JSON(http.StatusOK, map[string]interface{}{ "data": data, "code": 0, @@ -36,25 +38,25 @@ func (receiver Ctx) Success(data interface{}) { }) } -func (receiver Ctx) Fail(err error) { +func (receiver *Ctx) Fail(err error) { receiver.JSON(http.StatusOK, map[string]interface{}{ "code": 1, "msg": err.Error(), }) } -func (receiver Ctx) Gin() *gin.Context { +func (receiver *Ctx) Gin() *gin.Context { return receiver.Context } -func (receiver Ctx) User() interface{} { +func (receiver *Ctx) User() interface{} { if receiver.UserInfo == nil { receiver.InitUser() } return receiver.UserInfo } -func (receiver Ctx) Id() uint64 { +func (receiver *Ctx) Id() uint64 { u, ok := receiver.Context.Get(UserIdKey) if !ok { logrus.Fatal("id 不存在, todo Context.Set(UserIdKey, Uid)") @@ -63,7 +65,7 @@ func (receiver Ctx) Id() uint64 { return u.(uint64) } -func (receiver Ctx) IdStr() string { +func (receiver *Ctx) IdStr() string { u, ok := receiver.Context.Get(UserIdKey) if !ok { return "" @@ -71,7 +73,7 @@ func (receiver Ctx) IdStr() string { return u.(string) } -func (receiver Ctx) Token() string { +func (receiver *Ctx) Token() string { tokenString := receiver.Context.GetHeader("Authorization") if strings.HasPrefix(tokenString, "Bearer ") { @@ -80,9 +82,15 @@ func (receiver Ctx) Token() string { return tokenString } -func (receiver Ctx) InitUser() { +func (receiver *Ctx) InitUser() { if receiver.UserInfo == nil { - uid, ok := receiver.Context.Get(UserIdKey) + u, ok := receiver.Context.Get(UserKey) + if ok { + receiver.UserInfo = u + return + } + + uid, ok := receiver.Get(UserIdKey) if ok { user := UserModel database.DB().Model(UserModel).First(&user, uid) @@ -100,4 +108,30 @@ type Context interface { Id() uint64 IdStr() string User() interface{} + + // 下面是补充 gin.Context 的方法 + + JSON(code int, obj interface{}) + String(code int, format string, values ...interface{}) + Param(key string) string + Query(key string) string + PostForm(key string) string + BindJSON(obj interface{}) error + Status(code int) + Set(key string, value interface{}) + Get(key string) (value interface{}, exists bool) + AbortWithStatus(code int) + Next() + + GetString(key string) string + GetBool(key string) bool + GetInt(key string) int + GetInt64(key string) int64 + GetFloat64(key string) float64 + GetTime(key string) time.Time + GetDuration(key string) time.Duration + GetStringSlice(key string) []string + GetStringMap(key string) map[string]interface{} + GetStringMapString(key string) map[string]string + GetStringMapStringSlice(key string) map[string][]string } diff --git a/bootstrap/services/http_server.go b/bootstrap/services/http_server.go index eae3d98..242f6f3 100644 --- a/bootstrap/services/http_server.go +++ b/bootstrap/services/http_server.go @@ -46,6 +46,6 @@ func (receiver *HttpServer) SetPort(port int) { func (receiver *HttpServer) RunListener() { err := receiver.GetEngine().Run(":" + receiver.port) if err != nil { - logrus.WithFields(logrus.Fields{"port": receiver.port}).Error("http发送错误") + logrus.WithFields(logrus.Fields{"port": receiver.port}).Error("http server 启动发生错误") } } diff --git a/bootstrap/services/redis_server.go b/bootstrap/services/redis_server.go index 523137c..90178f5 100644 --- a/bootstrap/services/redis_server.go +++ b/bootstrap/services/redis_server.go @@ -2,6 +2,7 @@ package services import ( "context" + "errors" "github.com/go-redis/redis/v8" log "github.com/sirupsen/logrus" "time" @@ -22,7 +23,7 @@ func (r Redis) Get(key string) *redis.StringCmd { func (r Redis) GetString(key string) (string, bool) { cmd := r.Client.Get(context.Background(), key) err := cmd.Err() - if err == redis.Nil { + if errors.Is(err, redis.Nil) { return "", false } else if err != nil { log.Error(err) @@ -34,7 +35,7 @@ func (r Redis) GetString(key string) (string, bool) { func (r Redis) GetInt(key string) (int, bool) { i, err := r.Client.Get(context.Background(), key).Int() if err != nil { - if err == redis.Nil { + if errors.Is(err, redis.Nil) { return 0, false } log.Errorf("GetInt %v", err) @@ -45,7 +46,7 @@ func (r Redis) GetInt(key string) (int, bool) { func (r Redis) GetInt64(key string) (int64, bool) { i, err := r.Client.Get(context.Background(), key).Int64() if err != nil { - if err == redis.Nil { + if errors.Is(err, redis.Nil) { return 0, false } log.Errorf("GetInt64 %v", err) @@ -56,7 +57,7 @@ func (r Redis) GetInt64(key string) (int64, bool) { func (r Redis) GetFloat32(key string) (float32, bool) { i, err := r.Client.Get(context.Background(), key).Float32() if err != nil { - if err == redis.Nil { + if errors.Is(err, redis.Nil) { return 0, false } log.Errorf("GetFloat32 %v", err) @@ -67,7 +68,7 @@ func (r Redis) GetFloat32(key string) (float32, bool) { func (r Redis) GetFloat64(key string) (float64, bool) { i, err := r.Client.Get(context.Background(), key).Float64() if err != nil { - if err == redis.Nil { + if errors.Is(err, redis.Nil) { return 0, false } log.Errorf("GetFloat32 %v", err) @@ -78,7 +79,7 @@ func (r Redis) GetFloat64(key string) (float64, bool) { func (r Redis) GetBool(key string) (bool, bool) { i, err := r.Client.Get(context.Background(), key).Bool() if err != nil { - if err == redis.Nil { + if errors.Is(err, redis.Nil) { return false, false } log.Errorf("GetFloat32 %v", err) @@ -89,7 +90,7 @@ func (r Redis) GetBool(key string) (bool, bool) { func (r Redis) Incr(key string) (int64, bool) { cmd := r.Client.Incr(context.Background(), key) err := cmd.Err() - if err == redis.Nil { + if errors.Is(err, redis.Nil) { return 0, true } else if err != nil { log.Error(err)