From 38b61cd4ea09a8db40814de1565bfd7131cdf1ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=B2=81=E5=BB=BA=E5=BC=BA?= Date: Tue, 30 Jul 2024 16:46:10 +0800 Subject: [PATCH] Add filter and sort plugin api --- filter/filter.go | 2 + filter/plugin_api_filter.go | 93 +++++++++++++++++++++++++++++++++++ module/plugin_api.go | 71 +++++++++++++++++++++++++++ recconf/recconf.go | 9 ++++ sort/plugin_api_sort.go | 97 +++++++++++++++++++++++++++++++++++++ sort/sort.go | 2 + 6 files changed, 274 insertions(+) create mode 100644 filter/plugin_api_filter.go create mode 100644 module/plugin_api.go create mode 100644 sort/plugin_api_sort.go diff --git a/filter/filter.go b/filter/filter.go index 3c14861..286a5e8 100644 --- a/filter/filter.go +++ b/filter/filter.go @@ -139,6 +139,8 @@ func RegisterFilterWithConfig(config *recconf.RecommendConfig) { f = NewUser2ItemExposureWithConditionFilter(conf) } else if conf.FilterType == "ConditionFilter" { f = NewConditionFilter(conf) + } else if conf.FilterType == "PluginAPIFilter" { + f = NewPluginAPIFilter(conf) } if f == nil { diff --git a/filter/plugin_api_filter.go b/filter/plugin_api_filter.go new file mode 100644 index 0000000..3aeee29 --- /dev/null +++ b/filter/plugin_api_filter.go @@ -0,0 +1,93 @@ +package filter + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/pairec/v2/log" + "github.com/alibaba/pairec/v2/module" + "github.com/alibaba/pairec/v2/recconf" +) + +type PluginAPIFilter struct { + url string +} + +func NewPluginAPIFilter(config recconf.FilterConfig) *PluginAPIFilter { + filter := PluginAPIFilter{} + + filter.url = config.PluginAPIFilterConf.URL + + return &filter +} +func (f *PluginAPIFilter) Filter(filterData *FilterData) error { + if _, ok := filterData.Data.([]*module.Item); !ok { + return errors.New("filter data type error") + + } + return f.doFilter(filterData) +} + +func (f *PluginAPIFilter) doFilter(filterData *FilterData) error { + items := filterData.Data.([]*module.Item) + var newItems []*module.Item + + reqData := module.NewPluginAPIRequest(filterData.User, items, filterData.Context) + // 将结构体编码为 JSON 格式 + jsonData, err := json.Marshal(reqData) + if err != nil { + return fmt.Errorf("error encoding JSON: %w", err) + } + + // 创建一个 HTTP 请求 + req, err := http.NewRequest("POST", f.url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + + // 发送请求 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("error sending request: %w", err) + } + defer resp.Body.Close() + + // 检查响应状态码 + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("error: received non-200 response code: %v", resp.StatusCode) + } + + // 读取和处理响应 + var respData module.PluginAPIFilterResponse + if err := json.NewDecoder(resp.Body).Decode(&respData); err != nil { + return fmt.Errorf("error decoding response: %w", err) + } + + if respData.Code != 200 { + if respData.Code >= 400 { + return fmt.Errorf("error: received non-200 business code: %d, msg: %s", respData.Code, respData.Msg) + } else { + log.Warning(fmt.Sprintf("requestId=%s\tmodule=%s\tmsg=%v", filterData.Context.RecommendId, "PluginAPIFilter", respData.Msg)) + } + } + + itemMap := make(map[string]*module.Item, len(items)) + for i, item := range items { + itemMap[string(item.Id)] = items[i] + } + + for _, itemId := range respData.Items { + newItems = append(newItems, itemMap[itemId]) + } + + filterData.Data = newItems + + return nil +} diff --git a/module/plugin_api.go b/module/plugin_api.go new file mode 100644 index 0000000..9dbb0ea --- /dev/null +++ b/module/plugin_api.go @@ -0,0 +1,71 @@ +package module + +import "github.com/alibaba/pairec/v2/context" + +type PluginAPIRequest struct { + Uid string `json:"uid"` + Size int `json:"size"` + SceneId string `json:"scene_id"` + Features map[string]any `json:"features"` + ItemId string `json:"item_id"` + ItemList []map[string]any `json:"item_list"` + Debug bool `json:"debug"` + RequestId string `json:"request_id"` +} + +type PluginAPIFilterResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + RequestId string `json:"request_id"` + Items []string +} + +type PluginAPISortResponse = PluginAPIFilterResponse + +func NewPluginAPIRequest(user *User, items []*Item, ctx *context.RecommendContext) *PluginAPIRequest { + request := &PluginAPIRequest{} + + request.Uid = string(user.Id) + request.Size = ctx.Size + + if scene, ok := ctx.GetParameter("scene").(string); ok { + request.SceneId = scene + } + + request.Features = user.Properties + + if itemId, ok := ctx.GetParameter("item_id").(string); ok { + request.ItemId = itemId + } + + request.ItemList = make([]map[string]any, 0, len(items)) + for _, item := range items { + itemData := make(map[string]interface{}) + itemData["item_id"] = item.Id + itemData["score"] = item.Score + itemData["retrieve_id"] = item.RetrieveId + + if item.ItemType != "" { + itemData["item_type"] = item.ItemType + } + if item.Embedding != nil { + itemData["embedding"] = item.Embedding + } + if item.Properties != nil { + for k, v := range item.Properties { + itemData[k] = v + } + } + if item.algoScores != nil { + itemData["algo_scores"] = item.algoScores + } + + request.ItemList = append(request.ItemList, itemData) + } + + request.Debug = ctx.Debug + + request.RequestId = ctx.RecommendId + + return request +} diff --git a/recconf/recconf.go b/recconf/recconf.go index 0f3f50c..3ac0c8c 100644 --- a/recconf/recconf.go +++ b/recconf/recconf.go @@ -681,6 +681,7 @@ type FilterConfig struct { } DefaultFilterName string } + PluginAPIFilterConf PluginAPIFilterConfig } type BeFilterConfig struct { FilterConfig @@ -700,6 +701,7 @@ type SortConfig struct { Size int DPPConf DPPSortConfig PIDConf PIDControllerConfig + PluginAPIConfig PluginAPISortConfig MixSortRules []MixSortConfig BoostScoreConditionsFilterAll bool BoostScoreConditions []BoostScoreCondition @@ -764,6 +766,9 @@ type AdjustCountConfig struct { Count int Type string } +type PluginAPIFilterConfig struct { + URL string +} type CallBackConfig struct { DataSource DataSourceConfig RankConf RankConfig @@ -800,6 +805,10 @@ type TriggerConfig struct { Boundaries []int } +type PluginAPISortConfig struct { + URL string +} + type DPPSortConfig struct { Name string DaoConf DaoConfig diff --git a/sort/plugin_api_sort.go b/sort/plugin_api_sort.go new file mode 100644 index 0000000..c504766 --- /dev/null +++ b/sort/plugin_api_sort.go @@ -0,0 +1,97 @@ +package sort + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "github.com/alibaba/pairec/v2/log" + "github.com/alibaba/pairec/v2/module" + "github.com/alibaba/pairec/v2/recconf" +) + +type PluginAPISort struct { + name string + url string +} + +func (s *PluginAPISort) Sort(sortData *SortData) error { + if _, ok := sortData.Data.([]*module.Item); !ok { + return errors.New("sort data type error") + } + + return s.doSort(sortData) +} + +func (s *PluginAPISort) doSort(sortData *SortData) error { + start := time.Now() + items := sortData.Data.([]*module.Item) + var newItems []*module.Item + + reqData := module.NewPluginAPIRequest(sortData.User, items, sortData.Context) + // 将结构体编码为 JSON 格式 + jsonData, err := json.Marshal(reqData) + if err != nil { + return fmt.Errorf("error encoding JSON: %w", err) + } + + // 创建一个 HTTP 请求 + req, err := http.NewRequest("POST", s.url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + + // 发送请求 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("error sending request: %w", err) + } + defer resp.Body.Close() + + // 检查响应状态码 + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("error: received non-200 response code: %v", resp.StatusCode) + } + + // 读取和处理响应 + var respData module.PluginAPIFilterResponse + if err := json.NewDecoder(resp.Body).Decode(&respData); err != nil { + return fmt.Errorf("error decoding response: %w", err) + } + + if respData.Code != 200 { + if respData.Code >= 400 { + return fmt.Errorf("error: received non-200 business code: %d, msg: %s", respData.Code, respData.Msg) + } else { + log.Warning(fmt.Sprintf("requestId=%s\tmodule=%s\tmsg=%v", sortData.Context.RecommendId, "PluginAPISort", respData.Msg)) + } + } + + itemMap := make(map[string]*module.Item, len(items)) + for i, item := range items { + itemMap[string(item.Id)] = items[i] + } + + for _, itemId := range respData.Items { + newItems = append(newItems, itemMap[itemId]) + } + + sortData.Data = newItems + sortInfoLogWithName(sortData, "PluginAPISort", s.name, len(items), start) + return nil +} + +func NewPluginAPISort(config recconf.SortConfig) *PluginAPISort { + p := PluginAPISort{} + p.name = config.Name + p.url = config.PluginAPIConfig.URL + + return &p +} diff --git a/sort/sort.go b/sort/sort.go index 591860e..583e306 100644 --- a/sort/sort.go +++ b/sort/sort.go @@ -185,6 +185,8 @@ func RegisterSortWithConfig(config *recconf.RecommendConfig) { s = NewBoostScoreByWeight(conf) } else if conf.SortType == "DistinctIdSort" { s = NewDistinctIdSort(conf) + } else if conf.SortType == "PluginAPISort" { + s = NewPluginAPISort(conf) } if s == nil {