Skip to content

Commit

Permalink
implement SessionManager
Browse files Browse the repository at this point in the history
from #26
fix static check lint issues
  • Loading branch information
switchupcb committed Jul 7, 2023
1 parent da09a8a commit 5e0e580
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 39 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ bot := &disgo.Client{
Authorization: &disgo.Authorization{ ... },
Config: disgo.DefaultConfig(),
Handlers: new(disgo.Handlers),
Sessions: new(disgo.SessionManager)
Sessions: disgo.NewSessionManager()
}
```

Expand Down
27 changes: 24 additions & 3 deletions wrapper/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ func (e ErrorRequest) Error() string {

// Status Code Error Messages.
const (
errStatusCodeKnown = "Status Code %d: %v"
errStatusCodeUnknown = "Status Code %d: Unknown status code error from Discord"
errStatusCodeKnown = "status code %d: %v"
errStatusCodeUnknown = "status code %d: unknown status code error from Discord"
)

// StatusCodeError handles a Discord API HTTP Status Code and returns the relevant error message.
Expand Down Expand Up @@ -124,6 +124,26 @@ func (e ErrorEvent) Error() string {
e.ClientID, e.Event, e.Action, e.Err).Error()
}

// Discord Gateway Error Messages
const (
errNoSessionManager = `The client must contain a non-nil SessionManager to connect to the Discord Gateway.
Set the *Client.SessionManager using one of the following methods.
--- 1
bot := &disgo.Client{
...
Sessions: disgo.NewSessionManager()
}
--- 2
bot.Sessions = disgo.NewSessionManager()
`
)

// ErrorSession represents a WebSocket Session error that occurs during an active session.
type ErrorSession struct {
// Err represents the error that occurred.
Expand Down Expand Up @@ -158,5 +178,6 @@ func (e ErrorDisconnect) Error() string {
return fmt.Errorf("error disconnecting from %q\n"+
"\tDisconnect(): %v\n"+
"\treason: %w\n",
e.Connection, e.Err, e.Action).Error()
e.Connection, e.Err, e.Action,
).Error() //lint:ignore ST1005 readability
}
2 changes: 2 additions & 0 deletions wrapper/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ func putSession(s *Session) {
s.ID = ""
s.Seq = 0
s.Endpoint = ""
s.Shard = nil
s.Context = nil
s.Conn = nil
s.heartbeat = nil
s.manager = nil
s.client_manager = nil

spool.Put(s)
}
Expand Down
2 changes: 1 addition & 1 deletion wrapper/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (r *RateLimit) GetBucket(routeid string, resourceid string) *Bucket {
r.SetBucketID(requestid, requestid)

// DefaultBucket (Per-Route) = RateLimit.DefaultBucket
if "" == resourceid {
if resourceid == "" {
if r.DefaultBucket == nil {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion wrapper/request_form.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func createFormFile(m *multipart.Writer, name, filename, contentType string) (io
h.Set("Content-Disposition",
fmt.Sprintf(`form-data; name="%s"; filename="%s"`, name, quoteEscaper.Replace(filename)))

if "" == contentType {
if contentType == "" {
contentType = contentTypeOctetStreamString
}

Expand Down
27 changes: 26 additions & 1 deletion wrapper/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Session struct {
// Endpoint represents the endpoint that is used to reconnect to the Gateway.
Endpoint string

// Shard represents the [shard_id, num_shards] for this session.
// Shard represents the [shard_id, num_shards] for the Session.
//
// https://discord.com/developers/docs/topics/gateway#sharding
Shard *[2]int
Expand All @@ -52,6 +52,9 @@ type Session struct {
// manager represents a manager of a Session's goroutines.
manager *manager

// client_manager represents the *Client Session Manager of the Session.
client_manager *SessionManager

// RWMutex is used to protect the Session's variables from data races
// by providing transactional functionality.
sync.RWMutex
Expand Down Expand Up @@ -88,6 +91,12 @@ func (s *Session) Connect(bot *Client) error {

// connect connects a session to a WebSocket Connection.
func (s *Session) connect(bot *Client) error {
if bot.Sessions == nil {
return fmt.Errorf(errNoSessionManager) //lint:ignore ST1005 format help message.
}

s.client_manager = bot.Sessions

if s.isConnected() {
return fmt.Errorf("session %q is already connected", s.ID)
}
Expand Down Expand Up @@ -295,11 +304,15 @@ func (s *Session) initial(bot *Client, attempt int) error {

LogSession(Logger.Info(), ready.SessionID).Msg("received Ready event")

// Configure the session.
s.ID = ready.SessionID
atomic.StoreInt64(&s.Seq, 0)
s.Endpoint = ready.ResumeGatewayURL
bot.ApplicationID = ready.Application.ID

// Store the session in the session manager.
s.client_manager.Gateway.Store(s.ID, s)

if bot.Config.Gateway.ShardManager != nil {
bot.Config.Gateway.ShardManager.Ready(bot, s, ready)
}
Expand All @@ -313,6 +326,9 @@ func (s *Session) initial(bot *Client, attempt int) error {
case *payload.EventName == FlagGatewayEventNameResumed:
LogSession(Logger.Info(), s.ID).Msg("received Resumed event")

// Store the session in the session manager.
s.client_manager.Gateway.Store(s.ID, s)

for _, handler := range bot.Handlers.Resumed {
go handler(&Resumed{})
}
Expand All @@ -332,6 +348,9 @@ func (s *Session) initial(bot *Client, attempt int) error {
if replayed.Op == FlagGatewayOpcodeDispatch && *replayed.EventName == FlagGatewayEventNameResumed {
LogSession(Logger.Info(), s.ID).Msg("received Resumed event")

// Store the session in the session manager.
s.client_manager.Gateway.Store(s.ID, s)

for _, handler := range bot.Handlers.Resumed {
go handler(&Resumed{})
}
Expand All @@ -346,6 +365,9 @@ func (s *Session) initial(bot *Client, attempt int) error {
// When the maximum concurrency limit has been reached while connecting, or when
// the session does NOT reconnect in time, the Discord Gateway send an Opcode 9 Invalid Session.
case FlagGatewayOpcodeInvalidSession:
// Remove the session from the session manager.
s.client_manager.Gateway.Store(s.ID, nil)

if attempt < 1 {
// wait for Discord to close the session, then complete a fresh connect.
<-time.NewTimer(invalidSessionWaitTime).C
Expand Down Expand Up @@ -410,6 +432,9 @@ func (s *Session) disconnect(code int) error {
// cancel the context to kill the goroutines of the Session.
defer s.manager.cancel()

// Remove the session from the session manager.
s.client_manager.Gateway.Store(s.ID, nil)

if err := s.Conn.Close(websocket.StatusCode(code), ""); err != nil {
return fmt.Errorf("%w", err)
}
Expand Down
32 changes: 0 additions & 32 deletions wrapper/session_bot_manager.go

This file was deleted.

22 changes: 22 additions & 0 deletions wrapper/session_client_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package wrapper

import "sync"

// SessionManager manages sessions.
type SessionManager struct {
// Gateway represents a map of Discord Gateway (TCP WebSocket Connections) session IDs to Sessions.
// map[ID]Session (map[string]*Session)
Gateway *sync.Map

// Voice represents a map of Discord Voice (UDP WebSocket Connection) session IDs to Sessions.
// map[ID]Session (map[string]*Session)
Voice *sync.Map
}

// NewSessionManager creates a new SessionManager.
func NewSessionManager() *SessionManager {
return &SessionManager{
Gateway: new(sync.Map),
Voice: new(sync.Map),
}
}
3 changes: 3 additions & 0 deletions wrapper/session_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ func (s *Session) onPayload(bot *Client, payload GatewayPayload) error {

// in the context of onPayload, an Invalid Session occurs when an active session is invalidated.
case FlagGatewayOpcodeInvalidSession:
// Remove the session from the session manager.
s.client_manager.Gateway.Store(s.ID, nil)

// wait for Discord to close the session, then complete a fresh connect.
<-time.NewTimer(invalidSessionWaitTime).C

Expand Down

0 comments on commit 5e0e580

Please sign in to comment.