diff --git a/decoder.go b/decoder.go index fc58795..d1f99b8 100644 --- a/decoder.go +++ b/decoder.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/base64" "encoding/binary" + "github.com/valyala/fastjson" "io" ) @@ -126,3 +127,19 @@ func readString(r io.Reader) (string, error) { return string(buf), nil } + +func jsonStringValue(v *fastjson.Value, keys ... string) string { + value := v.Get(keys...) + + if value == nil { + return "" + } + + strB, err := value.StringBytes() + + if err != nil { + return "" + } + + return string(strB) +} diff --git a/lavalink.go b/lavalink.go index aca0648..c50e46d 100644 --- a/lavalink.go +++ b/lavalink.go @@ -46,6 +46,7 @@ func NewLavalink(shards string, userID string) *Lavalink { // AddNodes adds a node to the Lavalink manager func (lavalink *Lavalink) AddNodes(nodeConfigs ...NodeConfig) error { nodes := make([]Node, len(nodeConfigs)) + for i, c := range nodeConfigs { n := Node{ config: c, @@ -57,7 +58,9 @@ func (lavalink *Lavalink) AddNodes(nodeConfigs ...NodeConfig) error { } nodes[i] = n } + lavalink.nodes = append(lavalink.nodes, nodes...) + return nil } @@ -103,8 +106,10 @@ func (lavalink *Lavalink) BestNode() (*Node, error) { // GetPlayer gets a player for a guild func (lavalink *Lavalink) GetPlayer(guild string) (*Player, error) { p, ok := lavalink.players[guild] + if !ok { return nil, errPlayerNotFound } + return p, nil } diff --git a/messages.go b/messages.go new file mode 100644 index 0000000..94b0a83 --- /dev/null +++ b/messages.go @@ -0,0 +1,39 @@ +package gavalink + +type basicMessage struct { + Op string `json:"op"` + GuildID string `json:"guildId,omitempty"` +} + +type playMessage struct { + Op string `json:"op"` + GuildID string `json:"guildId,omitempty"` + Track string `json:"track,omitempty"` + StartTime string `json:"startTime,omitempty"` + EndTime string `json:"endTime,omitempty"` +} + +type pauseMessage struct { + Op string `json:"op"` + GuildID string `json:"guildId,omitempty"` + Pause bool `json:"pause,omitempty"` +} + +type seekMessage struct { + Op string `json:"op"` + GuildID string `json:"guildId,omitempty"` + Position *int `json:"position,omitempty"` +} + +type volumeMessage struct { + Op string `json:"op"` + GuildID string `json:"guildId,omitempty"` + Volume int `json:"volume,omitempty"` +} + +type voiceUpdateMessage struct { + Op string `json:"op"` + GuildID string `json:"guildId,omitempty"` + SessionID string `json:"sessionId,omitempty"` + Event *VoiceServerUpdate `json:"event,omitempty"` +} diff --git a/model.go b/model.go index 98221cb..ac3322a 100644 --- a/model.go +++ b/model.go @@ -1,105 +1,76 @@ -package gavalink - -const ( - // TrackLoaded is a Tracks Type for a succesful single track load - TrackLoaded = "TRACK_LOADED" - // PlaylistLoaded is a Tracks Type for a succseful playlist load - PlaylistLoaded = "PLAYLIST_LOADED" - // SearchResult is a Tracks Type for a search containing many tracks - SearchResult = "SEARCH_RESULT" - // NoMatches is a Tracks Type for a query yielding no matches - NoMatches = "NO_MATCHES" - // LoadFailed is a Tracks Type for an internal Lavalink error - LoadFailed = "LOAD_FAILED" -) - -// Tracks contains data for a Lavalink Tracks response -type Tracks struct { - // Type contains the type of response - // - // This will be one of TrackLoaded, PlaylistLoaded, SearchResult, - // NoMatches, or LoadFailed - Type string `json:"loadType"` - PlaylistInfo *PlaylistInfo `json:"playlistInfo"` - Tracks []Track `json:"tracks"` -} - -// PlaylistInfo contains information about a loaded playlist -type PlaylistInfo struct { - // Name is the friendly of the playlist - Name string `json:"name"` - // SelectedTrack is the index of the track that loaded the playlist, - // if one is present. - SelectedTrack int `json:"selectedTrack"` -} - -// Track contains information about a loaded track -type Track struct { - // Data contains the base64 encoded Lavaplayer track - Data string `json:"track"` - Info TrackInfo `json:"info"` -} - -// TrackInfo contains more data about a loaded track -type TrackInfo struct { - Identifier string `json:"identifier"` - Title string `json:"title"` - Author string `json:"author"` - URI string `json:"uri"` - Seekable bool `json:"isSeekable"` - Stream bool `json:"isStream"` - Length int `json:"length"` - Position int `json:"position"` -} - -const ( - opVoiceUpdate = "voiceUpdate" - opPlay = "play" - opStop = "stop" - opPause = "pause" - opSeek = "seek" - opVolume = "volume" - opDestroy = "destroy" - opPlayerUpdate = "playerUpdate" - opEvent = "event" - opStats = "stats" - eventTrackEnd = "TrackEndEvent" - eventTrackException = "TrackExceptionEvent" - eventTrackStuck = "TrackStuckEvent" -) - -type message struct { - Op string `json:"op"` - GuildID string `json:"guildId,omitempty"` - SessionID string `json:"sessionId,omitempty"` - Event *VoiceServerUpdate `json:"event,omitempty"` - Track string `json:"track,omitempty"` - StartTime string `json:"startTime,omitempty"` - EndTime string `json:"endTime,omitempty"` - Pause *bool `json:"pause,omitempty"` - Position *int `json:"position,omitempty"` - Volume *int `json:"volume,omitempty"` - State *state `json:"state,omitempty"` - Type string `json:"type,omitempty"` - Reason string `json:"reason,omitempty"` - Error string `json:"error,omitempty"` - ThresholdMs int `json:"thresholdMs,omitempty"` - StatCPU *statCPU `json:"cpu,omitempty"` - // TODO: stats -} - -type state struct { - Time int `json:"time"` - Position int `json:"position"` -} - -type statCPU struct { - Load float32 `json:"lavalinkLoad"` -} - -// VoiceServerUpdate is a raw Discord VOICE_SERVER_UPDATE event -type VoiceServerUpdate struct { - GuildID string `json:"guild_id"` - Endpoint string `json:"endpoint"` - Token string `json:"token"` -} +package gavalink + +const ( + // TrackLoaded is a Tracks Type for a succesful single track load + TrackLoaded = "TRACK_LOADED" + // PlaylistLoaded is a Tracks Type for a succseful playlist load + PlaylistLoaded = "PLAYLIST_LOADED" + // SearchResult is a Tracks Type for a search containing many tracks + SearchResult = "SEARCH_RESULT" + // NoMatches is a Tracks Type for a query yielding no matches + NoMatches = "NO_MATCHES" + // LoadFailed is a Tracks Type for an internal Lavalink error + LoadFailed = "LOAD_FAILED" +) + +// Tracks contains data for a Lavalink Tracks response +type Tracks struct { + // Type contains the type of response + // + // This will be one of TrackLoaded, PlaylistLoaded, SearchResult, + // NoMatches, or LoadFailed + Type string `json:"loadType"` + PlaylistInfo *PlaylistInfo `json:"playlistInfo"` + Tracks []Track `json:"tracks"` +} + +// PlaylistInfo contains information about a loaded playlist +type PlaylistInfo struct { + // Name is the friendly of the playlist + Name string `json:"name"` + // SelectedTrack is the index of the track that loaded the playlist, + // if one is present. + SelectedTrack int `json:"selectedTrack"` +} + +// Track contains information about a loaded track +type Track struct { + // Data contains the base64 encoded Lavaplayer track + Data string `json:"track"` + Info TrackInfo `json:"info"` +} + +// TrackInfo contains more data about a loaded track +type TrackInfo struct { + Identifier string `json:"identifier"` + Title string `json:"title"` + Author string `json:"author"` + URI string `json:"uri"` + Seekable bool `json:"isSeekable"` + Stream bool `json:"isStream"` + Length int `json:"length"` + Position int `json:"position"` +} + +const ( + opVoiceUpdate = "voiceUpdate" + opPlay = "play" + opStop = "stop" + opPause = "pause" + opSeek = "seek" + opVolume = "volume" + opDestroy = "destroy" + opPlayerUpdate = "playerUpdate" + opEvent = "event" + opStats = "stats" + eventTrackEnd = "TrackEndEvent" + eventTrackException = "TrackExceptionEvent" + eventTrackStuck = "TrackStuckEvent" +) + +// VoiceServerUpdate is a raw Discord VOICE_SERVER_UPDATE event +type VoiceServerUpdate struct { + GuildID string `json:"guild_id"` + Endpoint string `json:"endpoint"` + Token string `json:"token"` +} diff --git a/node.go b/node.go index 9a81197..e00e55b 100644 --- a/node.go +++ b/node.go @@ -3,7 +3,7 @@ package gavalink import ( "encoding/json" "fmt" - "io/ioutil" + "github.com/valyala/fastjson" "net/http" "strconv" @@ -70,8 +70,11 @@ func (node *Node) stop() { } func (node *Node) listen() { + var p fastjson.Parser + for { msgType, msg, err := node.wsConn.ReadMessage() + if err != nil { Log.Println(err) // try to reconnect @@ -84,53 +87,54 @@ func (node *Node) listen() { Log.Println("node", node.config.WebSocket, "reconnected") return } - err = node.onEvent(msgType, msg) - // TODO: better error handling? + + if msgType != websocket.TextMessage { + continue + } + + v, err := p.ParseBytes(msg) if err != nil { - Log.Println(err) + continue } + + node.onEvent(v) } } -func (node *Node) onEvent(msgType int, msg []byte) error { - if msgType != websocket.TextMessage { - return errUnknownPayload - } +func (node *Node) onEvent(v *fastjson.Value) error { + op := jsonStringValue(v, "op") - m := message{} - err := json.Unmarshal(msg, &m) - if err != nil { - return err - } - - switch m.Op { + switch op { case opPlayerUpdate: - player, err := node.manager.GetPlayer(m.GuildID) - if err != nil { - return err - } - player.time = m.State.Time - player.position = m.State.Position - case opEvent: - player, err := node.manager.GetPlayer(m.GuildID) + player, err := node.manager.GetPlayer(jsonStringValue(v, "guildId")) if err != nil { return err } - switch m.Type { + player.time = v.GetInt("state", "time") + player.position = v.GetInt("state", "position") + case opEvent: + player, err := node.manager.GetPlayer(jsonStringValue(v, "guildId")) + if err != nil { + return err + } + + track := jsonStringValue(v, "track") + + switch jsonStringValue(v, "type") { case eventTrackEnd: player.track = "" - err = player.handler.OnTrackEnd(player, m.Track, m.Reason) + err = player.handler.OnTrackEnd(player, track, jsonStringValue(v, "reason")) case eventTrackException: - err = player.handler.OnTrackException(player, m.Track, m.Reason) + err = player.handler.OnTrackException(player, track, jsonStringValue(v, "reason")) case eventTrackStuck: - err = player.handler.OnTrackStuck(player, m.Track, m.ThresholdMs) + err = player.handler.OnTrackStuck(player, track, v.GetInt("thresholdMs")) } return err case opStats: - node.load = m.StatCPU.Load + node.load = float32(v.GetFloat64("cpu", "lavalinkLoad")) default: return errUnknownPayload } @@ -140,20 +144,19 @@ func (node *Node) onEvent(msgType int, msg []byte) error { // CreatePlayer creates an audio player on this node func (node *Node) CreatePlayer(guildID string, sessionID string, event VoiceServerUpdate, handler EventHandler) (*Player, error) { - msg := message{ + msg := voiceUpdateMessage{ Op: opVoiceUpdate, GuildID: guildID, SessionID: sessionID, Event: &event, } - data, err := json.Marshal(msg) - if err != nil { - return nil, err - } - err = node.wsConn.WriteMessage(websocket.TextMessage, data) + + err := node.writeMessage(msg) + if err != nil { return nil, err } + player := &Player{ guildID: guildID, manager: node.manager, @@ -161,7 +164,9 @@ func (node *Node) CreatePlayer(guildID string, sessionID string, event VoiceServ handler: handler, vol: 100, } + node.manager.players[guildID] = player + return player, nil } @@ -175,24 +180,30 @@ func (node *Node) CreatePlayer(guildID string, sessionID string, event VoiceServ // See the Lavaplayer Source Code for all valid options. func (node *Node) LoadTracks(query string) (*Tracks, error) { url := fmt.Sprintf("%s/loadtracks?identifier=%s", node.config.REST, query) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { return nil, err } + req.Header.Set("Authorization", node.config.Password) resp, err := http.DefaultClient.Do(req) + if err != nil { return nil, err } - data, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } + tracks := new(Tracks) - err = json.Unmarshal(data, &tracks) - if err != nil { + + if err := json.NewDecoder(resp.Body).Decode(&tracks); err != nil { return nil, err } + return tracks, nil } + +func (node *Node) writeMessage(v interface{}) error { + return node.wsConn.WriteJSON(v) +} diff --git a/player.go b/player.go index d4699d6..cf7480f 100644 --- a/player.go +++ b/player.go @@ -1,10 +1,7 @@ package gavalink import ( - "encoding/json" "strconv" - - "github.com/gorilla/websocket" ) // Player is a Lavalink player @@ -40,19 +37,15 @@ func (player *Player) PlayAt(track string, startTime int, endTime int) error { start := strconv.Itoa(startTime) end := strconv.Itoa(endTime) - msg := message{ + msg := playMessage{ Op: opPlay, GuildID: player.guildID, Track: track, StartTime: start, EndTime: end, } - data, err := json.Marshal(msg) - if err != nil { - return err - } - err = player.node.wsConn.WriteMessage(websocket.TextMessage, data) - return err + + return player.node.writeMessage(msg) } // Track returns the player's current track @@ -63,33 +56,26 @@ func (player *Player) Track() string { // Stop will stop the currently playing track func (player *Player) Stop() error { player.track = "" - msg := message{ + + msg := basicMessage{ Op: opStop, GuildID: player.guildID, } - data, err := json.Marshal(msg) - if err != nil { - return err - } - err = player.node.wsConn.WriteMessage(websocket.TextMessage, data) - return err + + return player.node.writeMessage(msg) } // Pause will pause or resume the player, depending on the pause parameter func (player *Player) Pause(pause bool) error { player.paused = pause - msg := message{ + msg := pauseMessage{ Op: opPause, GuildID: player.guildID, - Pause: &pause, + Pause: pause, } - data, err := json.Marshal(msg) - if err != nil { - return err - } - err = player.node.wsConn.WriteMessage(websocket.TextMessage, data) - return err + + return player.node.writeMessage(msg) } // Paused returns whether or not the player is currently paused @@ -99,17 +85,13 @@ func (player *Player) Paused() bool { // Seek will seek the player to the speicifed position, in millis func (player *Player) Seek(position int) error { - msg := message{ + msg := seekMessage{ Op: opSeek, GuildID: player.guildID, Position: &position, } - data, err := json.Marshal(msg) - if err != nil { - return err - } - err = player.node.wsConn.WriteMessage(websocket.TextMessage, data) - return err + + return player.node.wsConn.WriteJSON(msg) } // Position returns the player's position, as reported by Lavalink @@ -127,17 +109,13 @@ func (player *Player) Volume(volume int) error { player.vol = volume - msg := message{ + msg := volumeMessage{ Op: opVolume, GuildID: player.guildID, - Volume: &volume, + Volume: volume, } - data, err := json.Marshal(msg) - if err != nil { - return err - } - err = player.node.wsConn.WriteMessage(websocket.TextMessage, data) - return err + + return player.node.wsConn.WriteJSON(msg) } // GetVolume gets the player's volume level @@ -154,34 +132,29 @@ func (player *Player) GetVolume() int { // To move a player to a new Node, first player.Destroy() it, and then // create a new player on the new node. func (player *Player) Forward(sessionID string, event VoiceServerUpdate) error { - msg := message{ + msg := voiceUpdateMessage{ Op: opVoiceUpdate, GuildID: player.guildID, SessionID: sessionID, Event: &event, } - data, err := json.Marshal(msg) - if err != nil { - return err - } - err = player.node.wsConn.WriteMessage(websocket.TextMessage, data) - return err + + return player.node.wsConn.WriteJSON(msg) } // Destroy will destroy this player func (player *Player) Destroy() error { - msg := message{ + msg := basicMessage{ Op: opDestroy, GuildID: player.guildID, } - data, err := json.Marshal(msg) - if err != nil { - return err - } - err = player.node.wsConn.WriteMessage(websocket.TextMessage, data) + + err := player.node.wsConn.WriteJSON(msg) + if err != nil { return err } + delete(player.manager.players, player.guildID) return nil }