API Implementation, patches
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Tyler 2021-04-15 00:41:06 -04:00
parent e3958febc7
commit b6efd0df0c
9 changed files with 304 additions and 38 deletions

View File

@ -6,12 +6,22 @@ import (
"net/http" "net/http"
) )
func Start() error { func New() *API {
r := chi.NewRouter() r := chi.NewRouter()
r.Use(render.SetContentType(render.ContentTypeJSON)) r.Use(render.SetContentType(render.ContentTypeJSON))
r.Get("/hosts", hostsGet) return &API{router: r}
}
return http.ListenAndServe(":8080", r) type API struct {
router chi.Router
}
func (a *API) Router() chi.Router {
return a.router
}
func (a *API) Start() error {
return http.ListenAndServe(":8080", a.router)
} }

View File

@ -1,7 +0,0 @@
package api
import "net/http"
func hostsGet(w http.ResponseWriter, r *http.Request) {
}

View File

@ -17,7 +17,7 @@ type Handler struct {
resolver *resolver.Resolver resolver *resolver.Resolver
middleware []MiddlewareFunc middleware []MiddlewareFunc
cache, negCache cache.Cache cache, negCache cache.Cache
hosts hosts.Hosts hosts *hosts.ProviderList
} }
type MiddlewareFunc func(w dns.ResponseWriter, r *dns.Msg, m *dns.Msg) *dns.Msg type MiddlewareFunc func(w dns.ResponseWriter, r *dns.Msg, m *dns.Msg) *dns.Msg
@ -34,10 +34,15 @@ func TsigMiddleware(secretKey string) MiddlewareFunc {
} }
} }
func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h hosts.Hosts) *Handler { func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h *hosts.ProviderList) *Handler {
return &Handler{r, make([]MiddlewareFunc, 0), resolverCache, negCache, h} return &Handler{r, make([]MiddlewareFunc, 0), resolverCache, negCache, h}
} }
func (h *Handler) Use(f MiddlewareFunc) *Handler {
h.middleware = append(h.middleware, f)
return h
}
// do handles a dns request. // do handles a dns request.
// network will decide which network type it is (udp, tcp, https, etc) // network will decide which network type it is (udp, tcp, https, etc)
func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) { func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) {

View File

@ -5,9 +5,7 @@ import (
"time" "time"
) )
var ( type HostMap map[string][]Host
zeroDuration = time.Duration(0)
)
type Host struct { type Host struct {
Type uint16 `json:"type"` Type uint16 `json:"type"`
@ -19,23 +17,58 @@ func (h *Host) TypeString() string {
return dns.TypeToString[h.Type] return dns.TypeToString[h.Type]
} }
type Hosts interface {
Get(queryType uint16, domain string) (*Host, error)
}
type ProviderList struct { type ProviderList struct {
providers []Provider providers []Provider
} }
// Provider is an interface specifying a host source
// Each source should support AT LEAST List and Get, but can support Writer as well
type Provider interface { type Provider interface {
List() (HostMap, error)
Get(queryType uint16, domain string) (*Host, error) Get(queryType uint16, domain string) (*Host, error)
Set(domain string, host *Host) error
} }
func NewHosts(providers []Provider) Hosts { // Writer is an interface to modify hosts.
// Examples of this include Redis, Bolt, MySQL, etc.
type Writer interface {
Set(domain string, host *Host) error
Delete(domain string) error
}
type ProviderWriter interface {
Provider
Writer
}
func NewHosts(providers []Provider) *ProviderList {
return &ProviderList{providers} return &ProviderList{providers}
} }
// List returns all results, merged into one HostMap
func (h *ProviderList) List() (HostMap, error) {
hostMap := make(HostMap)
for _, provider := range h.providers {
hosts, err := provider.List()
if err != nil {
continue
}
for k, v := range hosts {
if existing, ok := hostMap[k]; ok {
existing = append(existing, v...)
hostMap[k] = existing
} else {
hostMap[k] = v
}
}
}
return hostMap, nil
}
// Get Matches values to providers, loping each in order // Get Matches values to providers, loping each in order
func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) { func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) {
var host *Host var host *Host
@ -50,4 +83,36 @@ func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) {
} }
return host, err return host, err
}
// Set invokes each provider, setting the host on the first one to return a nil error
func (h *ProviderList) Set(domain string, host *Host) (err error) {
for _, provider := range h.providers {
if writer, ok := provider.(Writer); ok {
err = writer.Set(domain, host)
if err == nil {
return
}
}
}
err = errUnsupportedOperation
return
}
// Delete invokes each provider, removing the host on the first one to return a nil error
func (h *ProviderList) Delete(domain string) (err error) {
for _, provider := range h.providers {
if writer, ok := provider.(Writer); ok {
err = writer.Delete(domain)
if err == nil {
return
}
}
}
err = errUnsupportedOperation
return
} }

