Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and fix #3

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 8 additions & 44 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,10 +757,14 @@ func (c *Context) ClientIP() string {
}
}

remoteIP, trusted := c.RemoteIP()
// It also checks if the remoteIP is a trusted proxy or not.
// In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks
// defined by Engine.SetTrustedProxies()
remoteIP := net.ParseIP(c.RemoteIP())
if remoteIP == nil {
return ""
}
trusted := c.engine.isTrustedProxy(remoteIP)

if trusted && c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil {
for _, headerName := range c.engine.RemoteIPHeaders {
Expand All @@ -773,53 +777,13 @@ func (c *Context) ClientIP() string {
return remoteIP.String()
}

func (e *Engine) isTrustedProxy(ip net.IP) bool {
if e.trustedCIDRs != nil {
for _, cidr := range e.trustedCIDRs {
if cidr.Contains(ip) {
return true
}
}
}
return false
}

// RemoteIP parses the IP from Request.RemoteAddr, normalizes and returns the IP (without the port).
// It also checks if the remoteIP is a trusted proxy or not.
// In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks
// defined by Engine.SetTrustedProxies()
func (c *Context) RemoteIP() (net.IP, bool) {
func (c *Context) RemoteIP() string {
ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr))
if err != nil {
return nil, false
}
remoteIP := net.ParseIP(ip)
if remoteIP == nil {
return nil, false
}

return remoteIP, c.engine.isTrustedProxy(remoteIP)
}

func (e *Engine) validateHeader(header string) (clientIP string, valid bool) {
if header == "" {
return "", false
}
items := strings.Split(header, ",")
for i := len(items) - 1; i >= 0; i-- {
ipStr := strings.TrimSpace(items[i])
ip := net.ParseIP(ipStr)
if ip == nil {
return "", false
}

// X-Forwarded-For is appended by proxy
// Check IPs in reverse order and stop when find untrusted proxy
if (i == 0) || (!e.isTrustedProxy(ip)) {
return ipStr, true
}
return ""
}
return
return ip
}

// ContentType returns the Content-Type header of the request.
Expand Down
10 changes: 9 additions & 1 deletion context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"html/template"
"io"
"mime/multipart"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -1404,6 +1405,11 @@ func TestContextClientIP(t *testing.T) {
// Tests exercising the TrustedProxies functionality
resetContextForClientIPTests(c)

// IPv6 support
c.Request.RemoteAddr = "[::1]:12345"
assert.Equal(t, "20.20.20.20", c.ClientIP())

resetContextForClientIPTests(c)
// No trusted proxies
_ = c.engine.SetTrustedProxies([]string{})
c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"}
Expand Down Expand Up @@ -1500,6 +1506,7 @@ func resetContextForClientIPTests(c *Context) {
c.Request.Header.Set("CF-Connecting-IP", "60.60.60.60")
c.Request.RemoteAddr = " 40.40.40.40:42123 "
c.engine.TrustedPlatform = ""
c.engine.trustedCIDRs = defaultTrustedCIDRs
c.engine.AppEngine = false
}

Expand Down Expand Up @@ -2051,7 +2058,8 @@ func TestRemoteIPFail(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil)
c.Request.RemoteAddr = "[:::]:80"
ip, trust := c.RemoteIP()
ip := net.ParseIP(c.RemoteIP())
trust := c.engine.isTrustedProxy(ip)
assert.Nil(t, ip)
assert.False(t, trust)
}
Expand Down
51 changes: 47 additions & 4 deletions gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"net/http"
"os"
"path"
"reflect"
"strings"
"sync"

Expand All @@ -28,7 +27,16 @@ var (

var defaultPlatform string

var defaultTrustedCIDRs = []*net.IPNet{{IP: net.IP{0x0, 0x0, 0x0, 0x0}, Mask: net.IPMask{0x0, 0x0, 0x0, 0x0}}} // 0.0.0.0/0
var defaultTrustedCIDRs = []*net.IPNet{
{ // 0.0.0.0/0 (IPv4)
IP: net.IP{0x0, 0x0, 0x0, 0x0},
Mask: net.IPMask{0x0, 0x0, 0x0, 0x0},
},
{ // ::/0 (IPv6)
IP: net.IP{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Mask: net.IPMask{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
}

// HandlerFunc defines the handler used by gin middleware as return value.
type HandlerFunc func(*Context)
Expand Down Expand Up @@ -399,9 +407,9 @@ func (engine *Engine) SetTrustedProxies(trustedProxies []string) error {
return engine.parseTrustedProxies()
}

// isUnsafeTrustedProxies compares Engine.trustedCIDRs and defaultTrustedCIDRs, it's not safe if equal (returns true)
// isUnsafeTrustedProxies checks if Engine.trustedCIDRs contains all IPs, it's not safe if it has (returns true)
func (engine *Engine) isUnsafeTrustedProxies() bool {
return reflect.DeepEqual(engine.trustedCIDRs, defaultTrustedCIDRs)
return engine.isTrustedProxy(net.ParseIP("0.0.0.0")) || engine.isTrustedProxy(net.ParseIP("::"))
}

// parseTrustedProxies parse Engine.trustedProxies to Engine.trustedCIDRs
Expand All @@ -411,6 +419,41 @@ func (engine *Engine) parseTrustedProxies() error {
return err
}

// isTrustedProxy will check whether the IP address is included in the trusted list according to Engine.trustedCIDRs
func (engine *Engine) isTrustedProxy(ip net.IP) bool {
if engine.trustedCIDRs == nil {
return false
}
for _, cidr := range engine.trustedCIDRs {
if cidr.Contains(ip) {
return true
}
}
return false
}

// validateHeader will parse X-Forwarded-For header and return the trusted client IP address
func (engine *Engine) validateHeader(header string) (clientIP string, valid bool) {
if header == "" {
return "", false
}
items := strings.Split(header, ",")
for i := len(items) - 1; i >= 0; i-- {
ipStr := strings.TrimSpace(items[i])
ip := net.ParseIP(ipStr)
if ip == nil {
break
}

// X-Forwarded-For is appended by proxy
// Check IPs in reverse order and stop when find untrusted proxy
if (i == 0) || (!engine.isTrustedProxy(ip)) {
return ipStr, true
}
}
return "", false
}

// parseIP parse a string representation of an IP and returns a net.IP with the
// minimum byte representation or nil if input is invalid.
func parseIP(ip string) net.IP {
Expand Down