package hosts import ( "encoding/json" "github.com/go-redis/redis/v7" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "strings" ) type RedisHosts struct { Provider redis *redis.Client key string } func NewRedisProvider(rc *redis.Client, key string) Provider { rh := &RedisHosts{ redis: rc, key: key, } return rh } func (r *RedisHosts) List() (HostMap, error) { res, err := r.redis.HGetAll(r.key).Result() if err != nil { return nil, err } hosts := make(HostMap) for k, v := range res { var domainRecords []Host if err = json.Unmarshal([]byte(v), &domainRecords); err != nil { continue } hosts[k] = domainRecords } return nil, errUnsupportedOperation } func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) { log.WithFields(log.Fields{ "queryType": dns.TypeToString[queryType], "question": domain, }).Debug("Checking redis provider") domain = strings.ToLower(domain) if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil { var h []Host if err = json.Unmarshal([]byte(res), &h); err != nil { return nil, err } for _, host := range h { if host.Type == queryType { return &host, nil } } } if idx := strings.Index(domain, "."); idx != -1 { wildcard := "*." + domain[strings.Index(domain, ".")+1:] if res, err := r.redis.HGet(r.key, wildcard).Result(); res != "" && err == nil { var h []Host if err = json.Unmarshal([]byte(res), &h); err != nil { return nil, err } for _, host := range h { if host.Type == queryType { return &host, nil } } } } return nil, errRecordNotFound } func (r *RedisHosts) Set(domain string, host *Host) error { hosts := []*Host{host} if res, err := r.redis.HGet(r.key, domain).Result(); res != "" && err == nil { if err = json.Unmarshal([]byte(res), &hosts); err != nil { return err } hosts = append(hosts, host) } b, err := json.Marshal(hosts) if err != nil { return err } _, err = r.redis.HSet(r.key, strings.ToLower(domain), b).Result() return err }