From e35b203240ff39a295d59075859fa2fd985af4c0 Mon Sep 17 00:00:00 2001 From: Denis Peshkov Date: Sat, 17 Feb 2024 12:56:02 +0400 Subject: [PATCH] fix error handling --- internal/greenlight/user.go | 16 +++--- internal/http/auth.go | 37 ++++++------- internal/http/healtcheck.go | 13 ++--- internal/http/json.go | 10 ++-- internal/http/movie.go | 99 +++++++++++++++------------------- internal/http/server.go | 26 +++++---- internal/http/user.go | 24 ++++----- internal/multierr/error.go | 8 +++ internal/postgres/migration.go | 54 +++++++++---------- internal/postgres/movie.go | 59 ++++++++++---------- internal/postgres/postgres.go | 15 +++--- internal/postgres/user.go | 32 +++++------ 12 files changed, 192 insertions(+), 201 deletions(-) diff --git a/internal/greenlight/user.go b/internal/greenlight/user.go index 41e696d..6c3c430 100644 --- a/internal/greenlight/user.go +++ b/internal/greenlight/user.go @@ -3,10 +3,10 @@ package greenlight import ( "context" "errors" - "fmt" "net/mail" "unicode/utf8" + "github.com/denpeshkov/greenlight/internal/multierr" "golang.org/x/crypto/bcrypt" ) @@ -52,16 +52,16 @@ func (u *User) Valid() error { type Password []byte // NewPasswords generates a hashed password from the plaintext password. -func NewPassword(plaintext string) (Password, error) { - op := "greenlight.NewPassword" +func NewPassword(plaintext string) (_ Password, err error) { + defer multierr.Wrap(&err, "greenlight.NewPassword") if err := PasswordValid(plaintext); err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } hash, err := bcrypt.GenerateFromPassword([]byte(plaintext), 12) if err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } return hash, nil } @@ -85,15 +85,15 @@ func PasswordValid(plaintext string) error { } // Matches tests whether the provided plaintext password matches the hashed password. -func (p *Password) Matches(plaintext string) (bool, error) { - op := "greenlight.password.Matches" +func (p *Password) Matches(plaintext string) (_ bool, err error) { + defer multierr.Wrap(&err, "greenlight.password.Matches") if err := bcrypt.CompareHashAndPassword(*p, []byte(plaintext)); err != nil { switch { case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword): return false, nil default: - return false, fmt.Errorf("%s: %w", op, err) + return false, err } } return true, nil diff --git a/internal/http/auth.go b/internal/http/auth.go index a7e3085..48eb257 100644 --- a/internal/http/auth.go +++ b/internal/http/auth.go @@ -2,57 +2,50 @@ package http import ( "errors" - "fmt" "net/http" "github.com/denpeshkov/greenlight/internal/greenlight" + "github.com/denpeshkov/greenlight/internal/multierr" ) func (s *Server) registerAuthHandlers() { - s.router.HandleFunc("POST /v1/auth/token", s.handleCreateToken) + s.router.HandleFunc("POST /v1/auth/token", s.handlerFunc(s.handleCreateToken)) } // handleCreateToken handles requests to create an authentication token. -func (s *Server) handleCreateToken(w http.ResponseWriter, r *http.Request) { - op := "http.Server.handleCreateToken" +func (s *Server) handleCreateToken(w http.ResponseWriter, r *http.Request) (err error) { + defer multierr.Wrap(&err, "http.Server.handleCreateToken") var req struct { Email string `json:"email"` Password string `json:"password"` } - if err := s.readRequest(w, r, &req); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + if err = s.readRequest(w, r, &req); err != nil { + return err } - - if err := greenlight.PasswordValid(req.Password); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + if err = greenlight.PasswordValid(req.Password); err != nil { + return err } u, err := s.userService.Get(r.Context(), req.Email) if err != nil { switch { case errors.Is(err, greenlight.ErrNotFound): - s.Error(w, r, greenlight.NewUnauthorizedError("Invalid credentials.")) + return greenlight.NewUnauthorizedError("Invalid credentials.") default: - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) + return err } - return } if match, err := u.Password.Matches(req.Password); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } else if !match { - s.Error(w, r, greenlight.NewUnauthorizedError("Invalid credentials.")) - return + return greenlight.NewUnauthorizedError("Invalid credentials.") } token, err := s.authService.CreateToken(r.Context(), u.ID) if err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } resp := struct { @@ -62,7 +55,7 @@ func (s *Server) handleCreateToken(w http.ResponseWriter, r *http.Request) { } if err := s.sendResponse(w, r, http.StatusCreated, resp, nil); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } + return nil } diff --git a/internal/http/healtcheck.go b/internal/http/healtcheck.go index 8d0cbea..1ac7684 100644 --- a/internal/http/healtcheck.go +++ b/internal/http/healtcheck.go @@ -1,24 +1,25 @@ package http import ( - "fmt" "net/http" + + "github.com/denpeshkov/greenlight/internal/multierr" ) func (s *Server) registerHealthCheckHandlers() { - s.router.HandleFunc("GET /v1/healthcheck", s.handleHealthCheck) + s.router.HandleFunc("GET /v1/healthcheck", s.handlerFunc(s.handleHealthCheck)) } // handleHealthCheck handles requests to get application information (status). -func (s *Server) handleHealthCheck(w http.ResponseWriter, r *http.Request) { - op := "http.Server.handleHealthCheck" +func (s *Server) handleHealthCheck(w http.ResponseWriter, r *http.Request) (err error) { + defer multierr.Wrap(&err, "http.Server.handleHealthCheck") info := HealthInfo{"1.0", "UP"} w.Header().Set("Content-Type", "application/json") if err := s.sendResponse(w, r, http.StatusOK, info, nil); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } + return nil } // Application information. diff --git a/internal/http/json.go b/internal/http/json.go index 78efc12..8e99b4d 100644 --- a/internal/http/json.go +++ b/internal/http/json.go @@ -2,22 +2,22 @@ package http import ( "encoding/json" - "fmt" "net/http" "net/textproto" "strings" "github.com/denpeshkov/greenlight/internal/greenlight" + "github.com/denpeshkov/greenlight/internal/multierr" ) // sendResponse sends a JSON response with a given status. // In case of an error, response (and status) is not send and error is returned. -func (s *Server) sendResponse(w http.ResponseWriter, r *http.Request, status int, resp any, headers http.Header) error { - op := "http.Server.sendResponse" +func (s *Server) sendResponse(w http.ResponseWriter, r *http.Request, status int, resp any, headers http.Header) (err error) { + defer multierr.Wrap(&err, "http.Server.sendResponse") js, err := json.Marshal(resp) if err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } for k, v := range headers { k = textproto.CanonicalMIMEHeaderKey(k) @@ -26,7 +26,7 @@ func (s *Server) sendResponse(w http.ResponseWriter, r *http.Request, status int w.WriteHeader(status) _, err = w.Write(js) if err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } return nil } diff --git a/internal/http/movie.go b/internal/http/movie.go index a992498..984089e 100644 --- a/internal/http/movie.go +++ b/internal/http/movie.go @@ -9,31 +9,30 @@ import ( "time" "github.com/denpeshkov/greenlight/internal/greenlight" + "github.com/denpeshkov/greenlight/internal/multierr" ) func (s *Server) registerMovieHandlers() { - s.router.Handle("GET /v1/movies/{id}", http.HandlerFunc(s.handleMovieGet)) - s.router.Handle("GET /v1/movies", http.HandlerFunc(s.handleMoviesGet)) - s.router.Handle("POST /v1/movies", http.HandlerFunc(s.handleMovieCreate)) - s.router.Handle("PATCH /v1/movies/{id}", http.HandlerFunc(s.handleMovieUpdate)) - s.router.Handle("DELETE /v1/movies/{id}", http.HandlerFunc(s.handleMovieDelete)) + s.router.Handle("GET /v1/movies/{id}", s.handlerFunc(s.handleMovieGet)) + s.router.Handle("GET /v1/movies", s.handlerFunc(s.handleMoviesGet)) + s.router.Handle("POST /v1/movies", s.handlerFunc(s.handleMovieCreate)) + s.router.Handle("PATCH /v1/movies/{id}", s.handlerFunc(s.handleMovieUpdate)) + s.router.Handle("DELETE /v1/movies/{id}", s.handlerFunc(s.handleMovieDelete)) } // handleMovieGet handles requests to get a specified movie. -func (s *Server) handleMovieGet(w http.ResponseWriter, r *http.Request) { - op := "http.Server.handleMovieGet" +func (s *Server) handleMovieGet(w http.ResponseWriter, r *http.Request) (err error) { + defer multierr.Wrap(&err, "http.Server.handleMovieGet") idRaw := r.PathValue("id") id, err := strconv.ParseInt(idRaw, 10, 64) if err != nil || id < 0 { - s.Error(w, r, greenlight.NewInvalidError(`Invalid "ID" parameter format: %s`, idRaw)) - return + return greenlight.NewInvalidError(`Invalid "ID" parameter format: %s`, idRaw) } m, err := s.movieService.Get(r.Context(), id) if err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } resp := struct { @@ -51,14 +50,14 @@ func (s *Server) handleMovieGet(w http.ResponseWriter, r *http.Request) { } if err := s.sendResponse(w, r, http.StatusOK, resp, nil); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } + return nil } // handleMoviesGet handles requests to get movies based on provided filter parameters. -func (s *Server) handleMoviesGet(w http.ResponseWriter, r *http.Request) { - op := "http.Server.handleMoviesGet" +func (s *Server) handleMoviesGet(w http.ResponseWriter, r *http.Request) (err error) { + defer multierr.Wrap(&err, "http.Server.handleMoviesGet") filter := greenlight.MovieFilter{ Title: "", @@ -78,8 +77,7 @@ func (s *Server) handleMoviesGet(w http.ResponseWriter, r *http.Request) { pageRaw := vs.Get("page") page, err := strconv.Atoi(pageRaw) if err != nil { - s.Error(w, r, greenlight.NewInvalidError(`Invalid "page" parameter format: %s`, pageRaw)) - return + return greenlight.NewInvalidError(`Invalid "page" parameter format: %s`, pageRaw) } filter.Page = page } @@ -87,8 +85,7 @@ func (s *Server) handleMoviesGet(w http.ResponseWriter, r *http.Request) { pageSzRaw := vs.Get("page_size") pageSz, err := strconv.Atoi(pageSzRaw) if err != nil { - s.Error(w, r, greenlight.NewInvalidError(`Invalid "page_size" parameter format: %s`, pageSzRaw)) - return + return greenlight.NewInvalidError(`Invalid "page_size" parameter format: %s`, pageSzRaw) } filter.PageSize = pageSz } @@ -97,14 +94,12 @@ func (s *Server) handleMoviesGet(w http.ResponseWriter, r *http.Request) { } if err := filter.Valid(); err != nil { - s.Error(w, r, err) - return + return err } movies, err := s.movieService.GetAll(r.Context(), filter) if err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } type respEl struct { @@ -126,14 +121,14 @@ func (s *Server) handleMoviesGet(w http.ResponseWriter, r *http.Request) { } if err := s.sendResponse(w, r, http.StatusOK, resp, nil); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } + return nil } // handleMovieCreate handles requests to create a movie. -func (s *Server) handleMovieCreate(w http.ResponseWriter, r *http.Request) { - op := "http.Server.handleMovieCreate" +func (s *Server) handleMovieCreate(w http.ResponseWriter, r *http.Request) (err error) { + defer multierr.Wrap(&err, "http.Server.handleMovieCreate") var req struct { Title string `json:"title"` @@ -142,8 +137,7 @@ func (s *Server) handleMovieCreate(w http.ResponseWriter, r *http.Request) { Genres []string `json:"genres"` } if err := s.readRequest(w, r, &req); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } m := &greenlight.Movie{ @@ -153,12 +147,10 @@ func (s *Server) handleMovieCreate(w http.ResponseWriter, r *http.Request) { Genres: req.Genres, } if err := m.Valid(); err != nil { - s.Error(w, r, err) - return + return err } if err := s.movieService.Create(r.Context(), m); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } resp := struct { @@ -178,26 +170,24 @@ func (s *Server) handleMovieCreate(w http.ResponseWriter, r *http.Request) { headers := make(http.Header) headers.Set("Location", fmt.Sprintf("/v1/movies/%d", resp.ID)) if err := s.sendResponse(w, r, http.StatusCreated, resp, headers); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } + return nil } // handleMovieUpdate handles requests to update a specified movie. -func (s *Server) handleMovieUpdate(w http.ResponseWriter, r *http.Request) { - op := "http.Server.handleMovieUpdate" +func (s *Server) handleMovieUpdate(w http.ResponseWriter, r *http.Request) (err error) { + defer multierr.Wrap(&err, "http.Server.handleMovieUpdate") idRaw := r.PathValue("id") id, err := strconv.ParseInt(idRaw, 10, 64) if err != nil || id < 0 { - s.Error(w, r, greenlight.NewInvalidError("Invalid ID format: %s", idRaw)) - return + return greenlight.NewInvalidError("Invalid ID format: %s", idRaw) } m, err := s.movieService.Get(r.Context(), id) if err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } // use pointers to allow partial updates @@ -208,8 +198,7 @@ func (s *Server) handleMovieUpdate(w http.ResponseWriter, r *http.Request) { Genres []string `json:"genres"` } if err := s.readRequest(w, r, &req); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } if req.Title != nil { @@ -226,12 +215,10 @@ func (s *Server) handleMovieUpdate(w http.ResponseWriter, r *http.Request) { } if err := m.Valid(); err != nil { - s.Error(w, r, err) - return + return err } if err := s.movieService.Update(r.Context(), m); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } resp := struct { @@ -249,30 +236,28 @@ func (s *Server) handleMovieUpdate(w http.ResponseWriter, r *http.Request) { } if err := s.sendResponse(w, r, http.StatusOK, resp, nil); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } + return nil } // handleMovieDelete handles requests to delete a specified movie. -func (s *Server) handleMovieDelete(w http.ResponseWriter, r *http.Request) { - op := "http.Server.handleMovieDelete" +func (s *Server) handleMovieDelete(w http.ResponseWriter, r *http.Request) (err error) { + defer multierr.Wrap(&err, "http.Server.handleMovieDelete") idRaw := r.PathValue("id") id, err := strconv.ParseInt(idRaw, 10, 64) if err != nil || id < 0 { - s.Error(w, r, greenlight.NewInvalidError("Invalid ID format: %s", idRaw)) - return + return greenlight.NewInvalidError("Invalid ID format: %s", idRaw) } if err := s.movieService.Delete(r.Context(), id); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } if err := s.sendResponse(w, r, http.StatusNoContent, nil, nil); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } + return nil } // date represents a date in the format "YYYY-MM-DD". diff --git a/internal/http/server.go b/internal/http/server.go index 06379e8..5b1900f 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -3,12 +3,12 @@ package http import ( "context" "errors" - "fmt" "log/slog" "net/http" "os" "github.com/denpeshkov/greenlight/internal/greenlight" + "github.com/denpeshkov/greenlight/internal/multierr" ) // Server represents an HTTP server. @@ -50,8 +50,8 @@ func NewServer(addr string, movieService greenlight.MovieService, userService gr } // Open starts an HTTP server. -func (s *Server) Open() error { - op := "http.Server.Start" +func (s *Server) Open() (err error) { + defer multierr.Wrap(&err, "http.Server.Start") s.server.Handler = s.recoverPanic(s.rateLimit(s.notFound(s.methodNotAllowed(s.router)))) s.server.ErrorLog = slog.NewLogLogger(s.logger.Handler(), slog.LevelError) @@ -60,27 +60,35 @@ func (s *Server) Open() error { s.server.ReadTimeout = s.opts.readTimeout s.server.WriteTimeout = s.opts.writeTimeout - err := s.server.ListenAndServe() + err = s.server.ListenAndServe() if !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("%s: %w", op, err) + return err } return nil } // Close gracefully shuts down the server. -func (s *Server) Close() error { - op := "http.Server.Close" +func (s *Server) Close() (err error) { + defer multierr.Wrap(&err, "http.Server.Close") ctx, cancel := context.WithTimeout(context.Background(), s.opts.shutdownTimeout) defer cancel() - err := s.server.Shutdown(ctx) + err = s.server.Shutdown(ctx) if err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } return nil } +func (s *Server) handlerFunc(h func(http.ResponseWriter, *http.Request) error) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := h(w, r); err != nil { + s.Error(w, r, err) + } + }) +} + func newLogger() *slog.Logger { opts := slog.HandlerOptions{Level: slog.LevelDebug} handler := slog.NewJSONHandler(os.Stderr, &opts) diff --git a/internal/http/user.go b/internal/http/user.go index abce5d7..6990a28 100644 --- a/internal/http/user.go +++ b/internal/http/user.go @@ -1,19 +1,19 @@ package http import ( - "fmt" "net/http" "github.com/denpeshkov/greenlight/internal/greenlight" + "github.com/denpeshkov/greenlight/internal/multierr" ) func (s *Server) registerUserHandlers() { - s.router.HandleFunc("POST /v1/users", s.handleUserCreate) + s.router.HandleFunc("POST /v1/users", s.handlerFunc(s.handleUserCreate)) } // handleUserCreate handles requests to create (register) a user. -func (s *Server) handleUserCreate(w http.ResponseWriter, r *http.Request) { - op := "http.Server.handleUserCreate" +func (s *Server) handleUserCreate(w http.ResponseWriter, r *http.Request) (err error) { + defer multierr.Wrap(&err, "http.Server.handleUserCreate") var req struct { Name string `json:"name"` @@ -21,8 +21,7 @@ func (s *Server) handleUserCreate(w http.ResponseWriter, r *http.Request) { Password string `json:"password"` } if err := s.readRequest(w, r, &req); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } u := &greenlight.User{ @@ -31,18 +30,15 @@ func (s *Server) handleUserCreate(w http.ResponseWriter, r *http.Request) { } pass, err := greenlight.NewPassword(req.Password) if err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } u.Password = pass if err := u.Valid(); err != nil { - s.Error(w, r, err) - return + return err } if err := s.userService.Create(r.Context(), u); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } resp := struct { @@ -56,7 +52,7 @@ func (s *Server) handleUserCreate(w http.ResponseWriter, r *http.Request) { } if err := s.sendResponse(w, r, http.StatusCreated, resp, nil); err != nil { - s.Error(w, r, fmt.Errorf("%s: %w", op, err)) - return + return err } + return nil } diff --git a/internal/multierr/error.go b/internal/multierr/error.go index c2df596..19a5404 100644 --- a/internal/multierr/error.go +++ b/internal/multierr/error.go @@ -3,6 +3,7 @@ package multierr import ( "bytes" + "fmt" ) // joinError represents an error that wraps the errors. @@ -60,3 +61,10 @@ func (e joinError) Unwrap() []error { } return e } + +// Wrap adds context to the error and allows unwrapping the result to recover the original error. +func Wrap(err *error, format string, args ...any) { + if *err != nil { + *err = fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), *err) + } +} diff --git a/internal/postgres/migration.go b/internal/postgres/migration.go index c31d710..237b1db 100644 --- a/internal/postgres/migration.go +++ b/internal/postgres/migration.go @@ -32,30 +32,30 @@ const ( // FIXME execute migrations in a DB transaction. // Migrate looks at the currently active migration version and applies all up or down migrations, depending on the provided argument. -func (db *DB) Migrate(t MigrationType) error { - op := "postgres.DB.Migrate" +func (db *DB) Migrate(t MigrationType) (err error) { + defer multierr.Wrap(&err, "postgres.DB.Migrate") db.logger.Debug("start database migration", "type", t.String()) // Creates a migrations table. // Table is based on: https://github.com/golang-migrate/migrate. - // version 0 means that no migrations are applied (all rollback-ed) + // Version 0 means that no migrations are applied (all rollback-ed). query := `CREATE TABLE IF NOT EXISTS migrations (version bigint PRIMARY KEY DEFAULT 0, dirty boolean NOT NULL DEFAULT FALSE)` if _, err := db.db.Exec(query); err != nil { - return fmt.Errorf("%s: create migration table: %w", op, err) + return fmt.Errorf("create migration table: %w", err) } version, dirty, err := db.migrationState() if err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } if dirty { - return fmt.Errorf("%s: current DB has dirty version=%d", op, version) + return fmt.Errorf("current DB has dirty version=%d", version) } - names, err := read(t) + names, err := readMigrationFiles(t) if err != nil { - return fmt.Errorf("%s: get migration files: %w", op, err) + return fmt.Errorf("get migration files: %w", err) } slices.Sort(names) @@ -66,26 +66,26 @@ func (db *DB) Migrate(t MigrationType) error { } if err != nil { dirty = true - err = fmt.Errorf("%s: DB left with dirty version=%d: %w", op, version, err) + err = fmt.Errorf("DB left with dirty version=%d: %w", version, err) } // Update the migration table even in the presence of an error; version will be of the failed migration and dirty will be true. if _, err2 := db.db.Exec(`TRUNCATE TABLE migrations`); err2 != nil { - return fmt.Errorf("%s: update migration table: %w", op, multierr.Join(err2, err)) + return fmt.Errorf("update migration table: %w", multierr.Join(err2, err)) } if _, err2 := db.db.Exec(`INSERT INTO migrations (version, dirty) VALUES ($1, $2)`, version, dirty); err2 != nil { - return fmt.Errorf("%s: update migration table: %w", op, multierr.Join(err2, err)) + return fmt.Errorf("update migration table: %w", multierr.Join(err2, err)) } return err } -func (db *DB) migrateUp(names []string, version int) (int, error) { - op := "postgres.DB.migrateUp" +func (db *DB) migrateUp(names []string, version int) (_ int, err error) { + defer multierr.Wrap(&err, "postgres.DB.migrateUp") for version < len(names) { name := names[version] if err := db.migrateFile(name); err != nil { - return version, fmt.Errorf("%s: migration file %q: %w", op, name, err) + return version, fmt.Errorf("migration file %q: %w", name, err) } db.logger.Debug("database migration", "migration_type", "UP", "file", name) version++ @@ -93,13 +93,13 @@ func (db *DB) migrateUp(names []string, version int) (int, error) { return version, nil } -func (db *DB) migrateDown(names []string, version int) (int, error) { - op := "postgres.DB.migrateDown" +func (db *DB) migrateDown(names []string, version int) (_ int, err error) { + defer multierr.Wrap(&err, "postgres.DB.migrateDown") for version >= 1 { name := names[version-1] if err := db.migrateFile(name); err != nil { - return version, fmt.Errorf("%s: migration file %q: %w", op, name, err) + return version, fmt.Errorf("migration file %q: %w", name, err) } db.logger.Debug("database migration", "migration_type", "DOWN", "file", name) version-- @@ -108,14 +108,14 @@ func (db *DB) migrateDown(names []string, version int) (int, error) { } // migrate runs a single migration file. -func (db *DB) migrateFile(name string) error { - op := "postgres.DB.migrateFile" +func (db *DB) migrateFile(name string) (err error) { + defer multierr.Wrap(&err, "postgres.DB.migrateFile") // Read and execute migration file. if buf, err := migrationFS.ReadFile(name); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } else if _, err = db.db.Exec(string(buf)); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } return nil } @@ -123,22 +123,22 @@ func (db *DB) migrateFile(name string) error { // migrationState returns current migration version and dirty state. // In case of an error returned version is 0 and dirty is false. func (db *DB) migrationState() (version int, dirty bool, err error) { - op := "postgres.DB.migrationState" + defer multierr.Wrap(&err, "postgres.DB.migrationState") if err = db.db.QueryRow(`SELECT version, dirty FROM migrations`).Scan(&version, &dirty); err != nil && err != sql.ErrNoRows { - return 0, false, fmt.Errorf("%s: %w", op, err) + return 0, false, err } return version, dirty, nil } -// read returns the names of all up or down migration files. -func read(t MigrationType) ([]string, error) { - op := "postgres.read" +// readMigrationFiles returns the names of all up or down migration files. +func readMigrationFiles(t MigrationType) (_ []string, err error) { + defer multierr.Wrap(&err, "postgres.readMigrationFiles") pattern := fmt.Sprintf("migrations/*.%s.sql", t.String()) names, err := fs.Glob(migrationFS, pattern) if err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } return names, nil } diff --git a/internal/postgres/movie.go b/internal/postgres/movie.go index 3ba487a..cdb803d 100644 --- a/internal/postgres/movie.go +++ b/internal/postgres/movie.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/denpeshkov/greenlight/internal/greenlight" + "github.com/denpeshkov/greenlight/internal/multierr" "github.com/lib/pq" ) @@ -25,15 +26,15 @@ func NewMovieService(db *DB) *MovieService { } } -func (s *MovieService) Get(ctx context.Context, id int64) (*greenlight.Movie, error) { - op := "postgres.MovieService.Get" +func (s *MovieService) Get(ctx context.Context, id int64) (_ *greenlight.Movie, err error) { + defer multierr.Wrap(&err, "postgres.MovieService.Get(%d)", id) ctx, cancel := context.WithTimeout(ctx, s.db.opts.queryTimeout) defer cancel() tx, err := s.db.db.BeginTx(ctx, nil) if err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } defer tx.Rollback() @@ -45,25 +46,25 @@ func (s *MovieService) Get(ctx context.Context, id int64) (*greenlight.Movie, er case errors.Is(err, sql.ErrNoRows): return nil, greenlight.ErrNotFound default: - return nil, fmt.Errorf("%s: movie with id=%d: %w", op, id, err) + return nil, err } } if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } return &m, nil } -func (s *MovieService) GetAll(ctx context.Context, filter greenlight.MovieFilter) ([]*greenlight.Movie, error) { - op := "postgres.MovieService.GetAll" +func (s *MovieService) GetAll(ctx context.Context, filter greenlight.MovieFilter) (_ []*greenlight.Movie, err error) { + defer multierr.Wrap(&err, "postgres.MovieService.GetAll") ctx, cancel := context.WithTimeout(ctx, s.db.opts.queryTimeout) defer cancel() tx, err := s.db.db.BeginTx(ctx, nil) if err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } defer tx.Rollback() @@ -81,7 +82,7 @@ func (s *MovieService) GetAll(ctx context.Context, filter greenlight.MovieFilter LIMIT $3 OFFSET $4`, sortCol, sortDir) rs, err := tx.QueryContext(ctx, query, filter.Title, pq.Array(filter.Genres), filter.PageSize, (filter.Page-1)*filter.PageSize) if err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } defer rs.Close() @@ -89,29 +90,29 @@ func (s *MovieService) GetAll(ctx context.Context, filter greenlight.MovieFilter for rs.Next() { var m greenlight.Movie if err := rs.Scan(&m.ID, &m.Title, &m.ReleaseDate, &m.Runtime, pq.Array(&m.Genres), &m.Version); err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } movies = append(movies, &m) } if err := rs.Err(); err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } return movies, nil } -func (s *MovieService) Update(ctx context.Context, m *greenlight.Movie) error { - op := "postgres.MovieService.Update" +func (s *MovieService) Update(ctx context.Context, m *greenlight.Movie) (err error) { + defer multierr.Wrap(&err, "postgres.MovieService.Update(%d)", m.ID) ctx, cancel := context.WithTimeout(ctx, s.db.opts.queryTimeout) defer cancel() tx, err := s.db.db.BeginTx(ctx, nil) if err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } defer tx.Rollback() @@ -120,27 +121,27 @@ func (s *MovieService) Update(ctx context.Context, m *greenlight.Movie) error { if err := tx.QueryRowContext(ctx, query, args...).Scan(&m.Version); err != nil { switch { case errors.Is(err, sql.ErrNoRows): - return greenlight.NewConflictError("Conflicting change for the movie with id=%d", m.ID) + return greenlight.NewConflictError("Conflicting change") default: - return fmt.Errorf("%s: movie with id=%d: %w", op, m.ID, err) + return err } } if err := tx.Commit(); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } return nil } -func (s *MovieService) Delete(ctx context.Context, id int64) error { - op := "postgres.MovieService.Delete" +func (s *MovieService) Delete(ctx context.Context, id int64) (err error) { + defer multierr.Wrap(&err, "postgres.MovieService.Delete(%d)", id) ctx, cancel := context.WithTimeout(ctx, s.db.opts.queryTimeout) defer cancel() tx, err := s.db.db.BeginTx(ctx, nil) if err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } defer tx.Rollback() @@ -148,43 +149,43 @@ func (s *MovieService) Delete(ctx context.Context, id int64) error { args := []any{id} rs, err := tx.ExecContext(ctx, query, args...) if err != nil { - return fmt.Errorf("%s: movie with id=%d: %w", op, id, err) + return err } n, err := rs.RowsAffected() if err != nil { - return fmt.Errorf("%s: movie with id=%d: %w", op, id, err) + return err } if n == 0 { return greenlight.ErrNotFound } if err := tx.Commit(); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } return nil } -func (s *MovieService) Create(ctx context.Context, m *greenlight.Movie) error { - op := "postgres.MovieService.Create" +func (s *MovieService) Create(ctx context.Context, m *greenlight.Movie) (err error) { + defer multierr.Wrap(&err, "postgres.MovieService.Create") ctx, cancel := context.WithTimeout(ctx, s.db.opts.queryTimeout) defer cancel() tx, err := s.db.db.BeginTx(ctx, nil) if err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } defer tx.Rollback() query := `INSERT INTO movies (title, release_date, runtime, genres) VALUES ($1, $2, $3, $4) RETURNING id, version` args := []any{m.Title, m.ReleaseDate, m.Runtime, pq.Array(m.Genres)} if err := tx.QueryRowContext(ctx, query, args...).Scan(&m.ID, &m.Version); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } if err := tx.Commit(); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } return nil } diff --git a/internal/postgres/postgres.go b/internal/postgres/postgres.go index 09525c1..90df301 100644 --- a/internal/postgres/postgres.go +++ b/internal/postgres/postgres.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "errors" - "fmt" "log/slog" "os" @@ -38,14 +37,14 @@ func NewDB(dsn string, opts ...Option) *DB { // Open returns a new instance of an established database connection. func (db *DB) Open() (err error) { - op := "postgres.DB.Open" + defer multierr.Wrap(&err, "postgres.DB.Open") if db.DSN == "" { return errors.New("data source name (DSN) required") } if db.db, err = sql.Open("postgres", db.DSN); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } db.db.SetMaxOpenConns(db.opts.maxOpenConns) @@ -57,22 +56,22 @@ func (db *DB) Open() (err error) { defer cancel() if err = db.db.PingContext(ctx); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } if err = db.Migrate(UP); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } return nil } // Close gracefully shuts down the database. -func (db *DB) Close() error { - op := "postgres.DB.Close" +func (db *DB) Close() (err error) { + defer multierr.Wrap(&err, "postgres.DB.Close") err1 := db.Migrate(DOWN) err2 := db.db.Close() if err := multierr.Join(err2, err1); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } return nil } diff --git a/internal/postgres/user.go b/internal/postgres/user.go index 710b58e..9b39a52 100644 --- a/internal/postgres/user.go +++ b/internal/postgres/user.go @@ -4,9 +4,9 @@ import ( "context" "database/sql" "errors" - "fmt" "github.com/denpeshkov/greenlight/internal/greenlight" + "github.com/denpeshkov/greenlight/internal/multierr" ) // UserService represents a service for managing users backed by PostgreSQL. @@ -23,15 +23,15 @@ func NewUserService(db *DB) *UserService { } } -func (s *UserService) Get(ctx context.Context, email string) (*greenlight.User, error) { - op := "postgres.UserService.Get" +func (s *UserService) Get(ctx context.Context, email string) (_ *greenlight.User, err error) { + defer multierr.Wrap(&err, "postgres.UserService.Get") ctx, cancel := context.WithTimeout(ctx, s.db.opts.queryTimeout) defer cancel() tx, err := s.db.db.BeginTx(ctx, nil) if err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } defer tx.Rollback() @@ -43,25 +43,25 @@ func (s *UserService) Get(ctx context.Context, email string) (*greenlight.User, case errors.Is(err, sql.ErrNoRows): return nil, greenlight.ErrNotFound default: - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } } if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("%s: %w", op, err) + return nil, err } return &u, nil } -func (s *UserService) Create(ctx context.Context, u *greenlight.User) error { - op := "postgres.UserService.Create" +func (s *UserService) Create(ctx context.Context, u *greenlight.User) (err error) { + defer multierr.Wrap(&err, "postgres.UserService.Create") ctx, cancel := context.WithTimeout(ctx, s.db.opts.queryTimeout) defer cancel() tx, err := s.db.db.BeginTx(ctx, nil) if err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } defer tx.Rollback() @@ -72,25 +72,25 @@ func (s *UserService) Create(ctx context.Context, u *greenlight.User) error { case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`: return greenlight.NewConflictError("A user with this email already exists.") default: - return fmt.Errorf("%s: %w", op, err) + return err } } if err := tx.Commit(); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } return nil } -func (s *UserService) Update(ctx context.Context, u *greenlight.User) error { - op := "postgres.UserService.Update" +func (s *UserService) Update(ctx context.Context, u *greenlight.User) (err error) { + defer multierr.Wrap(&err, "postgres.UserService.Update") ctx, cancel := context.WithTimeout(ctx, s.db.opts.queryTimeout) defer cancel() tx, err := s.db.db.BeginTx(ctx, nil) if err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } defer tx.Rollback() @@ -101,12 +101,12 @@ func (s *UserService) Update(ctx context.Context, u *greenlight.User) error { case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`: return greenlight.NewConflictError("A user with this email already exists.") default: - return fmt.Errorf("%s: %w", op, err) + return err } } if err := tx.Commit(); err != nil { - return fmt.Errorf("%s: %w", op, err) + return err } return nil }