streamdeck-obs-replay/slobs/client.go

276 lines
4.6 KiB
Go

package slobs
import (
"encoding/json"
"errors"
"github.com/dchest/uniuri"
"github.com/gorilla/websocket"
"log"
"math/rand"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
)
type SubscriptionHandler func(*ResourceEvent)
type ResponseHandler func(*RPCResponse)
const (
Event = "EVENT"
Subscription = "SUBSCRIPTION"
)
var (
ErrNoRequest = errors.New("request not found")
ErrNoHandler = errors.New("handler not found")
)
type Client struct {
conn *websocket.Conn
connected bool
Address string
requestId int32
requests map[int]ResponseHandler
requestLock sync.RWMutex
subscriptions map[string]SubscriptionHandler
subscriptionLock sync.RWMutex
}
func NewClient(address string) *Client {
return &Client{
Address: address,
requests: make(map[int]ResponseHandler),
subscriptions: make(map[string]SubscriptionHandler),
}
}
func (c *Client) Connected() bool {
return c.connected
}
func (c *Client) Connect() error {
if c.connected {
return errors.New("already connected")
}
endpoint := "ws://" + c.Address + "/api/" + paddedRandomIntn(999) + "/" + uniuri.New() + "/websocket"
var err error
c.conn, _, err = websocket.DefaultDialer.Dial(endpoint, http.Header{})
if err != nil {
return err
}
_, data, err := c.conn.ReadMessage()
if err != nil {
return err
}
if data[0] != 'o' {
return errors.New("invalid initial message")
}
c.connected = true
go c.loop()
return nil
}
func (c *Client) Auth(key string, callback func(error)) {
c.SendRPC("TcpServerService", "auth", func(response *RPCResponse) {
if response.Error != nil {
callback(errors.New(response.Error.Message))
return
}
var result bool
json.Unmarshal(*response.Result, &result)
if !result {
return
}
callback(nil)
}, key)
}
func (c *Client) Disconnect() error {
return c.conn.Close()
}
func (c *Client) Subscribe(resource, method string, handler SubscriptionHandler) error {
responseCh := make(chan error, 1)
c.SendRPC(resource, method, func(response *RPCResponse) {
if response.Error != nil {
responseCh <- errors.New(response.Error.Message)
return
}
res := &ResourceEvent{}
json.Unmarshal(*response.Result, &res)
c.subscriptionLock.Lock()
c.subscriptions[res.ResourceId] = handler
c.subscriptionLock.Unlock()
close(responseCh)
})
return <-responseCh
}
func (c *Client) SendRPC(resource, method string, handler ResponseHandler, args ...string) error {
m := make(map[string]interface{})
m["resource"] = resource
m["args"] = args
atomic.AddInt32(&c.requestId, 1)
newRequestId := int(atomic.LoadInt32(&c.requestId))
request := &RPCRequest{
ID: newRequestId,
Method: method,
Params: m,
JSONRPC: "2.0",
}
b, err := json.Marshal(request)
if err != nil {
return err
}
if handler != nil {
c.requestLock.Lock()
c.requests[newRequestId] = handler
c.requestLock.Unlock()
}
return c.conn.WriteJSON([]string{string(b) + "\n"})
}
func (c *Client) loop() error {
for {
_, data, err := c.conn.ReadMessage()
if err != nil {
c.connected = false
return err
}
if len(data) < 1 {
continue
}
switch data[0] {
case 'h':
// Heartbeat
continue
case 'a':
// Normal message
arr := make([]string, 0)
err := json.Unmarshal(data[1:], &arr)
if err != nil {
continue
}
for _, message := range arr {
resp := &RPCResponse{}
message = strings.TrimSpace(message)
log.Println("Handling", message)
err = json.Unmarshal([]byte(message), &resp)
if err != nil {
continue
}
go c.handle(resp)
}
case 'c':
// Session closed
var v []interface{}
if err := json.Unmarshal(data[1:], &v); err != nil {
log.Printf("Closing session: %s", err)
return nil
}
break
default:
log.Println("Unknown:", data[0])
}
}
}
func (c *Client) handle(resp *RPCResponse) error {
if resp.ID != nil {
c.requestLock.RLock()
h, ok := c.requests[*resp.ID]
c.requestLock.RUnlock()
if !ok {
return ErrNoRequest
}
h(resp)
c.requestLock.Lock()
delete(c.requests, *resp.ID)
c.requestLock.Unlock()
return nil
}
res := &ResourceEvent{}
err := json.Unmarshal(*resp.Result, &res)
if err != nil {
return err
}
switch res.Type {
case Event:
c.subscriptionLock.RLock()
h, exists := c.subscriptions[res.ResourceId]
c.subscriptionLock.RUnlock()
if !exists {
return ErrNoHandler
}
h(res)
}
return nil
}
func paddedRandomIntn(max int) string {
var (
ml = len(strconv.Itoa(max))
ri = rand.Intn(max)
is = strconv.Itoa(ri)
)
if len(is) < ml {
is = strings.Repeat("0", ml-len(is)) + is
}
return is
}