diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index 552c7222e92..639e4966c42 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -1,117 +1,74 @@ package proxy import ( + "fmt" "net" "net/http" "net/http/httptest" "testing" ) -func TestProxyRequestValidationMethod(t *testing.T) { - server := Server{Subnet: "127.0.0.1/24"} - _, subnet, err := net.ParseCIDR(server.Subnet) - if err != nil { - t.FailNow() +func TestRequestValidation(t *testing.T) { + tests := []struct { + name string + method string + subnet string + hostname string + wantStatus int + }{ + { + name: "get https same subnet", + method: http.MethodGet, + subnet: "127.0.0.1/24", + hostname: "https://127.0.0.2:123", + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "get https different subnet", + method: http.MethodGet, + subnet: "127.0.0.1/24", + hostname: "https://10.0.0.2:123", + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "connect http same subnet", + method: http.MethodConnect, + subnet: "127.0.0.1/24", + hostname: "127.0.0.2:123", + wantStatus: http.StatusOK, + }, + { + name: "connect http different subnet", + method: http.MethodConnect, + subnet: "127.0.0.1/24", + hostname: "10.0.0.1:123", + wantStatus: http.StatusForbidden, + }, } - server.subnet = subnet - //This should fail because the method is not CONNECT - recorder := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodGet, "https://127.0.0.1:123", nil) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := Server{Subnet: tt.subnet} + _, subnet, err := net.ParseCIDR(server.Subnet) + if err != nil { + t.FailNow() + } + server.subnet = subnet - server.validateProxyResquest(recorder, request) + recorder := httptest.NewRecorder() + request := httptest.NewRequest(tt.method, tt.hostname, nil) - response := recorder.Result() - if response.StatusCode != http.StatusMethodNotAllowed { - t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusMethodNotAllowed, response.StatusCode) - t.FailNow() - } - - //This should succeed because the method is CONNECT - recorder = httptest.NewRecorder() - request = httptest.NewRequest(http.MethodConnect, "127.0.0.1:123", nil) - - server.validateProxyResquest(recorder, request) - - response = recorder.Result() - - if response.StatusCode != http.StatusOK { - t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusOK, response.StatusCode) - t.FailNow() - } - -} - -func TestProxyRequestValidationHostname(t *testing.T) { - - server := Server{Subnet: "127.0.0.1/24"} - _, subnet, err := net.ParseCIDR(server.Subnet) - if err != nil { - t.FailNow() - } - server.subnet = subnet - - //This should fail because the hostname in not valid - recorder := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodConnect, "", nil) - - server.validateProxyResquest(recorder, request) - - response := recorder.Result() - - if response.StatusCode != http.StatusBadRequest { - t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusBadRequest, response.StatusCode) - t.FailNow() - } - - //This should succeed because the hostname is valid - recorder = httptest.NewRecorder() - request = httptest.NewRequest(http.MethodConnect, "127.0.0.1:8443", nil) - - server.validateProxyResquest(recorder, request) - - response = recorder.Result() - - if response.StatusCode != http.StatusOK { - t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusOK, response.StatusCode) - t.FailNow() - } - -} - -func TestProxyRequestValidationSubnet(t *testing.T) { - - server := Server{Subnet: "127.0.0.1/24"} - _, subnet, err := net.ParseCIDR(server.Subnet) - if err != nil { - t.FailNow() - } - server.subnet = subnet - - //This should succeed because it is in the subnet - recorder := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodConnect, "127.0.0.1:1234", nil) - - server.validateProxyResquest(recorder, request) - - response := recorder.Result() - - if response.StatusCode != http.StatusOK { - t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusOK, response.StatusCode) - t.FailNow() - } + server.validateProxyResquest(recorder, request) - //This should fail because it is not in the subnet - recorder = httptest.NewRecorder() - request = httptest.NewRequest(http.MethodConnect, "10.0.0.1:1234", nil) + response := recorder.Result() - server.validateProxyResquest(recorder, request) + if response.StatusCode != tt.wantStatus { + fmt.Println(response.StatusCode, tt.wantStatus) + t.Error(tt.hostname) + } - response = recorder.Result() + }) - if response.StatusCode != http.StatusForbidden { - t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusForbidden, response.StatusCode) - t.FailNow() } }