Skip to content

Commit

Permalink
feat: added sms sending support to courier
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksiireshetnik committed Nov 9, 2021
1 parent 437cc99 commit 0ffa6b7
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 164 deletions.
205 changes: 45 additions & 160 deletions courier/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,128 +2,48 @@ package courier

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"strconv"
"time"

"github.com/cenkalti/backoff"
"github.com/gofrs/uuid"
"github.com/pkg/errors"

"github.com/ory/herodot"

gomail "github.com/ory/mail/v3"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/x"
gomail "github.com/ory/mail/v3"
)

type (
smtpDependencies interface {
Dependencies interface {
PersistenceProvider
x.LoggingProvider
config.Provider
}

Courier struct {
Dialer *gomail.Dialer
d smtpDependencies
smsClient *smsClient
smtpDialer *gomail.Dialer
deps Dependencies
}

Provider interface {
Courier(ctx context.Context) *Courier
}
)

func NewSMTP(d smtpDependencies, c *config.Config) *Courier {
uri := c.CourierSMTPURL()
password, _ := uri.User.Password()
port, _ := strconv.ParseInt(uri.Port(), 10, 0)

dialer := &gomail.Dialer{
Host: uri.Hostname(),
Port: int(port),
Username: uri.User.Username(),
Password: password,

Timeout: time.Second * 10,
RetryFailure: true,
}

sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify"))

// SMTP schemes
// smtp: smtp clear text (with uri parameter) or with StartTLS (enforced by default)
// smtps: smtp with implicit TLS (recommended way in 2021 to avoid StartTLS downgrade attacks
// and defaulting to fully-encrypted protocols https://datatracker.ietf.org/doc/html/rfc8314)
switch uri.Scheme {
case "smtp":
// Enforcing StartTLS by default for security best practices (config review, etc.)
skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls"))
if !skipStartTLS {
// #nosec G402 This is ok (and required!) because it is configurable and disabled by default.
dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()}
// Enforcing StartTLS
dialer.StartTLSPolicy = gomail.MandatoryStartTLS
}
case "smtps":
// #nosec G402 This is ok (and required!) because it is configurable and disabled by default.
dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()}
dialer.SSL = true
}

func NewCourier(d Dependencies, c *config.Config) *Courier {
return &Courier{
d: d,
Dialer: dialer,
smsClient: newSMS(c),
smtpDialer: newSMTP(c),
deps: d,
}
}

func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) {
recipient, err := t.EmailRecipient()
if err != nil {
return uuid.Nil, err
}

subject, err := t.EmailSubject()
if err != nil {
return uuid.Nil, err
}

bodyPlaintext, err := t.EmailBodyPlaintext()
if err != nil {
return uuid.Nil, err
}

templateType, err := GetTemplateType(t)
if err != nil {
return uuid.Nil, err
}

templateData, err := json.Marshal(t)
if err != nil {
return uuid.Nil, err
}

message := &Message{
Status: MessageStatusQueued,
Type: MessageTypeEmail,
Recipient: recipient,
Body: bodyPlaintext,
Subject: subject,
TemplateType: templateType,
TemplateData: templateData,
}
if err := m.d.CourierPersister().AddMessage(ctx, message); err != nil {
return uuid.Nil, err
}
return message.ID, nil
}

