diff --git a/.golangci.yml b/.golangci.yml index 3d9c8b0fe..6fe97685f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -132,7 +132,6 @@ linters: - deadcode - depguard - dogsled - - dupl - errcheck - funlen - gochecknoinits diff --git a/pkg/networkservice/chains/client/client.go b/pkg/networkservice/chains/client/client.go index a8467383d..aca90ef0c 100644 --- a/pkg/networkservice/chains/client/client.go +++ b/pkg/networkservice/chains/client/client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2022 Cisco and/or its affiliates. +// Copyright (c) 2021-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -28,6 +28,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/common/clienturl" "github.com/networkservicemesh/sdk/pkg/networkservice/common/connect" "github.com/networkservicemesh/sdk/pkg/networkservice/common/dial" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/limit" "github.com/networkservicemesh/sdk/pkg/networkservice/common/null" "github.com/networkservicemesh/sdk/pkg/networkservice/common/refresh" "github.com/networkservicemesh/sdk/pkg/networkservice/common/trimpath" @@ -63,6 +64,7 @@ func NewClient(ctx context.Context, clientOpts ...Option) networkservice.Network dial.WithDialOptions(opts.dialOptions...), dial.WithDialTimeout(opts.dialTimeout), ), + limit.NewClient(), }, append( opts.additionalFunctionality, diff --git a/pkg/networkservice/common/limit/client.go b/pkg/networkservice/common/limit/client.go new file mode 100644 index 000000000..29160143c --- /dev/null +++ b/pkg/networkservice/common/limit/client.go @@ -0,0 +1,114 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package limit provides a chain element that can set limits for the RPC calls. +package limit + +import ( + "context" + "time" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/networkservice" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/networkservice/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/log" +) + +// Option overrides default values +type Option func(c *limitClient) + +// WithDialLimit sets dial limit +func WithDialLimit(d time.Duration) Option { + return func(c *limitClient) { + c.dialLimit = d + } +} + +type limitClient struct { + dialLimit time.Duration +} + +// NewClient returns new NetworkServiceClient that limits rpc +func NewClient(opts ...Option) networkservice.NetworkServiceClient { + ret := &limitClient{ + dialLimit: time.Minute, + } + + for _, opt := range opts { + opt(ret) + } + + return ret +} + +func (n *limitClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) { + cc, ok := clientconn.Load(ctx) + if !ok { + return next.Server(ctx).Request(ctx, request) + } + + closer, ok := cc.(interface{ Close() error }) + if !ok { + return next.Server(ctx).Request(ctx, request) + } + + doneCh := make(chan struct{}) + defer close(doneCh) + + logger := log.FromContext(ctx).WithField("throttleClient", "Request") + + go func() { + select { + case <-time.After(n.dialLimit): + logger.Warn("Reached dial limit, closing connection...") + _ = closer.Close() + case <-doneCh: + return + } + }() + return next.Client(ctx).Request(ctx, request, opts...) +} + +func (n *limitClient) Close(ctx context.Context, conn *networkservice.Connection, opts ...grpc.CallOption) (*empty.Empty, error) { + cc, ok := clientconn.Load(ctx) + if !ok { + return next.Server(ctx).Close(ctx, conn) + } + + closer, ok := cc.(interface{ Close() error }) + if !ok { + return next.Server(ctx).Close(ctx, conn) + } + + doneCh := make(chan struct{}) + defer close(doneCh) + + logger := log.FromContext(ctx).WithField("throttleClient", "Close") + + go func() { + select { + case <-time.After(n.dialLimit): + logger.Warn("Reached dial limit, closing connection...") + _ = closer.Close() + case <-doneCh: + return + } + }() + return next.Client(ctx).Close(ctx, conn, opts...) +} diff --git a/pkg/networkservice/common/limit/client_test.go b/pkg/networkservice/common/limit/client_test.go new file mode 100644 index 000000000..ded243301 --- /dev/null +++ b/pkg/networkservice/common/limit/client_test.go @@ -0,0 +1,122 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package limit_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/networkservicemesh/api/pkg/api/networkservice" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/networkservice/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/limit" + "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkclose" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkrequest" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" +) + +type myConnection struct { + closed atomic.Bool + grpc.ClientConnInterface +} + +func (cc *myConnection) Close() error { + cc.closed.Store(true) + return nil +} + +func Test_DialLimitShouldCalled_OnLimitReached_Request(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + var cc = new(myConnection) + var myChain = chain.NewNetworkServiceClient( + metadata.NewClient(), + clientconn.NewClient(cc), + limit.NewClient(limit.WithDialLimit(time.Second/5)), + checkrequest.NewClient(t, func(t *testing.T, nsr *networkservice.NetworkServiceRequest) { + time.Sleep(time.Second / 4) + }), + ) + + _, _ = myChain.Request(context.Background(), &networkservice.NetworkServiceRequest{}) + + require.Eventually(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) +} + +func Test_DialLimitShouldCalled_OnLimitReached_Close(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + var cc = new(myConnection) + var myChain = chain.NewNetworkServiceClient( + metadata.NewClient(), + clientconn.NewClient(cc), + limit.NewClient(limit.WithDialLimit(time.Second/5)), + checkclose.NewClient(t, func(t *testing.T, nsr *networkservice.Connection) { + time.Sleep(time.Second / 4) + }), + ) + + _, _ = myChain.Request(context.Background(), &networkservice.NetworkServiceRequest{}) + _, _ = myChain.Close(context.Background(), &networkservice.Connection{}) + + require.Eventually(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) +} + +func Test_DialLimitShouldNotBeCalled_OnSuccesRequest(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + var cc = new(myConnection) + var myChain = chain.NewNetworkServiceClient( + metadata.NewClient(), + clientconn.NewClient(cc), + limit.NewClient(limit.WithDialLimit(time.Second/5)), + ) + + _, _ = myChain.Request(context.Background(), &networkservice.NetworkServiceRequest{}) + + require.Never(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) +} + +func Test_DialLimitShouldNotBeCalled_OnSuccessClose(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + var cc = new(myConnection) + var myChain = chain.NewNetworkServiceClient( + metadata.NewClient(), + clientconn.NewClient(cc), + limit.NewClient(limit.WithDialLimit(time.Second/5)), + ) + + _, _ = myChain.Request(context.Background(), &networkservice.NetworkServiceRequest{}) + _, _ = myChain.Close(context.Background(), &networkservice.Connection{}) + + require.Never(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) +} diff --git a/pkg/registry/chains/client/ns_client.go b/pkg/registry/chains/client/ns_client.go index 749151548..bcdbdce72 100644 --- a/pkg/registry/chains/client/ns_client.go +++ b/pkg/registry/chains/client/ns_client.go @@ -31,6 +31,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/dial" "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/common/heal" + "github.com/networkservicemesh/sdk/pkg/registry/common/limit" "github.com/networkservicemesh/sdk/pkg/registry/common/null" "github.com/networkservicemesh/sdk/pkg/registry/common/retry" "github.com/networkservicemesh/sdk/pkg/registry/core/chain" @@ -63,6 +64,7 @@ func NewNetworkServiceRegistryClient(ctx context.Context, opts ...Option) regist dial.WithDialTimeout(clientOpts.dialTimeout), dial.WithDialOptions(clientOpts.dialOptions...), ), + limit.NewNetworkServiceRegistryClient(), }, append( clientOpts.nsAdditionalFunctionality, diff --git a/pkg/registry/chains/client/nse_client.go b/pkg/registry/chains/client/nse_client.go index 7ef1d7ae5..a5cc0af39 100644 --- a/pkg/registry/chains/client/nse_client.go +++ b/pkg/registry/chains/client/nse_client.go @@ -32,6 +32,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/dial" "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/common/heal" + "github.com/networkservicemesh/sdk/pkg/registry/common/limit" "github.com/networkservicemesh/sdk/pkg/registry/common/null" "github.com/networkservicemesh/sdk/pkg/registry/common/refresh" "github.com/networkservicemesh/sdk/pkg/registry/common/retry" @@ -66,6 +67,7 @@ func NewNetworkServiceEndpointRegistryClient(ctx context.Context, opts ...Option dial.WithDialTimeout(clientOpts.dialTimeout), dial.WithDialOptions(clientOpts.dialOptions...), ), + limit.NewNetworkServiceEndpointRegistryClient(), }, append( clientOpts.nseAdditionalFunctionality, diff --git a/pkg/registry/chains/proxydns/server_ns_test.go b/pkg/registry/chains/proxydns/server_ns_test.go index f7fba1294..316273c86 100644 --- a/pkg/registry/chains/proxydns/server_ns_test.go +++ b/pkg/registry/chains/proxydns/server_ns_test.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2022 Cisco Systems, Inc. +// Copyright (c) 2022-2024 Cisco Systems, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -143,7 +143,7 @@ func TestLocalDomain_NetworkServiceRegistry(t *testing.T) { registryclient.WithDialOptions(grpc.WithTransportCredentials(insecure.NewCredentials())), registryclient.WithClientURL(domain1.Registry.URL)) - stream, err := client2.Find(context.Background(), ®istryapi.NetworkServiceQuery{ + stream, err := client2.Find(ctx, ®istryapi.NetworkServiceQuery{ NetworkService: ®istryapi.NetworkService{ Name: expected.Name, }, diff --git a/pkg/registry/common/limit/common.go b/pkg/registry/common/limit/common.go new file mode 100644 index 000000000..2f3b572c1 --- /dev/null +++ b/pkg/registry/common/limit/common.go @@ -0,0 +1,33 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package limit + +import "time" + +type limitConfig struct { + dialLimit time.Duration +} + +// Option overrides default values +type Option func(cfg *limitConfig) + +// WithDialLimit sets dial time limit +func WithDialLimit(t time.Duration) Option { + return Option(func(cfg *limitConfig) { + cfg.dialLimit = t + }) +} diff --git a/pkg/registry/common/limit/ns_client.go b/pkg/registry/common/limit/ns_client.go new file mode 100644 index 000000000..d9079d81c --- /dev/null +++ b/pkg/registry/common/limit/ns_client.go @@ -0,0 +1,146 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package limit provides a chain element that can set limits for the RPC calls. +package limit + +import ( + "context" + "time" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/log" +) + +type limitNSClient struct { + cfg *limitConfig +} + +func (n *limitNSClient) Register(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + cc, ok := clientconn.Load(ctx) + if !ok { + return next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) + } + + closer, ok := cc.(interface{ Close() error }) + if !ok { + return next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) + } + + doneCh := make(chan struct{}) + defer close(doneCh) + + logger := log.FromContext(ctx).WithField("throttleNSClient", "Register") + + go func() { + select { + case <-time.After(n.cfg.dialLimit): + logger.Warn("Reached dial limit, closing connection...") + _ = closer.Close() + case <-doneCh: + return + } + }() + return next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (n *limitNSClient) Find(ctx context.Context, in *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + cc, ok := clientconn.Load(ctx) + if !ok { + return next.NetworkServiceRegistryClient(ctx).Find(ctx, in, opts...) + } + + closer, ok := cc.(interface{ Close() error }) + if !ok { + return next.NetworkServiceRegistryClient(ctx).Find(ctx, in, opts...) + } + + logger := log.FromContext(ctx).WithField("throttleNSClient", "Find") + + doneCh := make(chan struct{}) + defer close(doneCh) + + go func() { + select { + case <-time.After(n.cfg.dialLimit): + logger.Warn("Reached dial limit, closing connection...") + _ = closer.Close() + case <-doneCh: + return + } + }() + + resp, err := next.NetworkServiceRegistryClient(ctx).Find(ctx, in, opts...) + if err == nil { + go func() { + select { + case <-time.After(n.cfg.dialLimit): + logger.Warn("Reached dial limit, closing connection...") + _ = closer.Close() + case <-resp.Context().Done(): + return + } + }() + } + return resp, err +} + +func (n *limitNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { + cc, ok := clientconn.Load(ctx) + if !ok { + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) + } + + closer, ok := cc.(interface{ Close() error }) + if !ok { + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) + } + + doneCh := make(chan struct{}) + defer close(doneCh) + + logger := log.FromContext(ctx).WithField("throttleNSClient", "Unregister") + + go func() { + select { + case <-time.After(n.cfg.dialLimit): + logger.Warn("Reached dial limit, closing connection...") + _ = closer.Close() + case <-doneCh: + return + } + }() + + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +// NewNetworkServiceRegistryClient - returns a new null client that does nothing but call next.NetworkServiceRegistryClient(ctx). +func NewNetworkServiceRegistryClient(opts ...Option) registry.NetworkServiceRegistryClient { + cfg := &limitConfig{ + dialLimit: time.Minute, + } + for _, opt := range opts { + opt(cfg) + } + return &limitNSClient{ + cfg: cfg, + } +} diff --git a/pkg/registry/common/limit/ns_client_test.go b/pkg/registry/common/limit/ns_client_test.go new file mode 100644 index 000000000..cd7e647f0 --- /dev/null +++ b/pkg/registry/common/limit/ns_client_test.go @@ -0,0 +1,124 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package limit_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/common/limit" + "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checkcontext" + + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/utils/metadata" +) + +type myConnection struct { + closed atomic.Bool + grpc.ClientConnInterface +} + +func (cc *myConnection) Close() error { + cc.closed.Store(true) + return nil +} + +func Test_DialLimitShouldCalled_OnLimitReached(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + var cc = new(myConnection) + var myChain = chain.NewNetworkServiceRegistryClient( + metadata.NewNetworkServiceClient(), + clientconn.NewNetworkServiceRegistryClient(), + checkcontext.NewNSClient(t, func(t *testing.T, ctx context.Context) { + clientconn.Store(ctx, cc) + }), + limit.NewNetworkServiceRegistryClient(limit.WithDialLimit(time.Second/5)), + checkcontext.NewNSClient(t, func(t *testing.T, ctx context.Context) { + time.Sleep(time.Second / 5) + }), + ) + + _, _ = myChain.Register(context.Background(), ®istry.NetworkService{Name: t.Name()}) + + require.Eventually(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) + + cc.closed.Store(false) + + _, _ = myChain.Find(context.Background(), ®istry.NetworkServiceQuery{NetworkService: ®istry.NetworkService{Name: t.Name()}}) + + require.Eventually(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) + + cc.closed.Store(false) + + _, _ = myChain.Unregister(context.Background(), ®istry.NetworkService{Name: t.Name()}) + + require.Eventually(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) +} + +func Test_DialLimitShouldNotBeCalled_OnSuccess(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + var cc = new(myConnection) + var myChain = chain.NewNetworkServiceRegistryClient( + metadata.NewNetworkServiceClient(), + clientconn.NewNetworkServiceRegistryClient(), + checkcontext.NewNSClient(t, func(t *testing.T, ctx context.Context) { + clientconn.Store(ctx, cc) + }), + limit.NewNetworkServiceRegistryClient(limit.WithDialLimit(time.Second/5)), + ) + + ctx, cancel := context.WithCancel(context.Background()) + _, _ = myChain.Register(ctx, ®istry.NetworkService{Name: t.Name()}) + cancel() + + require.Never(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) + + ctx, cancel = context.WithCancel(context.Background()) + _, _ = myChain.Find(ctx, ®istry.NetworkServiceQuery{NetworkService: ®istry.NetworkService{Name: t.Name()}}) + cancel() + + require.Never(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) + + ctx, cancel = context.WithCancel(context.Background()) + _, _ = myChain.Unregister(ctx, ®istry.NetworkService{Name: t.Name()}) + cancel() + + require.Never(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) +} diff --git a/pkg/registry/common/limit/nse_client.go b/pkg/registry/common/limit/nse_client.go new file mode 100644 index 000000000..4d98a4d54 --- /dev/null +++ b/pkg/registry/common/limit/nse_client.go @@ -0,0 +1,148 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package limit + +import ( + "context" + "time" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/log" +) + +type limitNSEClient struct { + cfg *limitConfig +} + +func (n *limitNSEClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + cc, ok := clientconn.Load(ctx) + if !ok { + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) + } + + closer, ok := cc.(interface{ Close() error }) + if !ok { + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) + } + + doneCh := make(chan struct{}) + defer close(doneCh) + + logger := log.FromContext(ctx).WithField("throttleNSEClient", "Register") + + go func() { + select { + case <-time.After(n.cfg.dialLimit): + logger.Warn("Reached dial limit, closing connection...") + _ = closer.Close() + case <-doneCh: + return + } + }() + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (n *limitNSEClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + cc, ok := clientconn.Load(ctx) + if !ok { + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) + } + + closer, ok := cc.(interface{ Close() error }) + if !ok { + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) + } + + logger := log.FromContext(ctx).WithField("throttleNSEClient", "Find") + doneCh := make(chan struct{}) + defer close(doneCh) + + go func() { + select { + case <-time.After(n.cfg.dialLimit): + logger.Warn("Reached dial limit, closing connection...") + _ = closer.Close() + case <-doneCh: + return + } + }() + + resp, err := next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) + + if err == nil { + go func() { + select { + case <-time.After(n.cfg.dialLimit): + logger.Warn("Reached dial limit, closing connection...") + _ = closer.Close() + case <-resp.Context().Done(): + return + } + }() + } + + return resp, err +} + +func (n *limitNSEClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + cc, ok := clientconn.Load(ctx) + if !ok { + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) + } + + closer, ok := cc.(interface{ Close() error }) + if !ok { + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) + } + + doneCh := make(chan struct{}) + defer close(doneCh) + + logger := log.FromContext(ctx).WithField("throttleNSEClient", "Unregister") + + go func() { + select { + case <-time.After(n.cfg.dialLimit): + logger.Warn("Reached dial limit, closing connection...") + _ = closer.Close() + case <-doneCh: + return + } + }() + + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +// NewNetworkServiceEndpointRegistryClient - returns a new null client that does nothing but call next.NetworkServiceEndpointRegistryClient(ctx). +func NewNetworkServiceEndpointRegistryClient(opts ...Option) registry.NetworkServiceEndpointRegistryClient { + cfg := &limitConfig{ + dialLimit: time.Minute, + } + + for _, opt := range opts { + opt(cfg) + } + + return &limitNSEClient{ + cfg: cfg, + } +} diff --git a/pkg/registry/common/limit/nse_client_test.go b/pkg/registry/common/limit/nse_client_test.go new file mode 100644 index 000000000..6939ac26f --- /dev/null +++ b/pkg/registry/common/limit/nse_client_test.go @@ -0,0 +1,112 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package limit_test + +import ( + "context" + "testing" + "time" + + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/registry/common/clientconn" + "github.com/networkservicemesh/sdk/pkg/registry/common/limit" + "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checkcontext" + + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/utils/metadata" +) + +func Test_NSEDialLimitShouldCalled_OnLimitReached(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + var cc = new(myConnection) + var myChain = chain.NewNetworkServiceEndpointRegistryClient( + metadata.NewNetworkServiceEndpointClient(), + clientconn.NewNetworkServiceEndpointRegistryClient(), + checkcontext.NewNSEClient(t, func(t *testing.T, ctx context.Context) { + clientconn.Store(ctx, cc) + }), + limit.NewNetworkServiceEndpointRegistryClient(limit.WithDialLimit(time.Second/5)), + checkcontext.NewNSEClient(t, func(t *testing.T, ctx context.Context) { + time.Sleep(time.Second / 5) + }), + ) + + _, _ = myChain.Register(context.Background(), ®istry.NetworkServiceEndpoint{Name: t.Name()}) + + require.Eventually(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) + + cc.closed.Store(false) + + _, _ = myChain.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{Name: t.Name()}}) + + require.Eventually(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) + + cc.closed.Store(false) + + _, _ = myChain.Unregister(context.Background(), ®istry.NetworkServiceEndpoint{Name: t.Name()}) + + require.Eventually(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) +} + +func Test_NSEDialLimitShouldNotBeCalled_OnSuccess(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + var cc = new(myConnection) + var myChain = chain.NewNetworkServiceEndpointRegistryClient( + metadata.NewNetworkServiceEndpointClient(), + clientconn.NewNetworkServiceEndpointRegistryClient(), + checkcontext.NewNSEClient(t, func(t *testing.T, ctx context.Context) { + clientconn.Store(ctx, cc) + }), + limit.NewNetworkServiceEndpointRegistryClient(limit.WithDialLimit(time.Second/5)), + ) + + ctx, cancel := context.WithCancel(context.Background()) + _, _ = myChain.Register(ctx, ®istry.NetworkServiceEndpoint{Name: t.Name()}) + cancel() + + require.Never(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) + + ctx, cancel = context.WithCancel(context.Background()) + _, _ = myChain.Find(ctx, ®istry.NetworkServiceEndpointQuery{NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{Name: t.Name()}}) + cancel() + + require.Never(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) + + ctx, cancel = context.WithCancel(context.Background()) + _, _ = myChain.Unregister(ctx, ®istry.NetworkServiceEndpoint{Name: t.Name()}) + cancel() + + require.Never(t, func() bool { + return cc.closed.Load() + }, time.Second/2, time.Millisecond*75) +}