Skip to content
Snippets Groups Projects
Commit cd405659 authored by Travis Ralston's avatar Travis Ralston
Browse files

Pipelines shouldn't be handling range requests

parent e4235079
No related branches found
No related tags found
No related merge requests found
......@@ -56,8 +56,6 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
media, stream, err := pipeline_download.Execute(rctx, server, mediaId, pipeline_download.DownloadOpts{
FetchRemoteIfNeeded: downloadRemote,
StartByte: -1,
EndByte: -1,
BlockForReadUntil: blockFor,
})
if err != nil {
......
......@@ -109,8 +109,6 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
thumbnail, stream, err := pipeline_thumbnail.Execute(rctx, server, mediaId, pipeline_thumbnail.ThumbnailOpts{
DownloadOpts: pipeline_download.DownloadOpts{
FetchRemoteIfNeeded: downloadRemote,
StartByte: -1,
EndByte: -1,
BlockForReadUntil: blockFor,
RecordOnly: false, // overridden
},
......
......@@ -79,8 +79,6 @@ func MediaInfo(r *http.Request, rctx rcontext.RequestContext, user _apimeta.User
record, stream, err := pipeline_download.Execute(rctx, server, mediaId, pipeline_download.DownloadOpts{
FetchRemoteIfNeeded: downloadRemote,
StartByte: -1,
EndByte: -1,
BlockForReadUntil: 30 * time.Second,
RecordOnly: false,
})
......
......@@ -59,8 +59,6 @@ func LocalCopy(r *http.Request, rctx rcontext.RequestContext, user _apimeta.User
record, stream, err := pipeline_download.Execute(rctx, server, mediaId, pipeline_download.DownloadOpts{
FetchRemoteIfNeeded: downloadRemote,
StartByte: -1,
EndByte: -1,
BlockForReadUntil: 30 * time.Second,
RecordOnly: false,
})
......
......@@ -42,8 +42,6 @@ func ExportEntityData(ctx rcontext.RequestContext, exportId string, entityId str
ctx.Log.Debugf("Downloading %s", mxc)
_, s, err := pipeline_download.Execute(ctx, media.Origin, media.MediaId, pipeline_download.DownloadOpts{
FetchRemoteIfNeeded: false,
StartByte: -1,
EndByte: -1,
BlockForReadUntil: 1 * time.Minute,
RecordOnly: false,
})
......
......@@ -33,5 +33,6 @@ func Download(ctx rcontext.RequestContext, ds config.DatastoreConfig, dsFileName
return nil, errors.New("unknown datastore type - contact developer")
}
// TODO(TR-1): @@ Return seekable stream
return rsc, err
}
......@@ -4,29 +4,14 @@ import (
"errors"
"io"
"github.com/getsentry/sentry-go"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
"github.com/turt2live/matrix-media-repo/redislib"
)
type limitedCloser struct {
io.ReadCloser
lm io.Reader
rs io.ReadCloser
}
func (r limitedCloser) Read(p []byte) (int, error) {
return r.lm.Read(p)
}
func (r limitedCloser) Close() error {
return r.rs.Close()
}
func OpenStream(ctx rcontext.RequestContext, media *database.Locatable, startByte int64, endByte int64) (io.ReadCloser, error) {
reader, err := redislib.TryGetMedia(ctx, media.Sha256Hash, startByte, endByte)
func OpenStream(ctx rcontext.RequestContext, media *database.Locatable) (io.ReadCloser, error) {
reader, err := redislib.TryGetMedia(ctx, media.Sha256Hash)
if err != nil || reader != nil {
ctx.Log.Debugf("Got %s from cache", media.Sha256Hash)
return io.NopCloser(reader), err
......@@ -37,44 +22,5 @@ func OpenStream(ctx rcontext.RequestContext, media *database.Locatable, startByt
return nil, errors.New("unable to locate datastore for media")
}
rsc, err := datastores.Download(ctx, ds, media.Location)
if err != nil {
return nil, err
}
return CreateLimitedStream(ctx, rsc, startByte, endByte)
}
func CreateLimitedStream(ctx rcontext.RequestContext, r io.ReadCloser, startByte int64, endByte int64) (io.ReadCloser, error) {
if startByte >= 0 {
if rsc, ok := r.(io.ReadSeekCloser); ok {
if _, err := rsc.Seek(startByte, io.SeekStart); err != nil {
err2 := rsc.Close()
if err2 != nil {
ctx.Log.Errorf("Error while closing datastore stream due to other error: %s", err2)
sentry.CaptureException(err2)
}
return nil, err
}
} else {
_, err := io.CopyN(io.Discard, r, startByte)
if err != nil {
err2 := r.Close()
if err2 != nil {
ctx.Log.Errorf("Error while closing datastore stream due to other error: %s", err2)
sentry.CaptureException(err2)
}
return nil, err
}
}
}
var lm io.Reader = r
if endByte >= 1 {
if startByte < 0 {
startByte = 0
}
lm = io.LimitReader(r, endByte-startByte)
}
return &limitedCloser{lm: lm, rs: r}, nil
return datastores.Download(ctx, ds, media.Location)
}
......@@ -5,10 +5,9 @@ import (
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/pipelines/_steps/download"
)
func ReturnAppropriateThing(ctx rcontext.RequestContext, isDownload bool, recordOnly bool, width int, height int, startByte int64, endByte int64) (io.ReadCloser, error) {
func ReturnAppropriateThing(ctx rcontext.RequestContext, isDownload bool, recordOnly bool, width int, height int) (io.ReadCloser, error) {
flag := ctx.Config.Quarantine.ReplaceDownloads
if !isDownload {
flag = ctx.Config.Quarantine.ReplaceThumbnails
......@@ -16,14 +15,6 @@ func ReturnAppropriateThing(ctx rcontext.RequestContext, isDownload bool, record
if !flag || recordOnly {
return nil, common.ErrMediaQuarantined
} else {
if qr, err := MakeThumbnail(ctx, width, height); err != nil {
return nil, err
} else {
if r, err2 := download.CreateLimitedStream(ctx, qr, startByte, endByte); err2 != nil {
return nil, err2
} else {
return r, common.ErrMediaQuarantined
}
}
return MakeThumbnail(ctx, width, height)
}
}
......@@ -36,7 +36,7 @@ func Generate(ctx rcontext.RequestContext, mediaRecord *database.DbMedia, width
"origin": mediaRecord.Origin,
})
mediaStream, err := download.OpenStream(ctx, mediaRecord.Locatable, -1, -1)
mediaStream, err := download.OpenStream(ctx, mediaRecord.Locatable)
if err != nil {
ch <- generateResult{err: err}
return
......
......@@ -24,14 +24,12 @@ var recordSf = sfcache.NewSingleflightCache[*database.DbMedia]()
type DownloadOpts struct {
FetchRemoteIfNeeded bool
StartByte int64
EndByte int64
BlockForReadUntil time.Duration
RecordOnly bool
}
func (o DownloadOpts) String() string {
return fmt.Sprintf("f=%t,s=%d,e=%d,b=%s", o.FetchRemoteIfNeeded, o.StartByte, o.EndByte, o.BlockForReadUntil.String())
return fmt.Sprintf("f=%t,b=%s,r=%t", o.FetchRemoteIfNeeded, o.BlockForReadUntil.String(), o.RecordOnly)
}
func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts DownloadOpts) (*database.DbMedia, io.ReadCloser, error) {
......@@ -63,13 +61,13 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
// Step 3: Do we already have the media? Serve it if yes.
if record != nil {
if record.Quarantined {
return quarantine.ReturnAppropriateThing(ctx, true, opts.RecordOnly, 512, 512, opts.StartByte, opts.EndByte)
return quarantine.ReturnAppropriateThing(ctx, true, opts.RecordOnly, 512, 512)
}
meta.FlagAccess(ctx, record.Sha256Hash, record.CreationTs)
if opts.RecordOnly {
return nil, nil
}
return download.OpenStream(ctx, record.Locatable, opts.StartByte, opts.EndByte)
return download.OpenStream(ctx, record.Locatable)
}
// Step 4: Media record unknown - download it (if possible)
......@@ -82,7 +80,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
}
recordSf.OverwriteCacheKey(sfKey, record)
if record.Quarantined {
return quarantine.ReturnAppropriateThing(ctx, true, opts.RecordOnly, 512, 512, opts.StartByte, opts.EndByte)
return quarantine.ReturnAppropriateThing(ctx, true, opts.RecordOnly, 512, 512)
}
meta.FlagAccess(ctx, record.Sha256Hash, record.CreationTs)
if opts.RecordOnly {
......@@ -90,12 +88,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
return nil, nil
}
// Step 5: Limit the stream if needed
r, err = download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte)
if err != nil {
return nil, err
}
// Step 5: Return the stream
return r, nil
})
if errors.Is(err, common.ErrMediaQuarantined) {
......
......@@ -40,10 +40,6 @@ func (o ThumbnailOpts) ImpliedDownloadOpts() pipeline_download.DownloadOpts {
FetchRemoteIfNeeded: o.FetchRemoteIfNeeded,
BlockForReadUntil: o.BlockForReadUntil,
RecordOnly: true,
// We remove the range parameters to ensure we get a useful download stream
StartByte: -1,
EndByte: -1,
}
}
......@@ -83,7 +79,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
if dr != nil {
dr.Close()
}
return quarantine.ReturnAppropriateThing(ctx, false, opts.RecordOnly, opts.Width, opts.Height, opts.StartByte, opts.EndByte)
return quarantine.ReturnAppropriateThing(ctx, false, opts.RecordOnly, opts.Width, opts.Height)
}
return nil, err
}
......@@ -104,14 +100,14 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
if opts.RecordOnly {
return nil, nil
}
return download.OpenStream(ctx, record.Locatable, opts.StartByte, opts.EndByte)
return download.OpenStream(ctx, record.Locatable)
}
// Step 6: Generate the thumbnail and return that
record, r, err := thumbnails.Generate(ctx, mediaRecord, opts.Width, opts.Height, opts.Method, opts.Animated)
if err != nil {
if !opts.RecordOnly && errors.Is(err, common.ErrMediaDimensionsTooSmall) {
d, err := download.OpenStream(ctx, mediaRecord.Locatable, opts.StartByte, opts.EndByte)
d, err := download.OpenStream(ctx, mediaRecord.Locatable)
if err != nil {
return nil, err
} else {
......@@ -126,8 +122,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
return nil, nil
}
// Step 7: Create a limited stream
return download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte)
// Step 7: Return stream
return r, nil
})
if errors.Is(err, common.ErrMediaQuarantined) || errors.Is(err, common.ErrMediaDimensionsTooSmall) {
if r != nil {
......
......@@ -63,7 +63,7 @@ func StoreMedia(ctx rcontext.RequestContext, hash string, content io.Reader, siz
return nil
}
func TryGetMedia(ctx rcontext.RequestContext, hash string, startByte int64, endByte int64) (io.Reader, error) {
func TryGetMedia(ctx rcontext.RequestContext, hash string) (io.Reader, error) {
makeConnection()
if ring == nil {
return nil, nil
......@@ -73,17 +73,10 @@ func TryGetMedia(ctx rcontext.RequestContext, hash string, startByte int64, endB
defer cancel()
var result *redis.StringCmd
if startByte >= 0 && endByte >= 1 {
ctx.Log.Debugf("Getting range from cache for %s (bytes %d-%d)", hash, startByte, endByte)
if startByte < endByte {
result = ring.GetRange(timeoutCtx, hash, startByte, endByte)
} else {
return nil, errors.New("invalid range - start must be before end")
}
} else {
ctx.Log.Debugf("Getting whole cached object for %s", hash)
result = ring.Get(timeoutCtx, hash)
}
// TODO(TR-1): @@ Return seekable stream
ctx.Log.Debugf("Getting whole cached object for %s", hash)
result = ring.Get(timeoutCtx, hash)
s, err := result.Bytes()
if err != nil {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment