8 Commits

7 changed files with 214 additions and 32 deletions

4
.gitignore vendored
View File

@ -23,3 +23,7 @@
# End of https://www.gitignore.io/api/go # End of https://www.gitignore.io/api/go
.env .env
# IntelliJ
.idea/
*.iml

51
balancer.go Normal file
View File

@ -0,0 +1,51 @@
package gavalink
import (
"math"
"sort"
)
type balancePenalty struct {
node *Node
penalty int
}
func BestNodeByPenalties(nodes []*Node) (*Node, error) {
penalties := make([]balancePenalty, len(nodes))
var playerPenalty, cpuPenalty, deficitFramePenalty, nullFramePenalty int
for i, node := range nodes {
playerPenalty = 0
cpuPenalty = 0
deficitFramePenalty = 0
nullFramePenalty = 0
if node.stats != nil {
playerPenalty = node.stats.ActivePlayers
cpuPenalty = int(math.Pow(1.05, 100*node.stats.Cpu.SystemLoad)*10 - 10)
if node.stats.Frames != nil && node.stats.Frames.Deficit != -1 {
deficitFramePenalty = int(math.Pow(1.03, 500*float64(node.stats.Frames.Deficit/3000))*600 - 600)
nullFramePenalty = int(math.Pow(1.03, 500*float64(node.stats.Frames.Nulled/3000))*300 - 300)
nullFramePenalty *= 2
}
}
penalties[i] = balancePenalty{node, playerPenalty + cpuPenalty + deficitFramePenalty + nullFramePenalty}
}
sort.SliceStable(penalties, func(i, j int) bool {
return penalties[i].penalty < penalties[j].penalty
})
return penalties[0].node, nil
}
func BestNodeByLoad(n []*Node) (*Node, error) {
sort.SliceStable(n, func(i, j int) bool {
return n[i].stats.Cpu.LavalinkLoad < n[j].stats.Cpu.LavalinkLoad
})
return n[0], nil
}

View File

@ -6,6 +6,7 @@ import (
"encoding/binary" "encoding/binary"
"github.com/valyala/fastjson" "github.com/valyala/fastjson"
"io" "io"
"time"
) )
const trackInfoVersioned int32 = 1 const trackInfoVersioned int32 = 1
@ -109,7 +110,7 @@ func Decode(r io.Reader) (*TrackInfo, error) {
Author: author, Author: author,
URI: url, URI: url,
Stream: stream == 1, Stream: stream == 1,
Length: int(length), Length: time.Duration(length) * time.Millisecond,
} }
return track, nil return track, nil

View File

@ -5,7 +5,6 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"sort"
"time" "time"
) )
@ -24,6 +23,8 @@ type Lavalink struct {
nodes []*Node nodes []*Node
players map[string]*Player players map[string]*Player
capabilities map[string]interface{}
BestNodeFunc func([]*Node) (*Node, error) BestNodeFunc func([]*Node) (*Node, error)
} }
@ -44,7 +45,7 @@ func NewLavalink(shards string, userID string) *Lavalink {
userID: userID, userID: userID,
players: make(map[string]*Player), players: make(map[string]*Player),
BestNodeFunc: BestNodeByLoad, BestNodeFunc: BestNodeByPenalties,
} }
} }
@ -135,10 +136,11 @@ func (lavalink *Lavalink) GetPlayer(guild string) (*Player, error) {
return p, nil return p, nil
} }
func BestNodeByLoad(n []*Node) (*Node, error) { // Add capabilities mappings to the client, letting the server know what we support
sort.SliceStable(n, func(i, j int) bool { func (lavalink *Lavalink) AddCapability(key string, i interface{}) {
return n[i].load < n[j].load if lavalink.capabilities == nil {
}) lavalink.capabilities = make(map[string]interface{})
}
return n[0], nil
lavalink.capabilities[key] = i
} }

View File

