diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index dd9c7bef38f..135187e1303 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -6,6 +6,7 @@ package proxy import ( "crypto/tls" "crypto/x509" + "errors" "io" "io/ioutil" "net" @@ -24,6 +25,7 @@ type Server struct { KeyFile string ClientCertFile string Subnet string + subnet *net.IPNet } func (s *Server) Run() error { @@ -31,6 +33,7 @@ func (s *Server) Run() error { if err != nil { return err } + s.subnet = subnet b, err := ioutil.ReadFile(s.ClientCertFile) if err != nil { @@ -91,25 +94,38 @@ func (s *Server) Run() error { return err } - return http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - return - } - - ip, _, err := net.SplitHostPort(r.Host) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - if !subnet.Contains(net.ParseIP(ip)) { - http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) - return - } - - Proxy(s.Log, w, r, 0) - })) + return http.Serve(l, http.HandlerFunc(s.proxyHandler)) +} + +func (s Server) proxyHandler(w http.ResponseWriter, r *http.Request) { + err := s.validateProxyRequest(w, r) + if err != nil { + return + } + Proxy(s.Log, w, r, 0) +} + +// validateProxyRequest checks that the request is valid. If not, it writes the +// appropriate http headers and returns an error. +func (s Server) validateProxyRequest(w http.ResponseWriter, r *http.Request) error { + + ip, _, err := net.SplitHostPort(r.Host) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if r.Method != http.MethodConnect { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return errors.New("Request is not valid, method is not CONNECT") + } + + if !s.subnet.Contains(net.ParseIP(ip)) { + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return errors.New("Request is not allowed, the originating IP is not part of the allowed subnet") + } + + return nil } // Proxy takes an HTTP/1.x CONNECT Request and ResponseWriter from the Golang diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go new file mode 100644 index 00000000000..e71a1b56c45 --- /dev/null +++ b/pkg/proxy/proxy_test.go @@ -0,0 +1,82 @@ +package proxy + +// Copyright (c) Microsoft Corporation. +// Licensed under the Apache License 2.0. + +import ( + "fmt" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestValidation(t *testing.T) { + tests := []struct { + name string + method string + subnet string + hostname string + wantStatus int + wantErr bool + }{ + { + name: "get https same subnet", + method: http.MethodGet, + subnet: "127.0.0.1/24", + hostname: "https://127.0.0.2:123", + wantStatus: http.StatusMethodNotAllowed, + wantErr: true, + }, + { + name: "connect http same subnet", + method: http.MethodConnect, + subnet: "127.0.0.1/24", + hostname: "127.0.0.2:123", + wantStatus: http.StatusOK, + wantErr: false, + }, + { + name: "connect http different subnet", + method: http.MethodConnect, + subnet: "127.0.0.1/24", + hostname: "10.0.0.1:123", + wantStatus: http.StatusForbidden, + wantErr: true, + }, + { + name: "wrong hostname", + method: http.MethodGet, + subnet: "127.0.0.1/24", + hostname: "https://127.0.0.1::", + wantStatus: http.StatusBadRequest, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := Server{Subnet: tt.subnet} + _, subnet, err := net.ParseCIDR(server.Subnet) + if err != nil { + t.FailNow() + } + server.subnet = subnet + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(tt.method, tt.hostname, nil) + + err = server.validateProxyRequest(recorder, request) + if (err != nil && !tt.wantErr) || (err == nil && tt.wantErr) { + t.Error(err) + } + + response := recorder.Result() + + if response.StatusCode != tt.wantStatus { + fmt.Println(response.StatusCode, tt.wantStatus) + t.Error(tt.hostname) + } + }) + } +}