diff --git a/oauth2cli.go b/oauth2cli.go index 578ecd3..56627a0 100644 --- a/oauth2cli.go +++ b/oauth2cli.go @@ -7,8 +7,9 @@ import ( "fmt" "net/http" - "github.com/int128/oauth2cli/oauth2params" "golang.org/x/oauth2" + + "github.com/int128/oauth2cli/oauth2params" ) var noopMiddleware = func(h http.Handler) http.Handler { return h } @@ -56,6 +57,9 @@ type Config struct { // You can set this if your provider does not accept localhost. // Default to localhost. RedirectURLHostname string + // RedirectURLPath is the path of the redirect URL. + // Default to /. + RedirectURLPath string // Options for an authorization request. // You can set oauth2.AccessTypeOffline and the PKCE options here. AuthCodeOptions []oauth2.AuthCodeOption @@ -138,13 +142,12 @@ func (c *Config) validateAndSetDefaults() error { // // This performs the following steps: // -// 1. Start a local server at the port. -// 2. Open a browser and navigate it to the local server. -// 3. Wait for the user authorization. -// 4. Receive a code via an authorization response (HTTP redirect). -// 5. Exchange the code and a token. -// 6. Return the code. -// +// 1. Start a local server at the port. +// 2. Open a browser and navigate it to the local server. +// 3. Wait for the user authorization. +// 4. Receive a code via an authorization response (HTTP redirect). +// 5. Exchange the code and a token. +// 6. Return the code. func GetToken(ctx context.Context, c Config) (*oauth2.Token, error) { if err := c.validateAndSetDefaults(); err != nil { return nil, fmt.Errorf("invalid config: %w", err) diff --git a/server.go b/server.go index 3b5d23f..3a9d370 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/http" + "path" "sync" "time" @@ -94,11 +95,26 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) { } func computeRedirectURL(l net.Listener, c *Config) string { - hostPort := fmt.Sprintf("%s:%d", c.RedirectURLHostname, l.Addr().(*net.TCPAddr).Port) + port := l.Addr().(*net.TCPAddr).Port + scheme := "http" if c.LocalServerCertFile != "" { - return "https://" + hostPort + scheme = "https" } - return "http://" + hostPort + + isDefaultPort := false + if (scheme == "https" && port == 443) || (scheme == "http" && port == 80) { + isDefaultPort = true + } + + u := fmt.Sprintf("%s://%s:%d", scheme, c.RedirectURLHostname, port) + if isDefaultPort { + u = fmt.Sprintf("%s://%s", scheme, c.RedirectURLHostname) + } + + if c.RedirectURLPath != "" { + u = u + path.Join("/", c.RedirectURLPath) + } + return u } type authorizationResponse struct { @@ -114,16 +130,21 @@ type localServerHandler struct { func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() + redirectPath := "/" + if h.config.RedirectURLPath != "" { + redirectPath = h.config.RedirectURLPath + } + switch { - case r.Method == "GET" && r.URL.Path == "/" && q.Get("error") != "": + case r.Method == "GET" && r.URL.Path == redirectPath && q.Get("error") != "": h.onceRespCh.Do(func() { h.respCh <- h.handleErrorResponse(w, r) }) - case r.Method == "GET" && r.URL.Path == "/" && q.Get("code") != "": + case r.Method == "GET" && r.URL.Path == redirectPath && q.Get("code") != "": h.onceRespCh.Do(func() { h.respCh <- h.handleCodeResponse(w, r) }) - case r.Method == "GET" && r.URL.Path == "/": + case r.Method == "GET" && r.URL.Path == redirectPath: h.handleIndex(w, r) default: http.NotFound(w, r)