Skip to content

Commit

Permalink
chore: check for closed connections before sending [DET-4829] (#1949)
Browse files Browse the repository at this point in the history
  • Loading branch information
mackrorysd authored Feb 11, 2021
1 parent b62a0b1 commit 2cff862
Show file tree
Hide file tree
Showing 15 changed files with 68 additions and 49 deletions.
2 changes: 1 addition & 1 deletion CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# API
/proto/ @hamidzr
/master/internal/api @hamidzr
/master/internal/grpc @hamidzr
/master/internal/grpcutil @hamidzr
/master/internal/api*.go @hamidzr

# Kubernetes
Expand Down
10 changes: 7 additions & 3 deletions master/internal/api/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ func NewLogStoreProcessor(
}
}

func connectionIsClosed(ctx context.Context) bool {
return ctx.Err() != nil
}

// Receive implements the actor.Actor interface.
func (l *LogStoreProcessor) Receive(ctx *actor.Context) error {
type tick struct{}
Expand All @@ -100,7 +104,7 @@ func (l *LogStoreProcessor) Receive(ctx *actor.Context) error {
ctx.Tell(ctx.Self(), tick{})

case tick:
if l.ctx.Err() != nil {
if connectionIsClosed(l.ctx) {
ctx.Self().Stop()
return nil
}
Expand Down Expand Up @@ -131,7 +135,7 @@ func (l *LogStoreProcessor) Receive(ctx *actor.Context) error {
default:
// Check the ctx again before we process, since fetch takes most of the time and
// a send on a closed ctx will print errors in the master log that can be misleading.
if l.ctx.Err() != nil {
if connectionIsClosed(l.ctx) {
ctx.Self().Stop()
return nil
}
Expand Down Expand Up @@ -194,7 +198,7 @@ func (l *LogStreamProcessor) Receive(ctx *actor.Context) error {
}

case LogBatch:
if l.ctx.Err() != nil {
if connectionIsClosed(l.ctx) {
ctx.Self().Stop()
break
}
Expand Down
12 changes: 6 additions & 6 deletions master/internal/api_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"fmt"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/grpc"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

Expand All @@ -30,17 +30,17 @@ func (a *apiServer) Login(
switch err {
case nil:
case db.ErrNotFound:
return nil, grpc.ErrInvalidCredentials
return nil, grpcutil.ErrInvalidCredentials
default:
return nil, err
}

if !user.ValidatePassword(replicateClientSideSaltAndHash(req.Password)) {
return nil, grpc.ErrInvalidCredentials
return nil, grpcutil.ErrInvalidCredentials
}

if !user.Active {
return nil, grpc.ErrPermissionDenied
return nil, grpcutil.ErrPermissionDenied
}

token, err := a.m.db.StartUserSession(user)
Expand All @@ -53,7 +53,7 @@ func (a *apiServer) Login(

func (a *apiServer) CurrentUser(
ctx context.Context, _ *apiv1.CurrentUserRequest) (*apiv1.CurrentUserResponse, error) {
user, _, err := grpc.GetUser(ctx, a.m.db)
user, _, err := grpcutil.GetUser(ctx, a.m.db)
if err != nil {
return nil, err
}
Expand All @@ -63,7 +63,7 @@ func (a *apiServer) CurrentUser(

func (a *apiServer) Logout(
ctx context.Context, _ *apiv1.LogoutRequest) (*apiv1.LogoutResponse, error) {
_, userSession, err := grpc.GetUser(ctx, a.m.db)
_, userSession, err := grpcutil.GetUser(ctx, a.m.db)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions master/internal/api_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

"github.com/determined-ai/determined/master/internal/api"
"github.com/determined-ai/determined/master/internal/command"
"github.com/determined-ai/determined/master/internal/grpc"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/sproto"
"github.com/determined-ai/determined/master/pkg/actor"
"github.com/determined-ai/determined/master/pkg/model"
Expand All @@ -35,7 +35,7 @@ type protoCommandParams struct {
func (a *apiServer) prepareLaunchParams(ctx context.Context, req *protoCommandParams) (
*command.CommandParams, *model.User, error,
) {
user, _, err := grpc.GetUser(ctx, a.m.db)
user, _, err := grpcutil.GetUser(ctx, a.m.db)
if err != nil {
return nil, nil, status.Errorf(codes.Internal, "failed to get the user: %s", err)
}
Expand Down
34 changes: 22 additions & 12 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/pkg/errors"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/grpc"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/hpimportance"
"github.com/determined-ai/determined/master/internal/lttb"
"github.com/determined-ai/determined/master/pkg/actor"
Expand Down Expand Up @@ -514,7 +514,7 @@ func (a *apiServer) CreateExperiment(
return &apiv1.CreateExperimentResponse{}, nil
}

user, _, err := grpc.GetUser(ctx, a.m.db)
user, _, err := grpcutil.GetUser(ctx, a.m.db)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get the user: %s", err)
}
Expand Down Expand Up @@ -585,6 +585,9 @@ func (a *apiServer) MetricNames(req *apiv1.MetricNamesRequest,
}
}

if grpcutil.ConnectionIsClosed(resp) {
return nil
}
if err = resp.Send(&response); err != nil {
return err
}
Expand All @@ -598,8 +601,7 @@ func (a *apiServer) MetricNames(req *apiv1.MetricNamesRequest,
}

time.Sleep(period)
if err := resp.Context().Err(); err != nil {
// connection is closed
if grpcutil.ConnectionIsClosed(resp) {
return nil
}
}
Expand Down Expand Up @@ -654,6 +656,9 @@ func (a *apiServer) MetricBatches(req *apiv1.MetricBatchesRequest,
}
}

if grpcutil.ConnectionIsClosed(resp) {
return nil
}
if err = resp.Send(&response); err != nil {
return errors.Wrapf(err, "error sending batches recorded for metric")
}
Expand All @@ -667,8 +672,7 @@ func (a *apiServer) MetricBatches(req *apiv1.MetricBatchesRequest,
}

time.Sleep(period)
if err := resp.Context().Err(); err != nil {
// connection is closed
if grpcutil.ConnectionIsClosed(resp) {
return nil
}
}
Expand Down Expand Up @@ -720,6 +724,9 @@ func (a *apiServer) TrialsSnapshot(req *apiv1.TrialsSnapshotRequest,

response.Trials = newTrials

if grpcutil.ConnectionIsClosed(resp) {
return nil
}
if err = resp.Send(&response); err != nil {
return errors.Wrapf(err, "error sending batches recorded for metrics")
}
Expand All @@ -733,8 +740,7 @@ func (a *apiServer) TrialsSnapshot(req *apiv1.TrialsSnapshotRequest,
}

time.Sleep(period)
if err := resp.Context().Err(); err != nil {
// connection is closed
if grpcutil.ConnectionIsClosed(resp) {
return nil
}
}
Expand Down Expand Up @@ -919,6 +925,9 @@ func (a *apiServer) TrialsSample(req *apiv1.TrialsSampleRequest,
response.PromotedTrials = promotedTrials
response.DemotedTrials = demotedTrials

if grpcutil.ConnectionIsClosed(resp) {
return nil
}
if err = resp.Send(&response); err != nil {
return errors.Wrap(err, "error sending sample of trial metric streams")
}
Expand All @@ -932,8 +941,7 @@ func (a *apiServer) TrialsSample(req *apiv1.TrialsSampleRequest,
}

time.Sleep(period)
if err := resp.Context().Err(); err != nil {
// connection is closed
if grpcutil.ConnectionIsClosed(resp) {
return nil
}
}
Expand Down Expand Up @@ -1011,6 +1019,9 @@ func (a *apiServer) GetHPImportance(req *apiv1.GetHPImportanceRequest,
response.ValidationMetrics[metric] = protoMetricHPI(metricHpi)
}

if grpcutil.ConnectionIsClosed(resp) {
return nil
}
if err := resp.Send(&response); err != nil {
return errors.Wrap(err, "error sending hyperparameter importance response")
}
Expand Down Expand Up @@ -1040,8 +1051,7 @@ func (a *apiServer) GetHPImportance(req *apiv1.GetHPImportanceRequest,
}

time.Sleep(period)
if err := resp.Context().Err(); err != nil {
// connection is closed
if grpcutil.ConnectionIsClosed(resp) {
return nil
}
}
Expand Down
8 changes: 4 additions & 4 deletions master/internal/api_master.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/determined-ai/determined/proto/pkg/logv1"

"github.com/determined-ai/determined/master/internal/api"
"github.com/determined-ai/determined/master/internal/grpc"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

Expand Down Expand Up @@ -55,9 +55,9 @@ func (a *apiServer) GetMasterConfig(

func (a *apiServer) MasterLogs(
req *apiv1.MasterLogsRequest, resp apiv1.Determined_MasterLogsServer) error {
if err := grpc.ValidateRequest(
grpc.ValidateLimit(req.Limit),
grpc.ValidateFollow(req.Limit, req.Follow),
if err := grpcutil.ValidateRequest(
grpcutil.ValidateLimit(req.Limit),
grpcutil.ValidateFollow(req.Limit, req.Follow),
); err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions master/internal/api_notebook.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

"github.com/determined-ai/determined/master/internal/api"
"github.com/determined-ai/determined/master/internal/command"
"github.com/determined-ai/determined/master/internal/grpc"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/sproto"
"github.com/determined-ai/determined/master/pkg/actor"
"github.com/determined-ai/determined/master/pkg/logger"
Expand Down Expand Up @@ -42,8 +42,8 @@ func (a *apiServer) KillNotebook(

func (a *apiServer) NotebookLogs(
req *apiv1.NotebookLogsRequest, resp apiv1.Determined_NotebookLogsServer) error {
if err := grpc.ValidateRequest(
grpc.ValidateLimit(req.Limit),
if err := grpcutil.ValidateRequest(
grpcutil.ValidateLimit(req.Limit),
); err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (

"github.com/determined-ai/determined/master/internal/api"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/grpc"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/pkg/actor"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/proto/pkg/apiv1"
Expand Down Expand Up @@ -50,9 +50,9 @@ type TrialLogBackend interface {

func (a *apiServer) TrialLogs(
req *apiv1.TrialLogsRequest, resp apiv1.Determined_TrialLogsServer) error {
if err := grpc.ValidateRequest(
grpc.ValidateLimit(req.Limit),
grpc.ValidateFollow(req.Limit, req.Follow),
if err := grpcutil.ValidateRequest(
grpcutil.ValidateLimit(req.Limit),
grpcutil.ValidateFollow(req.Limit, req.Follow),
); err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions master/internal/api_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"google.golang.org/grpc/status"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/grpc"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/userv1"
Expand Down Expand Up @@ -80,14 +80,14 @@ func (a *apiServer) GetUser(

func (a *apiServer) PostUser(
ctx context.Context, req *apiv1.PostUserRequest) (*apiv1.PostUserResponse, error) {
curUser, _, err := grpc.GetUser(ctx, a.m.db)
curUser, _, err := grpcutil.GetUser(ctx, a.m.db)
if err != nil {
return nil, err
}
if !curUser.Admin {
return nil, grpc.ErrPermissionDenied
return nil, grpcutil.ErrPermissionDenied
}
if err = grpc.ValidateRequest(
if err = grpcutil.ValidateRequest(
func() (bool, string) { return req.User != nil, "no user specified" },
func() (bool, string) { return req.User.Username != "", "no username specified" },
); err != nil {
Expand Down Expand Up @@ -121,12 +121,12 @@ func (a *apiServer) PostUser(

func (a *apiServer) SetUserPassword(
ctx context.Context, req *apiv1.SetUserPasswordRequest) (*apiv1.SetUserPasswordResponse, error) {
curUser, _, err := grpc.GetUser(ctx, a.m.db)
curUser, _, err := grpcutil.GetUser(ctx, a.m.db)
if err != nil {
return nil, err
}
if !curUser.Admin && curUser.Username != req.Username {
return nil, grpc.ErrPermissionDenied
return nil, grpcutil.ErrPermissionDenied
}
user := &model.User{Username: req.Username}
if err = user.UpdatePasswordHash(replicateClientSideSaltAndHash(req.Password)); err != nil {
Expand Down
6 changes: 3 additions & 3 deletions master/internal/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
"github.com/determined-ai/determined/master/internal/command"
detContext "github.com/determined-ai/determined/master/internal/context"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/grpc"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/hpimportance"
"github.com/determined-ai/determined/master/internal/proxy"
"github.com/determined-ai/determined/master/internal/resourcemanagers"
Expand Down Expand Up @@ -163,7 +163,7 @@ func (m *Master) startServers(ctx context.Context, cert *tls.Certificate) error
}

// Initialize listeners and multiplexing.
err = grpc.RegisterHTTPProxy(ctx, m.echo, m.config.Port, cert)
err = grpcutil.RegisterHTTPProxy(ctx, m.echo, m.config.Port, cert)
if err != nil {
return errors.Wrap(err, "failed to register gRPC gateway")
}
Expand All @@ -188,7 +188,7 @@ func (m *Master) startServers(ctx context.Context, cert *tls.Certificate) error
}()
}
start("gRPC server", func() error {
srv := grpc.NewGRPCServer(m.db, &apiServer{m: m})
srv := grpcutil.NewGRPCServer(m.db, &apiServer{m: m})
// We should defer srv.Stop() here, but cmux does not unblock accept calls when underlying
// listeners close and grpc-go depends on cmux unblocking and closing, Stop() blocks
// indefinitely when using cmux.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package grpcutil

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package grpcutil

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package grpcutil

import (
"context"
Expand All @@ -7,6 +7,7 @@ import (

"github.com/grpc-ecosystem/grpc-gateway/runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -55,3 +56,7 @@ func errorHandler(
}
}
}

func ConnectionIsClosed(stream grpc.ServerStream) bool {
return stream.Context().Err() != nil
}
Loading

0 comments on commit 2cff862

Please sign in to comment.