Skip to content

Commit

Permalink
feat: lighter SSRF migration (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl authored Dec 1, 2022
1 parent 6e01212 commit d5dfdaa
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 20 deletions.
53 changes: 39 additions & 14 deletions httpx/private_ip_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
"net"
"net/http"
"net/url"

"github.com/ory/x/stringsx"
"syscall"
"time"

"github.com/pkg/errors"
)
Expand Down Expand Up @@ -80,29 +80,54 @@ var _ http.RoundTripper = (*NoInternalIPRoundTripper)(nil)

// NoInternalIPRoundTripper is a RoundTripper that disallows internal IP addresses.
type NoInternalIPRoundTripper struct {
http.RoundTripper
internalIPExceptions []string
}

func (n NoInternalIPRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
rt := http.DefaultTransport
if n.RoundTripper != nil {
rt = n.RoundTripper
}

incoming := IncomingRequestURL(request)
incoming.RawQuery = ""
incoming.RawFragment = ""
for _, exception := range n.internalIPExceptions {
if incoming.String() == exception {
return rt.RoundTrip(request)
return http.DefaultTransport.RoundTrip(request)
}
}

host, _, _ := net.SplitHostPort(request.Host)
if err := DisallowIPPrivateAddresses(stringsx.Coalesce(host, request.Host)); err != nil {
return nil, err
}
return NoInternalTransport.RoundTrip(request)
}

var NoInternalDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Control: func(network, address string, _ syscall.RawConn) error {
if !(network == "tcp4" || network == "tcp6") {
return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a safe network type", network))
}

host, _, err := net.SplitHostPort(address)
if err != nil {
return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a valid host/port pair: %s", address, err))
}

ip := net.ParseIP(host)
if ip == nil {
return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a valid IP address", host))
}

if ip.IsPrivate() || ip.IsLoopback() || ip.IsUnspecified() {
return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip))
}

return nil
},
}

return rt.RoundTrip(request)
var NoInternalTransport http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: NoInternalDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
60 changes: 58 additions & 2 deletions httpx/private_ip_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
package httpx

import (
"net"
"net/http"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -47,7 +49,7 @@ func (n noOpRoundTripper) RoundTrip(request *http.Request) (*http.Response, erro
var _ http.RoundTripper = new(noOpRoundTripper)

func TestAllowExceptions(t *testing.T) {
rt := &NoInternalIPRoundTripper{RoundTripper: new(noOpRoundTripper), internalIPExceptions: []string{"http://localhost/asdf"}}
rt := &NoInternalIPRoundTripper{internalIPExceptions: []string{"http://localhost/asdf"}}

_, err := rt.RoundTrip(&http.Request{
Host: "localhost",
Expand All @@ -56,7 +58,12 @@ func TestAllowExceptions(t *testing.T) {
"Host": []string{"localhost"},
},
})
require.NoError(t, err)
// assert that the error is eiher nil or a dial error.
if err != nil {
opErr := new(net.OpError)
require.ErrorAs(t, err, &opErr)
require.Equal(t, "dial", opErr.Op)
}

_, err = rt.RoundTrip(&http.Request{
Host: "localhost",
Expand All @@ -67,3 +74,52 @@ func TestAllowExceptions(t *testing.T) {
})
require.Error(t, err)
}

func assertErrorContains(msg string) assert.ErrorAssertionFunc {
return func(t assert.TestingT, err error, i ...interface{}) bool {
if !assert.Error(t, err, i...) {
return false
}
return assert.Contains(t, err.Error(), msg)
}
}

func TestNoInternalDialer(t *testing.T) {
for _, tt := range []struct {
name string
network string
address string
assertErr assert.ErrorAssertionFunc
}{{
name: "TCP public is allowed",
network: "tcp",
address: "www.google.de:443",
assertErr: assert.NoError,
}, {
name: "TCP private is denied",
network: "tcp",
address: "localhost:443",
assertErr: assertErrorContains("is not a public IP address"),
}, {
name: "UDP public is denied",
network: "udp",
address: "www.google.de:443",
assertErr: assertErrorContains("not a safe network type"),
}, {
name: "UDP public is denied",
network: "udp",
address: "www.google.de:443",
assertErr: assertErrorContains("not a safe network type"),
}, {
name: "UNIX sockets are denied",
network: "unix",
address: "/etc/passwd",
assertErr: assertErrorContains("not a safe network type"),
}} {

t.Run("case="+tt.name, func(t *testing.T) {
_, err := NoInternalDialer.Dial(tt.network, tt.address)
tt.assertErr(t, err)
})
}
}
1 change: 0 additions & 1 deletion httpx/resilient_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ func NewResilientClient(opts ...ResilientOptions) *retryablehttp.Client {

if o.noInternalIPs == true {
o.c.Transport = &NoInternalIPRoundTripper{
RoundTripper: o.c.Transport,
internalIPExceptions: o.internalIPExceptions,
}
}
Expand Down
6 changes: 3 additions & 3 deletions httpx/resilient_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

func TestNoPrivateIPs(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("Hello, world!"))
}))
t.Cleanup(ts.Close)
Expand Down Expand Up @@ -46,15 +46,15 @@ func TestNoPrivateIPs(t *testing.T) {
_, err := c.Get(destination)
if !passes {
require.Error(t, err)
assert.Contains(t, err.Error(), "is in the")
assert.Contains(t, err.Error(), "is not a public IP address")
} else {
require.NoError(t, err)
}
}
}

func TestClientWithTracer(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("Hello, world!"))
}))
t.Cleanup(ts.Close)
Expand Down

0 comments on commit d5dfdaa

Please sign in to comment.