package hosts import ( "encoding/json" "github.com/miekg/dns" log "github.com/sirupsen/logrus" bolt "go.etcd.io/bbolt" "strings" ) const ( recordBucket = "records" ) type BoltHosts struct { Provider db *bolt.DB } func NewBoltProvider(file string) Provider { db, err := bolt.Open(file, 0600, &bolt.Options{}) if err != nil { log.WithError(err).Fatalln("Unable to open database") } err = db.Update(func(tx *bolt.Tx) error { _, err := tx.CreateBucketIfNotExists([]byte(recordBucket)) if err != nil { return err } return nil }) return &BoltHosts{ db: db, } } func (b *BoltHosts) List() (HostMap, error) { hosts := make(HostMap) err := b.db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("records")) c := b.Cursor() for k, v := c.First(); k != nil; k, v = c.Next() { var domainRecords []Host if err := json.Unmarshal(v, &domainRecords); err != nil { continue } hosts[string(k)] = domainRecords } return nil }) if err != nil { return nil, err } return hosts, nil } func (b *BoltHosts) Get(queryType uint16, domain string) (*Host, error) { log.WithFields(log.Fields{ "queryType": dns.TypeToString[queryType], "question": domain, }).Debug("Checking bolt provider") domain = strings.ToLower(domain) var err error var v []byte err = b.db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("records")) v = b.Get([]byte(domain)) if string(v) != "" { return nil } v = b.Get([]byte("*." + domain)) if string(v) != "" { return nil } return errRecordNotFound }) if err != nil { return nil, err } var h []Host if err = json.Unmarshal(v, &h); err != nil { return nil, err } for _, host := range h { if host.Type == queryType { return &host, nil } } return nil, errRecordNotFound } func (b *BoltHosts) Set(domain string, host *Host) error { err := b.db.Update(func(tx *bolt.Tx) error { b := tx.Bucket([]byte(recordBucket)) hosts := []*Host{host} existing := b.Get([]byte(domain)) if existing != nil { err := json.Unmarshal(existing, &hosts) if err != nil { return err } hosts = append(hosts, host) } hostBytes, err := json.Marshal(hosts) if err != nil { return err } err = b.Put([]byte(domain), hostBytes) if err != nil { return err } return nil }) return err }