yarascanner/main.go

252 lines
4.3 KiB
Go

package main
import (
"encoding/json"
"github.com/go-chi/chi"
"github.com/hillu/go-yara/v4"
"github.com/package-url/packageurl-go"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"io"
"net/http"
"os"
"os/signal"
"runtime"
"strings"
"sync"
"syscall"
"time"
)
type callbackFunc func(rules yara.MatchRules)
type Job struct {
Data io.ReadCloser
Callback callbackFunc
}
var (
client *http.Client
jobChan = make(chan *Job)
)
func main() {
viper.SetDefault("threads", runtime.NumCPU())
viper.SetDefault("rules", "pkg:github/Neo23x0/signature-base#yara")
viper.SetDefault("bind", ":8080")
viper.AutomaticEnv()
client = &http.Client{
Timeout: 15 * time.Second,
}
c, err := yara.NewCompiler()
if err != nil {
log.WithError(err).Fatal("Unable to setup new compiler")
}
c.DefineVariable("filename", "")
c.DefineVariable("filepath", "")
c.DefineVariable("extension", "")
c.DefineVariable("filetype", "")
log.Info("Loading rules")
loadRules(c)
rules, err := c.GetRules()
if err != nil {
log.WithError(err).Fatal("Unable to compile rules")
}
threads := viper.GetInt("threads")
log.WithField("workers", threads).Info("Starting workers")
for i := 0; i < threads; i++ {
go worker(rules)
}
r := chi.NewRouter()
r.Post("/scan", scanHandler)
bind := viper.GetString("bind")
log.WithField("bind", bind).Info("Binding to address")
go http.ListenAndServe(bind, r)
ch := make(chan os.Signal)
signal.Notify(ch, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGINT)
<-ch
}
// HTTP handler for scanning files
func scanHandler(w http.ResponseWriter, r *http.Request) {
contentType := r.Header.Get("Content-Type")
if idx := strings.Index(contentType, ";"); idx != -1 {
contentType = contentType[0:idx]
}
switch contentType {
case "multipart/form-data":
if r.MultipartForm == nil {
r.ParseMultipartForm(32 << 20)
}
wg := &sync.WaitGroup{}
results := make([]yara.MatchRules, 0)
jobCallback := func(m yara.MatchRules) {
results = append(results, m)
wg.Done()
}
log.WithField("contentType", contentType).Debug("Adding files from multipart form as jobs")
fileCount := 0
for _, files := range r.MultipartForm.File {
// Append files
for _, file := range files {
wg.Add(1)
fileCount++
f, err := file.Open()
if err != nil {
continue
}
job := &Job{
Data: f,
Callback: jobCallback,
}
jobChan <- job
}
}
log.WithField("count", fileCount).Debug("Waiting for jobs to finish")
wg.Wait()
json.NewEncoder(w).Encode(results)
default:
job := &Job{
Data: r.Body,
Callback: func(m yara.MatchRules) {
json.NewEncoder(w).Encode(m)
},
}
log.WithField("contentType", contentType).Debug("Scanning contents of body")
jobChan <- job
}
}
// Load rules from the rules configuration option
func loadRules(c *yara.Compiler) {
rulePaths := strings.Split(viper.GetString("rules"), ",")
for _, p := range rulePaths {
log.WithField("package", p).Info("Loading rules from package")
instance, err := packageurl.FromString(p)
if err != nil {
log.WithFields(log.Fields{
"error": err,
"package": p,
}).Fatalln("Invalid rule URL")
}
switch instance.Type {
case "git", "bitbucket", "github", "gitlab":
err = loadRulesFromGit(instance, c)
case "http", "https":
err = loadRulesFromHttp(instance, c)
}
if err != nil {
log.WithFields(log.Fields{
"error": err,
"package": p,
}).Fatalln("Unable to load rules")
}
}
}
// Load rules from http(s)
// TODO: Support archive files alongside standard yar files
func loadRulesFromHttp(pkg packageurl.PackageURL, c *yara.Compiler) error {
res, err := client.Get(pkg.Name)
if err != nil {
return err
}
defer res.Body.Close()
b, err := io.ReadAll(res.Body)
if err != nil {
return err
}
return c.AddString(string(b), "")
}
// A worker routine. Creates a new scanner instance and pulls jobs.
func worker(rules *yara.Rules) {
s, err := yara.NewScanner(rules)
if err != nil {
panic(err)
}
for {
job := <-jobChan
log.Debug("Processing job")
processJob(s, job)
}
}
func processJob(s *yara.Scanner, job *Job) {
var m yara.MatchRules
defer job.Data.Close()
b, err := io.ReadAll(job.Data)
if err != nil {
return
}
err = s.SetCallback(&m).ScanMem(b)
if err != nil {
return
}
// Respond with job
if len(m) < 1 {
return
}
job.Callback(m)
}