Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BED-5198 - Merge main into stage/v6.3.1 #1035

Merged
merged 12 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 71 additions & 14 deletions cmd/api/src/analysis/ad/adcs_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,63 +447,75 @@ func TestTrustedForNTAuth(t *testing.T) {
func TestEnrollOnBehalfOf(t *testing.T) {
testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema())
testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error {
harness.EnrollOnBehalfOfHarnessOne.Setup(testContext)
harness.EnrollOnBehalfOfHarness1.Setup(testContext)
return nil
}, func(harness integration.HarnessDetails, db graph.Database) {
certTemplates, err := ad2.FetchNodesByKind(context.Background(), db, ad.CertTemplate)
v1Templates := make([]*graph.Node, 0)
v2Templates := make([]*graph.Node, 0)

for _, template := range certTemplates {
if version, err := template.Properties.Get(ad.SchemaVersion.String()).Float64(); err != nil {
continue
} else if version == 1 {
v1Templates = append(v1Templates, template)
} else if version >= 2 {
continue
v2Templates = append(v2Templates, template)
}
}

require.Nil(t, err)

db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates)
results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates, harness.EnrollOnBehalfOfHarness1.Domain1)
require.Nil(t, err)

require.Len(t, results, 3)

require.Contains(t, results, analysis.CreatePostRelationshipJob{
FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate11.ID,
ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate11.ID,
ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
Kind: ad.EnrollOnBehalfOf,
})

require.Contains(t, results, analysis.CreatePostRelationshipJob{
FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate13.ID,
ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate13.ID,
ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
Kind: ad.EnrollOnBehalfOf,
})

require.Contains(t, results, analysis.CreatePostRelationshipJob{
FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
Kind: ad.EnrollOnBehalfOf,
})

return nil
})

db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates, harness.EnrollOnBehalfOfHarness1.Domain1)
require.Nil(t, err)

require.Len(t, results, 0)

return nil
})
})

testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error {
harness.EnrollOnBehalfOfHarnessTwo.Setup(testContext)
harness.EnrollOnBehalfOfHarness2.Setup(testContext)
return nil
}, func(harness integration.HarnessDetails, db graph.Database) {
certTemplates, err := ad2.FetchNodesByKind(context.Background(), db, ad.CertTemplate)
v1Templates := make([]*graph.Node, 0)
v2Templates := make([]*graph.Node, 0)

for _, template := range certTemplates {
if version, err := template.Properties.Get(ad.SchemaVersion.String()).Float64(); err != nil {
continue
} else if version == 1 {
continue
v1Templates = append(v1Templates, template)
} else if version >= 2 {
v2Templates = append(v2Templates, template)
}
Expand All @@ -512,15 +524,60 @@ func TestEnrollOnBehalfOf(t *testing.T) {
require.Nil(t, err)

db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates)
results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates, harness.EnrollOnBehalfOfHarness2.Domain2)
require.Nil(t, err)

require.Len(t, results, 0)
return nil
})

db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates, harness.EnrollOnBehalfOfHarness2.Domain2)
require.Nil(t, err)

require.Len(t, results, 1)
require.Contains(t, results, analysis.CreatePostRelationshipJob{
FromID: harness.EnrollOnBehalfOfHarnessTwo.CertTemplate21.ID,
ToID: harness.EnrollOnBehalfOfHarnessTwo.CertTemplate23.ID,
FromID: harness.EnrollOnBehalfOfHarness2.CertTemplate21.ID,
ToID: harness.EnrollOnBehalfOfHarness2.CertTemplate23.ID,
Kind: ad.EnrollOnBehalfOf,
})
return nil
})
})

testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error {
harness.EnrollOnBehalfOfHarness3.Setup(testContext)
return nil
}, func(harness integration.HarnessDetails, db graph.Database) {
operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - EnrollOnBehalfOf 3")

_, enterpriseCertAuthorities, certTemplates, domains, cache, err := FetchADCSPrereqs(db)
require.Nil(t, err)

if err := ad2.PostEnrollOnBehalfOf(domains, enterpriseCertAuthorities, certTemplates, cache, operation); err != nil {
t.Logf("failed post processing for %s: %v", ad.EnrollOnBehalfOf.String(), err)
}
err = operation.Done()
require.Nil(t, err)

db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
if startNodes, err := ops.FetchStartNodes(tx.Relationships().Filterf(func() graph.Criteria {
return query.Kind(query.Relationship(), ad.EnrollOnBehalfOf)
})); err != nil {
t.Fatalf("error fetching EnrollOnBehalfOf edges in integration test; %v", err)
} else if endNodes, err := ops.FetchStartNodes(tx.Relationships().Filterf(func() graph.Criteria {
return query.Kind(query.Relationship(), ad.EnrollOnBehalfOf)
})); err != nil {
t.Fatalf("error fetching EnrollOnBehalfOf edges in integration test; %v", err)
} else {
require.Len(t, startNodes, 2)
require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate11))
require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12))

require.Len(t, endNodes, 2)
require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12))
require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12))
}

