diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 2c77ee51a0e3a..f038cf502625c 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -7451,6 +7451,10 @@ type authProviderMock struct { server types.ServerV2 } +func (mock authProviderMock) ListUnifiedResources(ctx context.Context, req *authproto.ListUnifiedResourcesRequest) (*authproto.ListUnifiedResourcesResponse, error) { + return nil, nil +} + func (mock authProviderMock) GetNodes(ctx context.Context, n string) ([]types.Server, error) { return []types.Server{&mock.server}, nil } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 2117a79a639ac..1f24dcd4a0835 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -100,6 +100,7 @@ type AuthProvider interface { IsMFARequired(ctx context.Context, req *authproto.IsMFARequiredRequest) (*authproto.IsMFARequiredResponse, error) GenerateUserSingleUseCerts(ctx context.Context) (authproto.AuthService_GenerateUserSingleUseCertsClient, error) MaintainSessionPresence(ctx context.Context) (authproto.AuthService_MaintainSessionPresenceClient, error) + ListUnifiedResources(ctx context.Context, req *authproto.ListUnifiedResourcesRequest) (*authproto.ListUnifiedResourcesResponse, error) } // NewTerminal creates a web-based terminal based on WebSockets and returns a @@ -885,6 +886,21 @@ func (t *sshBaseHandler) connectToNode(ctx context.Context, ws terminal.WSConn, // The close error is ignored instead of using [trace.NewAggregate] because // aggregate errors do not allow error inspection with things like [trace.IsAccessDenied]. _ = conn.Close() + + // Since connection attempts are made via UUID and not hostname, any access denied errors + // will not contain the resolved host address. To provide an easier troubleshooting experience + // for users, attempt to resolve the hostname of the server and augment the error message with it. + if trace.IsAccessDenied(err) { + if resp, err := t.userAuthClient.ListUnifiedResources(ctx, &authproto.ListUnifiedResourcesRequest{ + SortBy: types.SortBy{Field: types.ResourceKind}, + Kinds: []string{types.KindNode}, + Limit: 1, + PredicateExpression: fmt.Sprintf(`resource.metadata.name == "%s"`, t.sessionData.ServerID), + }); err == nil && len(resp.Resources) > 0 { + return nil, trace.AccessDenied("access denied to %q connecting to %v", sshConfig.User, resp.Resources[0].GetNode().GetHostname()) + } + } + return nil, trace.Wrap(err) }