From 169cc4f563122be7c62c1586ebbed4374d78f09c Mon Sep 17 00:00:00 2001 From: Tyler Date: Fri, 19 Jun 2020 23:19:33 -0400 Subject: [PATCH] Proper locking, start to User-Agent population --- example/main.go | 2 +- lavalink.go | 57 ++++++++++++++++++++++++++++++++++++------------- node.go | 19 +++++++++++------ player.go | 3 ++- 4 files changed, 58 insertions(+), 23 deletions(-) diff --git a/example/main.go b/example/main.go index b7ecadf..2c0223a 100644 --- a/example/main.go +++ b/example/main.go @@ -55,7 +55,7 @@ func ready(s *discordgo.Session, event *discordgo.Ready) { log.Println("discordgo ready!") s.UpdateStatus(0, "gavalink") - lavalink = gavalink.NewLavalink("1", event.User.ID) + lavalink = gavalink.NewLavalink(1, event.User.ID) err := lavalink.AddNodes(gavalink.NodeConfig{ REST: "http://localhost:2333", diff --git a/lavalink.go b/lavalink.go index 8261eac..e9e2015 100644 --- a/lavalink.go +++ b/lavalink.go @@ -5,6 +5,7 @@ import ( "log" "net/http" "os" + "runtime" "sync" "time" ) @@ -18,11 +19,13 @@ func init() { // Lavalink manages a connection to Lavalink Nodes type Lavalink struct { - shards string + shards int userID string - nodes []*Node - players map[string]*Player + nodes []*Node + + players map[string]*Player + playersMu sync.RWMutex // Event handlers handlersMu sync.RWMutex @@ -41,11 +44,10 @@ var ( errVolumeOutOfRange = errors.New("Volume is out of range, must be within [0, 1000]") errInvalidVersion = errors.New("This library requires Lavalink >= 3") errUnknownPayload = errors.New("Lavalink sent an unknown payload") - errNilHandler = errors.New("You must provide an event handler. Use gavalink.DummyEventHandler if you wish to ignore events") ) // NewLavalink creates a new Lavalink manager -func NewLavalink(shards string, userID string) *Lavalink { +func NewLavalink(shards int, userID string) *Lavalink { return &Lavalink{ shards: shards, userID: userID, @@ -56,32 +58,48 @@ func NewLavalink(shards string, userID string) *Lavalink { } // AddNodes adds a node to the Lavalink manager +// This function calls all of the node connect methods at once. +// TODO perhaps add a pool/max at a time limit? func (l *Lavalink) AddNodes(nodeConfigs ...NodeConfig) error { - nodes := make([]*Node, len(nodeConfigs)) - client := &http.Client{ Timeout: 60 * time.Second, } - for i, c := range nodeConfigs { + wg := &sync.WaitGroup{} + + errCh := make(chan error) + + for _, c := range nodeConfigs { n := &Node{ config: c, manager: l, client: client, } - err := n.open() + wg.Add(1) - if err != nil { - return err - } + go func(n *Node, wg *sync.WaitGroup, errCh chan error) { + defer wg.Done() - nodes[i] = n + err := n.open() + + if err != nil { + errCh <- err + return + } + + l.nodes = append(l.nodes, n) + }(n, wg, errCh) } - l.nodes = append(l.nodes, nodes...) + wg.Wait() - return nil + select { + case err := <-errCh: + return err + default: + return nil + } } // RemoveNode removes a node from the manager @@ -99,6 +117,7 @@ func (l *Lavalink) removeNode(node *Node) error { node.stop() + l.playersMu.RLock() for _, player := range l.players { if player.node == node { n, err := l.BestNode() @@ -110,6 +129,7 @@ func (l *Lavalink) removeNode(node *Node) error { player.ChangeNode(n) } } + l.playersMu.RUnlock() // temp var for easier reading n := l.nodes @@ -133,6 +153,9 @@ func (l *Lavalink) BestNode() (*Node, error) { // GetPlayer gets a player for a guild func (l *Lavalink) GetPlayer(guild string) (*Player, error) { + l.playersMu.RLock() + defer l.playersMu.RUnlock() + p, ok := l.players[guild] if !ok { @@ -150,3 +173,7 @@ func (l *Lavalink) AddCapability(key string, i interface{}) { l.capabilities[key] = i } + +func gavalinkUserAgent() string { + return "Gavalink (v1.0, " + runtime.Version() + ")" +} diff --git a/node.go b/node.go index 80cdc0a..b1e4929 100644 --- a/node.go +++ b/node.go @@ -67,11 +67,13 @@ type FrameStats struct { Deficit int `json:"deficit"` } +// Opens the connection to the Lavalink server func (node *Node) open() error { header := http.Header{} + header.Set("User-Agent", gavalinkUserAgent()) header.Set("Authorization", node.config.Password) - header.Set("Num-Shards", node.manager.shards) + header.Set("Num-Shards", strconv.Itoa(node.manager.shards)) header.Set("User-Id", node.manager.userID) if node.manager.capabilities != nil { @@ -160,6 +162,7 @@ func (node *Node) listen() { } } +// Handle an event from the node func (node *Node) onEvent(v *fastjson.Value, msg []byte) error { op := jsonStringValue(v, "op") @@ -193,13 +196,13 @@ func (node *Node) onEvent(v *fastjson.Value, msg []byte) error { case eventTrackStart: player.track = track - player.handle(eventTrackStart, &TrackStart{ + node.manager.handle(player, eventTrackStart, &TrackStart{ Track: track, }) case eventTrackEnd: player.track = "" - player.handle(eventTrackEnd, &TrackEnd{ + node.manager.handle(player, eventTrackEnd, &TrackEnd{ Track: track, Reason: jsonStringValue(v, "reason"), }) @@ -215,9 +218,9 @@ func (node *Node) onEvent(v *fastjson.Value, msg []byte) error { ex.Exception = exception } - player.handle(eventTrackException, ex) + node.manager.handle(player, eventTrackException, ex) case eventTrackStuck: - player.handle(eventTrackStuck, &TrackStuck{ + node.manager.handle(player, eventTrackStuck, &TrackStuck{ Track: track, Threshold: time.Duration(v.GetInt("thresholdMs")) * time.Millisecond, }) @@ -229,7 +232,7 @@ func (node *Node) onEvent(v *fastjson.Value, msg []byte) error { File: track, } - player.handle(eventVoiceProcessed, &VoiceProcessed{ + node.manager.handle(player, eventVoiceProcessed, &VoiceProcessed{ Data: data, Hotword: v.GetBool("hotword"), Override: v.GetBool("override"), @@ -268,6 +271,8 @@ func (node *Node) CreatePlayer(guildID string, sessionID string, event VoiceServ lastVoiceServerUpdate: event, } + node.manager.playersMu.Lock() + defer node.manager.playersMu.Unlock() node.manager.players[guildID] = player return player, nil @@ -291,6 +296,7 @@ func (node *Node) LoadTracks(query string) (*Tracks, error) { return nil, err } + req.Header.Set("User-Agent", gavalinkUserAgent()) req.Header.Set("Authorization", node.config.Password) resp, err := node.client.Do(req) @@ -308,6 +314,7 @@ func (node *Node) LoadTracks(query string) (*Tracks, error) { return tracks, nil } +// Write a JSON message via the node's websocket connection func (node *Node) writeMessage(v interface{}) error { return node.wsConn.WriteJSON(v) } diff --git a/player.go b/player.go index 7bad5f7..a32295a 100644 --- a/player.go +++ b/player.go @@ -2,7 +2,6 @@ package gavalink import ( "strconv" - "sync" ) // Player is a Lavalink player @@ -203,6 +202,8 @@ func (player *Player) Destroy() error { return err } + player.manager.playersMu.Lock() + defer player.manager.playersMu.Unlock() delete(player.manager.players, player.guildID) return nil }