return nil
})
Expand Down
3 changes: 3 additions & 0 deletions cmd/api/src/api/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ const (
ErrorResponsePayloadUnmarshalError = "error unmarshalling JSON payload"
ErrorResponseRequestTimeout = "request timed out"
ErrorResponseUserSelfDisable = "user attempted to disable themselves"
ErrorResponseUserSelfRoleChange = "user attempted to change own role"
ErrorResponseUserSelfSSOProviderChange = "user attempted to change own SSO Provider"
ErrorResponseAGTagWhiteSpace = "asset group tags must not contain whitespace"
ErrorResponseAGNameTagEmpty = "asset group name or tag must not be empty"
ErrorResponseAGDuplicateName = "asset group name must be unique"
ErrorResponseAGDuplicateTag = "asset group tag must be unique"
ErrorResponseUserDuplicatePrincipal = "principal name must be unique"
ErrorResponseDetailsUniqueViolation = "unique constraint was violated"
ErrorResponseDetailsNotImplemented = "All good things to those who wait. Not implemented."

Expand Down
45 changes: 25 additions & 20 deletions cmd/api/src/api/v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,24 +356,18 @@ func (s ManagementResource) CreateUser(response http.ResponseWriter, request *ht
}

if newUser, err := s.db.CreateUser(request.Context(), userTemplate); err != nil {
api.HandleDatabaseError(request, response, err)
if errors.Is(err, database.ErrDuplicateUserPrincipal) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, api.ErrorResponseUserDuplicatePrincipal, request), response)
} else {
api.HandleDatabaseError(request, response, err)
}
} else {
api.WriteBasicResponse(request.Context(), newUser, http.StatusOK, response)
}

}
}

func (s ManagementResource) ensureUserHasNoAuthSecret(ctx context.Context, user model.User) error {
if user.AuthSecret != nil {
if err := s.db.DeleteAuthSecret(ctx, *user.AuthSecret); err != nil {
return api.FormatDatabaseError(err)
}
}

return nil
}

func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *http.Request) {
var (
updateUserRequest v2.UpdateUserRequest
Expand All @@ -400,8 +394,10 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.PrincipalName = updateUserRequest.Principal
user.IsDisabled = updateUserRequest.IsDisabled

loggedInUser, _ := auth.GetUserFromAuthCtx(authCtx.AuthCtx)

if user.IsDisabled {
if loggedInUser, _ := auth.GetUserFromAuthCtx(authCtx.AuthCtx); user.ID == loggedInUser.ID {
if user.ID == loggedInUser.ID {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfDisable, request), response)
return
} else if userSessions, err := s.db.LookupActiveSessionsByUser(request.Context(), user); err != nil {
Expand All @@ -419,9 +415,6 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
if samlProviderID, err := serde.ParseInt32(updateUserRequest.SAMLProviderID); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, fmt.Sprintf("SAML Provider ID must be a number: %v", err.Error()), request), response)
return
} else if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if provider, err := s.db.GetSAMLProvider(request.Context(), samlProviderID); err != nil {
api.HandleDatabaseError(request, response, err)
return
Expand All @@ -431,10 +424,7 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.SSOProviderID = provider.SSOProviderID
}
} else if updateUserRequest.SSOProviderID.Valid {
if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if _, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
if _, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
Expand All @@ -447,8 +437,23 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.SSOProviderID = null.NewInt32(0, false)
}

// Prevent a user from modifying their own roles/permissions
if user.ID == loggedInUser.ID {
if !slices.Equal(roles.IDs(), loggedInUser.Roles.IDs()) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfRoleChange, request), response)
return
} else if !user.SSOProviderID.Equal(loggedInUser.SSOProviderID) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfSSOProviderChange, request), response)
return
}
}

if err := s.db.UpdateUser(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
if errors.Is(err, database.ErrDuplicateUserPrincipal) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, api.ErrorResponseUserDuplicatePrincipal, request), response)
} else {
api.HandleDatabaseError(request, response, err)
}
} else {
response.WriteHeader(http.StatusOK)
}
Expand Down
66 changes: 65 additions & 1 deletion cmd/api/src/api/v2/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,17 @@ func TestManagementResource_PutUserAuthSecret(t *testing.T) {

func TestManagementResource_EnableUserSAML(t *testing.T) {
var (
adminUser = model.User{Unique: model.Unique{ID: must.NewUUIDv4()}}
goodRoles = []int32{0}
goodUserID = must.NewUUIDv4()
badUserID = must.NewUUIDv4()
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
)

bhCtx := ctx.Get(context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}))
bhCtx.AuthCtx.Owner = adminUser

defer mockCtrl.Finish()

t.Run("Successfully update user with deprecated saml provider", func(t *testing.T) {
Expand All @@ -181,6 +185,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand All @@ -197,9 +202,9 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().GetUser(gomock.Any(), badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)
mockDB.EXPECT().DeleteAuthSecret(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": badUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand All @@ -218,6 +223,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand Down Expand Up @@ -1511,6 +1517,64 @@ func TestManagementResource_UpdateUser_SelfDisable(t *testing.T) {
require.Contains(t, response.Body.String(), api.ErrorResponseUserSelfDisable)
}

func TestManagementResource_UpdateUser_UserSelfModify(t *testing.T) {
var (
adminRole = model.Role{
Serial: model.Serial{
ID: 1,
},
}
goodRoles = []int32{1}
badRole = model.Role{
Serial: model.Serial{
ID: 2,
},
}
badRoles = []int32{2}
adminUser = model.User{AuthSecret: defaultDigestAuthSecret(t, "currentPassword"), Unique: model.Unique{ID: must.NewUUIDv4()}, Roles: model.Roles{adminRole}}
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
)

bhCtx := ctx.Get(context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}))
bhCtx.AuthCtx.Owner = adminUser

defer mockCtrl.Finish()

t.Run("Prevent users from changing their own SSO provider", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{adminRole}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), adminUser.ID).Return(adminUser, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProviderID).Return(model.SSOProvider{}, nil)
test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": adminUser.ID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: goodRoles,
SSOProviderID: null.Int32From(123),
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusBadRequest)
})

t.Run("Prevent users from changing their own roles", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{badRole}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), adminUser.ID).Return(adminUser, nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": adminUser.ID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: badRoles,
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusBadRequest)
})
}

func TestManagementResource_UpdateUser_LookupActiveSessionsError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
Expand Down
Loading
Loading