From b6efd0df0cfb9b046ce892f4d32d6bd14594aba7 Mon Sep 17 00:00:00 2001 From: Tyler Date: Thu, 15 Apr 2021 00:41:06 -0400 Subject: [PATCH] API Implementation, patches --- api/api.go | 16 ++++- api/hosts.go | 7 --- handler.go | 9 ++- hosts/hosts.go | 83 ++++++++++++++++++++++--- hosts/hosts_api.go | 140 +++++++++++++++++++++++++++++++++++++++++++ hosts/hosts_bolt.go | 45 ++++++++++---- hosts/hosts_file.go | 8 +-- hosts/hosts_redis.go | 22 +++++++ main.go | 12 +++- 9 files changed, 304 insertions(+), 38 deletions(-) delete mode 100644 api/hosts.go create mode 100644 hosts/hosts_api.go diff --git a/api/api.go b/api/api.go index fc2cdef..08014a5 100644 --- a/api/api.go +++ b/api/api.go @@ -6,12 +6,22 @@ import ( "net/http" ) -func Start() error { +func New() *API { r := chi.NewRouter() r.Use(render.SetContentType(render.ContentTypeJSON)) - r.Get("/hosts", hostsGet) + return &API{router: r} +} - return http.ListenAndServe(":8080", r) +type API struct { + router chi.Router +} + +func (a *API) Router() chi.Router { + return a.router +} + +func (a *API) Start() error { + return http.ListenAndServe(":8080", a.router) } \ No newline at end of file diff --git a/api/hosts.go b/api/hosts.go deleted file mode 100644 index 2737da5..0000000 --- a/api/hosts.go +++ /dev/null @@ -1,7 +0,0 @@ -package api - -import "net/http" - -func hostsGet(w http.ResponseWriter, r *http.Request) { - -} diff --git a/handler.go b/handler.go index 524c9a3..f3c2b7e 100644 --- a/handler.go +++ b/handler.go @@ -17,7 +17,7 @@ type Handler struct { resolver *resolver.Resolver middleware []MiddlewareFunc cache, negCache cache.Cache - hosts hosts.Hosts + hosts *hosts.ProviderList } type MiddlewareFunc func(w dns.ResponseWriter, r *dns.Msg, m *dns.Msg) *dns.Msg @@ -34,10 +34,15 @@ func TsigMiddleware(secretKey string) MiddlewareFunc { } } -func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h hosts.Hosts) *Handler { +func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h *hosts.ProviderList) *Handler { return &Handler{r, make([]MiddlewareFunc, 0), resolverCache, negCache, h} } +func (h *Handler) Use(f MiddlewareFunc) *Handler { + h.middleware = append(h.middleware, f) + return h +} + // do handles a dns request. // network will decide which network type it is (udp, tcp, https, etc) func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) { diff --git a/hosts/hosts.go b/hosts/hosts.go index 636f48b..601981e 100644 --- a/hosts/hosts.go +++ b/hosts/hosts.go @@ -5,9 +5,7 @@ import ( "time" ) -var ( - zeroDuration = time.Duration(0) -) +type HostMap map[string][]Host type Host struct { Type uint16 `json:"type"` @@ -19,23 +17,58 @@ func (h *Host) TypeString() string { return dns.TypeToString[h.Type] } -type Hosts interface { - Get(queryType uint16, domain string) (*Host, error) -} - type ProviderList struct { providers []Provider } +// Provider is an interface specifying a host source +// Each source should support AT LEAST List and Get, but can support Writer as well type Provider interface { + List() (HostMap, error) Get(queryType uint16, domain string) (*Host, error) - Set(domain string, host *Host) error } -func NewHosts(providers []Provider) Hosts { +// Writer is an interface to modify hosts. +// Examples of this include Redis, Bolt, MySQL, etc. +type Writer interface { + Set(domain string, host *Host) error + Delete(domain string) error +} + +type ProviderWriter interface { + Provider + Writer +} + +func NewHosts(providers []Provider) *ProviderList { return &ProviderList{providers} } +// List returns all results, merged into one HostMap +func (h *ProviderList) List() (HostMap, error) { + hostMap := make(HostMap) + + for _, provider := range h.providers { + hosts, err := provider.List() + + if err != nil { + continue + } + + for k, v := range hosts { + if existing, ok := hostMap[k]; ok { + existing = append(existing, v...) + + hostMap[k] = existing + } else { + hostMap[k] = v + } + } + } + + return hostMap, nil +} + // Get Matches values to providers, loping each in order func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) { var host *Host @@ -50,4 +83,36 @@ func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) { } return host, err +} + +// Set invokes each provider, setting the host on the first one to return a nil error +func (h *ProviderList) Set(domain string, host *Host) (err error) { + for _, provider := range h.providers { + if writer, ok := provider.(Writer); ok { + err = writer.Set(domain, host) + + if err == nil { + return + } + } + } + + err = errUnsupportedOperation + return +} + +// Delete invokes each provider, removing the host on the first one to return a nil error +func (h *ProviderList) Delete(domain string) (err error) { + for _, provider := range h.providers { + if writer, ok := provider.(Writer); ok { + err = writer.Delete(domain) + + if err == nil { + return + } + } + } + + err = errUnsupportedOperation + return } \ No newline at end of file diff --git a/hosts/hosts_api.go b/hosts/hosts_api.go new file mode 100644 index 0000000..5624682 --- /dev/null +++ b/hosts/hosts_api.go @@ -0,0 +1,140 @@ +package hosts + +import ( + "github.com/go-chi/chi" + "github.com/go-chi/render" + "github.com/miekg/dns" + "net/http" + "time" +) + +const ( + defaultDuration = 600 * time.Second +) + +func EnableAPI(h Provider, r chi.Router) { + a := &api{hosts: h} + + r.Route("/hosts", func(sub chi.Router) { + sub.Get("/", a.hostsGet) + sub.Post("/", a.hostsCreate) + sub.Patch("/{domain}", a.hostsUpdate) + sub.Delete("/{domain}", a.hostsDelete) + }) +} + +type api struct { + hosts Provider +} + +// hostsGet handles GET requests on /hosts (list records) +func (a *api) hostsGet(w http.ResponseWriter, r *http.Request) { + hosts, err := a.hosts.List() + + if err != nil { + return + } + + render.JSON(w, r, hosts) +} + +type requestBody struct { + Domain string `json:"domain"` + Type string `json:"type"` + Values []string `json:"values"` + TTL int `json:"ttl"` +} + +func (b requestBody) TTLDuration() time.Duration { + if b.TTL > 0 { + return time.Duration(b.TTL) * time.Second + } + + return defaultDuration +} + +// hostsUpdate handles POST requests on /hosts +func (a *api) hostsCreate(w http.ResponseWriter, r *http.Request) { + var writer Writer + var ok bool + + if writer, ok = a.hosts.(Writer); !ok { + w.WriteHeader(http.StatusNotImplemented) + return + } + + var request requestBody + + err := render.DefaultDecoder(r, &request) + + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + var recordType uint16 + + if recordType, ok = dns.StringToType[request.Type]; !ok { + w.WriteHeader(http.StatusBadRequest) + return + } + + err = writer.Set(request.Domain, &Host{ + Type: recordType, + Values: request.Values, + TTL: request.TTLDuration(), + }) + + if err != nil { + + } +} + +// hostsUpdate handles PATCH requests on /hosts/:domain +func (a *api) hostsUpdate(w http.ResponseWriter, r *http.Request) { + domain := chi.URLParam(r, "domain") + + if domain == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + var writer Writer + var ok bool + + if writer, ok = a.hosts.(Writer); !ok { + w.WriteHeader(http.StatusNotImplemented) + return + } + + // TODO: Read record from provider, update data from body, save + err := writer.Set(domain, nil) + + if err != nil { + + } +} + +// hostsDelete handles DELETE requests on /hosts/:domain +func (a *api) hostsDelete(w http.ResponseWriter, r *http.Request) { + domain := chi.URLParam(r, "domain") + + if domain == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + var writer Writer + var ok bool + + if writer, ok = a.hosts.(Writer); !ok { + w.WriteHeader(http.StatusNotImplemented) + return + } + + err := writer.Delete(domain) + + if err != nil { + + } +} \ No newline at end of file diff --git a/hosts/hosts_bolt.go b/hosts/hosts_bolt.go index 09ce2eb..fe213ee 100644 --- a/hosts/hosts_bolt.go +++ b/hosts/hosts_bolt.go @@ -2,8 +2,6 @@ package hosts import ( "encoding/json" - "errors" - "github.com/miekg/dns" log "github.com/sirupsen/logrus" bolt "go.etcd.io/bbolt" "strings" @@ -40,6 +38,34 @@ func NewBoltProvider(file string) Provider { } } +func (b *BoltHosts) List() (HostMap, error) { + hosts := make(HostMap) + + err := b.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("records")) + + c := b.Cursor() + + for k, v := c.First(); k != nil; k, v = c.Next() { + var domainRecords []Host + + if err := json.Unmarshal(v, &domainRecords); err != nil { + continue + } + + hosts[string(k)] = domainRecords + } + + return nil + }) + + if err != nil { + return nil, err + } + + return hosts, nil +} + func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) { log.Debug("Checking bolt provider for %s : %s", queryType, domain) @@ -47,25 +73,24 @@ func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) { var err error - key := domain + "_" + dns.TypeToString[queryType] var v []byte err = b.db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("records")) - v = b.Get([]byte(key)) + v = b.Get([]byte(domain)) - if string(v) == "" { - return errors.New( "Record not found, key: " + key) + if string(v) != "" { + return nil } - v = b.Get([]byte("*." + key)) + v = b.Get([]byte("*." + domain)) - if string(v) == "" { - return errors.New( "Record not found, key: " + key) + if string(v) != "" { + return nil } - return nil + return errRecordNotFound }) if err != nil { diff --git a/hosts/hosts_file.go b/hosts/hosts_file.go index 297a24d..66db4dd 100644 --- a/hosts/hosts_file.go +++ b/hosts/hosts_file.go @@ -57,6 +57,10 @@ func NewFileProvider(file string, ttl time.Duration) Provider { return fp } +func (f *FileHosts) List() (HostMap, error) { + return nil, errUnsupportedOperation +} + func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) { log.Debug("Checking file provider for %s : %s", queryType, domain) @@ -90,10 +94,6 @@ func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) { return nil, errRecordNotFound } -func (f *FileHosts) Set(domain string, host *Host) error { - return errUnsupportedOperation -} - var ( hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$") ) diff --git a/hosts/hosts_redis.go b/hosts/hosts_redis.go index b61abd7..43e5327 100644 --- a/hosts/hosts_redis.go +++ b/hosts/hosts_redis.go @@ -23,6 +23,28 @@ func NewRedisProvider(rc *redis.Client, key string) Provider { return rh } +func (r *RedisHosts) List() (HostMap, error) { + res, err := r.redis.HGetAll(r.key).Result() + + if err != nil { + return nil, err + } + + hosts := make(HostMap) + + for k, v := range res { + var domainRecords []Host + + if err = json.Unmarshal([]byte(v), &domainRecords); err != nil { + continue + } + + hosts[k] = domainRecords + } + + return nil, errUnsupportedOperation +} + func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) { log.Debug("Checking redis provider for %s", domain) diff --git a/main.go b/main.go index 2eaae6d..a19aa44 100644 --- a/main.go +++ b/main.go @@ -108,9 +108,15 @@ func main() { providers = append(providers, hosts.NewRedisProvider(rc, viper.GetString("hosts.redis.key"))) } + h := hosts.NewHosts(providers) + + a := api.New() + + hosts.EnableAPI(h, a.Router()) + if viper.GetBool("api.enabled") { go func() { - err := api.Start() + err := a.Start() if err != nil { log.WithError(err).Fatalln("Unable to bind API") @@ -118,9 +124,9 @@ func main() { }() } - h := hosts.NewHosts(providers) + handler := NewHandler(r, resolverCache, negCache, h) - server.Run(NewHandler(r, resolverCache, negCache, h)) + server.Run(handler) log.Infof("joker dns %s (%s)", Version, runtime.Version())