Skip to content

Commit

Permalink
Handle all SNS messages with valid signatures
Browse files Browse the repository at this point in the history
This commit implements and tests the full SNS endpoint, which can confirm the subscription and act on notifications. It includes a tested implementation of SNS signature verification. It also includes one experimental fuzz test for the SNS verification. There's currently no CI infrastructure for running tests at all, let alone fuzzing, but I've run them locally and they pass.
  • Loading branch information
jameshochadel committed Feb 11, 2025
1 parent 5d37c62 commit 1df3188
Show file tree
Hide file tree
Showing 9 changed files with 454 additions and 60 deletions.
11 changes: 11 additions & 0 deletions helper/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# CSB Helper

## Conventions

Test functions are ordered in test files as follows:

1. Individual unit tests (`TestXxx`)
1. Table-driven tests
1. Benchmark tests (`BenchmarkXxx`)
1. Fuzz tests (`FuzzXxx`)
1. Example tests (`ExampleXxx`)
6 changes: 4 additions & 2 deletions helper/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ go 1.23.1
require (
github.com/aws/aws-sdk-go-v2 v1.34.0
github.com/aws/aws-sdk-go-v2/config v1.29.2
github.com/aws/aws-sdk-go-v2/credentials v1.17.55
github.com/aws/aws-sdk-go-v2/service/sns v1.33.15
golang.org/x/net v0.29.0
)

require github.com/jmespath/go-jmespath v0.4.0 // indirect
require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.55 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
)

