diff --git a/courier/courier.go b/courier/courier.go index 9f549efa070..1a3a33b7e52 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -2,128 +2,48 @@ package courier import ( "context" - "crypto/tls" - "encoding/json" - "fmt" - "strconv" "time" "github.com/cenkalti/backoff" - "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/ory/herodot" - - gomail "github.com/ory/mail/v3" - "github.com/ory/kratos/driver/config" "github.com/ory/kratos/x" + gomail "github.com/ory/mail/v3" ) type ( - smtpDependencies interface { + Dependencies interface { PersistenceProvider x.LoggingProvider config.Provider } + Courier struct { - Dialer *gomail.Dialer - d smtpDependencies + smsClient *smsClient + smtpDialer *gomail.Dialer + deps Dependencies } + Provider interface { Courier(ctx context.Context) *Courier } ) -func NewSMTP(d smtpDependencies, c *config.Config) *Courier { - uri := c.CourierSMTPURL() - password, _ := uri.User.Password() - port, _ := strconv.ParseInt(uri.Port(), 10, 0) - - dialer := &gomail.Dialer{ - Host: uri.Hostname(), - Port: int(port), - Username: uri.User.Username(), - Password: password, - - Timeout: time.Second * 10, - RetryFailure: true, - } - - sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify")) - - // SMTP schemes - // smtp: smtp clear text (with uri parameter) or with StartTLS (enforced by default) - // smtps: smtp with implicit TLS (recommended way in 2021 to avoid StartTLS downgrade attacks - // and defaulting to fully-encrypted protocols https://datatracker.ietf.org/doc/html/rfc8314) - switch uri.Scheme { - case "smtp": - // Enforcing StartTLS by default for security best practices (config review, etc.) - skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls")) - if !skipStartTLS { - // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. - dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} - // Enforcing StartTLS - dialer.StartTLSPolicy = gomail.MandatoryStartTLS - } - case "smtps": - // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. - dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} - dialer.SSL = true - } - +func NewCourier(d Dependencies, c *config.Config) *Courier { return &Courier{ - d: d, - Dialer: dialer, + smsClient: newSMS(c), + smtpDialer: newSMTP(c), + deps: d, } } -func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) { - recipient, err := t.EmailRecipient() - if err != nil { - return uuid.Nil, err - } - - subject, err := t.EmailSubject() - if err != nil { - return uuid.Nil, err - } - - bodyPlaintext, err := t.EmailBodyPlaintext() - if err != nil { - return uuid.Nil, err - } - - templateType, err := GetTemplateType(t) - if err != nil { - return uuid.Nil, err - } - - templateData, err := json.Marshal(t) - if err != nil { - return uuid.Nil, err - } - - message := &Message{ - Status: MessageStatusQueued, - Type: MessageTypeEmail, - Recipient: recipient, - Body: bodyPlaintext, - Subject: subject, - TemplateType: templateType, - TemplateData: templateData, - } - if err := m.d.CourierPersister().AddMessage(ctx, message); err != nil { - return uuid.Nil, err - } - return message.ID, nil -} - -func (m *Courier) Work(ctx context.Context) error { +func (c *Courier) Work(ctx context.Context) error { errChan := make(chan error) defer close(errChan) - go m.watchMessages(ctx, errChan) + go c.watchMessages(ctx, errChan) select { case <-ctx.Done(): @@ -136,10 +56,10 @@ func (m *Courier) Work(ctx context.Context) error { } } -func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { +func (c *Courier) watchMessages(ctx context.Context, errChan chan error) { for { if err := backoff.Retry(func() error { - return m.DispatchQueue(ctx) + return c.DispatchQueue(ctx) }, backoff.NewExponentialBackOff()); err != nil { errChan <- err return @@ -148,82 +68,47 @@ func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { } } -func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error { +func (c *Courier) DispatchMessage(ctx context.Context, msg Message) error { switch msg.Type { case MessageTypeEmail: - from := m.d.Config(ctx).CourierSMTPFrom() - fromName := m.d.Config(ctx).CourierSMTPFromName() - gm := gomail.NewMessage() - if fromName == "" { - gm.SetHeader("From", from) - } else { - gm.SetAddressHeader("From", from, fromName) - } - - gm.SetHeader("To", msg.Recipient) - gm.SetHeader("Subject", msg.Subject) - - headers := m.d.Config(ctx).CourierSMTPHeaders() - for k, v := range headers { - gm.SetHeader(k, v) - } - - gm.SetBody("text/plain", msg.Body) - - tmpl, err := NewEmailTemplateFromMessage(m.d.Config(ctx), msg) - if err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to get email template from message.`) - } else { - htmlBody, err := tmpl.EmailBody() - if err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to get email body from template.`) - } else { - gm.AddAlternative("text/html", htmlBody) - } - } - - if err := m.Dialer.DialAndSend(ctx, gm); err != nil { - m.d.Logger(). - WithError(err). - WithField("smtp_server", fmt.Sprintf("%s:%d", m.Dialer.Host, m.Dialer.Port)). - WithField("smtp_ssl_enabled", m.Dialer.SSL). - // WithField("email_to", msg.Recipient). - WithField("message_from", from). - Error("Unable to send email using SMTP connection.") - return errors.WithStack(err) + if err := c.dispatchEmail(ctx, msg); err != nil { + return err } - - if err := m.d.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to set the message status to "sent".`) + case MessageTypePhone: + if err := c.dispatchSMS(ctx, msg); err != nil { return err } + default: + return errors.New("received unexpected message type") + } - m.d.Logger(). + if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil { + c.deps.Logger(). + WithError(err). WithField("message_id", msg.ID). - WithField("message_type", msg.Type). - WithField("message_template_type", msg.TemplateType). - WithField("message_subject", msg.Subject). - Debug("Courier sent out message.") - return nil + Error(`Unable to set the message status to "sent".`) + return err } + + c.deps.Logger(). + WithField("message_id", msg.ID). + WithField("message_type", msg.Type). + WithField("message_template_type", msg.TemplateType). + WithField("message_subject", msg.Subject). + Debug("Courier sent out message.") + return errors.Errorf("received unexpected message type: %d", msg.Type) } -func (m *Courier) DispatchQueue(ctx context.Context) error { - if len(m.Dialer.Host) == 0 { +func (c *Courier) DispatchQueue(ctx context.Context) error { + if len(c.smtpDialer.Host) == 0 { return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but courier.smtp_url is not set!")) } + if len(c.smsClient.Host) == 0 { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver a sms but courier.sms.host is not set!")) + } - messages, err := m.d.CourierPersister().NextMessages(ctx, 10) + messages, err := c.deps.CourierPersister().NextMessages(ctx, 10) if err != nil { if errors.Is(err, ErrQueueEmpty) { return nil @@ -233,10 +118,10 @@ func (m *Courier) DispatchQueue(ctx context.Context) error { for k := range messages { var msg = messages[k] - if err := m.DispatchMessage(ctx, msg); err != nil { + if err := c.DispatchMessage(ctx, msg); err != nil { for _, replace := range messages[k:] { - if err := m.d.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil { - m.d.Logger(). + if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil { + c.deps.Logger(). WithError(err). WithField("message_id", replace.ID). Error(`Unable to reset the failed message's status to "queued".`) diff --git a/courier/message.go b/courier/message.go index 94b102781cb..0641a0d49b8 100644 --- a/courier/message.go +++ b/courier/message.go @@ -4,9 +4,9 @@ import ( "context" "time" - "github.com/ory/kratos/corp" - "github.com/gofrs/uuid" + + "github.com/ory/kratos/corp" ) type MessageStatus int @@ -21,6 +21,7 @@ type MessageType int const ( MessageTypeEmail MessageType = iota + 1 + MessageTypePhone ) // swagger:ignore diff --git a/courier/sms.go b/courier/sms.go new file mode 100644 index 00000000000..42b00fd0949 --- /dev/null +++ b/courier/sms.go @@ -0,0 +1,59 @@ +package courier + +import ( + "context" + "errors" + "net/http" + "net/url" + + "github.com/gofrs/uuid" + + "github.com/ory/kratos/driver/config" +) + +type smsClient struct { + *http.Client + Host string +} + +func newSMS(c *config.Config) *smsClient { + return &smsClient{ + Client: &http.Client{}, + Host: c.CourierSMSHost().String(), + } + +} + +func (c *Courier) QueueSMS(ctx context.Context, t EmailTemplate) (uuid.UUID, error) { + message := &Message{ + Status: MessageStatusQueued, + Type: MessageTypePhone, + } + if err := c.deps.CourierPersister().AddMessage(ctx, message); err != nil { + return uuid.Nil, err + } + + return message.ID, nil +} + +func (c *Courier) dispatchSMS(ctx context.Context, msg Message) error { + from := c.deps.Config(ctx).CourierSMSFrom() + + v := url.Values{} + v.Set("To", msg.Recipient) + v.Set("From", from) + v.Set("Body", msg.Body) + + res, err := c.smsClient.PostForm(c.smsClient.Host, v) + if err != nil { + return err + } + + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return errors.New(http.StatusText(res.StatusCode)) + } + + return nil +} diff --git a/courier/smtp.go b/courier/smtp.go new file mode 100644 index 00000000000..fbefd3adc5b --- /dev/null +++ b/courier/smtp.go @@ -0,0 +1,150 @@ +package courier + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/ory/kratos/driver/config" + gomail "github.com/ory/mail/v3" +) + +func newSMTP(c *config.Config) *gomail.Dialer { + uri := c.CourierSMTPURL() + password, _ := uri.User.Password() + port, _ := strconv.ParseInt(uri.Port(), 10, 0) + + dialer := &gomail.Dialer{ + Host: uri.Hostname(), + Port: int(port), + Username: uri.User.Username(), + Password: password, + + Timeout: time.Second * 10, + RetryFailure: true, + } + + sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify")) + + // SMTP schemes + // smtp: smtp clear text (with uri parameter) or with StartTLS (enforced by default) + // smtps: smtp with implicit TLS (recommended way in 2021 to avoid StartTLS downgrade attacks + // and defaulting to fully-encrypted protocols https://datatracker.ietf.org/doc/html/rfc8314) + switch uri.Scheme { + case "smtp": + // Enforcing StartTLS by default for security best practices (config review, etc.) + skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls")) + if !skipStartTLS { + // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. + dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} + // Enforcing StartTLS + dialer.StartTLSPolicy = gomail.MandatoryStartTLS + } + case "smtps": + // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. + dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} + dialer.SSL = true + } + + return dialer +} + +func (c *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) { + recipient, err := t.EmailRecipient() + if err != nil { + return uuid.Nil, err + } + + subject, err := t.EmailSubject() + if err != nil { + return uuid.Nil, err + } + + bodyPlaintext, err := t.EmailBodyPlaintext() + if err != nil { + return uuid.Nil, err + } + + templateType, err := GetTemplateType(t) + if err != nil { + return uuid.Nil, err + } + + templateData, err := json.Marshal(t) + if err != nil { + return uuid.Nil, err + } + + message := &Message{ + Status: MessageStatusQueued, + Type: MessageTypeEmail, + Recipient: recipient, + Body: bodyPlaintext, + Subject: subject, + TemplateType: templateType, + TemplateData: templateData, + } + if err := c.deps.CourierPersister().AddMessage(ctx, message); err != nil { + return uuid.Nil, err + } + + return message.ID, nil +} + +func (c *Courier) dispatchEmail(ctx context.Context, msg Message) error { + from := c.deps.Config(ctx).CourierSMTPFrom() + fromName := c.deps.Config(ctx).CourierSMTPFromName() + gm := gomail.NewMessage() + if fromName == "" { + gm.SetHeader("From", from) + } else { + gm.SetAddressHeader("From", from, fromName) + } + + gm.SetHeader("To", msg.Recipient) + gm.SetHeader("Subject", msg.Subject) + + headers := c.deps.Config(ctx).CourierSMTPHeaders() + for k, v := range headers { + gm.SetHeader(k, v) + } + + gm.SetBody("text/plain", msg.Body) + + tmpl, err := NewEmailTemplateFromMessage(c.deps.Config(ctx), msg) + if err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to get email template from message.`) + } else { + htmlBody, err := tmpl.EmailBody() + if err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to get email body from template.`) + } else { + gm.AddAlternative("text/html", htmlBody) + } + } + + if err := c.smtpDialer.DialAndSend(ctx, gm); err != nil { + c.deps.Logger(). + WithError(err). + WithField("smtp_server", fmt.Sprintf("%s:%d", c.smtpDialer.Host, c.smtpDialer.Port)). + WithField("smtp_ssl_enabled", c.smtpDialer.SSL). + // WithField("email_to", msg.Recipient). + WithField("message_from", from). + Error("Unable to send email using SMTP connection.") + return errors.WithStack(err) + } + + return nil +} diff --git a/driver/config/config.go b/driver/config/config.go index a7a71b5beaa..ed14ebab1b9 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -61,6 +61,8 @@ const ( ViperKeyCourierSMTPFrom = "courier.smtp.from_address" ViperKeyCourierSMTPFromName = "courier.smtp.from_name" ViperKeyCourierSMTPHeaders = "courier.smtp.headers" + ViperKeyCourierSMSHost = "courier.sms.host" + ViperKeyCourierSMSFrom = "courier.sms.from_name" ViperKeySecretsDefault = "secrets.default" ViperKeySecretsCookie = "secrets.cookie" ViperKeySecretsCipher = "secrets.cipher" @@ -847,6 +849,14 @@ func (p *Config) CourierSMTPHeaders() map[string]string { return p.p.StringMap(ViperKeyCourierSMTPHeaders) } +func (p *Config) CourierSMSHost() *url.URL { + return p.ParseURIOrFail(ViperKeyCourierSMSHost) +} + +func (p *Config) CourierSMSFrom() string { + return p.p.StringF(ViperKeyCourierSMSFrom, "Kratos") +} + func splitUrlAndFragment(s string) (string, string) { i := strings.IndexByte(s, '#') if i < 0 { diff --git a/driver/registry_default.go b/driver/registry_default.go index c6fe37a9eeb..00c8af69d16 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -579,7 +579,7 @@ func (m *RegistryDefault) SetPersister(p persistence.Persister) { } func (m *RegistryDefault) Courier(ctx context.Context) *courier.Courier { - return courier.NewSMTP(m, m.Config(ctx)) + return courier.NewCourier(m, m.Config(ctx)) } func (m *RegistryDefault) ContinuityManager() continuity.Manager { diff --git a/embedx/config.schema.json b/embedx/config.schema.json index f55a4c380e5..dd4ba07b1fb 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -1337,10 +1337,37 @@ "connection_uri" ], "additionalProperties": false + }, + "sms": { + "title": "SMS sender configuration", + "description": "Configures outgoing sms messages using HTTP protocol with generic SMS provider", + "type": "object", + "properties": { + "host": { + "title": "HTTP address of API endpoint", + "description": "This URL will be used to connect to SMS provider.", + "examples": [ + "https://api.twillio.com/sms/send" + ], + "type": "string", + "pattern": "^https?:\\/\\/.*" + }, + "from_name": { + "title": "SMS Sender Address", + "description": "The recipient of a sms will see this as the sender address.", + "type": "string", + "default": "ORY/Kratos" + } + }, + "required": [ + "host" + ], + "additionalProperties": false } }, "required": [ - "smtp" + "smtp", + "sms" ], "additionalProperties": false }, diff --git a/selfservice/strategy/link/sender.go b/selfservice/strategy/link/sender.go index 025a3885512..6c3a90af54a 100644 --- a/selfservice/strategy/link/sender.go +++ b/selfservice/strategy/link/sender.go @@ -182,6 +182,9 @@ func (s *Sender) send(ctx context.Context, via string, t courier.EmailTemplate) case identity.AddressTypeEmail: _, err := s.r.Courier(ctx).QueueEmail(ctx, t) return err + case identity.AddressTypePhone: + _, err := s.r.Courier(ctx).QueueSMS(ctx, t) + return err default: return errors.Errorf("received unexpected via type: %s", via) }