Skip to content

Commit

Permalink
feat: sms-login initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
splaunov committed Mar 4, 2022
1 parent 567a3d7 commit de1f301
Show file tree
Hide file tree
Showing 108 changed files with 2,855 additions and 71 deletions.
5 changes: 5 additions & 0 deletions cmd/clidoc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func init() {
"NewInfoSelfServiceRemoveWebAuthn": text.NewInfoSelfServiceRemoveWebAuthn("{name}", aSecondAgo),
"NewErrorValidationVerificationFlowExpired": text.NewErrorValidationVerificationFlowExpired(-time.Second),
"NewInfoSelfServiceVerificationSuccessful": text.NewInfoSelfServiceVerificationSuccessful(),
"NewInfoSelfServicePhoneVerificationSuccessful": text.NewInfoSelfServicePhoneVerificationSuccessful(),
"NewVerificationEmailSent": text.NewVerificationEmailSent(),
"NewErrorValidationVerificationTokenInvalidOrAlreadyUsed": text.NewErrorValidationVerificationTokenInvalidOrAlreadyUsed(),
"NewErrorValidationVerificationRetrySuccess": text.NewErrorValidationVerificationRetrySuccess(),
Expand Down Expand Up @@ -114,7 +115,11 @@ func init() {
"NewErrorValidationRecoveryTokenInvalidOrAlreadyUsed": text.NewErrorValidationRecoveryTokenInvalidOrAlreadyUsed(),
"NewErrorValidationRecoveryRetrySuccess": text.NewErrorValidationRecoveryRetrySuccess(),
"NewErrorValidationRecoveryStateFailure": text.NewErrorValidationRecoveryStateFailure(),
"NewErrorValidationInvalidCode": text.NewErrorValidationInvalidCode(),
"NewErrorCodeSent": text.NewErrorCodeSent(),
"NewInfoNodeInputEmail": text.NewInfoNodeInputEmail(),
"NewInfoNodeInputPhone": text.NewInfoNodeInputPhone(),
"NewVerificationPhoneSent": text.NewVerificationPhoneSent(),
}
}

Expand Down
14 changes: 11 additions & 3 deletions courier/courier.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package courier

//go:generate mockgen -destination=mocks/mock_courier.go -package=mocks github.com/ory/kratos/courier Courier

