Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a per test PAM client in pam services. #34

Merged
merged 1 commit into from
Sep 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 103 additions & 91 deletions internal/services/pam/pam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pam_test

import (
"context"
"errors"
"flag"
"fmt"
"net"
Expand All @@ -21,8 +22,8 @@ import (
)

var (
testClient authd.PAMClient
brokerGeneratedID string
brokerManager *brokers.Manager
mockBrokerGeneratedID string
)

// Used for TestGetAuthenticationModes and TestSelectAuthenticationMode.
Expand Down Expand Up @@ -52,7 +53,9 @@ var (
func TestAvailableBrokers(t *testing.T) {
t.Parallel()

abResp, err := testClient.AvailableBrokers(context.Background(), &authd.Empty{})
client := newPamClient(t)

abResp, err := client.AvailableBrokers(context.Background(), &authd.Empty{})
require.NoError(t, err, "AvailableBrokers should not return an error, but did")

got := abResp.GetBrokersInfos()
Expand All @@ -68,18 +71,20 @@ func TestGetPreviousBroker(t *testing.T) {

username := t.Name()

client := newPamClient(t)

// Try to get the broker for the user before assigning it.
gotResp, _ := testClient.GetPreviousBroker(context.Background(), &authd.GPBRequest{Username: username})
gotResp, _ := client.GetPreviousBroker(context.Background(), &authd.GPBRequest{Username: username})
require.Empty(t, gotResp.GetPreviousBroker(), "GetPreviousBroker should return nil when the user has no broker assigned")

_, err := testClient.SetDefaultBrokerForUser(context.Background(), &authd.SDBFURequest{
_, err := client.SetDefaultBrokerForUser(context.Background(), &authd.SDBFURequest{
BrokerId: "local",
Username: username,
})
require.NoError(t, err, "Setup: could not set default broker for user for tests")

// Assert that the broker assigned to the user is correct.
gotResp, _ = testClient.GetPreviousBroker(context.Background(), &authd.GPBRequest{Username: username})
gotResp, _ = client.GetPreviousBroker(context.Background(), &authd.GPBRequest{Username: username})
require.Equal(t, "local", gotResp.GetPreviousBroker(), "GetPreviousBroker did not return the correct broker")
}

Expand Down Expand Up @@ -107,8 +112,10 @@ func TestSelectBroker(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

client := newPamClient(t)

if tc.brokerID == "" {
tc.brokerID = brokerGeneratedID
tc.brokerID = mockBrokerGeneratedID
} else if tc.brokerID == "-" {
tc.brokerID = ""
}
Expand All @@ -121,7 +128,7 @@ func TestSelectBroker(t *testing.T) {
BrokerId: tc.brokerID,
Username: tc.username,
}
sbResp, err := testClient.SelectBroker(context.Background(), sbRequest)
sbResp, err := client.SelectBroker(context.Background(), sbRequest)
if tc.wantErr {
require.Error(t, err, "SelectBroker should return an error, but did not")
return
Expand Down Expand Up @@ -166,8 +173,10 @@ func TestGetAuthenticationModes(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

client := newPamClient(t)

if !tc.noSession {
id := startSession(t, tc.username)
id := startSession(t, client, tc.username)
if tc.sessionID == "" {
tc.sessionID = id
}
Expand All @@ -184,7 +193,7 @@ func TestGetAuthenticationModes(t *testing.T) {
SessionId: tc.sessionID,
SupportedUiLayouts: tc.supportedUILayouts,
}
gamResp, err := testClient.GetAuthenticationModes(context.Background(), gamReq)
gamResp, err := client.GetAuthenticationModes(context.Background(), gamReq)
if tc.wantErr {
require.Error(t, err, "GetAuthenticationModes should return an error, but did not")
return
Expand Down Expand Up @@ -237,8 +246,10 @@ func TestSelectAuthenticationMode(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

client := newPamClient(t)

if !tc.noSession {
id := startSession(t, tc.username)
id := startSession(t, client, tc.username)
if tc.sessionID == "" {
tc.sessionID = id
}
Expand All @@ -260,15 +271,15 @@ func TestSelectAuthenticationMode(t *testing.T) {
SessionId: tc.sessionID,
SupportedUiLayouts: tc.supportedUILayouts,
}
_, err := testClient.GetAuthenticationModes(context.Background(), gamReq)
_, err := client.GetAuthenticationModes(context.Background(), gamReq)
require.NoError(t, err, "Setup: failed to get authentication modes for tests")
}

samReq := &authd.SAMRequest{
SessionId: tc.sessionID,
AuthenticationModeId: tc.authMode,
}
samResp, err := testClient.SelectAuthenticationMode(context.Background(), samReq)
samResp, err := client.SelectAuthenticationMode(context.Background(), samReq)
if tc.wantErr {
require.Error(t, err, "SelectAuthenticationMode should return an error, but did not")
return
Expand Down Expand Up @@ -316,8 +327,10 @@ func TestIsAuthenticated(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

client := newPamClient(t)

if !tc.noSession {
id := startSession(t, tc.username)
id := startSession(t, client, tc.username)
if tc.sessionID == "" {
tc.sessionID = id
}
Expand All @@ -337,7 +350,7 @@ func TestIsAuthenticated(t *testing.T) {
SessionId: tc.sessionID,
AuthenticationData: "some data",
}
iaResp, err := testClient.IsAuthenticated(ctx, iaReq)
iaResp, err := client.IsAuthenticated(ctx, iaReq)
firstCall = fmt.Sprintf("FIRST CALL:\n\taccess: %s\n\tdata: %s\n\terr: %v\n",
iaResp.GetAccess(),
iaResp.GetData(),
Expand All @@ -357,7 +370,7 @@ func TestIsAuthenticated(t *testing.T) {
SessionId: tc.sessionID,
AuthenticationData: "some data",
}
iaResp, err := testClient.IsAuthenticated(context.Background(), iaReq)
iaResp, err := client.IsAuthenticated(context.Background(), iaReq)
secondCall = fmt.Sprintf("SECOND CALL:\n\taccess: %s\n\tdata: %s\n\terr: %v\n",
iaResp.GetAccess(),
iaResp.GetData(),
Expand Down Expand Up @@ -396,7 +409,9 @@ func TestSetDefaultBrokerForUser(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

wantID := brokerGeneratedID
client := newPamClient(t)

wantID := mockBrokerGeneratedID
if tc.noBroker {
wantID = "does not exist"
}
Expand All @@ -408,14 +423,14 @@ func TestSetDefaultBrokerForUser(t *testing.T) {
BrokerId: wantID,
Username: tc.username,
}
_, err := testClient.SetDefaultBrokerForUser(context.Background(), sdbfuReq)
_, err := client.SetDefaultBrokerForUser(context.Background(), sdbfuReq)
if tc.wantErr {
require.Error(t, err, "SetDefaultBrokerForUser should return an error, but did not")
return
}
require.NoError(t, err, "SetDefaultBrokerForUser should not return an error, but did")

gotResp, _ := testClient.GetPreviousBroker(context.Background(), &authd.GPBRequest{Username: tc.username})
gotResp, _ := client.GetPreviousBroker(context.Background(), &authd.GPBRequest{Username: tc.username})
require.Equal(t, wantID, gotResp.GetPreviousBroker(), "SetDefaultBrokerForUser did not set the correct broker for the user")
})
}
Expand Down Expand Up @@ -447,8 +462,10 @@ func TestEndSession(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

client := newPamClient(t)

if !tc.noSession {
id := startSession(t, tc.username)
id := startSession(t, client, tc.username)
if tc.sessionID == "" {
tc.sessionID = id
}
Expand All @@ -460,7 +477,7 @@ func TestEndSession(t *testing.T) {
esReq := &authd.ESRequest{
SessionId: tc.sessionID,
}
_, err := testClient.EndSession(context.Background(), esReq)
_, err := client.EndSession(context.Background(), esReq)
if tc.wantErr {
require.Error(t, err, "EndSession should return an error, but did not")
return
Expand All @@ -470,70 +487,88 @@ func TestEndSession(t *testing.T) {
}
}

func startClient() (client authd.PAMClient, cleanup func(), err error) {
// initBrokers starts dbus mock brokers on the system bus. It returns its config path.
func initBrokers() (brokerConfigPath string, cleanup func(), err error) {
tmpDir, err := os.MkdirTemp("", "authd-internal-pam-tests-")
if err != nil {
return nil, nil, err
return "", nil, err
}

brokerDir := filepath.Join(tmpDir, "etc", "authd", "broker.d")
if err = os.MkdirAll(brokerDir, 0750); err != nil {
_ = os.RemoveAll(tmpDir)
return nil, nil, err
return "", nil, err
}
_, brokerCleanup, err := testutils.StartBusBrokerMock(brokerDir, "BrokerMock")
if err != nil {
_ = os.RemoveAll(tmpDir)
return nil, nil, err
return "", nil, err
}

socketPath := filepath.Join(tmpDir, "authd.sock")
lis, err := net.Listen("unix", socketPath)
if err != nil {
_ = os.RemoveAll(tmpDir)
return tmpDir, func() {
brokerCleanup()
return nil, nil, err
}
// We want everyone to be able to write to our socket and we will filter permissions
// #nosec G302
if err = os.Chmod(socketPath, 0666); err != nil {
_ = os.RemoveAll(tmpDir)
brokerCleanup()
return nil, nil, err
}
}, nil
}

brokerManager, err := brokers.NewManager(context.Background(), nil, brokers.WithRootDir(tmpDir))
if err != nil {
_ = os.RemoveAll(tmpDir)
brokerCleanup()
return nil, nil, err
}
// newPAMClient returns a new GRPC PAM client for tests connected to the global brokerManager.
func newPamClient(t *testing.T) (client authd.PAMClient) {
t.Helper()

// socket path is limited in length.
tmpDir, err := os.MkdirTemp("", "authd-socket-dir")
require.NoError(t, err, "Setup: could not setup temporary socket dir path")
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
socketPath := filepath.Join(tmpDir, "authd.sock")

lis, err := net.Listen("unix", socketPath)
require.NoError(t, err, "Setup: could not create unix socket")

grpcServer := grpc.NewServer()
service := pam.NewService(context.Background(), brokerManager)

grpcServer := grpc.NewServer()
authd.RegisterPAMServer(grpcServer, service)
done := make(chan struct{})
go func() {
defer close(done)
_ = grpcServer.Serve(lis)
}()

conn, err := grpc.Dial("unix://"+socketPath, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Cleanup(func() {
grpcServer.Stop()
<-done
brokerCleanup()
_ = os.RemoveAll(tmpDir)
return nil, nil, err
})

conn, err := grpc.Dial("unix://"+socketPath, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err, "Setup: Could not connect to GRPC server")
t.Cleanup(func() { _ = conn.Close() }) // We don't care about the error on cleanup

return authd.NewPAMClient(conn)
}

// getMockBrokerGeneratedID returns the generated ID for the mock broker.
func getMockBrokerGeneratedID(brokerManager *brokers.Manager) (string, error) {
for _, b := range brokerManager.AvailableBrokers() {
if b.Name != "BrokerMock" {
continue
}
return b.ID, nil
}
return "", errors.New("Setup: could not find generated broker mock ID in the broker manager list")
}

return authd.NewPAMClient(conn), func() {
conn.Close()
grpcServer.Stop()
<-done
brokerCleanup()
_ = os.RemoveAll(tmpDir)
}, nil
// startSession is a helper that starts a session on the mock broker.
func startSession(t *testing.T, client authd.PAMClient, username string) string {
t.Helper()

// Prefixes the username to avoid concurrency issues.
username = t.Name() + testutils.IDSeparator + username

sbResp, err := client.SelectBroker(context.Background(), &authd.SBRequest{
BrokerId: mockBrokerGeneratedID,
Username: username,
})
require.NoError(t, err, "Setup: failed to create session for tests")
return sbResp.GetSessionId()
}

func TestMain(m *testing.M) {
Expand All @@ -548,48 +583,25 @@ func TestMain(m *testing.M) {
}
defer busCleanup()

var cleanup func()
testClient, cleanup, err = startClient()
// Start brokers mock over dbus.
brokersConfigPath, cleanup, err := initBrokers()
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
busCleanup()
os.Exit(1)
}
defer cleanup()

brokerGeneratedID = getBrokerGeneratedID(testClient, "BrokerMock")
if brokerGeneratedID == "" {
fmt.Fprintf(os.Stderr, "could not get generated ID for BrokerMock\n")
cleanup()
busCleanup()
// Get manager shared across grpc services.
brokerManager, err = brokers.NewManager(context.Background(), nil, brokers.WithRootDir(brokersConfigPath))
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
m.Run()
}

// getBrokerGeneratedID returns the generated ID for the specified broker.
func getBrokerGeneratedID(client authd.PAMClient, brokerName string) string {
r, _ := client.AvailableBrokers(context.Background(), &authd.Empty{})
for _, b := range r.GetBrokersInfos() {
if b.GetName() != brokerName {
continue
}
return b.Id
mockBrokerGeneratedID, err = getMockBrokerGeneratedID(brokerManager)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
return ""
}

// startSession is a helper that starts a session on the specified broker.
func startSession(t *testing.T, username string) string {
t.Helper()

// Prefixes the username to avoid concurrency issues.
username = t.Name() + testutils.IDSeparator + username

sbResp, err := testClient.SelectBroker(context.Background(), &authd.SBRequest{
BrokerId: brokerGeneratedID,
Username: username,
})
require.NoError(t, err, "Setup: failed to create session for tests")
return sbResp.GetSessionId()
m.Run()
}