140
hosts/hosts_api.go Normal file
View File

@ -0,0 +1,140 @@
package hosts
import (
"github.com/go-chi/chi"
"github.com/go-chi/render"
"github.com/miekg/dns"
"net/http"
"time"
)
const (
defaultDuration = 600 * time.Second
)
func EnableAPI(h Provider, r chi.Router) {
a := &api{hosts: h}
r.Route("/hosts", func(sub chi.Router) {
sub.Get("/", a.hostsGet)
sub.Post("/", a.hostsCreate)
sub.Patch("/{domain}", a.hostsUpdate)
sub.Delete("/{domain}", a.hostsDelete)
})
}
type api struct {
hosts Provider
}
// hostsGet handles GET requests on /hosts (list records)
func (a *api) hostsGet(w http.ResponseWriter, r *http.Request) {
hosts, err := a.hosts.List()
if err != nil {
return
}
render.JSON(w, r, hosts)
}
type requestBody struct {
Domain string `json:"domain"`
Type string `json:"type"`
Values []string `json:"values"`
TTL int `json:"ttl"`
}
func (b requestBody) TTLDuration() time.Duration {
if b.TTL > 0 {
return time.Duration(b.TTL) * time.Second
}
return defaultDuration
}
// hostsUpdate handles POST requests on /hosts
func (a *api) hostsCreate(w http.ResponseWriter, r *http.Request) {
var writer Writer
var ok bool
if writer, ok = a.hosts.(Writer); !ok {
w.WriteHeader(http.StatusNotImplemented)
return
}
var request requestBody
err := render.DefaultDecoder(r, &request)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
var recordType uint16
if recordType, ok = dns.StringToType[request.Type]; !ok {
w.WriteHeader(http.StatusBadRequest)
return
}
err = writer.Set(request.Domain, &Host{
Type: recordType,
Values: request.Values,
TTL: request.TTLDuration(),
})
if err != nil {
}
}
// hostsUpdate handles PATCH requests on /hosts/:domain
func (a *api) hostsUpdate(w http.ResponseWriter, r *http.Request) {
domain := chi.URLParam(r, "domain")
if domain == "" {
w.WriteHeader(http.StatusBadRequest)
return
}
var writer Writer
var ok bool
if writer, ok = a.hosts.(Writer); !ok {
w.WriteHeader(http.StatusNotImplemented)
return
}
// TODO: Read record from provider, update data from body, save
err := writer.Set(domain, nil)
if err != nil {
}
}
// hostsDelete handles DELETE requests on /hosts/:domain
func (a *api) hostsDelete(w http.ResponseWriter, r *http.Request) {
domain := chi.URLParam(r, "domain")
if domain == "" {
w.WriteHeader(http.StatusBadRequest)
return
}
var writer Writer
var ok bool
if writer, ok = a.hosts.(Writer); !ok {
w.WriteHeader(http.StatusNotImplemented)
return
}
err := writer.Delete(domain)
if err != nil {
}
}

View File

