Skip to content

Commit

Permalink
feat: load templates with uri
Browse files Browse the repository at this point in the history
  • Loading branch information
Benehiko committed Jan 24, 2022
1 parent f3d2175 commit f32ef51
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 49 deletions.
5 changes: 5 additions & 0 deletions courier/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/ory/kratos/driver/config"
"net/url"
"strconv"
"time"
Expand All @@ -27,6 +28,10 @@ type (
CourierSMTPFromName() string
CourierSMTPHeaders() map[string]string
CourierTemplatesRoot() string
CourierTemplatesVerificationInvalid() *config.CourierEmailTemplate
CourierTemplatesVerificationValid() *config.CourierEmailTemplate
CourierTemplatesRecoveryInvalid() *config.CourierEmailTemplate
CourierTemplatesRecoveryValid() *config.CourierEmailTemplate
}
SMTPDependencies interface {
PersistenceProvider
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{{define "en_US"}}
{{ $l := cat "lang=" .lang }}
{{ nospace $l }}
{{end}}

{{define "base"}}
{{- if eq .lang "en_US" -}}
{{ template "en_US" . }}
{{- end -}}
{{end}}
80 changes: 71 additions & 9 deletions courier/template/load_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package template
import (
"bytes"
"embed"
"github.com/ory/x/fetcher"
htemplate "html/template"
"io"
"io/fs"
Expand All @@ -23,12 +24,31 @@ type Template interface {
Execute(wr io.Writer, data interface{}) error
}

func loadBuiltInTemplate(filesytem fs.FS, name string, html bool) (Template, error) {
type options struct {
root string
remoteURL string
}

type LoadTemplateOption func(*options)

func WithRemoteResource(url string) LoadTemplateOption {
return func(o *options) {
o.remoteURL = url
}
}

func WithRemoteResourceRoot(root string) LoadTemplateOption {
return func(o *options) {
o.root = root
}
}

func loadBuiltInTemplate(filesystem fs.FS, name string, html bool) (Template, error) {
if t, found := cache.Get(name); found {
return t.(Template), nil
}

file, err := filesytem.Open(name)
file, err := filesystem.Open(name)
if err != nil {
// try to fallback to bundled templates
var fallbackErr error
Expand Down Expand Up @@ -65,6 +85,32 @@ func loadBuiltInTemplate(filesytem fs.FS, name string, html bool) (Template, err
return tpl, nil
}

func loadRemoteTemplate(url string, name string, html bool, root string) (Template, error) {
if t, found := cache.Get(name); found {
return t.(Template), nil
}

f := fetcher.NewFetcher()
bb, err := f.Fetch(url)
if err != nil {
return nil, errors.WithStack(err)
}

var t Template
if html {
t, err = htemplate.New(root).Funcs(sprig.HtmlFuncMap()).Parse(bb.String())
if err != nil {
return nil, errors.WithStack(err)
}
} else {
t, err = template.New(root).Funcs(sprig.TxtFuncMap()).Parse(bb.String())
if err != nil {
return nil, errors.WithStack(err)
}
}
return t, nil
}

func loadTemplate(filesystem fs.FS, name, pattern string, html bool) (Template, error) {
if t, found := cache.Get(name); found {
return t.(Template), nil
Expand Down Expand Up @@ -106,9 +152,17 @@ func loadTemplate(filesystem fs.FS, name, pattern string, html bool) (Template,
return tpl, nil
}

func LoadTextTemplate(filesystem fs.FS, name, pattern string, model interface{}) (string, error) {
t, err := loadTemplate(filesystem, name, pattern, false)
func LoadTextTemplate(filesystem fs.FS, name, pattern string, model interface{}, remoteURL, remoteTemplateRoot string) (string, error) {
var t Template
var err error
if remoteURL != "" {
t, err = loadRemoteTemplate(remoteURL, name, false, remoteTemplateRoot)
if err != nil {
return "", err
}
}

t, err = loadTemplate(filesystem, name, pattern, false)
if err != nil {
return "", err
}
Expand All @@ -120,11 +174,19 @@ func LoadTextTemplate(filesystem fs.FS, name, pattern string, model interface{})
return b.String(), nil
}

func LoadHTMLTemplate(filesystem fs.FS, name, pattern string, model interface{}) (string, error) {
t, err := loadTemplate(filesystem, name, pattern, true)

if err != nil {
return "", err
func LoadHTMLTemplate(filesystem fs.FS, name, pattern string, model interface{}, remoteURL, remoteTemplateRoot string) (string, error) {
var t Template
var err error
if remoteURL != "" {
t, err = loadRemoteTemplate(remoteURL, name, true, remoteTemplateRoot)
if err != nil {
return "", err
}
} else {
t, err = loadTemplate(filesystem, name, pattern, true)
if err != nil {
return "", err
}
}

var b bytes.Buffer
Expand Down
24 changes: 22 additions & 2 deletions courier/template/load_template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ import (

func TestLoadTextTemplate(t *testing.T) {
var executeTextTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string {
tp, err := LoadTextTemplate(os.DirFS(dir), name, pattern, model)
tp, err := LoadTextTemplate(os.DirFS(dir), name, pattern, model, "", "")
require.NoError(t, err)
return tp
}

var executeHTMLTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string {
tp, err := LoadHTMLTemplate(os.DirFS(dir), name, pattern, model)
tp, err := LoadHTMLTemplate(os.DirFS(dir), name, pattern, model, "", "")
require.NoError(t, err)
return tp
}
Expand Down Expand Up @@ -61,4 +61,24 @@ func TestLoadTextTemplate(t *testing.T) {
require.NoError(t, os.RemoveAll(fp))
assert.Contains(t, executeTextTemplate(t, dir, name, "", nil), "cached stub body")
})

t.Run("method=remote resource", func(t *testing.T) {
t.Run("case=base64 encoded data", func(t *testing.T) {

})

t.Run("case=file resource", func(t *testing.T) {
m := map[string]interface{}{"lang": "en_US"}
tp, err := LoadHTMLTemplate(nil, "", "", m,
"file://courier/builtin/templates/test_stub/email.body.html.nested.gotmpl",
"base",
)
require.NoError(t, err)
assert.Contains(t, tp, "lang=en_US")
})

t.Run("case=http resource", func(t *testing.T) {

})
})
}
15 changes: 12 additions & 3 deletions courier/template/recovery_invalid.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,24 @@ func (t *RecoveryInvalid) EmailRecipient() (string, error) {
}

func (t *RecoveryInvalid) EmailSubject() (string, error) {
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.subject.gotmpl", "recovery/invalid/email.subject*", t.m)
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.subject.gotmpl", "recovery/invalid/email.subject*", t.m,
t.c.CourierTemplatesRecoveryInvalid().Subject,
t.c.CourierTemplatesRecoveryInvalid().TemplateRoot,
)
}

func (t *RecoveryInvalid) EmailBody() (string, error) {
return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.body.gotmpl", "recovery/invalid/email.body*", t.m)
return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.body.gotmpl", "recovery/invalid/email.body*", t.m,
t.c.CourierTemplatesRecoveryInvalid().Body.HTML,
t.c.CourierTemplatesRecoveryInvalid().TemplateRoot,
)
}

func (t *RecoveryInvalid) EmailBodyPlaintext() (string, error) {
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.body.plaintext.gotmpl", "recovery/invalid/email.body.plaintext*", t.m)
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/invalid/email.body.plaintext.gotmpl", "recovery/invalid/email.body.plaintext*", t.m,
t.c.CourierTemplatesRecoveryInvalid().Body.PlainText,
t.c.CourierTemplatesRecoveryInvalid().TemplateRoot,
)
}

func (t *RecoveryInvalid) MarshalJSON() ([]byte, error) {
Expand Down
15 changes: 12 additions & 3 deletions courier/template/recovery_valid.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,24 @@ func (t *RecoveryValid) EmailRecipient() (string, error) {
}

func (t *RecoveryValid) EmailSubject() (string, error) {
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.subject.gotmpl", "recovery/valid/email.subject*", t.m)
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.subject.gotmpl", "recovery/valid/email.subject*", t.m,
t.c.CourierTemplatesRecoveryValid().Subject,
t.c.CourierTemplatesRecoveryValid().TemplateRoot,
)
}

func (t *RecoveryValid) EmailBody() (string, error) {
return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.body.gotmpl", "recovery/valid/email.body*", t.m)
return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.body.gotmpl", "recovery/valid/email.body*", t.m,
t.c.CourierTemplatesRecoveryValid().Body.HTML,
t.c.CourierTemplatesRecoveryValid().TemplateRoot,
)
}

func (t *RecoveryValid) EmailBodyPlaintext() (string, error) {
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.body.plaintext.gotmpl", "recovery/valid/email.body.plaintext*", t.m)
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "recovery/valid/email.body.plaintext.gotmpl", "recovery/valid/email.body.plaintext*", t.m,
t.c.CourierTemplatesRecoveryValid().Body.PlainText,
t.c.CourierTemplatesRecoveryValid().TemplateRoot,
)
}

func (t *RecoveryValid) MarshalJSON() ([]byte, error) {
Expand Down
6 changes: 3 additions & 3 deletions courier/template/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ func (t *TestStub) EmailRecipient() (string, error) {
}

func (t *TestStub) EmailSubject() (string, error) {
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.subject.gotmpl", "test_stub/email.subject*", t.m)
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.subject.gotmpl", "test_stub/email.subject*", t.m, "", "")
}

func (t *TestStub) EmailBody() (string, error) {
return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.body.gotmpl", "test_stub/email.body*", t.m)
return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.body.gotmpl", "test_stub/email.body*", t.m, "", "")
}

func (t *TestStub) EmailBodyPlaintext() (string, error) {
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.body.plaintext.gotmpl", "test_stub/email.body.plaintext*", t.m)
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "test_stub/email.body.plaintext.gotmpl", "test_stub/email.body.plaintext*", t.m, "", "")
}

func (t *TestStub) MarshalJSON() ([]byte, error) {
Expand Down
6 changes: 6 additions & 0 deletions courier/template/template.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package template

import "github.com/ory/kratos/driver/config"

type (
TemplateConfig interface {
CourierTemplatesRoot() string
CourierTemplatesVerificationInvalid() *config.CourierEmailTemplate
CourierTemplatesVerificationValid() *config.CourierEmailTemplate
CourierTemplatesRecoveryInvalid() *config.CourierEmailTemplate
CourierTemplatesRecoveryValid() *config.CourierEmailTemplate
}
)
15 changes: 12 additions & 3 deletions courier/template/verification_invalid.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,24 @@ func (t *VerificationInvalid) EmailRecipient() (string, error) {
}

func (t *VerificationInvalid) EmailSubject() (string, error) {
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.subject.gotmpl", "verification/invalid/email.subject*", t.m)
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.subject.gotmpl", "verification/invalid/email.subject*", t.m,
t.c.CourierTemplatesVerificationInvalid().Subject,
t.c.CourierTemplatesVerificationInvalid().TemplateRoot,
)
}

func (t *VerificationInvalid) EmailBody() (string, error) {
return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.body.gotmpl", "verification/invalid/email.body*", t.m)
return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.body.gotmpl", "verification/invalid/email.body*", t.m,
t.c.CourierTemplatesVerificationInvalid().Body.HTML,
t.c.CourierTemplatesVerificationInvalid().TemplateRoot,
)
}

func (t *VerificationInvalid) EmailBodyPlaintext() (string, error) {
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.body.plaintext.gotmpl", "verification/invalid/email.body.plaintext*", t.m)
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/invalid/email.body.plaintext.gotmpl", "verification/invalid/email.body.plaintext*", t.m,
t.c.CourierTemplatesVerificationInvalid().Body.PlainText,
t.c.CourierTemplatesVerificationInvalid().TemplateRoot,
)
}

func (t *VerificationInvalid) MarshalJSON() ([]byte, error) {
Expand Down
15 changes: 12 additions & 3 deletions courier/template/verification_valid.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,24 @@ func (t *VerificationValid) EmailRecipient() (string, error) {
}

func (t *VerificationValid) EmailSubject() (string, error) {
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.subject.gotmpl", "verification/valid/email.subject*", t.m)
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.subject.gotmpl", "verification/valid/email.subject*", t.m,
t.c.CourierTemplatesVerificationValid().Subject,
t.c.CourierTemplatesVerificationValid().TemplateRoot,
)
}

func (t *VerificationValid) EmailBody() (string, error) {
return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.body.gotmpl", "verification/valid/email.body*", t.m)
return LoadHTMLTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.body.gotmpl", "verification/valid/email.body*", t.m,
t.c.CourierTemplatesVerificationValid().Body.HTML,
t.c.CourierTemplatesVerificationValid().TemplateRoot,
)
}

func (t *VerificationValid) EmailBodyPlaintext() (string, error) {
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.body.plaintext.gotmpl", "verification/valid/email.body.plaintext*", t.m)
return LoadTextTemplate(os.DirFS(t.c.CourierTemplatesRoot()), "verification/valid/email.body.plaintext.gotmpl", "verification/valid/email.body.plaintext*", t.m,
t.c.CourierTemplatesVerificationValid().Body.PlainText,
t.c.CourierTemplatesVerificationValid().TemplateRoot,
)
}

func (t *VerificationValid) MarshalJSON() ([]byte, error) {
Expand Down
Loading

0 comments on commit f32ef51

Please sign in to comment.