154 lines
2.3 KiB
Go
154 lines
2.3 KiB
Go
package hosts
|
|
|
|
import (
|
|
"encoding/json"
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
bolt "go.etcd.io/bbolt"
|
|
"strings"
|
|
)
|
|
|
|
const (
|
|
recordBucket = "records"
|
|
)
|
|
|
|
type BoltHosts struct {
|
|
Provider
|
|
|
|
db *bolt.DB
|
|
}
|
|
|
|
func NewBoltProvider(file string) Provider {
|
|
db, err := bolt.Open(file, 0600, &bolt.Options{})
|
|
|
|
if err != nil {
|
|
log.WithError(err).Fatalln("Unable to open database")
|
|
}
|
|
|
|
err = db.Update(func(tx *bolt.Tx) error {
|
|
_, err := tx.CreateBucketIfNotExists([]byte(recordBucket))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
return &BoltHosts{
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
func (b *BoltHosts) List() (HostMap, error) {
|
|
hosts := make(HostMap)
|
|
|
|
err := b.db.View(func(tx *bolt.Tx) error {
|
|
b := tx.Bucket([]byte("records"))
|
|
|
|
c := b.Cursor()
|
|
|
|
for k, v := c.First(); k != nil; k, v = c.Next() {
|
|
var domainRecords []Host
|
|
|
|
if err := json.Unmarshal(v, &domainRecords); err != nil {
|
|
continue
|
|
}
|
|
|
|
hosts[string(k)] = domainRecords
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return hosts, nil
|
|
}
|
|
|
|
func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) {
|
|
log.WithFields(log.Fields{
|
|
"queryType": dns.TypeToString[queryType],
|
|
"question": domain,
|
|
}).Debug("Checking bolt provider")
|
|
|
|
domain = strings.ToLower(domain)
|
|
|
|
var err error
|
|
|
|
var v []byte
|
|
|
|
err = b.db.View(func(tx *bolt.Tx) error {
|
|
b := tx.Bucket([]byte("records"))
|
|
|
|
v = b.Get([]byte(domain))
|
|
|
|
if string(v) != "" {
|
|
return nil
|
|
}
|
|
|
|
v = b.Get([]byte("*." + domain))
|
|
|
|
if string(v) != "" {
|
|
return nil
|
|
}
|
|
|
|
return errRecordNotFound
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var h []Host
|
|
|
|
if err = json.Unmarshal(v, &h); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, host := range h {
|
|
if host.Type == queryType {
|
|
return &host, nil
|
|
}
|
|
}
|
|
|
|
return nil, errRecordNotFound
|
|
}
|
|
|
|
func (b *BoltHosts) Set(domain string, host *Host) error {
|
|
err := b.db.Update(func(tx *bolt.Tx) error {
|
|
b := tx.Bucket([]byte(recordBucket))
|
|
|
|
hosts := []*Host{host}
|
|
|
|
existing := b.Get([]byte(domain))
|
|
|
|
if existing != nil {
|
|
err := json.Unmarshal(existing, &hosts)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
hosts = append(hosts, host)
|
|
}
|
|
|
|
hostBytes, err := json.Marshal(hosts)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = b.Put([]byte(domain), hostBytes)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
return err
|
|
}
|