require (
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.25 // indirect
Expand Down
4 changes: 4 additions & 0 deletions helper/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.33.10 h1:g9d+TOsu3ac7SgmY2dUf1qMgu/u
github.com/aws/aws-sdk-go-v2/service/sts v1.33.10/go.mod h1:WZfNmntu92HO44MVZAubQaz3qCuIdeOdog2sADfU6hU=
github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ=
github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
2 changes: 1 addition & 1 deletion helper/internal/brokerpaks/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ import (

func Handle(sesclient *awsses.Client, snsclient *sns.Client) http.Handler {
mux := http.NewServeMux()
mux.Handle("POST /ses/reputation-alarm", ses.HandleAlarm(sesclient, snsclient))
mux.Handle("POST /ses/reputation-alarm", ses.HandleSNSRequest(sesclient, snsclient))
return mux
}
115 changes: 65 additions & 50 deletions helper/internal/brokerpaks/ses/reputation.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@ import (
"io"
"log/slog"
"net/http"
"net/url"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ses"
"github.com/aws/aws-sdk-go-v2/service/sns"
)

type SESClient interface {
UpdateConfigurationSetSendingEnabled(context.Context, ses.UpdateConfigurationSetSendingEnabledInput) (ses.UpdateConfigurationSetSendingEnabledOutput, error)
}
const (
snsMessageTypeHeader = "x-amz-sns-message-type"
snsMessageTypeNotification = "Notification"
snsMessageTypeSubscriptionConfirmation = "SubscriptionConfirmation"
)

type SNSRequest struct {
Message CloudWatchAlarm
Subject string
SubscribeURL string
type SESClient interface {
UpdateConfigurationSetSendingEnabled(context.Context, *ses.UpdateConfigurationSetSendingEnabledInput, ...func(*ses.Options)) (*ses.UpdateConfigurationSetSendingEnabledOutput, error)
}

type CloudWatchAlarm struct {
Expand All @@ -35,28 +36,6 @@ type CloudWatchAlarm struct {
}
}

// UnmarshalJSON is custom implemented here because the Message field contains a JSON object
// (the SNS message) encoded in a string, with escaped quotes. The default Unmarshaller cannot
// handle this.
func (s *SNSRequest) UnmarshalJSON(b []byte) error {
// Unmarshal to an auxiliary type to get the string contents of all fields, including Message
var aux struct {
Message string
Subject string
SubscribeURL string
}

if err := json.Unmarshal(b, &aux); err != nil {
return err
}

s.Subject = aux.Subject
s.SubscribeURL = aux.SubscribeURL

// Unmarshal the Message field separately
return json.Unmarshal([]byte(aux.Message), &s.Message)
}

func (a *CloudWatchAlarm) Valid() map[string]string {
verrs := make(map[string]string)
prefix := "SES-BounceRate-Critical-Identity-"
Expand All @@ -77,9 +56,8 @@ func (a *CloudWatchAlarm) Valid() map[string]string {
return verrs
}

// parseRequests extracts the CloudWatch alarm from the body of the SNS request.
func ParseRequest(body io.Reader) (SNSRequest, error) {
var s SNSRequest
func UnmarshalMessage(body io.Reader) (SNSMessage, error) {
var s SNSMessage
b, err := io.ReadAll(body)
if err != nil {
return s, fmt.Errorf("reading SNS request body: %w", err)
Expand All @@ -94,33 +72,70 @@ func ParseRequest(body io.Reader) (SNSRequest, error) {
return s, nil
}

// TODO verify the SNS signature.
// TODO confirm subscription.
func HandleAlarm(sesclient *ses.Client, snsclient *sns.Client) http.Handler {
func handleSubscriptionConfirmation(msg SNSMessage) {
_, err := http.Get(msg.SubscribeURL)
if err != nil {
slog.Error("error confirming SNS subscription", "err", err)
} else {
slog.Info("confirmed subscription to SNS topic", "topic", msg.TopicArn)
}
}

func handleNotification(ctx context.Context, msg SNSMessage, sesclient SESClient) {
var a CloudWatchAlarm
err := json.Unmarshal([]byte(msg.Message), &a)
if err != nil {
slog.Error("unmarshalling CloudWatch alarm from SNS message body", "err", err)
}

if errs := a.Valid(); len(errs) > 0 {
slog.Error("error validating CloudWatch alarm. is the SNS subscription FilterPolicy allowing non-SES notifications?", "errs", errs)
}

cset := a.Trigger.Dimensions[0].Value
slog.Info("pausing sending on SES identity via Configuration Set", "configuration-set", cset)
_, err = sesclient.UpdateConfigurationSetSendingEnabled(ctx, &ses.UpdateConfigurationSetSendingEnabledInput{
ConfigurationSetName: aws.String(cset),
Enabled: false,
})
if err != nil {
slog.Error("error pausing sending on configuration set", "name", cset, "err", err)
}
}

// HandleSNSRequest handles requests from the platform notifications SNS topic subscription.
func HandleSNSRequest(sesclient *ses.Client, snsclient *sns.Client) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// todo: check if request is subscription request.
// check for SubscribeURL key?

req, err := ParseRequest(r.Body)
defer r.Body.Close() // todo, can return an error
msg, err := UnmarshalMessage(r.Body)
if err != nil {
slog.Error("error processing CloudWatch alarm SNS request", "err", err)
return
}
if errs := req.Message.Valid(); len(errs) > 0 {
slog.Error("error validating CloudWatch alarm. is the SNS subscription FilterPolicy allowing non-SES notifications?", "errs", errs)
u, err := url.Parse(*snsclient.Options().BaseEndpoint)
if err != nil {
slog.Error("initialized SNS client had no base endpoint -- this should never happen")
}

snsclient.ConfirmSubscription(context.Background(), &sns.ConfirmSubscriptionInput{})

cset := req.Message.Trigger.Dimensions[0].Value
if err = VerifySNSMessage(msg, u.Host); err != nil {
slog.Error("failed to verify SNS message signature", "err", err)
return
}

_, err = sesclient.UpdateConfigurationSetSendingEnabled(r.Context(), &ses.UpdateConfigurationSetSendingEnabledInput{
ConfigurationSetName: aws.String(cset),
Enabled: false,
})
if err != nil {
slog.Error("error pausing sending on configuration set", "name", cset, "err", err)
// once verified, switch on request type
mtype := r.Header.Get(snsMessageTypeHeader)
if mtype == "" {
slog.Error("SNS message passed verification but type header was empty -- this should never happen")
return
}
switch mtype {
case snsMessageTypeSubscriptionConfirmation:
handleSubscriptionConfirmation(msg)
case snsMessageTypeNotification:
handleNotification(r.Context(), msg, sesclient)
default:
// UnsubscribeConfirmation is a noop.
}
},
)
Expand Down
15 changes: 9 additions & 6 deletions helper/internal/brokerpaks/ses/reputation_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package ses_test

import (
"encoding/json"
"strings"
"testing"

"github.com/cloud-gov/csb/helper/internal/brokerpaks/ses"
)

var sesAlarm = ""

var otherAlarm = `{
var referenceAlarm = `{
"Type": "Notification",
"MessageId": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
"TopicArn": "arn:aws:sns:eu-west-1:000000000000:cloudwatch-alarms",
Expand All @@ -24,12 +23,16 @@ var otherAlarm = `{
}`

func TestParseRequest(t *testing.T) {
r := strings.NewReader(otherAlarm)
req, err := ses.ParseRequest(r)
r := strings.NewReader(referenceAlarm)
msg, err := ses.UnmarshalMessage(r)
if err != nil {
t.Fatal("error while parsing request: ", err)
}
alarm := req.Message
var alarm ses.CloudWatchAlarm
if err = json.Unmarshal([]byte(msg.Message), &alarm); err != nil {
t.Fatalf("unmarshalling alarm: %v", err.Error())
}

expectedName := "Example alarm name"
if alarm.AlarmName != expectedName {
t.Fatalf("expected alarm name %v, got %v", expectedName, alarm.AlarmName)
Expand Down
Loading

0 comments on commit 1df3188

Please sign in to comment.