Newer
Older
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"math"
"net/http"
"os"
"regexp"
"strings"
"time"
"github.com/disintegration/imaging"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-plugin"
"github.com/turt2live/matrix-media-repo/plugins/plugin_common"
"github.com/turt2live/matrix-media-repo/plugins/plugin_interfaces"
"github.com/turt2live/matrix-media-repo/util"
"github.com/turt2live/matrix-media-repo/util/idec"
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
)
type AntispamOCR struct {
logger hclog.Logger
config map[string]interface{}
userIdRegex *regexp.Regexp
contentTypes []string
minSize int
maxSize int
keywordGroups [][]string
ocrServer string
topPercentage float64
}
func (a *AntispamOCR) HandleConfig(config map[string]interface{}) error {
a.config = config
a.ocrServer = a.config["ocrServer"].(string)
a.minSize = int(a.config["minSizeBytes"].(float64))
a.maxSize = int(a.config["maxSizeBytes"].(float64))
a.topPercentage = a.config["percentageOfHeight"].(float64)
ctypes := make([]string, 0)
for _, t := range a.config["types"].([]interface{}) {
ctypes = append(ctypes, fmt.Sprintf("%v", t))
}
a.contentTypes = ctypes
kwg := make([][]string, 0)
for _, c := range a.config["keywordGroups"].([]interface{}) {
kwg2 := make([]string, 0)
for _, kw := range c.([]interface{}) {
kwg2 = append(kwg2, fmt.Sprintf("%v", kw))
}
kwg = append(kwg, kwg2)
}
a.keywordGroups = kwg
r, err := regexp.Compile(a.config["userIds"].(string))
if err != nil {
return err
}
a.userIdRegex = r
return nil
}
func (a *AntispamOCR) CheckForSpam(b64 string, filename string, contentType string, userId string, origin string, mediaId string) (bool, error) {
b, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return false, err
}
if len(b) < a.minSize || len(b) > a.maxSize {
return false, nil
}
if !util.ArrayContains(a.contentTypes, contentType) {
return false, nil
}
if !a.userIdRegex.MatchString(userId) {
return false, nil
}
img, err := idec.Decode(bytes.NewBuffer(b))
if err != nil {
return false, err
}
// For certain kinds of spam we don't really need to consider the whole image but just the upper third.
if a.topPercentage < 1.0 && a.topPercentage > 0 {
newHeight := int(math.Round(float64(img.Bounds().Max.Y) * a.topPercentage))
img = imaging.Fill(img, img.Bounds().Max.X, newHeight, imaging.Top, imaging.Linear)
}
// Steps:
// 1. Crush the image to reasonable dimensions (helps with later transforms). Use Lanczos to soften lines on letters.
// 2. Double the image size, using Lanczos again to do a second round of softening.
// 3. Try to remove any background noise (usually introduced during upload and by resizing).
// 4. Adjust contrast to make text more obvious on the background.
// 5. Convert to grayscale, thus avoiding any colour issues with the OCR.
img = imaging.Fit(img, 512, 512, imaging.Lanczos)
img = imaging.Fill(img, img.Bounds().Max.X*2, img.Bounds().Max.Y*2, imaging.Top, imaging.Lanczos)
img = imaging.Sharpen(img, 50)
img = imaging.AdjustContrast(img, 2)
img = imaging.Grayscale(img)
imgData := &bytes.Buffer{}
err = imaging.Encode(imgData, img, imaging.PNG) // dev note: deliberately png (don't use u.Encode())
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
if err != nil {
return false, err
}
b64 = base64.StdEncoding.EncodeToString(imgData.Bytes())
bodyBytes, err := json.Marshal(map[string]interface{}{
"base64": b64,
"trim": "\n",
})
if err != nil {
return false, err
}
ocrUrl := util.MakeUrl(a.ocrServer, "/base64")
req, err := http.NewRequest("POST", ocrUrl, bytes.NewBuffer(bodyBytes))
if err != nil {
return false, err
}
req.Header.Set("User-Agent", "matrix-media-repo")
client := &http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
a.logger.Error("non-fatal error checking spam: ", err)
return false, nil
}
if err != nil {
return false, err
}
var resp map[string]interface{}
err = json.Unmarshal(contents, &resp)
if err != nil {
return false, err
}
if res.StatusCode != http.StatusOK {
return false, fmt.Errorf("unexpected status code: %d", res.StatusCode)
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
}
ocr := strings.ToLower(resp["result"].(string))
for _, kwg := range a.keywordGroups {
hasKeyword := false
for _, kw := range kwg {
if strings.Contains(ocr, kw) {
hasKeyword = true
break
}
}
if !hasKeyword {
return false, nil
}
}
a.logger.Warn("spam detected")
return true, nil
}
func main() {
logger := hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
Output: os.Stderr,
JSONFormat: true,
})
antispam := &AntispamOCR{logger: logger}
plugin.Serve(&plugin.ServeConfig{
HandshakeConfig: plugin_common.HandshakeConfig,
Plugins: map[string]plugin.Plugin{
"antispam": &plugin_interfaces.AntispamPlugin{Impl: antispam},
},
})
}