Skip to content
Snippets Groups Projects
main.go 4.78 KiB
Newer Older
package main

import (
	"bytes"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"math"
	"net/http"
	"os"
	"regexp"
	"strings"
	"time"

	"github.com/disintegration/imaging"
	"github.com/hashicorp/go-hclog"
	"github.com/hashicorp/go-plugin"
	"github.com/turt2live/matrix-media-repo/plugins/plugin_common"
	"github.com/turt2live/matrix-media-repo/plugins/plugin_interfaces"
	"github.com/turt2live/matrix-media-repo/util"
	"github.com/turt2live/matrix-media-repo/util/idec"
)

type AntispamOCR struct {
	logger hclog.Logger
	config map[string]interface{}

	userIdRegex   *regexp.Regexp
	contentTypes  []string
	minSize       int
	maxSize       int
	keywordGroups [][]string
	ocrServer     string
	topPercentage float64
}

func (a *AntispamOCR) HandleConfig(config map[string]interface{}) error {
	a.config = config

	a.ocrServer = a.config["ocrServer"].(string)
	a.minSize = int(a.config["minSizeBytes"].(float64))
	a.maxSize = int(a.config["maxSizeBytes"].(float64))
	a.topPercentage = a.config["percentageOfHeight"].(float64)

	ctypes := make([]string, 0)
	for _, t := range a.config["types"].([]interface{}) {
		ctypes = append(ctypes, fmt.Sprintf("%v", t))
	}
	a.contentTypes = ctypes

	kwg := make([][]string, 0)
	for _, c := range a.config["keywordGroups"].([]interface{}) {
		kwg2 := make([]string, 0)
		for _, kw := range c.([]interface{}) {
			kwg2 = append(kwg2, fmt.Sprintf("%v", kw))
		}
		kwg = append(kwg, kwg2)
	}
	a.keywordGroups = kwg

	r, err := regexp.Compile(a.config["userIds"].(string))
	if err != nil {
		return err
	}
	a.userIdRegex = r

	return nil
}

func (a *AntispamOCR) CheckForSpam(b64 string, filename string, contentType string, userId string, origin string, mediaId string) (bool, error) {
	b, err := base64.StdEncoding.DecodeString(b64)
	if err != nil {
		return false, err
	}

	if len(b) < a.minSize || len(b) > a.maxSize {
		return false, nil
	}
	if !util.ArrayContains(a.contentTypes, contentType) {
		return false, nil
	}
	if !a.userIdRegex.MatchString(userId) {
		return false, nil
	}

	img, err := idec.Decode(bytes.NewBuffer(b))
	if err != nil {
		return false, err
	}

	// For certain kinds of spam we don't really need to consider the whole image but just the upper third.
	if a.topPercentage < 1.0 && a.topPercentage > 0 {
		newHeight := int(math.Round(float64(img.Bounds().Max.Y) * a.topPercentage))
		img = imaging.Fill(img, img.Bounds().Max.X, newHeight, imaging.Top, imaging.Linear)
	}

	// Steps:
	// 1. Crush the image to reasonable dimensions (helps with later transforms). Use Lanczos to soften lines on letters.
	// 2. Double the image size, using Lanczos again to do a second round of softening.
	// 3. Try to remove any background noise (usually introduced during upload and by resizing).
	// 4. Adjust contrast to make text more obvious on the background.
	// 5. Convert to grayscale, thus avoiding any colour issues with the OCR.
	img = imaging.Fit(img, 512, 512, imaging.Lanczos)
	img = imaging.Fill(img, img.Bounds().Max.X*2, img.Bounds().Max.Y*2, imaging.Top, imaging.Lanczos)
	img = imaging.Sharpen(img, 50)
	img = imaging.AdjustContrast(img, 2)
	img = imaging.Grayscale(img)

	imgData := &bytes.Buffer{}
	err = imaging.Encode(imgData, img, imaging.PNG) // dev note: deliberately png (don't use u.Encode())
	if err != nil {
		return false, err
	}
	b64 = base64.StdEncoding.EncodeToString(imgData.Bytes())

	bodyBytes, err := json.Marshal(map[string]interface{}{
		"base64": b64,
		"trim":   "\n",
	})
	if err != nil {
		return false, err
	}

	ocrUrl := util.MakeUrl(a.ocrServer, "/base64")
	req, err := http.NewRequest("POST", ocrUrl, bytes.NewBuffer(bodyBytes))
	if err != nil {
		return false, err
	}
	req.Header.Set("User-Agent", "matrix-media-repo")
	client := &http.Client{
		Timeout: 20 * time.Second,
	}
	res, err := client.Do(req)
	if err != nil {
		a.logger.Error("non-fatal error checking spam: ", err)
		return false, nil
	}
	contents, err := io.ReadAll(res.Body)
	if err != nil {
		return false, err
	}
	var resp map[string]interface{}
	err = json.Unmarshal(contents, &resp)
	if err != nil {
		return false, err
	}
	if res.StatusCode != http.StatusOK {
		return false, fmt.Errorf("unexpected status code: %d", res.StatusCode)
	}
	ocr := strings.ToLower(resp["result"].(string))

	for _, kwg := range a.keywordGroups {
		hasKeyword := false
		for _, kw := range kwg {
			if strings.Contains(ocr, kw) {
				hasKeyword = true
				break
			}
		}
		if !hasKeyword {
			return false, nil
		}
	}

	a.logger.Warn("spam detected")
	return true, nil
}

func main() {
	logger := hclog.New(&hclog.LoggerOptions{
		Level:      hclog.Trace,
		Output:     os.Stderr,
		JSONFormat: true,
	})

	antispam := &AntispamOCR{logger: logger}

	plugin.Serve(&plugin.ServeConfig{
		HandshakeConfig: plugin_common.HandshakeConfig,
		Plugins: map[string]plugin.Plugin{
			"antispam": &plugin_interfaces.AntispamPlugin{Impl: antispam},
		},
	})
}