diff --git a/handlers/common/common.go b/handlers/common/common.go index cd706184..5d2ee945 100644 --- a/handlers/common/common.go +++ b/handlers/common/common.go @@ -13,10 +13,11 @@ var ( log = cfg.Cfg.Logger ) -func PrepareTokensAndClient(r *http.Request, ptokens *structs.PTokens, setpid bool) (error, *http.Client, *oauth2.Token) { +// PrepareTokensAndClient setup the client, usually for a UserInfo request +func PrepareTokensAndClient(r *http.Request, ptokens *structs.PTokens, setpid bool) (*http.Client, *oauth2.Token, error) { providerToken, err := cfg.OAuthClient.Exchange(context.TODO(), r.URL.Query().Get("code")) if err != nil { - return err, nil, nil + return nil, nil, err } ptokens.PAccessToken = providerToken.AccessToken @@ -33,11 +34,11 @@ func PrepareTokensAndClient(r *http.Request, ptokens *structs.PTokens, setpid bo log.Debugf("ptokens: %+v", ptokens) client := cfg.OAuthClient.Client(context.TODO(), providerToken) - return err, client, providerToken + return client, providerToken, err } +// MapClaims populate CustomClaims from userInfo for each configure claims header func MapClaims(claims []byte, customClaims *structs.CustomClaims) error { - // Create a struct that contains the claims that we want to store from the config. var f interface{} err := json.Unmarshal(claims, &f) if err != nil { diff --git a/handlers/github/github.go b/handlers/github/github.go index c07792e2..74ae5eb3 100644 --- a/handlers/github/github.go +++ b/handlers/github/github.go @@ -81,22 +81,20 @@ func (me Provider) GetUserInfo(r *http.Request, user *structs.User, customClaims org, team := toOrgAndTeam(orgAndTeam) if org != "" { log.Info(org) - var ( - e error - isMember bool - ) + var err error + isMember := false if team != "" { - e, isMember = getTeamMembershipStateFromGitHub(client, user, org, team, ptoken) + isMember, err = getTeamMembershipStateFromGitHub(client, user, org, team, ptoken) } else { - e, isMember = getOrgMembershipStateFromGitHub(client, user, org, ptoken) + isMember, err = getOrgMembershipStateFromGitHub(client, user, org, ptoken) } - if e != nil { - return e - } else { - if isMember { - user.TeamMemberships = append(user.TeamMemberships, orgAndTeam) - } + if err != nil { + return err + } + if isMember { + user.TeamMemberships = append(user.TeamMemberships, orgAndTeam) } + } else { log.Warnf("Invalid org/team format in %s: must be written as /", orgAndTeam) } @@ -108,12 +106,12 @@ func (me Provider) GetUserInfo(r *http.Request, user *structs.User, customClaims return nil } -func getOrgMembershipStateFromGitHub(client *http.Client, user *structs.User, orgId string, ptoken *oauth2.Token) (rerr error, isMember bool) { - replacements := strings.NewReplacer(":org_id", orgId, ":username", user.Username) +func getOrgMembershipStateFromGitHub(client *http.Client, user *structs.User, orgID string, ptoken *oauth2.Token) (isMember bool, rerr error) { + replacements := strings.NewReplacer(":org_id", orgID, ":username", user.Username) orgMembershipResp, err := client.Get(replacements.Replace(cfg.GenOAuth.UserOrgURL) + ptoken.AccessToken) if err != nil { log.Error(err) - return err, false + return false, err } if orgMembershipResp.StatusCode == 302 { @@ -126,22 +124,22 @@ func getOrgMembershipStateFromGitHub(client *http.Client, user *structs.User, or if orgMembershipResp.StatusCode == 204 { log.Debug("getOrgMembershipStateFromGitHub isMember: true") - return nil, true + return true, nil } else if orgMembershipResp.StatusCode == 404 { log.Debug("getOrgMembershipStateFromGitHub isMember: false") - return nil, false + return false, nil } else { log.Errorf("getOrgMembershipStateFromGitHub: unexpected status code %d", orgMembershipResp.StatusCode) - return errors.New("Unexpected response status " + orgMembershipResp.Status), false + return false, errors.New("Unexpected response status " + orgMembershipResp.Status) } } -func getTeamMembershipStateFromGitHub(client *http.Client, user *structs.User, orgId string, team string, ptoken *oauth2.Token) (rerr error, isMember bool) { - replacements := strings.NewReplacer(":org_id", orgId, ":team_slug", team, ":username", user.Username) +func getTeamMembershipStateFromGitHub(client *http.Client, user *structs.User, orgID string, team string, ptoken *oauth2.Token) (isMember bool, rerr error) { + replacements := strings.NewReplacer(":org_id", orgID, ":team_slug", team, ":username", user.Username) membershipStateResp, err := client.Get(replacements.Replace(cfg.GenOAuth.UserTeamURL) + ptoken.AccessToken) if err != nil { log.Error(err) - return err, false + return false, err } defer func() { if err := membershipStateResp.Body.Close(); err != nil { @@ -154,16 +152,15 @@ func getTeamMembershipStateFromGitHub(client *http.Client, user *structs.User, o ghTeamState := structs.GitHubTeamMembershipState{} if err = json.Unmarshal(data, &ghTeamState); err != nil { log.Error(err) - return err, false + return false, err } - log.Debug("getTeamMembershipStateFromGitHub ghTeamState") - log.Debug(ghTeamState) - return nil, ghTeamState.State == "active" + log.Debugf("getTeamMembershipStateFromGitHub ghTeamState %s", ghTeamState) + return ghTeamState.State == "active", nil } else if membershipStateResp.StatusCode == 404 { log.Debug("getTeamMembershipStateFromGitHub isMember: false") - return nil, false + return false, err } else { log.Errorf("getTeamMembershipStateFromGitHub: unexpected status code %d", membershipStateResp.StatusCode) - return errors.New("Unexpected response status " + membershipStateResp.Status), false + return false, errors.New("Unexpected response status " + membershipStateResp.Status) } } diff --git a/handlers/github/github_test.go b/handlers/github/github_test.go index 9fb2093f..9ecb98cf 100644 --- a/handlers/github/github_test.go +++ b/handlers/github/github_test.go @@ -97,7 +97,7 @@ func TestGetTeamMembershipStateFromGitHubActive(t *testing.T) { setUp() mockResponse(regexMatcher(".*"), http.StatusOK, map[string]string{}, []byte("{\"state\": \"active\"}")) - err, isMember := getTeamMembershipStateFromGitHub(client, user, "org1", "team1", token) + isMember, err := getTeamMembershipStateFromGitHub(client, user, "org1", "team1", token) assert.Nil(t, err) assert.True(t, isMember) @@ -107,7 +107,7 @@ func TestGetTeamMembershipStateFromGitHubInactive(t *testing.T) { setUp() mockResponse(regexMatcher(".*"), http.StatusOK, map[string]string{}, []byte("{\"state\": \"inactive\"}")) - err, isMember := getTeamMembershipStateFromGitHub(client, user, "org1", "team1", token) + isMember, err := getTeamMembershipStateFromGitHub(client, user, "org1", "team1", token) assert.Nil(t, err) assert.False(t, isMember) @@ -117,7 +117,7 @@ func TestGetTeamMembershipStateFromGitHubNotAMember(t *testing.T) { setUp() mockResponse(regexMatcher(".*"), http.StatusNotFound, map[string]string{}, []byte("")) - err, isMember := getTeamMembershipStateFromGitHub(client, user, "org1", "team1", token) + isMember, err := getTeamMembershipStateFromGitHub(client, user, "org1", "team1", token) assert.Nil(t, err) assert.False(t, isMember) @@ -127,7 +127,7 @@ func TestGetOrgMembershipStateFromGitHubNotFound(t *testing.T) { setUp() mockResponse(regexMatcher(".*"), http.StatusNotFound, map[string]string{}, []byte("")) - err, isMember := getOrgMembershipStateFromGitHub(client, user, "myorg", token) + isMember, err := getOrgMembershipStateFromGitHub(client, user, "myorg", token) assert.Nil(t, err) assert.False(t, isMember) @@ -143,7 +143,7 @@ func TestGetOrgMembershipStateFromGitHubNoOrgAccess(t *testing.T) { mockResponse(regexMatcher(".*orgs/myorg/members.*"), http.StatusFound, map[string]string{"Location": location}, []byte("")) mockResponse(regexMatcher(".*orgs/myorg/public_members.*"), http.StatusNoContent, map[string]string{}, []byte("")) - err, isMember := getOrgMembershipStateFromGitHub(client, user, "myorg", token) + isMember, err := getOrgMembershipStateFromGitHub(client, user, "myorg", token) assert.Nil(t, err) assert.True(t, isMember)