Skip to content

Commit

Permalink
fix: ensure no internal networks can be called in SMS sender
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Feb 23, 2022
1 parent aa9c8d7 commit 65e42e5
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 27 deletions.
12 changes: 9 additions & 3 deletions courier/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type (
QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error)
QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error)
SmtpDialer() *gomail.Dialer
DispatchQueue(ctx context.Context) error
}

Provider interface {
Expand All @@ -39,9 +40,10 @@ type (
}

courier struct {
smsClient *smsClient
smtpClient *smtpClient
deps Dependencies
smsClient *smsClient
smtpClient *smtpClient
deps Dependencies
failOnError bool
}
)

Expand All @@ -53,6 +55,10 @@ func NewCourier(ctx context.Context, deps Dependencies) Courier {
}
}

func (c *courier) FailOnDispatchError() {
c.failOnError = true
}

func (c *courier) Work(ctx context.Context) error {
errChan := make(chan error)
defer close(errChan)
Expand Down
3 changes: 3 additions & 0 deletions courier/dispatcher.go → courier/courier_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func (c *courier) DispatchQueue(ctx context.Context) error {
if err := c.DispatchMessage(ctx, msg); err != nil {
for _, replace := range messages[k:] {
if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil {
if c.failOnError {
return err
}
c.deps.Logger().
WithError(err).
WithField("message_id", replace.ID).
Expand Down
4 changes: 1 addition & 3 deletions courier/sms.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ type sendSMSRequestBody struct {
}

type smsClient struct {
*http.Client
RequestConfig json.RawMessage

GetTemplateType func(t SMSTemplate) (TemplateType, error)
Expand All @@ -31,7 +30,6 @@ func newSMS(ctx context.Context, deps Dependencies) *smsClient {
}

return &smsClient{
Client: &http.Client{},
RequestConfig: deps.CourierConfig(ctx).CourierSMSRequestConfig(),

GetTemplateType: SMSTemplateType,
Expand Down Expand Up @@ -94,7 +92,7 @@ func (c *courier) dispatchSMS(ctx context.Context, msg Message) error {
return err
}

res, err := c.smsClient.Do(req)
res, err := c.deps.HTTPClient(ctx).Do(req)
if err != nil {
return err
}
Expand Down
28 changes: 28 additions & 0 deletions courier/sms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,31 @@ func TestQueueSMS(t *testing.T) {

srv.Close()
}

func TestDisallowedInternalNetwork(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
conf.MustSet(config.ViperKeyCourierSMSRequestConfig, fmt.Sprintf(`{
"url": "http://127.0.0.1/",
"method": "GET",
"body": "file://./stub/request.config.twilio.jsonnet"
}`))
conf.MustSet(config.ViperKeyCourierSMSEnabled, true)
conf.MustSet(config.ViperKeyCourierSMTPURL, "http://foo.url")
conf.MustSet(config.ViperKeyClientHTTPNoPrivateIPRanges, true)
reg.Logger().Level = logrus.TraceLevel

ctx := context.Background()
c := reg.Courier(ctx)
c.(interface {
FailOnDispatchError()
}).FailOnDispatchError()
_, err := c.QueueSMS(ctx, sms.NewTestStub(reg, &sms.TestStubModel{
To: "+12065550101",
Body: "test-sms-body-1",
}))
require.NoError(t, err)

err = c.DispatchQueue(ctx)
require.Error(t, err)
assert.Contains(t, err.Error(), "ip 127.0.0.1 is in the 127.0.0.0/8 range")
}
5 changes: 3 additions & 2 deletions request/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package request
import (
"encoding/json"
"fmt"
"net/http"

"github.com/hashicorp/go-retryablehttp"
)

type (
AuthStrategy interface {
apply(req *http.Request)
apply(req *retryablehttp.Request)
}

authStrategyFactory func(c json.RawMessage) (AuthStrategy, error)
Expand Down
8 changes: 5 additions & 3 deletions request/auth_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package request
import (
"encoding/json"
"net/http"

"github.com/hashicorp/go-retryablehttp"
)

type (
Expand All @@ -24,7 +26,7 @@ func newNoopAuthStrategy(_ json.RawMessage) (AuthStrategy, error) {
return &noopAuthStrategy{}, nil
}

func (c *noopAuthStrategy) apply(_ *http.Request) {}
func (c *noopAuthStrategy) apply(_ *retryablehttp.Request) {}

func newBasicAuthStrategy(raw json.RawMessage) (AuthStrategy, error) {
type config struct {
Expand All @@ -43,7 +45,7 @@ func newBasicAuthStrategy(raw json.RawMessage) (AuthStrategy, error) {
}, nil
}

func (c *basicAuthStrategy) apply(req *http.Request) {
func (c *basicAuthStrategy) apply(req *retryablehttp.Request) {
req.SetBasicAuth(c.user, c.password)
}

Expand All @@ -66,7 +68,7 @@ func newApiKeyStrategy(raw json.RawMessage) (AuthStrategy, error) {
}, nil
}

func (c *apiKeyStrategy) apply(req *http.Request) {
func (c *apiKeyStrategy) apply(req *retryablehttp.Request) {
switch c.in {
case "cookie":
req.AddCookie(&http.Cookie{Name: c.name, Value: c.value})
Expand Down
10 changes: 6 additions & 4 deletions request/auth_strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"net/http"
"testing"

"github.com/hashicorp/go-retryablehttp"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNoopAuthStrategy(t *testing.T) {
req := http.Request{Header: map[string][]string{}}
req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}}
auth := noopAuthStrategy{}

auth.apply(&req)
Expand All @@ -18,7 +20,7 @@ func TestNoopAuthStrategy(t *testing.T) {
}

func TestBasicAuthStrategy(t *testing.T) {
req := http.Request{Header: map[string][]string{}}
req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}}
auth := basicAuthStrategy{
user: "test-user",
password: "test-pass",
Expand All @@ -34,7 +36,7 @@ func TestBasicAuthStrategy(t *testing.T) {
}

func TestApiKeyInHeaderStrategy(t *testing.T) {
req := http.Request{Header: map[string][]string{}}
req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}}
auth := apiKeyStrategy{
in: "header",
name: "my-api-key-name",
Expand All @@ -50,7 +52,7 @@ func TestApiKeyInHeaderStrategy(t *testing.T) {
}

func TestApiKeyInCookieStrategy(t *testing.T) {
req := http.Request{Header: map[string][]string{}}
req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}}
auth := apiKeyStrategy{
in: "cookie",
name: "my-api-key-name",
Expand Down
6 changes: 3 additions & 3 deletions request/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const (
)

type Builder struct {
r *http.Request
r *retryablehttp.Request
log *logrusx.Logger
conf *Config
fetchClient *retryablehttp.Client
Expand All @@ -35,7 +35,7 @@ func NewBuilder(config json.RawMessage, client *retryablehttp.Client, l *logrusx
return nil, err
}

r, err := http.NewRequest(c.Method, c.URL, nil)
r, err := retryablehttp.NewRequest(c.Method, c.URL, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -153,7 +153,7 @@ func (b *Builder) addURLEncodedBody(template *bytes.Buffer, body interface{}) er
return nil
}

func (b *Builder) BuildRequest(body interface{}) (*http.Request, error) {
func (b *Builder) BuildRequest(body interface{}) (*retryablehttp.Request, error) {
b.r.Header = b.conf.Header
if err := b.addAuth(); err != nil {
return nil, err
Expand Down
10 changes: 1 addition & 9 deletions selfservice/hook/web_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"fmt"
"net/http"

"github.com/hashicorp/go-retryablehttp"

"github.com/ory/kratos/identity"
"github.com/ory/kratos/request"
"github.com/ory/kratos/selfservice/flow"
Expand Down Expand Up @@ -127,13 +125,7 @@ func (e *WebHook) execute(ctx context.Context, data *templateContext) error {
return err
}

httpClient := e.deps.HTTPClient(ctx)
r, err := retryablehttp.FromRequest(req)
if err != nil {
return err
}

resp, err := httpClient.Do(r)
resp, err := e.deps.HTTPClient(ctx).Do(req)
if err != nil {
return err
}
Expand Down

0 comments on commit 65e42e5

Please sign in to comment.