Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added sms support to courier #1941

Merged
merged 10 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 32 additions & 210 deletions courier/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,144 +2,68 @@ package courier

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"strconv"
"time"

"github.com/hashicorp/go-retryablehttp"

"github.com/ory/kratos/driver/config"
"github.com/ory/x/httpx"

"github.com/cenkalti/backoff"
"github.com/gofrs/uuid"
"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"

"github.com/ory/herodot"

gomail "github.com/ory/mail/v3"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/x"
gomail "github.com/ory/mail/v3"
"github.com/ory/x/httpx"
)

type (
SMTPDependencies interface {
Dependencies interface {
PersistenceProvider
x.LoggingProvider
ConfigProvider
HTTPClient(ctx context.Context, opts ...httpx.ResilientOptions) *retryablehttp.Client
}
TemplateTyper func(t EmailTemplate) (TemplateType, error)
EmailTemplateFromMessage func(d SMTPDependencies, msg Message) (EmailTemplate, error)
Courier struct {
Dialer *gomail.Dialer
d SMTPDependencies
GetTemplateType TemplateTyper
NewEmailTemplateFromMessage EmailTemplateFromMessage

Courier interface {
Work(ctx context.Context) error
QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error)
QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error)
SmtpDialer() *gomail.Dialer
DispatchQueue(ctx context.Context) error
}

Provider interface {
Courier(ctx context.Context) *Courier
Courier(ctx context.Context) Courier
}

ConfigProvider interface {
CourierConfig(ctx context.Context) config.CourierConfigs
}
)

func NewSMTP(ctx context.Context, d SMTPDependencies) *Courier {
uri := d.CourierConfig(ctx).CourierSMTPURL()

password, _ := uri.User.Password()
port, _ := strconv.ParseInt(uri.Port(), 10, 0)

dialer := &gomail.Dialer{
Host: uri.Hostname(),
Port: int(port),
Username: uri.User.Username(),
Password: password,

Timeout: time.Second * 10,
RetryFailure: true,
}

sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify"))

// SMTP schemes
// smtp: smtp clear text (with uri parameter) or with StartTLS (enforced by default)
// smtps: smtp with implicit TLS (recommended way in 2021 to avoid StartTLS downgrade attacks
// and defaulting to fully-encrypted protocols https://datatracker.ietf.org/doc/html/rfc8314)
switch uri.Scheme {
case "smtp":
// Enforcing StartTLS by default for security best practices (config review, etc.)
skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls"))
if !skipStartTLS {
// #nosec G402 This is ok (and required!) because it is configurable and disabled by default.
dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()}
// Enforcing StartTLS
dialer.StartTLSPolicy = gomail.MandatoryStartTLS
}
case "smtps":
// #nosec G402 This is ok (and required!) because it is configurable and disabled by default.
dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()}
dialer.SSL = true
courier struct {
smsClient *smsClient
smtpClient *smtpClient
deps Dependencies
failOnError bool
}
)

