package slobs import ( "encoding/json" "errors" "github.com/dchest/uniuri" "github.com/gorilla/websocket" "log" "math/rand" "net/http" "strconv" "strings" "sync" "sync/atomic" ) type SubscriptionHandler func(*ResourceEvent) type ResponseHandler func(*RPCResponse) const ( Event = "EVENT" Subscription = "SUBSCRIPTION" ) var ( ErrNoRequest = errors.New("request not found") ErrNoHandler = errors.New("handler not found") ) type Client struct { conn *websocket.Conn connected bool Address string requestId int32 requests map[int]ResponseHandler requestLock sync.RWMutex subscriptions map[string]SubscriptionHandler subscriptionLock sync.RWMutex } func NewClient(address string) *Client { return &Client{ Address: address, requests: make(map[int]ResponseHandler), subscriptions: make(map[string]SubscriptionHandler), } } func (c *Client) Connected() bool { return c.connected } func (c *Client) Connect() error { if c.connected { return errors.New("already connected") } endpoint := "ws://" + c.Address + "/api/" + paddedRandomIntn(999) + "/" + uniuri.New() + "/websocket" var err error c.conn, _, err = websocket.DefaultDialer.Dial(endpoint, http.Header{}) if err != nil { return err } _, data, err := c.conn.ReadMessage() if err != nil { return err } if data[0] != 'o' { return errors.New("invalid initial message") } c.connected = true go c.loop() return nil } func (c *Client) Disconnect() error { return c.conn.Close() } func (c *Client) Subscribe(resource, method string, handler SubscriptionHandler) { c.SendRPC(resource, method, func(response *RPCResponse) { res := &ResourceEvent{} json.Unmarshal(*response.Result, &res) c.subscriptionLock.Lock() c.subscriptions[res.ResourceId] = handler c.subscriptionLock.Unlock() }) } func (c *Client) SendRPC(resource, method string, handler ResponseHandler) error { m := make(map[string]interface{}) m["resource"] = resource m["args"] = []string{} atomic.AddInt32(&c.requestId, 1) newRequestId := int(atomic.LoadInt32(&c.requestId)) request := &RPCRequest{ ID: newRequestId, Method: method, Params: m, JSONRPC: "2.0", } b, err := json.Marshal(request) if err != nil { return err } if handler != nil { c.requestLock.Lock() c.requests[newRequestId] = handler c.requestLock.Unlock() } return c.conn.WriteJSON([]string{string(b) + "\n"}) } func (c *Client) loop() error { for { _, data, err := c.conn.ReadMessage() if err != nil { c.connected = false return err } if len(data) < 1 { continue } switch data[0] { case 'h': // Heartbeat continue case 'a': // Normal message arr := make([]string, 0) err := json.Unmarshal(data[1:], &arr) if err != nil { continue } for _, message := range arr { resp := &RPCResponse{} message = strings.TrimSpace(message) log.Println("Handling", message) err = json.Unmarshal([]byte(message), &resp) if err != nil { continue } go c.handle(resp) } case 'c': // Session closed var v []interface{} if err := json.Unmarshal(data[1:], &v); err != nil { log.Printf("Closing session: %s", err) return nil } break default: log.Println("Unknown:", data[0]) } } } func (c *Client) handle(resp *RPCResponse) error { if resp.ID != nil { c.requestLock.RLock() h, ok := c.requests[*resp.ID] c.requestLock.RUnlock() if !ok { return ErrNoRequest } h(resp) c.requestLock.Lock() delete(c.requests, *resp.ID) c.requestLock.Unlock() return nil } res := &ResourceEvent{} err := json.Unmarshal(*resp.Result, &res) if err != nil { return err } switch res.Type { case Event: c.subscriptionLock.RLock() h, exists := c.subscriptions[res.ResourceId] c.subscriptionLock.RUnlock() if !exists { return ErrNoHandler } h(res) } return nil } func paddedRandomIntn(max int) string { var ( ml = len(strconv.Itoa(max)) ri = rand.Intn(max) is = strconv.Itoa(ri) ) if len(is) < ml { is = strings.Repeat("0", ml-len(is)) + is } return is }