API Implementation, patches
This commit is contained in:
		@ -5,9 +5,7 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	zeroDuration = time.Duration(0)
 | 
			
		||||
)
 | 
			
		||||
type HostMap map[string][]Host
 | 
			
		||||
 | 
			
		||||
type Host struct {
 | 
			
		||||
	Type uint16 `json:"type"`
 | 
			
		||||
@ -19,23 +17,58 @@ func (h *Host) TypeString() string {
 | 
			
		||||
	return dns.TypeToString[h.Type]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Hosts interface {
 | 
			
		||||
	Get(queryType uint16, domain string) (*Host, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ProviderList struct {
 | 
			
		||||
	providers       []Provider
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Provider is an interface specifying a host source
 | 
			
		||||
// Each source should support AT LEAST List and Get, but can support Writer as well
 | 
			
		||||
type Provider interface {
 | 
			
		||||
	List() (HostMap, error)
 | 
			
		||||
	Get(queryType uint16, domain string) (*Host, error)
 | 
			
		||||
	Set(domain string, host *Host) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewHosts(providers []Provider) Hosts {
 | 
			
		||||
// Writer is an interface to modify hosts.
 | 
			
		||||
// Examples of this include Redis, Bolt, MySQL, etc.
 | 
			
		||||
type Writer interface {
 | 
			
		||||
	Set(domain string, host *Host) error
 | 
			
		||||
	Delete(domain string) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ProviderWriter interface {
 | 
			
		||||
	Provider
 | 
			
		||||
	Writer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewHosts(providers []Provider) *ProviderList {
 | 
			
		||||
	return &ProviderList{providers}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List returns all results, merged into one HostMap
 | 
			
		||||
func (h *ProviderList) List() (HostMap, error) {
 | 
			
		||||
	hostMap := make(HostMap)
 | 
			
		||||
 | 
			
		||||
	for _, provider := range h.providers {
 | 
			
		||||
		hosts, err := provider.List()
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for k, v := range hosts {
 | 
			
		||||
			if existing, ok := hostMap[k]; ok {
 | 
			
		||||
				existing = append(existing, v...)
 | 
			
		||||
 | 
			
		||||
				hostMap[k] = existing
 | 
			
		||||
			} else {
 | 
			
		||||
				hostMap[k] = v
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return hostMap, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get Matches values to providers, loping each in order
 | 
			
		||||
func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
	var host *Host
 | 
			
		||||
@ -50,4 +83,36 @@ func (h *ProviderList) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return host, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Set invokes each provider, setting the host on the first one to return a nil error
 | 
			
		||||
func (h *ProviderList) Set(domain string, host *Host) (err error) {
 | 
			
		||||
	for _, provider := range h.providers {
 | 
			
		||||
		if writer, ok := provider.(Writer); ok {
 | 
			
		||||
			err = writer.Set(domain, host)
 | 
			
		||||
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = errUnsupportedOperation
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Delete invokes each provider, removing the host on the first one to return a nil error
 | 
			
		||||
func (h *ProviderList) Delete(domain string) (err error) {
 | 
			
		||||
	for _, provider := range h.providers {
 | 
			
		||||
		if writer, ok := provider.(Writer); ok {
 | 
			
		||||
			err = writer.Delete(domain)
 | 
			
		||||
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = errUnsupportedOperation
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										140
									
								
								hosts/hosts_api.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										140
									
								
								hosts/hosts_api.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,140 @@
 | 
			
		||||
package hosts
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/go-chi/chi"
 | 
			
		||||
	"github.com/go-chi/render"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	defaultDuration = 600 * time.Second
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func EnableAPI(h Provider, r chi.Router) {
 | 
			
		||||
	a := &api{hosts: h}
 | 
			
		||||
 | 
			
		||||
	r.Route("/hosts", func(sub chi.Router) {
 | 
			
		||||
		sub.Get("/", a.hostsGet)
 | 
			
		||||
		sub.Post("/", a.hostsCreate)
 | 
			
		||||
		sub.Patch("/{domain}", a.hostsUpdate)
 | 
			
		||||
		sub.Delete("/{domain}", a.hostsDelete)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type api struct {
 | 
			
		||||
	hosts Provider
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// hostsGet handles GET requests on /hosts (list records)
 | 
			
		||||
func (a *api) hostsGet(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	hosts, err := a.hosts.List()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.JSON(w, r, hosts)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type requestBody struct {
 | 
			
		||||
	Domain string `json:"domain"`
 | 
			
		||||
	Type string `json:"type"`
 | 
			
		||||
	Values []string `json:"values"`
 | 
			
		||||
	TTL int `json:"ttl"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b requestBody) TTLDuration() time.Duration {
 | 
			
		||||
	if b.TTL > 0 {
 | 
			
		||||
		return time.Duration(b.TTL) * time.Second
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return defaultDuration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// hostsUpdate handles POST requests on /hosts
 | 
			
		||||
func (a *api) hostsCreate(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	var writer Writer
 | 
			
		||||
	var ok bool
 | 
			
		||||
 | 
			
		||||
	if writer, ok = a.hosts.(Writer); !ok {
 | 
			
		||||
		w.WriteHeader(http.StatusNotImplemented)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var request requestBody
 | 
			
		||||
 | 
			
		||||
	err := render.DefaultDecoder(r, &request)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		w.WriteHeader(http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var recordType uint16
 | 
			
		||||
 | 
			
		||||
	if recordType, ok = dns.StringToType[request.Type]; !ok {
 | 
			
		||||
		w.WriteHeader(http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = writer.Set(request.Domain, &Host{
 | 
			
		||||
		Type: recordType,
 | 
			
		||||
		Values: request.Values,
 | 
			
		||||
		TTL: request.TTLDuration(),
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// hostsUpdate handles PATCH requests on /hosts/:domain
 | 
			
		||||
func (a *api) hostsUpdate(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	domain := chi.URLParam(r, "domain")
 | 
			
		||||
 | 
			
		||||
	if domain == "" {
 | 
			
		||||
		w.WriteHeader(http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var writer Writer
 | 
			
		||||
	var ok bool
 | 
			
		||||
 | 
			
		||||
	if writer, ok = a.hosts.(Writer); !ok {
 | 
			
		||||
		w.WriteHeader(http.StatusNotImplemented)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: Read record from provider, update data from body, save
 | 
			
		||||
	err := writer.Set(domain, nil)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// hostsDelete handles DELETE requests on /hosts/:domain
 | 
			
		||||
func (a *api) hostsDelete(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	domain := chi.URLParam(r, "domain")
 | 
			
		||||
 | 
			
		||||
	if domain == "" {
 | 
			
		||||
		w.WriteHeader(http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var writer Writer
 | 
			
		||||
	var ok bool
 | 
			
		||||
 | 
			
		||||
	if writer, ok = a.hosts.(Writer); !ok {
 | 
			
		||||
		w.WriteHeader(http.StatusNotImplemented)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := writer.Delete(domain)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -2,8 +2,6 @@ package hosts
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	bolt "go.etcd.io/bbolt"
 | 
			
		||||
	"strings"
 | 
			
		||||
@ -40,6 +38,34 @@ func NewBoltProvider(file string) Provider {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *BoltHosts) List() (HostMap, error) {
 | 
			
		||||
	hosts := make(HostMap)
 | 
			
		||||
 | 
			
		||||
	err := b.db.View(func(tx *bolt.Tx) error {
 | 
			
		||||
		b := tx.Bucket([]byte("records"))
 | 
			
		||||
 | 
			
		||||
		c := b.Cursor()
 | 
			
		||||
 | 
			
		||||
		for k, v := c.First(); k != nil; k, v = c.Next() {
 | 
			
		||||
			var domainRecords []Host
 | 
			
		||||
 | 
			
		||||
			if err := json.Unmarshal(v, &domainRecords); err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			hosts[string(k)] = domainRecords
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return hosts, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
	log.Debug("Checking bolt provider for %s : %s", queryType, domain)
 | 
			
		||||
 | 
			
		||||
@ -47,25 +73,24 @@ func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	key := domain + "_" + dns.TypeToString[queryType]
 | 
			
		||||
	var v []byte
 | 
			
		||||
 | 
			
		||||
	err = b.db.View(func(tx *bolt.Tx) error {
 | 
			
		||||
		b := tx.Bucket([]byte("records"))
 | 
			
		||||
 | 
			
		||||
		v = b.Get([]byte(key))
 | 
			
		||||
		v = b.Get([]byte(domain))
 | 
			
		||||
 | 
			
		||||
		if string(v) == "" {
 | 
			
		||||
			return errors.New( "Record not found, key:  " + key)
 | 
			
		||||
		if string(v) != "" {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		v = b.Get([]byte("*." + key))
 | 
			
		||||
		v = b.Get([]byte("*." + domain))
 | 
			
		||||
 | 
			
		||||
		if string(v) == "" {
 | 
			
		||||
			return errors.New( "Record not found, key:  " + key)
 | 
			
		||||
		if string(v) != "" {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
		return errRecordNotFound
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 | 
			
		||||
@ -57,6 +57,10 @@ func NewFileProvider(file string, ttl time.Duration) Provider {
 | 
			
		||||
	return fp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *FileHosts) List() (HostMap, error) {
 | 
			
		||||
	return nil, errUnsupportedOperation
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
	log.Debug("Checking file provider for %s : %s", queryType, domain)
 | 
			
		||||
 | 
			
		||||
@ -90,10 +94,6 @@ func (f *FileHosts) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
	return nil, errRecordNotFound
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *FileHosts) Set(domain string, host *Host) error {
 | 
			
		||||
	return errUnsupportedOperation
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	hostRegexp = regexp.MustCompile("^(.*?)\\s+(.*)$")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,28 @@ func NewRedisProvider(rc *redis.Client, key string) Provider {
 | 
			
		||||
	return rh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RedisHosts) List() (HostMap, error) {
 | 
			
		||||
	res, err := r.redis.HGetAll(r.key).Result()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hosts := make(HostMap)
 | 
			
		||||
 | 
			
		||||
	for k, v := range res {
 | 
			
		||||
		var domainRecords []Host
 | 
			
		||||
 | 
			
		||||
		if err = json.Unmarshal([]byte(v), &domainRecords); err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		hosts[k] = domainRecords
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, errUnsupportedOperation
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RedisHosts) Get(queryType uint16, domain string) (*Host, error) {
 | 
			
		||||
	log.Debug("Checking redis provider for %s", domain)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user