diff --git a/hosts.go b/hosts.go index 1208eb6..1153ad5 100644 --- a/hosts.go +++ b/hosts.go @@ -12,22 +12,22 @@ import ( ) type Hosts struct { - FileHosts map[string]string + FileHosts *FileHosts RedisHosts *RedisHosts } func NewHosts(hs HostsSettings, rs RedisSettings) Hosts { fileHosts := &FileHosts{hs.HostsFile} - var rc *redis.Client + var redisHosts *RedisHosts if hs.RedisEnable { - rc = &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password} + rc := &redis.Client{Addr: rs.Addr(), Db: rs.DB, Password: rs.Password} + redisHosts = &RedisHosts{rc, hs.RedisKey} } else { - rc = nil + redisHosts = new(RedisHosts) } - redisHosts := &RedisHosts{rc, hs.RedisKey} - hosts := Hosts{fileHosts.GetAll(), redisHosts} + hosts := Hosts{fileHosts, redisHosts} return hosts } @@ -41,7 +41,7 @@ func NewHosts(hs HostsSettings, rs RedisSettings) Hosts { func (h *Hosts) Get(domain string, family int) (ip net.IP, ok bool) { var sip string - if sip, ok = h.FileHosts[domain]; !ok { + if sip, ok = h.FileHosts.Get(domain); !ok { if sip, ok = h.RedisHosts.Get(domain); !ok { return nil, false } @@ -64,7 +64,7 @@ func (h *Hosts) GetAll() map[string]string { for domain, ip := range h.RedisHosts.GetAll() { m[domain] = ip } - for domain, ip := range h.FileHosts { + for domain, ip := range h.FileHosts.GetAll() { m[domain] = ip } return m @@ -103,6 +103,12 @@ type FileHosts struct { file string } +func (f *FileHosts) Get(domain string) (ip string, ok bool) { + hosts := f.GetAll() + ip, ok = hosts[domain] + return +} + func (f *FileHosts) GetAll() map[string]string { var hosts = make(map[string]string)