@ -1,9 +1,11 @@
package gavalink package gavalink
import ( import (
"encoding/json"
"io" "io"
"net/http" "net/http"
"os" "os"
"time"
) )
const ( const (
@ -27,7 +29,7 @@ type Tracks struct {
// NoMatches, or LoadFailed // NoMatches, or LoadFailed
Type string `json:"loadType"` Type string `json:"loadType"`
PlaylistInfo *PlaylistInfo `json:"playlistInfo"` PlaylistInfo *PlaylistInfo `json:"playlistInfo"`
Tracks []Track `json:"tracks"` Tracks []*Track `json:"tracks"`
} }
// PlaylistInfo contains information about a loaded playlist // PlaylistInfo contains information about a loaded playlist
@ -43,7 +45,7 @@ type PlaylistInfo struct {
type Track struct { type Track struct {
// Data contains the base64 encoded Lavaplayer track // Data contains the base64 encoded Lavaplayer track
Data string `json:"track"` Data string `json:"track"`
Info TrackInfo `json:"info"` Info *TrackInfo `json:"info"`
} }
// TrackInfo contains more data about a loaded track // TrackInfo contains more data about a loaded track
@ -54,10 +56,36 @@ type TrackInfo struct {
URI string `json:"uri"` URI string `json:"uri"`
Seekable bool `json:"isSeekable"` Seekable bool `json:"isSeekable"`
Stream bool `json:"isStream"` Stream bool `json:"isStream"`
Length int `json:"length"` Length time.Duration `json:"length"`
Position int `json:"position"` Position int `json:"position"`
} }
func (t *TrackInfo) MarshalJSON() ([]byte, error) {
type Alias TrackInfo
return json.Marshal(&struct {
Length int64 `json:"length"`
*Alias
}{
Length: int64(t.Length / time.Millisecond),
Alias: (*Alias)(t),
})
}
func (t *TrackInfo) UnmarshalJSON(data []byte) error {
type Alias TrackInfo
aux := &struct {
Length int64 `json:"length"`
*Alias
}{
Alias: (*Alias)(t),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
t.Length = time.Duration(aux.Length) * time.Millisecond
return nil
}
const ( const (
opVoiceUpdate = "voiceUpdate" opVoiceUpdate = "voiceUpdate"
opVoiceProcessed = "voiceProcessed" opVoiceProcessed = "voiceProcessed"
@ -90,7 +118,9 @@ type VoiceServerUpdate struct {
type VoiceProcessingData struct { type VoiceProcessingData struct {
io.ReadCloser io.ReadCloser
Client *http.Client node *Node
UserID string
URL string URL string
File string File string
@ -98,7 +128,11 @@ type VoiceProcessingData struct {
} }
func (v *VoiceProcessingData) open() error { func (v *VoiceProcessingData) open() error {
res, err := v.Client.Get(v.URL) req, err := http.NewRequest(http.MethodGet, v.URL, nil)
req.Header.Set("Authorization", v.node.config.Password)
res, err := v.node.client.Do(req)
if err != nil { if err != nil {
return err return err

33
model_test.go Normal file
View File

@ -0,0 +1,33 @@
package gavalink
import (
"encoding/json"
"testing"
"time"
)
func TestTrackInfo_JSON(t *testing.T) {
i := &TrackInfo{
Length: 10 * time.Second,
}
b, err := json.Marshal(i)
if err != nil {
t.Fatal(err)
}
t.Log(string(b))
deserialize := &TrackInfo{}
if err = json.Unmarshal(b, &deserialize); err != nil {
t.Fatal(err)
}
t.Log("Deserialized length:", deserialize.Length)
if deserialize.Length != time.Second*10 {
t.Fatal("Expected deserialized time to be 10 seconds!")
}
}

71
node.go
View File

@ -5,7 +5,9 @@ import (
"fmt" "fmt"
"github.com/valyala/fastjson" "github.com/valyala/fastjson"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
@ -29,12 +31,41 @@ type NodeConfig struct {
// Node wraps a Lavalink Node // Node wraps a Lavalink Node
type Node struct { type Node struct {
config NodeConfig config NodeConfig
load float32 stats *RemoteStats
manager *Lavalink manager *Lavalink
wsConn *websocket.Conn wsConn *websocket.Conn
client *http.Client client *http.Client
} }
type RemoteStats struct {
Op string `json:"op"`
Players int `json:"players"`
ActivePlayers int `json:"playingPlayers"`
Uptime int64 `json:"uptime"`
Memory *MemoryStats `json:"memory"`
Cpu *CpuStats `json:"cpu"`
Frames *FrameStats `json:"frameStats"`
}
type MemoryStats struct {
Free uint64 `json:"free"`
Used uint64 `json:"used"`
Allocated uint64 `json:"allocated"`
Reserveable uint64 `json:"reserveable"`
}
type CpuStats struct {
Cores int `json:"cores"`
SystemLoad float64 `json:"systemLoad"`
LavalinkLoad float64 `json:"lavalinkLoad"`
}
type FrameStats struct {
Sent int `json:"sent"`
Nulled int `json:"nulled"`
Deficit int `json:"deficit"`
}
func (node *Node) open() error { func (node *Node) open() error {
header := http.Header{} header := http.Header{}
@ -42,6 +73,22 @@ func (node *Node) open() error {
header.Set("Num-Shards", node.manager.shards) header.Set("Num-Shards", node.manager.shards)
header.Set("User-Id", node.manager.userID) header.Set("User-Id", node.manager.userID)
if node.manager.capabilities != nil {
v := make([]string, 0)
for k, vals := range node.manager.capabilities {
b, err := json.Marshal(vals)
if err != nil {
continue
}
v = append(v, k+"="+string(b))
}
header.Set("Capabilities", strings.Join(v, ";"))
}
ws, resp, err := websocket.DefaultDialer.Dial(node.config.WebSocket, header) ws, resp, err := websocket.DefaultDialer.Dial(node.config.WebSocket, header)
if err != nil { if err != nil {
@ -108,16 +155,25 @@ func (node *Node) listen() {
continue continue
} }
node.onEvent(v) node.onEvent(v, msg)
} }
} }
func (node *Node) onEvent(v *fastjson.Value) error { func (node *Node) onEvent(v *fastjson.Value, msg []byte) error {
op := jsonStringValue(v, "op") op := jsonStringValue(v, "op")
switch op { switch op {
case opStats:
node.stats = &RemoteStats{}
err := json.Unmarshal(msg, &node.stats)
if err != nil {
return err
}
case opPlayerUpdate: case opPlayerUpdate:
player, err := node.manager.GetPlayer(jsonStringValue(v, "guildId")) player, err := node.manager.GetPlayer(jsonStringValue(v, "guildId"))
if err != nil { if err != nil {
return err return err
} }
@ -153,13 +209,13 @@ func (node *Node) onEvent(v *fastjson.Value) error {
track := jsonStringValue(v, "track") track := jsonStringValue(v, "track")
data := &VoiceProcessingData{ data := &VoiceProcessingData{
node: node,
UserID: jsonStringValue(v, "userId"),
URL: fmt.Sprintf("%s/audio/%s", node.config.REST, track), URL: fmt.Sprintf("%s/audio/%s", node.config.REST, track),
File: track, File: track,
} }
return player.handler.OnVoiceProcessed(player, data, v.GetBool("hotword"), v.GetBool("override")) return player.handler.OnVoiceProcessed(player, data, v.GetBool("hotword"), v.GetBool("override"))
case opStats:
node.load = float32(v.GetFloat64("cpu", "lavalinkLoad"))
default: default:
return errUnknownPayload return errUnknownPayload
} }
@ -206,9 +262,10 @@ func (node *Node) CreatePlayer(guildID string, sessionID string, event VoiceServ
// //
// See the Lavaplayer Source Code for all valid options. // See the Lavaplayer Source Code for all valid options.
func (node *Node) LoadTracks(query string) (*Tracks, error) { func (node *Node) LoadTracks(query string) (*Tracks, error) {
url := fmt.Sprintf("%s/loadtracks?identifier=%s", node.config.REST, query) v := url.Values{}
v.Set("identifier", query)
req, err := http.NewRequest(http.MethodGet, url, nil) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/loadtracks?%s", node.config.REST, v.Encode()), nil)
if err != nil { if err != nil {
return nil, err return nil, err