Skip to content
Snippets Groups Projects
Commit 0b735656 authored by Travis Ralston's avatar Travis Ralston
Browse files

Stop parsing range requests manually

parent cd405659
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"math"
"mime"
"net/http"
"net/url"
......@@ -17,6 +16,7 @@ import (
"github.com/alioygur/is"
"github.com/gabriel-vasile/mimetype"
"github.com/getsentry/sentry-go"
"github.com/t2bot/gotd-contrib/http_range"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
......@@ -88,11 +88,19 @@ func (c *RContextRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
beforeParseDownload:
log.Infof("Replying with result: %T %+v", res, res)
if downloadRes, isDownload := res.(*_responses.DownloadResponse); isDownload {
doRange, rangeStart, rangeEnd, rangeErrMsg := parseRange(r, downloadRes)
if doRange && rangeErrMsg != "" {
ranges, err := http_range.ParseRange(r.Header.Get("Range"), downloadRes.SizeBytes, rctx.Config.Downloads.DefaultRangeChunkSizeBytes)
if errors.Is(err, http_range.ErrInvalid) {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest(rangeErrMsg)
doRange = false
res = _responses.BadRequest("invalid range header")
goto beforeParseDownload // reprocess `res`
} else if errors.Is(err, http_range.ErrNoOverlap) {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("out of range")
goto beforeParseDownload // reprocess `res`
}
if len(ranges) > 1 {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("only 1 range is supported")
goto beforeParseDownload // reprocess `res`
}
......@@ -100,17 +108,31 @@ beforeParseDownload:
expectedBytes = downloadRes.SizeBytes
// Don't rely on user-supplied values for content-type
br := readers.NewBufferReadsReader(downloadRes.Data)
if mimeType, err := mimetype.DetectReader(br); err != nil {
var rsc io.ReadSeekCloser
var mimeType *mimetype.MIME
if t, ok := downloadRes.Data.(io.ReadSeekCloser); ok {
rsc = t
mimeType, err = mimetype.DetectReader(rsc)
if _, err2 := rsc.Seek(0, io.SeekStart); err2 != nil {
rctx.Log.Error("Error seeking after detecting mimetype: ", err2)
sentry.CaptureException(err2)
res = _responses.InternalServerError("Unexpected Error")
goto beforeParseDownload // reprocess `res`
}
} else {
br := readers.NewBufferReadsReader(downloadRes.Data)
mimeType, err = mimetype.DetectReader(br)
ogReader := downloadRes.Data
downloadRes.Data = readers.NewCancelCloser(io.NopCloser(br.GetRewoundReader()), func() {
_ = ogReader.Close()
})
}
if err != nil {
rctx.Log.Warn("Non-fatal error sniffing mime type of download: ", err)
sentry.CaptureException(err)
} else if mimeType != nil {
contentType = mimeType.String()
}
ogReader := downloadRes.Data
downloadRes.Data = readers.NewCancelCloser(io.NopCloser(br.GetRewoundReader()), func() {
_ = ogReader.Close()
})
if contentType != downloadRes.ContentType {
rctx.Log.Debugf("Expected '%s' content type but ended up with '%s'", downloadRes.ContentType, contentType)
......@@ -158,11 +180,23 @@ beforeParseDownload:
headers.Set("Content-Disposition", disposition+"; filename*=utf-8''"+url.QueryEscape(fname))
}
if _, ok := stream.(io.ReadSeekCloser); ok && doRange {
headers.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", rangeStart, rangeEnd, downloadRes.SizeBytes))
proposedStatusCode = http.StatusPartialContent
}
stream = downloadRes.Data
if len(ranges) > 0 {
if rsc, ok := stream.(io.ReadSeekCloser); ok {
target := ranges[0] // we only use the first range (validated up above)
if _, err = rsc.Seek(target.Start, io.SeekStart); err != nil {
rctx.Log.Warn("Non-fatal error seeking for Range request: ", err)
sentry.CaptureException(err)
} else {
headers.Set("Content-Range", target.ContentRange(downloadRes.SizeBytes))
proposedStatusCode = http.StatusPartialContent
stream = readers.NewCancelCloser(io.NopCloser(io.LimitReader(rsc, target.Length)), func() {
_ = rsc.Close()
})
expectedBytes = target.Length
}
}
}
}
// Try to find a suitable error code, if one is needed
......@@ -257,59 +291,3 @@ func writeStatusCode(w http.ResponseWriter, r *http.Request, statusCode int) *ht
w.WriteHeader(statusCode)
return r.WithContext(context.WithValue(r.Context(), common.ContextStatusCode, statusCode))
}
func parseRange(r *http.Request, res *_responses.DownloadResponse) (bool, int64, int64, string) {
rangeHeader := r.Header.Get("Range")
if rangeHeader == "" || res.SizeBytes <= 0 {
return false, 0, 0, ""
}
if !strings.HasPrefix(rangeHeader, "bytes=") {
return true, 0, 0, "Improper range units"
}
if !strings.Contains(rangeHeader, ",") && !strings.HasPrefix(rangeHeader, "bytes=-") {
parts := strings.Split(rangeHeader[len("bytes="):], "-")
if len(parts) <= 2 {
rstart, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return true, 0, 0, "Improper start of range"
}
if rstart < 0 {
return true, 0, 0, "Improper start of range: negative"
}
rend := int64(-1)
if len(parts) > 1 && parts[1] != "" {
rend, err = strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return true, 0, 0, "Improper end of range"
}
if rend < 1 {
return true, 0, 0, "Improper end of range: negative"
}
if rend >= res.SizeBytes {
return true, 0, 0, "Improper end of range: out of bounds"
}
if rend <= rstart {
return true, 0, 0, "Start must be before end"
}
if (rstart + rend) >= res.SizeBytes {
return true, 0, 0, "Range too large"
}
} else {
add := int64(10485760) // 10mb default
conf := GetDomainConfig(r)
if conf.Downloads.DefaultRangeChunkSizeBytes > 0 {
add = conf.Downloads.DefaultRangeChunkSizeBytes
}
rend = int64(math.Min(float64(rstart+add), float64(res.SizeBytes-1)))
}
if (rend - rstart) <= 0 {
return true, 0, 0, "Range invalid at last pass"
}
return true, rstart, rend, ""
}
}
return false, 0, 0, ""
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment