From 07b7dc40fc748abb8498059f129408b647ccad29 Mon Sep 17 00:00:00 2001 From: okhowang <3352585+okhowang@users.noreply.github.com> Date: Mon, 3 Apr 2023 20:32:44 +0800 Subject: [PATCH] =?UTF-8?q?cache=E5=A2=9E=E5=8A=A0=E5=B8=A6Context?= =?UTF-8?q?=E7=89=88=E6=9C=AC=EF=BC=8C=E5=BC=80=E6=94=BE=E5=B9=B3=E5=8F=B0?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E6=8E=A5=E5=8F=A3=E6=94=AF=E6=8C=81Context?= =?UTF-8?q?=E7=89=88=E6=9C=AC=20(#653)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cache/cache.go | 46 ++++++++++- cache/redis.go | 28 ++++++- cache/redis_test.go | 10 ++- go.mod | 1 + go.sum | 7 ++ openplatform/context/accessToken.go | 113 ++++++++++++++++++++-------- util/http.go | 16 +++- 7 files changed, 178 insertions(+), 43 deletions(-) diff --git a/cache/cache.go b/cache/cache.go index f3feb84c4..077597dbe 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -1,6 +1,9 @@ package cache -import "time" +import ( + "context" + "time" +) // Cache interface type Cache interface { @@ -9,3 +12,44 @@ type Cache interface { IsExist(key string) bool Delete(key string) error } + +// ContextCache interface +type ContextCache interface { + Cache + GetContext(ctx context.Context, key string) interface{} + SetContext(ctx context.Context, key string, val interface{}, timeout time.Duration) error + IsExistContext(ctx context.Context, key string) bool + DeleteContext(ctx context.Context, key string) error +} + +// GetContext get value from cache +func GetContext(ctx context.Context, cache Cache, key string) interface{} { + if cache, ok := cache.(ContextCache); ok { + return cache.GetContext(ctx, key) + } + return cache.Get(key) +} + +// SetContext set value to cache +func SetContext(ctx context.Context, cache Cache, key string, val interface{}, timeout time.Duration) error { + if cache, ok := cache.(ContextCache); ok { + return cache.SetContext(ctx, key, val, timeout) + } + return cache.Set(key, val, timeout) +} + +// IsExistContext check value exists in cache. +func IsExistContext(ctx context.Context, cache Cache, key string) bool { + if cache, ok := cache.(ContextCache); ok { + return cache.IsExistContext(ctx, key) + } + return cache.IsExist(key) +} + +// DeleteContext delete value in cache. +func DeleteContext(ctx context.Context, cache Cache, key string) error { + if cache, ok := cache.(ContextCache); ok { + return cache.DeleteContext(ctx, key) + } + return cache.Delete(key) +} diff --git a/cache/redis.go b/cache/redis.go index 24a736f94..f51f7bf88 100644 --- a/cache/redis.go +++ b/cache/redis.go @@ -47,7 +47,12 @@ func (r *Redis) SetRedisCtx(ctx context.Context) { // Get 获取一个值 func (r *Redis) Get(key string) interface{} { - result, err := r.conn.Do(r.ctx, "GET", key).Result() + return r.GetContext(r.ctx, key) +} + +// GetContext 获取一个值 +func (r *Redis) GetContext(ctx context.Context, key string) interface{} { + result, err := r.conn.Do(ctx, "GET", key).Result() if err != nil { return nil } @@ -56,17 +61,32 @@ func (r *Redis) Get(key string) interface{} { // Set 设置一个值 func (r *Redis) Set(key string, val interface{}, timeout time.Duration) error { - return r.conn.SetEX(r.ctx, key, val, timeout).Err() + return r.SetContext(r.ctx, key, val, timeout) +} + +// SetContext 设置一个值 +func (r *Redis) SetContext(ctx context.Context, key string, val interface{}, timeout time.Duration) error { + return r.conn.SetEX(ctx, key, val, timeout).Err() } // IsExist 判断key是否存在 func (r *Redis) IsExist(key string) bool { - result, _ := r.conn.Exists(r.ctx, key).Result() + return r.IsExistContext(r.ctx, key) +} + +// IsExistContext 判断key是否存在 +func (r *Redis) IsExistContext(ctx context.Context, key string) bool { + result, _ := r.conn.Exists(ctx, key).Result() return result > 0 } // Delete 删除 func (r *Redis) Delete(key string) error { - return r.conn.Del(r.ctx, key).Err() + return r.DeleteContext(r.ctx, key) +} + +// DeleteContext 删除 +func (r *Redis) DeleteContext(ctx context.Context, key string) error { + return r.conn.Del(ctx, key).Err() } diff --git a/cache/redis_test.go b/cache/redis_test.go index 8973fe6a0..a41a2f166 100644 --- a/cache/redis_test.go +++ b/cache/redis_test.go @@ -4,17 +4,23 @@ import ( "context" "testing" "time" + + "github.com/alicebob/miniredis/v2" ) func TestRedis(t *testing.T) { + server, err := miniredis.Run() + if err != nil { + t.Error("miniredis.Run Error", err) + } + t.Cleanup(server.Close) var ( timeoutDuration = time.Second ctx = context.Background() opts = &RedisOpts{ - Host: "127.0.0.1:6379", + Host: server.Addr(), } redis = NewRedis(ctx, opts) - err error val = "silenceper" key = "username" ) diff --git a/go.mod b/go.mod index b49244c3a..0180599f0 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/silenceper/wechat/v2 go 1.16 require ( + github.com/alicebob/miniredis/v2 v2.30.0 github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d github.com/fatih/structs v1.1.0 github.com/go-redis/redis/v8 v8.11.5 diff --git a/go.sum b/go.sum index f9efc2738..64deda94f 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.30.0 h1:uA3uhDbCxfO9+DI/DuGeAMr9qI+noVWwGPNTFuKID5M= +github.com/alicebob/miniredis/v2 v2.30.0/go.mod h1:84TWKZlxYkfgMucPBf5SOQBYJceZeQRFIaQgNMiCX6Q= github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d h1:pVrfxiGfwelyab6n21ZBkbkmbevaf+WvMIiR7sr97hw= github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= @@ -71,6 +75,8 @@ github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64 h1:5mLPGnFdSsevFRFc9q3yYbBkB6tsm4aCwwQV/j1JQAQ= +github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -89,6 +95,7 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/openplatform/context/accessToken.go b/openplatform/context/accessToken.go index a8316e2ef..68861295f 100644 --- a/openplatform/context/accessToken.go +++ b/openplatform/context/accessToken.go @@ -2,11 +2,13 @@ package context import ( + "context" "encoding/json" "fmt" "net/url" "time" + "github.com/silenceper/wechat/v2/cache" "github.com/silenceper/wechat/v2/util" ) @@ -31,24 +33,29 @@ type ComponentAccessToken struct { ExpiresIn int64 `json:"expires_in"` } -// GetComponentAccessToken 获取 ComponentAccessToken -func (ctx *Context) GetComponentAccessToken() (string, error) { +// GetComponentAccessTokenContext 获取 ComponentAccessToken +func (ctx *Context) GetComponentAccessTokenContext(stdCtx context.Context) (string, error) { accessTokenCacheKey := fmt.Sprintf("component_access_token_%s", ctx.AppID) - val := ctx.Cache.Get(accessTokenCacheKey) + val := cache.GetContext(stdCtx, ctx.Cache, accessTokenCacheKey) if val == nil { return "", fmt.Errorf("cann't get component access token") } return val.(string), nil } -// SetComponentAccessToken 通过component_verify_ticket 获取 ComponentAccessToken -func (ctx *Context) SetComponentAccessToken(verifyTicket string) (*ComponentAccessToken, error) { +// GetComponentAccessToken 获取 ComponentAccessToken +func (ctx *Context) GetComponentAccessToken() (string, error) { + return ctx.GetComponentAccessTokenContext(context.Background()) +} + +// SetComponentAccessTokenContext 通过component_verify_ticket 获取 ComponentAccessToken +func (ctx *Context) SetComponentAccessTokenContext(stdCtx context.Context, verifyTicket string) (*ComponentAccessToken, error) { body := map[string]string{ "component_appid": ctx.AppID, "component_appsecret": ctx.AppSecret, "component_verify_ticket": verifyTicket, } - respBody, err := util.PostJSON(componentAccessTokenURL, body) + respBody, err := util.PostJSONContext(stdCtx, componentAccessTokenURL, body) if err != nil { return nil, err } @@ -64,15 +71,20 @@ func (ctx *Context) SetComponentAccessToken(verifyTicket string) (*ComponentAcce accessTokenCacheKey := fmt.Sprintf("component_access_token_%s", ctx.AppID) expires := at.ExpiresIn - 1500 - if err := ctx.Cache.Set(accessTokenCacheKey, at.AccessToken, time.Duration(expires)*time.Second); err != nil { + if err := cache.SetContext(stdCtx, ctx.Cache, accessTokenCacheKey, at.AccessToken, time.Duration(expires)*time.Second); err != nil { return nil, nil } return at, nil } -// GetPreCode 获取预授权码 -func (ctx *Context) GetPreCode() (string, error) { - cat, err := ctx.GetComponentAccessToken() +// SetComponentAccessToken 通过component_verify_ticket 获取 ComponentAccessToken +func (ctx *Context) SetComponentAccessToken(stdCtx context.Context, verifyTicket string) (*ComponentAccessToken, error) { + return ctx.SetComponentAccessTokenContext(stdCtx, verifyTicket) +} + +// GetPreCodeContext 获取预授权码 +func (ctx *Context) GetPreCodeContext(stdCtx context.Context) (string, error) { + cat, err := ctx.GetComponentAccessTokenContext(stdCtx) if err != nil { return "", err } @@ -80,7 +92,7 @@ func (ctx *Context) GetPreCode() (string, error) { "component_appid": ctx.AppID, } uri := fmt.Sprintf(getPreCodeURL, cat) - body, err := util.PostJSON(uri, req) + body, err := util.PostJSONContext(stdCtx, uri, req) if err != nil { return "", err } @@ -95,24 +107,39 @@ func (ctx *Context) GetPreCode() (string, error) { return ret.PreCode, nil } -// GetComponentLoginPage 获取第三方公众号授权链接(扫码授权) -func (ctx *Context) GetComponentLoginPage(redirectURI string, authType int, bizAppID string) (string, error) { - code, err := ctx.GetPreCode() +// GetPreCode 获取预授权码 +func (ctx *Context) GetPreCode() (string, error) { + return ctx.GetPreCodeContext(context.Background()) +} + +// GetComponentLoginPageContext 获取第三方公众号授权链接(扫码授权) +func (ctx *Context) GetComponentLoginPageContext(stdCtx context.Context, redirectURI string, authType int, bizAppID string) (string, error) { + code, err := ctx.GetPreCodeContext(stdCtx) if err != nil { return "", err } return fmt.Sprintf(componentLoginURL, ctx.AppID, code, url.QueryEscape(redirectURI), authType, bizAppID), nil } -// GetBindComponentURL 获取第三方公众号授权链接(链接跳转,适用移动端) -func (ctx *Context) GetBindComponentURL(redirectURI string, authType int, bizAppID string) (string, error) { - code, err := ctx.GetPreCode() +// GetComponentLoginPage 获取第三方公众号授权链接(扫码授权) +func (ctx *Context) GetComponentLoginPage(redirectURI string, authType int, bizAppID string) (string, error) { + return ctx.GetComponentLoginPageContext(context.Background(), redirectURI, authType, bizAppID) +} + +// GetBindComponentURLContext 获取第三方公众号授权链接(链接跳转,适用移动端) +func (ctx *Context) GetBindComponentURLContext(stdCtx context.Context, redirectURI string, authType int, bizAppID string) (string, error) { + code, err := ctx.GetPreCodeContext(stdCtx) if err != nil { return "", err } return fmt.Sprintf(bindComponentURL, authType, ctx.AppID, code, url.QueryEscape(redirectURI), bizAppID), nil } +// GetBindComponentURL 获取第三方公众号授权链接(链接跳转,适用移动端) +func (ctx *Context) GetBindComponentURL(redirectURI string, authType int, bizAppID string) (string, error) { + return ctx.GetBindComponentURLContext(context.Background(), redirectURI, authType, bizAppID) +} + // ID 微信返回接口中各种类型字段 type ID struct { ID int `json:"id"` @@ -137,9 +164,9 @@ type AuthrAccessToken struct { RefreshToken string `json:"authorizer_refresh_token"` } -// QueryAuthCode 使用授权码换取公众号或小程序的接口调用凭据和授权信息 -func (ctx *Context) QueryAuthCode(authCode string) (*AuthBaseInfo, error) { - cat, err := ctx.GetComponentAccessToken() +// QueryAuthCodeContext 使用授权码换取公众号或小程序的接口调用凭据和授权信息 +func (ctx *Context) QueryAuthCodeContext(stdCtx context.Context, authCode string) (*AuthBaseInfo, error) { + cat, err := ctx.GetComponentAccessTokenContext(stdCtx) if err != nil { return nil, err } @@ -149,7 +176,7 @@ func (ctx *Context) QueryAuthCode(authCode string) (*AuthBaseInfo, error) { "authorization_code": authCode, } uri := fmt.Sprintf(queryAuthURL, cat) - body, err := util.PostJSON(uri, req) + body, err := util.PostJSONContext(stdCtx, uri, req) if err != nil { return nil, err } @@ -169,9 +196,14 @@ func (ctx *Context) QueryAuthCode(authCode string) (*AuthBaseInfo, error) { return ret.Info, nil } -// RefreshAuthrToken 获取(刷新)授权公众号或小程序的接口调用凭据(令牌) -func (ctx *Context) RefreshAuthrToken(appid, refreshToken string) (*AuthrAccessToken, error) { - cat, err := ctx.GetComponentAccessToken() +// QueryAuthCode 使用授权码换取公众号或小程序的接口调用凭据和授权信息 +func (ctx *Context) QueryAuthCode(authCode string) (*AuthBaseInfo, error) { + return ctx.QueryAuthCodeContext(context.Background(), authCode) +} + +// RefreshAuthrTokenContext 获取(刷新)授权公众号或小程序的接口调用凭据(令牌) +func (ctx *Context) RefreshAuthrTokenContext(stdCtx context.Context, appid, refreshToken string) (*AuthrAccessToken, error) { + cat, err := ctx.GetComponentAccessTokenContext(stdCtx) if err != nil { return nil, err } @@ -182,7 +214,7 @@ func (ctx *Context) RefreshAuthrToken(appid, refreshToken string) (*AuthrAccessT "authorizer_refresh_token": refreshToken, } uri := fmt.Sprintf(refreshTokenURL, cat) - body, err := util.PostJSON(uri, req) + body, err := util.PostJSONContext(stdCtx, uri, req) if err != nil { return nil, err } @@ -193,22 +225,32 @@ func (ctx *Context) RefreshAuthrToken(appid, refreshToken string) (*AuthrAccessT } authrTokenKey := "authorizer_access_token_" + appid - if err := ctx.Cache.Set(authrTokenKey, ret.AccessToken, time.Second*time.Duration(ret.ExpiresIn-30)); err != nil { + if err := cache.SetContext(stdCtx, ctx.Cache, authrTokenKey, ret.AccessToken, time.Second*time.Duration(ret.ExpiresIn-30)); err != nil { return nil, err } return ret, nil } -// GetAuthrAccessToken 获取授权方AccessToken -func (ctx *Context) GetAuthrAccessToken(appid string) (string, error) { +// RefreshAuthrToken 获取(刷新)授权公众号或小程序的接口调用凭据(令牌) +func (ctx *Context) RefreshAuthrToken(appid, refreshToken string) (*AuthrAccessToken, error) { + return ctx.RefreshAuthrTokenContext(context.Background(), appid, refreshToken) +} + +// GetAuthrAccessTokenContext 获取授权方AccessToken +func (ctx *Context) GetAuthrAccessTokenContext(stdCtx context.Context, appid string) (string, error) { authrTokenKey := "authorizer_access_token_" + appid - val := ctx.Cache.Get(authrTokenKey) + val := cache.GetContext(stdCtx, ctx.Cache, authrTokenKey) if val == nil { return "", fmt.Errorf("cannot get authorizer %s access token", appid) } return val.(string), nil } +// GetAuthrAccessToken 获取授权方AccessToken +func (ctx *Context) GetAuthrAccessToken(appid string) (string, error) { + return ctx.GetAuthrAccessTokenContext(context.Background(), appid) +} + // AuthorizerInfo 授权方详细信息 type AuthorizerInfo struct { NickName string `json:"nick_name"` @@ -258,9 +300,9 @@ type CategoriesInfo struct { Second string `wx:"second"` } -// GetAuthrInfo 获取授权方的帐号基本信息 -func (ctx *Context) GetAuthrInfo(appid string) (*AuthorizerInfo, *AuthBaseInfo, error) { - cat, err := ctx.GetComponentAccessToken() +// GetAuthrInfoContext 获取授权方的帐号基本信息 +func (ctx *Context) GetAuthrInfoContext(stdCtx context.Context, appid string) (*AuthorizerInfo, *AuthBaseInfo, error) { + cat, err := ctx.GetComponentAccessTokenContext(stdCtx) if err != nil { return nil, nil, err } @@ -271,7 +313,7 @@ func (ctx *Context) GetAuthrInfo(appid string) (*AuthorizerInfo, *AuthBaseInfo, } uri := fmt.Sprintf(getComponentInfoURL, cat) - body, err := util.PostJSON(uri, req) + body, err := util.PostJSONContext(stdCtx, uri, req) if err != nil { return nil, nil, err } @@ -286,3 +328,8 @@ func (ctx *Context) GetAuthrInfo(appid string) (*AuthorizerInfo, *AuthBaseInfo, return ret.AuthorizerInfo, ret.AuthorizationInfo, nil } + +// GetAuthrInfo 获取授权方的帐号基本信息 +func (ctx *Context) GetAuthrInfo(appid string) (*AuthorizerInfo, *AuthBaseInfo, error) { + return ctx.GetAuthrInfoContext(context.Background(), appid) +} diff --git a/util/http.go b/util/http.go index 38089aebe..fdd2f0abf 100644 --- a/util/http.go +++ b/util/http.go @@ -69,8 +69,8 @@ func HTTPPostContext(ctx context.Context, uri string, data []byte, header map[st return io.ReadAll(response.Body) } -// PostJSON post json 数据请求 -func PostJSON(uri string, obj interface{}) ([]byte, error) { +// PostJSONContext post json 数据请求 +func PostJSONContext(ctx context.Context, uri string, obj interface{}) ([]byte, error) { jsonBuf := new(bytes.Buffer) enc := json.NewEncoder(jsonBuf) enc.SetEscapeHTML(false) @@ -78,7 +78,12 @@ func PostJSON(uri string, obj interface{}) ([]byte, error) { if err != nil { return nil, err } - response, err := http.Post(uri, "application/json;charset=utf-8", jsonBuf) + req, err := http.NewRequestWithContext(ctx, "POST", uri, jsonBuf) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json;charset=utf-8") + response, err := http.DefaultClient.Do(req) if err != nil { return nil, err } @@ -90,6 +95,11 @@ func PostJSON(uri string, obj interface{}) ([]byte, error) { return io.ReadAll(response.Body) } +// PostJSON post json 数据请求 +func PostJSON(uri string, obj interface{}) ([]byte, error) { + return PostJSONContext(context.Background(), uri, obj) +} + // PostJSONWithRespContentType post json数据请求,且返回数据类型 func PostJSONWithRespContentType(uri string, obj interface{}) ([]byte, string, error) { jsonBuf := new(bytes.Buffer)