From 68b4b2def07d494f8a134d1c3569e54c4e0126fd Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Mon, 31 Jan 2022 08:00:41 +0000 Subject: [PATCH 01/10] autogen(docs): generate and format documentation From b423c88b151b8bd657376b9aab36812d7d4e4d65 Mon Sep 17 00:00:00 2001 From: reshetnik-alexey Date: Tue, 9 Nov 2021 16:39:37 +0530 Subject: [PATCH 02/10] feat: added sms sending support to courier --- courier/courier.go | 240 ++-------------- courier/courier_test.go | 149 ---------- courier/dispatcher.go | 67 +++++ courier/{templates.go => email_templates.go} | 40 +-- ...plates_test.go => email_templates_test.go} | 26 +- courier/message.go | 5 +- courier/sms.go | 113 ++++++++ courier/sms_templates.go | 46 +++ courier/sms_templates_test.go | 60 ++++ courier/sms_test.go | 108 +++++++ courier/smtp.go | 180 ++++++++++++ courier/smtp_test.go | 154 ++++++++++ courier/stub/request.config.twilio.jsonnet | 5 + .../builtin/templates/otp/sms.body.gotmpl | 3 + .../templates/otp/test_stub/sms.body.gotmpl | 1 + courier/template/email/recovery_invalid.go | 43 +++ .../{ => email}/recovery_invalid_test.go | 9 +- courier/template/email/recovery_valid.go | 45 +++ .../{ => email}/recovery_valid_test.go | 9 +- courier/template/email/stub.go | 45 +++ .../template/email/verification_invalid.go | 43 +++ .../{ => email}/verification_invalid_test.go | 9 +- courier/template/email/verification_valid.go | 45 +++ .../{ => email}/verification_valid_test.go | 9 +- courier/template/load_template.go | 4 +- courier/template/load_template_test.go | 30 +- courier/template/recovery_invalid.go | 41 --- courier/template/recovery_valid.go | 43 --- courier/template/sms/otp.go | 37 +++ courier/template/sms/otp_test.go | 34 +++ courier/template/sms/stub.go | 37 +++ courier/template/sms/stub_test.go | 31 ++ courier/template/stub.go | 42 --- courier/template/template.go | 5 +- courier/template/testhelpers/testhelpers.go | 14 +- courier/template/verification_invalid.go | 41 --- courier/template/verification_valid.go | 43 --- driver/config/config.go | 33 +++ driver/registry_default.go | 4 +- embedx/config.schema.json | 73 +++++ request/auth.go | 31 ++ request/auth_strategy.go | 76 +++++ request/auth_strategy_test.go | 67 +++++ request/auth_test.go | 56 ++++ request/builder.go | 203 +++++++++++++ request/builder_test.go | 272 ++++++++++++++++++ request/config.go | 61 ++++ request/stub/test_body.jsonnet | 5 + selfservice/hook/web_hook.go | 221 +------------- selfservice/hook/web_hook_test.go | 268 ----------------- selfservice/strategy/link/sender.go | 15 +- x/require.go | 5 + 52 files changed, 2111 insertions(+), 1135 deletions(-) create mode 100644 courier/dispatcher.go rename courier/{templates.go => email_templates.go} (65%) rename courier/{templates_test.go => email_templates_test.go} (64%) create mode 100644 courier/sms.go create mode 100644 courier/sms_templates.go create mode 100644 courier/sms_templates_test.go create mode 100644 courier/sms_test.go create mode 100644 courier/smtp.go create mode 100644 courier/smtp_test.go create mode 100644 courier/stub/request.config.twilio.jsonnet create mode 100644 courier/template/courier/builtin/templates/otp/sms.body.gotmpl create mode 100644 courier/template/courier/builtin/templates/otp/test_stub/sms.body.gotmpl create mode 100644 courier/template/email/recovery_invalid.go rename courier/template/{ => email}/recovery_invalid_test.go (66%) create mode 100644 courier/template/email/recovery_valid.go rename courier/template/{ => email}/recovery_valid_test.go (67%) create mode 100644 courier/template/email/stub.go create mode 100644 courier/template/email/verification_invalid.go rename courier/template/{ => email}/verification_invalid_test.go (67%) create mode 100644 courier/template/email/verification_valid.go rename courier/template/{ => email}/verification_valid_test.go (65%) delete mode 100644 courier/template/recovery_invalid.go delete mode 100644 courier/template/recovery_valid.go create mode 100644 courier/template/sms/otp.go create mode 100644 courier/template/sms/otp_test.go create mode 100644 courier/template/sms/stub.go create mode 100644 courier/template/sms/stub_test.go delete mode 100644 courier/template/stub.go delete mode 100644 courier/template/verification_invalid.go delete mode 100644 courier/template/verification_valid.go create mode 100644 request/auth.go create mode 100644 request/auth_strategy.go create mode 100644 request/auth_strategy_test.go create mode 100644 request/auth_test.go create mode 100644 request/builder.go create mode 100644 request/builder_test.go create mode 100644 request/config.go create mode 100644 request/stub/test_body.jsonnet delete mode 100644 selfservice/hook/web_hook_test.go diff --git a/courier/courier.go b/courier/courier.go index 329a95a48d2..e173db3cb3d 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -2,144 +2,62 @@ package courier import ( "context" - "crypto/tls" - "encoding/json" - "fmt" - "strconv" "time" - "github.com/hashicorp/go-retryablehttp" - - "github.com/ory/kratos/driver/config" - "github.com/ory/x/httpx" - "github.com/cenkalti/backoff" "github.com/gofrs/uuid" + "github.com/hashicorp/go-retryablehttp" "github.com/pkg/errors" - "github.com/ory/herodot" - - gomail "github.com/ory/mail/v3" - + "github.com/ory/kratos/driver/config" "github.com/ory/kratos/x" + gomail "github.com/ory/mail/v3" + "github.com/ory/x/httpx" ) type ( - SMTPDependencies interface { + Dependencies interface { PersistenceProvider x.LoggingProvider ConfigProvider HTTPClient(ctx context.Context, opts ...httpx.ResilientOptions) *retryablehttp.Client } - TemplateTyper func(t EmailTemplate) (TemplateType, error) - EmailTemplateFromMessage func(d SMTPDependencies, msg Message) (EmailTemplate, error) - Courier struct { - Dialer *gomail.Dialer - d SMTPDependencies - GetTemplateType TemplateTyper - NewEmailTemplateFromMessage EmailTemplateFromMessage - } - Provider interface { - Courier(ctx context.Context) *Courier - } - ConfigProvider interface { - CourierConfig(ctx context.Context) config.CourierConfigs - } -) - -func NewSMTP(ctx context.Context, d SMTPDependencies) *Courier { - uri := d.CourierConfig(ctx).CourierSMTPURL() - - password, _ := uri.User.Password() - port, _ := strconv.ParseInt(uri.Port(), 10, 0) - - dialer := &gomail.Dialer{ - Host: uri.Hostname(), - Port: int(port), - Username: uri.User.Username(), - Password: password, - Timeout: time.Second * 10, - RetryFailure: true, - } - - sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify")) - - // SMTP schemes - // smtp: smtp clear text (with uri parameter) or with StartTLS (enforced by default) - // smtps: smtp with implicit TLS (recommended way in 2021 to avoid StartTLS downgrade attacks - // and defaulting to fully-encrypted protocols https://datatracker.ietf.org/doc/html/rfc8314) - switch uri.Scheme { - case "smtp": - // Enforcing StartTLS by default for security best practices (config review, etc.) - skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls")) - if !skipStartTLS { - // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. - dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} - // Enforcing StartTLS - dialer.StartTLSPolicy = gomail.MandatoryStartTLS - } - case "smtps": - // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. - dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} - dialer.SSL = true + Courier interface { + Work(ctx context.Context) error + QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) + QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error) + SmtpDialer() *gomail.Dialer } - return &Courier{ - d: d, - Dialer: dialer, - GetTemplateType: GetTemplateType, - NewEmailTemplateFromMessage: NewEmailTemplateFromMessage, - } -} - -func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) { - recipient, err := t.EmailRecipient() - if err != nil { - return uuid.Nil, err - } - - subject, err := t.EmailSubject(ctx) - if err != nil { - return uuid.Nil, err - } - - bodyPlaintext, err := t.EmailBodyPlaintext(ctx) - if err != nil { - return uuid.Nil, err - } - - templateType, err := m.GetTemplateType(t) - if err != nil { - return uuid.Nil, err + Provider interface { + Courier(ctx context.Context) Courier } - templateData, err := json.Marshal(t) - if err != nil { - return uuid.Nil, err + ConfigProvider interface { + CourierConfig(ctx context.Context) config.CourierConfigs } - message := &Message{ - Status: MessageStatusQueued, - Type: MessageTypeEmail, - Recipient: recipient, - Body: bodyPlaintext, - Subject: subject, - TemplateType: templateType, - TemplateData: templateData, + courier struct { + smsClient *smsClient + smtpClient *smtpClient + deps Dependencies } +) - if err := m.d.CourierPersister().AddMessage(ctx, message); err != nil { - return uuid.Nil, err +func NewCourier(ctx context.Context, deps Dependencies) Courier { + return &courier{ + smsClient: newSMS(ctx, deps), + smtpClient: newSMTP(ctx, deps), + deps: deps, } - return message.ID, nil } -func (m *Courier) Work(ctx context.Context) error { +func (c *courier) Work(ctx context.Context) error { errChan := make(chan error) defer close(errChan) - go m.watchMessages(ctx, errChan) + go c.watchMessages(ctx, errChan) select { case <-ctx.Done(): @@ -152,10 +70,10 @@ func (m *Courier) Work(ctx context.Context) error { } } -func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { +func (c *courier) watchMessages(ctx context.Context, errChan chan error) { for { if err := backoff.Retry(func() error { - return m.DispatchQueue(ctx) + return c.DispatchQueue(ctx) }, backoff.NewExponentialBackOff()); err != nil { errChan <- err return @@ -163,105 +81,3 @@ func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { time.Sleep(time.Second) } } - -func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error { - switch msg.Type { - case MessageTypeEmail: - from := m.d.CourierConfig(ctx).CourierSMTPFrom() - fromName := m.d.CourierConfig(ctx).CourierSMTPFromName() - gm := gomail.NewMessage() - if fromName == "" { - gm.SetHeader("From", from) - } else { - gm.SetAddressHeader("From", from, fromName) - } - - gm.SetHeader("To", msg.Recipient) - gm.SetHeader("Subject", msg.Subject) - - headers := m.d.CourierConfig(ctx).CourierSMTPHeaders() - for k, v := range headers { - gm.SetHeader(k, v) - } - - gm.SetBody("text/plain", msg.Body) - - tmpl, err := m.NewEmailTemplateFromMessage(m.d, msg) - if err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to get email template from message.`) - } else { - htmlBody, err := tmpl.EmailBody(ctx) - if err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to get email body from template.`) - } else { - gm.AddAlternative("text/html", htmlBody) - } - } - - if err := m.Dialer.DialAndSend(ctx, gm); err != nil { - m.d.Logger(). - WithError(err). - WithField("smtp_server", fmt.Sprintf("%s:%d", m.Dialer.Host, m.Dialer.Port)). - WithField("smtp_ssl_enabled", m.Dialer.SSL). - // WithField("email_to", msg.Recipient). - WithField("message_from", from). - Error("Unable to send email using SMTP connection.") - return errors.WithStack(err) - } - - if err := m.d.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to set the message status to "sent".`) - return err - } - - m.d.Logger(). - WithField("message_id", msg.ID). - WithField("message_type", msg.Type). - WithField("message_template_type", msg.TemplateType). - WithField("message_subject", msg.Subject). - Debug("Courier sent out message.") - return nil - } - return errors.Errorf("received unexpected message type: %d", msg.Type) -} - -func (m *Courier) DispatchQueue(ctx context.Context) error { - if len(m.Dialer.Host) == 0 { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but courier.smtp_url is not set!")) - } - - messages, err := m.d.CourierPersister().NextMessages(ctx, 10) - if err != nil { - if errors.Is(err, ErrQueueEmpty) { - return nil - } - return err - } - - for k := range messages { - var msg = messages[k] - if err := m.DispatchMessage(ctx, msg); err != nil { - for _, replace := range messages[k:] { - if err := m.d.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", replace.ID). - Error(`Unable to reset the failed message's status to "queued".`) - } - } - - return err - } - } - - return nil -} diff --git a/courier/courier_test.go b/courier/courier_test.go index d47ba8e7fd4..b871e19fbe5 100644 --- a/courier/courier_test.go +++ b/courier/courier_test.go @@ -1,30 +1,10 @@ package courier_test import ( - "context" - "fmt" - "io/ioutil" - "net/http" "testing" - "time" - - "github.com/sirupsen/logrus" "github.com/ory/kratos/x" - gomail "github.com/ory/mail/v3" - - "github.com/gofrs/uuid" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" - dhelper "github.com/ory/x/sqlcon/dockertest" - - courier "github.com/ory/kratos/courier" - templates "github.com/ory/kratos/courier/template" - "github.com/ory/kratos/driver/config" - "github.com/ory/kratos/internal" ) // nolint:staticcheck @@ -33,132 +13,3 @@ func TestMain(m *testing.M) { atexit.Add(x.CleanUpTestSMTP) atexit.Exit(m.Run()) } - -func TestNewSMTP(t *testing.T) { - ctx := context.Background() - - setupConfig := func(stringURL string) *courier.Courier { - conf, reg := internal.NewFastRegistryWithMocks(t) - conf.MustSet(config.ViperKeyCourierSMTPURL, stringURL) - t.Logf("SMTP URL: %s", conf.CourierSMTPURL().String()) - return courier.NewSMTP(ctx, reg) - } - - if testing.Short() { - t.SkipNow() - } - - //Should enforce StartTLS => dialer.StartTLSPolicy = gomail.MandatoryStartTLS and dialer.SSL = false - smtp := setupConfig("smtp://foo:bar@my-server:1234/") - assert.Equal(t, smtp.Dialer.StartTLSPolicy, gomail.MandatoryStartTLS, "StartTLS not enforced") - assert.Equal(t, smtp.Dialer.SSL, false, "Implicit TLS should not be enabled") - - //Should enforce TLS => dialer.SSL = true - smtp = setupConfig("smtps://foo:bar@my-server:1234/") - assert.Equal(t, smtp.Dialer.SSL, true, "Implicit TLS should be enabled") - - //Should allow cleartext => dialer.StartTLSPolicy = gomail.OpportunisticStartTLS and dialer.SSL = false - smtp = setupConfig("smtp://foo:bar@my-server:1234/?disable_starttls=true") - assert.Equal(t, smtp.Dialer.StartTLSPolicy, gomail.OpportunisticStartTLS, "StartTLS is enforced") - assert.Equal(t, smtp.Dialer.SSL, false, "Implicit TLS should not be enabled") -} - -func TestSMTP(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - - smtp, api, err := x.RunTestSMTP() - require.NoError(t, err) - t.Logf("SMTP URL: %s", smtp) - t.Logf("API URL: %s", api) - - ctx := context.Background() - - conf, reg := internal.NewFastRegistryWithMocks(t) - conf.MustSet(config.ViperKeyCourierSMTPURL, smtp) - conf.MustSet(config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh") - reg.Logger().Level = logrus.TraceLevel - - c := reg.Courier(ctx) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - id, err := c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ - To: "test-recipient-1@example.org", - Subject: "test-subject-1", - Body: "test-body-1", - })) - require.NoError(t, err) - require.NotEqual(t, uuid.Nil, id) - - id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ - To: "test-recipient-2@example.org", - Subject: "test-subject-2", - Body: "test-body-2", - })) - require.NoError(t, err) - require.NotEqual(t, uuid.Nil, id) - - // The third email contains a sender name and custom headers - conf.MustSet(config.ViperKeyCourierSMTPFromName, "Bob") - conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header1", "foo") - conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header2", "bar") - customerHeaders := conf.CourierSMTPHeaders() - require.Len(t, customerHeaders, 2) - id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ - To: "test-recipient-3@example.org", - Subject: "test-subject-3", - Body: "test-body-3", - })) - require.NoError(t, err) - require.NotEqual(t, uuid.Nil, id) - - go func() { - require.NoError(t, c.Work(ctx)) - }() - - var body []byte - for k := 0; k < 30; k++ { - time.Sleep(time.Second) - err = func() error { - res, err := http.Get(api + "/api/v2/messages") - if err != nil { - return err - } - - defer res.Body.Close() - body, err = ioutil.ReadAll(res.Body) - if err != nil { - return err - } - - if http.StatusOK != res.StatusCode { - return errors.Errorf("expected status code 200 but got %d with body: %s", res.StatusCode, body) - } - - if total := gjson.GetBytes(body, "total").Int(); total != 3 { - return errors.Errorf("expected to have delivered at least 3 messages but got count %d with body: %s", total, body) - } - - return nil - }() - if err == nil { - break - } - } - require.NoError(t, err) - - for k := 1; k <= 3; k++ { - assert.Contains(t, string(body), fmt.Sprintf("test-subject-%d", k)) - assert.Contains(t, string(body), fmt.Sprintf("test-body-%d", k)) - assert.Contains(t, string(body), fmt.Sprintf("test-recipient-%d@example.org", k)) - assert.Contains(t, string(body), "test-stub@ory.sh") - } - - // Assertion for the third email with sender name and headers - assert.Contains(t, string(body), "Bob") - assert.Contains(t, string(body), `"test-stub-header1":["foo"]`) - assert.Contains(t, string(body), `"test-stub-header2":["bar"]`) -} diff --git a/courier/dispatcher.go b/courier/dispatcher.go new file mode 100644 index 00000000000..4d8beb7f2fa --- /dev/null +++ b/courier/dispatcher.go @@ -0,0 +1,67 @@ +package courier + +import ( + "context" + + "github.com/pkg/errors" +) + +func (c *courier) DispatchMessage(ctx context.Context, msg Message) error { + switch msg.Type { + case MessageTypeEmail: + if err := c.dispatchEmail(ctx, msg); err != nil { + return err + } + case MessageTypePhone: + if err := c.dispatchSMS(ctx, msg); err != nil { + return err + } + default: + return errors.Errorf("received unexpected message type: %d", msg.Type) + } + + if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to set the message status to "sent".`) + return err + } + + c.deps.Logger(). + WithField("message_id", msg.ID). + WithField("message_type", msg.Type). + WithField("message_template_type", msg.TemplateType). + WithField("message_subject", msg.Subject). + Debug("Courier sent out message.") + + return nil +} + +func (c *courier) DispatchQueue(ctx context.Context) error { + messages, err := c.deps.CourierPersister().NextMessages(ctx, 10) + if err != nil { + if errors.Is(err, ErrQueueEmpty) { + return nil + } + return err + } + + for k := range messages { + var msg = messages[k] + if err := c.DispatchMessage(ctx, msg); err != nil { + for _, replace := range messages[k:] { + if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", replace.ID). + Error(`Unable to reset the failed message's status to "queued".`) + } + } + + return err + } + } + + return nil +} diff --git a/courier/templates.go b/courier/email_templates.go similarity index 65% rename from courier/templates.go rename to courier/email_templates.go index cbadd573106..2ea3ea3bdf3 100644 --- a/courier/templates.go +++ b/courier/email_templates.go @@ -6,11 +6,12 @@ import ( "github.com/pkg/errors" - "github.com/ory/kratos/courier/template" + "github.com/ory/kratos/courier/template/email" ) type ( - TemplateType string + TemplateType string + EmailTemplate interface { json.Marshaler EmailSubject(context.Context) (string, error) @@ -25,58 +26,59 @@ const ( TypeRecoveryValid TemplateType = "recovery_valid" TypeVerificationInvalid TemplateType = "verification_invalid" TypeVerificationValid TemplateType = "verification_valid" + TypeOTP TemplateType = "otp" TypeTestStub TemplateType = "stub" ) -func GetTemplateType(t EmailTemplate) (TemplateType, error) { +func GetEmailTemplateType(t EmailTemplate) (TemplateType, error) { switch t.(type) { - case *template.RecoveryInvalid: + case *email.RecoveryInvalid: return TypeRecoveryInvalid, nil - case *template.RecoveryValid: + case *email.RecoveryValid: return TypeRecoveryValid, nil - case *template.VerificationInvalid: + case *email.VerificationInvalid: return TypeVerificationInvalid, nil - case *template.VerificationValid: + case *email.VerificationValid: return TypeVerificationValid, nil - case *template.TestStub: + case *email.TestStub: return TypeTestStub, nil default: return "", errors.Errorf("unexpected template type") } } -func NewEmailTemplateFromMessage(d SMTPDependencies, msg Message) (EmailTemplate, error) { +func NewEmailTemplateFromMessage(d Dependencies, msg Message) (EmailTemplate, error) { switch msg.TemplateType { case TypeRecoveryInvalid: - var t template.RecoveryInvalidModel + var t email.RecoveryInvalidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewRecoveryInvalid(d, &t), nil + return email.NewRecoveryInvalid(d, &t), nil case TypeRecoveryValid: - var t template.RecoveryValidModel + var t email.RecoveryValidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewRecoveryValid(d, &t), nil + return email.NewRecoveryValid(d, &t), nil case TypeVerificationInvalid: - var t template.VerificationInvalidModel + var t email.VerificationInvalidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewVerificationInvalid(d, &t), nil + return email.NewVerificationInvalid(d, &t), nil case TypeVerificationValid: - var t template.VerificationValidModel + var t email.VerificationValidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewVerificationValid(d, &t), nil + return email.NewVerificationValid(d, &t), nil case TypeTestStub: - var t template.TestStubModel + var t email.TestStubModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewTestStub(d, &t), nil + return email.NewTestStub(d, &t), nil default: return nil, errors.Errorf("received unexpected message template type: %s", msg.TemplateType) } diff --git a/courier/templates_test.go b/courier/email_templates_test.go similarity index 64% rename from courier/templates_test.go rename to courier/email_templates_test.go index e41fa0705bd..e6b97885e36 100644 --- a/courier/templates_test.go +++ b/courier/email_templates_test.go @@ -9,25 +9,23 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/kratos/courier" - "github.com/ory/kratos/courier/template" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/internal" ) func TestGetTemplateType(t *testing.T) { for expectedType, tmpl := range map[courier.TemplateType]courier.EmailTemplate{ - courier.TypeRecoveryInvalid: &template.RecoveryInvalid{}, - courier.TypeRecoveryValid: &template.RecoveryValid{}, - courier.TypeVerificationInvalid: &template.VerificationInvalid{}, - courier.TypeVerificationValid: &template.VerificationValid{}, - courier.TypeTestStub: &template.TestStub{}, + courier.TypeRecoveryInvalid: &email.RecoveryInvalid{}, + courier.TypeRecoveryValid: &email.RecoveryValid{}, + courier.TypeVerificationInvalid: &email.VerificationInvalid{}, + courier.TypeVerificationValid: &email.VerificationValid{}, + courier.TypeTestStub: &email.TestStub{}, } { t.Run(fmt.Sprintf("case=%s", expectedType), func(t *testing.T) { - actualType, err := courier.GetTemplateType(tmpl) + actualType, err := courier.GetEmailTemplateType(tmpl) require.NoError(t, err) require.Equal(t, expectedType, actualType) - }) - } } @@ -36,11 +34,11 @@ func TestNewEmailTemplateFromMessage(t *testing.T) { ctx := context.Background() for tmplType, expectedTmpl := range map[courier.TemplateType]courier.EmailTemplate{ - courier.TypeRecoveryInvalid: template.NewRecoveryInvalid(reg, &template.RecoveryInvalidModel{To: "foo"}), - courier.TypeRecoveryValid: template.NewRecoveryValid(reg, &template.RecoveryValidModel{To: "bar", RecoveryURL: "http://foo.bar"}), - courier.TypeVerificationInvalid: template.NewVerificationInvalid(reg, &template.VerificationInvalidModel{To: "baz"}), - courier.TypeVerificationValid: template.NewVerificationValid(reg, &template.VerificationValidModel{To: "faz", VerificationURL: "http://bar.foo"}), - courier.TypeTestStub: template.NewTestStub(reg, &template.TestStubModel{To: "far", Subject: "test subject", Body: "test body"}), + courier.TypeRecoveryInvalid: email.NewRecoveryInvalid(reg, &email.RecoveryInvalidModel{To: "foo"}), + courier.TypeRecoveryValid: email.NewRecoveryValid(reg, &email.RecoveryValidModel{To: "bar", RecoveryURL: "http://foo.bar"}), + courier.TypeVerificationInvalid: email.NewVerificationInvalid(reg, &email.VerificationInvalidModel{To: "baz"}), + courier.TypeVerificationValid: email.NewVerificationValid(reg, &email.VerificationValidModel{To: "faz", VerificationURL: "http://bar.foo"}), + courier.TypeTestStub: email.NewTestStub(reg, &email.TestStubModel{To: "far", Subject: "test subject", Body: "test body"}), } { t.Run(fmt.Sprintf("case=%s", tmplType), func(t *testing.T) { tmplData, err := json.Marshal(expectedTmpl) diff --git a/courier/message.go b/courier/message.go index 94b102781cb..0641a0d49b8 100644 --- a/courier/message.go +++ b/courier/message.go @@ -4,9 +4,9 @@ import ( "context" "time" - "github.com/ory/kratos/corp" - "github.com/gofrs/uuid" + + "github.com/ory/kratos/corp" ) type MessageStatus int @@ -21,6 +21,7 @@ type MessageType int const ( MessageTypeEmail MessageType = iota + 1 + MessageTypePhone ) // swagger:ignore diff --git a/courier/sms.go b/courier/sms.go new file mode 100644 index 00000000000..059fa638b53 --- /dev/null +++ b/courier/sms.go @@ -0,0 +1,113 @@ +package courier + +import ( + "context" + "encoding/json" + "errors" + "net/http" + + "github.com/gofrs/uuid" + + "github.com/ory/kratos/request" +) + +type sendSMSRequestBody struct { + To string + From string + Body string +} + +type smsClient struct { + *http.Client + Host string + RequestConfig json.RawMessage + + GetTemplateType func(t SMSTemplate) (TemplateType, error) + NewTemplateFromMessage func(d Dependencies, msg Message) (SMSTemplate, error) +} + +func newSMS(ctx context.Context, deps Dependencies) *smsClient { + if !deps.CourierConfig(ctx).CourierSMSEnabled() { + deps.Logger().Error("messages will not be sent - no sms gate server address is set in config") + } + + return &smsClient{ + Client: &http.Client{}, + RequestConfig: deps.CourierConfig(ctx).CourierSMSRequestConfig(), + + GetTemplateType: SMSTemplateType, + NewTemplateFromMessage: NewSMSTemplateFromMessage, + } +} + +func (c *courier) QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error) { + recipient, err := t.PhoneNumber() + if err != nil { + return uuid.Nil, err + } + + templateType, err := c.smsClient.GetTemplateType(t) + if err != nil { + return uuid.Nil, err + } + + templateData, err := json.Marshal(t) + if err != nil { + return uuid.Nil, err + } + + message := &Message{ + Status: MessageStatusQueued, + Type: MessageTypePhone, + Recipient: recipient, + TemplateType: templateType, + TemplateData: templateData, + } + if err := c.deps.CourierPersister().AddMessage(ctx, message); err != nil { + return uuid.Nil, err + } + + return message.ID, nil +} + +func (c *courier) dispatchSMS(ctx context.Context, msg Message) error { + tmpl, err := c.smsClient.NewTemplateFromMessage(c.deps, msg) + if err != nil { + return err + } + + body, err := tmpl.SMSBody(ctx) + if err != nil { + return err + } + + builder, err := request.NewBuilder(c.smsClient.RequestConfig, c.deps.Logger()) + if err != nil { + return err + } + + req, err := builder.BuildRequest(&sendSMSRequestBody{ + To: msg.Recipient, + From: c.deps.CourierConfig(ctx).CourierSMSFrom(), + Body: body, + }) + if err != nil { + return err + } + + res, err := c.smsClient.Do(req) + if err != nil { + return err + } + + defer res.Body.Close() + + switch res.StatusCode { + case http.StatusOK: + case http.StatusCreated: + default: + return errors.New(http.StatusText(res.StatusCode)) + } + + return nil +} diff --git a/courier/sms_templates.go b/courier/sms_templates.go new file mode 100644 index 00000000000..079268bd8e1 --- /dev/null +++ b/courier/sms_templates.go @@ -0,0 +1,46 @@ +package courier + +import ( + "context" + "encoding/json" + + "github.com/pkg/errors" + + "github.com/ory/kratos/courier/template/sms" +) + +type SMSTemplate interface { + json.Marshaler + SMSBody(context.Context) (string, error) + PhoneNumber() (string, error) +} + +func SMSTemplateType(t SMSTemplate) (TemplateType, error) { + switch t.(type) { + case *sms.OTPMessage: + return TypeOTP, nil + case *sms.TestStub: + return TypeTestStub, nil + default: + return "", errors.Errorf("unexpected template type") + } +} + +func NewSMSTemplateFromMessage(d Dependencies, m Message) (SMSTemplate, error) { + switch m.TemplateType { + case TypeOTP: + var t sms.OTPMessageModel + if err := json.Unmarshal(m.TemplateData, &t); err != nil { + return nil, err + } + return sms.NewOTPMessage(d, &t), nil + case TypeTestStub: + var t sms.TestStubModel + if err := json.Unmarshal(m.TemplateData, &t); err != nil { + return nil, err + } + return sms.NewTestStub(d, &t), nil + default: + return nil, errors.Errorf("received unexpected message template type: %s", m.TemplateType) + } +} diff --git a/courier/sms_templates_test.go b/courier/sms_templates_test.go new file mode 100644 index 00000000000..760f89a21e0 --- /dev/null +++ b/courier/sms_templates_test.go @@ -0,0 +1,60 @@ +package courier_test + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/internal" +) + +func TestSMSTemplateType(t *testing.T) { + for expectedType, tmpl := range map[courier.TemplateType]courier.SMSTemplate{ + courier.TypeOTP: &sms.OTPMessage{}, + courier.TypeTestStub: &sms.TestStub{}, + } { + t.Run(fmt.Sprintf("case=%s", expectedType), func(t *testing.T) { + actualType, err := courier.SMSTemplateType(tmpl) + require.NoError(t, err) + require.Equal(t, expectedType, actualType) + }) + } +} + +func TestNewSMSTemplateFromMessage(t *testing.T) { + _, reg := internal.NewFastRegistryWithMocks(t) + ctx := context.Background() + + for tmplType, expectedTmpl := range map[courier.TemplateType]courier.SMSTemplate{ + courier.TypeOTP: sms.NewOTPMessage(reg, &sms.OTPMessageModel{To: "+12345678901"}), + courier.TypeTestStub: sms.NewTestStub(reg, &sms.TestStubModel{To: "+12345678901", Body: "test body"}), + } { + t.Run(fmt.Sprintf("case=%s", tmplType), func(t *testing.T) { + tmplData, err := json.Marshal(expectedTmpl) + require.NoError(t, err) + + m := courier.Message{TemplateType: tmplType, TemplateData: tmplData} + actualTmpl, err := courier.NewSMSTemplateFromMessage(reg, m) + require.NoError(t, err) + + require.IsType(t, expectedTmpl, actualTmpl) + + expectedRecipient, err := expectedTmpl.PhoneNumber() + require.NoError(t, err) + actualRecipient, err := actualTmpl.PhoneNumber() + require.NoError(t, err) + require.Equal(t, expectedRecipient, actualRecipient) + + expectedBody, err := expectedTmpl.SMSBody(ctx) + require.NoError(t, err) + actualBody, err := actualTmpl.SMSBody(ctx) + require.NoError(t, err) + require.Equal(t, expectedBody, actualBody) + }) + } +} diff --git a/courier/sms_test.go b/courier/sms_test.go new file mode 100644 index 00000000000..0178e2fee4f --- /dev/null +++ b/courier/sms_test.go @@ -0,0 +1,108 @@ +package courier_test + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/driver/config" + "github.com/ory/kratos/internal" + "github.com/ory/kratos/x" +) + +func TestQueueSMS(t *testing.T) { + expectedSender := "Kratos Test" + expectedSMS := []*sms.TestStubModel{ + { + To: "+12065550101", + Body: "test-sms-body-1", + }, + { + To: "+12065550102", + Body: "test-sms-body-2", + }, + } + + actual := make([]*sms.TestStubModel, 0, 2) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + type sendSMSRequestBody struct { + To string + From string + Body string + } + + rb, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var body sendSMSRequestBody + + err = json.Unmarshal(rb, &body) + require.NoError(t, err) + + assert.NotEmpty(t, r.Header["Authorization"]) + assert.Equal(t, "Basic bWU6MTIzNDU=", r.Header["Authorization"][0]) + + assert.Equal(t, body.From, expectedSender) + actual = append(actual, &sms.TestStubModel{ + To: body.To, + Body: body.Body, + }) + })) + + requestConfig := fmt.Sprintf(`{ + "url": "%s", + "method": "POST", + "body": "file://./stub/request.config.twilio.jsonnet", + "auth": { + "type": "basic_auth", + "config": { + "user": "me", + "password": "12345" + } + } + }`, srv.URL) + + conf, reg := internal.NewFastRegistryWithMocks(t) + conf.MustSet(config.ViperKeyCourierSMSRequestConfig, requestConfig) + conf.MustSet(config.ViperKeyCourierSMSFrom, expectedSender) + conf.MustSet(config.ViperKeyCourierSMSEnabled, true) + conf.MustSet(config.ViperKeyCourierSMTPURL, "http://foo.url") + reg.Logger().Level = logrus.TraceLevel + + ctx := context.Background() + + c := reg.Courier(ctx) + + ctx, cancel := context.WithCancel(ctx) + defer t.Cleanup(cancel) + + for _, message := range expectedSMS { + id, err := c.QueueSMS(ctx, sms.NewTestStub(reg, message)) + require.NoError(t, err) + x.RequireNotNilUUID(t, id) + } + + go func() { + require.NoError(t, c.Work(ctx)) + }() + + time.Sleep(time.Second) + for i, message := range actual { + expected := expectedSMS[i] + + assert.Equal(t, expected.To, message.To) + assert.Equal(t, fmt.Sprintf("stub sms body %s\n", expected.Body), message.Body) + } + + srv.Close() +} diff --git a/courier/smtp.go b/courier/smtp.go new file mode 100644 index 00000000000..6eba348764a --- /dev/null +++ b/courier/smtp.go @@ -0,0 +1,180 @@ +package courier + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/ory/herodot" + gomail "github.com/ory/mail/v3" +) + +type smtpClient struct { + *gomail.Dialer + + GetTemplateType func(t EmailTemplate) (TemplateType, error) + NewTemplateFromMessage func(d Dependencies, msg Message) (EmailTemplate, error) +} + +func newSMTP(ctx context.Context, deps Dependencies) *smtpClient { + uri := deps.CourierConfig(ctx).CourierSMTPURL() + + password, _ := uri.User.Password() + port, _ := strconv.ParseInt(uri.Port(), 10, 0) + + dialer := &gomail.Dialer{ + Host: uri.Hostname(), + Port: int(port), + Username: uri.User.Username(), + Password: password, + + Timeout: time.Second * 10, + RetryFailure: true, + } + + sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify")) + + // SMTP schemes + // smtp: smtp clear text (with uri parameter) or with StartTLS (enforced by default) + // smtps: smtp with implicit TLS (recommended way in 2021 to avoid StartTLS downgrade attacks + // and defaulting to fully-encrypted protocols https://datatracker.ietf.org/doc/html/rfc8314) + switch uri.Scheme { + case "smtp": + // Enforcing StartTLS by default for security best practices (config review, etc.) + skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls")) + if !skipStartTLS { + // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. + dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} + // Enforcing StartTLS + dialer.StartTLSPolicy = gomail.MandatoryStartTLS + } + case "smtps": + // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. + dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} + dialer.SSL = true + } + + return &smtpClient{ + Dialer: dialer, + + GetTemplateType: GetEmailTemplateType, + NewTemplateFromMessage: NewEmailTemplateFromMessage, + } +} + +func (c *courier) SmtpDialer() *gomail.Dialer { + return c.smtpClient.Dialer +} + +func (c *courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) { + recipient, err := t.EmailRecipient() + if err != nil { + return uuid.Nil, err + } + + subject, err := t.EmailSubject(ctx) + if err != nil { + return uuid.Nil, err + } + + bodyPlaintext, err := t.EmailBodyPlaintext(ctx) + if err != nil { + return uuid.Nil, err + } + + templateType, err := c.smtpClient.GetTemplateType(t) + if err != nil { + return uuid.Nil, err + } + + templateData, err := json.Marshal(t) + if err != nil { + return uuid.Nil, err + } + + message := &Message{ + Status: MessageStatusQueued, + Type: MessageTypeEmail, + Recipient: recipient, + Body: bodyPlaintext, + Subject: subject, + TemplateType: templateType, + TemplateData: templateData, + } + + if err := c.deps.CourierPersister().AddMessage(ctx, message); err != nil { + return uuid.Nil, err + } + + return message.ID, nil +} + +func (c *courier) dispatchEmail(ctx context.Context, msg Message) error { + if c.smtpClient.Host == "" { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but courier.smtp_url is not set!")) + } + + from := c.deps.CourierConfig(ctx).CourierSMTPFrom() + fromName := c.deps.CourierConfig(ctx).CourierSMTPFromName() + + gm := gomail.NewMessage() + if fromName == "" { + gm.SetHeader("From", from) + } else { + gm.SetAddressHeader("From", from, fromName) + } + + gm.SetHeader("To", msg.Recipient) + gm.SetHeader("Subject", msg.Subject) + + headers := c.deps.CourierConfig(ctx).CourierSMTPHeaders() + for k, v := range headers { + gm.SetHeader(k, v) + } + + gm.SetBody("text/plain", msg.Body) + + tmpl, err := c.smtpClient.NewTemplateFromMessage(c.deps, msg) + if err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to get email template from message.`) + } else { + htmlBody, err := tmpl.EmailBody(ctx) + if err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to get email body from template.`) + } else { + gm.AddAlternative("text/html", htmlBody) + } + } + + if err := c.smtpClient.DialAndSend(ctx, gm); err != nil { + c.deps.Logger(). + WithError(err). + WithField("smtp_server", fmt.Sprintf("%s:%d", c.smtpClient.Host, c.smtpClient.Port)). + WithField("smtp_ssl_enabled", c.smtpClient.SSL). + // WithField("email_to", msg.Recipient). + WithField("message_from", from). + Error("Unable to send email using SMTP connection.") + return errors.WithStack(err) + } + + c.deps.Logger(). + WithField("message_id", msg.ID). + WithField("message_type", msg.Type). + WithField("message_template_type", msg.TemplateType). + WithField("message_subject", msg.Subject). + Debug("Courier sent out message.") + + return nil +} diff --git a/courier/smtp_test.go b/courier/smtp_test.go new file mode 100644 index 00000000000..58df4f9fc9b --- /dev/null +++ b/courier/smtp_test.go @@ -0,0 +1,154 @@ +package courier_test + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + + "github.com/ory/kratos/courier" + templates "github.com/ory/kratos/courier/template/email" + "github.com/ory/kratos/driver/config" + "github.com/ory/kratos/internal" + "github.com/ory/kratos/x" + gomail "github.com/ory/mail/v3" +) + +func TestNewSMTP(t *testing.T) { + ctx := context.Background() + + setupConfig := func(stringURL string) courier.Courier { + conf, reg := internal.NewFastRegistryWithMocks(t) + conf.MustSet(config.ViperKeyCourierSMTPURL, stringURL) + + t.Logf("SMTP URL: %s", conf.CourierSMTPURL().String()) + + return courier.NewCourier(ctx, reg) + } + + if testing.Short() { + t.SkipNow() + } + + //Should enforce StartTLS => dialer.StartTLSPolicy = gomail.MandatoryStartTLS and dialer.SSL = false + smtp := setupConfig("smtp://foo:bar@my-server:1234/") + assert.Equal(t, smtp.SmtpDialer().StartTLSPolicy, gomail.MandatoryStartTLS, "StartTLS not enforced") + assert.Equal(t, smtp.SmtpDialer().SSL, false, "Implicit TLS should not be enabled") + + //Should enforce TLS => dialer.SSL = true + smtp = setupConfig("smtps://foo:bar@my-server:1234/") + assert.Equal(t, smtp.SmtpDialer().SSL, true, "Implicit TLS should be enabled") + + //Should allow cleartext => dialer.StartTLSPolicy = gomail.OpportunisticStartTLS and dialer.SSL = false + smtp = setupConfig("smtp://foo:bar@my-server:1234/?disable_starttls=true") + assert.Equal(t, smtp.SmtpDialer().StartTLSPolicy, gomail.OpportunisticStartTLS, "StartTLS is enforced") + assert.Equal(t, smtp.SmtpDialer().SSL, false, "Implicit TLS should not be enabled") +} + +func TestQueueEmail(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + + smtp, api, err := x.RunTestSMTP() + require.NoError(t, err) + t.Logf("SMTP URL: %s", smtp) + t.Logf("API URL: %s", api) + + ctx := context.Background() + + conf, reg := internal.NewRegistryDefaultWithDSN(t, "") + conf.MustSet(config.ViperKeyCourierSMTPURL, smtp) + conf.MustSet(config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh") + reg.Logger().Level = logrus.TraceLevel + + c := reg.Courier(ctx) //??? + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + id, err := c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ + To: "test-recipient-1@example.org", + Subject: "test-subject-1", + Body: "test-body-1", + })) + require.NoError(t, err) + x.RequireNotNilUUID(t, id) + + id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ + To: "test-recipient-2@example.org", + Subject: "test-subject-2", + Body: "test-body-2", + })) + require.NoError(t, err) + x.RequireNotNilUUID(t, id) + + // The third email contains a sender name and custom headers + conf.MustSet(config.ViperKeyCourierSMTPFromName, "Bob") + conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header1", "foo") + conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header2", "bar") + customerHeaders := conf.CourierSMTPHeaders() + require.Len(t, customerHeaders, 2) + id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ + To: "test-recipient-3@example.org", + Subject: "test-subject-3", + Body: "test-body-3", + })) + require.NoError(t, err) + x.RequireNotNilUUID(t, id) + + go func() { + require.NoError(t, c.Work(ctx)) + }() + + var body []byte + for k := 0; k < 30; k++ { + time.Sleep(time.Second) + err = func() error { + res, err := http.Get(api + "/api/v2/messages") + if err != nil { + return err + } + + defer res.Body.Close() + body, err = ioutil.ReadAll(res.Body) + if err != nil { + return err + } + + if http.StatusOK != res.StatusCode { + return errors.Errorf("expected status code 200 but got %d with body: %s", res.StatusCode, body) + } + + if total := gjson.GetBytes(body, "total").Int(); total != 3 { + return errors.Errorf("expected to have delivered at least 3 messages but got count %d with body: %s", total, body) + } + + return nil + }() + if err == nil { + break + } + } + require.NoError(t, err) + + for k := 1; k <= 3; k++ { + assert.Contains(t, string(body), fmt.Sprintf("test-subject-%d", k)) + assert.Contains(t, string(body), fmt.Sprintf("test-body-%d", k)) + assert.Contains(t, string(body), fmt.Sprintf("test-recipient-%d@example.org", k)) + assert.Contains(t, string(body), "test-stub@ory.sh") + } + + // Assertion for the third email with sender name and headers + assert.Contains(t, string(body), "Bob") + assert.Contains(t, string(body), `"test-stub-header1":["foo"]`) + assert.Contains(t, string(body), `"test-stub-header2":["bar"]`) +} diff --git a/courier/stub/request.config.twilio.jsonnet b/courier/stub/request.config.twilio.jsonnet new file mode 100644 index 00000000000..93752e14503 --- /dev/null +++ b/courier/stub/request.config.twilio.jsonnet @@ -0,0 +1,5 @@ +function(ctx) { + from: ctx.From, + to: ctx.To, + body: ctx.Body +} diff --git a/courier/template/courier/builtin/templates/otp/sms.body.gotmpl b/courier/template/courier/builtin/templates/otp/sms.body.gotmpl new file mode 100644 index 00000000000..a630a83b82d --- /dev/null +++ b/courier/template/courier/builtin/templates/otp/sms.body.gotmpl @@ -0,0 +1,3 @@ +Hi, please verify your account using following code: + +{{ .Code }} diff --git a/courier/template/courier/builtin/templates/otp/test_stub/sms.body.gotmpl b/courier/template/courier/builtin/templates/otp/test_stub/sms.body.gotmpl new file mode 100644 index 00000000000..a37e4640152 --- /dev/null +++ b/courier/template/courier/builtin/templates/otp/test_stub/sms.body.gotmpl @@ -0,0 +1 @@ +stub sms body {{ .Body }} diff --git a/courier/template/email/recovery_invalid.go b/courier/template/email/recovery_invalid.go new file mode 100644 index 00000000000..25c20c2095e --- /dev/null +++ b/courier/template/email/recovery_invalid.go @@ -0,0 +1,43 @@ +package email + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + RecoveryInvalid struct { + d template.Dependencies + m *RecoveryInvalidModel + } + RecoveryInvalidModel struct { + To string + } +) + +func NewRecoveryInvalid(d template.Dependencies, m *RecoveryInvalidModel) *RecoveryInvalid { + return &RecoveryInvalid{d: d, m: m} +} + +func (t *RecoveryInvalid) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *RecoveryInvalid) EmailSubject(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.subject.gotmpl", "recovery/invalid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Subject) +} + +func (t *RecoveryInvalid) EmailBody(ctx context.Context) (string, error) { + return template.LoadHTML(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.body.gotmpl", "recovery/invalid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Body.HTML) +} + +func (t *RecoveryInvalid) EmailBodyPlaintext(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.body.plaintext.gotmpl", "recovery/invalid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Body.PlainText) +} + +func (t *RecoveryInvalid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/recovery_invalid_test.go b/courier/template/email/recovery_invalid_test.go similarity index 66% rename from courier/template/recovery_invalid_test.go rename to courier/template/email/recovery_invalid_test.go index aa5b1c83f0c..d3d533ab53a 100644 --- a/courier/template/recovery_invalid_test.go +++ b/courier/template/email/recovery_invalid_test.go @@ -1,13 +1,12 @@ -package template_test +package email_test import ( "context" "testing" "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/courier/template/testhelpers" - - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) @@ -17,12 +16,12 @@ func TestRecoverInvalid(t *testing.T) { t.Run("test=with courier templates directory", func(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) - tpl := template.NewRecoveryInvalid(reg, &template.RecoveryInvalidModel{}) + tpl := email.NewRecoveryInvalid(reg, &email.RecoveryInvalidModel{}) testhelpers.TestRendered(t, ctx, tpl) }) t.Run("case=test remote resources", func(t *testing.T) { - testhelpers.TestRemoteTemplates(t, "courier/builtin/templates/recovery/invalid", courier.TypeRecoveryInvalid) + testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery/invalid", courier.TypeRecoveryInvalid) }) } diff --git a/courier/template/email/recovery_valid.go b/courier/template/email/recovery_valid.go new file mode 100644 index 00000000000..65ce00f27c0 --- /dev/null +++ b/courier/template/email/recovery_valid.go @@ -0,0 +1,45 @@ +package email + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + RecoveryValid struct { + d template.Dependencies + m *RecoveryValidModel + } + RecoveryValidModel struct { + To string + RecoveryURL string + Identity map[string]interface{} + } +) + +func NewRecoveryValid(d template.Dependencies, m *RecoveryValidModel) *RecoveryValid { + return &RecoveryValid{d: d, m: m} +} + +func (t *RecoveryValid) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *RecoveryValid) EmailSubject(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.subject.gotmpl", "recovery/valid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Subject) +} + +func (t *RecoveryValid) EmailBody(ctx context.Context) (string, error) { + return template.LoadHTML(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.body.gotmpl", "recovery/valid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Body.HTML) +} + +func (t *RecoveryValid) EmailBodyPlaintext(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.body.plaintext.gotmpl", "recovery/valid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Body.PlainText) +} + +func (t *RecoveryValid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/recovery_valid_test.go b/courier/template/email/recovery_valid_test.go similarity index 67% rename from courier/template/recovery_valid_test.go rename to courier/template/email/recovery_valid_test.go index 0de24aa4b9d..0264fba9a4d 100644 --- a/courier/template/recovery_valid_test.go +++ b/courier/template/email/recovery_valid_test.go @@ -1,13 +1,12 @@ -package template_test +package email_test import ( "context" "testing" "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/courier/template/testhelpers" - - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) @@ -17,12 +16,12 @@ func TestRecoverValid(t *testing.T) { t.Run("test=with courier templates directory", func(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) - tpl := template.NewRecoveryValid(reg, &template.RecoveryValidModel{}) + tpl := email.NewRecoveryValid(reg, &email.RecoveryValidModel{}) testhelpers.TestRendered(t, ctx, tpl) }) t.Run("test=with remote resources", func(t *testing.T) { - testhelpers.TestRemoteTemplates(t, "courier/builtin/templates/recovery/valid", courier.TypeRecoveryValid) + testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery/valid", courier.TypeRecoveryValid) }) } diff --git a/courier/template/email/stub.go b/courier/template/email/stub.go new file mode 100644 index 00000000000..e5cecaf657a --- /dev/null +++ b/courier/template/email/stub.go @@ -0,0 +1,45 @@ +package email + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + TestStub struct { + d template.Dependencies + m *TestStubModel + } + TestStubModel struct { + To string + Subject string + Body string + } +) + +func NewTestStub(d template.Dependencies, m *TestStubModel) *TestStub { + return &TestStub{d: d, m: m} +} + +func (t *TestStub) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *TestStub) EmailSubject(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.subject.gotmpl", "test_stub/email.subject*", t.m, "") +} + +func (t *TestStub) EmailBody(ctx context.Context) (string, error) { + return template.LoadHTML(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.body.gotmpl", "test_stub/email.body*", t.m, "") +} + +func (t *TestStub) EmailBodyPlaintext(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.body.plaintext.gotmpl", "test_stub/email.body.plaintext*", t.m, "") +} + +func (t *TestStub) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/email/verification_invalid.go b/courier/template/email/verification_invalid.go new file mode 100644 index 00000000000..f153c13aa92 --- /dev/null +++ b/courier/template/email/verification_invalid.go @@ -0,0 +1,43 @@ +package email + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + VerificationInvalid struct { + d template.Dependencies + m *VerificationInvalidModel + } + VerificationInvalidModel struct { + To string + } +) + +func NewVerificationInvalid(d template.Dependencies, m *VerificationInvalidModel) *VerificationInvalid { + return &VerificationInvalid{d: d, m: m} +} + +func (t *VerificationInvalid) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *VerificationInvalid) EmailSubject(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.subject.gotmpl", "verification/invalid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Subject) +} + +func (t *VerificationInvalid) EmailBody(ctx context.Context) (string, error) { + return template.LoadHTML(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.body.gotmpl", "verification/invalid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Body.HTML) +} + +func (t *VerificationInvalid) EmailBodyPlaintext(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.body.plaintext.gotmpl", "verification/invalid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Body.PlainText) +} + +func (t *VerificationInvalid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/verification_invalid_test.go b/courier/template/email/verification_invalid_test.go similarity index 67% rename from courier/template/verification_invalid_test.go rename to courier/template/email/verification_invalid_test.go index 8bbb9972e58..15a837e0996 100644 --- a/courier/template/verification_invalid_test.go +++ b/courier/template/email/verification_invalid_test.go @@ -1,13 +1,12 @@ -package template_test +package email_test import ( "context" "testing" "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/courier/template/testhelpers" - - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) @@ -17,14 +16,14 @@ func TestVerifyInvalid(t *testing.T) { t.Run("test=with courier templates directory", func(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) - tpl := template.NewVerificationInvalid(reg, &template.VerificationInvalidModel{}) + tpl := email.NewVerificationInvalid(reg, &email.VerificationInvalidModel{}) testhelpers.TestRendered(t, ctx, tpl) }) t.Run("test=with remote resources", func(t *testing.T) { t.Run("test=with remote resources", func(t *testing.T) { - testhelpers.TestRemoteTemplates(t, "courier/builtin/templates/verification/invalid", courier.TypeVerificationInvalid) + testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification/invalid", courier.TypeVerificationInvalid) }) }) } diff --git a/courier/template/email/verification_valid.go b/courier/template/email/verification_valid.go new file mode 100644 index 00000000000..3de84840bdb --- /dev/null +++ b/courier/template/email/verification_valid.go @@ -0,0 +1,45 @@ +package email + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + VerificationValid struct { + d template.Dependencies + m *VerificationValidModel + } + VerificationValidModel struct { + To string + VerificationURL string + Identity map[string]interface{} + } +) + +func NewVerificationValid(d template.Dependencies, m *VerificationValidModel) *VerificationValid { + return &VerificationValid{d: d, m: m} +} + +func (t *VerificationValid) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *VerificationValid) EmailSubject(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.subject.gotmpl", "verification/valid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Subject) +} + +func (t *VerificationValid) EmailBody(ctx context.Context) (string, error) { + return template.LoadHTML(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.body.gotmpl", "verification/valid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Body.HTML) +} + +func (t *VerificationValid) EmailBodyPlaintext(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.body.plaintext.gotmpl", "verification/valid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Body.PlainText) +} + +func (t *VerificationValid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/verification_valid_test.go b/courier/template/email/verification_valid_test.go similarity index 65% rename from courier/template/verification_valid_test.go rename to courier/template/email/verification_valid_test.go index 2313c74d0fe..1ce209445fe 100644 --- a/courier/template/verification_valid_test.go +++ b/courier/template/email/verification_valid_test.go @@ -1,13 +1,12 @@ -package template_test +package email_test import ( "context" "testing" "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/courier/template/testhelpers" - - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) @@ -17,12 +16,12 @@ func TestVerifyValid(t *testing.T) { t.Run("test=with courier templates directory", func(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) - tpl := template.NewVerificationValid(reg, &template.VerificationValidModel{}) + tpl := email.NewVerificationValid(reg, &email.VerificationValidModel{}) testhelpers.TestRendered(t, ctx, tpl) }) t.Run("test=with remote resources", func(t *testing.T) { - testhelpers.TestRemoteTemplates(t, "courier/builtin/templates/verification/valid", courier.TypeVerificationValid) + testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification/valid", courier.TypeVerificationValid) }) } diff --git a/courier/template/load_template.go b/courier/template/load_template.go index ceceb43e196..d3e3194cbd3 100644 --- a/courier/template/load_template.go +++ b/courier/template/load_template.go @@ -149,7 +149,7 @@ func loadTemplate(filesystem fs.FS, name, pattern string, html bool) (Template, return tpl, nil } -func LoadTextTemplate(ctx context.Context, d templateDependencies, filesystem fs.FS, name, pattern string, model interface{}, remoteURL string) (string, error) { +func LoadText(ctx context.Context, d templateDependencies, filesystem fs.FS, name, pattern string, model interface{}, remoteURL string) (string, error) { var t Template var err error if remoteURL != "" { @@ -171,7 +171,7 @@ func LoadTextTemplate(ctx context.Context, d templateDependencies, filesystem fs return b.String(), nil } -func LoadHTMLTemplate(ctx context.Context, d templateDependencies, filesystem fs.FS, name, pattern string, model interface{}, remoteURL string) (string, error) { +func LoadHTML(ctx context.Context, d templateDependencies, filesystem fs.FS, name, pattern string, model interface{}, remoteURL string) (string, error) { var t Template var err error if remoteURL != "" { diff --git a/courier/template/load_template_test.go b/courier/template/load_template_test.go index b27ac0f21a0..89b9a936244 100644 --- a/courier/template/load_template_test.go +++ b/courier/template/load_template_test.go @@ -30,7 +30,7 @@ func TestLoadTextTemplate(t *testing.T) { var executeTextTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string { ctx := context.Background() _, reg := internal.NewFastRegistryWithMocks(t) - tp, err := template.LoadTextTemplate(ctx, reg, os.DirFS(dir), name, pattern, model, "") + tp, err := template.LoadText(ctx, reg, os.DirFS(dir), name, pattern, model, "") require.NoError(t, err) return tp } @@ -38,7 +38,7 @@ func TestLoadTextTemplate(t *testing.T) { var executeHTMLTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string { ctx := context.Background() _, reg := internal.NewFastRegistryWithMocks(t) - tp, err := template.LoadHTMLTemplate(ctx, reg, os.DirFS(dir), name, pattern, model, "") + tp, err := template.LoadHTML(ctx, reg, os.DirFS(dir), name, pattern, model, "") require.NoError(t, err) return tp } @@ -70,7 +70,7 @@ func TestLoadTextTemplate(t *testing.T) { for _, tc := range nonhermetic { t.Run("case=should not support function: "+tc, func(t *testing.T) { - _, err := template.LoadTextTemplate(ctx, reg, x.NewStubFS(tc, []byte(fmt.Sprintf("{{ %s }}", tc))), tc, "", map[string]interface{}{}, "") + _, err := template.LoadText(ctx, reg, x.NewStubFS(tc, []byte(fmt.Sprintf("{{ %s }}", tc))), tc, "", map[string]interface{}{}, "") require.Error(t, err) require.Contains(t, err.Error(), fmt.Sprintf("function \"%s\" not defined", tc)) }) @@ -108,7 +108,7 @@ func TestLoadTextTemplate(t *testing.T) { f, err := ioutil.ReadFile("courier/builtin/templates/test_stub/email.body.html.en_US.gotmpl") require.NoError(t, err) b64 := base64.StdEncoding.EncodeToString(f) - tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", m, "base64://"+b64) + tp, err := template.LoadHTML(ctx, reg, nil, "", "", m, "base64://"+b64) require.NoError(t, err) assert.Contains(t, tp, "lang=en_US") }) @@ -120,7 +120,7 @@ func TestLoadTextTemplate(t *testing.T) { b64 := base64.StdEncoding.EncodeToString(f) - tp, err := template.LoadTextTemplate(ctx, reg, nil, "", "", m, "base64://"+b64) + tp, err := template.LoadText(ctx, reg, nil, "", "", m, "base64://"+b64) require.NoError(t, err) assert.Contains(t, tp, "stub email body something") }) @@ -130,14 +130,14 @@ func TestLoadTextTemplate(t *testing.T) { t.Run("case=file resource", func(t *testing.T) { t.Run("case=html template", func(t *testing.T) { m := map[string]interface{}{"lang": "en_US"} - tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", m, "file://courier/builtin/templates/test_stub/email.body.html.en_US.gotmpl") + tp, err := template.LoadHTML(ctx, reg, nil, "", "", m, "file://courier/builtin/templates/test_stub/email.body.html.en_US.gotmpl") require.NoError(t, err) assert.Contains(t, tp, "lang=en_US") }) t.Run("case=plaintext", func(t *testing.T) { m := map[string]interface{}{"Body": "something"} - tp, err := template.LoadTextTemplate(ctx, reg, nil, "", "", m, "file://courier/builtin/templates/test_stub/email.body.plaintext.gotmpl") + tp, err := template.LoadText(ctx, reg, nil, "", "", m, "file://courier/builtin/templates/test_stub/email.body.plaintext.gotmpl") require.NoError(t, err) assert.Contains(t, tp, "stub email body something") }) @@ -156,14 +156,14 @@ func TestLoadTextTemplate(t *testing.T) { t.Run("case=html template", func(t *testing.T) { m := map[string]interface{}{"lang": "en_US"} - tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", m, ts.URL+"/html") + tp, err := template.LoadHTML(ctx, reg, nil, "", "", m, ts.URL+"/html") require.NoError(t, err) assert.Contains(t, tp, "lang=en_US") }) t.Run("case=plaintext", func(t *testing.T) { m := map[string]interface{}{"Body": "something"} - tp, err := template.LoadTextTemplate(ctx, reg, nil, "", "", m, ts.URL+"/plaintext") + tp, err := template.LoadText(ctx, reg, nil, "", "", m, ts.URL+"/plaintext") require.NoError(t, err) assert.Contains(t, tp, "stub email body something") }) @@ -171,12 +171,12 @@ func TestLoadTextTemplate(t *testing.T) { }) t.Run("case=unsupported resource", func(t *testing.T) { - tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "grpc://unsupported-url") + tp, err := template.LoadHTML(ctx, reg, nil, "", "", map[string]interface{}{}, "grpc://unsupported-url") require.ErrorIs(t, err, fetcher.ErrUnknownScheme) require.Empty(t, tp) - tp, err = template.LoadTextTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "grpc://unsupported-url") + tp, err = template.LoadText(ctx, reg, nil, "", "", map[string]interface{}{}, "grpc://unsupported-url") require.ErrorIs(t, err, fetcher.ErrUnknownScheme) require.Empty(t, tp) }) @@ -186,22 +186,22 @@ func TestLoadTextTemplate(t *testing.T) { reg.HTTPClient(ctx).RetryMax = 1 reg.HTTPClient(ctx).RetryWaitMax = time.Millisecond - _, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "http://localhost:8080/1234") + _, err := template.LoadHTML(ctx, reg, nil, "", "", map[string]interface{}{}, "http://localhost:8080/1234") require.Error(t, err) assert.Contains(t, err.Error(), "is in the") - _, err = template.LoadTextTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "http://localhost:8080/1234") + _, err = template.LoadText(ctx, reg, nil, "", "", map[string]interface{}{}, "http://localhost:8080/1234") require.Error(t, err) assert.Contains(t, err.Error(), "is in the") }) t.Run("method=cache works", func(t *testing.T) { - tp1, err := template.LoadTextTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "base64://e3sgJGwgOj0gY2F0ICJsYW5nPSIgLmxhbmcgfX0Ke3sgbm9zcGFjZSAkbCB9fQ==") + tp1, err := template.LoadText(ctx, reg, nil, "", "", map[string]interface{}{}, "base64://e3sgJGwgOj0gY2F0ICJsYW5nPSIgLmxhbmcgfX0Ke3sgbm9zcGFjZSAkbCB9fQ==") assert.NoError(t, err) - tp2, err := template.LoadTextTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "base64://c3R1YiBlbWFpbCBib2R5IHt7IC5Cb2R5IH19") + tp2, err := template.LoadText(ctx, reg, nil, "", "", map[string]interface{}{}, "base64://c3R1YiBlbWFpbCBib2R5IHt7IC5Cb2R5IH19") assert.NoError(t, err) require.NotEqualf(t, tp1, tp2, "Expected remote template 1 and remote template 2 to not be equal") diff --git a/courier/template/recovery_invalid.go b/courier/template/recovery_invalid.go deleted file mode 100644 index a90995e3364..00000000000 --- a/courier/template/recovery_invalid.go +++ /dev/null @@ -1,41 +0,0 @@ -package template - -import ( - "context" - "encoding/json" - "os" -) - -type ( - RecoveryInvalid struct { - d TemplateDependencies - m *RecoveryInvalidModel - } - RecoveryInvalidModel struct { - To string - } -) - -func NewRecoveryInvalid(d TemplateDependencies, m *RecoveryInvalidModel) *RecoveryInvalid { - return &RecoveryInvalid{d: d, m: m} -} - -func (t *RecoveryInvalid) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *RecoveryInvalid) EmailSubject(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.subject.gotmpl", "recovery/invalid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Subject) -} - -func (t *RecoveryInvalid) EmailBody(ctx context.Context) (string, error) { - return LoadHTMLTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.body.gotmpl", "recovery/invalid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Body.HTML) -} - -func (t *RecoveryInvalid) EmailBodyPlaintext(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.body.plaintext.gotmpl", "recovery/invalid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Body.PlainText) -} - -func (t *RecoveryInvalid) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/courier/template/recovery_valid.go b/courier/template/recovery_valid.go deleted file mode 100644 index ba7d1c0fe18..00000000000 --- a/courier/template/recovery_valid.go +++ /dev/null @@ -1,43 +0,0 @@ -package template - -import ( - "context" - "encoding/json" - "os" -) - -type ( - RecoveryValid struct { - d TemplateDependencies - m *RecoveryValidModel - } - RecoveryValidModel struct { - To string - RecoveryURL string - Identity map[string]interface{} - } -) - -func NewRecoveryValid(d TemplateDependencies, m *RecoveryValidModel) *RecoveryValid { - return &RecoveryValid{d: d, m: m} -} - -func (t *RecoveryValid) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *RecoveryValid) EmailSubject(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.subject.gotmpl", "recovery/valid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Subject) -} - -func (t *RecoveryValid) EmailBody(ctx context.Context) (string, error) { - return LoadHTMLTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.body.gotmpl", "recovery/valid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Body.HTML) -} - -func (t *RecoveryValid) EmailBodyPlaintext(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.body.plaintext.gotmpl", "recovery/valid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Body.PlainText) -} - -func (t *RecoveryValid) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/courier/template/sms/otp.go b/courier/template/sms/otp.go new file mode 100644 index 00000000000..38e31446c1a --- /dev/null +++ b/courier/template/sms/otp.go @@ -0,0 +1,37 @@ +package sms + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + OTPMessage struct { + d template.Dependencies + m *OTPMessageModel + } + + OTPMessageModel struct { + To string + Code string + } +) + +func NewOTPMessage(d template.Dependencies, m *OTPMessageModel) *OTPMessage { + return &OTPMessage{d: d, m: m} +} + +func (t *OTPMessage) PhoneNumber() (string, error) { + return t.m.To, nil +} + +func (t *OTPMessage) SMSBody(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "otp/sms.body.gotmpl", "otp/sms.body*", t.m, "") +} + +func (t *OTPMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/sms/otp_test.go b/courier/template/sms/otp_test.go new file mode 100644 index 00000000000..cb97c00d1c0 --- /dev/null +++ b/courier/template/sms/otp_test.go @@ -0,0 +1,34 @@ +package sms_test + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/internal" +) + +func TestNewOTPMessage(t *testing.T) { + _, reg := internal.NewFastRegistryWithMocks(t) + + const ( + expectedPhone = "+12345678901" + otp = "012345" + ) + + tpl := sms.NewOTPMessage(reg, &sms.OTPMessageModel{To: expectedPhone, Code: otp}) + + expectedBody := fmt.Sprintf("Hi, please verify your account using following code:\n\n%s\n", otp) + + actualBody, err := tpl.SMSBody(context.Background()) + require.NoError(t, err) + assert.Equal(t, expectedBody, actualBody) + + actualPhone, err := tpl.PhoneNumber() + require.NoError(t, err) + assert.Equal(t, expectedPhone, actualPhone) +} diff --git a/courier/template/sms/stub.go b/courier/template/sms/stub.go new file mode 100644 index 00000000000..84140015635 --- /dev/null +++ b/courier/template/sms/stub.go @@ -0,0 +1,37 @@ +package sms + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + TestStub struct { + d template.Dependencies + m *TestStubModel + } + + TestStubModel struct { + To string + Body string + } +) + +func NewTestStub(d template.Dependencies, m *TestStubModel) *TestStub { + return &TestStub{d: d, m: m} +} + +func (t *TestStub) PhoneNumber() (string, error) { + return t.m.To, nil +} + +func (t *TestStub) SMSBody(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "otp/test_stub/sms.body.gotmpl", "otp/test_stub/sms.body*", t.m, "") +} + +func (t *TestStub) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/sms/stub_test.go b/courier/template/sms/stub_test.go new file mode 100644 index 00000000000..9b170a5532e --- /dev/null +++ b/courier/template/sms/stub_test.go @@ -0,0 +1,31 @@ +package sms_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/internal" +) + +func TestNewTestStub(t *testing.T) { + _, reg := internal.NewFastRegistryWithMocks(t) + + const ( + expectedPhone = "+12345678901" + expectedBody = "test sms" + ) + + tpl := sms.NewTestStub(reg, &sms.TestStubModel{To: expectedPhone, Body: expectedBody}) + + actualBody, err := tpl.SMSBody(context.Background()) + require.NoError(t, err) + assert.Equal(t, "stub sms body test sms\n", actualBody) + + actualPhone, err := tpl.PhoneNumber() + require.NoError(t, err) + assert.Equal(t, expectedPhone, actualPhone) +} diff --git a/courier/template/stub.go b/courier/template/stub.go deleted file mode 100644 index 58f95a37669..00000000000 --- a/courier/template/stub.go +++ /dev/null @@ -1,42 +0,0 @@ -package template - -import ( - "context" - "encoding/json" - "os" -) - -type TestStub struct { - d TemplateDependencies - m *TestStubModel -} - -type TestStubModel struct { - To string - Subject string - Body string -} - -func NewTestStub(d TemplateDependencies, m *TestStubModel) *TestStub { - return &TestStub{d: d, m: m} -} - -func (t *TestStub) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *TestStub) EmailSubject(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.subject.gotmpl", "test_stub/email.subject*", t.m, "") -} - -func (t *TestStub) EmailBody(ctx context.Context) (string, error) { - return LoadHTMLTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.body.gotmpl", "test_stub/email.body*", t.m, "") -} - -func (t *TestStub) EmailBodyPlaintext(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.body.plaintext.gotmpl", "test_stub/email.body.plaintext*", t.m, "") -} - -func (t *TestStub) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/courier/template/template.go b/courier/template/template.go index 46e94b8e9c6..f81e8ce444a 100644 --- a/courier/template/template.go +++ b/courier/template/template.go @@ -10,14 +10,15 @@ import ( ) type ( - TemplateConfig interface { + Config interface { CourierTemplatesRoot() string CourierTemplatesVerificationInvalid() *config.CourierEmailTemplate CourierTemplatesVerificationValid() *config.CourierEmailTemplate CourierTemplatesRecoveryInvalid() *config.CourierEmailTemplate CourierTemplatesRecoveryValid() *config.CourierEmailTemplate } - TemplateDependencies interface { + + Dependencies interface { CourierConfig(ctx context.Context) config.CourierConfigs HTTPClient(ctx context.Context, opts ...httpx.ResilientOptions) *retryablehttp.Client } diff --git a/courier/template/testhelpers/testhelpers.go b/courier/template/testhelpers/testhelpers.go index 0e2a3a49082..895ec767f40 100644 --- a/courier/template/testhelpers/testhelpers.go +++ b/courier/template/testhelpers/testhelpers.go @@ -9,6 +9,8 @@ import ( "path" "testing" + "github.com/ory/kratos/courier/template/email" + "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -55,21 +57,21 @@ func TestRemoteTemplates(t *testing.T, basePath string, tmplType courier.Templat return base64.StdEncoding.EncodeToString(f) } - getTemplate := func(tmpl courier.TemplateType, d template.TemplateDependencies) interface { + getTemplate := func(tmpl courier.TemplateType, d template.Dependencies) interface { EmailBody(context.Context) (string, error) EmailSubject(context.Context) (string, error) } { switch tmpl { case courier.TypeRecoveryInvalid: - return template.NewRecoveryInvalid(d, &template.RecoveryInvalidModel{}) + return email.NewRecoveryInvalid(d, &email.RecoveryInvalidModel{}) case courier.TypeRecoveryValid: - return template.NewRecoveryValid(d, &template.RecoveryValidModel{}) + return email.NewRecoveryValid(d, &email.RecoveryValidModel{}) case courier.TypeTestStub: - return template.NewTestStub(d, &template.TestStubModel{}) + return email.NewTestStub(d, &email.TestStubModel{}) case courier.TypeVerificationInvalid: - return template.NewVerificationInvalid(d, &template.VerificationInvalidModel{}) + return email.NewVerificationInvalid(d, &email.VerificationInvalidModel{}) case courier.TypeVerificationValid: - return template.NewVerificationValid(d, &template.VerificationValidModel{}) + return email.NewVerificationValid(d, &email.VerificationValidModel{}) default: return nil } diff --git a/courier/template/verification_invalid.go b/courier/template/verification_invalid.go deleted file mode 100644 index e78ec3a106f..00000000000 --- a/courier/template/verification_invalid.go +++ /dev/null @@ -1,41 +0,0 @@ -package template - -import ( - "context" - "encoding/json" - "os" -) - -type ( - VerificationInvalid struct { - d TemplateDependencies - m *VerificationInvalidModel - } - VerificationInvalidModel struct { - To string - } -) - -func NewVerificationInvalid(d TemplateDependencies, m *VerificationInvalidModel) *VerificationInvalid { - return &VerificationInvalid{d: d, m: m} -} - -func (t *VerificationInvalid) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *VerificationInvalid) EmailSubject(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.subject.gotmpl", "verification/invalid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Subject) -} - -func (t *VerificationInvalid) EmailBody(ctx context.Context) (string, error) { - return LoadHTMLTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.body.gotmpl", "verification/invalid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Body.HTML) -} - -func (t *VerificationInvalid) EmailBodyPlaintext(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.body.plaintext.gotmpl", "verification/invalid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Body.PlainText) -} - -func (t *VerificationInvalid) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/courier/template/verification_valid.go b/courier/template/verification_valid.go deleted file mode 100644 index cdd6e25c6b8..00000000000 --- a/courier/template/verification_valid.go +++ /dev/null @@ -1,43 +0,0 @@ -package template - -import ( - "context" - "encoding/json" - "os" -) - -type ( - VerificationValid struct { - d TemplateDependencies - m *VerificationValidModel - } - VerificationValidModel struct { - To string - VerificationURL string - Identity map[string]interface{} - } -) - -func NewVerificationValid(d TemplateDependencies, m *VerificationValidModel) *VerificationValid { - return &VerificationValid{d: d, m: m} -} - -func (t *VerificationValid) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *VerificationValid) EmailSubject(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.subject.gotmpl", "verification/valid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Subject) -} - -func (t *VerificationValid) EmailBody(ctx context.Context) (string, error) { - return LoadHTMLTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.body.gotmpl", "verification/valid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Body.HTML) -} - -func (t *VerificationValid) EmailBodyPlaintext(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.body.plaintext.gotmpl", "verification/valid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Body.PlainText) -} - -func (t *VerificationValid) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/driver/config/config.go b/driver/config/config.go index 1f728668b32..588225a632a 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -69,6 +69,9 @@ const ( ViperKeyCourierSMTPFrom = "courier.smtp.from_address" ViperKeyCourierSMTPFromName = "courier.smtp.from_name" ViperKeyCourierSMTPHeaders = "courier.smtp.headers" + ViperKeyCourierSMSRequestConfig = "courier.sms.request_config" + ViperKeyCourierSMSEnabled = "courier.sms.enabled" + ViperKeyCourierSMSFrom = "courier.sms.from" ViperKeySecretsDefault = "secrets.default" ViperKeySecretsCookie = "secrets.cookie" ViperKeySecretsCipher = "secrets.cipher" @@ -235,6 +238,9 @@ type ( CourierSMTPFrom() string CourierSMTPFromName() string CourierSMTPHeaders() map[string]string + CourierSMSEnabled() bool + CourierSMSFrom() string + CourierSMSRequestConfig() json.RawMessage CourierTemplatesRoot() string CourierTemplatesVerificationInvalid() *CourierEmailTemplate CourierTemplatesVerificationValid() *CourierEmailTemplate @@ -919,6 +925,33 @@ func (p *Config) CourierSMTPHeaders() map[string]string { return p.p.StringMap(ViperKeyCourierSMTPHeaders) } +func (p *Config) CourierSMSRequestConfig() json.RawMessage { + if !p.p.Bool(ViperKeyCourierSMSEnabled) { + return nil + } + + out, err := p.p.Marshal(kjson.Parser()) + if err != nil { + p.l.WithError(err).Warn("Unable to marshal self service strategy configuration.") + return nil + } + + config := gjson.GetBytes(out, ViperKeyCourierSMSRequestConfig).Raw + if len(config) <= 0 { + return json.RawMessage("{}") + } + + return json.RawMessage(config) +} + +func (p *Config) CourierSMSFrom() string { + return p.p.StringF(ViperKeyCourierSMSFrom, "Ory Kratos") +} + +func (p *Config) CourierSMSEnabled() bool { + return p.p.Bool(ViperKeyCourierSMSEnabled) +} + func splitUrlAndFragment(s string) (string, string) { i := strings.IndexByte(s, '#') if i < 0 { diff --git a/driver/registry_default.go b/driver/registry_default.go index 0d1e6fa1d05..62a089edb23 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -588,8 +588,8 @@ func (m *RegistryDefault) SetPersister(p persistence.Persister) { m.persister = p } -func (m *RegistryDefault) Courier(ctx context.Context) *courier.Courier { - return courier.NewSMTP(ctx, m) +func (m *RegistryDefault) Courier(ctx context.Context) courier.Courier { + return courier.NewCourier(ctx, m) } func (m *RegistryDefault) ContinuityManager() continuity.Manager { diff --git a/embedx/config.schema.json b/embedx/config.schema.json index d0310b7ca86..362a4902b51 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -1445,6 +1445,79 @@ "connection_uri" ], "additionalProperties": false + }, + "sms": { + "title": "SMS sender configuration", + "description": "Configures outgoing sms messages using HTTP protocol with generic SMS provider", + "type": "object", + "properties": { + "enabled": { + "description": "Determines if SMS functionality is enabled", + "type": "boolean", + "default": false + }, + "from": { + "title": "SMS Sender Address", + "description": "The recipient of a sms will see this as the sender address.", + "type": "string", + "default": "Ory Kratos" + }, + "request_config": { + "type": "object", + "properties": { + "url": { + "title": "HTTP address of API endpoint", + "description": "This URL will be used to connect to the SMS provider.", + "examples": [ + "https://api.twillio.com/sms/send" + ], + "type": "string", + "pattern": "^https?:\\/\\/.*" + }, + "method": { + "type": "string", + "description": "The HTTP method to use (GET, POST, etc)." + }, + "body": { + "type": "string", + "oneOf": [ + { + "format": "uri", + "pattern": "^(http|https|file|base64)://", + "description": "URI pointing to the jsonnet template used for payload generation. Only used for those HTTP methods, which support HTTP body payloads", + "examples": [ + "file:///path/to/body.jsonnet", + "file://./body.jsonnet", + "base64://ZnVuY3Rpb24oY3R4KSB7CiAgaWRlbnRpdHlfaWQ6IGlmIGN0eFsiaWRlbnRpdHkiXSAhPSBudWxsIHRoZW4gY3R4LmlkZW50aXR5LmlkLAp9=", + "https://oryapis.com/default_body.jsonnet" + ] + }, + { + "description": "DEPRECATED: please use a URI instead (i.e. prefix your filepath with 'file://')", + "not": { + "pattern": "^(http|https|file|base64)://" + } + } + ] + }, + "auth": { + "type": "object", + "title": "Auth mechanisms", + "description": "Define which auth mechanism to use for auth with the SMS provider", + "oneOf": [ + { + "$ref": "#/definitions/webHookAuthApiKeyProperties" + }, + { + "$ref": "#/definitions/webHookAuthBasicAuthProperties" + } + ] + }, + "additionalProperties": false + } + } + }, + "additionalProperties": false } }, "required": [ diff --git a/request/auth.go b/request/auth.go new file mode 100644 index 00000000000..65df14402fa --- /dev/null +++ b/request/auth.go @@ -0,0 +1,31 @@ +package request + +import ( + "encoding/json" + "fmt" + "net/http" +) + +type ( + AuthStrategy interface { + apply(req *http.Request) + } + + authStrategyFactory func(c json.RawMessage) (AuthStrategy, error) +) + +var strategyFactories = map[string]authStrategyFactory{ + "": newNoopAuthStrategy, + "api_key": newApiKeyStrategy, + "basic_auth": newBasicAuthStrategy, +} + +func authStrategy(name string, config json.RawMessage) (AuthStrategy, error) { + strategyFactory, ok := strategyFactories[name] + if ok { + return strategyFactory(config) + } + + return nil, fmt.Errorf("unsupported auth type: %s", name) + +} diff --git a/request/auth_strategy.go b/request/auth_strategy.go new file mode 100644 index 00000000000..e2e41b9e0f8 --- /dev/null +++ b/request/auth_strategy.go @@ -0,0 +1,76 @@ +package request + +import ( + "encoding/json" + "net/http" +) + +type ( + noopAuthStrategy struct{} + + basicAuthStrategy struct { + user string + password string + } + + apiKeyStrategy struct { + name string + value string + in string + } +) + +func newNoopAuthStrategy(_ json.RawMessage) (AuthStrategy, error) { + return &noopAuthStrategy{}, nil +} + +func (c *noopAuthStrategy) apply(_ *http.Request) {} + +func newBasicAuthStrategy(raw json.RawMessage) (AuthStrategy, error) { + type config struct { + User string + Password string + } + + var c config + if err := json.Unmarshal(raw, &c); err != nil { + return nil, err + } + + return &basicAuthStrategy{ + user: c.User, + password: c.Password, + }, nil +} + +func (c *basicAuthStrategy) apply(req *http.Request) { + req.SetBasicAuth(c.user, c.password) +} + +func newApiKeyStrategy(raw json.RawMessage) (AuthStrategy, error) { + type config struct { + In string + Name string + Value string + } + + var c config + if err := json.Unmarshal(raw, &c); err != nil { + return nil, err + } + + return &apiKeyStrategy{ + in: c.In, + name: c.Name, + value: c.Value, + }, nil +} + +func (c *apiKeyStrategy) apply(req *http.Request) { + switch c.in { + case "cookie": + req.AddCookie(&http.Cookie{Name: c.name, Value: c.value}) + default: + req.Header.Set(c.name, c.value) + } +} diff --git a/request/auth_strategy_test.go b/request/auth_strategy_test.go new file mode 100644 index 00000000000..b22d140c46e --- /dev/null +++ b/request/auth_strategy_test.go @@ -0,0 +1,67 @@ +package request + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNoopAuthStrategy(t *testing.T) { + req := http.Request{Header: map[string][]string{}} + auth := noopAuthStrategy{} + + auth.apply(&req) + + assert.Empty(t, req.Header, "Empty auth strategy shall not modify any request headers") +} + +func TestBasicAuthStrategy(t *testing.T) { + req := http.Request{Header: map[string][]string{}} + auth := basicAuthStrategy{ + user: "test-user", + password: "test-pass", + } + + auth.apply(&req) + + assert.Len(t, req.Header, 1) + + user, pass, _ := req.BasicAuth() + assert.Equal(t, "test-user", user) + assert.Equal(t, "test-pass", pass) +} + +func TestApiKeyInHeaderStrategy(t *testing.T) { + req := http.Request{Header: map[string][]string{}} + auth := apiKeyStrategy{ + in: "header", + name: "my-api-key-name", + value: "my-api-key-value", + } + + auth.apply(&req) + + require.Len(t, req.Header, 1) + + actualValue := req.Header.Get("my-api-key-name") + assert.Equal(t, "my-api-key-value", actualValue) +} + +func TestApiKeyInCookieStrategy(t *testing.T) { + req := http.Request{Header: map[string][]string{}} + auth := apiKeyStrategy{ + in: "cookie", + name: "my-api-key-name", + value: "my-api-key-value", + } + + auth.apply(&req) + + cookies := req.Cookies() + assert.Len(t, cookies, 1) + + assert.Equal(t, "my-api-key-name", cookies[0].Name) + assert.Equal(t, "my-api-key-value", cookies[0].Value) +} diff --git a/request/auth_test.go b/request/auth_test.go new file mode 100644 index 00000000000..c0df7933690 --- /dev/null +++ b/request/auth_test.go @@ -0,0 +1,56 @@ +package request + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthStrategy(t *testing.T) { + for _, tc := range map[string]struct { + name string + config string + expected AuthStrategy + }{ + "noop": { + name: "", + config: "", + expected: &noopAuthStrategy{}, + }, + "basic_auth": { + name: "basic_auth", + config: `{ + "user": "test-api-user", + "password": "secret" + }`, + expected: &basicAuthStrategy{}, + }, + "api-key/header": { + name: "api_key", + config: `{ + "in": "header", + "name": "my-api-key", + "value": "secret" + }`, + expected: &apiKeyStrategy{}, + }, + "api-key/cookie": { + name: "api_key", + config: `{ + "in": "cookie", + "name": "my-api-key", + "value": "secret" + }`, + expected: &apiKeyStrategy{}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + strategy, err := authStrategy(tc.name, json.RawMessage(tc.config)) + require.NoError(t, err) + + assert.IsTypef(t, tc.expected, strategy, "auth strategy should be of the expected type") + }) + } +} diff --git a/request/builder.go b/request/builder.go new file mode 100644 index 00000000000..21783c73c10 --- /dev/null +++ b/request/builder.go @@ -0,0 +1,203 @@ +package request + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "net/url" + "reflect" + "strings" + + "github.com/google/go-jsonnet" + + "github.com/ory/x/fetcher" + "github.com/ory/x/logrusx" +) + +const ( + ContentTypeForm = "application/x-www-form-urlencoded" + ContentTypeJSON = "application/json" +) + +type Builder struct { + r *http.Request + log *logrusx.Logger + conf *Config +} + +func NewBuilder(config json.RawMessage, l *logrusx.Logger) (*Builder, error) { + c, err := parseConfig(config) + if err != nil { + return nil, err + } + + r, err := http.NewRequest(c.Method, c.URL, nil) + if err != nil { + return nil, err + } + + return &Builder{ + r: r, + log: l, + conf: c, + }, nil +} + +func (b *Builder) addAuth() error { + authConfig := b.conf.Auth + + strategy, err := authStrategy(authConfig.Type, authConfig.Config) + if err != nil { + return err + } + + strategy.apply(b.r) + + return nil +} + +func (b *Builder) addBody(body interface{}) error { + if isNilInterface(body) { + return nil + } + + contentType := b.r.Header.Get("Content-Type") + + if b.conf.TemplateURI == "" { + return errors.New("got empty template path for request with body") + } + + switch contentType { + case ContentTypeForm: + if err := b.addURLEncodedBody(body); err != nil { + return err + } + case ContentTypeJSON: + if err := b.addJSONBody(body); err != nil { + return err + } + default: + return errors.New("invalid config - incorrect Content-Type for request with body") + } + + return nil +} + +func (b *Builder) addJSONBody(body interface{}) error { + tURL := b.conf.TemplateURI + + tpl, err := readTemplate(tURL, b.log) + if err != nil { + return err + } + + buf := new(bytes.Buffer) + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) + enc.SetIndent("", "") + + if err := enc.Encode(body); err != nil { + return err + } + + vm := jsonnet.MakeVM() + vm.TLACode("ctx", buf.String()) + + res, err := vm.EvaluateAnonymousSnippet(tURL, tpl.String()) + if err != nil { + return err + } + + rb := strings.NewReader(res) + b.r.Body = io.NopCloser(rb) + b.r.ContentLength = int64(rb.Len()) + + return nil +} + +func (b *Builder) addURLEncodedBody(body interface{}) error { + tURL := b.conf.TemplateURI + tpl, err := readTemplate(tURL, b.log) + if err != nil { + return err + } + + buf := new(bytes.Buffer) + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) + enc.SetIndent("", "") + + if err := enc.Encode(body); err != nil { + return err + } + + vm := jsonnet.MakeVM() + vm.TLACode("ctx", buf.String()) + + res, err := vm.EvaluateAnonymousSnippet(tURL, tpl.String()) + if err != nil { + return err + } + + values := map[string]string{} + if err := json.Unmarshal([]byte(res), &values); err != nil { + return err + } + + u := url.Values{} + + for key, value := range values { + u.Add(key, value) + } + + rb := strings.NewReader(u.Encode()) + b.r.Body = io.NopCloser(rb) + + return nil +} + +func (b *Builder) BuildRequest(body interface{}) (*http.Request, error) { + b.r.Header = b.conf.Header + if err := b.addAuth(); err != nil { + return nil, err + } + + // According to the HTTP spec any request method, but TRACE is allowed to + // have a body. Even this is a bad practice for some of them, like for GET + if b.conf.Method != http.MethodTrace { + if err := b.addBody(body); err != nil { + return nil, err + } + } + + return b.r, nil +} + +func readTemplate(templateURI string, l *logrusx.Logger) (*bytes.Buffer, error) { + if templateURI == "" { + return nil, nil + } + + f := fetcher.NewFetcher() + + tpl, err := f.Fetch(templateURI) + if errors.Is(err, fetcher.ErrUnknownScheme) { + // legacy filepath + templateURI = "file://" + templateURI + l.WithError(err).Warnf("support for filepaths without a 'file://' scheme will be dropped in the next release, please use %s instead in your config", templateURI) + + tpl, err = f.Fetch(templateURI) + } + // this handles the first error if it is a known scheme error, or the second fetch error + if err != nil { + return nil, err + } + + return tpl, nil +} + +func isNilInterface(i interface{}) bool { + return i == nil || (reflect.ValueOf(i).Kind() == reflect.Ptr && reflect.ValueOf(i).IsNil()) +} diff --git a/request/builder_test.go b/request/builder_test.go new file mode 100644 index 00000000000..aa8fd9a9078 --- /dev/null +++ b/request/builder_test.go @@ -0,0 +1,272 @@ +package request + +import ( + _ "embed" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/x/logrusx" +) + +type testRequestBody struct { + To string + From string + Body string +} + +//go:embed stub/test_body.jsonnet +var testJSONNetTemplate []byte + +func TestBuildRequest(t *testing.T) { + for _, tc := range []struct { + name string + method string + url string + authStrategy string + header http.Header + bodyTemplateURI string + body *testRequestBody + expectedBody string + rawConfig string + }{ + { + name: "POST request without auth", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint1", + authStrategy: "", // noop strategy + bodyTemplateURI: "file://./stub/test_body.jsonnet", + body: &testRequestBody{ + To: "+15056445993", + From: "+12288534869", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+12288534869\",\n \"To\": \"+15056445993\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint1", + "method": "POST", + "body": "file://./stub/test_body.jsonnet" + }`, + }, + { + name: "POST request with legacy template path", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint1", + bodyTemplateURI: "./stub/test_body.jsonnet", + body: &testRequestBody{ + To: "+15056445993", + From: "+12288534869", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+12288534869\",\n \"To\": \"+15056445993\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint1", + "method": "POST", + "body": "./stub/test_body.jsonnet" + }`, + }, + { + name: "POST request with base64 encoded template path", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint1", + bodyTemplateURI: "base64://" + base64.StdEncoding.EncodeToString(testJSONNetTemplate), + body: &testRequestBody{ + To: "+15056445993", + From: "+12288534869", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+12288534869\",\n \"To\": \"+15056445993\"\n}\n", + rawConfig: fmt.Sprintf(`{ + "url": "https://test.kratos.ory.sh/my_endpoint1", + "method": "POST", + "body": "base64://%s" + }`, base64.StdEncoding.EncodeToString(testJSONNetTemplate)), + }, + { + name: "POST request with custom header", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint2", + authStrategy: "", + header: map[string][]string{"Custom-Header": {"test"}}, + bodyTemplateURI: "file://./stub/test_body.jsonnet", + body: &testRequestBody{ + To: "+12127110378", + From: "+15822228108", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+15822228108\",\n \"To\": \"+12127110378\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint2", + "method": "POST", + "header": { + "Custom-Header": "test" + }, + "body": "file://./stub/test_body.jsonnet" + }`, + }, + { + name: "GET request with body", + method: "GET", + url: "https://test.kratos.ory.sh/my_endpoint3", + authStrategy: "basic_auth", + bodyTemplateURI: "file://./stub/test_body.jsonnet", + body: &testRequestBody{ + To: "+14134242223", + From: "+13104661805", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+13104661805\",\n \"To\": \"+14134242223\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint3", + "method": "GET", + "auth": { + "type": "basic_auth", + "config": { + "user": "test-api-user", + "password": "secret" + } + }, + "body": "file://./stub/test_body.jsonnet" + }`, + }, + { + name: "GET request without body", + method: "GET", + url: "https://test.kratos.ory.sh/my_endpoint4", + authStrategy: "basic_auth", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint4", + "method": "GET", + "auth": { + "type": "basic_auth", + "config": { + "user": "test-api-user", + "password": "secret" + } + } + }`, + }, + { + name: "DELETE request with body", + method: "DELETE", + url: "https://test.kratos.ory.sh/my_endpoint5", + authStrategy: "api_key", + bodyTemplateURI: "file://./stub/test_body.jsonnet", + body: &testRequestBody{ + To: "+12235499085", + From: "+14253787846", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+14253787846\",\n \"To\": \"+12235499085\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint5", + "method": "DELETE", + "body": "file://./stub/test_body.jsonnet", + "auth": { + "type": "api_key", + "config": { + "in": "header", + "name": "my-api-key", + "value": "secret" + } + } + }`, + }, + { + name: "POST request with urlencoded body", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint6", + bodyTemplateURI: "file://./stub/test_body.jsonnet", + authStrategy: "api_key", + header: map[string][]string{"Content-Type": {ContentTypeForm}}, + body: &testRequestBody{ + To: "+14134242223", + From: "+13104661805", + Body: "test-sms-body", + }, + expectedBody: "Body=test-sms-body&From=%2B13104661805&To=%2B14134242223", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint6", + "method": "POST", + "body": "file://./stub/test_body.jsonnet", + "header": { + "Content-Type": "application/x-www-form-urlencoded" + }, + "auth": { + "type": "api_key", + "config": { + "in": "cookie", + "name": "my-api-key", + "value": "secret" + } + } + }`, + }, + { + name: "POST request with default body type", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint7", + bodyTemplateURI: "file://./stub/test_body.jsonnet", + authStrategy: "basic_auth", + header: map[string][]string{"Content-Type": {ContentTypeJSON}}, + body: &testRequestBody{ + To: "+14134242223", + From: "+13104661805", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+13104661805\",\n \"To\": \"+14134242223\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint7", + "method": "POST", + "body": "file://./stub/test_body.jsonnet", + "auth": { + "type": "basic_auth", + "config": { + "user": "test-api-user", + "password": "secret" + } + } + }`, + }, + } { + t.Run("request-type="+tc.name, func(t *testing.T) { + l := logrusx.New("kratos", "test") + + rb, err := NewBuilder(json.RawMessage(tc.rawConfig), l) + require.NoError(t, err) + + assert.Equal(t, tc.bodyTemplateURI, rb.conf.TemplateURI) + assert.Equal(t, tc.authStrategy, rb.conf.Auth.Type) + + req, err := rb.BuildRequest(tc.body) + require.NoError(t, err) + + assert.Equal(t, tc.url, req.URL.String()) + assert.Equal(t, tc.method, req.Method) + + if tc.body != nil { + requestBody, err := ioutil.ReadAll(req.Body) + require.NoError(t, err) + + assert.Equal(t, tc.expectedBody, string(requestBody)) + } + + if tc.header != nil { + mustContainHeader(t, tc.header, req.Header) + } + }) + } +} + +func mustContainHeader(t *testing.T, expected http.Header, actual http.Header) { + for k := range expected { + require.Contains(t, actual, k) + assert.Equal(t, expected[k], actual[k]) + } +} diff --git a/request/config.go b/request/config.go new file mode 100644 index 00000000000..caf5061bf32 --- /dev/null +++ b/request/config.go @@ -0,0 +1,61 @@ +package request + +import ( + "encoding/json" + "net/http" + + "github.com/tidwall/gjson" +) + +type ( + Auth struct { + Type string + Config json.RawMessage + } + + Config struct { + Method string `json:"method"` + URL string `json:"url"` + TemplateURI string `json:"body"` + Header http.Header `json:"header"` + Auth Auth `json:"auth,omitempty"` + } +) + +func parseConfig(r json.RawMessage) (*Config, error) { + type rawConfig struct { + Method string `json:"method"` + URL string `json:"url"` + TemplateURI string `json:"body"` + Header json.RawMessage `json:"header"` + Auth Auth `json:"auth,omitempty"` + } + + var rc rawConfig + err := json.Unmarshal(r, &rc) + if err != nil { + return nil, err + } + + rawHeader := gjson.ParseBytes(rc.Header).Map() + hdr := http.Header{} + + _, ok := rawHeader["Content-Type"] + if !ok { + hdr.Set("Content-Type", ContentTypeJSON) + } + + for key, value := range rawHeader { + hdr.Set(key, value.String()) + } + + c := Config{ + Method: rc.Method, + URL: rc.URL, + TemplateURI: rc.TemplateURI, + Header: hdr, + Auth: rc.Auth, + } + + return &c, nil +} diff --git a/request/stub/test_body.jsonnet b/request/stub/test_body.jsonnet new file mode 100644 index 00000000000..03edc83a65e --- /dev/null +++ b/request/stub/test_body.jsonnet @@ -0,0 +1,5 @@ +function(ctx) { + From: ctx.From, + To: ctx.To, + Body: ctx.Body, +} diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index 03476a8ac72..1ae2f5fe6a4 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -1,22 +1,15 @@ package hook import ( - "bytes" "context" "encoding/json" "fmt" - "io" "net/http" "github.com/hashicorp/go-retryablehttp" - "github.com/ory/x/fetcher" - "github.com/ory/x/logrusx" - - "github.com/google/go-jsonnet" - "github.com/pkg/errors" - "github.com/ory/kratos/identity" + "github.com/ory/kratos/request" "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/recovery" @@ -32,32 +25,6 @@ var _ verification.PostHookExecutor = new(WebHook) var _ recovery.PostHookExecutor = new(WebHook) type ( - AuthStrategy interface { - apply(req *retryablehttp.Request) - } - - authStrategyFactory func(c json.RawMessage) (AuthStrategy, error) - - NoopAuthStrategy struct{} - - BasicAuthStrategy struct { - user string - password string - } - - ApiKeyStrategy struct { - name string - value string - in string - } - - WebHookConfig struct { - Method string - URL string - TemplateURI string - Auth AuthStrategy - } - webHookDependencies interface { x.LoggingProvider x.HTTPClientProvider @@ -72,113 +39,13 @@ type ( } WebHook struct { - r webHookDependencies - c json.RawMessage + deps webHookDependencies + conf json.RawMessage } ) -var strategyFactories = map[string]authStrategyFactory{ - "": newNoopAuthStrategy, - "api_key": newApiKeyStrategy, - "basic_auth": newBasicAuthStrategy, -} - -func newAuthStrategy(name string, c json.RawMessage) (as AuthStrategy, err error) { - if f, ok := strategyFactories[name]; ok { - as, err = f(c) - } else { - err = fmt.Errorf("unsupported auth type: %s", name) - } - return -} - -func newNoopAuthStrategy(_ json.RawMessage) (AuthStrategy, error) { - return &NoopAuthStrategy{}, nil -} - -func (c *NoopAuthStrategy) apply(_ *retryablehttp.Request) {} - -func newBasicAuthStrategy(raw json.RawMessage) (AuthStrategy, error) { - type config struct { - User string - Password string - } - - var c config - if err := json.Unmarshal(raw, &c); err != nil { - return nil, err - } - - return &BasicAuthStrategy{ - user: c.User, - password: c.Password, - }, nil -} - -func (c *BasicAuthStrategy) apply(req *retryablehttp.Request) { - req.SetBasicAuth(c.user, c.password) -} - -func newApiKeyStrategy(raw json.RawMessage) (AuthStrategy, error) { - type config struct { - In string - Name string - Value string - } - - var c config - if err := json.Unmarshal(raw, &c); err != nil { - return nil, err - } - - return &ApiKeyStrategy{ - in: c.In, - name: c.Name, - value: c.Value, - }, nil -} - -func (c *ApiKeyStrategy) apply(req *retryablehttp.Request) { - switch c.in { - case "cookie": - req.AddCookie(&http.Cookie{Name: c.name, Value: c.value}) - default: - req.Header.Set(c.name, c.value) - } -} - -func newWebHookConfig(r json.RawMessage) (*WebHookConfig, error) { - type rawWebHookConfig struct { - Method string - Url string - Body string - Auth struct { - Type string - Config json.RawMessage - } - } - - var rc rawWebHookConfig - err := json.Unmarshal(r, &rc) - if err != nil { - return nil, err - } - - as, err := newAuthStrategy(rc.Auth.Type, rc.Auth.Config) - if err != nil { - return nil, fmt.Errorf("failed to create web hook auth strategy: %w", err) - } - - return &WebHookConfig{ - Method: rc.Method, - URL: rc.Url, - TemplateURI: rc.Body, - Auth: as, - }, nil -} - func NewWebHook(r webHookDependencies, c json.RawMessage) *WebHook { - return &WebHook{r: r, c: c} + return &WebHook{deps: r, conf: c} } func (e *WebHook) ExecuteLoginPreHook(_ http.ResponseWriter, req *http.Request, flow *login.Flow) error { @@ -250,88 +117,28 @@ func (e *WebHook) ExecuteSettingsPostPersistHook(_ http.ResponseWriter, req *htt } func (e *WebHook) execute(ctx context.Context, data *templateContext) error { - httpClient := e.r.HTTPClient(ctx) - - // TODO: reminder for the future: move parsing of config to the web hook initialization - conf, err := newWebHookConfig(e.c) + builder, err := request.NewBuilder(e.conf, e.deps.Logger()) if err != nil { - return fmt.Errorf("failed to parse web hook config: %w", err) - } - - var body io.Reader - if conf.Method != "TRACE" { - // According to the HTTP spec any request method, but TRACE is allowed to - // have a body. Even this is a really bad practice for some of them, like for - // GET - body, err = createBody(e.r.Logger(), conf.TemplateURI, data, httpClient) - if err != nil { - return fmt.Errorf("failed to create web hook body: %w", err) - } - } - - if body == nil { - body = bytes.NewReader(make([]byte, 0)) - } - - if err = doHttpCall(ctx, conf.Method, conf.URL, conf.Auth, body, httpClient); err != nil { - return fmt.Errorf("failed to call web hook %w", err) - } - return nil -} - -func createBody(l *logrusx.Logger, templateURI string, data *templateContext, hc *retryablehttp.Client) (*bytes.Reader, error) { - if len(templateURI) == 0 { - return bytes.NewReader(make([]byte, 0)), nil + return err } - f := fetcher.NewFetcher(fetcher.WithClient(hc)) - - template, err := f.Fetch(templateURI) - if errors.Is(err, fetcher.ErrUnknownScheme) { - // legacy filepath - templateURI = "file://" + templateURI - l.WithError(err).Warnf("support for filepaths without a 'file://' scheme will be dropped in the next release, please use %s instead in your config", templateURI) - template, err = f.Fetch(templateURI) - } - // this handles the first error if it is a known scheme error, or the second fetch error + req, err := builder.BuildRequest(data) if err != nil { - return nil, err - } - - vm := jsonnet.MakeVM() - - buf := new(bytes.Buffer) - enc := json.NewEncoder(buf) - enc.SetEscapeHTML(false) - enc.SetIndent("", "") - - if err := enc.Encode(data); err != nil { - return nil, err - } - vm.TLACode("ctx", buf.String()) - - if res, err := vm.EvaluateAnonymousSnippet(templateURI, template.String()); err != nil { - return nil, err - } else { - return bytes.NewReader([]byte(res)), nil + return err } -} -func doHttpCall(ctx context.Context, method string, url string, as AuthStrategy, body io.Reader, hc *retryablehttp.Client) error { - req, err := retryablehttp.NewRequest(method, url, body) - req = req.WithContext(ctx) + httpClient := e.deps.HTTPClient(ctx) + r, err := retryablehttp.FromRequest(req) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - - as.apply(req) - - resp, err := hc.Do(req) + resp, err := httpClient.Do(r) if err != nil { return err - } else if resp.StatusCode >= 400 { + } + + if resp.StatusCode >= http.StatusBadRequest { return fmt.Errorf("web hook failed with status code %v", resp.StatusCode) } diff --git a/selfservice/hook/web_hook_test.go b/selfservice/hook/web_hook_test.go deleted file mode 100644 index 41b0ba135f0..00000000000 --- a/selfservice/hook/web_hook_test.go +++ /dev/null @@ -1,268 +0,0 @@ -package hook - -import ( - _ "embed" - "encoding/base64" - "encoding/json" - "io" - "net/http" - "testing" - - "github.com/hashicorp/go-retryablehttp" - - "github.com/sirupsen/logrus/hooks/test" - - "github.com/ory/x/logrusx" - - "github.com/ory/kratos/identity" - "github.com/ory/kratos/x" - - "github.com/stretchr/testify/require" - - "github.com/ory/kratos/selfservice/flow/login" - - "github.com/stretchr/testify/assert" -) - -func TestNoopAuthStrategy(t *testing.T) { - req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} - auth := NoopAuthStrategy{} - - auth.apply(&req) - - assert.Empty(t, req.Header, "Empty auth strategy shall not modify any request headers") -} - -func TestBasicAuthStrategy(t *testing.T) { - req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} - auth := BasicAuthStrategy{ - user: "test-user", - password: "test-pass", - } - - auth.apply(&req) - - assert.Len(t, req.Header, 1) - - user, pass, _ := req.BasicAuth() - assert.Equal(t, "test-user", user) - assert.Equal(t, "test-pass", pass) -} - -func TestApiKeyInHeaderStrategy(t *testing.T) { - req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} - auth := ApiKeyStrategy{ - in: "header", - name: "my-api-key-name", - value: "my-api-key-value", - } - - auth.apply(&req) - - require.Len(t, req.Header, 1) - - actualValue := req.Header.Get("my-api-key-name") - assert.Equal(t, "my-api-key-value", actualValue) -} - -func TestApiKeyInCookieStrategy(t *testing.T) { - req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} - auth := ApiKeyStrategy{ - in: "cookie", - name: "my-api-key-name", - value: "my-api-key-value", - } - - auth.apply(&req) - - cookies := req.Cookies() - assert.Len(t, cookies, 1) - - assert.Equal(t, "my-api-key-name", cookies[0].Name) - assert.Equal(t, "my-api-key-value", cookies[0].Value) -} - -//go:embed stub/test_body.jsonnet -var testBodyJSONNet []byte - -func TestJsonNetSupport(t *testing.T) { - f := &login.Flow{ID: x.NewUUID()} - i := identity.NewIdentity("") - l := logrusx.New("kratos", "test") - - for _, tc := range []struct { - desc, template string - data *templateContext - }{ - { - desc: "simple file URI", - template: "file://./stub/test_body.jsonnet", - data: &templateContext{ - Flow: f, - RequestHeaders: http.Header{ - "Cookie": []string{"c1=v1", "c2=v2"}, - "Some-Header": []string{"Some-Value"}, - }, - RequestMethod: "POST", - RequestUrl: "https://test.kratos.ory.sh/some-test-path", - Identity: i, - }, - }, - { - desc: "legacy filepath without scheme", - template: "./stub/test_body.jsonnet", - data: &templateContext{ - Flow: f, - RequestHeaders: http.Header{ - "Cookie": []string{"c1=v1", "c2=v2"}, - "Some-Header": []string{"Some-Value"}, - }, - RequestMethod: "POST", - RequestUrl: "https://test.kratos.ory.sh/some-test-path", - Identity: i, - }, - }, - { - desc: "base64 encoded template URI", - template: "base64://" + base64.StdEncoding.EncodeToString(testBodyJSONNet), - data: &templateContext{ - Flow: f, - RequestHeaders: http.Header{ - "Cookie": []string{"foo=bar"}, - "My-Custom-Header": []string{"Cumstom-Value"}, - }, - RequestMethod: "PUT", - RequestUrl: "https://test.kratos.ory.sh/other-test-path", - Identity: i, - }, - }, - } { - t.Run("case="+tc.desc, func(t *testing.T) { - b, err := createBody(l, tc.template, tc.data, retryablehttp.NewClient()) - require.NoError(t, err) - body, err := io.ReadAll(b) - require.NoError(t, err) - - expected, err := json.Marshal(map[string]interface{}{ - "flow_id": tc.data.Flow.GetID(), - "identity_id": tc.data.Identity.ID, - "headers": tc.data.RequestHeaders, - "method": tc.data.RequestMethod, - "url": tc.data.RequestUrl, - }) - require.NoError(t, err) - - assert.JSONEq(t, string(expected), string(body)) - }) - } - - t.Run("case=warns about legacy usage", func(t *testing.T) { - hook := test.Hook{} - l := logrusx.New("kratos", "test", logrusx.WithHook(&hook)) - - _, _ = createBody(l, "./foo", nil, retryablehttp.NewClient()) - - require.Len(t, hook.Entries, 1) - assert.Contains(t, hook.LastEntry().Message, "support for filepaths without a 'file://' scheme will be dropped") - }) - - t.Run("case=return non nil body reader on empty templateURI", func(t *testing.T) { - body, err := createBody(l, "", nil, retryablehttp.NewClient()) - assert.NotNil(t, body) - assert.Nil(t, err) - }) -} - -func TestWebHookConfig(t *testing.T) { - for _, tc := range []struct { - strategy string - method string - url string - body string - rawConfig string - authStrategy AuthStrategy - }{ - { - strategy: "empty", - method: "POST", - url: "https://test.kratos.ory.sh/my_hook1", - body: "/path/to/my/jsonnet1.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook1", - "method": "POST", - "body": "/path/to/my/jsonnet1.file" - }`, - authStrategy: &NoopAuthStrategy{}, - }, - { - strategy: "basic_auth", - method: "GET", - url: "https://test.kratos.ory.sh/my_hook2", - body: "/path/to/my/jsonnet2.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook2", - "method": "GET", - "body": "/path/to/my/jsonnet2.file", - "auth": { - "type": "basic_auth", - "config": { - "user": "test-api-user", - "password": "secret" - } - } - }`, - authStrategy: &BasicAuthStrategy{}, - }, - { - strategy: "api-key/header", - method: "DELETE", - url: "https://test.kratos.ory.sh/my_hook3", - body: "/path/to/my/jsonnet3.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook3", - "method": "DELETE", - "body": "/path/to/my/jsonnet3.file", - "auth": { - "type": "api_key", - "config": { - "in": "header", - "name": "my-api-key", - "value": "secret" - } - } - }`, - authStrategy: &ApiKeyStrategy{}, - }, - { - strategy: "api-key/cookie", - method: "POST", - url: "https://test.kratos.ory.sh/my_hook4", - body: "/path/to/my/jsonnet4.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook4", - "method": "POST", - "body": "/path/to/my/jsonnet4.file", - "auth": { - "type": "api_key", - "config": { - "in": "cookie", - "name": "my-api-key", - "value": "secret" - } - } - }`, - authStrategy: &ApiKeyStrategy{}, - }, - } { - t.Run("auth-strategy="+tc.strategy, func(t *testing.T) { - conf, err := newWebHookConfig([]byte(tc.rawConfig)) - assert.Nil(t, err) - - assert.Equal(t, tc.url, conf.URL) - assert.Equal(t, tc.method, conf.Method) - assert.Equal(t, tc.body, conf.TemplateURI) - assert.NotNil(t, conf.Auth) - assert.IsTypef(t, tc.authStrategy, conf.Auth, "Auth should be of the expected type") - }) - } -} diff --git a/selfservice/strategy/link/sender.go b/selfservice/strategy/link/sender.go index f6567467f7f..df4a9bf2b85 100644 --- a/selfservice/strategy/link/sender.go +++ b/selfservice/strategy/link/sender.go @@ -7,6 +7,8 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/kratos/courier/template/email" + "github.com/ory/x/httpx" "github.com/pkg/errors" @@ -16,7 +18,6 @@ import ( "github.com/ory/x/urlx" "github.com/ory/kratos/courier" - templates "github.com/ory/kratos/courier/template" "github.com/ory/kratos/driver/config" "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow/recovery" @@ -66,7 +67,7 @@ func (s *Sender) SendRecoveryLink(ctx context.Context, r *http.Request, f *recov address, err := s.r.IdentityPool().FindRecoveryAddressByValue(ctx, identity.RecoveryAddressTypeEmail, to) if err != nil { - if err := s.send(ctx, string(via), templates.NewRecoveryInvalid(s.r, &templates.RecoveryInvalidModel{To: to})); err != nil { + if err := s.send(ctx, string(via), email.NewRecoveryInvalid(s.r, &email.RecoveryInvalidModel{To: to})); err != nil { return err } return errors.Cause(ErrUnknownAddress) @@ -106,7 +107,7 @@ func (s *Sender) SendVerificationLink(ctx context.Context, f *verification.Flow, WithField("via", via). WithSensitiveField("email_address", address). Info("Sending out invalid verification email because address is unknown.") - if err := s.send(ctx, string(via), templates.NewVerificationInvalid(s.r, &templates.VerificationInvalidModel{To: to})); err != nil { + if err := s.send(ctx, string(via), email.NewVerificationInvalid(s.r, &email.VerificationInvalidModel{To: to})); err != nil { return err } return errors.Cause(ErrUnknownAddress) @@ -145,8 +146,8 @@ func (s *Sender) SendRecoveryTokenTo(ctx context.Context, f *recovery.Flow, i *i return err } - return s.send(ctx, string(address.Via), templates.NewRecoveryValid(s.r, - &templates.RecoveryValidModel{To: address.Value, RecoveryURL: urlx.CopyWithQuery( + return s.send(ctx, string(address.Via), email.NewRecoveryValid(s.r, + &email.RecoveryValidModel{To: address.Value, RecoveryURL: urlx.CopyWithQuery( urlx.AppendPaths(s.r.Config(ctx).SelfServiceLinkMethodBaseURL(), recovery.RouteSubmitFlow), url.Values{ "token": {token.Token}, @@ -168,8 +169,8 @@ func (s *Sender) SendVerificationTokenTo(ctx context.Context, f *verification.Fl return err } - if err := s.send(ctx, string(address.Via), templates.NewVerificationValid(s.r, - &templates.VerificationValidModel{To: address.Value, VerificationURL: urlx.CopyWithQuery( + if err := s.send(ctx, string(address.Via), email.NewVerificationValid(s.r, + &email.VerificationValidModel{To: address.Value, VerificationURL: urlx.CopyWithQuery( urlx.AppendPaths(s.r.Config(ctx).SelfServiceLinkMethodBaseURL(), verification.RouteSubmitFlow), url.Values{ "flow": {f.ID.String()}, diff --git a/x/require.go b/x/require.go index 17154b1593f..689dc6a1aab 100644 --- a/x/require.go +++ b/x/require.go @@ -5,6 +5,7 @@ import ( "encoding/json" "testing" + "github.com/gofrs/uuid" "github.com/stretchr/testify/require" ) @@ -13,3 +14,7 @@ func RequireJSONMarshal(t *testing.T, in interface{}) []byte { require.NoError(t, json.NewEncoder(&b).Encode(in)) return b.Bytes() } + +func RequireNotNilUUID(t *testing.T, id uuid.UUID) { + require.NotEqual(t, uuid.Nil, id) +} From 1d72d9138a2d6671216d08637be694189136fb7d Mon Sep 17 00:00:00 2001 From: reshetnik-alexey Date: Wed, 16 Feb 2022 20:10:24 +0530 Subject: [PATCH 03/10] fix: mr comment fix --- courier/sms.go | 3 +- courier/sms_test.go | 14 ++++- courier/smtp_test.go | 9 +-- courier/template/sms/otp.go | 5 +- courier/template/sms/stub.go | 5 +- embedx/config.schema.json | 40 +++++++------- request/builder.go | 55 +++++++++---------- request/builder_test.go | 2 +- selfservice/hook/web_hook.go | 2 +- .../root.courierSMS.yaml | 23 ++++++++ x/require.go | 5 -- 11 files changed, 95 insertions(+), 68 deletions(-) create mode 100644 test/schema/fixtures/config.schema.test.success/root.courierSMS.yaml diff --git a/courier/sms.go b/courier/sms.go index 059fa638b53..7a26ad7e013 100644 --- a/courier/sms.go +++ b/courier/sms.go @@ -19,7 +19,6 @@ type sendSMSRequestBody struct { type smsClient struct { *http.Client - Host string RequestConfig json.RawMessage GetTemplateType func(t SMSTemplate) (TemplateType, error) @@ -81,7 +80,7 @@ func (c *courier) dispatchSMS(ctx context.Context, msg Message) error { return err } - builder, err := request.NewBuilder(c.smsClient.RequestConfig, c.deps.Logger()) + builder, err := request.NewBuilder(c.smsClient.RequestConfig, c.deps.HTTPClient(ctx), c.deps.Logger()) if err != nil { return err } diff --git a/courier/sms_test.go b/courier/sms_test.go index 0178e2fee4f..4f179615711 100644 --- a/courier/sms_test.go +++ b/courier/sms_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/gofrs/uuid" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,7 +19,7 @@ import ( "github.com/ory/kratos/courier/template/sms" "github.com/ory/kratos/driver/config" "github.com/ory/kratos/internal" - "github.com/ory/kratos/x" + "github.com/ory/x/resilience" ) func TestQueueSMS(t *testing.T) { @@ -89,14 +91,20 @@ func TestQueueSMS(t *testing.T) { for _, message := range expectedSMS { id, err := c.QueueSMS(ctx, sms.NewTestStub(reg, message)) require.NoError(t, err) - x.RequireNotNilUUID(t, id) + require.NotEqual(t, uuid.Nil, id) } go func() { require.NoError(t, c.Work(ctx)) }() - time.Sleep(time.Second) + require.NoError(t, resilience.Retry(reg.Logger(), time.Millisecond*250, time.Second*10, func() error { + if len(actual) == len(expectedSMS) { + return nil + } + return errors.New("capacity not reached") + })) + for i, message := range actual { expected := expectedSMS[i] diff --git a/courier/smtp_test.go b/courier/smtp_test.go index 58df4f9fc9b..626d2c45d17 100644 --- a/courier/smtp_test.go +++ b/courier/smtp_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -70,7 +71,7 @@ func TestQueueEmail(t *testing.T) { conf.MustSet(config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh") reg.Logger().Level = logrus.TraceLevel - c := reg.Courier(ctx) //??? + c := reg.Courier(ctx) ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -81,7 +82,7 @@ func TestQueueEmail(t *testing.T) { Body: "test-body-1", })) require.NoError(t, err) - x.RequireNotNilUUID(t, id) + require.NotEqual(t, uuid.Nil, id) id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ To: "test-recipient-2@example.org", @@ -89,7 +90,7 @@ func TestQueueEmail(t *testing.T) { Body: "test-body-2", })) require.NoError(t, err) - x.RequireNotNilUUID(t, id) + require.NotEqual(t, uuid.Nil, id) // The third email contains a sender name and custom headers conf.MustSet(config.ViperKeyCourierSMTPFromName, "Bob") @@ -103,7 +104,7 @@ func TestQueueEmail(t *testing.T) { Body: "test-body-3", })) require.NoError(t, err) - x.RequireNotNilUUID(t, id) + require.NotEqual(t, uuid.Nil, id) go func() { require.NoError(t, c.Work(ctx)) diff --git a/courier/template/sms/otp.go b/courier/template/sms/otp.go index 38e31446c1a..ef003f63b0e 100644 --- a/courier/template/sms/otp.go +++ b/courier/template/sms/otp.go @@ -15,8 +15,9 @@ type ( } OTPMessageModel struct { - To string - Code string + To string + Code string + Identity map[string]interface{} } ) diff --git a/courier/template/sms/stub.go b/courier/template/sms/stub.go index 84140015635..fa2fb19e3b5 100644 --- a/courier/template/sms/stub.go +++ b/courier/template/sms/stub.go @@ -15,8 +15,9 @@ type ( } TestStubModel struct { - To string - Body string + To string + Body string + Identity map[string]interface{} } ) diff --git a/embedx/config.schema.json b/embedx/config.schema.json index 362a4902b51..0f3674dd358 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -1478,26 +1478,23 @@ "type": "string", "description": "The HTTP method to use (GET, POST, etc)." }, + "header": { + "type": "object", + "description": "The HTTP headers that must be applied to request", + "additionalProperties": { + "type": "string" + } + }, "body": { "type": "string", - "oneOf": [ - { - "format": "uri", - "pattern": "^(http|https|file|base64)://", - "description": "URI pointing to the jsonnet template used for payload generation. Only used for those HTTP methods, which support HTTP body payloads", - "examples": [ - "file:///path/to/body.jsonnet", - "file://./body.jsonnet", - "base64://ZnVuY3Rpb24oY3R4KSB7CiAgaWRlbnRpdHlfaWQ6IGlmIGN0eFsiaWRlbnRpdHkiXSAhPSBudWxsIHRoZW4gY3R4LmlkZW50aXR5LmlkLAp9=", - "https://oryapis.com/default_body.jsonnet" - ] - }, - { - "description": "DEPRECATED: please use a URI instead (i.e. prefix your filepath with 'file://')", - "not": { - "pattern": "^(http|https|file|base64)://" - } - } + "format": "uri", + "pattern": "^(http|https|file|base64)://", + "description": "URI pointing to the jsonnet template used for payload generation. Only used for those HTTP methods, which support HTTP body payloads", + "examples": [ + "file:///path/to/body.jsonnet", + "file://./body.jsonnet", + "base64://ZnVuY3Rpb24oY3R4KSB7CiAgaWRlbnRpdHlfaWQ6IGlmIGN0eFsiaWRlbnRpdHkiXSAhPSBudWxsIHRoZW4gY3R4LmlkZW50aXR5LmlkLAp9=", + "https://oryapis.com/default_body.jsonnet" ] }, "auth": { @@ -1514,7 +1511,12 @@ ] }, "additionalProperties": false - } + }, + "required": [ + "url", + "method" + ], + "additionalProperties": false } }, "additionalProperties": false diff --git a/request/builder.go b/request/builder.go index 21783c73c10..69d39bf9d43 100644 --- a/request/builder.go +++ b/request/builder.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/google/go-jsonnet" + "github.com/hashicorp/go-retryablehttp" "github.com/ory/x/fetcher" "github.com/ory/x/logrusx" @@ -22,12 +23,13 @@ const ( ) type Builder struct { - r *http.Request - log *logrusx.Logger - conf *Config + r *http.Request + log *logrusx.Logger + conf *Config + fetchClient *retryablehttp.Client } -func NewBuilder(config json.RawMessage, l *logrusx.Logger) (*Builder, error) { +func NewBuilder(config json.RawMessage, client *retryablehttp.Client, l *logrusx.Logger) (*Builder, error) { c, err := parseConfig(config) if err != nil { return nil, err @@ -39,9 +41,10 @@ func NewBuilder(config json.RawMessage, l *logrusx.Logger) (*Builder, error) { } return &Builder{ - r: r, - log: l, - conf: c, + r: r, + log: l, + conf: c, + fetchClient: client, }, nil } @@ -69,13 +72,18 @@ func (b *Builder) addBody(body interface{}) error { return errors.New("got empty template path for request with body") } + tpl, err := b.readTemplate() + if err != nil { + return err + } + switch contentType { case ContentTypeForm: - if err := b.addURLEncodedBody(body); err != nil { + if err := b.addURLEncodedBody(tpl, body); err != nil { return err } case ContentTypeJSON: - if err := b.addJSONBody(body); err != nil { + if err := b.addJSONBody(tpl, body); err != nil { return err } default: @@ -85,14 +93,7 @@ func (b *Builder) addBody(body interface{}) error { return nil } -func (b *Builder) addJSONBody(body interface{}) error { - tURL := b.conf.TemplateURI - - tpl, err := readTemplate(tURL, b.log) - if err != nil { - return err - } - +func (b *Builder) addJSONBody(template *bytes.Buffer, body interface{}) error { buf := new(bytes.Buffer) enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) @@ -105,7 +106,7 @@ func (b *Builder) addJSONBody(body interface{}) error { vm := jsonnet.MakeVM() vm.TLACode("ctx", buf.String()) - res, err := vm.EvaluateAnonymousSnippet(tURL, tpl.String()) + res, err := vm.EvaluateAnonymousSnippet(b.conf.TemplateURI, template.String()) if err != nil { return err } @@ -117,13 +118,7 @@ func (b *Builder) addJSONBody(body interface{}) error { return nil } -func (b *Builder) addURLEncodedBody(body interface{}) error { - tURL := b.conf.TemplateURI - tpl, err := readTemplate(tURL, b.log) - if err != nil { - return err - } - +func (b *Builder) addURLEncodedBody(template *bytes.Buffer, body interface{}) error { buf := new(bytes.Buffer) enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) @@ -136,7 +131,7 @@ func (b *Builder) addURLEncodedBody(body interface{}) error { vm := jsonnet.MakeVM() vm.TLACode("ctx", buf.String()) - res, err := vm.EvaluateAnonymousSnippet(tURL, tpl.String()) + res, err := vm.EvaluateAnonymousSnippet(b.conf.TemplateURI, template.String()) if err != nil { return err } @@ -175,18 +170,20 @@ func (b *Builder) BuildRequest(body interface{}) (*http.Request, error) { return b.r, nil } -func readTemplate(templateURI string, l *logrusx.Logger) (*bytes.Buffer, error) { +func (b *Builder) readTemplate() (*bytes.Buffer, error) { + templateURI := b.conf.TemplateURI + if templateURI == "" { return nil, nil } - f := fetcher.NewFetcher() + f := fetcher.NewFetcher(fetcher.WithClient(b.fetchClient)) tpl, err := f.Fetch(templateURI) if errors.Is(err, fetcher.ErrUnknownScheme) { // legacy filepath templateURI = "file://" + templateURI - l.WithError(err).Warnf("support for filepaths without a 'file://' scheme will be dropped in the next release, please use %s instead in your config", templateURI) + b.log.WithError(err).Warnf("support for filepaths without a 'file://' scheme will be dropped in the next release, please use %s instead in your config", templateURI) tpl, err = f.Fetch(templateURI) } diff --git a/request/builder_test.go b/request/builder_test.go index aa8fd9a9078..1deb04dee07 100644 --- a/request/builder_test.go +++ b/request/builder_test.go @@ -238,7 +238,7 @@ func TestBuildRequest(t *testing.T) { t.Run("request-type="+tc.name, func(t *testing.T) { l := logrusx.New("kratos", "test") - rb, err := NewBuilder(json.RawMessage(tc.rawConfig), l) + rb, err := NewBuilder(json.RawMessage(tc.rawConfig), nil, l) require.NoError(t, err) assert.Equal(t, tc.bodyTemplateURI, rb.conf.TemplateURI) diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index 1ae2f5fe6a4..53f0466c442 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -117,7 +117,7 @@ func (e *WebHook) ExecuteSettingsPostPersistHook(_ http.ResponseWriter, req *htt } func (e *WebHook) execute(ctx context.Context, data *templateContext) error { - builder, err := request.NewBuilder(e.conf, e.deps.Logger()) + builder, err := request.NewBuilder(e.conf, e.deps.HTTPClient(ctx), e.deps.Logger()) if err != nil { return err } diff --git a/test/schema/fixtures/config.schema.test.success/root.courierSMS.yaml b/test/schema/fixtures/config.schema.test.success/root.courierSMS.yaml new file mode 100644 index 00000000000..d8c9707ed89 --- /dev/null +++ b/test/schema/fixtures/config.schema.test.success/root.courierSMS.yaml @@ -0,0 +1,23 @@ +selfservice: + default_browser_return_url: "#/definitions/defaultReturnTo" + +dsn: foo + +identity: + schemas: + - id: default + url: https://example.com + +courier: + smtp: + connection_uri: smtps://foo:bar@my-mailserver:1234/ + from_address: no-reply@ory.kratos.sh + sms: + enabled: true + from: "+19592155527" + request_config: + url: https://sms.example.com + method: POST + body: file://request.config.twilio.jsonnet + header: + 'Content-Type': "application/x-www-form-urlencoded" diff --git a/x/require.go b/x/require.go index 689dc6a1aab..17154b1593f 100644 --- a/x/require.go +++ b/x/require.go @@ -5,7 +5,6 @@ import ( "encoding/json" "testing" - "github.com/gofrs/uuid" "github.com/stretchr/testify/require" ) @@ -14,7 +13,3 @@ func RequireJSONMarshal(t *testing.T, in interface{}) []byte { require.NoError(t, json.NewEncoder(&b).Encode(in)) return b.Bytes() } - -func RequireNotNilUUID(t *testing.T, id uuid.UUID) { - require.NotEqual(t, uuid.Nil, id) -} From 0ad29616ec5055dce4d8db81bd26ce200adfd33f Mon Sep 17 00:00:00 2001 From: reshetnik-alexey Date: Thu, 17 Feb 2022 12:33:48 +0530 Subject: [PATCH 04/10] fix: added malformed config test --- .../root.SMSConfigmalformedURL.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 test/schema/fixtures/config.schema.test.failure/root.SMSConfigmalformedURL.yaml diff --git a/test/schema/fixtures/config.schema.test.failure/root.SMSConfigmalformedURL.yaml b/test/schema/fixtures/config.schema.test.failure/root.SMSConfigmalformedURL.yaml new file mode 100644 index 00000000000..f5b1003cff6 --- /dev/null +++ b/test/schema/fixtures/config.schema.test.failure/root.SMSConfigmalformedURL.yaml @@ -0,0 +1,5 @@ +sms: + request_config: + url: "malformed uri" + method: POST + body: "malformed uri" From 80b7828c68d35baaa76d7b02e0e8707cc18b984f Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 23 Feb 2022 09:46:03 +0100 Subject: [PATCH 05/10] fix: resolve issues with the CI pipeline --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 4a3980dbad8..b2e99b81671 100644 --- a/go.mod +++ b/go.mod @@ -76,7 +76,7 @@ require ( github.com/ory/kratos-client-go v0.6.3-alpha.1 github.com/ory/mail/v3 v3.0.0 github.com/ory/nosurf v1.2.7 - github.com/ory/x v0.0.345 + github.com/ory/x v0.0.348 github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.3.0 diff --git a/go.sum b/go.sum index 66299a3bbd6..db4e460ee60 100644 --- a/go.sum +++ b/go.sum @@ -1804,8 +1804,8 @@ github.com/ory/x v0.0.205/go.mod h1:A1s4iwmFIppRXZLF3J9GGWeY/HpREVm0Dk5z/787iek= github.com/ory/x v0.0.250/go.mod h1:jUJaVptu+geeqlb9SyQCogTKj5ztSDIF6APkhbKtwLc= github.com/ory/x v0.0.272/go.mod h1:1TTPgJGQutrhI2OnwdrTIHE9ITSf4MpzXFzA/ncTGRc= github.com/ory/x v0.0.288/go.mod h1:APpShLyJcVzKw1kTgrHI+j/L9YM+8BRjHlcYObc7C1U= -github.com/ory/x v0.0.345 h1:e3ZCt8SxLXQdn/fWM/xjxl+2+DhjrTNIY9DVwYMR2m4= -github.com/ory/x v0.0.345/go.mod h1:Ddbu3ecSaNDgxdntdD1gDu3ALG5fWR5AwUB1ILeBUNE= +github.com/ory/x v0.0.348 h1:Z2wbEvSpTindtjKTTrd3grIlWbBtvW2udYG5ZjTZHTo= +github.com/ory/x v0.0.348/go.mod h1:Ddbu3ecSaNDgxdntdD1gDu3ALG5fWR5AwUB1ILeBUNE= github.com/otiai10/copy v1.2.0/go.mod h1:rrF5dJ5F0t/EWSYODDu4j9/vEeYHMkc8jt0zJChqQWw= github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= From 1c213233ec22b4ece32b1ef77585064312b7286d Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 23 Feb 2022 10:59:02 +0100 Subject: [PATCH 06/10] chore: update sms template body --- .../template/courier/builtin/templates/otp/sms.body.gotmpl | 4 +--- courier/template/sms/otp_test.go | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/courier/template/courier/builtin/templates/otp/sms.body.gotmpl b/courier/template/courier/builtin/templates/otp/sms.body.gotmpl index a630a83b82d..ff95187e9e7 100644 --- a/courier/template/courier/builtin/templates/otp/sms.body.gotmpl +++ b/courier/template/courier/builtin/templates/otp/sms.body.gotmpl @@ -1,3 +1 @@ -Hi, please verify your account using following code: - -{{ .Code }} +Your verification code is: {{ .Code }} diff --git a/courier/template/sms/otp_test.go b/courier/template/sms/otp_test.go index cb97c00d1c0..01f4dcbbacb 100644 --- a/courier/template/sms/otp_test.go +++ b/courier/template/sms/otp_test.go @@ -22,7 +22,7 @@ func TestNewOTPMessage(t *testing.T) { tpl := sms.NewOTPMessage(reg, &sms.OTPMessageModel{To: expectedPhone, Code: otp}) - expectedBody := fmt.Sprintf("Hi, please verify your account using following code:\n\n%s\n", otp) + expectedBody := fmt.Sprintf("Your verification code is: %s\n", otp) actualBody, err := tpl.SMSBody(context.Background()) require.NoError(t, err) From 2f03ff6c6836a9bd96a2421ad9a2292591de6227 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 23 Feb 2022 10:59:16 +0100 Subject: [PATCH 07/10] chore: code style --- request/auth.go | 1 - request/builder_test.go | 12 ++++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/request/auth.go b/request/auth.go index 65df14402fa..c1093250008 100644 --- a/request/auth.go +++ b/request/auth.go @@ -27,5 +27,4 @@ func authStrategy(name string, config json.RawMessage) (AuthStrategy, error) { } return nil, fmt.Errorf("unsupported auth type: %s", name) - } diff --git a/request/builder_test.go b/request/builder_test.go index 1deb04dee07..ad868392bf2 100644 --- a/request/builder_test.go +++ b/request/builder_test.go @@ -30,7 +30,7 @@ func TestBuildRequest(t *testing.T) { method string url string authStrategy string - header http.Header + expectedHeader http.Header bodyTemplateURI string body *testRequestBody expectedBody string @@ -93,7 +93,7 @@ func TestBuildRequest(t *testing.T) { method: "POST", url: "https://test.kratos.ory.sh/my_endpoint2", authStrategy: "", - header: map[string][]string{"Custom-Header": {"test"}}, + expectedHeader: map[string][]string{"Custom-Header": {"test"}}, bodyTemplateURI: "file://./stub/test_body.jsonnet", body: &testRequestBody{ To: "+12127110378", @@ -184,7 +184,7 @@ func TestBuildRequest(t *testing.T) { url: "https://test.kratos.ory.sh/my_endpoint6", bodyTemplateURI: "file://./stub/test_body.jsonnet", authStrategy: "api_key", - header: map[string][]string{"Content-Type": {ContentTypeForm}}, + expectedHeader: map[string][]string{"Content-Type": {ContentTypeForm}}, body: &testRequestBody{ To: "+14134242223", From: "+13104661805", @@ -214,7 +214,7 @@ func TestBuildRequest(t *testing.T) { url: "https://test.kratos.ory.sh/my_endpoint7", bodyTemplateURI: "file://./stub/test_body.jsonnet", authStrategy: "basic_auth", - header: map[string][]string{"Content-Type": {ContentTypeJSON}}, + expectedHeader: map[string][]string{"Content-Type": {ContentTypeJSON}}, body: &testRequestBody{ To: "+14134242223", From: "+13104661805", @@ -257,8 +257,8 @@ func TestBuildRequest(t *testing.T) { assert.Equal(t, tc.expectedBody, string(requestBody)) } - if tc.header != nil { - mustContainHeader(t, tc.header, req.Header) + if tc.expectedHeader != nil { + mustContainHeader(t, tc.expectedHeader, req.Header) } }) } From 40fc6fc1e0279571546b17b0a8635fa5a40743dc Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 23 Feb 2022 11:23:15 +0100 Subject: [PATCH 08/10] fix: ensure no internal networks can be called in SMS sender --- courier/courier.go | 12 ++++++-- .../{dispatcher.go => courier_dispatcher.go} | 3 ++ courier/sms.go | 4 +-- courier/sms_test.go | 28 +++++++++++++++++++ request/auth.go | 5 ++-- request/auth_strategy.go | 8 ++++-- request/auth_strategy_test.go | 10 ++++--- request/builder.go | 6 ++-- selfservice/hook/web_hook.go | 10 +------ 9 files changed, 59 insertions(+), 27 deletions(-) rename courier/{dispatcher.go => courier_dispatcher.go} (97%) diff --git a/courier/courier.go b/courier/courier.go index e173db3cb3d..d5253038229 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -28,6 +28,7 @@ type ( QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error) SmtpDialer() *gomail.Dialer + DispatchQueue(ctx context.Context) error } Provider interface { @@ -39,9 +40,10 @@ type ( } courier struct { - smsClient *smsClient - smtpClient *smtpClient - deps Dependencies + smsClient *smsClient + smtpClient *smtpClient + deps Dependencies + failOnError bool } ) @@ -53,6 +55,10 @@ func NewCourier(ctx context.Context, deps Dependencies) Courier { } } +func (c *courier) FailOnDispatchError() { + c.failOnError = true +} + func (c *courier) Work(ctx context.Context) error { errChan := make(chan error) defer close(errChan) diff --git a/courier/dispatcher.go b/courier/courier_dispatcher.go similarity index 97% rename from courier/dispatcher.go rename to courier/courier_dispatcher.go index 4d8beb7f2fa..6810fbd50cc 100644 --- a/courier/dispatcher.go +++ b/courier/courier_dispatcher.go @@ -52,6 +52,9 @@ func (c *courier) DispatchQueue(ctx context.Context) error { if err := c.DispatchMessage(ctx, msg); err != nil { for _, replace := range messages[k:] { if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil { + if c.failOnError { + return err + } c.deps.Logger(). WithError(err). WithField("message_id", replace.ID). diff --git a/courier/sms.go b/courier/sms.go index 7a26ad7e013..3fbb8d78dba 100644 --- a/courier/sms.go +++ b/courier/sms.go @@ -18,7 +18,6 @@ type sendSMSRequestBody struct { } type smsClient struct { - *http.Client RequestConfig json.RawMessage GetTemplateType func(t SMSTemplate) (TemplateType, error) @@ -31,7 +30,6 @@ func newSMS(ctx context.Context, deps Dependencies) *smsClient { } return &smsClient{ - Client: &http.Client{}, RequestConfig: deps.CourierConfig(ctx).CourierSMSRequestConfig(), GetTemplateType: SMSTemplateType, @@ -94,7 +92,7 @@ func (c *courier) dispatchSMS(ctx context.Context, msg Message) error { return err } - res, err := c.smsClient.Do(req) + res, err := c.deps.HTTPClient(ctx).Do(req) if err != nil { return err } diff --git a/courier/sms_test.go b/courier/sms_test.go index 4f179615711..fdcd234f4ce 100644 --- a/courier/sms_test.go +++ b/courier/sms_test.go @@ -114,3 +114,31 @@ func TestQueueSMS(t *testing.T) { srv.Close() } + +func TestDisallowedInternalNetwork(t *testing.T) { + conf, reg := internal.NewFastRegistryWithMocks(t) + conf.MustSet(config.ViperKeyCourierSMSRequestConfig, fmt.Sprintf(`{ + "url": "http://127.0.0.1/", + "method": "GET", + "body": "file://./stub/request.config.twilio.jsonnet" + }`)) + conf.MustSet(config.ViperKeyCourierSMSEnabled, true) + conf.MustSet(config.ViperKeyCourierSMTPURL, "http://foo.url") + conf.MustSet(config.ViperKeyClientHTTPNoPrivateIPRanges, true) + reg.Logger().Level = logrus.TraceLevel + + ctx := context.Background() + c := reg.Courier(ctx) + c.(interface { + FailOnDispatchError() + }).FailOnDispatchError() + _, err := c.QueueSMS(ctx, sms.NewTestStub(reg, &sms.TestStubModel{ + To: "+12065550101", + Body: "test-sms-body-1", + })) + require.NoError(t, err) + + err = c.DispatchQueue(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "ip 127.0.0.1 is in the 127.0.0.0/8 range") +} diff --git a/request/auth.go b/request/auth.go index c1093250008..398aa0aef91 100644 --- a/request/auth.go +++ b/request/auth.go @@ -3,12 +3,13 @@ package request import ( "encoding/json" "fmt" - "net/http" + + "github.com/hashicorp/go-retryablehttp" ) type ( AuthStrategy interface { - apply(req *http.Request) + apply(req *retryablehttp.Request) } authStrategyFactory func(c json.RawMessage) (AuthStrategy, error) diff --git a/request/auth_strategy.go b/request/auth_strategy.go index e2e41b9e0f8..f280a5b92e7 100644 --- a/request/auth_strategy.go +++ b/request/auth_strategy.go @@ -3,6 +3,8 @@ package request import ( "encoding/json" "net/http" + + "github.com/hashicorp/go-retryablehttp" ) type ( @@ -24,7 +26,7 @@ func newNoopAuthStrategy(_ json.RawMessage) (AuthStrategy, error) { return &noopAuthStrategy{}, nil } -func (c *noopAuthStrategy) apply(_ *http.Request) {} +func (c *noopAuthStrategy) apply(_ *retryablehttp.Request) {} func newBasicAuthStrategy(raw json.RawMessage) (AuthStrategy, error) { type config struct { @@ -43,7 +45,7 @@ func newBasicAuthStrategy(raw json.RawMessage) (AuthStrategy, error) { }, nil } -func (c *basicAuthStrategy) apply(req *http.Request) { +func (c *basicAuthStrategy) apply(req *retryablehttp.Request) { req.SetBasicAuth(c.user, c.password) } @@ -66,7 +68,7 @@ func newApiKeyStrategy(raw json.RawMessage) (AuthStrategy, error) { }, nil } -func (c *apiKeyStrategy) apply(req *http.Request) { +func (c *apiKeyStrategy) apply(req *retryablehttp.Request) { switch c.in { case "cookie": req.AddCookie(&http.Cookie{Name: c.name, Value: c.value}) diff --git a/request/auth_strategy_test.go b/request/auth_strategy_test.go index b22d140c46e..e2422fb4042 100644 --- a/request/auth_strategy_test.go +++ b/request/auth_strategy_test.go @@ -4,12 +4,14 @@ import ( "net/http" "testing" + "github.com/hashicorp/go-retryablehttp" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNoopAuthStrategy(t *testing.T) { - req := http.Request{Header: map[string][]string{}} + req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} auth := noopAuthStrategy{} auth.apply(&req) @@ -18,7 +20,7 @@ func TestNoopAuthStrategy(t *testing.T) { } func TestBasicAuthStrategy(t *testing.T) { - req := http.Request{Header: map[string][]string{}} + req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} auth := basicAuthStrategy{ user: "test-user", password: "test-pass", @@ -34,7 +36,7 @@ func TestBasicAuthStrategy(t *testing.T) { } func TestApiKeyInHeaderStrategy(t *testing.T) { - req := http.Request{Header: map[string][]string{}} + req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} auth := apiKeyStrategy{ in: "header", name: "my-api-key-name", @@ -50,7 +52,7 @@ func TestApiKeyInHeaderStrategy(t *testing.T) { } func TestApiKeyInCookieStrategy(t *testing.T) { - req := http.Request{Header: map[string][]string{}} + req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} auth := apiKeyStrategy{ in: "cookie", name: "my-api-key-name", diff --git a/request/builder.go b/request/builder.go index 69d39bf9d43..fbd1782f293 100644 --- a/request/builder.go +++ b/request/builder.go @@ -23,7 +23,7 @@ const ( ) type Builder struct { - r *http.Request + r *retryablehttp.Request log *logrusx.Logger conf *Config fetchClient *retryablehttp.Client @@ -35,7 +35,7 @@ func NewBuilder(config json.RawMessage, client *retryablehttp.Client, l *logrusx return nil, err } - r, err := http.NewRequest(c.Method, c.URL, nil) + r, err := retryablehttp.NewRequest(c.Method, c.URL, nil) if err != nil { return nil, err } @@ -153,7 +153,7 @@ func (b *Builder) addURLEncodedBody(template *bytes.Buffer, body interface{}) er return nil } -func (b *Builder) BuildRequest(body interface{}) (*http.Request, error) { +func (b *Builder) BuildRequest(body interface{}) (*retryablehttp.Request, error) { b.r.Header = b.conf.Header if err := b.addAuth(); err != nil { return nil, err diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index 53f0466c442..bb427ee7465 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -6,8 +6,6 @@ import ( "fmt" "net/http" - "github.com/hashicorp/go-retryablehttp" - "github.com/ory/kratos/identity" "github.com/ory/kratos/request" "github.com/ory/kratos/selfservice/flow" @@ -127,13 +125,7 @@ func (e *WebHook) execute(ctx context.Context, data *templateContext) error { return err } - httpClient := e.deps.HTTPClient(ctx) - r, err := retryablehttp.FromRequest(req) - if err != nil { - return err - } - - resp, err := httpClient.Do(r) + resp, err := e.deps.HTTPClient(ctx).Do(req) if err != nil { return err } From fd66d6ca97a8339cb23a2ec9f277a7837ac04905 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 23 Feb 2022 13:55:02 +0100 Subject: [PATCH 09/10] test: add tests for new sms options --- courier/sms.go | 13 +++++---- courier/smtp.go | 4 ++- .../TestCourierSMS-case=configs_set.json | 15 ++++++++++ .../TestCourierSMS-case=defaults.json | 1 + driver/config/config_test.go | 22 +++++++++++++++ driver/config/stub/.kratos.courier.sms.yaml | 28 +++++++++++++++++++ 6 files changed, 77 insertions(+), 6 deletions(-) create mode 100644 driver/config/.snapshots/TestCourierSMS-case=configs_set.json create mode 100644 driver/config/.snapshots/TestCourierSMS-case=defaults.json create mode 100644 driver/config/stub/.kratos.courier.sms.yaml diff --git a/courier/sms.go b/courier/sms.go index 3fbb8d78dba..7872e219e7c 100644 --- a/courier/sms.go +++ b/courier/sms.go @@ -3,9 +3,12 @@ package courier import ( "context" "encoding/json" - "errors" "net/http" + "github.com/pkg/errors" + + "github.com/ory/herodot" + "github.com/gofrs/uuid" "github.com/ory/kratos/request" @@ -25,10 +28,6 @@ type smsClient struct { } func newSMS(ctx context.Context, deps Dependencies) *smsClient { - if !deps.CourierConfig(ctx).CourierSMSEnabled() { - deps.Logger().Error("messages will not be sent - no sms gate server address is set in config") - } - return &smsClient{ RequestConfig: deps.CourierConfig(ctx).CourierSMSRequestConfig(), @@ -68,6 +67,10 @@ func (c *courier) QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error } func (c *courier) dispatchSMS(ctx context.Context, msg Message) error { + if !c.deps.CourierConfig(ctx).CourierSMSEnabled() { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an sms but courier.sms.enabled is set to false!")) + } + tmpl, err := c.smsClient.NewTemplateFromMessage(c.deps, msg) if err != nil { return err diff --git a/courier/smtp.go b/courier/smtp.go index 6eba348764a..8d5e19f83be 100644 --- a/courier/smtp.go +++ b/courier/smtp.go @@ -8,6 +8,8 @@ import ( "strconv" "time" + "github.com/ory/kratos/driver/config" + "github.com/gofrs/uuid" "github.com/pkg/errors" @@ -117,7 +119,7 @@ func (c *courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e func (c *courier) dispatchEmail(ctx context.Context, msg Message) error { if c.smtpClient.Host == "" { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but courier.smtp_url is not set!")) + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but %s is not set!", config.ViperKeyCourierSMTPURL)) } from := c.deps.CourierConfig(ctx).CourierSMTPFrom() diff --git a/driver/config/.snapshots/TestCourierSMS-case=configs_set.json b/driver/config/.snapshots/TestCourierSMS-case=configs_set.json new file mode 100644 index 00000000000..471e0019a44 --- /dev/null +++ b/driver/config/.snapshots/TestCourierSMS-case=configs_set.json @@ -0,0 +1,15 @@ +{ + "auth": { + "config": { + "password": "YourPass", + "user": "YourUsername" + }, + "type": "basic_auth" + }, + "body": "base64://e30=", + "header": { + "Content-Type": "application/x-www-form-urlencoded" + }, + "method": "POST", + "url": "https://api.twilio.com/2010-04-01/Accounts/YourAccountID/Messages.json" +} diff --git a/driver/config/.snapshots/TestCourierSMS-case=defaults.json b/driver/config/.snapshots/TestCourierSMS-case=defaults.json new file mode 100644 index 00000000000..19765bd501b --- /dev/null +++ b/driver/config/.snapshots/TestCourierSMS-case=defaults.json @@ -0,0 +1 @@ +null diff --git a/driver/config/config_test.go b/driver/config/config_test.go index fb674635fa1..80851755729 100644 --- a/driver/config/config_test.go +++ b/driver/config/config_test.go @@ -16,6 +16,8 @@ import ( "testing" "time" + "github.com/ory/x/snapshotx" + "github.com/ghodss/yaml" "github.com/spf13/cobra" @@ -1046,6 +1048,26 @@ func TestChangeMinPasswordLength(t *testing.T) { }) } +func TestCourierSMS(t *testing.T) { + ctx := context.Background() + + t.Run("case=configs set", func(t *testing.T) { + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, + configx.WithConfigFiles("stub/.kratos.courier.sms.yaml"), configx.SkipValidation()) + assert.True(t, conf.CourierSMSEnabled()) + snapshotx.SnapshotTExcept(t, conf.CourierSMSRequestConfig(), nil) + assert.Equal(t, "+49123456789", conf.CourierSMSFrom()) + }) + + t.Run("case=defaults", func(t *testing.T) { + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + + assert.False(t, conf.CourierSMSEnabled()) + snapshotx.SnapshotTExcept(t, conf.CourierSMSRequestConfig(), nil) + assert.Equal(t, "Ory Kratos", conf.CourierSMSFrom()) + }) +} + func TestCourierTemplatesConfig(t *testing.T) { ctx := context.Background() diff --git a/driver/config/stub/.kratos.courier.sms.yaml b/driver/config/stub/.kratos.courier.sms.yaml new file mode 100644 index 00000000000..1c2fbae89c3 --- /dev/null +++ b/driver/config/stub/.kratos.courier.sms.yaml @@ -0,0 +1,28 @@ +dsn: sqlite://foo.db?mode=memory&_fk=true + +selfservice: + default_browser_return_url: http://return-to-3-test.ory.sh/ + +identity: + default_schema_id: default + schemas: + - id: default + url: base64://ewogICIkaWQiOiAib3J5Oi8vaWRlbnRpdHktdGVzdC1zY2hlbWEiLAogICIkc2NoZW1hIjogImh0dHA6Ly9qc29uLXNjaGVtYS5vcmcvZHJhZnQtMDcvc2NoZW1hIyIsCiAgInRpdGxlIjogIklkZW50aXR5U2NoZW1hIiwKICAidHlwZSI6ICJvYmplY3QiLAogICJwcm9wZXJ0aWVzIjogewogICAgInRyYWl0cyI6IHsKICAgICAgInR5cGUiOiAib2JqZWN0IiwKICAgICAgInByb3BlcnRpZXMiOiB7CiAgICAgICAgIm5hbWUiOiB7CiAgICAgICAgICAidHlwZSI6ICJvYmplY3QiLAogICAgICAgICAgInByb3BlcnRpZXMiOiB7CiAgICAgICAgICAgICJmaXJzdCI6IHsKICAgICAgICAgICAgICAidHlwZSI6ICJzdHJpbmciCiAgICAgICAgICAgIH0sCiAgICAgICAgICAgICJsYXN0IjogewogICAgICAgICAgICAgICJ0eXBlIjogInN0cmluZyIKICAgICAgICAgICAgfQogICAgICAgICAgfQogICAgICAgIH0KICAgICAgfSwKICAgICAgInJlcXVpcmVkIjogWwogICAgICAgICJuYW1lIgogICAgICBdLAogICAgICAiYWRkaXRpb25hbFByb3BlcnRpZXMiOiB0cnVlCiAgICB9CiAgfQp9 + +courier: + smtp: + connection_uri: smtp://foo:bar@baz/ + sms: + enabled: true + from: '+49123456789' + request_config: + url: https://api.twilio.com/2010-04-01/Accounts/YourAccountID/Messages.json + method: POST + body: base64://e30= + header: + 'Content-Type': 'application/x-www-form-urlencoded' + auth: + type: basic_auth + config: + user: YourUsername + password: YourPass From f62f33c2f1a2298046f1a5d8a82cd91fe3238be9 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 23 Feb 2022 14:00:54 +0100 Subject: [PATCH 10/10] fix: lower-case jsonnet context for sms --- courier/sms.go | 6 +++--- courier/stub/request.config.twilio.jsonnet | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/courier/sms.go b/courier/sms.go index 7872e219e7c..1436a6062ab 100644 --- a/courier/sms.go +++ b/courier/sms.go @@ -15,9 +15,9 @@ import ( ) type sendSMSRequestBody struct { - To string - From string - Body string + From string `json:"from"` + To string `json:"to"` + Body string `json:"body"` } type smsClient struct { diff --git a/courier/stub/request.config.twilio.jsonnet b/courier/stub/request.config.twilio.jsonnet index 93752e14503..da0736b06df 100644 --- a/courier/stub/request.config.twilio.jsonnet +++ b/courier/stub/request.config.twilio.jsonnet @@ -1,5 +1,5 @@ function(ctx) { - from: ctx.From, - to: ctx.To, - body: ctx.Body + from: ctx.from, + to: ctx.to, + body: ctx.body }