86 lines
1.6 KiB
Go
86 lines
1.6 KiB
Go
package hosts
|
|
|
|
import (
|
|
"encoding/json"
|
|
"github.com/go-redis/redis/v7"
|
|
log "github.com/sirupsen/logrus"
|
|
"strings"
|
|
)
|
|
|
|
type RedisHosts struct {
|
|
Provider
|
|
|
|
redis *redis.Client
|
|
key string
|
|
}
|
|
|
|
func NewRedisProvider(rc *redis.Client, key string) Provider {
|
|
rh := &RedisHosts{
|
|
redis: rc,
|
|
key: key,
|
|
}
|
|
|
|
return rh
|
|
}
|
|
|
|
func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) {
|
|
log.Debug("Checking redis provider for %s", domain)
|
|
|
|
domain = strings.ToLower(domain)
|
|
|
|
if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil {
|
|
var h []Host
|
|
|
|
if err = json.Unmarshal([]byte(res), &h); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, host := range h {
|
|
if host.Type == queryType {
|
|
return &host, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
if idx := strings.Index(domain, "."); idx != -1 {
|
|
wildcard := "*." + domain[strings.Index(domain, ".")+1:]
|
|
|
|
if res, err := r.redis.HGet(r.key, wildcard).Result(); res != "" && err == nil {
|
|
var h []Host
|
|
|
|
if err = json.Unmarshal([]byte(res), &h); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, host := range h {
|
|
if host.Type == queryType {
|
|
return &host, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, errRecordNotFound
|
|
}
|
|
|
|
func (r *RedisHosts) Set(domain string, host *Host) error {
|
|
hosts := []*Host{host}
|
|
|
|
if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil {
|
|
if err = json.Unmarshal([]byte(res), &hosts); err != nil {
|
|
return err
|
|
}
|
|
|
|
hosts = append(hosts, host)
|
|
}
|
|
|
|
b, err := json.Marshal(hosts)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = r.redis.HSet(r.key, strings.ToLower(domain), b).Result()
|
|
|
|
return err
|
|
} |