Proper locking, start to User-Agent population

This commit is contained in:
Tyler 2020-06-19 23:19:33 -04:00
parent c5910e7f5d
commit 169cc4f563
4 changed files with 58 additions and 23 deletions

View File

@ -55,7 +55,7 @@ func ready(s *discordgo.Session, event *discordgo.Ready) {
log.Println("discordgo ready!") log.Println("discordgo ready!")
s.UpdateStatus(0, "gavalink") s.UpdateStatus(0, "gavalink")
lavalink = gavalink.NewLavalink("1", event.User.ID) lavalink = gavalink.NewLavalink(1, event.User.ID)
err := lavalink.AddNodes(gavalink.NodeConfig{ err := lavalink.AddNodes(gavalink.NodeConfig{
REST: "http://localhost:2333", REST: "http://localhost:2333",

View File

@ -5,6 +5,7 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"runtime"
"sync" "sync"
"time" "time"
) )
@ -18,11 +19,13 @@ func init() {
// Lavalink manages a connection to Lavalink Nodes // Lavalink manages a connection to Lavalink Nodes
type Lavalink struct { type Lavalink struct {
shards string shards int
userID string userID string
nodes []*Node nodes []*Node
players map[string]*Player
players map[string]*Player
playersMu sync.RWMutex
// Event handlers // Event handlers
handlersMu sync.RWMutex handlersMu sync.RWMutex
@ -41,11 +44,10 @@ var (
errVolumeOutOfRange = errors.New("Volume is out of range, must be within [0, 1000]") errVolumeOutOfRange = errors.New("Volume is out of range, must be within [0, 1000]")
errInvalidVersion = errors.New("This library requires Lavalink >= 3") errInvalidVersion = errors.New("This library requires Lavalink >= 3")
errUnknownPayload = errors.New("Lavalink sent an unknown payload") errUnknownPayload = errors.New("Lavalink sent an unknown payload")
errNilHandler = errors.New("You must provide an event handler. Use gavalink.DummyEventHandler if you wish to ignore events")
) )
// NewLavalink creates a new Lavalink manager // NewLavalink creates a new Lavalink manager
func NewLavalink(shards string, userID string) *Lavalink { func NewLavalink(shards int, userID string) *Lavalink {
return &Lavalink{ return &Lavalink{
shards: shards, shards: shards,
userID: userID, userID: userID,
@ -56,32 +58,48 @@ func NewLavalink(shards string, userID string) *Lavalink {
} }
// AddNodes adds a node to the Lavalink manager // AddNodes adds a node to the Lavalink manager
// This function calls all of the node connect methods at once.
// TODO perhaps add a pool/max at a time limit?
func (l *Lavalink) AddNodes(nodeConfigs ...NodeConfig) error { func (l *Lavalink) AddNodes(nodeConfigs ...NodeConfig) error {
nodes := make([]*Node, len(nodeConfigs))
client := &http.Client{ client := &http.Client{
Timeout: 60 * time.Second, Timeout: 60 * time.Second,
} }
for i, c := range nodeConfigs { wg := &sync.WaitGroup{}
errCh := make(chan error)
for _, c := range nodeConfigs {
n := &Node{ n := &Node{
config: c, config: c,
manager: l, manager: l,
client: client, client: client,
} }
err := n.open() wg.Add(1)
if err != nil { go func(n *Node, wg *sync.WaitGroup, errCh chan error) {
return err defer wg.Done()
}
nodes[i] = n err := n.open()
if err != nil {
errCh <- err
return
}
l.nodes = append(l.nodes, n)
}(n, wg, errCh)
} }
l.nodes = append(l.nodes, nodes...) wg.Wait()
return nil select {
case err := <-errCh:
return err
default:
return nil
}
} }
// RemoveNode removes a node from the manager // RemoveNode removes a node from the manager
@ -99,6 +117,7 @@ func (l *Lavalink) removeNode(node *Node) error {
node.stop() node.stop()
l.playersMu.RLock()
for _, player := range l.players { for _, player := range l.players {
if player.node == node { if player.node == node {
n, err := l.BestNode() n, err := l.BestNode()
@ -110,6 +129,7 @@ func (l *Lavalink) removeNode(node *Node) error {
player.ChangeNode(n) player.ChangeNode(n)
} }
} }
l.playersMu.RUnlock()
// temp var for easier reading // temp var for easier reading
n := l.nodes n := l.nodes
@ -133,6 +153,9 @@ func (l *Lavalink) BestNode() (*Node, error) {
// GetPlayer gets a player for a guild // GetPlayer gets a player for a guild
func (l *Lavalink) GetPlayer(guild string) (*Player, error) { func (l *Lavalink) GetPlayer(guild string) (*Player, error) {
l.playersMu.RLock()
defer l.playersMu.RUnlock()
p, ok := l.players[guild] p, ok := l.players[guild]
if !ok { if !ok {
@ -150,3 +173,7 @@ func (l *Lavalink) AddCapability(key string, i interface{}) {
l.capabilities[key] = i l.capabilities[key] = i
} }
func gavalinkUserAgent() string {
return "Gavalink (v1.0, " + runtime.Version() + ")"
}

19
node.go
View File

@ -67,11 +67,13 @@ type FrameStats struct {
Deficit int `json:"deficit"` Deficit int `json:"deficit"`
} }
// Opens the connection to the Lavalink server
func (node *Node) open() error { func (node *Node) open() error {
header := http.Header{} header := http.Header{}
header.Set("User-Agent", gavalinkUserAgent())
header.Set("Authorization", node.config.Password) header.Set("Authorization", node.config.Password)
header.Set("Num-Shards", node.manager.shards) header.Set("Num-Shards", strconv.Itoa(node.manager.shards))
header.Set("User-Id", node.manager.userID) header.Set("User-Id", node.manager.userID)
if node.manager.capabilities != nil { if node.manager.capabilities != nil {
@ -160,6 +162,7 @@ func (node *Node) listen() {
} }
} }
// Handle an event from the node
func (node *Node) onEvent(v *fastjson.Value, msg []byte) error { func (node *Node) onEvent(v *fastjson.Value, msg []byte) error {
op := jsonStringValue(v, "op") op := jsonStringValue(v, "op")
@ -193,13 +196,13 @@ func (node *Node) onEvent(v *fastjson.Value, msg []byte) error {
case eventTrackStart: case eventTrackStart:
player.track = track player.track = track
player.handle(eventTrackStart, &TrackStart{ node.manager.handle(player, eventTrackStart, &TrackStart{
Track: track, Track: track,
}) })
case eventTrackEnd: case eventTrackEnd:
player.track = "" player.track = ""
player.handle(eventTrackEnd, &TrackEnd{ node.manager.handle(player, eventTrackEnd, &TrackEnd{
Track: track, Track: track,
Reason: jsonStringValue(v, "reason"), Reason: jsonStringValue(v, "reason"),
}) })
@ -215,9 +218,9 @@ func (node *Node) onEvent(v *fastjson.Value, msg []byte) error {
ex.Exception = exception ex.Exception = exception
} }
player.handle(eventTrackException, ex) node.manager.handle(player, eventTrackException, ex)
case eventTrackStuck: case eventTrackStuck:
player.handle(eventTrackStuck, &TrackStuck{ node.manager.handle(player, eventTrackStuck, &TrackStuck{
Track: track, Track: track,
Threshold: time.Duration(v.GetInt("thresholdMs")) * time.Millisecond, Threshold: time.Duration(v.GetInt("thresholdMs")) * time.Millisecond,
}) })
@ -229,7 +232,7 @@ func (node *Node) onEvent(v *fastjson.Value, msg []byte) error {
File: track, File: track,
} }
player.handle(eventVoiceProcessed, &VoiceProcessed{ node.manager.handle(player, eventVoiceProcessed, &VoiceProcessed{
Data: data, Data: data,
Hotword: v.GetBool("hotword"), Hotword: v.GetBool("hotword"),
Override: v.GetBool("override"), Override: v.GetBool("override"),
@ -268,6 +271,8 @@ func (node *Node) CreatePlayer(guildID string, sessionID string, event VoiceServ
lastVoiceServerUpdate: event, lastVoiceServerUpdate: event,
} }
node.manager.playersMu.Lock()
defer node.manager.playersMu.Unlock()
node.manager.players[guildID] = player node.manager.players[guildID] = player
return player, nil return player, nil
@ -291,6 +296,7 @@ func (node *Node) LoadTracks(query string) (*Tracks, error) {
return nil, err return nil, err
} }
req.Header.Set("User-Agent", gavalinkUserAgent())
req.Header.Set("Authorization", node.config.Password) req.Header.Set("Authorization", node.config.Password)
resp, err := node.client.Do(req) resp, err := node.client.Do(req)
@ -308,6 +314,7 @@ func (node *Node) LoadTracks(query string) (*Tracks, error) {
return tracks, nil return tracks, nil
} }
// Write a JSON message via the node's websocket connection
func (node *Node) writeMessage(v interface{}) error { func (node *Node) writeMessage(v interface{}) error {
return node.wsConn.WriteJSON(v) return node.wsConn.WriteJSON(v)
} }

View File

@ -2,7 +2,6 @@ package gavalink
import ( import (
"strconv" "strconv"
"sync"
) )
// Player is a Lavalink player // Player is a Lavalink player
@ -203,6 +202,8 @@ func (player *Player) Destroy() error {
return err return err
} }
player.manager.playersMu.Lock()
defer player.manager.playersMu.Unlock()
delete(player.manager.players, player.guildID) delete(player.manager.players, player.guildID)
return nil return nil
} }