package _routers

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"mime"
	"net/http"
	"net/url"
	"strconv"
	"strings"

	"github.com/alioygur/is"
	"github.com/gabriel-vasile/mimetype"
	"github.com/getsentry/sentry-go"
	"github.com/t2bot/gotd-contrib/http_range"
	"github.com/turt2live/matrix-media-repo/api/_responses"
	"github.com/turt2live/matrix-media-repo/common"
	"github.com/turt2live/matrix-media-repo/common/rcontext"
	"github.com/turt2live/matrix-media-repo/util"
	"github.com/turt2live/matrix-media-repo/util/readers"
)

type GeneratorFn = func(r *http.Request, ctx rcontext.RequestContext) interface{}

type RContextRouter struct {
	generatorFn GeneratorFn
	next        http.Handler
}

func NewRContextRouter(generatorFn GeneratorFn, next http.Handler) *RContextRouter {
	return &RContextRouter{generatorFn: generatorFn, next: next}
}

func (c *RContextRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	log := GetLogger(r)
	rctx := rcontext.RequestContext{
		Context: r.Context(),
		Log:     log,
		Config:  *GetDomainConfig(r),
		Request: r,
	}

	var res interface{}
	res = c.generatorFn(r, rctx)
	if res == nil {
		res = &_responses.EmptyResponse{}
	}

	shouldCache := true
	wrappedRes, isNoCache := res.(*_responses.DoNotCacheResponse)
	if isNoCache {
		shouldCache = false
		res = wrappedRes.Payload
	}

	headers := w.Header()

	// Check for HTML response and reply accordingly
	if htmlRes, isHtml := res.(*_responses.HtmlResponse); isHtml {
		log.Infof("Replying with result: %T <%d chars of html>", res, len(htmlRes.HTML))

		// Write out HTML here, now that we know it's happening
		if shouldCache {
			headers.Set("Cache-Control", "private, max-age=259200") // 3 days
		}
		headers.Set("Content-Type", "text/html; charset=UTF-8")

		// Clear the CSP because we're serving HTML
		headers.Set("Content-Security-Policy", "")
		headers.Set("X-Content-Security-Policy", "")

		r = writeStatusCode(w, r, http.StatusOK)
		if _, err := w.Write([]byte(htmlRes.HTML)); err != nil {
			panic(errors.New("error sending HtmlResponse: " + err.Error()))
		}
		return // don't continue
	}

	// Next try handling the response as a download, which might turn into an error
	proposedStatusCode := http.StatusOK
	var stream io.ReadCloser
	expectedBytes := int64(0)
	var contentType string
beforeParseDownload:
	log.Infof("Replying with result: %T %+v", res, res)
	if downloadRes, isDownload := res.(*_responses.DownloadResponse); isDownload {
		ranges, err := http_range.ParseRange(r.Header.Get("Range"), downloadRes.SizeBytes, rctx.Config.Downloads.DefaultRangeChunkSizeBytes)
		if errors.Is(err, http_range.ErrInvalid) {
			proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
			res = _responses.BadRequest("invalid range header")
			goto beforeParseDownload // reprocess `res`
		} else if errors.Is(err, http_range.ErrNoOverlap) {
			proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
			res = _responses.BadRequest("out of range")
			goto beforeParseDownload // reprocess `res`
		}
		if len(ranges) > 1 {
			proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
			res = _responses.BadRequest("only 1 range is supported")
			goto beforeParseDownload // reprocess `res`
		}

		contentType = "application/octet-stream"
		expectedBytes = downloadRes.SizeBytes

		// Don't rely on user-supplied values for content-type
		var rsc io.ReadSeekCloser
		var mimeType *mimetype.MIME
		if t, ok := downloadRes.Data.(io.ReadSeekCloser); ok {
			rsc = t
			mimeType, err = mimetype.DetectReader(rsc)
			if _, err2 := rsc.Seek(0, io.SeekStart); err2 != nil {
				rctx.Log.Error("Error seeking after detecting mimetype: ", err2)
				sentry.CaptureException(err2)
				res = _responses.InternalServerError("Unexpected Error")
				goto beforeParseDownload // reprocess `res`
			}
		} else {
			br := readers.NewBufferReadsReader(downloadRes.Data)
			mimeType, err = mimetype.DetectReader(br)
			ogReader := downloadRes.Data
			downloadRes.Data = readers.NewCancelCloser(io.NopCloser(br.GetRewoundReader()), func() {
				_ = ogReader.Close()
			})
		}
		if err != nil {
			rctx.Log.Warn("Non-fatal error sniffing mime type of download: ", err)
			sentry.CaptureException(err)
		} else if mimeType != nil {
			contentType = mimeType.String()
		}

		if contentType != downloadRes.ContentType {
			rctx.Log.Debugf("Expected '%s' content type but ended up with '%s'", downloadRes.ContentType, contentType)
		}

		if shouldCache {
			headers.Set("Cache-Control", "private, max-age=259200") // 3 days
		}

		if downloadRes.SizeBytes > 0 {
			headers.Set("Accept-Ranges", "bytes")
		}

		disposition := downloadRes.TargetDisposition
		if disposition == "" {
			disposition = "attachment"
		} else if disposition == "infer" {
			if contentType == "" {
				disposition = "attachment"
			} else {
				if util.CanInline(contentType) {
					disposition = "inline"
				} else {
					disposition = "attachment"
				}
			}
		}
		fname := downloadRes.Filename
		if fname == "" {
			exts, err := mime.ExtensionsByType(contentType)
			if err != nil {
				exts = nil
				sentry.CaptureException(err)
				log.Warn("Unexpected error inferring file extension: ", err)
			}
			ext := ""
			if exts != nil && len(exts) > 0 {
				ext = exts[0]
			}
			fname = "file" + ext
		}
		if is.ASCII(fname) {
			headers.Set("Content-Disposition", disposition+"; filename="+url.QueryEscape(fname))
		} else {
			headers.Set("Content-Disposition", disposition+"; filename*=utf-8''"+url.QueryEscape(fname))
		}

		stream = downloadRes.Data
		if len(ranges) > 0 {
			if rsc, ok := stream.(io.ReadSeekCloser); ok {
				target := ranges[0] // we only use the first range (validated up above)
				if _, err = rsc.Seek(target.Start, io.SeekStart); err != nil {
					rctx.Log.Warn("Non-fatal error seeking for Range request: ", err)
					sentry.CaptureException(err)
				} else {
					headers.Set("Content-Range", target.ContentRange(downloadRes.SizeBytes))
					proposedStatusCode = http.StatusPartialContent
					stream = readers.NewCancelCloser(io.NopCloser(io.LimitReader(rsc, target.Length)), func() {
						_ = rsc.Close()
					})
					expectedBytes = target.Length
				}
			}
		}
	}

	// Try to find a suitable error code, if one is needed
	if errRes, isError := res.(_responses.ErrorResponse); isError {
		res = &errRes // just fix it
	}
	if errRes, isError := res.(*_responses.ErrorResponse); isError && proposedStatusCode == http.StatusOK {
		switch errRes.InternalCode {
		case common.ErrCodeUnknownToken:
			proposedStatusCode = http.StatusUnauthorized
			break
		case common.ErrCodeNotFound:
			proposedStatusCode = http.StatusNotFound
			break
		case common.ErrCodeMediaTooLarge:
			proposedStatusCode = http.StatusRequestEntityTooLarge
			break
		case common.ErrCodeBadRequest:
			proposedStatusCode = http.StatusBadRequest
			break
		case common.ErrCodeMethodNotAllowed:
			proposedStatusCode = http.StatusMethodNotAllowed
			break
		case common.ErrCodeForbidden:
			proposedStatusCode = http.StatusForbidden
			break
		case common.ErrCodeCannotOverwrite:
			proposedStatusCode = http.StatusConflict
			break
		case common.ErrCodeNotYetUploaded:
			proposedStatusCode = http.StatusGatewayTimeout
			break
		default: // Treat as unknown (a generic server error)
			proposedStatusCode = http.StatusInternalServerError
			break
		}
	}

	// Prepare a stream if one isn't set, and assume JSON
	if stream == nil {
		contentType = "application/json"
		b, err := json.Marshal(res)
		if err != nil {
			panic(err) // blow up this request
		}
		stream = io.NopCloser(bytes.NewReader(b))
		expectedBytes = int64(len(b))
	}

	mediaType, params, err := mime.ParseMediaType(contentType)
	if err != nil {
		sentry.CaptureException(err)
		log.Warn("Failed to parse content type header for media on reply: ", err)
	} else {
		// TODO: Maybe we only strip the charset from images? Is it valid to have the param on other types?
		if !strings.HasPrefix(mediaType, "text/") && mediaType != "application/json" {
			delete(params, "charset")
		}
		contentType = mime.FormatMediaType(mediaType, params)
	}
	headers.Set("Content-Type", contentType)

	if expectedBytes > 0 {
		headers.Set("Content-Length", strconv.FormatInt(expectedBytes, 10))
	}

	r = writeStatusCode(w, r, proposedStatusCode)

	defer stream.Close()
	written, err := io.Copy(w, stream)
	if err != nil {
		panic(err) // blow up this request
	}
	if expectedBytes > 0 && written != expectedBytes {
		panic(errors.New(fmt.Sprintf("mismatch transfer size: %d expected, %d sent", expectedBytes, written)))
	}

	if c.next != nil {
		c.next.ServeHTTP(w, r)
	}
}

func GetStatusCode(r *http.Request) int {
	x, ok := r.Context().Value(common.ContextStatusCode).(int)
	if !ok {
		return http.StatusOK
	}
	return x
}

func writeStatusCode(w http.ResponseWriter, r *http.Request, statusCode int) *http.Request {
	w.WriteHeader(statusCode)
	return r.WithContext(context.WithValue(r.Context(), common.ContextStatusCode, statusCode))
}