API Implementation, patches
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
e3958febc7
commit
b6efd0df0c
18
api/api.go
18
api/api.go
@ -6,12 +6,22 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func Start() error {
|
||||
func New() *API {
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Use(render.SetContentType(render.ContentTypeJSON))
|
||||
|
||||
r.Get("/hosts", hostsGet)
|
||||
|
||||
return http.ListenAndServe(":8080", r)
|
||||
return &API{router: r}
|
||||
}
|
||||
|
||||
type API struct {
|
||||
router chi.Router
|
||||
}
|
||||
|
||||
func (a *API) Router() chi.Router {
|
||||
return a.router
|
||||
}
|
||||
|
||||
func (a *API) Start() error {
|
||||
return http.ListenAndServe(":8080", a.router)
|
||||
}
|
@ -1,7 +0,0 @@
|
||||
package api
|
||||
|
||||
import "net/http"
|
||||
|
||||
func hostsGet(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
}
|
@ -17,7 +17,7 @@ type Handler struct {
|
||||
resolver *resolver.Resolver
|
||||
middleware []MiddlewareFunc
|
||||
cache, negCache cache.Cache
|
||||
hosts hosts.Hosts
|
||||
hosts *hosts.ProviderList
|
||||
}
|
||||
|
||||
type MiddlewareFunc func(w dns.ResponseWriter, r *dns.Msg, m *dns.Msg) *dns.Msg
|
||||
@ -34,10 +34,15 @@ func TsigMiddleware(secretKey string) MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h hosts.Hosts) *Handler {
|
||||
func NewHandler(r *resolver.Resolver, resolverCache, negCache cache.Cache, h *hosts.ProviderList) *Handler {
|
||||
return &Handler{r, make([]MiddlewareFunc, 0), resolverCache, negCache, h}
|
||||
}
|
||||
|
||||
func (h *Handler) Use(f MiddlewareFunc) *Handler {
|
||||
h.middleware = append(h.middleware, f)
|
||||
return h
|
||||
}
|
||||
|
||||
// do handles a dns request.
|
||||
// network will decide which network type it is (udp, tcp, https, etc)
|
||||
func (h *Handler) do(network string, w dns.ResponseWriter, req *dns.Msg) {
|
||||
|
@ -5,9 +5,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
zeroDuration = time.Duration(0)
|
||||
)
|
||||
type HostMap map[string][]Host
|
||||
|
||||
type Host struct {
|
||||
Type uint16 `json:"type"`
|
||||
@ -19,23 +17,58 @@ func (h *Host) TypeString() string {
|
||||
return dns.TypeToString[h.Type]
|
||||
}
|
||||
|
||||
type Hosts interface {
|
||||
Get(queryType uint16, domain string) (*Host, error)
|
||||
}
|
||||
|
||||
type ProviderList struct {
|
||||
providers []Provider
|
||||
}
|
||||
|
||||
// Provider is an interface specifying a host source
|
||||
// Each source should support AT LEAST List and Get, but can support Writer as well
|
||||
type Provider interface {
|
||||
List() (HostMap, error)
|
||||
Get(queryType uint16, domain string) (*Host, error)
|
||||
Set(domain string, host *Host) error
|
||||
}
|
||||
|
||||
func NewHosts(providers []Provider) Hosts {
|
||||
// Writer is an interface to modify hosts.
|
||||
// Examples of this include Redis, Bolt, MySQL, etc.
|
||||
type Writer interface {
|
||||
Set(domain string, host *Host) error
|
||||
Delete(domain string) error
|
||||
}
|
||||
|
||||
type ProviderWriter interface {
|
||||
Provider
|
||||
Writer
|
||||
}
|
||||
|
||||
func NewHosts(providers []Provider) *ProviderList {
|
||||
return &ProviderList{providers}
|
||||
}
|
||||
|
||||
// List returns all results, merged into one HostMap
|
||||
func (h *ProviderList) List() (HostMap, error) {
|
||||
hostMap := make(HostMap)
|
||||
|
||||
for _, provider := range h.providers {
|
||||
hosts, err := provider.List()
|
||||
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for k, v := range hosts {
|
||||
if existing, ok := hostMap[k]; ok {
|
||||
existing = append(existing, v...)
|
||||
|
||||
hostMap[k] = existing
|
||||
} else {
|
||||
hostMap[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return hostMap, nil
|
||||
}
|
||||
|
||||
// Get Matches values to providers, loping each in order
|
||||
func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) {
|
||||
var host *Host
|
||||
@ -51,3 +84,35 @@ func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) {
|
||||
|
||||
return host, err
|
||||
}
|
||||
|
||||
// Set invokes each provider, setting the host on the first one to return a nil error
|
||||
func (h *ProviderList) Set(domain string, host *Host) (err error) {
|
||||
for _, provider := range h.providers {
|
||||
if writer, ok := provider.(Writer); ok {
|
||||
err = writer.Set(domain, host)
|
||||
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = errUnsupportedOperation
|
||||
return
|
||||
}
|
||||
|
||||
// Delete invokes each provider, removing the host on the first one to return a nil error
|
||||
func (h *ProviderList) Delete(domain string) (err error) {
|
||||
for _, provider := range h.providers {
|
||||
if writer, ok := provider.(Writer); ok {
|
||||
err = writer.Delete(domain)
|
||||
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = errUnsupportedOperation
|
||||
return
|
||||
}
|
140
hosts/hosts_api.go
Normal file
140
hosts/hosts_api.go
Normal file
@ -0,0 +1,140 @@
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/miekg/dns"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDuration = 600 * time.Second
|
||||
)
|
||||
|
||||
func EnableAPI(h Provider, r chi.Router) {
|
||||
a := &api{hosts: h}
|
||||
|
||||
r.Route("/hosts", func(sub chi.Router) {
|
||||
sub.Get("/", a.hostsGet)
|
||||
sub.Post("/", a.hostsCreate)
|
||||
sub.Patch("/{domain}", a.hostsUpdate)
|
||||
sub.Delete("/{domain}", a.hostsDelete)
|
||||
})
|
||||
}
|
||||
|
||||
type api struct {
|
||||
hosts Provider
|
||||
}
|
||||
|
||||
// hostsGet handles GET requests on /hosts (list records)
|
||||
func (a *api) hostsGet(w http.ResponseWriter, r *http.Request) {
|
||||
hosts, err := a.hosts.List()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, r, hosts)
|
||||
}
|
||||
|
||||
type requestBody struct {
|
||||
Domain string `json:"domain"`
|
||||
Type string `json:"type"`
|
||||
Values []string `json:"values"`
|
||||
TTL int `json:"ttl"`
|
||||
}
|
||||
|
||||
func (b requestBody) TTLDuration() time.Duration {
|
||||
if b.TTL > 0 {
|
||||
return time.Duration(b.TTL) * time.Second
|
||||
}
|
||||
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
// hostsUpdate handles POST requests on /hosts
|
||||
func (a *api) hostsCreate(w http.ResponseWriter, r *http.Request) {
|
||||
var writer Writer
|
||||
var ok bool
|
||||
|
||||
if writer, ok = a.hosts.(Writer); !ok {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
var request requestBody
|
||||
|
||||
err := render.DefaultDecoder(r, &request)
|
||||
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var recordType uint16
|
||||
|
||||
if recordType, ok = dns.StringToType[request.Type]; !ok {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = writer.Set(request.Domain, &Host{
|
||||
Type: recordType,
|
||||
Values: request.Values,
|
||||
TTL: request.TTLDuration(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// hostsUpdate handles PATCH requests on /hosts/:domain
|
||||
func (a *api) hostsUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
domain := chi.URLParam(r, "domain")
|
||||
|
||||
if domain == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var writer Writer
|
||||
var ok bool
|
||||
|
||||
if writer, ok = a.hosts.(Writer); !ok {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Read record from provider, update data from body, save
|
||||
err := writer.Set(domain, nil)
|
||||
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// hostsDelete handles DELETE requests on /hosts/:domain
|
||||
func (a *api) hostsDelete(w http.ResponseWriter, r *http.Request) {
|
||||
domain := chi.URLParam(r, "domain")
|
||||
|
||||
if domain == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var writer Writer
|
||||
var ok bool
|
||||
|
||||
if writer, ok = a.hosts.(Writer); !ok {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
err := writer.Delete(domain)
|
||||
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
}
|
@ -2,8 +2,6 @@ package hosts
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
"strings"
|
||||
@ -40,6 +38,34 @@ func NewBoltProvider(file string) Provider {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BoltHosts) List() (HostMap, error) {
|
||||
hosts := make(HostMap)
|
||||
|
||||
err := b.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte("records"))
|
||||
|
||||
c := b.Cursor()
|
||||
|
||||
for k, v := c.First(); k != nil; k, v = c.Next() {
|
||||
var domainRecords []Host
|
||||
|
||||
if err := json.Unmarshal(v, &domainRecords); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
hosts[string(k)] = domainRecords
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
|
||||
log.Debug("Checking bolt provider for %s : %s", queryType, domain)
|
||||
|
||||
@ -47,25 +73,24 @@ func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
|
||||
|
||||
var err error
|
||||
|
||||
key := domain + "_" + dns.TypeToString[queryType]
|
||||
var v []byte
|
||||
|
||||
err = b.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte("records"))
|
||||
|
||||
v = b.Get([]byte(key))
|
||||
|
||||
if string(v) == "" {
|
||||
return errors.New( "Record not found, key: " + key)
|
||||
}
|
||||
|
||||
v = b.Get([]byte("*." + key))
|
||||
|
||||
if string(v) == "" {
|
||||
return errors.New( "Record not found, key: " + key)
|
||||
}
|
||||
v = b.Get([]byte(domain))
|
||||
|
||||
if string(v) != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
v = b.Get([]byte("*." + domain))
|
||||
|
||||
if string(v) != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errRecordNotFound
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -57,6 +57,10 @@ func NewFileProvider(file string, ttl time.Duration) Provider {
|
||||
return fp
|
||||
}
|
||||
|
||||
func (f *FileHosts) List() (HostMap, error) {
|
||||
return nil, errUnsupportedOperation
|
||||
}
|
||||
|
||||
func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
|
||||
log.Debug("Checking file provider for %s : %s", queryType, domain)
|
||||
|
||||
@ -90,10 +94,6 @@ func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
|
||||
func (f *FileHosts) Set(domain string, host *Host) error {
|
||||
return errUnsupportedOperation
|
||||
}
|
||||
|
||||
var (
|
||||
hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$")
|
||||
)
|
||||
|
@ -23,6 +23,28 @@ func NewRedisProvider(rc *redis.Client, key string) Provider {
|
||||
return rh
|
||||
}
|
||||
|
||||
func (r *RedisHosts) List() (HostMap, error) {
|
||||
res, err := r.redis.HGetAll(r.key).Result()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hosts := make(HostMap)
|
||||
|
||||
for k, v := range res {
|
||||
var domainRecords []Host
|
||||
|
||||
if err = json.Unmarshal([]byte(v), &domainRecords); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
hosts[k] = domainRecords
|
||||
}
|
||||
|
||||
return nil, errUnsupportedOperation
|
||||
}
|
||||
|
||||
func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) {
|
||||
log.Debug("Checking redis provider for %s", domain)
|
||||
|
||||
|
12
main.go
12
main.go
@ -108,9 +108,15 @@ func main() {
|
||||
providers = append(providers, hosts.NewRedisProvider(rc, viper.GetString("hosts.redis.key")))
|
||||
}
|
||||
|
||||
h := hosts.NewHosts(providers)
|
||||
|
||||
a := api.New()
|
||||
|
||||
hosts.EnableAPI(h, a.Router())
|
||||
|
||||
if viper.GetBool("api.enabled") {
|
||||
go func() {
|
||||
err := api.Start()
|
||||
err := a.Start()
|
||||
|
||||
if err != nil {
|
||||
log.WithError(err).Fatalln("Unable to bind API")
|
||||
@ -118,9 +124,9 @@ func main() {
|
||||
}()
|
||||
}
|
||||
|
||||
h := hosts.NewHosts(providers)
|
||||
handler := NewHandler(r, resolverCache, negCache, h)
|
||||
|
||||
server.Run(NewHandler(r, resolverCache, negCache, h))
|
||||
server.Run(handler)
|
||||
|
||||
log.Infof("joker dns %s (%s)", Version, runtime.Version())
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user