8 Commits

5 changed files with 185 additions and 34 deletions

51
balancer.go Normal file
View File

@ -0,0 +1,51 @@
package gavalink
import (
"math"
"sort"
)
type balancePenalty struct {
node *Node
penalty int
}
func BestNodeByPenalties(nodes []*Node) (*Node, error) {
penalties := make([]balancePenalty, len(nodes))
var playerPenalty, cpuPenalty, deficitFramePenalty, nullFramePenalty int
for i, node := range nodes {
playerPenalty = 0
cpuPenalty = 0
deficitFramePenalty = 0
nullFramePenalty = 0
if node.stats != nil {
playerPenalty = node.stats.ActivePlayers
cpuPenalty = int(math.Pow(1.05, 100*node.stats.Cpu.SystemLoad)*10 - 10)
if node.stats.Frames != nil && node.stats.Frames.Deficit != -1 {
deficitFramePenalty = int(math.Pow(1.03, 500*float64(node.stats.Frames.Deficit/3000))*600 - 600)
nullFramePenalty = int(math.Pow(1.03, 500*float64(node.stats.Frames.Nulled/3000))*300 - 300)
nullFramePenalty *= 2
}
}
penalties[i] = balancePenalty{node, playerPenalty + cpuPenalty + deficitFramePenalty + nullFramePenalty}
}
sort.SliceStable(penalties, func(i, j int) bool {
return penalties[i].penalty < penalties[j].penalty
})
return penalties[0].node, nil
}
func BestNodeByLoad(n []*Node) (*Node, error) {
sort.SliceStable(n, func(i, j int) bool {
return n[i].stats.Cpu.LavalinkLoad < n[j].stats.Cpu.LavalinkLoad
})
return n[0], nil
}

View File

