From d99d5902a110d96c43555d349f6d402f3ced7a49 Mon Sep 17 00:00:00 2001 From: kenshinx Date: Thu, 12 Feb 2015 17:19:46 +0800 Subject: [PATCH] Merged pull request kenshinx/godns#10 --- resolver.go | 46 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/resolver.go b/resolver.go index b568833..6649b50 100644 --- a/resolver.go +++ b/resolver.go @@ -3,18 +3,19 @@ package main import ( "fmt" "strings" + "sync" "time" "github.com/miekg/dns" ) type ResolvError struct { - qname string + qname, net string nameservers []string } func (e ResolvError) Error() string { - errmsg := fmt.Sprintf("%s resolv failed on %s", e.qname, strings.Join(e.nameservers, "; ")) + errmsg := fmt.Sprintf("%s resolv failed on %s (%s)", e.qname, strings.Join(e.nameservers, "; "), e.net) return errmsg } @@ -22,6 +23,9 @@ type Resolver struct { config *dns.ClientConfig } +// Lookup will ask each nameserver in top-to-bottom fashion, starting a new request +// in every second, and return as early as possbile (have an answer). +// It returns an error if no request has succeeded. func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error) { c := &dns.Client{ Net: net, @@ -31,21 +35,49 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error qname := req.Question[0].Name - for _, nameserver := range r.Nameservers() { + res := make(chan *dns.Msg, 1) + var wg sync.WaitGroup + L := func(nameserver string) { + defer wg.Done() r, rtt, err := c.Exchange(req, nameserver) if err != nil { Debug("%s socket error on %s", qname, nameserver) Debug("error:%s", err.Error()) - continue + return } if r != nil && r.Rcode != dns.RcodeSuccess { Debug("%s failed to get an valid answer on %s", qname, nameserver) + return + } + Debug("%s resolv on %s (%s) ttl: %d", UnFqdn(qname), nameserver, net, rtt) + select { + case res <- r: + default: + } + } + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + // Start lookup on each nameserver top-down, in every second + for _, nameserver := range r.Nameservers() { + wg.Add(1) + go L(nameserver) + // but exit early, if we have an answer + select { + case r := <-res: + return r, nil + case <-ticker.C: continue } - Debug("%s resolv on %s ttl: %d", UnFqdn(qname), nameserver, rtt) - return r, nil } - return nil, ResolvError{qname, r.Nameservers()} + // wait for all the namservers to finish + wg.Wait() + select { + case r := <-res: + return r, nil + default: + return nil, ResolvError{qname, net, r.Nameservers()} + } }