diff --git a/go.mod b/go.mod index 45749d6..d0b4a1f 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,11 @@ module meow.tf/astra/gavalink require ( - github.com/bwmarrin/discordgo v0.20.3 // indirect + github.com/bwmarrin/discordgo v0.20.3 github.com/foxbot/gavalink v0.0.0-20181105223750-6252b1245300 github.com/gorilla/websocket v1.4.0 github.com/valyala/fastjson v1.4.1 + golang.org/x/sync v0.0.0-20201207232520-09787c993a3a ) go 1.13 diff --git a/go.sum b/go.sum index 9f8fa13..1b80f5f 100644 --- a/go.sum +++ b/go.sum @@ -8,3 +8,5 @@ github.com/valyala/fastjson v1.4.1 h1:hrltpHpIpkaxll8QltMU8c3QZ5+qIiCL8yKqPFJI/y github.com/valyala/fastjson v1.4.1/go.mod h1:nV6MsjxL2IMJQUoHDIrjEI7oLyeqK6aBD7EFWPsvP8o= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16 h1:y6ce7gCWtnH+m3dCjzQ1PCuwl28DDIc3VNnvY29DlIA= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/lavalink.go b/lavalink.go index 8631c5e..73acf52 100644 --- a/lavalink.go +++ b/lavalink.go @@ -1,7 +1,9 @@ package gavalink import ( + "context" "errors" + "golang.org/x/sync/errgroup" "log" "net/http" "os" @@ -65,41 +67,29 @@ func (l *Lavalink) AddNodes(nodeConfigs ...NodeConfig) error { Timeout: 60 * time.Second, } - wg := &sync.WaitGroup{} - - errCh := make(chan error) + eg, ctx := errgroup.WithContext(context.Background()) for _, c := range nodeConfigs { - n := &Node{ - config: c, - manager: l, - client: client, - } + eg.Go(func() error { + n := &Node{ + config: c, + manager: l, + client: client, + } - wg.Add(1) - - go func(n *Node, wg *sync.WaitGroup, errCh chan error) { - defer wg.Done() - - err := n.open() + err := n.open(ctx) if err != nil { - errCh <- err - return + return err } l.nodes = append(l.nodes, n) - }(n, wg, errCh) + + return nil + }) } - wg.Wait() - - select { - case err := <-errCh: - return err - default: - return nil - } + return eg.Wait() } // RemoveNode removes a node from the manager diff --git a/node.go b/node.go index b1e4929..dc2a1d6 100644 --- a/node.go +++ b/node.go @@ -1,6 +1,7 @@ package gavalink import ( + "context" "encoding/json" "fmt" "github.com/valyala/fastjson" @@ -15,16 +16,21 @@ import ( // NodeConfig configures a Lavalink Node type NodeConfig struct { + // Node identifier (uuid, hostname, etc) + Identifier string + // REST is the host where Lavalink's REST server runs // // This value is expected without a trailing slash, e.g. like // `http://localhost:2333` REST string + // WebSocket is the host where Lavalink's WebSocket server runs // // This value is expected without a trailing slash, e.g. like // `http://localhost:8012` WebSocket string + // Password is the expected Authorization header for the Node Password string } @@ -68,7 +74,7 @@ type FrameStats struct { } // Opens the connection to the Lavalink server -func (node *Node) open() error { +func (node *Node) open(ctx context.Context) error { header := http.Header{} header.Set("User-Agent", gavalinkUserAgent()) @@ -92,7 +98,7 @@ func (node *Node) open() error { header.Set("Capabilities", strings.Join(v, ";")) } - ws, resp, err := websocket.DefaultDialer.Dial(node.config.WebSocket, header) + ws, resp, err := websocket.DefaultDialer.DialContext(ctx, node.config.WebSocket, header) if err != nil { return err @@ -136,7 +142,7 @@ func (node *Node) listen() { if err != nil { Log.Println(err) // try to reconnect - oerr := node.open() + oerr := node.open(context.Background()) if oerr != nil { Log.Println("node", node.config.WebSocket, "failed and could not reconnect, destroying.", err, oerr) diff --git a/player.go b/player.go index 5c7bcf3..45ded88 100644 --- a/player.go +++ b/player.go @@ -184,7 +184,9 @@ func (player *Player) Forward(sessionID string, event VoiceServerUpdate) error { func (player *Player) ChangeNode(node *Node) error { player.node = node - player.Forward(player.sessionID, player.lastVoiceServerUpdate) + if err := player.Forward(player.sessionID, player.lastVoiceServerUpdate); err != nil { + return err + } return player.PlayAt(player.track, player.position, 0) }