diff --git a/revoke_handler.go b/revoke_handler.go index 97773e7a..bc071bd3 100644 --- a/revoke_handler.go +++ b/revoke_handler.go @@ -6,11 +6,11 @@ package oauth2 import ( "context" "encoding/json" + "errors" "fmt" "net/http" "authelia.com/provider/oauth2/internal/consts" - "github.com/pkg/errors" "authelia.com/provider/oauth2/internal/errorsx" ) @@ -84,35 +84,36 @@ func (f *Fosite) WriteRevocationResponse(ctx context.Context, rw http.ResponseWr rw.Header().Set("Cache-Control", "no-store") rw.Header().Set("Pragma", "no-cache") - if err == nil { + switch { + case err == nil: rw.WriteHeader(http.StatusOK) - return + case errors.Is(err, ErrInvalidRequest): + f.writeRevocationResponseError(ctx, rw, ErrInvalidRequest) + case errors.Is(err, ErrInvalidClient): + f.writeRevocationResponseError(ctx, rw, ErrInvalidClient) + case errors.Is(err, ErrInvalidGrant): + f.writeRevocationResponseError(ctx, rw, ErrInvalidGrant) + case errors.Is(err, ErrUnauthorizedClient): + f.writeRevocationResponseError(ctx, rw, ErrUnauthorizedClient) + case errors.Is(err, ErrUnsupportedGrantType): + f.writeRevocationResponseError(ctx, rw, ErrUnsupportedGrantType) + case errors.Is(err, ErrInvalidScope): + f.writeRevocationResponseError(ctx, rw, ErrInvalidScope) + default: + rw.WriteHeader(http.StatusInternalServerError) } +} - if errors.Is(err, ErrInvalidRequest) { - rw.Header().Set("Content-Type", "application/json;charset=UTF-8") +func (f *Fosite) writeRevocationResponseError(ctx context.Context, rw http.ResponseWriter, rfc *RFC6749Error) { + rw.Header().Set("Content-Type", "application/json; charset=utf-8") - js, err := json.Marshal(ErrInvalidRequest) - if err != nil { - http.Error(rw, fmt.Sprintf(`{"error": "%s"}`, err.Error()), http.StatusInternalServerError) - return - } - - rw.WriteHeader(ErrInvalidRequest.CodeField) - _, _ = rw.Write(js) - } else if errors.Is(err, ErrInvalidClient) { - rw.Header().Set("Content-Type", "application/json;charset=UTF-8") + js, err := json.Marshal(rfc) + if err != nil { + http.Error(rw, fmt.Sprintf(`{"error": "%s"}`, err.Error()), http.StatusInternalServerError) + return + } - js, err := json.Marshal(ErrInvalidClient) - if err != nil { - http.Error(rw, fmt.Sprintf(`{"error": "%s"}`, err.Error()), http.StatusInternalServerError) - return - } + rw.WriteHeader(rfc.CodeField) - rw.WriteHeader(ErrInvalidClient.CodeField) - _, _ = rw.Write(js) - } else { - // 200 OK - rw.WriteHeader(http.StatusOK) - } + _, _ = rw.Write(js) }