return &Courier{
d: d,
Dialer: dialer,
GetTemplateType: GetTemplateType,
NewEmailTemplateFromMessage: NewEmailTemplateFromMessage,
func NewCourier(ctx context.Context, deps Dependencies) Courier {
return &courier{
smsClient: newSMS(ctx, deps),
smtpClient: newSMTP(ctx, deps),
deps: deps,
}
}

func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) {
recipient, err := t.EmailRecipient()
if err != nil {
return uuid.Nil, err
}

subject, err := t.EmailSubject(ctx)
if err != nil {
return uuid.Nil, err
}

bodyPlaintext, err := t.EmailBodyPlaintext(ctx)
if err != nil {
return uuid.Nil, err
}

templateType, err := m.GetTemplateType(t)
if err != nil {
return uuid.Nil, err
}

templateData, err := json.Marshal(t)
if err != nil {
return uuid.Nil, err
}

message := &Message{
Status: MessageStatusQueued,
Type: MessageTypeEmail,
Recipient: recipient,
Body: bodyPlaintext,
Subject: subject,
TemplateType: templateType,
TemplateData: templateData,
}

if err := m.d.CourierPersister().AddMessage(ctx, message); err != nil {
return uuid.Nil, err
}
return message.ID, nil
func (c *courier) FailOnDispatchError() {
c.failOnError = true
}

func (m *Courier) Work(ctx context.Context) error {
func (c *courier) Work(ctx context.Context) error {
errChan := make(chan error)
defer close(errChan)

go m.watchMessages(ctx, errChan)
go c.watchMessages(ctx, errChan)

select {
case <-ctx.Done():
Expand All @@ -152,116 +76,14 @@ func (m *Courier) Work(ctx context.Context) error {
}
}

func (m *Courier) watchMessages(ctx context.Context, errChan chan error) {
func (c *courier) watchMessages(ctx context.Context, errChan chan error) {
for {
if err := backoff.Retry(func() error {
return m.DispatchQueue(ctx)
return c.DispatchQueue(ctx)
}, backoff.NewExponentialBackOff()); err != nil {
errChan <- err
return
}
time.Sleep(time.Second)
}
}

func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error {
switch msg.Type {
case MessageTypeEmail:
from := m.d.CourierConfig(ctx).CourierSMTPFrom()
fromName := m.d.CourierConfig(ctx).CourierSMTPFromName()
gm := gomail.NewMessage()
if fromName == "" {
gm.SetHeader("From", from)
} else {
gm.SetAddressHeader("From", from, fromName)
}

gm.SetHeader("To", msg.Recipient)
gm.SetHeader("Subject", msg.Subject)

headers := m.d.CourierConfig(ctx).CourierSMTPHeaders()
for k, v := range headers {
gm.SetHeader(k, v)
}

gm.SetBody("text/plain", msg.Body)

tmpl, err := m.NewEmailTemplateFromMessage(m.d, msg)
if err != nil {
m.d.Logger().
WithError(err).
WithField("message_id", msg.ID).
Error(`Unable to get email template from message.`)
} else {
htmlBody, err := tmpl.EmailBody(ctx)
if err != nil {
m.d.Logger().
WithError(err).
WithField("message_id", msg.ID).
Error(`Unable to get email body from template.`)
} else {
gm.AddAlternative("text/html", htmlBody)
}
}

if err := m.Dialer.DialAndSend(ctx, gm); err != nil {
m.d.Logger().
WithError(err).
WithField("smtp_server", fmt.Sprintf("%s:%d", m.Dialer.Host, m.Dialer.Port)).
WithField("smtp_ssl_enabled", m.Dialer.SSL).
// WithField("email_to", msg.Recipient).
WithField("message_from", from).
Error("Unable to send email using SMTP connection.")
return errors.WithStack(err)
}

if err := m.d.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil {
m.d.Logger().
WithError(err).
WithField("message_id", msg.ID).
Error(`Unable to set the message status to "sent".`)
return err
}

m.d.Logger().
WithField("message_id", msg.ID).
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
Debug("Courier sent out message.")
return nil
}
return errors.Errorf("received unexpected message type: %d", msg.Type)
}

func (m *Courier) DispatchQueue(ctx context.Context) error {
if len(m.Dialer.Host) == 0 {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but courier.smtp_url is not set!"))
}

messages, err := m.d.CourierPersister().NextMessages(ctx, 10)
if err != nil {
if errors.Is(err, ErrQueueEmpty) {
return nil
}
return err
}

for k := range messages {
var msg = messages[k]
if err := m.DispatchMessage(ctx, msg); err != nil {
for _, replace := range messages[k:] {
if err := m.d.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil {
m.d.Logger().
WithError(err).
WithField("message_id", replace.ID).
Error(`Unable to reset the failed message's status to "queued".`)
}
}

return err
}
}

return nil
}
70 changes: 70 additions & 0 deletions courier/courier_dispatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package courier

import (
"context"

"github.com/pkg/errors"
)

func (c *courier) DispatchMessage(ctx context.Context, msg Message) error {
switch msg.Type {
case MessageTypeEmail:
if err := c.dispatchEmail(ctx, msg); err != nil {
return err
}
case MessageTypePhone:
if err := c.dispatchSMS(ctx, msg); err != nil {
return err
}
default:
return errors.Errorf("received unexpected message type: %d", msg.Type)
}

if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", msg.ID).
Error(`Unable to set the message status to "sent".`)
return err
}

c.deps.Logger().
WithField("message_id", msg.ID).
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
Debug("Courier sent out message.")

return nil
}

func (c *courier) DispatchQueue(ctx context.Context) error {
messages, err := c.deps.CourierPersister().NextMessages(ctx, 10)
if err != nil {
if errors.Is(err, ErrQueueEmpty) {
return nil
}
return err
}

for k := range messages {
var msg = messages[k]
if err := c.DispatchMessage(ctx, msg); err != nil {
for _, replace := range messages[k:] {
if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil {
if c.failOnError {
return err
}
c.deps.Logger().
WithError(err).
WithField("message_id", replace.ID).
Error(`Unable to reset the failed message's status to "queued".`)
}
}

return err
}
}

return nil
}
Loading