diff --git a/courier/courier.go b/courier/courier.go index 3e800408e560..b2adaab5d687 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "github.com/ory/kratos/driver/config" "net/url" "strconv" "time" @@ -27,6 +28,10 @@ type ( CourierSMTPFromName() string CourierSMTPHeaders() map[string]string CourierTemplatesRoot() string + CourierTemplatesVerificationInvalid() *config.CourierEmailTemplate + CourierTemplatesVerificationValid() *config.CourierEmailTemplate + CourierTemplatesRecoveryInvalid() *config.CourierEmailTemplate + CourierTemplatesRecoveryValid() *config.CourierEmailTemplate } SMTPDependencies interface { PersistenceProvider diff --git a/courier/template/courier/builtin/templates/test_stub/email.body.html.nested.gotmpl b/courier/template/courier/builtin/templates/test_stub/email.body.html.nested.gotmpl new file mode 100644 index 000000000000..b05ea541b689 --- /dev/null +++ b/courier/template/courier/builtin/templates/test_stub/email.body.html.nested.gotmpl @@ -0,0 +1,10 @@ +{{define "en_US"}} +{{ $l := cat "lang=" .lang }} +{{ nospace $l }} +{{end}} + +{{define "base"}} +{{- if eq .lang "en_US" -}} +{{ template "en_US" . }} +{{- end -}} +{{end}} diff --git a/courier/template/load_template.go b/courier/template/load_template.go index cf24c4c140a9..8bd0831a917b 100644 --- a/courier/template/load_template.go +++ b/courier/template/load_template.go @@ -3,6 +3,7 @@ package template import ( "bytes" "embed" + "github.com/ory/x/fetcher" htemplate "html/template" "io" "io/fs" @@ -23,12 +24,31 @@ type Template interface { Execute(wr io.Writer, data interface{}) error } -func loadBuiltInTemplate(filesytem fs.FS, name string, html bool) (Template, error) { +type options struct { + root string + remoteURL string +} + +type LoadTemplateOption func(*options) + +func WithRemoteResource(url string) LoadTemplateOption { + return func(o *options) { + o.remoteURL = url + } +} + +func WithRemoteResourceRoot(root string) LoadTemplateOption { + return func(o *options) { + o.root = root + } +} + +func loadBuiltInTemplate(filesystem fs.FS, name string, html bool) (Template, error) { if t, found := cache.Get(name); found { return t.(Template), nil } - file, err := filesytem.Open(name) + file, err := filesystem.Open(name) if err != nil { // try to fallback to bundled templates var fallbackErr error @@ -65,6 +85,32 @@ func loadBuiltInTemplate(filesytem fs.FS, name string, html bool) (Template, err return tpl, nil } +func loadRemoteTemplate(url string, name string, html bool, root string) (Template, error) { + if t, found := cache.Get(name); found { + return t.(Template), nil + } + + f := fetcher.NewFetcher() + bb, err := f.Fetch(url) + if err != nil { + return nil, errors.WithStack(err) + } + + var t Template + if html { + t, err = htemplate.New(root).Funcs(sprig.HtmlFuncMap()).Parse(bb.String()) + if err != nil { + return nil, errors.WithStack(err) + } + } else { + t, err = template.New(root).Funcs(sprig.TxtFuncMap()).Parse(bb.String()) + if err != nil { + return nil, errors.WithStack(err) + } + } + return t, nil +} + func loadTemplate(filesystem fs.FS, name, pattern string, html bool) (Template, error) { if t, found := cache.Get(name); found { return t.(Template), nil @@ -106,9 +152,17 @@ func loadTemplate(filesystem fs.FS, name, pattern string, html bool) (Template, return tpl, nil } -func LoadTextTemplate(filesystem fs.FS, name, pattern string, model interface{}) (string, error) { - t, err := loadTemplate(filesystem, name, pattern, false) +func LoadTextTemplate(filesystem fs.FS, name, pattern string, model interface{}, remoteURL, remoteTemplateRoot string) (string, error) { + var t Template + var err error + if remoteURL != "" { + t, err = loadRemoteTemplate(remoteURL, name, false, remoteTemplateRoot) + if err != nil { + return "", err + } + } + t, err = loadTemplate(filesystem, name, pattern, false) if err != nil { return "", err } @@ -120,11 +174,19 @@ func LoadTextTemplate(filesystem fs.FS, name, pattern string, model interface{}) return b.String(), nil } -func LoadHTMLTemplate(filesystem fs.FS, name, pattern string, model interface{}) (string, error) { - t, err := loadTemplate(filesystem, name, pattern, true) - - if err != nil { - return "", err +func LoadHTMLTemplate(filesystem fs.FS, name, pattern string, model interface{}, remoteURL, remoteTemplateRoot string) (string, error) { + var t Template + var err error + if remoteURL != "" { + t, err = loadRemoteTemplate(remoteURL, name, true, remoteTemplateRoot) + if err != nil { + return "", err + } + } else { + t, err = loadTemplate(filesystem, name, pattern, true) + if err != nil { + return "", err + } } var b bytes.Buffer diff --git a/courier/template/load_template_test.go b/courier/template/load_template_test.go index d618b8cb861d..09d8364024f3 100644 --- a/courier/template/load_template_test.go +++ b/courier/template/load_template_test.go @@ -14,13 +14,13 @@ import ( func TestLoadTextTemplate(t *testing.T) { var executeTextTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string { - tp, err := LoadTextTemplate(os.DirFS(dir), name, pattern, model) + tp, err := LoadTextTemplate(os.DirFS(dir), name, pattern, model, "", "") require.NoError(t, err) return tp } var executeHTMLTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string { - tp, err := LoadHTMLTemplate(os.DirFS(dir), name, pattern, model) + tp, err := LoadHTMLTemplate(os.DirFS(dir), name, pattern, model, "", "") require.NoError(t, err) return tp } @@ -61,4 +61,24 @@ func TestLoadTextTemplate(t *testing.T) { require.NoError(t, os.RemoveAll(fp)) assert.Contains(t, executeTextTemplate(t, dir, name, "", nil), "cached stub body") }) + + t.Run("method=remote resource", func(t *testing.T) { + t.Run("case=base64 encoded data", func(t *testing.T) { + + }) + + t.Run("case=file resource", func(t *testing.T) { + m := map[string]interface{}{"lang": "en_US"} + tp, err := LoadHTMLTemplate(nil, "", "", m, + "file://courier/builtin/templates/test_stub/email.body.html.nested.gotmpl", + "base", + ) + require.NoError(t, err) + assert.Contains(t, tp, "lang=en_US") + }) + + t.Run("case=http resource", func(t *testing.T) { + + }) + }) } diff --git a/courier/template/recovery_invalid.go b/courier/template/recovery_invalid.go index d442132c652c..5feeda19c08c 100644 --- a/courier/template/recovery_invalid.go +++ b/courier/template/recovery_invalid.go @@ -24,15 +24,24 @@ func (t *RecoveryInvalid) EmailRecipient() (string, error) { } func (t *RecoveryInvalid) EmailSubject() (string, error) { - return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.subject.gotmpl", "recovery/invalid/email.subject*", t.m) + return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.subject.gotmpl", "recovery/invalid/email.subject*", t.m, + t.c.CourierTemplatesRecoveryInvalid().Subject, + t.c.CourierTemplatesRecoveryInvalid().TemplateRoot, + ) } func (t *RecoveryInvalid) EmailBody() (string, error) { - return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.body.gotmpl", "recovery/invalid/email.body*", t.m) + return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.body.gotmpl", "recovery/invalid/email.body*", t.m, + t.c.CourierTemplatesRecoveryInvalid().Body.HTML, + t.c.CourierTemplatesRecoveryInvalid().TemplateRoot, + ) } func (t *RecoveryInvalid) EmailBodyPlaintext() (string, error) { - return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.body.plaintext.gotmpl", "recovery/invalid/email.body.plaintext*", t.m) + return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.body.plaintext.gotmpl", "recovery/invalid/email.body.plaintext*", t.m, + t.c.CourierTemplatesRecoveryInvalid().Body.PlainText, + t.c.CourierTemplatesRecoveryInvalid().TemplateRoot, + ) } func (t *RecoveryInvalid) MarshalJSON() ([]byte, error) { diff --git a/courier/template/recovery_valid.go b/courier/template/recovery_valid.go index 5a164bdc406b..b5ddf67bf378 100644 --- a/courier/template/recovery_valid.go +++ b/courier/template/recovery_valid.go @@ -26,15 +26,24 @@ func (t *RecoveryValid) EmailRecipient() (string, error) { } func (t *RecoveryValid) EmailSubject() (string, error) { - return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.subject.gotmpl", "recovery/valid/email.subject*", t.m) + return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.subject.gotmpl", "recovery/valid/email.subject*", t.m, + t.c.CourierTemplatesRecoveryValid().Subject, + t.c.CourierTemplatesRecoveryValid().TemplateRoot, + ) } func (t *RecoveryValid) EmailBody() (string, error) { - return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.body.gotmpl", "recovery/valid/email.body*", t.m) + return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.body.gotmpl", "recovery/valid/email.body*", t.m, + t.c.CourierTemplatesRecoveryValid().Body.HTML, + t.c.CourierTemplatesRecoveryValid().TemplateRoot, + ) } func (t *RecoveryValid) EmailBodyPlaintext() (string, error) { - return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.body.plaintext.gotmpl", "recovery/valid/email.body.plaintext*", t.m) + return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.body.plaintext.gotmpl", "recovery/valid/email.body.plaintext*", t.m, + t.c.CourierTemplatesRecoveryValid().Body.PlainText, + t.c.CourierTemplatesRecoveryValid().TemplateRoot, + ) } func (t *RecoveryValid) MarshalJSON() ([]byte, error) { diff --git a/courier/template/stub.go b/courier/template/stub.go index 64481ee23f51..3101fa0c4ae3 100644 --- a/courier/template/stub.go +++ b/courier/template/stub.go @@ -25,15 +25,15 @@ func (t *TestStub) EmailRecipient() (string, error) { } func (t *TestStub) EmailSubject() (string, error) { - return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.subject.gotmpl", "test_stub/email.subject*", t.m) + return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.subject.gotmpl", "test_stub/email.subject*", t.m, "", "") } func (t *TestStub) EmailBody() (string, error) { - return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.body.gotmpl", "test_stub/email.body*", t.m) + return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.body.gotmpl", "test_stub/email.body*", t.m, "", "") } func (t *TestStub) EmailBodyPlaintext() (string, error) { - return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.body.plaintext.gotmpl", "test_stub/email.body.plaintext*", t.m) + return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.body.plaintext.gotmpl", "test_stub/email.body.plaintext*", t.m, "", "") } func (t *TestStub) MarshalJSON() ([]byte, error) { diff --git a/courier/template/template.go b/courier/template/template.go index 0486356fb0e8..2531e8f47b7c 100644 --- a/courier/template/template.go +++ b/courier/template/template.go @@ -1,7 +1,13 @@ package template +import "github.com/ory/kratos/driver/config" + type ( TemplateConfig interface { CourierTemplatesRoot() string + CourierTemplatesVerificationInvalid() *config.CourierEmailTemplate + CourierTemplatesVerificationValid() *config.CourierEmailTemplate + CourierTemplatesRecoveryInvalid() *config.CourierEmailTemplate + CourierTemplatesRecoveryValid() *config.CourierEmailTemplate } ) diff --git a/courier/template/verification_invalid.go b/courier/template/verification_invalid.go index fa457ee1ff07..d2c2bc005d58 100644 --- a/courier/template/verification_invalid.go +++ b/courier/template/verification_invalid.go @@ -24,15 +24,24 @@ func (t *VerificationInvalid) EmailRecipient() (string, error) { } func (t *VerificationInvalid) EmailSubject() (string, error) { - return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.subject.gotmpl", "verification/invalid/email.subject*", t.m) + return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.subject.gotmpl", "verification/invalid/email.subject*", t.m, + t.c.CourierTemplatesVerificationInvalid().Subject, + t.c.CourierTemplatesVerificationInvalid().TemplateRoot, + ) } func (t *VerificationInvalid) EmailBody() (string, error) { - return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.body.gotmpl", "verification/invalid/email.body*", t.m) + return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.body.gotmpl", "verification/invalid/email.body*", t.m, + t.c.CourierTemplatesVerificationInvalid().Body.HTML, + t.c.CourierTemplatesVerificationInvalid().TemplateRoot, + ) } func (t *VerificationInvalid) EmailBodyPlaintext() (string, error) { - return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.body.plaintext.gotmpl", "verification/invalid/email.body.plaintext*", t.m) + return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.body.plaintext.gotmpl", "verification/invalid/email.body.plaintext*", t.m, + t.c.CourierTemplatesVerificationInvalid().Body.PlainText, + t.c.CourierTemplatesVerificationInvalid().TemplateRoot, + ) } func (t *VerificationInvalid) MarshalJSON() ([]byte, error) { diff --git a/courier/template/verification_valid.go b/courier/template/verification_valid.go index 9973bd681689..1a9af18ff9d4 100644 --- a/courier/template/verification_valid.go +++ b/courier/template/verification_valid.go @@ -26,15 +26,24 @@ func (t *VerificationValid) EmailRecipient() (string, error) { } func (t *VerificationValid) EmailSubject() (string, error) { - return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.subject.gotmpl", "verification/valid/email.subject*", t.m) + return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.subject.gotmpl", "verification/valid/email.subject*", t.m, + t.c.CourierTemplatesVerificationValid().Subject, + t.c.CourierTemplatesVerificationValid().TemplateRoot, + ) } func (t *VerificationValid) EmailBody() (string, error) { - return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.body.gotmpl", "verification/valid/email.body*", t.m) + return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.body.gotmpl", "verification/valid/email.body*", t.m, + t.c.CourierTemplatesVerificationValid().Body.HTML, + t.c.CourierTemplatesVerificationValid().TemplateRoot, + ) } func (t *VerificationValid) EmailBodyPlaintext() (string, error) { - return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.body.plaintext.gotmpl", "verification/valid/email.body.plaintext*", t.m) + return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.body.plaintext.gotmpl", "verification/valid/email.body.plaintext*", t.m, + t.c.CourierTemplatesVerificationValid().Body.PlainText, + t.c.CourierTemplatesVerificationValid().TemplateRoot, + ) } func (t *VerificationValid) MarshalJSON() ([]byte, error) { diff --git a/driver/config/config.go b/driver/config/config.go index 70e9b256e4fd..806b9e9c9819 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -59,8 +59,10 @@ const ( ViperKeyDSN = "dsn" ViperKeyCourierSMTPURL = "courier.smtp.connection_uri" ViperKeyCourierTemplatesPath = "courier.template_override_path" - ViperKeyCourierTemplatesRecovery = "courier.templates.recovery" - ViperKeyCourierTemplatesVerification = "courier.templates.verification" + ViperKeyCourierTemplatesRecoveryInvalid = "courier.templates.recovery.invalid" + ViperKeyCourierTemplatesRecoveryValid = "courier.templates.recovery.valid" + ViperKeyCourierTemplatesVerificationInvalid = "courier.templates.verification.invalid" + ViperKeyCourierTemplatesVerificationValid = "courier.templates.verification.valid" ViperKeyCourierSMTPFrom = "courier.smtp.from_address" ViperKeyCourierSMTPFromName = "courier.smtp.from_name" ViperKeyCourierSMTPHeaders = "courier.smtp.headers" @@ -213,22 +215,16 @@ type ( HTML string `json:"html"` } CourierEmailTemplate struct { - Body *CourierEmailBodyTemplate `json:"body"` - Subject string `json:"subject"` + TemplateRoot string `json:"template_root"` + Body *CourierEmailBodyTemplate `json:"body"` + Subject string `json:"subject"` } - CourierFlowTemplate struct { - Invalid *CourierEmailTemplate `json:"invalid"` - Valid *CourierEmailTemplate `json:"valid"` - } - CourierVerificationTemplate CourierFlowTemplate - CourierRecoveryTemplate CourierFlowTemplate - Config struct { + Config struct { l *logrusx.Logger p *configx.Provider identitySchema *jsonschema.Schema stdOutOrErr io.Writer } - Provider interface { Config(ctx context.Context) *Config } @@ -838,20 +834,52 @@ func (p *Config) CourierTemplatesRoot() string { return p.p.StringF(ViperKeyCourierTemplatesPath, "courier/builtin/templates") } -func (p *Config) CourierTemplatesVerification() (*CourierVerificationTemplate, error) { - var templates *CourierVerificationTemplate - if err := p.p.Unmarshal(ViperKeyCourierTemplatesVerification, &templates); err != nil { - return nil, err +func (p *Config) courierTemplatesHelper(key string) *CourierEmailTemplate { + courierTemplate := &CourierEmailTemplate{ + TemplateRoot: "", + Body: &CourierEmailBodyTemplate{ + PlainText: "", + HTML: "", + }, + Subject: "", } - return templates, nil -} -func (p *Config) CourierTemplatesRecovery() (*CourierRecoveryTemplate, error) { - var templates *CourierRecoveryTemplate - if err := p.p.Unmarshal(ViperKeyCourierTemplatesRecovery, &templates); err != nil { - return nil, err + if !p.p.Exists(key) { + return courierTemplate } - return templates, nil + + out, err := p.p.Marshal(kjson.Parser()) + if err != nil { + p.l.WithError(err).Fatalf("Unable to dencode values from %s.", key) + return courierTemplate + } + + config := gjson.GetBytes(out, key).Raw + if len(config) == 0 { + return courierTemplate + } + + if err := json.NewDecoder(bytes.NewBufferString(config)).Decode(&courierTemplate); err != nil { + p.l.WithError(err).Fatalf("Unable to encode values from %s.", key) + return courierTemplate + } + return courierTemplate +} + +func (p *Config) CourierTemplatesVerificationInvalid() *CourierEmailTemplate { + return p.courierTemplatesHelper(ViperKeyCourierTemplatesVerificationInvalid) +} + +func (p *Config) CourierTemplatesVerificationValid() *CourierEmailTemplate { + return p.courierTemplatesHelper(ViperKeyCourierTemplatesVerificationValid) +} + +func (p *Config) CourierTemplatesRecoveryInvalid() *CourierEmailTemplate { + return p.courierTemplatesHelper(ViperKeyCourierTemplatesRecoveryInvalid) +} + +func (p *Config) CourierTemplatesRecoveryValid() *CourierEmailTemplate { + return p.courierTemplatesHelper(ViperKeyCourierTemplatesRecoveryValid) } func (p *Config) CourierSMTPHeaders() map[string]string { diff --git a/embedx/config.schema.json b/embedx/config.schema.json index dd862c23ccc9..529469cedfa4 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -2277,6 +2277,10 @@ "email": { "type": "object", "properties": { + "template_root": { + "type": "string", + "description": "The entry point for the template when using nested templates. This is optional as the template does not need to define an entry point." + }, "body": { "type": "object", "properties": {