-
Notifications
You must be signed in to change notification settings - Fork 0
/
session_store.go
113 lines (90 loc) · 2.7 KB
/
session_store.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package auth
import (
"errors"
"fmt"
"time"
"github.com/cristosal/orm"
"github.com/cristosal/orm/schema"
)
type (
// SessionRepo is a postgres backed session store
SessionRepo struct{ db orm.DB }
sessionRow struct {
ID string
UserID *int64
Data Session
CreatedAt time.Time
UpdatedAt time.Time
ExpiresAt time.Time
}
)
func (sessionRow) TableName() string {
return "sessions"
}
// NewSessionRepo returns postgres backed session store
func NewSessionRepo(db orm.DB) *SessionRepo {
return &SessionRepo{db}
}
// Drop drops the session table
func (s *SessionRepo) Drop() error {
return orm.Exec(s.db, "drop table sessions")
}
// Save upserts session into database
func (s *SessionRepo) Save(sess *Session) error {
sess.Counter++
if sess.ID == "" {
sid, err := GenerateToken(16)
if err != nil {
return err
}
sess.ID = sid
return orm.Exec(s.db, "insert into sessions (id, user_id, data, expires_at) values ($1, $2, $3, $4)",
sid, sess.UserID(), sess, sess.ExpiresAt)
}
return orm.Exec(s.db, "update sessions set updated_at = now(), data = $1, user_id = $2 where id = $3", sess, sess.UserID(), sess.ID)
}
// ByID returns a session by its id
func (s *SessionRepo) ByID(sessionID string) (*Session, error) {
var row sessionRow
if err := orm.Get(s.db, &row, "where id = $1", sessionID); err != nil {
if errors.Is(err, orm.ErrNotFound) {
return nil, ErrSessionNotFound
}
return nil, err
}
return &row.Data, nil
}
// ByUserID returns all sessions belonging to a user
func (s *SessionRepo) ByUserID(uid int64) ([]Session, error) {
var rows []sessionRow
if err := orm.List(s.db, &rows, "user_id = $1", uid); err != nil {
return nil, err
}
sessions := make([]Session, 0)
for i := range rows {
sessions = append(sessions, rows[i].Data)
}
return sessions, nil
}
// Remove session by id
func (s *SessionRepo) RemoveByID(id string) error {
return orm.Exec(s.db, "delete from sessions where id = $1", id)
}
// DeleteByUserID deletes all sessions for users in the email list
func (s *SessionRepo) RemoveByEmails(emails []string) error {
valueList := schema.ValueList(len(emails), 1)
sql := fmt.Sprintf(`DELETE FROM sessions WHERE user_id IN (SELECT id FROM users WHERE email IN (%s))`, valueList)
var values []any
for i := range emails {
values = append(values, emails[i])
}
return orm.Exec(s.db, sql, values...)
}
// RemoveByUserID deletes all sessions for a given user
func (s *SessionRepo) RemoveByUserID(uid int64) error {
return orm.Exec(s.db, "delete from sessions where user_id = $1", uid)
}
// RemoveExpired deletes all sessions which have expired
func (s *SessionRepo) RemoveExpired() error {
return orm.Exec(s.db, "delete from sessions where expires_at < now()")
}