import (
"context"
"time"
Expand Down Expand Up @@ -79,9 +81,15 @@ func (c *courier) Work(ctx context.Context) error {

func (c *courier) watchMessages(ctx context.Context, errChan chan error) {
for {
if err := backoff.Retry(func() error {
return c.DispatchQueue(ctx)
}, backoff.NewExponentialBackOff()); err != nil {
if err := backoff.RetryNotify(
func() error {
return c.DispatchQueue(ctx)
},
backoff.NewExponentialBackOff(),
func(err error, t time.Duration) {
c.deps.Logger().WithError(err).Error("Courier DispatchQueue error")
},
); err != nil {
errChan <- err
return
}
Expand Down
14 changes: 11 additions & 3 deletions courier/courier_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,28 @@ import (
)

func (c *courier) DispatchMessage(ctx context.Context, msg Message) error {
messageStatus := MessageStatusSent
logMessage := "Courier sent out message."

switch msg.Type {
case MessageTypeEmail:
if err := c.dispatchEmail(ctx, msg); err != nil {
return err
}
case MessageTypePhone:
if err := c.dispatchSMS(ctx, msg); err != nil {
return err
if m, ok := err.(*MessageRejectedError); ok {
messageStatus = MessageStatusRejected
logMessage = m.Error()
} else {
return err
}
}
default:
return errors.Errorf("received unexpected message type: %d", msg.Type)
}

if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil {
if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, messageStatus); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", msg.ID).
Expand All @@ -33,7 +41,7 @@ func (c *courier) DispatchMessage(ctx context.Context, msg Message) error {
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
Debug("Courier sent out message.")
Debug(logMessage)

return nil
}
Expand Down
1 change: 1 addition & 0 deletions courier/email_templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (
TypeVerificationValid TemplateType = "verification_valid"
TypeOTP TemplateType = "otp"
TypeTestStub TemplateType = "stub"
TypeCode TemplateType = "code"
)

func GetEmailTemplateType(t EmailTemplate) (TemplateType, error) {
Expand Down
18 changes: 18 additions & 0 deletions courier/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package courier

import (
"fmt"
"net/http"
)

type MessageRejectedError struct {
StatusCode int
ResponseBody string
}

func NewMessageRejectedError(statusCode int, responseBody string) error {
return &MessageRejectedError{StatusCode: statusCode, ResponseBody: responseBody}
}
func (m *MessageRejectedError) Error() string {
return fmt.Sprintf("Status: %s, body: %s", http.StatusText(m.StatusCode), m.ResponseBody)
}
1 change: 1 addition & 0 deletions courier/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const (
MessageStatusQueued MessageStatus = iota + 1
MessageStatusSent
MessageStatusProcessing
MessageStatusRejected // Service won't send this message for some unrecoverable reasons (incorrect phone number e.g.)
)

type MessageType int
Expand Down
124 changes: 124 additions & 0 deletions courier/mocks/mock_courier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion courier/sms.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package courier
import (
"context"
"encoding/json"
"io"
"net/http"

"github.com/pkg/errors"
Expand Down Expand Up @@ -105,8 +106,18 @@ func (c *courier) dispatchSMS(ctx context.Context, msg Message) error {
switch res.StatusCode {
case http.StatusOK:
case http.StatusCreated:
case http.StatusBadRequest:
b, err := io.ReadAll(res.Body)
if err != nil {
return err
}
return NewMessageRejectedError(res.StatusCode, string(b))
default:
return errors.New(http.StatusText(res.StatusCode))
b, err := io.ReadAll(res.Body)
if err != nil {
return err
}
return errors.Errorf("Status: %s, body: %s", http.StatusText(res.StatusCode), string(b))
}

return nil
Expand Down
8 changes: 8 additions & 0 deletions courier/sms_templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ func SMSTemplateType(t SMSTemplate) (TemplateType, error) {
return TypeOTP, nil
case *sms.TestStub:
return TypeTestStub, nil
case *sms.CodeMessage:
return TypeCode, nil
default:
return "", errors.Errorf("unexpected template type")
}
Expand All @@ -40,6 +42,12 @@ func NewSMSTemplateFromMessage(d Dependencies, m Message) (SMSTemplate, error) {
return nil, err
}
return sms.NewTestStub(d, &t), nil
case TypeCode:
var t sms.CodeMessageModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
return nil, err
}
return sms.NewCodeMessage(d, &t), nil
default:
return nil, errors.Errorf("received unexpected message template type: %s", m.TemplateType)
}
Expand Down
1 change: 1 addition & 0 deletions courier/sms_templates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestNewSMSTemplateFromMessage(t *testing.T) {
for tmplType, expectedTmpl := range map[courier.TemplateType]courier.SMSTemplate{
courier.TypeOTP: sms.NewOTPMessage(reg, &sms.OTPMessageModel{To: "+12345678901"}),
courier.TypeTestStub: sms.NewTestStub(reg, &sms.TestStubModel{To: "+12345678901", Body: "test body"}),
courier.TypeCode: sms.NewCodeMessage(reg, &sms.CodeMessageModel{To: "+12345678901"}),
} {
t.Run(fmt.Sprintf("case=%s", tmplType), func(t *testing.T) {
tmplData, err := json.Marshal(expectedTmpl)
Expand Down
4 changes: 2 additions & 2 deletions courier/sms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ func TestQueueSMS(t *testing.T) {
expectedSMS := []*sms.TestStubModel{
{
To: "+12065550101",
Body: "test-sms-body-1",
Body: "test-code-body-1",
},
{
To: "+12065550102",
Body: "test-sms-body-2",
Body: "test-code-body-2",
},
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
code {{ .Code }}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
stub sms body {{ .Body }}
35 changes: 35 additions & 0 deletions courier/template/sms/code_login.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package sms

import (
"context"
"encoding/json"
"github.com/ory/kratos/courier/template"
"os"
)

type (
CodeMessage struct {
d template.Dependencies
m *CodeMessageModel
}
CodeMessageModel struct {
To string
Code string
}
)

func NewCodeMessage(d template.Dependencies, m *CodeMessageModel) *CodeMessage {
return &CodeMessage{d: d, m: m}
}

func (t *CodeMessage) PhoneNumber() (string, error) {
return t.m.To, nil
}

func (t *CodeMessage) SMSBody(ctx context.Context) (string, error) {
return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "login/sms.body.gotmpl", "login/sms.body*", t.m, "")
}

func (t *CodeMessage) MarshalJSON() ([]byte, error) {
return json.Marshal(t.m)
}
9 changes: 9 additions & 0 deletions driver/clock/clock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package clock

import (
"github.com/benbjohnson/clock"
)

type Provider interface {
Clock() clock.Clock
}
10 changes: 10 additions & 0 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ const (
ViperKeyWebAuthnRPIcon = "selfservice.methods.webauthn.config.rp.issuer"
ViperKeyClientHTTPNoPrivateIPRanges = "clients.http.disallow_private_ip_ranges"
ViperKeyVersion = "version"
CodeMaxAttempts = "selfservice.methods.code.config.max_attempts"
CodeLifespan = "selfservice.methods.code.config.lifespan"
)

const (
Expand Down Expand Up @@ -1234,3 +1236,11 @@ func (p *Config) getTSLCertificates(daemon, certBase64, keyBase64, certPath, key
p.l.Infof("TLS has not been configured for %s, skipping", daemon)
return nil
}

func (p *Config) SelfServiceCodeMaxAttempts() int {
return p.p.Int(CodeMaxAttempts)
}

func (p *Config) SelfServiceCodeLifespan() time.Duration {
return p.p.DurationF(CodeLifespan, time.Hour)
}
Loading

0 comments on commit de1f301

Please sign in to comment.