From 40e5c2e557216bccaf4c027e1464030acd45971b Mon Sep 17 00:00:00 2001 From: mchtech Date: Mon, 11 Nov 2024 20:38:58 +0800 Subject: [PATCH] add CAS connector Signed-off-by: mchtech --- connector/cas/cas.go | 129 +++++++++++++++++ connector/cas/cas_test.go | 192 +++++++++++++++++++++++++ connector/cas/testdata/cas_failure.xml | 5 + connector/cas/testdata/cas_success.xml | 15 ++ go.mod | 4 +- go.sum | 4 + server/server.go | 2 + 7 files changed, 350 insertions(+), 1 deletion(-) create mode 100644 connector/cas/cas.go create mode 100644 connector/cas/cas_test.go create mode 100644 connector/cas/testdata/cas_failure.xml create mode 100644 connector/cas/testdata/cas_success.xml diff --git a/connector/cas/cas.go b/connector/cas/cas.go new file mode 100644 index 0000000000..a99eed681f --- /dev/null +++ b/connector/cas/cas.go @@ -0,0 +1,129 @@ +// Package cas provides authentication strategies using CAS. +package cas + +import ( + "fmt" + "log/slog" + "net/http" + "net/url" + + "github.com/dexidp/dex/connector" + "github.com/pkg/errors" + "gopkg.in/cas.v2" +) + +// Config holds configuration options for CAS logins. +type Config struct { + Portal string `json:"portal"` + Mapping map[string]string `json:"mapping"` +} + +// Open returns a strategy for logging in through CAS. +func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) { + casURL, err := url.Parse(c.Portal) + if err != nil { + return "", fmt.Errorf("failed to parse casURL %q: %v", c.Portal, err) + } + return &casConnector{ + client: http.DefaultClient, + portal: casURL, + mapping: c.Mapping, + logger: logger.With(slog.Group("connector", "type", "cas", "id", id)), + pathSuffix: "/" + id, + }, nil +} + +var _ connector.CallbackConnector = (*casConnector)(nil) + +type casConnector struct { + client *http.Client + portal *url.URL + mapping map[string]string + logger *slog.Logger + pathSuffix string +} + +// LoginURL returns the URL to redirect the user to login with. +func (m *casConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { + u, err := url.Parse(callbackURL) + if err != nil { + return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) + } + u.Path += m.pathSuffix + // context = $callbackURL + $m.pathSuffix + v := u.Query() + v.Set("context", u.String()) // without query params + v.Set("state", state) + u.RawQuery = v.Encode() + + loginURL := *m.portal + loginURL.Path += "/login" + // encode service url to context, which used in `HandleCallback` + // service = $callbackURL + $m.pathSuffix ? state=$state & context=$callbackURL + $m.pathSuffix + q := loginURL.Query() + q.Set("service", u.String()) // service = ...?state=...&context=... + loginURL.RawQuery = q.Encode() + return loginURL.String(), nil +} + +// HandleCallback parses the request and returns the user's identity +func (m *casConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) { + state := r.URL.Query().Get("state") + ticket := r.URL.Query().Get("ticket") + // service=context = $callbackURL + $m.pathSuffix + serviceURL, err := url.Parse(r.URL.Query().Get("context")) + if err != nil { + return connector.Identity{}, fmt.Errorf("failed to parse serviceURL %q: %v", r.URL.Query().Get("context"), err) + } + // service = $callbackURL + $m.pathSuffix ? state=$state & context=$callbackURL + $m.pathSuffix + q := serviceURL.Query() + q.Set("context", serviceURL.String()) + q.Set("state", state) + serviceURL.RawQuery = q.Encode() + + user, err := m.getCasUserByTicket(ticket, serviceURL) + if err != nil { + return connector.Identity{}, err + } + m.logger.Info("cas user", "user", user) + return user, nil +} + +func (m *casConnector) getCasUserByTicket(ticket string, serviceURL *url.URL) (connector.Identity, error) { + id := connector.Identity{} + // validate ticket + validator := cas.NewServiceTicketValidator(m.client, m.portal) + resp, err := validator.ValidateTicket(serviceURL, ticket) + if err != nil { + return id, errors.Wrapf(err, "failed to validate ticket via %q with ticket %q", serviceURL, ticket) + } + // fill identity + id.UserID = resp.User + id.Groups = resp.MemberOf + if len(m.mapping) == 0 { + return id, nil + } + if username, ok := m.mapping["username"]; ok { + id.Username = resp.Attributes.Get(username) + if id.Username == "" && username == "userid" { + id.Username = resp.User + } + } + if preferredUsername, ok := m.mapping["preferred_username"]; ok { + id.PreferredUsername = resp.Attributes.Get(preferredUsername) + if id.PreferredUsername == "" && preferredUsername == "userid" { + id.PreferredUsername = resp.User + } + } + if email, ok := m.mapping["email"]; ok { + id.Email = resp.Attributes.Get(email) + if id.Email != "" { + id.EmailVerified = true + } + } + // override memberOf + if groups, ok := m.mapping["groups"]; ok { + id.Groups = resp.Attributes[groups] + } + return id, nil +} diff --git a/connector/cas/cas_test.go b/connector/cas/cas_test.go new file mode 100644 index 0000000000..47a1171e90 --- /dev/null +++ b/connector/cas/cas_test.go @@ -0,0 +1,192 @@ +package cas + +import ( + "fmt" + "log/slog" + "math/rand" + "net/http" + "net/url" + "os" + "reflect" + "testing" + "time" + + "github.com/dexidp/dex/connector" + "github.com/pkg/errors" + "gopkg.in/yaml.v3" +) + +type tcase struct { + xml string + mapping map[string]string + id connector.Identity + err string +} + +func TestOpen(t *testing.T) { + configSection := ` +portal: https://example.org/cas +mapping: + username: name + preferred_username: username + email: email + groups: affiliation +` + + var config Config + if err := yaml.Unmarshal([]byte(configSection), &config); err != nil { + t.Errorf("parse config: %v", err) + return + } + + conn, err := config.Open("cas", slog.Default()) + if err != nil { + t.Errorf("open connector: %v", err) + return + } + + casConnector, _ := conn.(*casConnector) + if casConnector.portal.String() != config.Portal { + t.Errorf("expected portal %q, got %q", config.Portal, casConnector.portal.String()) + return + } + if !reflect.DeepEqual(casConnector.mapping, config.Mapping) { + t.Errorf("expected mapping %v, got %v", config.Mapping, casConnector.mapping) + return + } +} + +func TestCAS(t *testing.T) { + callback := "https://dex.example.org/dex/callback" + casURL, _ := url.Parse("https://example.org/cas") + scope := connector.Scopes{Groups: true} + + cases := []tcase{{ + xml: "testdata/cas_success.xml", + mapping: map[string]string{ + "username": "name", + "preferred_username": "username", + "email": "email", + }, + id: connector.Identity{ + UserID: "123456", + Username: "jdoe", + PreferredUsername: "jdoe", + Email: "jdoe@example.org", + EmailVerified: true, + Groups: []string{"A", "B"}, + ConnectorData: nil, + }, + err: "", + }, { + xml: "testdata/cas_success.xml", + mapping: map[string]string{ + "username": "name", + "preferred_username": "username", + "email": "email", + "groups": "affiliation", + }, + id: connector.Identity{ + UserID: "123456", + Username: "jdoe", + PreferredUsername: "jdoe", + Email: "jdoe@example.org", + EmailVerified: true, + Groups: []string{"staff", "faculty"}, + ConnectorData: nil, + }, + err: "", + }, { + xml: "testdata/cas_failure.xml", + mapping: map[string]string{}, + id: connector.Identity{}, + err: "INVALID_TICKET: Ticket ST-1856339-aA5Yuvrxzpv8Tau1cYQ7 not recognized", + }} + + seed := rand.NewSource(time.Now().UnixNano()) + for _, tc := range cases { + ticket := fmt.Sprintf("ST-%d", seed.Int63()) + state := fmt.Sprintf("%d", seed.Int63()) + + conn := &casConnector{ + portal: casURL, + mapping: tc.mapping, + logger: slog.Default(), + pathSuffix: "/cas", + client: &http.Client{ + Transport: &mockTransport{ + ticket: ticket, + file: tc.xml, + }, + }, + } + + // login + login, err := conn.LoginURL(scope, callback, state) + if err != nil { + t.Errorf("get login url: %v", err) + return + } + loginURL, err := url.Parse(login) + if err != nil { + t.Errorf("parse login url: %v", err) + return + } + + // cas server + queryService := loginURL.Query().Get("service") + serviceURL, err := url.Parse(queryService) + if err != nil { + t.Errorf("parse service url: %v", err) + return + } + serviceQueryState := serviceURL.Query().Get("state") + if serviceQueryState != state { + t.Errorf("state: expected %#v, got %#v", state, serviceQueryState) + return + } + req, _ := http.NewRequest(http.MethodGet, queryService, nil) + q := req.URL.Query() + q.Set("ticket", ticket) + req.URL.RawQuery = q.Encode() + + // validate + id, err := conn.HandleCallback(scope, req) + if err != nil { + if c := errors.Cause(err); c != nil && tc.err != "" && c.Error() == tc.err { + continue + } + t.Errorf("handle callback: %v", err) + return + } + if !reflect.DeepEqual(id, tc.id) { + t.Errorf("identity: expected %#v, got %#v", tc.id, id) + return + } + } +} + +type mockTransport struct { + ticket string + file string +} + +func (f *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + file, err := os.Open(f.file) + if err != nil { + return nil, err + } + + if ticket := req.URL.Query().Get("ticket"); ticket != f.ticket { + return nil, fmt.Errorf("ticket: expected %#v, got %#v", f.ticket, ticket) + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: file, + Header: http.Header{ + "Content-Type": []string{"text/xml"}, + }, + Request: req, + }, nil +} diff --git a/connector/cas/testdata/cas_failure.xml b/connector/cas/testdata/cas_failure.xml new file mode 100644 index 0000000000..0e21ba8583 --- /dev/null +++ b/connector/cas/testdata/cas_failure.xml @@ -0,0 +1,5 @@ + + + Ticket ST-1856339-aA5Yuvrxzpv8Tau1cYQ7 not recognized + + \ No newline at end of file diff --git a/connector/cas/testdata/cas_success.xml b/connector/cas/testdata/cas_success.xml new file mode 100644 index 0000000000..560b5c20bc --- /dev/null +++ b/connector/cas/testdata/cas_success.xml @@ -0,0 +1,15 @@ + + + 123456 + + jdoe + jdoe + jdoe@example.org + staff + faculty + A + B + + PGTIOU-84678-8a9d... + + \ No newline at end of file diff --git a/go.mod b/go.mod index dfa9e39364..f8f70254b4 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,8 @@ require ( google.golang.org/api v0.203.0 google.golang.org/grpc v1.67.1 google.golang.org/protobuf v1.35.1 + gopkg.in/cas.v2 v2.2.2 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -63,6 +65,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/inflect v0.19.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/glog v1.2.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/go-cmp v0.6.0 // indirect @@ -101,7 +104,6 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) replace github.com/dexidp/dex/api/v2 => ./api/v2 diff --git a/go.sum b/go.sum index e7a0ec0c64..a74b6bb6cf 100644 --- a/go.sum +++ b/go.sum @@ -90,6 +90,8 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY= +github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -394,6 +396,8 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +gopkg.in/cas.v2 v2.2.2 h1:teLr/JI7VDEQu6qkXKndYac9w5tfy57sWlV+eNYHH+o= +gopkg.in/cas.v2 v2.2.2/go.mod h1:mlmjh4qM/Jm3eSDD0QVr5GaaSW3nOonSUSWkLLvNYnI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/server/server.go b/server/server.go index 5c5faa3003..4a2e44bc01 100644 --- a/server/server.go +++ b/server/server.go @@ -32,6 +32,7 @@ import ( "github.com/dexidp/dex/connector/atlassiancrowd" "github.com/dexidp/dex/connector/authproxy" "github.com/dexidp/dex/connector/bitbucketcloud" + "github.com/dexidp/dex/connector/cas" "github.com/dexidp/dex/connector/gitea" "github.com/dexidp/dex/connector/github" "github.com/dexidp/dex/connector/gitlab" @@ -663,6 +664,7 @@ var ConnectorsConfig = map[string]func() ConnectorConfig{ "bitbucket-cloud": func() ConnectorConfig { return new(bitbucketcloud.Config) }, "openshift": func() ConnectorConfig { return new(openshift.Config) }, "atlassian-crowd": func() ConnectorConfig { return new(atlassiancrowd.Config) }, + "cas": func() ConnectorConfig { return new(cas.Config) }, // Keep around for backwards compatibility. "samlExperimental": func() ConnectorConfig { return new(saml.Config) }, }