From ded19e17bdab0a21e9f4bfcf64ef8c5e7e684c1b Mon Sep 17 00:00:00 2001 From: Pratik Jagrut Date: Tue, 3 Oct 2023 17:32:04 +0530 Subject: [PATCH] test: add unit tests for tarianctl add command --- .github/workflows/ci.yaml | 40 ++++ .gitignore | 1 + Makefile | 5 +- cmd/tarianctl/cmd/add/action_test.go | 232 +++++++++++++++++++ cmd/tarianctl/cmd/add/actions.go | 40 +++- cmd/tarianctl/cmd/add/add.go | 4 +- cmd/tarianctl/cmd/add/add_test.go | 64 ++++++ cmd/tarianctl/cmd/add/constraints.go | 80 +++---- cmd/tarianctl/cmd/add/constraints_test.go | 241 ++++++++++++++++++++ cmd/tarianctl/cmd/add/util.go | 33 +++ cmd/tarianctl/cmd/add/util_test.go | 50 ++++ cmd/tarianctl/cmd/flags/flag.go | 8 +- cmd/tarianctl/cmd/flags/flags_test.go | 107 +++++++++ cmd/tarianctl/cmd/get/actions.go | 13 +- cmd/tarianctl/cmd/get/constraints.go | 34 +-- cmd/tarianctl/cmd/get/events.go | 14 +- cmd/tarianctl/cmd/get/util.go | 28 +++ cmd/tarianctl/cmd/import/import.go | 14 +- cmd/tarianctl/cmd/remove/actions.go | 14 +- cmd/tarianctl/cmd/remove/constraints.go | 13 +- cmd/tarianctl/cmd/root_test.go | 36 +++ cmd/tarianctl/cmd/version.go | 6 +- cmd/tarianctl/util/grpc/fake_grpc_client.go | 21 ++ cmd/tarianctl/util/grpc/grpc_clients.go | 27 +++ cmd/tarianctl/util/grpc/interface.go | 13 ++ dev/config/monitored-pod/configmap.yaml | 3 +- pkg/tarianctl/client/client.go | 20 -- pkg/tarianctl/client/doc.go | 2 - pkg/tarianpb/fake_api_grpc.pb.go | 68 ++++++ 29 files changed, 1110 insertions(+), 121 deletions(-) create mode 100644 cmd/tarianctl/cmd/add/action_test.go create mode 100644 cmd/tarianctl/cmd/add/add_test.go create mode 100644 cmd/tarianctl/cmd/add/constraints_test.go create mode 100644 cmd/tarianctl/cmd/add/util.go create mode 100644 cmd/tarianctl/cmd/add/util_test.go create mode 100644 cmd/tarianctl/cmd/flags/flags_test.go create mode 100644 cmd/tarianctl/cmd/get/util.go create mode 100644 cmd/tarianctl/cmd/root_test.go create mode 100644 cmd/tarianctl/util/grpc/fake_grpc_client.go create mode 100644 cmd/tarianctl/util/grpc/grpc_clients.go create mode 100644 cmd/tarianctl/util/grpc/interface.go delete mode 100644 pkg/tarianctl/client/client.go delete mode 100644 pkg/tarianctl/client/doc.go create mode 100644 pkg/tarianpb/fake_api_grpc.pb.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 44b82718..94007034 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -39,6 +39,46 @@ jobs: with: args: -v --config=.golangci.yml + unit-test: + runs-on: ubuntu-latest + env: + GOPATH: ${{ github.workspace }}/../go + HOME: ${{ github.workspace }}/.. + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version: "stable" + - name: Install dependencies + run: | + set -x + # Install required dependencies and tools + sudo apt update && sudo apt install -y pkg-config libelf-dev clang + go install github.com/mgechev/revive@latest + go install honnef.co/go/tools/cmd/staticcheck@latest + go install google.golang.org/protobuf/cmd/protoc-gen-go@32051b4f86e54c2142c7c05362c6e96ae3454a1c # @v1.28.0 + go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@938f6e2f7550e542bd78f3b9e8812665db109e02 # @v1.1.0 + - name: Build + run: | + set -x + sudo apt update && sudo apt install -y jq pkg-config libelf-dev clang + go install google.golang.org/protobuf/cmd/protoc-gen-go@32051b4f86e54c2142c7c05362c6e96ae3454a1c # @v1.28.0 + go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@938f6e2f7550e542bd78f3b9e8812665db109e02 # @v1.1.0 + make bin/protoc bin/goreleaser + bash ./dev/run-kind-registry.sh + make ebpf generate + ./bin/goreleaser release --snapshot --rm-dist + make push-local-images + cp dist/tarianctl_linux_amd64/tarianctl ./bin/ + - name: Run unit tests + run: make unit-test + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.xml + verbose: true # optional (default = false) + test-k8s: runs-on: ubuntu-latest env: diff --git a/.gitignore b/.gitignore index 83115815..34a013c3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ /pkg/**/capture_exec.bpf.o /pkg/tarianpb/api.pb.go /pkg/tarianpb/types.pb.go +coverage.xml diff --git a/Makefile b/Makefile index 09d03938..1f15b2e8 100644 --- a/Makefile +++ b/Makefile @@ -142,7 +142,7 @@ push-local-images: docker push localhost:5000/tarian-node-agent unit-test: - CGO_CFLAGS=$(CGO_CFLAGS_STATIC) CGO_LDFLAGS=$(CGO_LDFLAGS_STATIC) go test -v -race -count=1 ./pkg/... + CGO_CFLAGS=$(CGO_CFLAGS_STATIC) CGO_LDFLAGS=$(CGO_LDFLAGS_STATIC) go test -v -race -coverprofile=coverage.xml -covermode=atomic ./pkg/... ./cmd/... e2e-test: CGO_CFLAGS=$(CGO_CFLAGS_STATIC) CGO_LDFLAGS=$(CGO_LDFLAGS_STATIC) go test -v -race -count=1 ./test/e2e/... @@ -150,6 +150,9 @@ e2e-test: k8s-test: ./test/k8s/test.sh +coverage: unit-test + go tool cover -html=coverage.xml + manifests: bin/controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects. $(CONTROLLER_GEN) webhook paths="./pkg/clusteragent/..." output:webhook:artifacts:config=dev/config/webhook diff --git a/cmd/tarianctl/cmd/add/action_test.go b/cmd/tarianctl/cmd/add/action_test.go new file mode 100644 index 00000000..cd12b715 --- /dev/null +++ b/cmd/tarianctl/cmd/add/action_test.go @@ -0,0 +1,232 @@ +package add + +import ( + "net" + "strings" + "testing" + + "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" + ugrpc "github.com/kube-tarian/tarian/cmd/tarianctl/util/grpc" + "github.com/kube-tarian/tarian/pkg/log" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" +) + +func TestActionCommand_Run(t *testing.T) { + // t.Parallel() + tests := []struct { + name string + expectedErr string + expectedLog string + + grpcClient ugrpc.Client + dryRun bool + onViolatedFile bool + onViolatedProcess bool + matchLabels []string + action string + onFalcoAlert string + }{ + { + name: "Add Action Successfully", + grpcClient: ugrpc.NewFakeGrpcClient(), + action: "delete-pod", + expectedLog: "Action was added successfully", + }, + { + name: "Add Action with Dry Run", + grpcClient: ugrpc.NewFakeGrpcClient(), + dryRun: true, + action: "delete-pod", + matchLabels: []string{ + "key1=val1", + "key2=val2", + }, + expectedLog: `kind: Action +namespace: test-namespace +name: test-action +selector: + matchlabels: + - key: key1 + value: val1 + - key: key2 + value: val2 +onviolatedprocess: false +onviolatedfile: false +onfalcoalert: false +falcopriority: 0 +action: delete-pod + +`, + }, + { + name: "Add Action with Dry Run and On Falco Alert", + grpcClient: ugrpc.NewFakeGrpcClient(), + dryRun: true, + action: "delete-pod", + onFalcoAlert: "alert", + expectedLog: `kind: Action +namespace: test-namespace +name: test-action +selector: + matchlabels: [] +onviolatedprocess: false +onviolatedfile: false +onfalcoalert: true +falcopriority: 1 +action: delete-pod +`, + }, + { + name: "Add Action with Invalid Falco Alert", + grpcClient: ugrpc.NewFakeGrpcClient(), + action: "delete-pod", + onFalcoAlert: "invalid", + expectedErr: "add action: invalid falco alert: invalid", + }, + { + name: "Add Action on Violated Process and violated file", + grpcClient: ugrpc.NewFakeGrpcClient(), + dryRun: true, + action: "delete-pod", + onViolatedFile: true, + onViolatedProcess: true, + expectedLog: `kind: Action +namespace: test-namespace +name: test-action +selector: + matchlabels: [] +onviolatedprocess: true +onviolatedfile: true +onfalcoalert: false +falcopriority: 0 +action: delete-pod +`, + }, + { + name: "Add Action with Invalid Action", + grpcClient: ugrpc.NewFakeGrpcClient(), + action: "invalid", + expectedErr: "invalid action: invalid", + }, + { + name: "Use real gRPC client", + action: "delete-pod", + expectedErr: "rpc error: code = Unimplemented desc = unknown service tarianpb.api.Config", + }, + } + + serverAddr := "localhost:50051" + go startFakeServer(t, serverAddr) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create the action command with the test configuration + cmd := &actionCommand{ + globalFlags: &flags.GlobalFlags{ + ServerAddr: serverAddr, + }, + logger: log.GetLogger(), + grpcClient: tt.grpcClient, + dryRun: tt.dryRun, + onFalcoAlert: tt.onFalcoAlert, + onViolatedProcess: tt.onViolatedProcess, + onViolatedFile: tt.onViolatedFile, + action: tt.action, + name: "test-action", + namespace: "test-namespace", + matchLabels: tt.matchLabels, + } + + // Capture log output + logOutput := []byte{} + cmd.logger.Out = &logOutputWriter{&logOutput} + + // Call the run function + err := cmd.run(nil, nil) + + // Assert expected error, if any + if tt.expectedErr != "" { + assert.Contains(t, err.Error(), tt.expectedErr) + } else { + if !assert.NoError(t, err) { + assert.FailNow(t, "error not expected") + } + } + + // Assert expected log output + if tt.expectedLog != "" { + assert.Equal(t, strings.TrimSpace(cleanLog(string(logOutput))), strings.TrimSpace(tt.expectedLog)) + } + }) + } +} + +// Helper struct to capture log output +type logOutputWriter struct { + output *[]byte +} + +func (w *logOutputWriter) Write(p []byte) (n int, err error) { + *w.output = append(*w.output, p...) + return len(p), nil +} + +func startFakeServer(t *testing.T, serverAddr string) { + lis, err := net.Listen("tcp", serverAddr) + if err != nil { + assert.NoError(t, err) + } + + srv := grpc.NewServer() + + if err := srv.Serve(lis); err != nil { + assert.NoError(t, err) + } +} + +func cleanLog(logLine string) string { + index := strings.Index(logLine, "]") + return logLine[index+2:] +} + +func TestNewAddActionCommand(t *testing.T) { + // Create a mock globalFlags instance + mockGlobalFlags := &flags.GlobalFlags{ + ServerAddr: "mock-server-address", + // Add other fields as needed + } + + // Call the function to be tested + cmd := newAddActionCommand(mockGlobalFlags) + + // Check if the returned value is of type *cobra.Command + assert.IsType(t, &cobra.Command{}, cmd) + + // Check if specific flags are correctly added + namespaceFlag := cmd.Flags().Lookup("namespace") + assert.NotNil(t, namespaceFlag) + assert.Equal(t, "default", namespaceFlag.DefValue) // Check default value + + nameFlag := cmd.Flags().Lookup("name") + assert.NotNil(t, nameFlag) + + matchLabelsFlag := cmd.Flags().Lookup("match-labels") + assert.NotNil(t, matchLabelsFlag) + + actionFlag := cmd.Flags().Lookup("action") + assert.NotNil(t, actionFlag) + + dryRunFlag := cmd.Flags().Lookup("dry-run") + assert.NotNil(t, dryRunFlag) + + onViolatedProcessFlag := cmd.Flags().Lookup("on-violated-process") + assert.NotNil(t, onViolatedProcessFlag) + + onViolatedFileFlag := cmd.Flags().Lookup("on-violated-file") + assert.NotNil(t, onViolatedFileFlag) + + onFalcoAlertFlag := cmd.Flags().Lookup("on-falco-alert") + assert.NotNil(t, onFalcoAlertFlag) +} diff --git a/cmd/tarianctl/cmd/add/actions.go b/cmd/tarianctl/cmd/add/actions.go index 6e6430ac..96d19ba8 100644 --- a/cmd/tarianctl/cmd/add/actions.go +++ b/cmd/tarianctl/cmd/add/actions.go @@ -7,18 +7,22 @@ import ( "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" "github.com/kube-tarian/tarian/cmd/tarianctl/util" + ugrpc "github.com/kube-tarian/tarian/cmd/tarianctl/util/grpc" + "google.golang.org/grpc" + "github.com/kube-tarian/tarian/pkg/log" - "github.com/kube-tarian/tarian/pkg/tarianctl/client" "github.com/kube-tarian/tarian/pkg/tarianpb" "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) type actionCommand struct { globalFlags *flags.GlobalFlags logger *logrus.Logger + grpcClient ugrpc.Client + name string namespace string matchLabels []string @@ -58,15 +62,26 @@ func newAddActionCommand(globalFlags *flags.GlobalFlags) *cobra.Command { } func (c *actionCommand) run(_ *cobra.Command, args []string) error { - opts, err := util.ClientOptionsFromCliContext(c.logger, c.globalFlags) - if err != nil { - return fmt.Errorf("add action: %w", err) + // TODO: Remove this check when we support more actions + if c.action != "delete-pod" { + c.logger.Errorf("invalid action: %s", c.action) + return fmt.Errorf("add action: invalid action: %s", c.action) } - configClient, err := client.NewConfigClient(c.globalFlags.ServerAddr, opts...) - if err != nil { - return fmt.Errorf("add action: failed to create config client: %w", err) + if c.grpcClient == nil { + opts, err := util.ClientOptionsFromCliContext(c.logger, c.globalFlags) + if err != nil { + return fmt.Errorf("add constraints: %w", err) + } + + grpcConn, err := grpc.Dial(c.globalFlags.ServerAddr, opts...) + if err != nil { + return fmt.Errorf("add constraints: failed to connect to server: %w", err) + } + defer grpcConn.Close() + c.grpcClient = ugrpc.NewGRPCClient(grpcConn) } + configClient := c.grpcClient.NewConfigClient() req := &tarianpb.AddActionRequest{ Action: &tarianpb.Action{ @@ -84,6 +99,15 @@ func (c *actionCommand) run(_ *cobra.Command, args []string) error { if c.onFalcoAlert != "" { req.Action.OnFalcoAlert = true + falcoAlert := map[string]bool{ + "alert": true, + "critical": true, + "emergency": true, + } + if !falcoAlert[c.onFalcoAlert] { + c.logger.Errorf("invalid falco alert: %s", c.onFalcoAlert) + return fmt.Errorf("add action: invalid falco alert: %s", c.onFalcoAlert) + } req.Action.FalcoPriority = tarianpb.FalcoPriorityFromString(c.onFalcoAlert) } diff --git a/cmd/tarianctl/cmd/add/add.go b/cmd/tarianctl/cmd/add/add.go index c84674f5..996f98ad 100644 --- a/cmd/tarianctl/cmd/add/add.go +++ b/cmd/tarianctl/cmd/add/add.go @@ -16,8 +16,8 @@ func NewAddCommand(globalFlags *flags.GlobalFlags) *cobra.Command { Short: "Add resources to the Tarian Server.", Long: "Add resources to the Tarian Server.", RunE: func(cmd *cobra.Command, args []string) error { - if len(args) == 0 { - err := errors.New("no resource specified, use `tarianctl add --help` for command usage") + if len(args) != 1 { + err := errors.New(`tarianctl needs exactly one argument, use "tarianctl add --help" for command usage`) return fmt.Errorf("add: %w", err) } return nil diff --git a/cmd/tarianctl/cmd/add/add_test.go b/cmd/tarianctl/cmd/add/add_test.go new file mode 100644 index 00000000..18a13647 --- /dev/null +++ b/cmd/tarianctl/cmd/add/add_test.go @@ -0,0 +1,64 @@ +package add + +import ( + "testing" + + "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" +) + +func TestNewAddCommand(t *testing.T) { + tests := []struct { + name string + args []string + expectedSubcommand string + expectedErr string + }{ + { + name: "No subcommand provided", + args: []string{}, + expectedSubcommand: "", + expectedErr: `tarianctl needs exactly one argument, use "tarianctl add --help" for command usage`, + }, + { + name: "Valid subcommand provided constraint", + args: []string{"constraint"}, + expectedSubcommand: "constraint", + expectedErr: "failed to connect to server", + }, + { + name: "Valid subcommand provided action", + args: []string{"action"}, + expectedSubcommand: "action", + expectedErr: `required flag(s) "action" not set`, + }, + { + name: "Invalid subcommand provided", + args: []string{"invalid-subcommand"}, + expectedSubcommand: "", + expectedErr: `unknown command "invalid-subcommand" for "add"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := NewAddCommand(&flags.GlobalFlags{}) + + assert.IsType(t, &cobra.Command{}, cmd) + + cmd.SetArgs(tt.args) + cmd.SilenceUsage = true + cmd.SilenceErrors = true + + err := cmd.Execute() + + if tt.expectedErr != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedErr) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/cmd/tarianctl/cmd/add/constraints.go b/cmd/tarianctl/cmd/add/constraints.go index 5046724a..026257a0 100644 --- a/cmd/tarianctl/cmd/add/constraints.go +++ b/cmd/tarianctl/cmd/add/constraints.go @@ -10,11 +10,12 @@ import ( "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" "github.com/kube-tarian/tarian/cmd/tarianctl/util" + ugrpc "github.com/kube-tarian/tarian/cmd/tarianctl/util/grpc" "github.com/kube-tarian/tarian/pkg/log" - "github.com/kube-tarian/tarian/pkg/tarianctl/client" "github.com/kube-tarian/tarian/pkg/tarianpb" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "google.golang.org/grpc" "gopkg.in/yaml.v2" ) @@ -22,6 +23,8 @@ type constraintsCommand struct { globalFlags *flags.GlobalFlags logger *logrus.Logger + grpcClient ugrpc.Client + name string namespace string matchLabels []string @@ -62,23 +65,23 @@ tarianctl add constraint --name nginx --namespace default --match-labels run=ngi } func (c *constraintsCommand) run(cobraCmd *cobra.Command, args []string) error { - opts, err := util.ClientOptionsFromCliContext(c.logger, c.globalFlags) - if err != nil { - return fmt.Errorf("add constraints: %w", err) - } - - configClient, err := client.NewConfigClient(c.globalFlags.ServerAddr, opts...) - if err != nil { - return fmt.Errorf("add constraints: failed to create config client: %w", err) - } + if c.grpcClient == nil { + opts, err := util.ClientOptionsFromCliContext(c.logger, c.globalFlags) + if err != nil { + return fmt.Errorf("add constraints: %w", err) + } - eventClient, err := client.NewEventClient(c.globalFlags.ServerAddr, opts...) - if err != nil { - return fmt.Errorf("add constraints: failed to create event client: %w", err) + grpcConn, err := grpc.Dial(c.globalFlags.ServerAddr, opts...) + if err != nil { + return fmt.Errorf("add constraints: failed to connect to server: %w", err) + } + defer grpcConn.Close() + c.grpcClient = ugrpc.NewGRPCClient(grpcConn) } + configClient := c.grpcClient.NewConfigClient() + eventClient := c.grpcClient.NewEventClient() fromViolatedPod := c.fromViolatedPod - var req *tarianpb.AddConstraintRequest // validate required fields @@ -87,6 +90,11 @@ func (c *constraintsCommand) run(cobraCmd *cobra.Command, args []string) error { return fmt.Errorf("add constraints: %w", err) } + if fromViolatedPod != "" && c.name != "" { + err := errors.New("constraint name and from-violated-pod cannot be used together") + return fmt.Errorf("add constraints: %w", err) + } + if fromViolatedPod != "" { constraint, err := c.buildConstraintFromViolatedPod(fromViolatedPod, eventClient, configClient, c.logger) if err != nil { @@ -118,9 +126,18 @@ func (c *constraintsCommand) run(cobraCmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("add constraints: %w", err) } - c.logger.Info(string(d)) } else { + if c.allowedFileSha256Sums == nil && c.allowedProcesses == nil { + err := errors.New("no allowed processes or files found, use --allowed-processes or --allowed-file-sha256sums or both") + return fmt.Errorf("add constraints: %w", err) + } + + if c.matchLabels == nil { + err := errors.New("no match labels found, use --match-labels") + return fmt.Errorf("add constraints: %w", err) + } + response, err := configClient.AddConstraint(context.Background(), req) if err != nil { return fmt.Errorf("add constraints: failed to add constraints: %w", err) @@ -139,29 +156,6 @@ func (c *constraintsCommand) run(cobraCmd *cobra.Command, args []string) error { return nil } -func matchLabelsFromString(strLabels []string) []*tarianpb.MatchLabel { - if strLabels == nil { - return nil - } - - labels := []*tarianpb.MatchLabel{} - - for _, s := range strLabels { - idx := strings.Index(s, "=") - - if idx < 0 { - continue - } - - key := s[:idx] - value := strings.Trim(s[idx+1:], "\"") - - labels = append(labels, &tarianpb.MatchLabel{Key: key, Value: value}) - } - - return labels -} - func allowedProcessesFromString(strProcesses []string) []*tarianpb.AllowedProcessRule { if strProcesses == nil { return nil @@ -192,7 +186,7 @@ func allowedFilesFromString(strFiles []string) []*tarianpb.AllowedFileRule { for _, s := range strFiles { idx := strings.Index(s, "=") - if idx < 0 { + if idx < 0 || idx == len(s)-1 { continue } @@ -237,7 +231,7 @@ func (c *constraintsCommand) buildConstraintFromViolatedPod(podName string, even resp, err := eventClient.GetEvents(ctx, &tarianpb.GetEventsRequest{Limit: 1000}) cancel() if err != nil { - return nil, fmt.Errorf("add constraints: buildConstraintFromViolatedPod: %w", err) + return nil, fmt.Errorf("buildConstraintFromViolatedPod: %w", err) } targets := []*tarianpb.Target{} @@ -247,7 +241,7 @@ func (c *constraintsCommand) buildConstraintFromViolatedPod(podName string, even if len(targets) == 0 { err := errors.New("zero target found") - return nil, fmt.Errorf("add constraints: buildConstraintFromViolatedPod: %w", err) + return nil, fmt.Errorf("buildConstraintFromViolatedPod: %w", err) } // build process rules @@ -271,13 +265,13 @@ func (c *constraintsCommand) buildConstraintFromViolatedPod(podName string, even if targets[0].GetPod() == nil { err := errors.New("no pod found") - return nil, fmt.Errorf("add constraints: buildConstraintFromViolatedPod: %w", err) + return nil, fmt.Errorf("buildConstraintFromViolatedPod: %w", err) } labels := targets[0].GetPod().GetLabels() if labels == nil { err := errors.New("no labels found") - return nil, fmt.Errorf("add constraints: buildConstraintFromViolatedPod: %w", err) + return nil, fmt.Errorf("buildConstraintFromViolatedPod: %w", err) } ignoredLabel := "pod-template-hash" diff --git a/cmd/tarianctl/cmd/add/constraints_test.go b/cmd/tarianctl/cmd/add/constraints_test.go new file mode 100644 index 00000000..71bcf4ca --- /dev/null +++ b/cmd/tarianctl/cmd/add/constraints_test.go @@ -0,0 +1,241 @@ +package add + +import ( + "strings" + "testing" + + "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" + ugrpc "github.com/kube-tarian/tarian/cmd/tarianctl/util/grpc" + "github.com/kube-tarian/tarian/pkg/log" + "github.com/kube-tarian/tarian/pkg/tarianpb" + "github.com/stretchr/testify/assert" +) + +func TestConstraintCommand_Run(t *testing.T) { + // t.Parallel() + tests := []struct { + name string + expectedErr string + expectedLog string + + grpcClient ugrpc.Client + constraintName string + matchLabels []string + allowedProcesses []string + allowedFileSha256Sums []string + fromViolatedPod string + dryRun bool + }{ + { + name: "Add Constraint Successfully", + grpcClient: ugrpc.NewFakeGrpcClient(), + constraintName: "test-constraint", + matchLabels: []string{"key1=val1"}, + allowedProcesses: []string{"process1"}, + expectedLog: "Constraint was added successfully", + }, + { + name: "Add Constraint without name or from-violated-pod", + grpcClient: ugrpc.NewFakeGrpcClient(), + expectedErr: "either constraint name or from-violated-pod is required", + }, + { + name: "Add Constraint with both name and from-violated-pod", + grpcClient: ugrpc.NewFakeGrpcClient(), + constraintName: "test-constraint", + fromViolatedPod: "test-pod", + expectedErr: "constraint name and from-violated-pod cannot be used together", + }, + { + name: "Add Constraint Successfully", + grpcClient: ugrpc.NewFakeGrpcClient(), + constraintName: "test-constraint", + matchLabels: []string{"key1=val1"}, + allowedProcesses: []string{"process1", "process2"}, + allowedFileSha256Sums: []string{"file1=sha256sum1", "file2=sha256sum2"}, + dryRun: true, + expectedLog: `kind: Constraint +namespace: test-namespace +name: test-constraint +selector: + matchlabels: + - key: key1 + value: val1 +allowedprocesses: +- regex: process1 +- regex: process2 +allowedfiles: +- name: file1 + sha256sum: sha256sum1 +- name: file2 + sha256sum: sha256sum2 +`, + }, + { + name: "Add Constraint without allowedProcesses and allowedFileSha256Sums", + grpcClient: ugrpc.NewFakeGrpcClient(), + constraintName: "test-constraint", + expectedErr: "no allowed processes or files found, use --allowed-processes or --allowed-file-sha256sums or both", + }, + { + name: "Add Constraint without matchLabels", + grpcClient: ugrpc.NewFakeGrpcClient(), + constraintName: "test-constraint", + allowedProcesses: []string{"process1", "process2"}, + allowedFileSha256Sums: []string{"file1=sha256sum1", "file2=sha256sum2"}, + expectedErr: "no match labels found, use --match-labels", + }, + { + name: "Use real gRPC client", + constraintName: "test-constraint", + matchLabels: []string{"key1=val1"}, + allowedProcesses: []string{"process1", "process2"}, + expectedErr: "rpc error: code = Unimplemented desc = unknown service tarianpb.api.Config", + }, + // TODO: Add test for from-violated-pod after faking GetEvents() + // { + // name: "Add Constraint with from-violated-pod", + // grpcClient: ugrpc.NewFakeGrpcClient(), + // fromViolatedPod: "test-pod", + // }, + + // TODO: Add test for Duplicate rules + + } + serverAddr := "localhost:50052" + go startFakeServer(t, serverAddr) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := &constraintsCommand{ + globalFlags: &flags.GlobalFlags{ + ServerAddr: serverAddr, + }, + logger: log.GetLogger(), + grpcClient: tt.grpcClient, + name: tt.constraintName, + namespace: "test-namespace", + matchLabels: tt.matchLabels, + allowedProcesses: tt.allowedProcesses, + allowedFileSha256Sums: tt.allowedFileSha256Sums, + fromViolatedPod: tt.fromViolatedPod, + dryRun: tt.dryRun, + } + + // Capture log output + logOutput := []byte{} + cmd.logger.Out = &logOutputWriter{&logOutput} + + // Call the run function + err := cmd.run(nil, nil) + + // Assert expected error, if any + if tt.expectedErr != "" { + assert.Contains(t, err.Error(), tt.expectedErr) + } else { + if !assert.NoError(t, err) { + assert.FailNow(t, "error not expected") + } + } + + // Assert expected log output + if tt.expectedLog != "" { + assert.Equal(t, strings.TrimSpace(cleanLog(string(logOutput))), strings.TrimSpace(tt.expectedLog)) + } + }) + } +} + +func TestAllowedProcessesFromString(t *testing.T) { + tests := []struct { + name string + input []string + expected []*tarianpb.AllowedProcessRule + shouldFail bool + }{ + { + name: "Allow both processes", + input: []string{"process1", "process2"}, + expected: []*tarianpb.AllowedProcessRule{ + {Regex: strPtr("process1")}, + {Regex: strPtr("process2")}, + }, + shouldFail: false, + }, + { + name: "No process, nil output", + input: nil, + expected: nil, + shouldFail: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := allowedProcessesFromString(tt.input) + + if tt.shouldFail { + assert.Nil(t, result) + } else { + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestAllowedFilesFromString(t *testing.T) { + tests := []struct { + name string + input []string + expected []*tarianpb.AllowedFileRule + shouldFail bool + }{ + { + name: "Allow both files", + input: []string{"file1=hash1", "file2=hash2"}, + expected: []*tarianpb.AllowedFileRule{ + {Name: "file1", Sha256Sum: strPtr("hash1")}, + {Name: "file2", Sha256Sum: strPtr("hash2")}, + }, + shouldFail: false, + }, + { + name: "Allow files with without hash", + input: []string{"file1=hash1", "file2=hash2", "file1"}, + expected: []*tarianpb.AllowedFileRule{ + {Name: "file1", Sha256Sum: strPtr("hash1")}, + {Name: "file2", Sha256Sum: strPtr("hash2")}, + }, + shouldFail: false, + }, + { + name: "Don't allow files with without hash", + input: []string{"file1="}, + expected: nil, + shouldFail: true, + }, + { + name: "No file, nil output", + input: nil, + expected: nil, + shouldFail: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := allowedFilesFromString(tt.input) + + if tt.shouldFail { + assert.Nil(t, result) + } else { + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func strPtr(s string) *string { + return &s +} + +// TODO: Add test for buildConstraintFromViolatedPod diff --git a/cmd/tarianctl/cmd/add/util.go b/cmd/tarianctl/cmd/add/util.go new file mode 100644 index 00000000..04763e9a --- /dev/null +++ b/cmd/tarianctl/cmd/add/util.go @@ -0,0 +1,33 @@ +package add + +import ( + "strings" + + "github.com/kube-tarian/tarian/pkg/tarianpb" +) + +func matchLabelsFromString(strLabels []string) []*tarianpb.MatchLabel { + if strLabels == nil { + return nil + } + + labels := []*tarianpb.MatchLabel{} + + for _, s := range strLabels { + idx := strings.Index(s, "=") + + if idx < 0 || idx == len(s)-1 { + continue + } + + key := s[:idx] + value := strings.Trim(s[idx+1:], "\"") + + labels = append(labels, &tarianpb.MatchLabel{Key: key, Value: value}) + } + + if len(labels) == 0 { + return nil + } + return labels +} diff --git a/cmd/tarianctl/cmd/add/util_test.go b/cmd/tarianctl/cmd/add/util_test.go new file mode 100644 index 00000000..6b8d6619 --- /dev/null +++ b/cmd/tarianctl/cmd/add/util_test.go @@ -0,0 +1,50 @@ +package add + +import ( + "testing" + + "github.com/kube-tarian/tarian/pkg/tarianpb" + "github.com/stretchr/testify/assert" +) + +func TestMatchLabelsFromString(t *testing.T) { + tests := []struct { + name string + input []string + expected []*tarianpb.MatchLabel + }{ + { + name: "MatchLabelsFromString should return valid Key=Value pairs", + input: []string{"key1=value1", "key2=value2"}, + expected: []*tarianpb.MatchLabel{ + {Key: "key1", Value: "value1"}, + {Key: "key2", Value: "value2"}, + }, + }, + { + name: "MatchLabelsFromString should return valid Key=Value pairs and ignore invalid labels", + input: []string{"key1=value1", "key2=value2", "invalid"}, + expected: []*tarianpb.MatchLabel{ + {Key: "key1", Value: "value1"}, + {Key: "key2", Value: "value2"}, + }, + }, + { + name: "MatchLabelsFromString should return nil if no valid Key=Value pairs are found", + input: []string{"invalid"}, + expected: nil, + }, + { + name: "MatchLabelsFromString should return nil if input is nil", + input: nil, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchLabelsFromString(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/cmd/tarianctl/cmd/flags/flag.go b/cmd/tarianctl/cmd/flags/flag.go index 9ac89740..a0f9a2c3 100644 --- a/cmd/tarianctl/cmd/flags/flag.go +++ b/cmd/tarianctl/cmd/flags/flag.go @@ -94,7 +94,9 @@ func (globalFlags *GlobalFlags) GetFlagValuesFromEnvVar() { // Read environment variable for "server-tls-enabled" flag if serverTLSEnabledEnv := os.Getenv("TARIAN_TLS_ENABLED"); serverTLSEnabledEnv != "" { - globalFlags.ServerTLSEnabled = true + if serverTLSEnabledEnv == "true" { + globalFlags.ServerTLSEnabled = true + } } // Read environment variable for "server-tls-ca-file" flag @@ -104,6 +106,8 @@ func (globalFlags *GlobalFlags) GetFlagValuesFromEnvVar() { // Read environment variable for "server-tls-insecure-skip-verify" flag if serverTLSInsecureSkipVerifyEnv := os.Getenv("TARIAN_TLS_INSECURE_SKIP_VERIFY"); serverTLSInsecureSkipVerifyEnv != "" { - globalFlags.ServerTLSInsecureSkipVerify = true + if serverTLSInsecureSkipVerifyEnv == "false" { + globalFlags.ServerTLSInsecureSkipVerify = false + } } } diff --git a/cmd/tarianctl/cmd/flags/flags_test.go b/cmd/tarianctl/cmd/flags/flags_test.go new file mode 100644 index 00000000..844ea508 --- /dev/null +++ b/cmd/tarianctl/cmd/flags/flags_test.go @@ -0,0 +1,107 @@ +package flags + +import ( + "os" + "testing" + + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" +) + +func TestSetGlobalFlags(t *testing.T) { + // Create a FlagSet for testing + fs := pflag.NewFlagSet("test", pflag.ExitOnError) + + // Initialize global flags + globalFlags := SetGlobalFlags(fs) + + // Test default values + assert.Equal(t, "info", globalFlags.LogLevel) + assert.Equal(t, "text", globalFlags.LogFormatter) + assert.Equal(t, defaultServerAddress, globalFlags.ServerAddr) + assert.False(t, globalFlags.ServerTLSEnabled) + assert.Equal(t, "", globalFlags.ServerTLSCAFile) + assert.True(t, globalFlags.ServerTLSInsecureSkipVerify) +} + +func TestValidateGlobalFlags(t *testing.T) { + tests := []struct { + name string + globalFlags *GlobalFlags + expectedError string + }{ + { + name: "Valid Flags", + globalFlags: &GlobalFlags{LogLevel: "info", LogFormatter: "text", ServerTLSCAFile: "ca.pem"}, + expectedError: "", + }, + { + name: "Invalid LogLevel", + globalFlags: &GlobalFlags{LogLevel: "invalid", LogFormatter: "text"}, + expectedError: "invalid log level: invalid", + }, + { + name: "Invalid LogFormatter", + globalFlags: &GlobalFlags{LogLevel: "info", LogFormatter: "invalid"}, + expectedError: "invalid log formatter: invalid", + }, + { + name: "ServerTLSWithoutCAFile", + globalFlags: &GlobalFlags{LogLevel: "info", LogFormatter: "text", ServerTLSEnabled: true}, + expectedError: "server TLS enabled but CA file is not provided", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.globalFlags.ValidateGlobalFlags() + if tt.expectedError == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tt.expectedError) + } + }) + } +} + +func TestGetFlagValuesFromEnvVar(t *testing.T) { + // Set environment variables for testing + tarianServerEnvVar := "TARIAN_SERVER_ADDRESS" + tarianServerEnvVarValue := "test-server:1234" + if err := os.Setenv(tarianServerEnvVar, tarianServerEnvVarValue); !assert.NoError(t, err) { + assert.FailNow(t, err.Error()) + } + defer os.Unsetenv(tarianServerEnvVar) + + // Set more environment variables for testing + TLSEnabledEnvVar := "TARIAN_TLS_ENABLED" + TLSEnabledEnvVarValue := "true" + if err := os.Setenv(TLSEnabledEnvVar, TLSEnabledEnvVarValue); !assert.NoError(t, err) { + assert.FailNow(t, err.Error()) + } + defer os.Unsetenv(TLSEnabledEnvVar) + + TLSCAFilEnvVar := "TARIAN_TLS_CA_FILE" + TLSCAFilEnvVarValue := "/path/to/ca.pem" + if err := os.Setenv(TLSCAFilEnvVar, TLSCAFilEnvVarValue); !assert.NoError(t, err) { + assert.FailNow(t, err.Error()) + } + defer os.Unsetenv(TLSCAFilEnvVar) + + TLSInsecureEnvVar := "TARIAN_TLS_INSECURE_SKIP_VERIFY" + TLSInsecureEnvVarValue := "false" + if err := os.Setenv(TLSInsecureEnvVar, TLSInsecureEnvVarValue); !assert.NoError(t, err) { + assert.FailNow(t, err.Error()) + } + defer os.Unsetenv(TLSInsecureEnvVar) + + // Create global flags and load values from environment variables + globalFlags := &GlobalFlags{} + globalFlags.GetFlagValuesFromEnvVar() + + // Check if the value was correctly loaded from the environment variable + assert.Equal(t, tarianServerEnvVarValue, globalFlags.ServerAddr) + assert.True(t, globalFlags.ServerTLSEnabled) + assert.False(t, globalFlags.ServerTLSInsecureSkipVerify) + assert.Equal(t, "/path/to/ca.pem", globalFlags.ServerTLSCAFile) +} diff --git a/cmd/tarianctl/cmd/get/actions.go b/cmd/tarianctl/cmd/get/actions.go index 5fc1c0d8..3bf06559 100644 --- a/cmd/tarianctl/cmd/get/actions.go +++ b/cmd/tarianctl/cmd/get/actions.go @@ -8,12 +8,13 @@ import ( "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" "github.com/kube-tarian/tarian/cmd/tarianctl/util" + ugrpc "github.com/kube-tarian/tarian/cmd/tarianctl/util/grpc" "github.com/kube-tarian/tarian/pkg/log" - "github.com/kube-tarian/tarian/pkg/tarianctl/client" "github.com/kube-tarian/tarian/pkg/tarianpb" "github.com/olekukonko/tablewriter" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "google.golang.org/grpc" "gopkg.in/yaml.v2" ) @@ -21,6 +22,8 @@ type actionCommand struct { globalFlags *flags.GlobalFlags logger *logrus.Logger + grpcClient ugrpc.Client + namespace string output string } @@ -55,10 +58,14 @@ func (c *actionCommand) run(_ *cobra.Command, args []string) error { return fmt.Errorf("get actions: %w", err) } - client, err := client.NewConfigClient(c.globalFlags.ServerAddr, opts...) + grpcConn, err := grpc.Dial(c.globalFlags.ServerAddr, opts...) if err != nil { - return fmt.Errorf("get actions: %w", err) + return fmt.Errorf("get actions: failed to connect to server: %w", err) } + defer grpcConn.Close() + + c.grpcClient = ugrpc.NewGRPCClient(grpcConn) + client := c.grpcClient.NewConfigClient() request := &tarianpb.GetActionsRequest{} ns := c.namespace diff --git a/cmd/tarianctl/cmd/get/constraints.go b/cmd/tarianctl/cmd/get/constraints.go index d02b0cad..12a747cc 100644 --- a/cmd/tarianctl/cmd/get/constraints.go +++ b/cmd/tarianctl/cmd/get/constraints.go @@ -8,9 +8,10 @@ import ( "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" "github.com/kube-tarian/tarian/cmd/tarianctl/util" + ugrpc "github.com/kube-tarian/tarian/cmd/tarianctl/util/grpc" "github.com/kube-tarian/tarian/pkg/log" - "github.com/kube-tarian/tarian/pkg/tarianctl/client" "github.com/sirupsen/logrus" + "google.golang.org/grpc" "github.com/kube-tarian/tarian/pkg/tarianpb" "github.com/olekukonko/tablewriter" @@ -22,6 +23,8 @@ type constraintsCommand struct { globalFlags *flags.GlobalFlags logger *logrus.Logger + grpcClient ugrpc.Client + output string } @@ -54,10 +57,14 @@ func (c *constraintsCommand) run(cobraCmd *cobra.Command, args []string) error { return fmt.Errorf("get constraints: %w", err) } - client, err := client.NewConfigClient(c.globalFlags.ServerAddr, opts...) + grpcConn, err := grpc.Dial(c.globalFlags.ServerAddr, opts...) if err != nil { - return fmt.Errorf("get constraints: %w", err) + return fmt.Errorf("get constraints: failed to connect to server: %w", err) } + defer grpcConn.Close() + + c.grpcClient = ugrpc.NewGRPCClient(grpcConn) + client := c.grpcClient.NewConfigClient() response, err := client.GetConstraints(context.Background(), &tarianpb.GetConstraintsRequest{}) if err != nil { @@ -96,27 +103,6 @@ func (c *constraintsCommand) run(cobraCmd *cobra.Command, args []string) error { return nil } -func matchLabelsToString(labels []*tarianpb.MatchLabel) string { - if len(labels) == 0 { - return "" - } - - str := strings.Builder{} - str.WriteString("matchLabels:") - - for i, l := range labels { - str.WriteString(l.GetKey()) - str.WriteString("=") - str.WriteString(l.GetValue()) - - if i < len(labels)-1 { - str.WriteString(",") - } - } - - return str.String() -} - func allowedProcessesToString(rules []*tarianpb.AllowedProcessRule) string { str := strings.Builder{} diff --git a/cmd/tarianctl/cmd/get/events.go b/cmd/tarianctl/cmd/get/events.go index 6eb7a9c4..4c331d70 100644 --- a/cmd/tarianctl/cmd/get/events.go +++ b/cmd/tarianctl/cmd/get/events.go @@ -10,8 +10,10 @@ import ( "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" "github.com/kube-tarian/tarian/cmd/tarianctl/util" + ugrpc "github.com/kube-tarian/tarian/cmd/tarianctl/util/grpc" + "google.golang.org/grpc" + "github.com/kube-tarian/tarian/pkg/log" - "github.com/kube-tarian/tarian/pkg/tarianctl/client" "github.com/kube-tarian/tarian/pkg/tarianpb" "github.com/olekukonko/tablewriter" "github.com/sirupsen/logrus" @@ -22,6 +24,8 @@ type eventsCommand struct { globalFlags *flags.GlobalFlags logger *logrus.Logger + grpcClient ugrpc.Client + limit uint } @@ -51,10 +55,14 @@ func (c *eventsCommand) run(_ *cobra.Command, args []string) error { return fmt.Errorf("get events: %w", err) } - client, err := client.NewEventClient(c.globalFlags.ServerAddr, opts...) + grpcConn, err := grpc.Dial(c.globalFlags.ServerAddr, opts...) if err != nil { - return fmt.Errorf("get events: %w", err) + return fmt.Errorf("get events: failed to connect to server: %w", err) } + defer grpcConn.Close() + + c.grpcClient = ugrpc.NewGRPCClient(grpcConn) + client := c.grpcClient.NewEventClient() response, err := client.GetEvents(context.Background(), &tarianpb.GetEventsRequest{Limit: uint32(c.limit)}) if err != nil { diff --git a/cmd/tarianctl/cmd/get/util.go b/cmd/tarianctl/cmd/get/util.go new file mode 100644 index 00000000..1ed76ac3 --- /dev/null +++ b/cmd/tarianctl/cmd/get/util.go @@ -0,0 +1,28 @@ +package get + +import ( + "strings" + + "github.com/kube-tarian/tarian/pkg/tarianpb" +) + +func matchLabelsToString(labels []*tarianpb.MatchLabel) string { + if len(labels) == 0 { + return "" + } + + str := strings.Builder{} + str.WriteString("matchLabels:") + + for i, l := range labels { + str.WriteString(l.GetKey()) + str.WriteString("=") + str.WriteString(l.GetValue()) + + if i < len(labels)-1 { + str.WriteString(",") + } + } + + return str.String() +} diff --git a/cmd/tarianctl/cmd/import/import.go b/cmd/tarianctl/cmd/import/import.go index af0bdb71..70865c9a 100644 --- a/cmd/tarianctl/cmd/import/import.go +++ b/cmd/tarianctl/cmd/import/import.go @@ -9,17 +9,21 @@ import ( "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" "github.com/kube-tarian/tarian/cmd/tarianctl/util" + ugrpc "github.com/kube-tarian/tarian/cmd/tarianctl/util/grpc" + "github.com/kube-tarian/tarian/pkg/log" - "github.com/kube-tarian/tarian/pkg/tarianctl/client" "github.com/kube-tarian/tarian/pkg/tarianpb" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "google.golang.org/grpc" "gopkg.in/yaml.v3" ) type importCommand struct { globalFlags *flags.GlobalFlags logger *logrus.Logger + + grpcClient ugrpc.Client } // NewImportCommand creates a new `import` command @@ -60,10 +64,14 @@ func (c *importCommand) run(_ *cobra.Command, args []string) error { return fmt.Errorf("import: %w", err) } - client, err := client.NewConfigClient(c.globalFlags.ServerAddr, opts...) + grpcConn, err := grpc.Dial(c.globalFlags.ServerAddr, opts...) if err != nil { - return fmt.Errorf("import: %w", err) + return fmt.Errorf("import: failed to connect to server: %w", err) } + defer grpcConn.Close() + + c.grpcClient = ugrpc.NewGRPCClient(grpcConn) + client := c.grpcClient.NewConfigClient() for _, f := range files { err := c.importFile(f, client) diff --git a/cmd/tarianctl/cmd/remove/actions.go b/cmd/tarianctl/cmd/remove/actions.go index 0ab36bcd..d7343410 100644 --- a/cmd/tarianctl/cmd/remove/actions.go +++ b/cmd/tarianctl/cmd/remove/actions.go @@ -8,17 +8,21 @@ import ( "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" "github.com/kube-tarian/tarian/cmd/tarianctl/util" + ugrpc "github.com/kube-tarian/tarian/cmd/tarianctl/util/grpc" + "github.com/kube-tarian/tarian/pkg/log" - "github.com/kube-tarian/tarian/pkg/tarianctl/client" "github.com/kube-tarian/tarian/pkg/tarianpb" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "google.golang.org/grpc" ) type removeActionsCmd struct { globalFlags *flags.GlobalFlags logger *logrus.Logger + grpcClient ugrpc.Client + namespace string } @@ -52,10 +56,14 @@ func (c *removeActionsCmd) run(_ *cobra.Command, args []string) error { return fmt.Errorf("remove action: %w", err) } - client, err := client.NewConfigClient(c.globalFlags.ServerAddr, opts...) + grpcConn, err := grpc.Dial(c.globalFlags.ServerAddr, opts...) if err != nil { - return fmt.Errorf("remove action: %w", err) + return fmt.Errorf("remove action: failed to connect to server: %w", err) } + defer grpcConn.Close() + + c.grpcClient = ugrpc.NewGRPCClient(grpcConn) + client := c.grpcClient.NewConfigClient() for _, name := range args { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) diff --git a/cmd/tarianctl/cmd/remove/constraints.go b/cmd/tarianctl/cmd/remove/constraints.go index fdf733c3..c87d5a90 100644 --- a/cmd/tarianctl/cmd/remove/constraints.go +++ b/cmd/tarianctl/cmd/remove/constraints.go @@ -8,17 +8,20 @@ import ( "github.com/kube-tarian/tarian/cmd/tarianctl/cmd/flags" "github.com/kube-tarian/tarian/cmd/tarianctl/util" + ugrpc "github.com/kube-tarian/tarian/cmd/tarianctl/util/grpc" "github.com/kube-tarian/tarian/pkg/log" - "github.com/kube-tarian/tarian/pkg/tarianctl/client" "github.com/kube-tarian/tarian/pkg/tarianpb" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "google.golang.org/grpc" ) type removeConstraintsCmd struct { globalFlags *flags.GlobalFlags logger *logrus.Logger + grpcClient ugrpc.Client + namespace string } @@ -51,10 +54,14 @@ func (c *removeConstraintsCmd) run(_ *cobra.Command, args []string) error { return fmt.Errorf("remove constraint: %w", err) } - client, err := client.NewConfigClient(c.globalFlags.ServerAddr, opts...) + grpcConn, err := grpc.Dial(c.globalFlags.ServerAddr, opts...) if err != nil { - return fmt.Errorf("remove constraint: %w", err) + return fmt.Errorf("remove constraint: failed to connect to server: %w", err) } + defer grpcConn.Close() + + c.grpcClient = ugrpc.NewGRPCClient(grpcConn) + client := c.grpcClient.NewConfigClient() for _, name := range args { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) diff --git a/cmd/tarianctl/cmd/root_test.go b/cmd/tarianctl/cmd/root_test.go new file mode 100644 index 00000000..c484250a --- /dev/null +++ b/cmd/tarianctl/cmd/root_test.go @@ -0,0 +1,36 @@ +package cmd + +import ( + "bytes" + "io" + "testing" + + "github.com/kube-tarian/tarian/pkg/log" + "github.com/stretchr/testify/assert" +) + +func TestRootCommand(t *testing.T) { + t.Run("TestRootCommandVersion", func(t *testing.T) { + stdout := new(bytes.Buffer) + + err := runRootCommand(stdout, []string{"version"}) + if assert.NoError(t, err) { + out, _ := io.ReadAll(stdout) + assert.Contains(t, string(out), "tarianctl version:") + } + }) + + t.Run("TestRootCommandInvalidSubcommand", func(t *testing.T) { + stdout := new(bytes.Buffer) + err := runRootCommand(stdout, []string{"invalidStderr-subcommand"}) + assert.EqualError(t, err, `unknown command "invalidStderr-subcommand" for "tarianctl"`) + }) +} + +func runRootCommand(output *bytes.Buffer, args []string) error { + logger := log.GetLogger() + logger.SetOutput(output) + rootCmd := buildRootCommand(logger) + rootCmd.SetArgs(args) + return rootCmd.Execute() +} diff --git a/cmd/tarianctl/cmd/version.go b/cmd/tarianctl/cmd/version.go index 259d1f5a..92ae1dbf 100644 --- a/cmd/tarianctl/cmd/version.go +++ b/cmd/tarianctl/cmd/version.go @@ -1,9 +1,8 @@ package cmd import ( - "fmt" - version "github.com/kube-tarian/tarian/cmd" + "github.com/kube-tarian/tarian/pkg/log" "github.com/spf13/cobra" ) @@ -12,6 +11,7 @@ var versionCmd = &cobra.Command{ Args: cobra.NoArgs, Short: "Prints version of tarianctl", Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("tarianctl version: %s\n", version.GetVersion()) + logger := log.GetLogger() + logger.Infof("tarianctl version: %s\n", version.GetVersion()) }, } diff --git a/cmd/tarianctl/util/grpc/fake_grpc_client.go b/cmd/tarianctl/util/grpc/fake_grpc_client.go new file mode 100644 index 00000000..7da7907b --- /dev/null +++ b/cmd/tarianctl/util/grpc/fake_grpc_client.go @@ -0,0 +1,21 @@ +package grpc + +import "github.com/kube-tarian/tarian/pkg/tarianpb" + +// FakeGrpcClient is a fake implementation of GrpcClient. +type FakeGrpcClient struct{} + +// NewFakeGrpcClient returns a new instance of FakeGrpcClient. +func NewFakeGrpcClient() *FakeGrpcClient { + return &FakeGrpcClient{} +} + +// NewConfigClient returns a new instance of fakeConfigClient. +func (f *FakeGrpcClient) NewConfigClient() tarianpb.ConfigClient { + return tarianpb.NewFakeConfigClient() +} + +// NewEventClient returns a new instance of fakeEventClient. +func (f *FakeGrpcClient) NewEventClient() tarianpb.EventClient { + return tarianpb.NewFakeEventClient() +} diff --git a/cmd/tarianctl/util/grpc/grpc_clients.go b/cmd/tarianctl/util/grpc/grpc_clients.go new file mode 100644 index 00000000..19d23252 --- /dev/null +++ b/cmd/tarianctl/util/grpc/grpc_clients.go @@ -0,0 +1,27 @@ +package grpc + +import ( + "github.com/kube-tarian/tarian/pkg/tarianpb" + "google.golang.org/grpc" +) + +type client struct { + conn *grpc.ClientConn +} + +// NewGRPCClient creates a new GRPCClient. +func NewGRPCClient(conn *grpc.ClientConn) Client { + return &client{ + conn: conn, + } +} + +// NewConfigClient creates a new ConfigClient. +func (g *client) NewConfigClient() tarianpb.ConfigClient { + return tarianpb.NewConfigClient(g.conn) +} + +// NewEventClient creates a new EventClient. +func (g *client) NewEventClient() tarianpb.EventClient { + return tarianpb.NewEventClient(g.conn) +} diff --git a/cmd/tarianctl/util/grpc/interface.go b/cmd/tarianctl/util/grpc/interface.go new file mode 100644 index 00000000..34f7eb75 --- /dev/null +++ b/cmd/tarianctl/util/grpc/interface.go @@ -0,0 +1,13 @@ +package grpc + +import ( + "github.com/kube-tarian/tarian/pkg/tarianpb" +) + +// Client is an interface for gRPC client +type Client interface { + // NewConfigClient returns a new ConfigClient + NewConfigClient() tarianpb.ConfigClient + // NewEventClient returns a new EventClient + NewEventClient() tarianpb.EventClient +} diff --git a/dev/config/monitored-pod/configmap.yaml b/dev/config/monitored-pod/configmap.yaml index d99b760b..47260f4c 100644 --- a/dev/config/monitored-pod/configmap.yaml +++ b/dev/config/monitored-pod/configmap.yaml @@ -29,4 +29,5 @@ data: kind: ConfigMap metadata: creationTimestamp: null - name: nginx-html \ No newline at end of file + name: nginx-html + diff --git a/pkg/tarianctl/client/client.go b/pkg/tarianctl/client/client.go deleted file mode 100644 index 1493122d..00000000 --- a/pkg/tarianctl/client/client.go +++ /dev/null @@ -1,20 +0,0 @@ -package client - -import ( - "github.com/kube-tarian/tarian/pkg/tarianpb" - "google.golang.org/grpc" -) - -// NewConfigClient creates a new ConfigClient. -func NewConfigClient(serverAddress string, opts ...grpc.DialOption) (tarianpb.ConfigClient, error) { - grpcConn, err := grpc.Dial(serverAddress, opts...) - - return tarianpb.NewConfigClient(grpcConn), err -} - -// NewEventClient creates a new EventClient. -func NewEventClient(serverAddress string, opts ...grpc.DialOption) (tarianpb.EventClient, error) { - grpcConn, err := grpc.Dial(serverAddress, opts...) - - return tarianpb.NewEventClient(grpcConn), err -} diff --git a/pkg/tarianctl/client/doc.go b/pkg/tarianctl/client/doc.go deleted file mode 100644 index ba9884e2..00000000 --- a/pkg/tarianctl/client/doc.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package client provides client construction functions to connect to tarian server -package client diff --git a/pkg/tarianpb/fake_api_grpc.pb.go b/pkg/tarianpb/fake_api_grpc.pb.go new file mode 100644 index 00000000..594767a3 --- /dev/null +++ b/pkg/tarianpb/fake_api_grpc.pb.go @@ -0,0 +1,68 @@ +package tarianpb + +import ( + context "context" + + grpc "google.golang.org/grpc" +) + +type fakeConfigClient struct{} + +// NewFakeConfigClient returns a new instance of fakeConfigClient. +func NewFakeConfigClient() ConfigClient { + return &fakeConfigClient{} +} + +// GetConstraints returns the constraints for the specified namespace. +func (f *fakeConfigClient) GetConstraints(ctx context.Context, in *GetConstraintsRequest, opts ...grpc.CallOption) (*GetConstraintsResponse, error) { + + return nil, nil +} + +// AddConstraint adds a constraint to the specified namespace. +func (f *fakeConfigClient) AddConstraint(ctx context.Context, in *AddConstraintRequest, opts ...grpc.CallOption) (*AddConstraintResponse, error) { + out := &AddConstraintResponse{ + Success: true, + } + return out, nil +} + +// RemoveConstraint removes a constraint from the specified namespace. +func (f *fakeConfigClient) RemoveConstraint(ctx context.Context, in *RemoveConstraintRequest, opts ...grpc.CallOption) (*RemoveConstraintResponse, error) { + return nil, nil +} + +// AddAction adds an action to the specified namespace. +func (f *fakeConfigClient) AddAction(ctx context.Context, in *AddActionRequest, opts ...grpc.CallOption) (*AddActionResponse, error) { + out := &AddActionResponse{ + Success: true, + } + return out, nil +} + +// GetActions returns the actions for the specified namespace. +func (f *fakeConfigClient) GetActions(ctx context.Context, in *GetActionsRequest, opts ...grpc.CallOption) (*GetActionsResponse, error) { + return nil, nil +} + +// RemoveAction removes an action from the specified namespace. +func (f *fakeConfigClient) RemoveAction(ctx context.Context, in *RemoveActionRequest, opts ...grpc.CallOption) (*RemoveActionResponse, error) { + return nil, nil +} + +type fakeEventClient struct{} + +// NewFakeEventClient returns a new instance of fakeEventClient. +func NewFakeEventClient() EventClient { + return &fakeEventClient{} +} + +// IngestEvent ingests an event to the Tarian Server. +func (f *fakeEventClient) IngestEvent(ctx context.Context, in *IngestEventRequest, opts ...grpc.CallOption) (*IngestEventResponse, error) { + return nil, nil +} + +// GetEvents returns the events from the Tarian Server. +func (f *fakeEventClient) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) { + return nil, nil +}