godns/hosts/hosts_bolt.go

154 lines
2.3 KiB
Go

package hosts
import (
"encoding/json"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
"strings"
)
const (
recordBucket = "records"
)
type BoltHosts struct {
Provider
db *bolt.DB
}
func NewBoltProvider(file string) Provider {
db, err := bolt.Open(file, 0600, &bolt.Options{})
if err != nil {
log.WithError(err).Fatalln("Unable to open database")
}
err = db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(recordBucket))
if err != nil {
return err
}
return nil
})
return &BoltHosts{
db: db,
}
}
func (b *BoltHosts) List() (HostMap, error) {
hosts := make(HostMap)
err := b.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("records"))
c := b.Cursor()
for k, v := c.First(); k != nil; k, v = c.Next() {
var domainRecords []Host
if err := json.Unmarshal(v, &domainRecords); err != nil {
continue
}
hosts[string(k)] = domainRecords
}
return nil
})
if err != nil {
return nil, err
}
return hosts, nil
}
func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
log.WithFields(log.Fields{
"queryType": dns.TypeToString[queryType],
"question": domain,
}).Debug("Checking bolt provider")
domain = strings.ToLower(domain)
var err error
var v []byte
err = b.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("records"))
v = b.Get([]byte(domain))
if string(v) != "" {
return nil
}
v = b.Get([]byte("*." + domain))
if string(v) != "" {
return nil
}
return errRecordNotFound
})
if err != nil {
return nil, err
}
var h []Host
if err = json.Unmarshal(v, &h); err != nil {
return nil, err
}
for _, host := range h {
if host.Type == queryType {
return &host, nil
}
}
return nil, errRecordNotFound
}
func (b *BoltHosts) Set(domain string, host *Host) error {
err := b.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(recordBucket))
hosts := []*Host{host}
existing := b.Get([]byte(domain))
if existing != nil {
err := json.Unmarshal(existing, &hosts)
if err != nil {
return err
}
hosts = append(hosts, host)
}
hostBytes, err := json.Marshal(hosts)
if err != nil {
return err
}
err = b.Put([]byte(domain), hostBytes)
if err != nil {
return err
}
return nil
})
return err
}