@ -2,8 +2,6 @@ package hosts
import ( import (
"encoding/json" "encoding/json"
"errors"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt" bolt "go.etcd.io/bbolt"
"strings" "strings"
@ -40,6 +38,34 @@ func NewBoltProvider(file string) Provider {
} }
} }
func (b *BoltHosts) List() (HostMap, error) {
hosts := make(HostMap)
err := b.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("records"))
c := b.Cursor()
for k, v := c.First(); k != nil; k, v = c.Next() {
var domainRecords []Host
if err := json.Unmarshal(v, &domainRecords); err != nil {
continue
}
hosts[string(k)] = domainRecords
}
return nil
})
if err != nil {
return nil, err
}
return hosts, nil
}
func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) { func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
log.Debug("Checking bolt provider for %s : %s", queryType, domain) log.Debug("Checking bolt provider for %s : %s", queryType, domain)
@ -47,25 +73,24 @@ func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
var err error var err error
key := domain + "_" + dns.TypeToString[queryType]
var v []byte var v []byte
err = b.db.View(func(tx *bolt.Tx) error { err = b.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("records")) b := tx.Bucket([]byte("records"))
v = b.Get([]byte(key)) v = b.Get([]byte(domain))
if string(v) == "" { if string(v) != "" {
return errors.New( "Record not found, key: " + key) return nil
} }
v = b.Get([]byte("*." + key)) v = b.Get([]byte("*." + domain))
if string(v) == "" { if string(v) != "" {
return errors.New( "Record not found, key: " + key) return nil
} }
return nil return errRecordNotFound
}) })
if err != nil { if err != nil {

View File

@ -57,6 +57,10 @@ func NewFileProvider(file string, ttl time.Duration) Provider {
return fp return fp
} }
func (f *FileHosts) List() (HostMap, error) {
return nil, errUnsupportedOperation
}
func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) { func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
log.Debug("Checking file provider for %s : %s", queryType, domain) log.Debug("Checking file provider for %s : %s", queryType, domain)
@ -90,10 +94,6 @@ func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
return nil, errRecordNotFound return nil, errRecordNotFound
} }
func (f *FileHosts) Set(domain string, host *Host) error {
return errUnsupportedOperation
}
var ( var (
hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$") hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$")
) )

View File

@ -23,6 +23,28 @@ func NewRedisProvider(rc *redis.Client, key string) Provider {
return rh return rh
} }
func (r *RedisHosts) List() (HostMap, error) {
res, err := r.redis.HGetAll(r.key).Result()
if err != nil {
return nil, err
}
hosts := make(HostMap)
for k, v := range res {
var domainRecords []Host
if err = json.Unmarshal([]byte(v), &domainRecords); err != nil {
continue
}
hosts[k] = domainRecords
}
return nil, errUnsupportedOperation
}
func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) { func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) {
log.Debug("Checking redis provider for %s", domain) log.Debug("Checking redis provider for %s", domain)

12
main.go
View File

@ -108,9 +108,15 @@ func main() {
providers = append(providers, hosts.NewRedisProvider(rc, viper.GetString("hosts.redis.key"))) providers = append(providers, hosts.NewRedisProvider(rc, viper.GetString("hosts.redis.key")))
} }
h := hosts.NewHosts(providers)
a := api.New()
hosts.EnableAPI(h, a.Router())
if viper.GetBool("api.enabled") { if viper.GetBool("api.enabled") {
go func() { go func() {
err := api.Start() err := a.Start()
if err != nil { if err != nil {
log.WithError(err).Fatalln("Unable to bind API") log.WithError(err).Fatalln("Unable to bind API")
@ -118,9 +124,9 @@ func main() {
}() }()
} }
h := hosts.NewHosts(providers) handler := NewHandler(r, resolverCache, negCache, h)
server.Run(NewHandler(r, resolverCache, negCache, h)) server.Run(handler)
log.Infof("joker dns %s (%s)", Version, runtime.Version()) log.Infof("joker dns %s (%s)", Version, runtime.Version())