diff --git a/resolver.go b/resolver.go index a83df5f..b3a095f 100644 --- a/resolver.go +++ b/resolver.go @@ -11,6 +11,8 @@ import ( "time" "github.com/miekg/dns" + "errors" + "crypto/tls" ) type ResolvError struct { @@ -33,6 +35,10 @@ type Resolver struct { servers []string domain_server *suffixTreeNode config *ResolvSettings + + tcpClient *dns.Client + udpClient *dns.Client + httpsClient *dns.Client } func NewResolver(c ResolvSettings) *Resolver { @@ -64,6 +70,26 @@ func NewResolver(c ResolvSettings) *Resolver { r.servers = append([]string{c.DOHServer}, r.servers...) } + timeout := r.Timeout() + + r.udpClient = &dns.Client{ + Net: "udp", + ReadTimeout: timeout, + WriteTimeout: timeout, + } + + r.tcpClient = &dns.Client{ + Net: "tcp", + ReadTimeout: timeout, + WriteTimeout: timeout, + } + + r.httpsClient = &dns.Client{ + Net: "https", + ReadTimeout: timeout, + WriteTimeout: timeout, + } + return r } @@ -149,24 +175,6 @@ func (r *Resolver) ReadServerListFile(path string) { // in every second, and return as early as possbile (have an answer). // It returns an error if no request has succeeded. func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error) { - c := &dns.Client{ - Net: net, - ReadTimeout: r.Timeout(), - WriteTimeout: r.Timeout(), - } - - httpC := &dns.Client{ - Net: "https", - ReadTimeout: r.Timeout(), - WriteTimeout: r.Timeout(), - } - - tlsClient := &dns.Client{ - Net: "tcp-tls", - ReadTimeout: r.Timeout(), - WriteTimeout: r.Timeout(), - } - if net == "udp" && settings.ResolvConfig.SetEDNS0 { req = req.SetEdns0(65535, true) } @@ -175,21 +183,18 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error res := make(chan *RResp, 1) var wg sync.WaitGroup - L := func(nameserver string) { + L := func(resolver *Resolver, nameserver string) { defer wg.Done() - var r *dns.Msg - var rtt time.Duration - var err error + c, err := resolver.resolverFor(net, nameserver) - if strings.HasPrefix(nameserver, "https") { - r, rtt, err = httpC.Exchange(req, nameserver) - } else if strings.HasPrefix(nameserver, ":853") { - r, rtt, err = tlsClient.Exchange(req, nameserver) - } else { - r, rtt, err = c.Exchange(req, nameserver) + if err != nil { + logger.Warn("error:%s", err.Error()) + return } + r, rtt, err := c.Exchange(req, nameserver) + if err != nil { logger.Warn("%s socket error on %s", qname, nameserver) logger.Warn("error:%s", err.Error()) @@ -218,7 +223,7 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error nameservers := r.Nameservers(qname) for _, nameserver := range nameservers { wg.Add(1) - go L(nameserver) + go L(r, nameserver) // but exit early, if we have an answer select { case re := <-res: @@ -239,6 +244,28 @@ func (r *Resolver) Lookup(net string, req *dns.Msg) (message *dns.Msg, err error } } +func (r *Resolver) resolverFor(net, nameserver string) (*dns.Client, error) { + if strings.HasPrefix(nameserver, "https") { + return r.httpsClient, nil + } else if strings.HasSuffix(nameserver, ":853") { + // TODO We need to set the server name so we can confirm the TLS connection. This may require a rewrite of storing nameservers. + return &dns.Client{ + Net: "tcp-tls", + ReadTimeout: r.Timeout(), + WriteTimeout: r.Timeout(), + TLSConfig: &tls.Config{ + ServerName: "", + }, + }, nil + } else if net == "udp" { + return r.udpClient, nil + } else if net == "tcp" { + return r.tcpClient, nil + } + + return nil, errors.New("no client for nameserver") +} + // Namservers return the array of nameservers, with port number appended. // '#' in the name is treated as port separator, as with dnsmasq.