Skip to content

Commit

Permalink
fix: template tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Alano Terblanche committed Jan 31, 2022
1 parent 84b0bb3 commit 658c2a3
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 638 deletions.
14 changes: 7 additions & 7 deletions courier/template/load_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
//go:embed courier/builtin/templates/*
var templates embed.FS

var cache, _ = lru.New(16)
var Cache, _ = lru.New(16)

type Template interface {
Execute(wr io.Writer, data interface{}) error
Expand All @@ -31,7 +31,7 @@ type templateDependencies interface {
}

func loadBuiltInTemplate(filesystem fs.FS, name string, html bool) (Template, error) {
if t, found := cache.Get(name); found {
if t, found := Cache.Get(name); found {
return t.(Template), nil
}

Expand Down Expand Up @@ -68,16 +68,16 @@ func loadBuiltInTemplate(filesystem fs.FS, name string, html bool) (Template, er
tpl = t
}

_ = cache.Add(name, tpl)
_ = Cache.Add(name, tpl)
return tpl, nil
}

func loadRemoteTemplate(ctx context.Context, d templateDependencies, url string, name string, html bool, root string) (Template, error) {
if t, found := cache.Get(name); found {
if t, found := Cache.Get(name); found {
return t.(Template), nil
}

f := fetcher.NewFetcher(fetcher.WithClient(d.HTTPClient(ctx).StandardClient()))
f := fetcher.NewFetcher(fetcher.WithClient(d.HTTPClient(ctx).HTTPClient))

bb, err := f.Fetch(url)
if err != nil {
Expand All @@ -100,7 +100,7 @@ func loadRemoteTemplate(ctx context.Context, d templateDependencies, url string,
}

func loadTemplate(filesystem fs.FS, name, pattern string, html bool) (Template, error) {
if t, found := cache.Get(name); found {
if t, found := Cache.Get(name); found {
return t.(Template), nil
}

Expand Down Expand Up @@ -136,7 +136,7 @@ func loadTemplate(filesystem fs.FS, name, pattern string, html bool) (Template,
tpl = t
}

_ = cache.Add(name, tpl)
_ = Cache.Add(name, tpl)
return tpl, nil
}

Expand Down
76 changes: 63 additions & 13 deletions courier/template/load_template_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package template
package template_test

import (
"context"
"encoding/base64"
"github.com/julienschmidt/httprouter"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/internal"
"github.com/ory/x/fetcher"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"

lru "github.com/hashicorp/golang-lru"
"github.com/stretchr/testify/assert"
Expand All @@ -19,13 +25,17 @@ import (

func TestLoadTextTemplate(t *testing.T) {
var executeTextTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string {
tp, err := LoadTextTemplate(os.DirFS(dir), name, pattern, model, "", "")
ctx := context.Background()
_, reg := internal.NewFastRegistryWithMocks(t)
tp, err := template.LoadTextTemplate(ctx, reg, os.DirFS(dir), name, pattern, model, "", "")
require.NoError(t, err)
return tp
}

var executeHTMLTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string {
tp, err := LoadHTMLTemplate(os.DirFS(dir), name, pattern, model, "", "")
ctx := context.Background()
_, reg := internal.NewFastRegistryWithMocks(t)
tp, err := template.LoadHTMLTemplate(ctx, reg, os.DirFS(dir), name, pattern, model, "", "")
require.NoError(t, err)
return tp
}
Expand All @@ -36,26 +46,26 @@ func TestLoadTextTemplate(t *testing.T) {
})

t.Run("method=fallback to bundled", func(t *testing.T) {
cache, _ = lru.New(16) // prevent cache hit
template.Cache, _ = lru.New(16) // prevent Cache hit
actual := executeTextTemplate(t, "some/inexistent/dir", "test_stub/email.body.gotmpl", "", nil)
assert.Contains(t, actual, "stub email")
})

t.Run("method=with Sprig functions", func(t *testing.T) {
cache, _ = lru.New(16) // prevent cache hit
template.Cache, _ = lru.New(16) // prevent Cache hit
m := map[string]interface{}{"input": "hello world"} // create a simple model
actual := executeTextTemplate(t, "courier/builtin/templates/test_stub", "email.body.sprig.gotmpl", "", m)
assert.Contains(t, actual, "HelloWorld,HELLOWORLD")
})

t.Run("method=html with nested templates", func(t *testing.T) {
cache, _ = lru.New(16) // prevent cache hit
template.Cache, _ = lru.New(16) // prevent Cache hit
m := map[string]interface{}{"lang": "en_US"} // create a simple model
actual := executeHTMLTemplate(t, "courier/builtin/templates/test_stub", "email.body.html.gotmpl", "email.body.html*", m)
assert.Contains(t, actual, "lang=en_US")
})

t.Run("method=cache works", func(t *testing.T) {
t.Run("method=Cache works", func(t *testing.T) {
dir := os.TempDir()
name := x.NewUUID().String() + ".body.gotmpl"
fp := filepath.Join(dir, name)
Expand All @@ -68,13 +78,18 @@ func TestLoadTextTemplate(t *testing.T) {
})

t.Run("method=remote resource", func(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

t.Run("case=base64 encoded data", func(t *testing.T) {
t.Run("html template", func(t *testing.T) {
m := map[string]interface{}{"lang": "en_US"}
f, err := ioutil.ReadFile("courier/builtin/templates/test_stub/email.body.html.nested.gotmpl")
require.NoError(t, err)
b64 := base64.StdEncoding.EncodeToString(f)
tp, err := LoadHTMLTemplate(nil, "", "", m, "base64://"+b64, "base")
tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", m, "base64://"+b64, "base")
require.NoError(t, err)
assert.Contains(t, tp, "lang=en_US")
})
Expand All @@ -86,7 +101,7 @@ func TestLoadTextTemplate(t *testing.T) {

b64 := base64.StdEncoding.EncodeToString(f)

tp, err := LoadTextTemplate(nil, "", "", m,
tp, err := template.LoadTextTemplate(ctx, reg, nil, "", "", m,
"base64://"+b64, "base")
require.NoError(t, err)
assert.Contains(t, tp, "stub email body something")
Expand All @@ -97,7 +112,7 @@ func TestLoadTextTemplate(t *testing.T) {
t.Run("case=file resource", func(t *testing.T) {
t.Run("case=html template", func(t *testing.T) {
m := map[string]interface{}{"lang": "en_US"}
tp, err := LoadHTMLTemplate(nil, "", "", m,
tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", m,
"file://courier/builtin/templates/test_stub/email.body.html.nested.gotmpl",
"base",
)
Expand All @@ -107,7 +122,7 @@ func TestLoadTextTemplate(t *testing.T) {

t.Run("case=plaintext", func(t *testing.T) {
m := map[string]interface{}{"Body": "something"}
tp, err := LoadTextTemplate(nil, "", "", m,
tp, err := template.LoadTextTemplate(ctx, reg, nil, "", "", m,
"file://courier/builtin/templates/test_stub/email.body.plaintext.gotmpl",
"base")
require.NoError(t, err)
Expand All @@ -128,7 +143,7 @@ func TestLoadTextTemplate(t *testing.T) {

t.Run("case=html template", func(t *testing.T) {
m := map[string]interface{}{"lang": "en_US"}
tp, err := LoadHTMLTemplate(nil, "", "", m,
tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", m,
ts.URL+"/html",
"base",
)
Expand All @@ -138,11 +153,46 @@ func TestLoadTextTemplate(t *testing.T) {

t.Run("case=plaintext", func(t *testing.T) {
m := map[string]interface{}{"Body": "something"}
tp, err := LoadTextTemplate(nil, "", "", m, ts.URL+"/plaintext", "base")
tp, err := template.LoadTextTemplate(ctx, reg, nil, "", "", m, ts.URL+"/plaintext", "base")
require.NoError(t, err)
assert.Contains(t, tp, "stub email body something")
})

})

t.Run("case=unsupported resource", func(t *testing.T) {
tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", map[string]interface{}{},
"grpc://unsupported-url",
"")

require.ErrorIs(t, err, fetcher.ErrUnknownScheme)
require.Empty(t, tp)

tp, err = template.LoadTextTemplate(ctx, reg, nil, "", "", map[string]interface{}{},
"grpc://unsupported-url",
"")
require.ErrorIs(t, err, fetcher.ErrUnknownScheme)
require.Empty(t, tp)
})

t.Run("case=disallowed resources", func(t *testing.T) {
require.NoError(t, reg.Config(ctx).Source().Set(config.ViperKeyClientHTTPNoPrivateIPRanges, true))
reg.HTTPClient(ctx).RetryMax = 1
reg.HTTPClient(ctx).RetryWaitMax = time.Millisecond

_, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", map[string]interface{}{},
"http://localhost:8080/1234",
"")

require.Error(t, err)
assert.Contains(t, err.Error(), "is in the")

_, err = template.LoadTextTemplate(ctx, reg, nil, "", "", map[string]interface{}{},
"http://localhost:8080/1234",
"")
require.Error(t, err)
assert.Contains(t, err.Error(), "is in the")

})
})
}
21 changes: 12 additions & 9 deletions courier/template/recovery_invalid_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package template_test

import (
"context"
"encoding/base64"
"github.com/julienschmidt/httprouter"
"github.com/ory/kratos/courier/template/testhelpers"
Expand All @@ -16,12 +17,14 @@ import (
)

func TestRecoverInvalid(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

t.Run("test=with courier templates directory", func(t *testing.T) {
conf, _ := internal.NewFastRegistryWithMocks(t)
tpl := template.NewRecoveryInvalid(conf, &template.RecoveryInvalidModel{})
_, reg := internal.NewFastRegistryWithMocks(t)
tpl := template.NewRecoveryInvalid(reg, &template.RecoveryInvalidModel{})

testhelpers.TestRendered(t, tpl)
testhelpers.TestRendered(t, ctx, tpl)
})

t.Run("test=with remote resources", func(t *testing.T) {
Expand All @@ -36,13 +39,13 @@ func TestRecoverInvalid(t *testing.T) {
ts := httptest.NewServer(router)
defer ts.Close()

tpl := template.NewRecoveryInvalid(testhelpers.SetupRemoteConfig(t,
tpl := template.NewRecoveryInvalid(testhelpers.SetupRemoteConfig(t, ctx,
ts.URL+"/email.body.plaintext.gotmpl",
ts.URL+"/email.body.gotmpl",
ts.URL+"/email.subject.gotmpl"),
&template.RecoveryInvalidModel{})

testhelpers.TestRendered(t, tpl)
testhelpers.TestRendered(t, ctx, tpl)
})

t.Run("case=base64 resource", func(t *testing.T) {
Expand All @@ -54,25 +57,25 @@ func TestRecoverInvalid(t *testing.T) {
return base64.StdEncoding.EncodeToString(f)
}

tpl := template.NewRecoveryInvalid(testhelpers.SetupRemoteConfig(t,
tpl := template.NewRecoveryInvalid(testhelpers.SetupRemoteConfig(t, ctx,
toBase64(baseUrl+"email.body.plaintext.gotmpl"),
toBase64(baseUrl+"email.body.gotmpl"),
toBase64(baseUrl+"email.subject.gotmpl")),
&template.RecoveryInvalidModel{})
testhelpers.TestRendered(t, tpl)
testhelpers.TestRendered(t, ctx, tpl)
})

t.Run("case=file resource", func(t *testing.T) {
baseUrl := "file://courier/builtin/templates/recovery/invalid/"

tpl := template.NewRecoveryInvalid(testhelpers.SetupRemoteConfig(t,
tpl := template.NewRecoveryInvalid(testhelpers.SetupRemoteConfig(t, ctx,
baseUrl+"email.body.plaintext.gotmpl",
baseUrl+"email.body.gotmpl",
baseUrl+"email.subject.gotmpl"),
&template.RecoveryInvalidModel{},
)

testhelpers.TestRendered(t, tpl)
testhelpers.TestRendered(t, ctx, tpl)
})
})
}
23 changes: 13 additions & 10 deletions courier/template/recovery_valid_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package template_test

import (
"context"
"encoding/base64"
"github.com/julienschmidt/httprouter"
"github.com/ory/kratos/courier/template/testhelpers"
Expand All @@ -16,12 +17,14 @@ import (
)

func TestRecoverValid(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

t.Run("test=with courier templates directory", func(t *testing.T) {
conf, _ := internal.NewFastRegistryWithMocks(t)
tpl := template.NewRecoveryValid(conf, &template.RecoveryValidModel{})
_, reg := internal.NewFastRegistryWithMocks(t)
tpl := template.NewRecoveryValid(reg, &template.RecoveryValidModel{})

testhelpers.TestRendered(t, tpl)
testhelpers.TestRendered(t, ctx, tpl)
})

t.Run("test=with remote resources", func(t *testing.T) {
Expand All @@ -36,13 +39,13 @@ func TestRecoverValid(t *testing.T) {
ts := httptest.NewServer(router)
defer ts.Close()

tpl := template.NewRecoveryValid(testhelpers.SetupRemoteConfig(t,
tpl := template.NewRecoveryValid(testhelpers.SetupRemoteConfig(t, ctx,
ts.URL+"/email.body.plaintext.gotmpl",
ts.URL+"/email.body.gotmpl",
ts.URL+"/email.subject.gotmpl"),
&template.RecoveryValidModel{})

testhelpers.TestRendered(t, tpl)
testhelpers.TestRendered(t, ctx, tpl)
})

t.Run("case=base64 resource", func(t *testing.T) {
Expand All @@ -54,25 +57,25 @@ func TestRecoverValid(t *testing.T) {
return base64.StdEncoding.EncodeToString(f)
}

tpl := template.NewRecoveryValid(testhelpers.SetupRemoteConfig(t,
tpl := template.NewRecoveryValid(testhelpers.SetupRemoteConfig(t, ctx,
toBase64(baseUrl+"email.body.plaintext.gotmpl"),
toBase64(baseUrl+"email.body.gotmpl"),
toBase64(baseUrl+"email.subject.gotmpl")),
&template.RecoveryValidModel{})
testhelpers.TestRendered(t, tpl)
testhelpers.TestRendered(t, ctx, tpl)
})

t.Run("case=file resource", func(t *testing.T) {
baseUrl := "file://courier/builtin/templates/recovery/valid/"

tpl := template.NewRecoveryValid(testhelpers.SetupRemoteConfig(t,
tpl := template.NewRecoveryValid(testhelpers.SetupRemoteConfig(t, ctx,
baseUrl+"email.body.plaintext.gotmpl",
baseUrl+"email.body.gotmpl",
baseUrl+"email.subject.gotmpl"),
&template.RecoveryValidModel{},
)

testhelpers.TestRendered(t, tpl)
testhelpers.TestRendered(t, ctx, tpl)
})
})
}
Loading

0 comments on commit 658c2a3

Please sign in to comment.