Skip to content

Commit

Permalink
Use Pinpoint by default (#2173)
Browse files Browse the repository at this point in the history
  • Loading branch information
sastels authored May 22, 2024
1 parent d1d851d commit 12b9571
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 38 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ CONTACT_FORM_EMAIL_ADDRESS = ""

AWS_PINPOINT_SC_POOL_ID=
AWS_PINPOINT_SC_TEMPLATE_IDS=
AWS_PINPOINT_DEFAULT_POOL_ID=
9 changes: 6 additions & 3 deletions app/clients/sms/aws_pinpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@ class AwsPinpointClient(SmsClient):
def init_app(self, current_app, statsd_client, *args, **kwargs):
self._client = boto3.client("pinpoint-sms-voice-v2", region_name="ca-central-1")
super(AwsPinpointClient, self).__init__(*args, **kwargs)
# super(SmsClient, self).__init__(*args, **kwargs)
self.current_app = current_app
self.name = "pinpoint"
self.statsd_client = statsd_client

def get_name(self):
return self.name

def send_sms(self, to, content, reference, multi=True, sender=None):
pool_id = self.current_app.config["AWS_PINPOINT_SC_POOL_ID"]
def send_sms(self, to, content, reference, multi=True, sender=None, template_id=None):
messageType = "TRANSACTIONAL"
matched = False

if template_id is not None and str(template_id) in self.current_app.config["AWS_PINPOINT_SC_TEMPLATE_IDS"]:
pool_id = self.current_app.config["AWS_PINPOINT_SC_POOL_ID"]
else:
pool_id = self.current_app.config["AWS_PINPOINT_DEFAULT_POOL_ID"]

for match in phonenumbers.PhoneNumberMatcher(to, "US"):
matched = True
to = phonenumbers.format_number(match.number, phonenumbers.PhoneNumberFormat.E164)
Expand Down
8 changes: 2 additions & 6 deletions app/clients/sms/aws_sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from time import monotonic

import boto3
import botocore
import phonenumbers
from notifications_utils.statsd_decorators import statsd

Expand All @@ -27,7 +26,7 @@ def get_name(self):
return self.name

@statsd(namespace="clients.sns")
def send_sms(self, to, content, reference, multi=True, sender=None):
def send_sms(self, to, content, reference, multi=True, sender=None, template_id=None):
matched = False

for match in phonenumbers.PhoneNumberMatcher(to, "US"):
Expand Down Expand Up @@ -66,12 +65,9 @@ def send_sms(self, to, content, reference, multi=True, sender=None):
try:
start_time = monotonic()
response = client.publish(PhoneNumber=to, Message=content, MessageAttributes=attributes)
except botocore.exceptions.ClientError as e:
self.statsd_client.incr("clients.sns.error")
raise str(e)
except Exception as e:
self.statsd_client.incr("clients.sns.error")
raise str(e)
raise e
finally:
elapsed_time = monotonic() - start_time
self.current_app.logger.info("AWS SNS request finished in {}".format(elapsed_time))
Expand Down
1 change: 1 addition & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ class Config(object):
AWS_SES_SECRET_KEY = os.getenv("AWS_SES_SECRET_KEY")
AWS_PINPOINT_REGION = os.getenv("AWS_PINPOINT_REGION", "us-west-2")
AWS_PINPOINT_SC_POOL_ID = os.getenv("AWS_PINPOINT_SC_POOL_ID", None)
AWS_PINPOINT_DEFAULT_POOL_ID = os.getenv("AWS_PINPOINT_DEFAULT_POOL_ID", None)
AWS_PINPOINT_CONFIGURATION_SET_NAME = os.getenv("AWS_PINPOINT_CONFIGURATION_SET_NAME", "pinpoint-configuration")
AWS_PINPOINT_SC_TEMPLATE_IDS = env.list("AWS_PINPOINT_SC_TEMPLATE_IDS", [])
AWS_US_TOLL_FREE_NUMBER = os.getenv("AWS_US_TOLL_FREE_NUMBER")
Expand Down
64 changes: 53 additions & 11 deletions app/delivery/send_to_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import os
import re
from datetime import datetime
from typing import Dict
from typing import Any, Dict, Optional
from uuid import UUID

import phonenumbers
from flask import current_app
from notifications_utils.recipients import (
validate_and_format_email_address,
Expand Down Expand Up @@ -48,6 +49,7 @@
NOTIFICATION_VIRUS_SCAN_FAILED,
PINPOINT_PROVIDER,
SMS_TYPE,
SNS_PROVIDER,
BounceRateStatus,
Notification,
Service,
Expand All @@ -67,9 +69,9 @@ def send_sms_to_provider(notification):
provider = provider_to_use(
SMS_TYPE,
notification.id,
notification.to,
notification.international,
notification.reply_to_text,
template_id=notification.template_id,
)

template_dict = dao_get_template_by_id(notification.template_id, notification.template_version).__dict__
Expand Down Expand Up @@ -105,6 +107,7 @@ def send_sms_to_provider(notification):
content=str(template),
reference=str(notification.id),
sender=notification.reply_to_text,
template_id=notification.template_id,
)
except Exception as e:
notification.billable_units = template.fragment_count
Expand Down Expand Up @@ -336,16 +339,55 @@ def update_notification_to_sending(notification, provider):
dao_update_notification(notification)


def provider_to_use(notification_type, notification_id, international=False, sender=None, template_id=None):
# Temporary redirect setup for template IDs that are meant for the short code usage.
if notification_type == SMS_TYPE and template_id is not None and str(template_id) in Config.AWS_PINPOINT_SC_TEMPLATE_IDS:
return clients.get_client_by_name_and_type("pinpoint", SMS_TYPE)
def provider_to_use(
notification_type: str,
notification_id: UUID,
to: Optional[str] = None,
international: bool = False,
sender: Optional[str] = None,
) -> Any:
"""
Get the provider to use for sending the notification.
SMS that are being sent with a dedicated number or to a US number should not use Pinpoint.
Args:
notification_type (str): SMS or EMAIL.
notification_id (UUID): id of notification. Just used for logging.
to (str, optional): recipient. Defaults to None.
international (bool, optional): Recipient is international. Defaults to False.
sender (str, optional): reply_to_text to use. Defaults to None.
Raises:
Exception: No active providers.
active_providers_in_order = [
p
for p in get_provider_details_by_notification_type(notification_type, international)
if p.active and p.identifier != PINPOINT_PROVIDER
]
Returns:
provider: Provider to use to send the notification.
"""

has_dedicated_number = sender is not None and sender.startswith("+1")
sending_to_us_number = False
if to is not None:
match = next(iter(phonenumbers.PhoneNumberMatcher(to, "US")), None)
if match and phonenumbers.region_code_for_number(match.number) == "US":
sending_to_us_number = True

if (
has_dedicated_number
or sending_to_us_number
or current_app.config["AWS_PINPOINT_SC_POOL_ID"] is None
or current_app.config["AWS_PINPOINT_DEFAULT_POOL_ID"] is None
):
active_providers_in_order = [
p
for p in get_provider_details_by_notification_type(notification_type, international)
if p.active and p.identifier != PINPOINT_PROVIDER
]
else:
active_providers_in_order = [
p
for p in get_provider_details_by_notification_type(notification_type, international)
if p.active and p.identifier != SNS_PROVIDER
]

if not active_providers_in_order:
current_app.logger.error("{} {} failed as no active providers".format(notification_type, notification_id))
Expand Down
19 changes: 19 additions & 0 deletions migrations/versions/0450_enable_pinpoint_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
Revision ID: 0450_enable_pinpoint_provider
Revises: 0449_update_magic_link_auth
Create Date: 2021-01-08 09:03:00 .214680
"""
from alembic import op

revision = "0450_enable_pinpoint_provider"
down_revision = "0449_update_magic_link_auth"


def upgrade():
op.execute("UPDATE provider_details set active=true where identifier in ('pinpoint');")


def downgrade():
op.execute("UPDATE provider_details set active=false where identifier in ('pinpoint');")
73 changes: 73 additions & 0 deletions tests/app/clients/test_aws_pinpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest

from app import aws_pinpoint_client
from tests.conftest import set_config_values


@pytest.mark.serial
def test_send_sms_sends_to_default_pool(notify_api, mocker, sample_template):
boto_mock = mocker.patch.object(aws_pinpoint_client, "_client", create=True)
mocker.patch.object(aws_pinpoint_client, "statsd_client", create=True)
to = "6135555555"
content = "foo"
reference = "ref"

with set_config_values(
notify_api,
{
"AWS_PINPOINT_SC_POOL_ID": "sc_pool_id",
"AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id",
"AWS_PINPOINT_CONFIGURATION_SET_NAME": "config_set_name",
"AWS_PINPOINT_SC_TEMPLATE_IDS": [],
},
):
aws_pinpoint_client.send_sms(to, content, reference=reference, template_id=sample_template.id)

boto_mock.send_text_message.assert_called_once_with(
DestinationPhoneNumber="+16135555555",
OriginationIdentity="default_pool_id",
MessageBody=content,
MessageType="TRANSACTIONAL",
ConfigurationSetName="config_set_name",
)


@pytest.mark.serial
def test_send_sms_sends_to_shortcode_pool(notify_api, mocker, sample_template):
boto_mock = mocker.patch.object(aws_pinpoint_client, "_client", create=True)
mocker.patch.object(aws_pinpoint_client, "statsd_client", create=True)
to = "6135555555"
content = "foo"
reference = "ref"

with set_config_values(
notify_api,
{
"AWS_PINPOINT_SC_POOL_ID": "sc_pool_id",
"AWS_PINPOINT_DEFAULT_POOL_ID": "default_pool_id",
"AWS_PINPOINT_CONFIGURATION_SET_NAME": "config_set_name",
"AWS_PINPOINT_SC_TEMPLATE_IDS": [str(sample_template.id)],
},
):
aws_pinpoint_client.send_sms(to, content, reference=reference, template_id=sample_template.id)

boto_mock.send_text_message.assert_called_once_with(
DestinationPhoneNumber="+16135555555",
OriginationIdentity="sc_pool_id",
MessageBody=content,
MessageType="TRANSACTIONAL",
ConfigurationSetName="config_set_name",
)


def test_send_sms_returns_raises_error_if_there_is_no_valid_number_is_found(notify_api, mocker):
mocker.patch.object(aws_pinpoint_client, "_client", create=True)
mocker.patch.object(aws_pinpoint_client, "statsd_client", create=True)

to = ""
content = reference = "foo"

with pytest.raises(ValueError) as excinfo:
aws_pinpoint_client.send_sms(to, content, reference)

assert "No valid numbers found for SMS delivery" in str(excinfo.value)
7 changes: 6 additions & 1 deletion tests/app/dao/test_provider_details_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,14 @@ def test_get_sms_provider_with_equal_priority_returns_provider(


def test_get_current_sms_provider_returns_active_only(restore_provider_details):
# Note that we currently have two active sms providers: sns and pinpoint.
current_provider = get_current_provider("sms")
current_provider.active = False
dao_update_provider_details(current_provider)
current_provider = get_current_provider("sms")
current_provider.active = False
dao_update_provider_details(current_provider)

new_current_provider = get_current_provider("sms")

assert new_current_provider is None
Expand Down Expand Up @@ -308,5 +313,5 @@ def test_dao_get_provider_stats(notify_db_session):
assert result[5].identifier == "pinpoint"
assert result[5].notification_type == "sms"
assert result[5].supports_international is False
assert result[5].active is False
assert result[5].active is True
assert result[5].current_month_billable_sms == 0
Loading

0 comments on commit 12b9571

Please sign in to comment.