Skip to content

Commit

Permalink
Merge pull request #2007 from facchettos/proxy-refactor-and-test
Browse files Browse the repository at this point in the history
Add testing to the proxy
  • Loading branch information
bennerv authored Apr 13, 2022
2 parents 31ee47d + bfb4d4b commit eab506d
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 19 deletions.
54 changes: 35 additions & 19 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package proxy
import (
"crypto/tls"
"crypto/x509"
"errors"
"io"
"io/ioutil"
"net"
Expand All @@ -25,13 +26,15 @@ type Server struct {
KeyFile string
ClientCertFile string
Subnet string
subnet *net.IPNet
}

func (s *Server) Run() error {
_, subnet, err := net.ParseCIDR(s.Subnet)
if err != nil {
return err
}
s.subnet = subnet

b, err := ioutil.ReadFile(s.ClientCertFile)
if err != nil {
Expand Down Expand Up @@ -92,25 +95,38 @@ func (s *Server) Run() error {
return err
}

return http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}

ip, _, err := net.SplitHostPort(r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

if !subnet.Contains(net.ParseIP(ip)) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}

Proxy(s.Log, w, r, 0)
}))
return http.Serve(l, http.HandlerFunc(s.proxyHandler))
}

func (s Server) proxyHandler(w http.ResponseWriter, r *http.Request) {
err := s.validateProxyRequest(w, r)
if err != nil {
return
}
Proxy(s.Log, w, r, 0)
}

// validateProxyRequest checks that the request is valid. If not, it writes the
// appropriate http headers and returns an error.
func (s Server) validateProxyRequest(w http.ResponseWriter, r *http.Request) error {

ip, _, err := net.SplitHostPort(r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return err
}

if r.Method != http.MethodConnect {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return errors.New("Request is not valid, method is not CONNECT")
}

if !s.subnet.Contains(net.ParseIP(ip)) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return errors.New("Request is not allowed, the originating IP is not part of the allowed subnet")
}

return nil
}

// Proxy takes an HTTP/1.x CONNECT Request and ResponseWriter from the Golang
Expand Down
82 changes: 82 additions & 0 deletions pkg/proxy/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package proxy

// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.

import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
)

func TestRequestValidation(t *testing.T) {
tests := []struct {
name string
method string
subnet string
hostname string
wantStatus int
wantErr bool
}{
{
name: "get https same subnet",
method: http.MethodGet,
subnet: "127.0.0.1/24",
hostname: "https://127.0.0.2:123",
wantStatus: http.StatusMethodNotAllowed,
wantErr: true,
},
{
name: "connect http same subnet",
method: http.MethodConnect,
subnet: "127.0.0.1/24",
hostname: "127.0.0.2:123",
wantStatus: http.StatusOK,
wantErr: false,
},
{
name: "connect http different subnet",
method: http.MethodConnect,
subnet: "127.0.0.1/24",
hostname: "10.0.0.1:123",
wantStatus: http.StatusForbidden,
wantErr: true,
},
{
name: "wrong hostname",
method: http.MethodGet,
subnet: "127.0.0.1/24",
hostname: "https://127.0.0.1::",
wantStatus: http.StatusBadRequest,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := Server{Subnet: tt.subnet}
_, subnet, err := net.ParseCIDR(server.Subnet)
if err != nil {
t.FailNow()
}
server.subnet = subnet

recorder := httptest.NewRecorder()
request := httptest.NewRequest(tt.method, tt.hostname, nil)

err = server.validateProxyRequest(recorder, request)
if (err != nil && !tt.wantErr) || (err == nil && tt.wantErr) {
t.Error(err)
}

response := recorder.Result()

if response.StatusCode != tt.wantStatus {
fmt.Println(response.StatusCode, tt.wantStatus)
t.Error(tt.hostname)
}
})
}
}

0 comments on commit eab506d

Please sign in to comment.