godns/hosts/hosts_redis.go

112 lines
2.0 KiB
Go
Raw Permalink Normal View History

2020-01-25 17:43:02 +00:00
package hosts
import (
"encoding/json"
"github.com/go-redis/redis/v7"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"strings"
)
type RedisHosts struct {
Provider
redis *redis.Client
key string
}
func NewRedisProvider(rc *redis.Client, key string) Provider {
rh := &RedisHosts{
redis: rc,
2019-09-26 04:43:17 +00:00
key: key,
}
return rh
}
2021-04-15 04:41:06 +00:00
func (r *RedisHosts) List() (HostMap, error) {
res, err := r.redis.HGetAll(r.key).Result()
if err != nil {
return nil, err
}
hosts := make(HostMap)
for k, v := range res {
var domainRecords []Host
if err = json.Unmarshal([]byte(v), &domainRecords); err != nil {
continue
}
hosts[k] = domainRecords
}
return nil, errUnsupportedOperation
}
func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) {
log.WithFields(log.Fields{
"queryType": dns.TypeToString[queryType],
"question": domain,
}).Debug("Checking redis provider")
2018-08-05 04:16:15 +00:00
domain = strings.ToLower(domain)
if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil {
var h []Host
if err = json.Unmarshal([]byte(res), &h); err != nil {
return nil, err
}
2018-09-01 01:27:26 +00:00
for _, host := range h {
if host.Type == queryType {
return &host, nil
}
}
}
2018-09-01 01:27:26 +00:00
if idx := strings.Index(domain, "."); idx != -1 {
2019-09-26 04:43:17 +00:00
wildcard := "*." + domain[strings.Index(domain, ".")+1:]
2018-09-01 01:27:26 +00:00
if res, err := r.redis.HGet(r.key, wildcard).Result(); res != "" && err == nil {
var h []Host
if err = json.Unmarshal([]byte(res), &h); err != nil {
return nil, err
}
for _, host := range h {
if host.Type == queryType {
return &host, nil
}
}
}
}
2018-09-01 01:27:26 +00:00
return nil, errRecordNotFound
}
func (r *RedisHosts) Set(domain string, host *Host) error {
hosts := []*Host{host}
if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil {
if err = json.Unmarshal([]byte(res), &hosts); err != nil {
return err
}
hosts = append(hosts, host)
}
b, err := json.Marshal(hosts)
if err != nil {
return err
}
_, err = r.redis.HSet(r.key, strings.ToLower(domain), b).Result()
return err
}