diff --git a/cli/agent/agent.go b/cli/agent/agent.go index cfbd83a..09b4fdb 100644 --- a/cli/agent/agent.go +++ b/cli/agent/agent.go @@ -77,7 +77,9 @@ var Cmd = &cobra.Command{ }() logger.Info().Msg("listening os signal: SIGINT, SIGTERM") - grpcServer = grpc.NewServer() + grpcServer = grpc.NewServer( + grpc.UnaryInterceptor(util.UnaryServerInterceptor()), + ) logger.Info().Msg("initalized gRPC server") lakeService, err = lake.NewService( ctx, diff --git a/cli/client/client.go b/cli/client/client.go index 07209ff..07651a9 100644 --- a/cli/client/client.go +++ b/cli/client/client.go @@ -170,65 +170,64 @@ var Cmd = &cobra.Command{ go func() { defer wg.Done() reader := bufio.NewReader(os.Stdin) - ok := true for { - printInput(false) - input, err := reader.ReadString('\n') - if err == io.EOF { + select { + case <-ctx.Done(): return - } - if err != nil { - fmt.Println(err) - continue - } - input = strings.TrimSuffix(input, "\n") - if input == "" { - continue - } - if !ok { - cancel() - return - } - go func() { - msg := Msg{ - Data: []byte(input), - Metadata: map[string][]byte{ - "nickname": []byte(nickname), - }, - } - data, err := json.Marshal(msg) - if err != nil { - fmt.Println(err) + default: + printInput(false) + input, err := reader.ReadString('\n') + if err == io.EOF { return } - sigDataBytes, err := privKey.Sign(data) if err != nil { fmt.Println(err) - return + continue + } + input = strings.TrimSuffix(input, "\n") + if input == "" { + continue } + go func() { + msg := Msg{ + Data: []byte(input), + Metadata: map[string][]byte{ + "nickname": []byte(nickname), + }, + } + data, err := json.Marshal(msg) + if err != nil { + fmt.Println(err) + return + } + sigDataBytes, err := privKey.Sign(data) + if err != nil { + fmt.Println(err) + return + } - pubRes, err := cli.Publish(ctx, &pb.PublishReq{ - TopicId: topicID, - MsgCapsule: &pb.MsgCapsule{ - Data: data, - Signature: &pb.Signature{ - PubKey: pubKeyBytes, - Data: sigDataBytes, + pubRes, err := cli.Publish(ctx, &pb.PublishReq{ + TopicId: topicID, + MsgCapsule: &pb.MsgCapsule{ + Data: data, + Signature: &pb.Signature{ + PubKey: pubKeyBytes, + Data: sigDataBytes, + }, }, - }, - }) - if err != nil { - fmt.Println(err) - ok = false - return - } + }) + if err != nil { + fmt.Println(err) + return + } - // check publish res - if !pubRes.GetOk() { - fmt.Println("failed to send message") - return - } - }() + // check publish res + if !pubRes.GetOk() { + fmt.Println("failed to send message") + return + } + }() + } } }() diff --git a/go.mod b/go.mod index b2e8fe9..a01735b 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/rs/zerolog v1.29.1 github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.4 + go.uber.org/ratelimit v0.3.0 google.golang.org/grpc v1.55.0 google.golang.org/protobuf v1.30.0 ) diff --git a/go.sum b/go.sum index 913a120..db399be 100644 --- a/go.sum +++ b/go.sum @@ -682,6 +682,8 @@ go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKY go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/ratelimit v0.3.0 h1:IdZd9wqvFXnvLvSEBo0KPcGfkoBGNkpTHlrE3Rcjkjw= +go.uber.org/ratelimit v0.3.0/go.mod h1:So5LG7CV1zWpY1sHe+DXTJqQvOx+FFPFaAs2SnoyBaI= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.16.0/go.mod h1:MA8QOfq0BHJwdXa996Y4dYkAqRKB8/1K1QMMZVaNZjQ= diff --git a/msg/box.go b/msg/box.go index f3e096e..9fc64b4 100644 --- a/msg/box.go +++ b/msg/box.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "strconv" "sync" "time" @@ -256,21 +255,15 @@ func (box *Box) LeaveSub(subscriberID string) error { } func init() { - tmp, err := getEnvInt("INTERNAL_CHAN_BUFFER_SIZE", DefaultInternalChanBufferSize) + tmp, err := util.GetEnvInt("INTERNAL_CHAN_BUFFER_SIZE", DefaultInternalChanBufferSize) if err != nil { panic(err) } internalChanBufferSize = tmp - tmp, err = getEnvInt("EXTERNAL_CHAN_BUFFER_SIZE", DefaultExternalChanBufferSize) + tmp, err = util.GetEnvInt("EXTERNAL_CHAN_BUFFER_SIZE", DefaultExternalChanBufferSize) if err != nil { panic(err) } externalChanBufferSize = tmp } - -func getEnvInt(key string, fallback int) (int, error) { - tmpStr := strconv.Itoa(fallback) - tmpStr = util.GetEnv(key, tmpStr) - return strconv.Atoi(tmpStr) -} diff --git a/util/grpc.go b/util/grpc.go new file mode 100644 index 0000000..aba1e20 --- /dev/null +++ b/util/grpc.go @@ -0,0 +1,32 @@ +package util + +import ( + "context" + + "go.uber.org/ratelimit" + "google.golang.org/grpc" +) + +const ( + DefaultUnaryServerInterceptorRateLimit = 10000 +) + +var ( + unaryServerInterceptorRateLimit int +) + +func UnaryServerInterceptor() grpc.UnaryServerInterceptor { + rl := ratelimit.New(unaryServerInterceptorRateLimit) + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + rl.Take() + return handler(ctx, req) + } +} + +func init() { + tmp, err := getEnvInt("UNARY_SERVER_INTERCEPTOR_RATE_LIMIT", DefaultUnaryServerInterceptorRateLimit) + if err != nil { + panic(err) + } + unaryServerInterceptorRateLimit = tmp +} diff --git a/util/util.go b/util/util.go index 9521e63..98846e9 100644 --- a/util/util.go +++ b/util/util.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/base64" "os" + "strconv" ) func CheckStrLen(target string, min, max int) bool { @@ -20,7 +21,7 @@ func GenerateRandomBase64String(size int) string { return base64.RawStdEncoding.EncodeToString(bytes) } -func GetEnv(key, fallback string) string { +func getEnv(key, fallback string) string { value, ok := os.LookupEnv(key) if !ok { return fallback @@ -28,6 +29,20 @@ func GetEnv(key, fallback string) string { return value } +func GetEnv(key, fallback string) string { + return getEnv(key, fallback) +} + +func getEnvInt(key string, fallback int) (int, error) { + tmpStr := strconv.Itoa(fallback) + tmpStr = getEnv(key, tmpStr) + return strconv.Atoi(tmpStr) +} + +func GetEnvInt(key string, fallback int) (int, error) { + return getEnvInt(key, fallback) +} + func GetLogLevel() string { return GetEnv("LOG_LEVEL", "info") }