@ -3,8 +3,9 @@ package gavalink
import ( import (
"errors" "errors"
"log" "log"
"net/http"
"os" "os"
"sort" "time"
) )
// Log sets the log.Logger gavalink will write to // Log sets the log.Logger gavalink will write to
@ -19,8 +20,12 @@ type Lavalink struct {
shards string shards string
userID string userID string
nodes []Node nodes []*Node
players map[string]*Player players map[string]*Player
capabilities map[string]interface{}
BestNodeFunc func([]*Node) (*Node, error)
} }
var ( var (
@ -36,21 +41,27 @@ var (
// NewLavalink creates a new Lavalink manager // NewLavalink creates a new Lavalink manager
func NewLavalink(shards string, userID string) *Lavalink { func NewLavalink(shards string, userID string) *Lavalink {
return &Lavalink{ return &Lavalink{
shards: shards, shards: shards,
userID: userID, userID: userID,
/* nodes: make([]Node, 1),*/
players: make(map[string]*Player), players: make(map[string]*Player),
BestNodeFunc: BestNodeByPenalties,
} }
} }
// AddNodes adds a node to the Lavalink manager // AddNodes adds a node to the Lavalink manager
func (lavalink *Lavalink) AddNodes(nodeConfigs ...NodeConfig) error { func (lavalink *Lavalink) AddNodes(nodeConfigs ...NodeConfig) error {
nodes := make([]Node, len(nodeConfigs)) nodes := make([]*Node, len(nodeConfigs))
client := &http.Client{
Timeout: 60 * time.Second,
}
for i, c := range nodeConfigs { for i, c := range nodeConfigs {
n := Node{ n := &Node{
config: c, config: c,
manager: lavalink, manager: lavalink,
client: client,
} }
err := n.open() err := n.open()
@ -71,7 +82,7 @@ func (lavalink *Lavalink) AddNodes(nodeConfigs ...NodeConfig) error {
func (lavalink *Lavalink) removeNode(node *Node) error { func (lavalink *Lavalink) removeNode(node *Node) error {
idx := -1 idx := -1
for i, n := range lavalink.nodes { for i, n := range lavalink.nodes {
if n == *node { if n == node {
idx = i idx = i
break break
} }
@ -82,6 +93,18 @@ func (lavalink *Lavalink) removeNode(node *Node) error {
node.stop() node.stop()
for _, player := range lavalink.players {
if player.node == node {
n, err := lavalink.BestNode()
if err != nil {
continue
}
player.ChangeNode(n)
}
}
// temp var for easier reading // temp var for easier reading
n := lavalink.nodes n := lavalink.nodes
z := len(n) - 1 z := len(n) - 1
@ -99,11 +122,7 @@ func (lavalink *Lavalink) BestNode() (*Node, error) {
return nil, errNoNodes return nil, errNoNodes
} }
sort.SliceStable(lavalink.nodes, func(i, j int) bool { return lavalink.BestNodeFunc(lavalink.nodes)
return lavalink.nodes[i].load < lavalink.nodes[j].load
})
return &lavalink.nodes[0], nil
} }
// GetPlayer gets a player for a guild // GetPlayer gets a player for a guild
@ -116,3 +135,12 @@ func (lavalink *Lavalink) GetPlayer(guild string) (*Player, error) {
return p, nil return p, nil
} }
// Add capabilities mappings to the client, letting the server know what we support
func (lavalink *Lavalink) AddCapability(key string, i interface{}) {
if lavalink.capabilities == nil {
lavalink.capabilities = make(map[string]interface{})
}
lavalink.capabilities[key] = i
}

View File

@ -91,6 +91,7 @@ type VoiceProcessingData struct {
io.ReadCloser io.ReadCloser
Client *http.Client Client *http.Client
UserID string
URL string URL string
File string File string

81
node.go
View File

@ -6,6 +6,7 @@ import (
"github.com/valyala/fastjson" "github.com/valyala/fastjson"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
@ -29,12 +30,41 @@ type NodeConfig struct {
// Node wraps a Lavalink Node // Node wraps a Lavalink Node
type Node struct { type Node struct {
config NodeConfig config NodeConfig
load float32 stats *RemoteStats
manager *Lavalink manager *Lavalink
wsConn *websocket.Conn wsConn *websocket.Conn
client *http.Client client *http.Client
} }
type RemoteStats struct {
Op string `json:"op"`
Players int `json:"players"`
ActivePlayers int `json:"playingPlayers"`
Uptime int64 `json:"uptime"`
Memory *MemoryStats `json:"memory"`
Cpu *CpuStats `json:"cpu"`
Frames *FrameStats `json:"frameStats"`
}
type MemoryStats struct {
Free uint64 `json:"free"`
Used uint64 `json:"used"`
Allocated uint64 `json:"allocated"`
Reserveable uint64 `json:"reserveable"`
}
type CpuStats struct {
Cores int `json:"cores"`
SystemLoad float64 `json:"systemLoad"`
LavalinkLoad float64 `json:"lavalinkLoad"`
}
type FrameStats struct {
Sent int `json:"sent"`
Nulled int `json:"nulled"`
Deficit int `json:"deficit"`
}
func (node *Node) open() error { func (node *Node) open() error {
header := http.Header{} header := http.Header{}
@ -42,6 +72,22 @@ func (node *Node) open() error {
header.Set("Num-Shards", node.manager.shards) header.Set("Num-Shards", node.manager.shards)
header.Set("User-Id", node.manager.userID) header.Set("User-Id", node.manager.userID)
if node.manager.capabilities != nil {
v := make([]string, 0)
for k, vals := range node.manager.capabilities {
b, err := json.Marshal(vals)
if err != nil {
continue
}
v = append(v, k+"="+string(b))
}
header.Set("Capabilities", strings.Join(v, ";"))
}
ws, resp, err := websocket.DefaultDialer.Dial(node.config.WebSocket, header) ws, resp, err := websocket.DefaultDialer.Dial(node.config.WebSocket, header)
if err != nil { if err != nil {
@ -108,16 +154,25 @@ func (node *Node) listen() {
continue continue
} }
node.onEvent(v) node.onEvent(v, msg)
} }
} }
func (node *Node) onEvent(v *fastjson.Value) error { func (node *Node) onEvent(v *fastjson.Value, msg []byte) error {
op := jsonStringValue(v, "op") op := jsonStringValue(v, "op")
switch op { switch op {
case opStats:
node.stats = &RemoteStats{}
err := json.Unmarshal(msg, &node.stats)
if err != nil {
return err
}
case opPlayerUpdate: case opPlayerUpdate:
player, err := node.manager.GetPlayer(jsonStringValue(v, "guildId")) player, err := node.manager.GetPlayer(jsonStringValue(v, "guildId"))
if err != nil { if err != nil {
return err return err
} }
@ -153,13 +208,13 @@ func (node *Node) onEvent(v *fastjson.Value) error {
track := jsonStringValue(v, "track") track := jsonStringValue(v, "track")
data := &VoiceProcessingData{ data := &VoiceProcessingData{
URL: fmt.Sprintf("%s/audio/%s", node.config.REST, track), Client: node.client,
File: track, UserID: jsonStringValue(v, "userId"),
URL: fmt.Sprintf("%s/audio/%s", node.config.REST, track),
File: track,
} }
return player.handler.OnVoiceProcessed(player, data, v.GetBool("hotword"), v.GetBool("override")) return player.handler.OnVoiceProcessed(player, data, v.GetBool("hotword"), v.GetBool("override"))
case opStats:
node.load = float32(v.GetFloat64("cpu", "lavalinkLoad"))
default: default:
return errUnknownPayload return errUnknownPayload
} }
@ -183,11 +238,13 @@ func (node *Node) CreatePlayer(guildID string, sessionID string, event VoiceServ
} }
player := &Player{ player := &Player{
guildID: guildID, guildID: guildID,
manager: node.manager, sessionID: sessionID,
node: node, manager: node.manager,
handler: handler, node: node,
vol: 100, handler: handler,
vol: 100,
lastVoiceServerUpdate: event,
} }
node.manager.players[guildID] = player node.manager.players[guildID] = player

View File

@ -6,15 +6,17 @@ import (
// Player is a Lavalink player // Player is a Lavalink player
type Player struct { type Player struct {
guildID string guildID string
time int sessionID string
position int time int
paused bool position int
vol int paused bool
track string vol int
manager *Lavalink track string
node *Node manager *Lavalink
handler EventHandler node *Node
handler EventHandler
lastVoiceServerUpdate VoiceServerUpdate
} }
// GuildID returns this player's Guild ID // GuildID returns this player's Guild ID
@ -166,6 +168,8 @@ func (player *Player) UserLeave(userId string) error {
// To move a player to a new Node, first player.Destroy() it, and then // To move a player to a new Node, first player.Destroy() it, and then
// create a new player on the new node. // create a new player on the new node.
func (player *Player) Forward(sessionID string, event VoiceServerUpdate) error { func (player *Player) Forward(sessionID string, event VoiceServerUpdate) error {
player.sessionID = sessionID
msg := voiceUpdateMessage{ msg := voiceUpdateMessage{
Op: opVoiceUpdate, Op: opVoiceUpdate,
GuildID: player.guildID, GuildID: player.guildID,
@ -173,9 +177,19 @@ func (player *Player) Forward(sessionID string, event VoiceServerUpdate) error {
Event: &event, Event: &event,
} }
player.lastVoiceServerUpdate = event
return player.node.wsConn.WriteJSON(msg) return player.node.wsConn.WriteJSON(msg)
} }
func (player *Player) ChangeNode(node *Node) error {
player.node = node
player.Forward(player.sessionID, player.lastVoiceServerUpdate)
return player.PlayAt(player.track, player.position, 0)
}
// Destroy will destroy this player // Destroy will destroy this player
func (player *Player) Destroy() error { func (player *Player) Destroy() error {
msg := basicMessage{ msg := basicMessage{