Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
admpub committed Nov 1, 2024
1 parent b1124b3 commit 9d42e2f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 18 deletions.
14 changes: 14 additions & 0 deletions handler/oauth2/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package oauth2

import "errors"

var (
ErrStateTokenMismatch = errors.New("state token mismatch")
ErrSessionDismatched = errors.New("could not find a matching session for this request")
ErrMustSelectProvider = errors.New("you must select a provider")

// Unpack Value
ErrIPAddressDismatched = errors.New(`IP address does not match`)
ErrUserAgentDismatched = errors.New(`UserAgent does not match`)
ErrDataExpired = errors.New(`data has expired`)
)
60 changes: 42 additions & 18 deletions handler/oauth2/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import (
"io"
"net/url"
"strings"
"time"

"github.com/admpub/goth"
"github.com/admpub/log"
"github.com/webx-top/com"
"github.com/webx-top/echo"
)
Expand All @@ -38,7 +40,7 @@ const StateSessionName = "EchoGothState"
const redirectURIQueryString = `redirect_uri=%2F`

var (
_ goth.Params = url.Values{}
_ goth.Params = (*url.Values)(nil)
EmptyUser = goth.User{}
)

Expand Down Expand Up @@ -87,7 +89,7 @@ var SetState = func(ctx echo.Context) (string, error) {
nonceBytes := make([]byte, 64)
_, err := io.ReadFull(rand.Reader, nonceBytes)
if err != nil {
err = errors.New("gothic: source of randomness unavailable: " + err.Error())
err = fmt.Errorf("gothic: source of randomness unavailable: %w", err)
return state, err
}
return base64.URLEncoding.EncodeToString(nonceBytes), nil
Expand All @@ -104,8 +106,6 @@ var GetState = func(ctx echo.Context) string {
return state
}

var ErrStateTokenMismatch = errors.New("state token mismatch")

/*
GetAuthURL starts the authentication process with the requested provided.
It will return a URL that should be used to send users to.
Expand Down Expand Up @@ -190,7 +190,7 @@ func fetchUser(ctx echo.Context) (goth.User, error) {

sv, ok := ctx.Session().Get(SessionName).(string)
if !ok || len(sv) == 0 {
return EmptyUser, errors.New("could not find a matching session for this request")
return EmptyUser, ErrSessionDismatched
}

defer func() {
Expand Down Expand Up @@ -246,7 +246,7 @@ var CompleteUserAuth = func(ctx echo.Context) (goth.User, error) {
return EmptyUser, fmt.Errorf(providerName+`: %w`, err)
}
if len(sv) == 0 {
return EmptyUser, errors.New("could not find a matching session for this request")
return EmptyUser, ErrSessionDismatched
}

defer func() {
Expand Down Expand Up @@ -312,23 +312,49 @@ var GetProviderName = getProviderName

func getProviderName(ctx echo.Context) (string, error) {
provider := ctx.Param("provider")
if len(provider) == 0 {
provider = ctx.Query("provider")
} else {
if len(provider) > 0 {
return provider, nil
}
provider = ctx.Query("provider")
if len(provider) == 0 {
return provider, errors.New("you must select a provider")
return provider, ErrMustSelectProvider
}
return provider, nil
}

func PackValue(ip, ua, value string) string {
return ip + `|` + com.Md5(ua) + `@` + com.String(time.Now().Unix()) + `|` + value
}

func UnpackValue(ip, ua, value string, maxAge time.Duration) (string, error) {
parts := strings.SplitN(value, `|`, 3)
if len(parts) != 3 {
return "", nil
}
if parts[0] != ip {
return "", fmt.Errorf(`%w: %q != %q`, ErrIPAddressDismatched, parts[0], ip)
}
parts2 := strings.SplitN(parts[1], `@`, 2)
if len(parts2) == 2 {
if ts := com.Int64(parts2[1]); ts > 0 {
ti := time.Unix(ts, 0)
if ti.Before(time.Now().Add(-maxAge)) {
return "", fmt.Errorf(`%w: %s (maxAge: %s)`, ErrDataExpired, ti.Format(time.DateTime), maxAge.String())
}
}
}
if uaMd5 := com.Md5(ua); parts2[0] != uaMd5 {
return "", fmt.Errorf(`%w: %q != %q`, ErrUserAgentDismatched, parts2[0], uaMd5)
}
return parts[2], nil
}

func EncryptValue(ctx echo.Context, value string) (string, error) {
if len(value) == 0 {
return value, nil
}
var err error
value = ctx.RealIP() + `|` + com.Md5(ctx.Request().UserAgent()) + `|` + value
value = PackValue(ctx.RealIP(), ctx.Request().UserAgent(), value)
value, err = CompressValue(value)
if err != nil {
return "", err
Expand Down Expand Up @@ -366,14 +392,12 @@ func DecryptValue(ctx echo.Context, value string) (string, error) {
if err != nil {
return "", err
}
parts := strings.SplitN(value, `|`, 3)
if len(parts) != 3 {
return "", nil
}
if parts[0] != ctx.RealIP() || parts[1] != com.Md5(ctx.Request().UserAgent()) {
return "", nil
var unpackErr error
value, unpackErr = UnpackValue(ctx.RealIP(), ctx.Request().UserAgent(), value, time.Hour)
if unpackErr != nil {
log.Warn(unpackErr.Error())
}
return parts[2], nil
return value, err
}

func UncompressValue(value string) (string, error) {
Expand Down
9 changes: 9 additions & 0 deletions handler/oauth2/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oauth2
import (
"net/url"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/webx-top/com"
Expand Down Expand Up @@ -34,3 +35,11 @@ func TestCompressValue(t *testing.T) {
require.Error(t, err)
t.Log(v)
}

func TestPackValue(t *testing.T) {
r := PackValue(`127.0.0.1`, `echo/1.1`, `test`)
t.Log(r)
r, err := UnpackValue(`127.0.0.1`, `echo/1.1`, r, time.Minute)
require.NoError(t, err)
require.Equal(t, `test`, r)
}

0 comments on commit 9d42e2f

Please sign in to comment.