func (m *Courier) Work(ctx context.Context) error {
func (c *Courier) Work(ctx context.Context) error {
errChan := make(chan error)
defer close(errChan)

go m.watchMessages(ctx, errChan)
go c.watchMessages(ctx, errChan)

select {
case <-ctx.Done():
Expand All @@ -136,10 +56,10 @@ func (m *Courier) Work(ctx context.Context) error {
}
}

func (m *Courier) watchMessages(ctx context.Context, errChan chan error) {
func (c *Courier) watchMessages(ctx context.Context, errChan chan error) {
for {
if err := backoff.Retry(func() error {
return m.DispatchQueue(ctx)
return c.DispatchQueue(ctx)
}, backoff.NewExponentialBackOff()); err != nil {
errChan <- err
return
Expand All @@ -148,82 +68,47 @@ func (m *Courier) watchMessages(ctx context.Context, errChan chan error) {
}
}

func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error {
func (c *Courier) DispatchMessage(ctx context.Context, msg Message) error {
switch msg.Type {
case MessageTypeEmail:
from := m.d.Config(ctx).CourierSMTPFrom()
fromName := m.d.Config(ctx).CourierSMTPFromName()
gm := gomail.NewMessage()
if fromName == "" {
gm.SetHeader("From", from)
} else {
gm.SetAddressHeader("From", from, fromName)
}

gm.SetHeader("To", msg.Recipient)
gm.SetHeader("Subject", msg.Subject)

headers := m.d.Config(ctx).CourierSMTPHeaders()
for k, v := range headers {
gm.SetHeader(k, v)
}

gm.SetBody("text/plain", msg.Body)

tmpl, err := NewEmailTemplateFromMessage(m.d.Config(ctx), msg)
if err != nil {
m.d.Logger().
WithError(err).
WithField("message_id", msg.ID).
Error(`Unable to get email template from message.`)
} else {
htmlBody, err := tmpl.EmailBody()
if err != nil {
m.d.Logger().
WithError(err).
WithField("message_id", msg.ID).
Error(`Unable to get email body from template.`)
} else {
gm.AddAlternative("text/html", htmlBody)
}
}

if err := m.Dialer.DialAndSend(ctx, gm); err != nil {
m.d.Logger().
WithError(err).
WithField("smtp_server", fmt.Sprintf("%s:%d", m.Dialer.Host, m.Dialer.Port)).
WithField("smtp_ssl_enabled", m.Dialer.SSL).
// WithField("email_to", msg.Recipient).
WithField("message_from", from).
Error("Unable to send email using SMTP connection.")
return errors.WithStack(err)
if err := c.dispatchEmail(ctx, msg); err != nil {
return err
}

if err := m.d.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil {
m.d.Logger().
WithError(err).
WithField("message_id", msg.ID).
Error(`Unable to set the message status to "sent".`)
case MessageTypePhone:
if err := c.dispatchSMS(ctx, msg); err != nil {
return err
}
default:
return errors.New("received unexpected message type")
}

m.d.Logger().
if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
Debug("Courier sent out message.")
return nil
Error(`Unable to set the message status to "sent".`)
return err
}

c.deps.Logger().
WithField("message_id", msg.ID).
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
Debug("Courier sent out message.")

return errors.Errorf("received unexpected message type: %d", msg.Type)
}

func (m *Courier) DispatchQueue(ctx context.Context) error {
if len(m.Dialer.Host) == 0 {
func (c *Courier) DispatchQueue(ctx context.Context) error {
if len(c.smtpDialer.Host) == 0 {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but courier.smtp_url is not set!"))
}
if len(c.smsClient.Host) == 0 {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver a sms but courier.sms.host is not set!"))
}

messages, err := m.d.CourierPersister().NextMessages(ctx, 10)
messages, err := c.deps.CourierPersister().NextMessages(ctx, 10)
if err != nil {
if errors.Is(err, ErrQueueEmpty) {
return nil
Expand All @@ -233,10 +118,10 @@ func (m *Courier) DispatchQueue(ctx context.Context) error {

for k := range messages {
var msg = messages[k]
if err := m.DispatchMessage(ctx, msg); err != nil {
if err := c.DispatchMessage(ctx, msg); err != nil {
for _, replace := range messages[k:] {
if err := m.d.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil {
m.d.Logger().
if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", replace.ID).
Error(`Unable to reset the failed message's status to "queued".`)
Expand Down
5 changes: 3 additions & 2 deletions courier/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"context"
"time"

"github.com/ory/kratos/corp"

"github.com/gofrs/uuid"

"github.com/ory/kratos/corp"
)

type MessageStatus int
Expand All @@ -21,6 +21,7 @@ type MessageType int

const (
MessageTypeEmail MessageType = iota + 1
MessageTypePhone
)

// swagger:ignore
Expand Down
59 changes: 59 additions & 0 deletions courier/sms.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package courier

import (
"context"
"errors"
"net/http"
"net/url"

"github.com/gofrs/uuid"

"github.com/ory/kratos/driver/config"
)

type smsClient struct {
*http.Client
Host string
}

func newSMS(c *config.Config) *smsClient {
return &smsClient{
Client: &http.Client{},
Host: c.CourierSMSHost().String(),
}

}

func (c *Courier) QueueSMS(ctx context.Context, t EmailTemplate) (uuid.UUID, error) {
message := &Message{
Status: MessageStatusQueued,
Type: MessageTypePhone,
}
if err := c.deps.CourierPersister().AddMessage(ctx, message); err != nil {
return uuid.Nil, err
}

return message.ID, nil
}

func (c *Courier) dispatchSMS(ctx context.Context, msg Message) error {
from := c.deps.Config(ctx).CourierSMSFrom()

v := url.Values{}
v.Set("To", msg.Recipient)
v.Set("From", from)
v.Set("Body", msg.Body)

res, err := c.smsClient.PostForm(c.smsClient.Host, v)
if err != nil {
return err
}

defer res.Body.Close()

if res.StatusCode != http.StatusOK {
return errors.New(http.StatusText(res.StatusCode))
}

return nil
}
Loading

0 comments on commit 0ffa6b7

Please sign in to comment.