diff --git a/internal/api/admin_test.go b/internal/api/admin_test.go index 6539be30e..64ee821a6 100644 --- a/internal/api/admin_test.go +++ b/internal/api/admin_test.go @@ -705,7 +705,8 @@ func (ts *AdminTestSuite) TestAdminUserCreateWithDisabledLogin() { req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) - *ts.Config = *c.customConfig + ts.Config.JWT = c.customConfig.JWT + ts.Config.External = c.customConfig.External ts.API.handler.ServeHTTP(w, req) require.Equal(ts.T(), c.expected, w.Code) }) diff --git a/internal/api/api.go b/internal/api/api.go index 77bd27590..8bc95bd4b 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -85,6 +85,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.Route("/callback", func(r *router) { r.UseBypass(logger) + r.Use(api.isValidExternalHost) r.Use(api.loadFlowState) r.Get("/", api.ExternalProviderCallback) @@ -93,6 +94,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.Route("/", func(r *router) { r.UseBypass(logger) + r.Use(api.isValidExternalHost) r.Get("/settings", api.Settings) diff --git a/internal/api/context.go b/internal/api/context.go index cb5004021..107377b50 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -2,6 +2,7 @@ package api import ( "context" + "net/url" jwt "github.com/golang-jwt/jwt" "github.com/supabase/gotrue/internal/models" @@ -28,6 +29,7 @@ const ( oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token oauthVerifierKey = contextKey("oauth_verifier") ssoProviderKey = contextKey("sso_provider") + externalHostKey = contextKey("external_host") flowStateKey = contextKey("flow_state_id") ) @@ -235,3 +237,15 @@ func getSSOProvider(ctx context.Context) *models.SSOProvider { } return obj.(*models.SSOProvider) } + +func withExternalHost(ctx context.Context, u *url.URL) context.Context { + return context.WithValue(ctx, externalHostKey, u) +} + +func getExternalHost(ctx context.Context) *url.URL { + obj := ctx.Value(externalHostKey) + if obj == nil { + return nil + } + return obj.(*url.URL) +} diff --git a/internal/api/external.go b/internal/api/external.go index 4351d994c..582695cb2 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -377,7 +377,8 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. if !emailData.Verified && !config.Mailer.Autoconfirm { mailer := a.Mailer(ctx) referrer := a.getReferrer(r) - if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { + externalURL := getExternalHost(ctx) + if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { return nil, tooManyRequestsError("For security purposes, you can only request this once every minute") } @@ -510,41 +511,59 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, error) { config := a.config name = strings.ToLower(name) + callbackURL := getExternalHost(ctx).String() + "/callback" switch name { case "apple": + config.External.Apple.RedirectURI = callbackURL return provider.NewAppleProvider(config.External.Apple) case "azure": + config.External.Azure.RedirectURI = callbackURL return provider.NewAzureProvider(config.External.Azure, scopes) case "bitbucket": + config.External.Bitbucket.RedirectURI = callbackURL return provider.NewBitbucketProvider(config.External.Bitbucket) case "discord": + config.External.Discord.RedirectURI = callbackURL return provider.NewDiscordProvider(config.External.Discord, scopes) case "github": + config.External.Github.RedirectURI = callbackURL return provider.NewGithubProvider(config.External.Github, scopes) case "gitlab": + config.External.Gitlab.RedirectURI = callbackURL return provider.NewGitlabProvider(config.External.Gitlab, scopes) case "google": + config.External.Google.RedirectURI = callbackURL return provider.NewGoogleProvider(config.External.Google, scopes) case "keycloak": + config.External.Keycloak.RedirectURI = callbackURL return provider.NewKeycloakProvider(config.External.Keycloak, scopes) case "linkedin": + config.External.Linkedin.RedirectURI = callbackURL return provider.NewLinkedinProvider(config.External.Linkedin, scopes) case "facebook": + config.External.Facebook.RedirectURI = callbackURL return provider.NewFacebookProvider(config.External.Facebook, scopes) case "notion": + config.External.Notion.RedirectURI = callbackURL return provider.NewNotionProvider(config.External.Notion) case "spotify": + config.External.Spotify.RedirectURI = callbackURL return provider.NewSpotifyProvider(config.External.Spotify, scopes) case "slack": + config.External.Slack.RedirectURI = callbackURL return provider.NewSlackProvider(config.External.Slack, scopes) case "twitch": + config.External.Twitch.RedirectURI = callbackURL return provider.NewTwitchProvider(config.External.Twitch, scopes) case "twitter": + config.External.Twitter.RedirectURI = callbackURL return provider.NewTwitterProvider(config.External.Twitter, scopes) case "workos": + config.External.WorkOS.RedirectURI = callbackURL return provider.NewWorkOSProvider(config.External.WorkOS) case "zoom": + config.External.Zoom.RedirectURI = callbackURL return provider.NewZoomProvider(config.External.Zoom) default: return nil, fmt.Errorf("Provider %s could not be found", name) diff --git a/internal/api/invite.go b/internal/api/invite.go index 85a651d15..0d9d054a5 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -79,7 +79,8 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { mailer := a.Mailer(ctx) referrer := a.getReferrer(r) - if err := sendInvite(tx, user, mailer, referrer, config.Mailer.OtpLength); err != nil { + externalURL := getExternalHost(ctx) + if err := sendInvite(tx, user, mailer, referrer, externalURL, config.Mailer.OtpLength); err != nil { return internalServerError("Error inviting user").WithInternalError(err) } return nil diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index fc7b9faea..ad217b19a 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -142,7 +142,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { mailer := a.Mailer(ctx) referrer := a.getReferrer(r) - return a.sendMagicLink(tx, user, mailer, config.SMTP.MaxFrequency, referrer, config.Mailer.OtpLength, flowType) + externalURL := getExternalHost(ctx) + return a.sendMagicLink(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType) }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { diff --git a/internal/api/mail.go b/internal/api/mail.go index 73f5b834b..d15507847 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "strings" "time" @@ -205,7 +206,8 @@ func (a *API) GenerateLink(w http.ResponseWriter, r *http.Request) error { return terr } - url, terr = mailer.GetEmailActionLink(user, params.Type, referrer) + externalURL := getExternalHost(ctx) + url, terr = mailer.GetEmailActionLink(user, params.Type, referrer, externalURL) if terr != nil { return terr } @@ -228,7 +230,7 @@ func (a *API) GenerateLink(w http.ResponseWriter, r *http.Request) error { return sendJSON(w, http.StatusOK, resp) } -func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, otpLength int, flowType models.FlowType) error { +func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { var err error if u.ConfirmationSentAt != nil && !u.ConfirmationSentAt.Add(maxFrequency).Before(time.Now()) { return MaxFrequencyLimitError @@ -241,7 +243,7 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail token := fmt.Sprintf("%x", sha256.Sum224([]byte(u.GetEmail()+otp))) u.ConfirmationToken = addFlowPrefixToToken(token, flowType) now := time.Now() - if err := mailer.ConfirmationMail(u, otp, referrerURL); err != nil { + if err := mailer.ConfirmationMail(u, otp, referrerURL, externalURL); err != nil { u.ConfirmationToken = oldToken return errors.Wrap(err, "Error sending confirmation email") } @@ -249,7 +251,7 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") } -func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, otpLength int) error { +func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, externalURL *url.URL, otpLength int) error { var err error oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) @@ -258,7 +260,7 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re } u.ConfirmationToken = fmt.Sprintf("%x", sha256.Sum224([]byte(u.GetEmail()+otp))) now := time.Now() - if err := mailer.InviteMail(u, otp, referrerURL); err != nil { + if err := mailer.InviteMail(u, otp, referrerURL, externalURL); err != nil { u.ConfirmationToken = oldToken return errors.Wrap(err, "Error sending invite email") } @@ -267,7 +269,7 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") } -func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, otpLength int, flowType models.FlowType) error { +func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { var err error if u.RecoverySentAt != nil && !u.RecoverySentAt.Add(maxFrequency).Before(time.Now()) { return MaxFrequencyLimitError @@ -281,7 +283,7 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile token := fmt.Sprintf("%x", sha256.Sum224([]byte(u.GetEmail()+otp))) u.RecoveryToken = addFlowPrefixToToken(token, flowType) now := time.Now() - if err := mailer.RecoveryMail(u, otp, referrerURL); err != nil { + if err := mailer.RecoveryMail(u, otp, referrerURL, externalURL); err != nil { u.RecoveryToken = oldToken return errors.Wrap(err, "Error sending recovery email") } @@ -313,7 +315,7 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma return errors.Wrap(tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"), "Database error updating user for reauthentication") } -func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, otpLength int, flowType models.FlowType) error { +func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { var err error // since Magic Link is just a recovery with a different template and behaviour // around new users we will reuse the recovery db timer to prevent potential abuse @@ -329,7 +331,7 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile u.RecoveryToken = addFlowPrefixToToken(token, flowType) now := time.Now() - if err := mailer.MagicLinkMail(u, otp, referrerURL); err != nil { + if err := mailer.MagicLinkMail(u, otp, referrerURL, externalURL); err != nil { u.RecoveryToken = oldToken return errors.Wrap(err, "Error sending magic link email") } @@ -338,7 +340,7 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile } // sendEmailChange sends out an email change token to the new email. -func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfiguration, u *models.User, mailer mailer.Mailer, email string, referrerURL string, otpLength int, flowType models.FlowType) error { +func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfiguration, u *models.User, mailer mailer.Mailer, email, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { var err error if u.EmailChangeSentAt != nil && !u.EmailChangeSentAt.Add(config.SMTP.MaxFrequency).Before(time.Now()) { return MaxFrequencyLimitError @@ -366,7 +368,7 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu u.EmailChangeConfirmStatus = zeroConfirmation now := time.Now() - if err := mailer.EmailChangeMail(u, otpNew, otpCurrent, referrerURL); err != nil { + if err := mailer.EmailChangeMail(u, otpNew, otpCurrent, referrerURL, externalURL); err != nil { return err } diff --git a/internal/api/mail_test.go b/internal/api/mail_test.go index 23763f99c..28517ef96 100644 --- a/internal/api/mail_test.go +++ b/internal/api/mail_test.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/golang-jwt/jwt" @@ -39,6 +40,11 @@ func (ts *MailTestSuite) SetupTest() { models.TruncateAll(ts.API.db) ts.Config.Mailer.SecureEmailChangeEnabled = true + + // Create User + u, err := models.NewUser("12345678", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating new user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new user") } func (ts *MailTestSuite) TestGenerateLink() { @@ -108,11 +114,14 @@ func (ts *MailTestSuite) TestGenerateLink() { }, } + customDomainUrl, err := url.ParseRequestURI("https://example.gotrue.com") + require.NoError(ts.T(), err) + for _, c := range cases { ts.Run(c.Desc, func() { var buffer bytes.Buffer require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.Body)) - req := httptest.NewRequest(http.MethodPost, "/admin/generate_link", &buffer) + req := httptest.NewRequest(http.MethodPost, customDomainUrl.String()+"/admin/generate_link", &buffer) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) w := httptest.NewRecorder() @@ -131,6 +140,11 @@ func (ts *MailTestSuite) TestGenerateLink() { // check if hashed_token matches hash function of email and the raw otp require.Equal(ts.T(), data["hashed_token"], fmt.Sprintf("%x", sha256.Sum224([]byte(c.Body.Email+data["email_otp"].(string))))) + + // check if the host used in the email link matches the initial request host + u, err := url.ParseRequestURI(data["action_link"].(string)) + require.NoError(ts.T(), err) + require.Equal(ts.T(), req.Host, u.Host) }) } } diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 8300a19f4..d868fb483 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -3,7 +3,9 @@ package api import ( "context" "encoding/json" + "fmt" "net/http" + "net/url" "strings" "time" @@ -179,6 +181,32 @@ func isIgnoreCaptchaRoute(req *http.Request) bool { return false } +func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + config := a.config + + var u *url.URL + var err error + + baseUrl := config.API.ExternalURL + xForwardedHost := req.Header.Get("X-Forwarded-Host") + xForwardedProto := req.Header.Get("X-Forwarded-Proto") + if xForwardedHost != "" && xForwardedProto != "" { + baseUrl = fmt.Sprintf("%s://%s", xForwardedProto, xForwardedHost) + } else if req.URL.Scheme != "" && req.URL.Hostname() != "" { + baseUrl = fmt.Sprintf("%s://%s", req.URL.Scheme, req.URL.Hostname()) + } + if u, err = url.ParseRequestURI(baseUrl); err != nil { + // fallback to the default hostname + log := observability.GetLogEntry(req) + log.WithField("request_url", baseUrl).Warn(err) + if u, err = url.ParseRequestURI(config.API.ExternalURL); err != nil { + return ctx, err + } + } + return withExternalHost(ctx, u), nil +} + func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.SAML.Enabled { diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index 0d1ad20e4..c3c157588 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "testing" jwt "github.com/golang-jwt/jwt" @@ -229,6 +230,35 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() { } } +func (ts *MiddlewareTestSuite) TestIsValidExternalHost() { + cases := []struct { + desc string + requestURL string + expectedURL string + }{ + { + desc: "Valid custom external url", + requestURL: "https://example.custom.com", + expectedURL: "https://example.custom.com", + }, + } + + _, err := url.ParseRequestURI("https://example.custom.com") + require.NoError(ts.T(), err) + + for _, c := range cases { + ts.Run(c.desc, func() { + req := httptest.NewRequest(http.MethodPost, c.requestURL, nil) + w := httptest.NewRecorder() + ctx, err := ts.API.isValidExternalHost(w, req) + require.NoError(ts.T(), err) + + externalURL := getExternalHost(ctx) + require.Equal(ts.T(), c.expectedURL, externalURL.String()) + }) + } +} + func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() { cases := []struct { desc string diff --git a/internal/api/recover.go b/internal/api/recover.go index d3276fe3f..661503ecd 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -77,7 +77,8 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { return terr } } - return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency, referrer, config.Mailer.OtpLength, flowType) + externalURL := getExternalHost(ctx) + return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType) }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { diff --git a/internal/api/resend.go b/internal/api/resend.go index 2403d348b..22ecdc1b8 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -115,13 +115,14 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { err = db.Transaction(func(tx *storage.Connection) error { mailer := a.Mailer(ctx) referrer := a.getReferrer(r) + externalURL := getExternalHost(ctx) switch params.Type { case signupVerification: if terr := models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", nil); terr != nil { return terr } // PKCE not implemented yet - return sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, config.Mailer.OtpLength, models.ImplicitFlow) + return sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow) case smsVerification: if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { return terr @@ -132,7 +133,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { } return a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider) case emailChangeVerification: - return a.sendEmailChange(tx, config, user, mailer, user.EmailChange, referrer, config.Mailer.OtpLength, models.ImplicitFlow) + return a.sendEmailChange(tx, config, user, mailer, params.Email, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow) case phoneChangeVerification: smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { diff --git a/internal/api/signup.go b/internal/api/signup.go index c643efde9..5d123e5ea 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -189,7 +189,8 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return terr } } - if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, config.Mailer.OtpLength, flowType); terr != nil { + externalURL := getExternalHost(ctx) + if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { now := time.Now() left := user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency).Sub(now) / time.Second diff --git a/internal/api/token.go b/internal/api/token.go index 5149970cb..5412854f3 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -525,7 +525,8 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R mailer := a.Mailer(ctx) referrer := a.getReferrer(r) - if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { + externalURL := getExternalHost(ctx) + if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { return internalServerError("Error sending confirmation mail").WithInternalError(terr) } return unauthorizedError("Error unverified email") diff --git a/internal/api/user.go b/internal/api/user.go index 100f50568..e446bbaca 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -188,7 +188,8 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { return terr } } - if terr = a.sendEmailChange(tx, config, user, mailer, params.Email, referrer, config.Mailer.OtpLength, flowType); terr != nil { + externalURL := getExternalHost(ctx) + if terr = a.sendEmailChange(tx, config, user, mailer, params.Email, referrer, externalURL, config.Mailer.OtpLength, flowType); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") } diff --git a/internal/mailer/mailer.go b/internal/mailer/mailer.go index 22054fa33..b73be371c 100644 --- a/internal/mailer/mailer.go +++ b/internal/mailer/mailer.go @@ -15,14 +15,14 @@ import ( // Mailer defines the interface a mailer must implement. type Mailer interface { Send(user *models.User, subject, body string, data map[string]interface{}) error - InviteMail(user *models.User, otp, referrerURL string) error - ConfirmationMail(user *models.User, otp, referrerURL string) error - RecoveryMail(user *models.User, otp, referrerURL string) error - MagicLinkMail(user *models.User, otp, referrerURL string) error - EmailChangeMail(user *models.User, otpNew, otpCurrent, referrerURL string) error + InviteMail(user *models.User, otp, referrerURL string, externalURL *url.URL) error + ConfirmationMail(user *models.User, otp, referrerURL string, externalURL *url.URL) error + RecoveryMail(user *models.User, otp, referrerURL string, externalURL *url.URL) error + MagicLinkMail(user *models.User, otp, referrerURL string, externalURL *url.URL) error + EmailChangeMail(user *models.User, otpNew, otpCurrent, referrerURL string, externalURL *url.URL) error ReauthenticateMail(user *models.User, otp string) error ValidateEmail(email string) error - GetEmailActionLink(user *models.User, actionType, referrerURL string) (string, error) + GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) } // NewMailer returns a new gotrue mailer @@ -64,23 +64,19 @@ func withDefault(value, defaultValue string) string { return value } -func getSiteURL(referrerURL, siteURL, filepath, fragment string) (string, error) { - baseURL := siteURL - if filepath == "" && referrerURL != "" { - baseURL = referrerURL - } - - site, err := url.Parse(baseURL) - if err != nil { - return "", err - } +func getPath(filepath string, params map[string]string) (*url.URL, error) { + path := &url.URL{} if filepath != "" { - path, err := url.Parse(filepath) - if err != nil { - return "", err + if p, err := url.Parse(filepath); err != nil { + return nil, err + } else { + path = p } - site = site.ResolveReference(path) } - site.RawQuery = fragment - return site.String(), nil + v := url.Values{} + for key, val := range params { + v.Add(key, val) + } + path.RawQuery = v.Encode() + return path, nil } diff --git a/internal/mailer/mailer_test.go b/internal/mailer/mailer_test.go index 2cd565a1a..0c37e9383 100644 --- a/internal/mailer/mailer_test.go +++ b/internal/mailer/mailer_test.go @@ -1,6 +1,7 @@ package mailer import ( + "net/url" "regexp" "testing" @@ -13,27 +14,57 @@ func enforceRelativeURL(url string) string { return urlRegexp.ReplaceAllString(url, "") } -func TestGetSiteURL(t *testing.T) { +func TestGetPath(t *testing.T) { cases := []struct { - ReferrerURL string - SiteURL string - Path string - Fragment string - Expected string + SiteURL string + Path string + Params map[string]string + Expected string }{ - {"", "https://test.example.com", "/templates/confirm.html", "", "https://test.example.com/templates/confirm.html"}, - {"", "https://test.example.com/removedpath", "/templates/confirm.html", "", "https://test.example.com/templates/confirm.html"}, - {"", "https://test.example.com/", "/trailingslash/", "", "https://test.example.com/trailingslash/"}, - {"", "https://test.example.com", "f", "fragment", "https://test.example.com/f?fragment"}, - {"https://test.example.com/admin", "https://test.example.com", "", "fragment", "https://test.example.com/admin?fragment"}, - {"https://test.example.com/admin", "https://test.example.com", "f", "fragment", "https://test.example.com/f?fragment"}, - {"", "https://test.example.com", "", "fragment", "https://test.example.com?fragment"}, + { + SiteURL: "https://test.example.com", + Path: "/templates/confirm.html", + Params: nil, + Expected: "https://test.example.com/templates/confirm.html", + }, + { + SiteURL: "https://test.example.com/removedpath", + Path: "/templates/confirm.html", + Params: nil, + Expected: "https://test.example.com/templates/confirm.html", + }, + { + SiteURL: "https://test.example.com/", + Path: "/trailingslash/", + Params: nil, + Expected: "https://test.example.com/trailingslash/", + }, + { + SiteURL: "https://test.example.com", + Path: "f", + Params: map[string]string{ + "key": "val", + }, + Expected: "https://test.example.com/f?key=val", + }, + { + SiteURL: "https://test.example.com", + Path: "", + Params: map[string]string{ + "key": "val", + }, + Expected: "https://test.example.com?key=val", + }, } for _, c := range cases { - act, err := getSiteURL(c.ReferrerURL, c.SiteURL, c.Path, c.Fragment) + u, err := url.ParseRequestURI(c.SiteURL) + assert.NoError(t, err, "error parsing URI request") + + path, err := getPath(c.Path, c.Params) + assert.NoError(t, err, c.Expected) - assert.Equal(t, c.Expected, act) + assert.Equal(t, c.Expected, u.ResolveReference(path).String()) } } diff --git a/internal/mailer/template.go b/internal/mailer/template.go index b9e696152..0fe1329fb 100644 --- a/internal/mailer/template.go +++ b/internal/mailer/template.go @@ -21,8 +21,7 @@ type TemplateMailer struct { Mailer MailClient } -func encodeRedirectParam(referrerURL string) string { - redirectParam := "" +func encodeRedirectURL(referrerURL string) string { if len(referrerURL) > 0 { if strings.ContainsAny(referrerURL, "&=#") { // if the string contains &, = or # it has not been URL @@ -30,11 +29,8 @@ func encodeRedirectParam(referrerURL string) string { // encoded by us otherwise, it should be taken as-is referrerURL = url.QueryEscape(referrerURL) } - - redirectParam = "&redirect_to=" + referrerURL } - - return redirectParam + return referrerURL } const defaultInviteMail = `

You have been invited

@@ -79,16 +75,20 @@ func (m TemplateMailer) ValidateEmail(email string) error { } // InviteMail sends a invite mail to a new user -func (m *TemplateMailer) InviteMail(user *models.User, otp, referrerURL string) error { - redirectParam := encodeRedirectParam(referrerURL) +func (m *TemplateMailer) InviteMail(user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Invite, map[string]string{ + "token": user.ConfirmationToken, + "type": "invite", + "redirect_to": encodeRedirectURL(referrerURL), + }) - url, err := getSiteURL(referrerURL, m.Config.API.ExternalURL, m.Config.Mailer.URLPaths.Invite, "token="+user.ConfirmationToken+"&type=invite"+redirectParam) if err != nil { return err } + data := map[string]interface{}{ "SiteURL": m.Config.SiteURL, - "ConfirmationURL": url, + "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, "TokenHash": user.ConfirmationToken, @@ -105,16 +105,19 @@ func (m *TemplateMailer) InviteMail(user *models.User, otp, referrerURL string) } // ConfirmationMail sends a signup confirmation mail to a new user -func (m *TemplateMailer) ConfirmationMail(user *models.User, otp, referrerURL string) error { - redirectParam := encodeRedirectParam(referrerURL) - fragment := "token=" + user.ConfirmationToken + "&type=signup" + redirectParam - url, err := getSiteURL(referrerURL, m.Config.API.ExternalURL, m.Config.Mailer.URLPaths.Confirmation, fragment) +func (m *TemplateMailer) ConfirmationMail(user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Confirmation, map[string]string{ + "token": user.ConfirmationToken, + "type": "signup", + "redirect_to": encodeRedirectURL(referrerURL), + }) if err != nil { return err } + data := map[string]interface{}{ "SiteURL": m.Config.SiteURL, - "ConfirmationURL": url, + "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, "TokenHash": user.ConfirmationToken, @@ -149,7 +152,7 @@ func (m *TemplateMailer) ReauthenticateMail(user *models.User, otp string) error } // EmailChangeMail sends an email change confirmation mail to a user -func (m *TemplateMailer) EmailChangeMail(user *models.User, otpNew, otpCurrent, referrerURL string) error { +func (m *TemplateMailer) EmailChangeMail(user *models.User, otpNew, otpCurrent, referrerURL string, externalURL *url.URL) error { type Email struct { Address string Otp string @@ -178,15 +181,15 @@ func (m *TemplateMailer) EmailChangeMail(user *models.User, otpNew, otpCurrent, }) } - redirectParam := encodeRedirectParam(referrerURL) - errors := make(chan error) for _, email := range emails { - url, err := getSiteURL( - referrerURL, - m.Config.API.ExternalURL, + path, err := getPath( m.Config.Mailer.URLPaths.EmailChange, - "token="+email.TokenHash+"&type=email_change"+redirectParam, + map[string]string{ + "token": email.TokenHash, + "type": "email_change", + "redirect_to": encodeRedirectURL(referrerURL), + }, ) if err != nil { return err @@ -194,7 +197,7 @@ func (m *TemplateMailer) EmailChangeMail(user *models.User, otpNew, otpCurrent, go func(address, token, tokenHash, template string) { data := map[string]interface{}{ "SiteURL": m.Config.SiteURL, - "ConfirmationURL": url, + "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.GetEmail(), "NewEmail": user.EmailChange, "Token": token, @@ -222,15 +225,18 @@ func (m *TemplateMailer) EmailChangeMail(user *models.User, otpNew, otpCurrent, } // RecoveryMail sends a password recovery mail -func (m *TemplateMailer) RecoveryMail(user *models.User, otp, referrerURL string) error { - redirectParam := encodeRedirectParam(referrerURL) - url, err := getSiteURL(referrerURL, m.Config.API.ExternalURL, m.Config.Mailer.URLPaths.Recovery, "token="+user.RecoveryToken+"&type=recovery"+redirectParam) +func (m *TemplateMailer) RecoveryMail(user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Recovery, map[string]string{ + "token": user.RecoveryToken, + "type": "recovery", + "redirect_to": encodeRedirectURL(referrerURL), + }) if err != nil { return err } data := map[string]interface{}{ "SiteURL": m.Config.SiteURL, - "ConfirmationURL": url, + "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, "TokenHash": user.RecoveryToken, @@ -247,17 +253,19 @@ func (m *TemplateMailer) RecoveryMail(user *models.User, otp, referrerURL string } // MagicLinkMail sends a login link mail -func (m *TemplateMailer) MagicLinkMail(user *models.User, otp, referrerURL string) error { - redirectParam := encodeRedirectParam(referrerURL) - fragment := "token=" + user.RecoveryToken + "&type=magiclink" + redirectParam - url, err := getSiteURL(referrerURL, m.Config.API.ExternalURL, m.Config.Mailer.URLPaths.Recovery, fragment) +func (m *TemplateMailer) MagicLinkMail(user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Recovery, map[string]string{ + "token": user.RecoveryToken, + "type": "magiclink", + "redirect_to": encodeRedirectURL(referrerURL), + }) if err != nil { return err } data := map[string]interface{}{ "SiteURL": m.Config.SiteURL, - "ConfirmationURL": url, + "ConfirmationURL": externalURL.ResolveReference(path).String(), "Email": user.Email, "Token": otp, "TokenHash": user.RecoveryToken, @@ -285,30 +293,53 @@ func (m TemplateMailer) Send(user *models.User, subject, body string, data map[s } // GetEmailActionLink returns a magiclink, recovery or invite link based on the actionType passed. -func (m TemplateMailer) GetEmailActionLink(user *models.User, actionType, referrerURL string) (string, error) { +func (m TemplateMailer) GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) { var err error + var path *url.URL - redirectParam := encodeRedirectParam(referrerURL) - - var url string + referrerURL = encodeRedirectURL(referrerURL) switch actionType { case "magiclink": - url, err = getSiteURL(referrerURL, m.Config.API.ExternalURL, m.Config.Mailer.URLPaths.Recovery, "token="+user.RecoveryToken+"&type=magiclink"+redirectParam) + path, err = getPath(m.Config.Mailer.URLPaths.Recovery, map[string]string{ + "token": user.RecoveryToken, + "type": "magiclink", + "redirect_to": referrerURL, + }) case "recovery": - url, err = getSiteURL(referrerURL, m.Config.API.ExternalURL, m.Config.Mailer.URLPaths.Recovery, "token="+user.RecoveryToken+"&type=recovery"+redirectParam) + path, err = getPath(m.Config.Mailer.URLPaths.Recovery, map[string]string{ + "token": user.RecoveryToken, + "type": "recovery", + "redirect_to": referrerURL, + }) case "invite": - url, err = getSiteURL(referrerURL, m.Config.API.ExternalURL, m.Config.Mailer.URLPaths.Invite, "token="+user.ConfirmationToken+"&type=invite"+redirectParam) + path, err = getPath(m.Config.Mailer.URLPaths.Invite, map[string]string{ + "token": user.ConfirmationToken, + "type": "invite", + "redirect_to": referrerURL, + }) case "signup": - url, err = getSiteURL(referrerURL, m.Config.API.ExternalURL, m.Config.Mailer.URLPaths.Confirmation, "token="+user.ConfirmationToken+"&type=signup"+redirectParam) + path, err = getPath(m.Config.Mailer.URLPaths.Confirmation, map[string]string{ + "token": user.ConfirmationToken, + "type": "signup", + "redirect_to": referrerURL, + }) case "email_change_current": - url, err = getSiteURL(referrerURL, m.Config.API.ExternalURL, m.Config.Mailer.URLPaths.EmailChange, "token="+user.EmailChangeTokenCurrent+"&type=email_change"+redirectParam) + path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, map[string]string{ + "token": user.EmailChangeTokenCurrent, + "type": "email_change", + "redirect_to": referrerURL, + }) case "email_change_new": - url, err = getSiteURL(referrerURL, m.Config.API.ExternalURL, m.Config.Mailer.URLPaths.EmailChange, "token="+user.EmailChangeTokenNew+"&type=email_change"+redirectParam) + path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, map[string]string{ + "token": user.EmailChangeTokenNew, + "type": "email_change", + "redirect_to": referrerURL, + }) default: return "", fmt.Errorf("invalid email action link type: %s", actionType) } if err != nil { return "", err } - return url, nil + return externalURL.ResolveReference(path).String(), nil }