diff --git a/internal/greenlight/auth.go b/internal/greenlight/auth.go index 4fff812..052f5d4 100644 --- a/internal/greenlight/auth.go +++ b/internal/greenlight/auth.go @@ -19,10 +19,10 @@ func NewAuthService(secret string) *AuthService { } } -func (a *AuthService) Create(ctx context.Context, id int64) (token string, err error) { +func (a *AuthService) CreateToken(ctx context.Context, userID int64) (token string, err error) { now := time.Now() claims := &jwt.RegisteredClaims{ - Subject: strconv.FormatInt(id, 10), + Subject: strconv.FormatInt(userID, 10), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(24 * time.Hour)), @@ -32,3 +32,21 @@ func (a *AuthService) Create(ctx context.Context, id int64) (token string, err e return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(a.secret)) } + +func (a *AuthService) ParseToken(tokenString string) (userID int64, err error) { + token, err := jwt.Parse( + tokenString, + func(t *jwt.Token) (interface{}, error) { + return []byte(a.secret), nil + }, + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}), + jwt.WithExpirationRequired(), + jwt.WithIssuer("github.com./denpeshkov/greenlight"), + jwt.WithAudience("github.com./denpeshkov/greenlight"), + ) + + if err != nil { + return 0, NewUnauthorizedError("Invalid or missing authentication token.") + } + return strconv.ParseInt(token.Claims.(*jwt.RegisteredClaims).Subject, 10, 64) +} diff --git a/internal/greenlight/context.go b/internal/greenlight/context.go new file mode 100644 index 0000000..8a2406f --- /dev/null +++ b/internal/greenlight/context.go @@ -0,0 +1,18 @@ +package greenlight + +import "context" + +type ctxKey string + +const ( + userIDCtxKey ctxKey = "userID" +) + +func NewContextWithUserID(ctx context.Context, userID int64) context.Context { + return context.WithValue(ctx, userIDCtxKey, userID) +} + +func UserIDFromContext(ctx context.Context) int64 { + userID, _ := ctx.Value(userIDCtxKey).(int64) + return userID +} diff --git a/internal/http/auth.go b/internal/http/auth.go index c3803fe..a7e3085 100644 --- a/internal/http/auth.go +++ b/internal/http/auth.go @@ -49,7 +49,7 @@ func (s *Server) handleCreateToken(w http.ResponseWriter, r *http.Request) { return } - token, err := s.authService.Create(r.Context(), u.ID) + token, err := s.authService.CreateToken(r.Context(), u.ID) if err != nil { s.Error(w, r, fmt.Errorf("%s: %w", op, err)) return diff --git a/internal/http/middleware.go b/internal/http/middleware.go index d6b669d..4796cf2 100644 --- a/internal/http/middleware.go +++ b/internal/http/middleware.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "strings" "sync" "time" @@ -126,3 +127,33 @@ func (s *Server) rateLimit(h http.Handler) http.Handler { h.ServeHTTP(w, r) }) } + +func (s *Server) authenticate(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + op := "http.Server.authenticate" + + // This indicates to any caches that the response may vary based on the value of the Authorization header in the request. + w.Header().Add("Vary", "Authorization") + + authzHeader := r.Header.Get("Authorization") + + if authzHeader == "" { + s.Error(w, r, fmt.Errorf("%s: %w", op, greenlight.NewUnauthorizedError("You must be authenticated to access this resource."))) + return + } + + headerParts := strings.Split(authzHeader, " ") + if len(headerParts) != 2 || headerParts[0] != "Bearer" { + s.Error(w, r, fmt.Errorf("%s: %w", op, greenlight.NewUnauthorizedError("Invalid or missing authentication token."))) + return + } + + userId, err := s.authService.ParseToken(headerParts[1]) + if err != nil { + s.Error(w, r, fmt.Errorf("%s: %w", op, err)) + return + } + r = r.WithContext(greenlight.NewContextWithUserID(r.Context(), userId)) + h.ServeHTTP(w, r) + }) +} diff --git a/internal/http/movie.go b/internal/http/movie.go index eb32b3e..979648c 100644 --- a/internal/http/movie.go +++ b/internal/http/movie.go @@ -12,11 +12,11 @@ import ( ) func (s *Server) registerMovieHandlers() { - s.router.HandleFunc("GET /v1/movies/{id}", s.handleMovieGet) - s.router.HandleFunc("GET /v1/movies", s.handleMoviesGet) - s.router.HandleFunc("POST /v1/movies", s.handleMovieCreate) - s.router.HandleFunc("PATCH /v1/movies/{id}", s.handleMovieUpdate) - s.router.HandleFunc("DELETE /v1/movies/{id}", s.handleMovieDelete) + s.router.Handle("GET /v1/movies/{id}", s.authenticate(http.HandlerFunc(s.handleMovieGet))) + s.router.Handle("GET /v1/movies", s.authenticate(http.HandlerFunc(s.handleMoviesGet))) + s.router.Handle("POST /v1/movies", s.authenticate(http.HandlerFunc(s.handleMovieCreate))) + s.router.Handle("PATCH /v1/movies/{id}", s.authenticate(http.HandlerFunc(s.handleMovieUpdate))) + s.router.Handle("DELETE /v1/movies/{id}", s.authenticate(http.HandlerFunc(s.handleMovieDelete))) } // handleMovieGet handles requests to get a specified movie. diff --git a/internal/http/server.go b/internal/http/server.go index b0892be..06379e8 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -53,9 +53,7 @@ func NewServer(addr string, movieService greenlight.MovieService, userService gr func (s *Server) Open() error { op := "http.Server.Start" - handler := s.recoverPanic(s.rateLimit(s.notFound(s.methodNotAllowed(s.router)))) - - s.server.Handler = handler + s.server.Handler = s.recoverPanic(s.rateLimit(s.notFound(s.methodNotAllowed(s.router)))) s.server.ErrorLog = slog.NewLogLogger(s.logger.Handler(), slog.LevelError) s.server.IdleTimeout = s.opts.idleTimeout