Skip to content

Commit

Permalink
test: add TLS SD validation (envoyproxy#2503)
Browse files Browse the repository at this point in the history
* basic TLS test

Signed-off-by: Kuat Yessenov <[email protected]>

* basic TLS test

Signed-off-by: Kuat Yessenov <[email protected]>

* fix a unit test

Signed-off-by: Kuat Yessenov <[email protected]>

* add mTLS certs

Signed-off-by: Kuat Yessenov <[email protected]>

* add principals

Signed-off-by: Kuat Yessenov <[email protected]>
  • Loading branch information
kyessenov authored and istio-testing committed Nov 1, 2019
1 parent 8fde024 commit 6a8fc0d
Show file tree
Hide file tree
Showing 34 changed files with 609 additions and 372 deletions.
32 changes: 27 additions & 5 deletions extensions/common/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ void extractServiceName(const std::string& fqdn, std::string* service_name) {

} // namespace

StringView AuthenticationPolicyString(ServiceAuthenticationPolicy policy) {
switch (policy) {
case ServiceAuthenticationPolicy::None:
return kNone;
case ServiceAuthenticationPolicy::MutualTLS:
return kMutualTLS;
default:
break;
}
return {};
;
}

using google::protobuf::util::JsonStringToMessage;
using google::protobuf::util::MessageToJsonString;

Expand Down Expand Up @@ -181,16 +194,25 @@ void populateHTTPRequestInfo(bool outbound, RequestInfo* request_info) {
->toString();

int64_t destination_port = 0;
std::string tls_version;

if (outbound) {
getValue({"upstream", "port"}, &destination_port);
getValue({"upstream", "mtls"}, &request_info->mTLS);
getStringValue({"upstream", "tls_version"}, &tls_version);
getStringValue({"upstream", "uri_san_peer_certificate"},
&request_info->destination_principal);
getStringValue({"upstream", "uri_san_local_certificate"},
&request_info->source_principal);
} else {
getValue({"destination", "port"}, &destination_port);
getValue({"connection", "mtls"}, &request_info->mTLS);
getStringValue({"connection", "tls_version"}, &tls_version);
bool mtls = false;
if (getValue({"connection", "mtls"}, &mtls)) {
request_info->service_auth_policy =
mtls ? ::Wasm::Common::ServiceAuthenticationPolicy::MutualTLS
: ::Wasm::Common::ServiceAuthenticationPolicy::None;
}
getStringValue({"connection", "uri_san_local_certificate"},
&request_info->destination_principal);
getStringValue({"connection", "uri_san_peer_certificate"},
&request_info->source_principal);
}
request_info->destination_port = destination_port;
}
Expand Down
15 changes: 13 additions & 2 deletions extensions/common/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ const std::string kProtocolGRPC = "grpc";
const std::set<std::string> kGrpcContentTypes{
"application/grpc", "application/grpc+proto", "application/grpc+json"};

enum class ServiceAuthenticationPolicy : int64_t {
Unspecified = 0,
None = 1,
MutualTLS = 2,
};

constexpr StringView kMutualTLS = "MUTUAL_TLS";
constexpr StringView kNone = "NONE";

StringView AuthenticationPolicyString(ServiceAuthenticationPolicy policy);

// RequestInfo represents the information collected from filter stream
// callbacks. This is used to fill metrics and logs.
struct RequestInfo {
Expand Down Expand Up @@ -86,8 +97,8 @@ struct RequestInfo {
// Operation of the request, i.e. HTTP method or gRPC API method.
std::string request_operation;

// Indicates if the request uses mTLS.
bool mTLS = false;
// Service authentication policy (NONE, MUTUAL_TLS)
ServiceAuthenticationPolicy service_auth_policy;

// Principal of source and destination workload extracted from TLS
// certificate.
Expand Down
4 changes: 2 additions & 2 deletions extensions/stackdriver/log/logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ void Logger::addLogEntry(const ::Wasm::Common::RequestInfo& request_info,
(*label_map)["destination_principal"] = request_info.destination_principal;
(*label_map)["source_principal"] = request_info.source_principal;
(*label_map)["service_authentication_policy"] =
request_info.mTLS ? "true" : "false";

std::string(::Wasm::Common::AuthenticationPolicyString(
request_info.service_auth_policy));
// Accumulate estimated size of the request. If the current request exceeds
// the size limit, flush the request out.
size_ += new_entry->ByteSizeLong();
Expand Down
6 changes: 4 additions & 2 deletions extensions/stackdriver/log/logger_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ ::Wasm::Common::RequestInfo requestInfo() {
request_info.request_protocol = "HTTP";
request_info.destination_principal = "destination_principal";
request_info.source_principal = "source_principal";
request_info.mTLS = true;
request_info.service_auth_policy =
::Wasm::Common::ServiceAuthenticationPolicy::MutualTLS;
return request_info;
}

Expand Down Expand Up @@ -117,7 +118,8 @@ google::logging::v2::WriteLogEntriesRequest expectedRequest(
(*label_map)["destination_principal"] = request_info.destination_principal;
(*label_map)["source_principal"] = request_info.source_principal;
(*label_map)["service_authentication_policy"] =
request_info.mTLS ? "true" : "false";
std::string(::Wasm::Common::AuthenticationPolicyString(
request_info.service_auth_policy));
}
return req;
}
Expand Down
9 changes: 4 additions & 5 deletions extensions/stackdriver/metric/record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ namespace Extensions {
namespace Stackdriver {
namespace Metric {

constexpr char kMutualTLS[] = "MUTUAL_TLS";
constexpr char kNone[] = "NONE";

void record(bool is_outbound, const ::wasm::common::NodeInfo &local_node_info,
const ::wasm::common::NodeInfo &peer_node_info,
const ::Wasm::Common::RequestInfo &request_info) {
Expand All @@ -40,7 +37,8 @@ void record(bool is_outbound, const ::wasm::common::NodeInfo &local_node_info,
{requestOperationKey(), request_info.request_operation},
{requestProtocolKey(), request_info.request_protocol},
{serviceAuthenticationPolicyKey(),
request_info.mTLS ? kMutualTLS : kNone},
::Wasm::Common::AuthenticationPolicyString(
request_info.service_auth_policy)},
{destinationServiceNameKey(), request_info.destination_service_host},
{destinationServiceNamespaceKey(), peer_node_info.namespace_()},
{destinationPortKey(), std::to_string(request_info.destination_port)},
Expand All @@ -65,7 +63,8 @@ void record(bool is_outbound, const ::wasm::common::NodeInfo &local_node_info,
{requestOperationKey(), request_info.request_operation},
{requestProtocolKey(), request_info.request_protocol},
{serviceAuthenticationPolicyKey(),
request_info.mTLS ? kMutualTLS : kNone},
::Wasm::Common::AuthenticationPolicyString(
request_info.service_auth_policy)},
{destinationServiceNameKey(), request_info.destination_service_host},
{destinationServiceNamespaceKey(), local_node_info.namespace_()},
{destinationPortKey(), std::to_string(request_info.destination_port)},
Expand Down
5 changes: 2 additions & 3 deletions extensions/stats/plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ constexpr StringView Sep = "#@";
const std::string unknown = "unknown";
const std::string vSource = "source";
const std::string vDest = "destination";
const std::string vMTLS = "mutual_tls";
const std::string vNone = "none";
const std::string vDash = "-";

const std::string default_field_separator = ";.;";
Expand Down Expand Up @@ -191,7 +189,8 @@ struct IstioDimensions {
request.response_flag.empty() ? vDash : request.response_flag;

connection_security_policy =
outbound ? unknown : (request.mTLS ? vMTLS : vNone);
std::string(::Wasm::Common::AuthenticationPolicyString(
request.service_auth_policy));

permissive_response_code = request.rbac_permissive_engine_result.empty()
? "none"
Expand Down
30 changes: 30 additions & 0 deletions test/envoye2e/basic_flow/basic_xds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ filter_chains:
route:
cluster: inbound|9080|http|server.default.svc.cluster.local
timeout: 0s
{{ .Vars.ServerTLSContext | indent 2 }}
`

func TestBasicHTTP(t *testing.T) {
Expand Down Expand Up @@ -103,3 +104,32 @@ func TestBasicHTTP(t *testing.T) {
t.Fatal(err)
}
}

func TestBasicHTTPwithTLS(t *testing.T) {
ports := env.NewPorts(env.BasicHTTPwithTLS)
params := &driver.Params{
Vars: map[string]string{
"BackendPort": fmt.Sprintf("%d", ports.BackendPort),
"ClientPort": fmt.Sprintf("%d", ports.ClientToServerProxyPort),
"ClientAdmin": fmt.Sprintf("%d", ports.ClientAdminPort),
"ServerAdmin": fmt.Sprintf("%d", ports.ServerAdminPort),
"ServerPort": fmt.Sprintf("%d", ports.ProxyToServerProxyPort),
},
XDS: int(ports.XDSPort),
}
params.Vars["ClientTLSContext"] = params.LoadTestData("testdata/transport_socket/client.yaml.tmpl")
params.Vars["ServerTLSContext"] = params.LoadTestData("testdata/transport_socket/server.yaml.tmpl")
if err := (&driver.Scenario{
[]driver.Step{
&driver.XDS{},
&driver.Update{Node: "client", Version: "0", Listeners: []string{ClientHTTPListener}},
&driver.Update{Node: "server", Version: "0", Listeners: []string{ServerHTTPListener}},
&driver.Envoy{Bootstrap: params.LoadTestData("testdata/bootstrap/client.yaml.tmpl")},
&driver.Envoy{Bootstrap: params.LoadTestData("testdata/bootstrap/server.yaml.tmpl")},
&driver.Sleep{1 * time.Second},
&driver.Get{ports.ClientToServerProxyPort, "hello, world!"},
},
}).Run(params); err != nil {
t.Fatal(err)
}
}
6 changes: 1 addition & 5 deletions test/envoye2e/driver/envoy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ import (
type Envoy struct {
// template for the bootstrap
Bootstrap string
// working directory (optional)
Dir string

tmpFile string
cmd *exec.Cmd
Expand Down Expand Up @@ -81,9 +79,7 @@ func (e *Envoy) Run(p *Params) error {
cmd := exec.Command(envoyPath, args...)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
if e.Dir != "" {
cmd.Dir = e.Dir
}
cmd.Dir = BazelWorkspace()

log.Printf("envoy cmd %v", cmd.Args)
e.cmd = cmd
Expand Down
19 changes: 18 additions & 1 deletion test/envoye2e/driver/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"os/exec"
"path/filepath"
"strings"

"github.com/golang/protobuf/proto"
)

// Loads resources in the test data directory
Expand All @@ -33,14 +35,21 @@ func BazelWorkspace() string {
return strings.TrimSuffix(string(workspace), "\n")
}

// Normalizes test data path
func TestPath(testFileName string) string {
return filepath.Join(BazelWorkspace(), testFileName)
}

// Loads a test file content
func LoadTestData(testFileName string) string {
data, err := ioutil.ReadFile(filepath.Join(BazelWorkspace(), testFileName))
data, err := ioutil.ReadFile(TestPath(testFileName))
if err != nil {
panic(err)
}
return string(data)
}

// Loads a test file and fills in template variables
func (p *Params) LoadTestData(testFileName string) string {
data := LoadTestData(testFileName)
out, err := p.Fill(data)
Expand All @@ -49,3 +58,11 @@ func (p *Params) LoadTestData(testFileName string) string {
}
return out
}

// Loads a test file as YAML into a proto and fills in template variables
func (p *Params) LoadTestProto(testFileName string, msg proto.Message) {
data := LoadTestData(testFileName)
if err := p.FillYAML(data, msg); err != nil {
panic(err)
}
}
10 changes: 9 additions & 1 deletion test/envoye2e/driver/scenario.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@ func (s *Sleep) Run(_ *Params) error {
func (s *Sleep) Cleanup() {}

func (p *Params) Fill(s string) (string, error) {
t := template.Must(template.New("params").Option("missingkey=zero").Parse(s))
t := template.Must(template.New("params").
Option("missingkey=zero").
Funcs(template.FuncMap{
"indent": func(n int, s string) string {
pad := strings.Repeat(" ", n)
return pad + strings.Replace(s, "\n", "\n"+pad, -1)
},
}).
Parse(s))
var b bytes.Buffer
if err := t.Execute(&b, p); err != nil {
return "", err
Expand Down
57 changes: 26 additions & 31 deletions test/envoye2e/driver/stackdriver.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,67 +85,62 @@ func (sd *Stackdriver) Cleanup() {
close(sd.done)
}

func (sd *Stackdriver) Check(ts []string, ls []string) Step {
return &checkStackdriver{
sd: sd,
ts: ts,
ls: ls,
}
}

type checkStackdriver struct {
sd *Stackdriver
ts []string
ls []string
}

func (s *checkStackdriver) Run(p *Params) error {
func (sd *Stackdriver) Check(p *Params, tsFiles []string, lsFiles []string) Step {
// check as sets of strings by marshaling to proto
twant := make(map[string]struct{})
for _, t := range s.ts {
for _, t := range tsFiles {
pb := &monitoring.TimeSeries{}
if err := p.FillYAML(t, pb); err != nil {
return err
}
p.LoadTestProto(t, pb)
twant[proto.MarshalTextString(pb)] = struct{}{}
}
lwant := make(map[string]struct{})
for _, l := range s.ls {
for _, l := range lsFiles {
pb := &logging.WriteLogEntriesRequest{}
if err := p.FillYAML(l, pb); err != nil {
return err
}
p.LoadTestProto(l, pb)
lwant[proto.MarshalTextString(pb)] = struct{}{}
}
return &checkStackdriver{
sd: sd,
twant: twant,
lwant: lwant,
}
}

type checkStackdriver struct {
sd *Stackdriver
twant map[string]struct{}
lwant map[string]struct{}
}

func (s *checkStackdriver) Run(p *Params) error {
foundAllLogs := false
foundAllMetrics := false
for i := 0; i < 30; i++ {
s.sd.Lock()
foundAllLogs = reflect.DeepEqual(s.sd.ls, lwant)
foundAllLogs = reflect.DeepEqual(s.sd.ls, s.lwant)
if !foundAllLogs {
log.Printf("got log entries %d, want %d\n", len(s.sd.ls), len(lwant))
if len(s.sd.ls) >= len(lwant) {
log.Printf("got log entries %d, want %d\n", len(s.sd.ls), len(s.lwant))
if len(s.sd.ls) >= len(s.lwant) {
for got := range s.sd.ls {
log.Println(got)
}
log.Println("--- but want ---")
for want := range lwant {
for want := range s.lwant {
log.Println(want)
}
return fmt.Errorf("failed to receive expected logs")
}
}

foundAllMetrics = reflect.DeepEqual(s.sd.ts, twant)
foundAllMetrics = reflect.DeepEqual(s.sd.ts, s.twant)
if !foundAllMetrics {
log.Printf("got metrics %d, want %d\n", len(s.sd.ts), len(twant))
if len(s.sd.ts) >= len(twant) {
log.Printf("got metrics %d, want %d\n", len(s.sd.ts), len(s.twant))
if len(s.sd.ts) >= len(s.twant) {
for got := range s.sd.ts {
log.Println(got)
}
log.Println("--- but want ---")
for want := range twant {
for want := range s.twant {
log.Println(want)
}
return fmt.Errorf("failed to receive expected metrics")
Expand Down
2 changes: 2 additions & 0 deletions test/envoye2e/env/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ const (

// xDS driven tests
BasicHTTP
BasicHTTPwithTLS
StackDriverPayload
StackDriverPayloadWithTLS
StackDriverReload
StackDriverParallel

Expand Down
2 changes: 1 addition & 1 deletion test/envoye2e/env/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func (s *TestSetup) SetUpClientServerEnvoy() error {
}
}
if s.startTcpBackend {
s.tcpBackend, err = NewTCPServer(s.ports.BackendPort, "hello", s.EnableTls)
s.tcpBackend, err = NewTCPServer(s.ports.BackendPort, "hello", s.EnableTls, s.Dir)
if err != nil {
log.Printf("unable to create TCP server %v", err)
} else {
Expand Down
Loading

0 comments on commit 6a8fc0d

Please sign in to comment.