diff --git a/README.md b/README.md index 144e302..207232d 100644 --- a/README.md +++ b/README.md @@ -81,8 +81,8 @@ Currently following implementations are available: ## TODO: -- [ ] Make sources configurable - - sources to be used should be configurable per instance +- [x] Make sources configurable + - [x] sources to be used should be configurable per instance - [ ] a configurable caching mechanism to enable offline usage - [ ] Add more sources (hackernews?, wikipedia?) - [ ] Enable markdown to console colored output ? diff --git a/clipmon.go b/clipmon.go index 056abed..0da402f 100644 --- a/clipmon.go +++ b/clipmon.go @@ -51,7 +51,7 @@ func (cbm *ClipboardMonitor) Run(ctx context.Context) error { query.Text = strings.TrimSpace(currentContent) cbm.Instance.Infof("running query for '%s'..", query.Text) - rs, err := cbm.Instance.Search(ctx, query) + rs, err := cbm.Instance.Search(ctx, query, Strategy1st) if err == nil && rs != nil { if len(rs) > 0 { cbm.Instance.Infof("recieved result. pasting back..") diff --git a/cmd/radium/cli.go b/cmd/radium/cli.go index 8f8c0e2..03b7107 100644 --- a/cmd/radium/cli.go +++ b/cmd/radium/cli.go @@ -2,36 +2,19 @@ package main import ( "fmt" - "log" "reflect" + "strings" - homedir "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" - "github.com/spf13/viper" ) func newCLI() *cobra.Command { cfg := &config{} rootCmd := newRootCmd(cfg) - initConfig := func() { - viper.SetConfigName("radium") - viper.SetConfigType("yaml") - - viper.AddConfigPath("./") - if hd, err := homedir.Dir(); err == nil { - viper.AddConfigPath(hd) - } - viper.AutomaticEnv() - viper.BindPFlags(rootCmd.PersistentFlags()) - - viper.ReadInConfig() - if err := viper.Unmarshal(cfg); err != nil { - log.Fatalf("config err: %s\n", err) - } - } - - cobra.OnInitialize(initConfig) + cobra.OnInitialize(func() { + initConfig(cfg, rootCmd) + }) return rootCmd } @@ -51,6 +34,7 @@ and replace queries with solutions as they come in! rootCmd.PersistentFlags().BoolP("ugly", "u", false, "Print raw output as yaml or json") rootCmd.PersistentFlags().Bool("json", false, "Print output as JSON") + rootCmd.PersistentFlags().StringSlice("sources", nil, "Enable sources") rootCmd.AddCommand(newServeCmd(cfg)) rootCmd.AddCommand(newQueryCmd(cfg)) @@ -67,14 +51,15 @@ func newListSources(cfg *config) *cobra.Command { } cmd.Run = func(_ *cobra.Command, args []string) { - ins := getNewRadiumInstance() + ins := getNewRadiumInstance(*cfg) srcs := ins.GetSources() if L := len(srcs); L > 0 { fmt.Printf("%d source(s) available:\n", L) - for name, src := range srcs { + fmt.Printf("%s\n", strings.Repeat("-", 20)) + for order, src := range srcs { ty := reflect.TypeOf(src) - fmt.Printf("* %s (Type: %s)\n", name, ty.String()) + fmt.Printf("%d. %s (Type: %s)\n", order+1, src.Name, ty.String()) } } else { fmt.Println("No sources configured") diff --git a/cmd/radium/config.go b/cmd/radium/config.go index c14e89e..d1ec9b9 100644 --- a/cmd/radium/config.go +++ b/cmd/radium/config.go @@ -1,6 +1,32 @@ package main +import ( + "log" + + homedir "github.com/mitchellh/go-homedir" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +func initConfig(cfg *config, rootCmd *cobra.Command) { + viper.SetDefault("sources", "cheatsh,learnxiny,tldr") + + viper.SetConfigName("radium") + viper.AddConfigPath("./") + if hd, err := homedir.Dir(); err == nil { + viper.AddConfigPath(hd) + } + viper.AutomaticEnv() + viper.BindPFlags(rootCmd.PersistentFlags()) + + viper.ReadInConfig() + if err := viper.Unmarshal(cfg); err != nil { + log.Fatalf("config err: %s\n", err) + } +} + // config struct is used to store CLI configurations. configuration // values are read into this struct using viper type config struct { + Sources []string `json:"sources"` } diff --git a/cmd/radium/query.go b/cmd/radium/query.go index 9b7a941..65f57aa 100644 --- a/cmd/radium/query.go +++ b/cmd/radium/query.go @@ -2,7 +2,7 @@ package main import ( "context" - "fmt" + "errors" "os" "strings" @@ -19,7 +19,9 @@ func newQueryCmd(cfg *config) *cobra.Command { } var attribs []string + var strategy string cmd.Flags().StringSliceVarP(&attribs, "attr", "a", []string{}, "Attributes to narrow the search scope") + cmd.Flags().StringVarP(&strategy, "strategy", "s", "concurrent", "Strategy to use for executing sources") cmd.Run = func(_ *cobra.Command, args []string) { query := radium.Query{} @@ -31,26 +33,24 @@ func newQueryCmd(cfg *config) *cobra.Command { if len(parts) == 2 { query.Attribs[parts[0]] = parts[1] } else { - fmt.Println("Err: invalid attrib format. must be :") + writeOut(cmd, errors.New("invalid attrib format. must be :")) os.Exit(1) } } ctx := context.Background() - ins := getNewRadiumInstance() - rs, err := ins.Search(ctx, query) + ins := getNewRadiumInstance(*cfg) + rs, err := ins.Search(ctx, query, strategy) if err != nil { - writeOut(cmd, map[string]interface{}{ - "error": err.Error(), - }) - } else { - if len(rs) == 1 { - writeOut(cmd, rs[0]) - } else { - writeOut(cmd, rs) - } + writeOut(cmd, err) + os.Exit(1) } + if len(rs) == 1 { + writeOut(cmd, rs[0]) + } else { + writeOut(cmd, rs) + } } return cmd diff --git a/cmd/radium/server.go b/cmd/radium/server.go index 9d6e69d..a5083c6 100644 --- a/cmd/radium/server.go +++ b/cmd/radium/server.go @@ -44,7 +44,7 @@ by passing '--clipboard' or '-C' option and setting '--addr' blank. var wg sync.WaitGroup - ins := getNewRadiumInstance() + ins := getNewRadiumInstance(*cfg) if addr != "" { srv := radium.NewServer(ins) diff --git a/cmd/radium/setup.go b/cmd/radium/setup.go index b4812ac..9785c33 100644 --- a/cmd/radium/setup.go +++ b/cmd/radium/setup.go @@ -1,16 +1,29 @@ package main import ( + "strings" + "github.com/shivylp/radium" "github.com/shivylp/radium/sources" ) -func getNewRadiumInstance() *radium.Instance { +func getNewRadiumInstance(cfg config) *radium.Instance { ins := radium.New(nil, nil) - ins.RegisterSource("cheat.sh", sources.NewCheatSh()) - ins.RegisterSource("tldr", sources.NewTLDR()) - ins.RegisterSource("learnxinyminutes", sources.NewLearnXInYMins()) + for _, src := range cfg.Sources { + switch strings.ToLower(strings.TrimSpace(src)) { + case "cheatsh", "cheat.sh": + ins.RegisterSource("cheat.sh", sources.NewCheatSh()) + case "learnxiny", "lxy", "learnxinyminutes": + ins.RegisterSource("learnxinyminutes", sources.NewLearnXInYMins()) + case "tldr": + ins.RegisterSource("tldr", sources.NewTLDR()) + case "wiki", "wikipedia": + ins.RegisterSource("wikipedia", sources.NewWikipedia()) + default: + ins.Fatalf("unknown source type: %s", src) + } + } return ins } diff --git a/cmd/radium/util.go b/cmd/radium/util.go index 8451c2a..1d28926 100644 --- a/cmd/radium/util.go +++ b/cmd/radium/util.go @@ -21,9 +21,12 @@ func writeOut(cmd *cobra.Command, v interface{}) { } func tryPrettyPrint(v interface{}) { - if article, ok := v.(radium.Article); ok { - fmt.Println(article.Content) - } else { + switch v.(type) { + case radium.Article: + fmt.Println((v.(radium.Article)).Content) + case error: + fmt.Printf("error: %s\n", v) + default: rawDump(v, true) } } diff --git a/interfaces.go b/interfaces.go index e84422d..e033803 100644 --- a/interfaces.go +++ b/interfaces.go @@ -1,14 +1,22 @@ package radium import ( + "context" "fmt" "log" + "os" ) // Source implementation is responsible for providing // external data source to query for results. type Source interface { - Search(q Query) ([]Article, error) + Search(ctx context.Context, q Query) ([]Article, error) +} + +// RegisteredSource embeds given Source along with the registered name. +type RegisteredSource struct { + Name string + Source } // Logger implementation should provide logging @@ -18,6 +26,7 @@ type Logger interface { Infof(format string, args ...interface{}) Warnf(format string, args ...interface{}) Errorf(format string, args ...interface{}) + Fatalf(format string, args ...interface{}) } // Cache implementation is responsible for caching @@ -45,5 +54,10 @@ func (dl defaultLogger) Warnf(format string, args ...interface{}) { } func (dl defaultLogger) Errorf(format string, args ...interface{}) { - log.Printf("ERR : %s", fmt.Sprintf(format, args...)) + log.Printf("ERROR: %s", fmt.Sprintf(format, args...)) +} + +func (dl defaultLogger) Fatalf(format string, args ...interface{}) { + log.Printf("FATAL: %s", fmt.Sprintf(format, args...)) + os.Exit(1) } diff --git a/radium.go b/radium.go index dc419c7..0a4b377 100644 --- a/radium.go +++ b/radium.go @@ -5,16 +5,25 @@ import ( "fmt" ) +// Default registered strategies +const ( + Strategy1st = "1st" + StrategyConcurrent = "concurrent" +) + // New initializes an instance of radium func New(cache Cache, logger Logger) *Instance { ins := &Instance{} - ins.sources = map[string]Source{} ins.cache = cache if logger == nil { logger = defaultLogger{} } ins.Logger = logger + ins.strategies = map[string]Strategy{ + Strategy1st: NewNthResult(1, ins.Logger), + StrategyConcurrent: NewConcurrent(ins.Logger), + } return ins } @@ -23,28 +32,42 @@ func New(cache Cache, logger Logger) *Instance { type Instance struct { Logger - sources map[string]Source - cache Cache + sources []RegisteredSource + strategies map[string]Strategy + cache Cache +} + +// RegisterStrategy adds a new source to the query sources +func (ins *Instance) RegisterStrategy(name string, strategy Strategy) { + if ins.strategies == nil { + ins.strategies = map[string]Strategy{} + } + + ins.strategies[name] = strategy } // RegisterSource adds a new source to the query sources -func (ins Instance) RegisterSource(name string, src Source) error { - if _, exists := ins.sources[name]; exists { - return fmt.Errorf("source with given name already exists") +func (ins *Instance) RegisterSource(name string, src Source) error { + for _, entry := range ins.sources { + if name == entry.Name { + return fmt.Errorf("source with given name already exists") + } } - ins.sources[name] = src + ins.sources = append(ins.sources, RegisteredSource{ + Name: name, + Source: src, + }) return nil } // GetSources returns a list of registered sources -func (ins Instance) GetSources() map[string]Source { +func (ins Instance) GetSources() []RegisteredSource { return ins.sources } // Search using given query and return results if any -func (ins Instance) Search(ctx context.Context, query Query) ([]Article, error) { - +func (ins Instance) Search(ctx context.Context, query Query, strategyName string) ([]Article, error) { if err := query.Validate(); err != nil { return nil, err } @@ -53,42 +76,18 @@ func (ins Instance) Search(ctx context.Context, query Query) ([]Article, error) return rs, nil } - results := ins.findFromSources(ctx, query) - go ins.performCaching(query, results) - return results, nil -} - -func (ins Instance) findFromSources(ctx context.Context, query Query) []Article { - var results []Article - for srcName, src := range ins.sources { - resList, err := src.Search(query) - if err != nil { - ins.Warnf("source '%s' failed: %s", srcName, err) - continue - } - - for _, res := range resList { - select { - case <-ctx.Done(): - break - default: - } - - if err := res.Validate(); err != nil { - ins.Warnf("invalid result from source '%s': %s", srcName, err) - continue - } - - res.Source = srcName - results = append(results, res) - } + strategy, exists := ins.strategies[strategyName] + if !exists { + return nil, fmt.Errorf("no such strategy: %s", strategyName) } - if results == nil { - results = []Article{} + results, err := strategy.Execute(ctx, query, ins.sources) + if err != nil { + return nil, err } - return results + go ins.performCaching(query, results) + return results, nil } func (ins Instance) findInCache(query Query) []Article { @@ -96,7 +95,7 @@ func (ins Instance) findInCache(query Query) []Article { return nil } - rs, err := ins.cache.Search(query) + rs, err := ins.cache.Search(context.Background(), query) if err != nil { ins.Warnf("failed to search in cache: %s", err) return nil diff --git a/radium.yaml b/radium.yaml new file mode 100644 index 0000000..526ff43 --- /dev/null +++ b/radium.yaml @@ -0,0 +1,4 @@ +sources: + - cheat.sh + - tldr + - learnxiny diff --git a/server.go b/server.go index 44059ab..2e3fe25 100644 --- a/server.go +++ b/server.go @@ -35,6 +35,11 @@ func (srv Server) handleSearch(wr http.ResponseWriter, req *http.Request) { wr.Header().Set("Content-type", "application/json") query := Query{} + strategy := req.FormValue("strategy") + if strategy == "" { + strategy = "1st" + } + query.Text = req.FormValue("q") query.Attribs = map[string]string{} @@ -45,7 +50,7 @@ func (srv Server) handleSearch(wr http.ResponseWriter, req *http.Request) { } ctx := req.Context() - rs, err := srv.ins.Search(ctx, query) + rs, err := srv.ins.Search(ctx, query, strategy) if err != nil { wr.WriteHeader(http.StatusNotFound) json.NewEncoder(wr).Encode(map[string]interface{}{ @@ -58,9 +63,9 @@ func (srv Server) handleSearch(wr http.ResponseWriter, req *http.Request) { func (srv Server) handleSources(wr http.ResponseWriter, req *http.Request) { sources := map[string]string{} - for name, src := range srv.ins.GetSources() { + for _, src := range srv.ins.GetSources() { ty := reflect.TypeOf(src) - sources[name] = ty.String() + sources[src.Name] = ty.String() } wr.Header().Set("Content-type", "application/json") json.NewEncoder(wr).Encode(sources) diff --git a/sources/cheatsh.go b/sources/cheatsh.go index fd380f3..8b69ccd 100644 --- a/sources/cheatsh.go +++ b/sources/cheatsh.go @@ -1,6 +1,7 @@ package sources import ( + "context" "fmt" "io/ioutil" "net/http" @@ -25,7 +26,7 @@ type CheatSh struct { // Search performs an HTTP request to http://cheat.sh to find // results matching the given query. -func (csh CheatSh) Search(query radium.Query) ([]radium.Article, error) { +func (csh CheatSh) Search(ctx context.Context, query radium.Query) ([]radium.Article, error) { var results []radium.Article if lang, found := query.Attribs["language"]; found { @@ -35,7 +36,7 @@ func (csh CheatSh) Search(query radium.Query) ([]radium.Article, error) { color = true } } - res, err := csh.makeLangRequest(query.Text, lang, color) + res, err := csh.makeLangRequest(ctx, query.Text, lang, color) if err == nil { results = append(results, *res) } @@ -43,7 +44,7 @@ func (csh CheatSh) Search(query radium.Query) ([]radium.Article, error) { return results, nil } -func (csh CheatSh) makeLangRequest(q string, lang string, color bool) (*radium.Article, error) { +func (csh CheatSh) makeLangRequest(ctx context.Context, q string, lang string, color bool) (*radium.Article, error) { queryStr := url.QueryEscape(strings.Replace(q, " ", "+", -1)) csURL := fmt.Sprintf("http://cheat.sh/%s/%s", url.QueryEscape(lang), queryStr) @@ -58,6 +59,7 @@ func (csh CheatSh) makeLangRequest(q string, lang string, color bool) (*radium.A req, _ := http.NewRequest(http.MethodGet, csURL, nil) req.Header.Set("User-Agent", "curl/7.54.0") + req.WithContext(ctx) resp, err := client.Do(req) if err != nil { diff --git a/sources/learnxiny.go b/sources/learnxiny.go index 223a97d..f55563b 100644 --- a/sources/learnxiny.go +++ b/sources/learnxiny.go @@ -1,6 +1,7 @@ package sources import ( + "context" "fmt" "io/ioutil" "net/http" @@ -27,17 +28,16 @@ type LearnXInY struct { // Search attempts to download the appropriate markdown file from learn-x-in-y // repository and format it as a result -func (lxy LearnXInY) Search(query radium.Query) ([]radium.Article, error) { +func (lxy LearnXInY) Search(ctx context.Context, query radium.Query) ([]radium.Article, error) { var rs []radium.Article lang := strings.Replace(query.Text, " ", "-", -1) - if res, err := lxy.getLanguageMarkdown(lang); err == nil { + if res, err := lxy.getLanguageMarkdown(ctx, lang); err == nil { rs = append(rs, *res) } return rs, nil } -func (lxy LearnXInY) getLanguageMarkdown(language string) (*radium.Article, error) { - +func (lxy LearnXInY) getLanguageMarkdown(ctx context.Context, language string) (*radium.Article, error) { ghURL := fmt.Sprintf(learnXInYURL, url.QueryEscape(language)) timeout := time.Duration(5 * time.Second) client := http.Client{ @@ -46,6 +46,7 @@ func (lxy LearnXInY) getLanguageMarkdown(language string) (*radium.Article, erro req, _ := http.NewRequest(http.MethodGet, ghURL, nil) req.Header.Set("User-Agent", "curl/7.54.0") + req.WithContext(ctx) resp, err := client.Do(req) if err != nil { diff --git a/sources/radium.go b/sources/radium.go index 3c9448e..41d2c07 100644 --- a/sources/radium.go +++ b/sources/radium.go @@ -1,6 +1,7 @@ package sources import ( + "context" "encoding/json" "fmt" "net/http" @@ -26,8 +27,7 @@ type Radium struct { // Search makes a GET /search to the radium server and formats the // response -func (rad Radium) Search(query radium.Query) ([]radium.Article, error) { - +func (rad Radium) Search(ctx context.Context, query radium.Query) ([]radium.Article, error) { timeout := time.Duration(5 * time.Second) client := http.Client{ Timeout: timeout, @@ -52,6 +52,7 @@ func (rad Radium) Search(query radium.Query) ([]radium.Article, error) { req, _ := http.NewRequest(http.MethodGet, urlObj.String(), nil) req.Header.Set("User-Agent", "curl/7.54.0") + req.WithContext(ctx) resp, err := client.Do(req) if err != nil { diff --git a/sources/tldr.go b/sources/tldr.go index 353338d..5d26b48 100644 --- a/sources/tldr.go +++ b/sources/tldr.go @@ -1,6 +1,7 @@ package sources import ( + "context" "fmt" "io/ioutil" "net/http" @@ -24,7 +25,7 @@ type TLDR struct { } // Search for a particular query in tldr-pages repository -func (tldr TLDR) Search(query radium.Query) ([]radium.Article, error) { +func (tldr TLDR) Search(ctx context.Context, query radium.Query) ([]radium.Article, error) { var rs []radium.Article tool := strings.Replace(query.Text, " ", "-", -1) @@ -34,14 +35,14 @@ func (tldr TLDR) Search(query radium.Query) ([]radium.Article, error) { platform = val } - res, err := tldr.getPlatformToolInfo(tool, platform) + res, err := tldr.getPlatformToolInfo(ctx, tool, platform) if err == nil { rs = append(rs, *res) } return rs, nil } -func (tldr TLDR) getPlatformToolInfo(tool, platform string) (*radium.Article, error) { +func (tldr TLDR) getPlatformToolInfo(ctx context.Context, tool, platform string) (*radium.Article, error) { rawGitURL := "https://raw.githubusercontent.com/tldr-pages/tldr/master/pages/%s/%s.md" ghURL := fmt.Sprintf(rawGitURL, url.QueryEscape(platform), url.QueryEscape(tool)) @@ -52,6 +53,7 @@ func (tldr TLDR) getPlatformToolInfo(tool, platform string) (*radium.Article, er req, _ := http.NewRequest(http.MethodGet, ghURL, nil) req.Header.Set("User-Agent", "curl/7.54.0") + req.WithContext(ctx) resp, err := client.Do(req) if err != nil { diff --git a/sources/wikipedia.go b/sources/wikipedia.go new file mode 100644 index 0000000..f22bb5c --- /dev/null +++ b/sources/wikipedia.go @@ -0,0 +1,22 @@ +package sources + +import ( + "context" + + "github.com/shivylp/radium" +) + +// NewWikipedia initializes wikipedia based radium source implementation. +func NewWikipedia() *Wikipedia { + return &Wikipedia{} +} + +// Wikipedia implements Source interface using wikipedia for lookups. +type Wikipedia struct { +} + +// Search will query en.wikipedia.com to find results and extracts the first +// paragraph of the page. +func (wiki *Wikipedia) Search(ctx context.Context, query radium.Query) ([]radium.Article, error) { + return nil, nil +} diff --git a/strategies/concurrent.go b/strategies/concurrent.go new file mode 100644 index 0000000..c1008f6 --- /dev/null +++ b/strategies/concurrent.go @@ -0,0 +1,71 @@ +package strategies + +import ( + "context" + "sync" + + "github.com/shivylp/radium" +) + +// NewConcurrent initializes a concurrent radium strategy +func NewConcurrent(logger radium.Logger) *Concurrent { + return &Concurrent{ + Logger: logger, + } +} + +// Concurrent is a radium strategy implementation. +type Concurrent struct { + radium.Logger +} + +// Execute the query against given list of sources concurrently. This strategy +// ingores the source errors and simply logs them. +func (con Concurrent) Execute(ctx context.Context, query radium.Query, sources []radium.RegisteredSource) ([]radium.Article, error) { + results := newSafeResults() + wg := &sync.WaitGroup{} + + for _, source := range sources { + wg.Add(1) + + go func(wg *sync.WaitGroup, src radium.RegisteredSource, rs *safeResults) { + srcResults, err := src.Search(ctx, query) + if err != nil { + con.Warnf("source '%s' failed: %s", src.Name, err) + return + } + + rs.extend(src.Name, con.Logger, srcResults) + wg.Done() + }(wg, source, results) + } + + wg.Wait() + return results.results, nil +} + +func newSafeResults() *safeResults { + return &safeResults{ + mu: &sync.Mutex{}, + } +} + +type safeResults struct { + mu *sync.Mutex + results []radium.Article +} + +func (sr *safeResults) extend(results []radium.Article, srcName string, logger radium.Logger) { + sr.mu.Lock() + defer sr.mu.Unlock() + + for _, res := range results { + if err := res.Validate(); err != nil { + logger.Warnf("ignoring invalid result from source '%s': %s", srcName, err) + continue + } + + sr.results = append(sr.results, res) + } + +} diff --git a/strategies/nth.go b/strategies/nth.go new file mode 100644 index 0000000..f3cb89e --- /dev/null +++ b/strategies/nth.go @@ -0,0 +1,45 @@ +package strategies + +import ( + "context" + + "github.com/shivylp/radium" +) + +// NewNthResult initializes NthResult strategy with given n +func NewNthResult(n int, logger radium.Logger) *NthResult { + return &NthResult{stopAt: n, Logger: logger} +} + +// NthResult implements a radium search strategy. This strategy +// executes search in the given order of sources and stops at nth +// result or if all the sources are executed. +type NthResult struct { + radium.Logger + + stopAt int +} + +// Execute each source in srcs until n results are obtained or all sources have +// been executed. This strategy returns on first error. +func (nth *NthResult) Execute(ctx context.Context, query radium.Query, srcs []radium.RegisteredSource) ([]radium.Article, error) { + results := []radium.Article{} + for _, src := range srcs { + select { + case <-ctx.Done(): + break + default: + } + + srcResults, err := src.Search(ctx, query) + if err != nil { + return nil, err + } + + results = append(results, srcResults...) + if len(results) >= nth.stopAt { + break + } + } + return results, nil +} diff --git a/strategy.go b/strategy.go new file mode 100644 index 0000000..fde79e6 --- /dev/null +++ b/strategy.go @@ -0,0 +1,113 @@ +package radium + +import ( + "context" + "sync" +) + +// Strategy implementation is responsible for performing queries +// against given set of sources using a particular approach. +type Strategy interface { + Execute(ctx context.Context, query Query, sources []RegisteredSource) ([]Article, error) +} + +// NewConcurrent initializes a concurrent radium strategy +func NewConcurrent(logger Logger) *Concurrent { + return &Concurrent{ + Logger: logger, + } +} + +// Concurrent is a radium strategy implementation. +type Concurrent struct { + Logger +} + +// Execute the query against given list of sources concurrently. This strategy +// ingores the source errors and simply logs them. +func (con Concurrent) Execute(ctx context.Context, query Query, sources []RegisteredSource) ([]Article, error) { + results := newSafeResults() + wg := &sync.WaitGroup{} + + for _, source := range sources { + wg.Add(1) + + go func(wg *sync.WaitGroup, src RegisteredSource, rs *safeResults) { + srcResults, err := src.Search(ctx, query) + if err != nil { + con.Warnf("source '%s' failed: %s", src.Name, err) + return + } + + rs.extend(srcResults, src.Name, con.Logger) + wg.Done() + }(wg, source, results) + } + + wg.Wait() + return results.results, nil +} + +func newSafeResults() *safeResults { + return &safeResults{ + mu: &sync.Mutex{}, + } +} + +type safeResults struct { + mu *sync.Mutex + results []Article +} + +func (sr *safeResults) extend(results []Article, srcName string, logger Logger) { + sr.mu.Lock() + defer sr.mu.Unlock() + + for _, res := range results { + if err := res.Validate(); err != nil { + logger.Warnf("ignoring invalid result from source '%s': %s", srcName, err) + continue + } + + sr.results = append(sr.results, res) + } + +} + +// NewNthResult initializes NthResult strategy with given n +func NewNthResult(n int, logger Logger) *NthResult { + return &NthResult{stopAt: n, Logger: logger} +} + +// NthResult implements a radium search strategy. This strategy +// executes search in the given order of sources and stops at nth +// result or if all the sources are executed. +type NthResult struct { + Logger + + stopAt int +} + +// Execute each source in srcs until n results are obtained or all sources have +// been executed. This strategy returns on first error. +func (nth *NthResult) Execute(ctx context.Context, query Query, srcs []RegisteredSource) ([]Article, error) { + results := []Article{} + for _, src := range srcs { + select { + case <-ctx.Done(): + break + default: + } + + srcResults, err := src.Search(ctx, query) + if err != nil { + return nil, err + } + + results = append(results, srcResults...) + if len(results) >= nth.stopAt { + break + } + } + return results, nil +}