Skip to content
Snippets Groups Projects
Unverified Commit 85ece3df authored by Erik Johnston's avatar Erik Johnston Committed by GitHub
Browse files

Merge pull request #5191 from matrix-org/erikj/refactor_pagination_bounds

Make generating SQL bounds for pagination generic
parents ce5bcefc 8dd9cca8
No related branches found
No related tags found
No related merge requests found
Make generating SQL bounds for pagination generic.
......@@ -64,59 +64,135 @@ _EventDictReturn = namedtuple(
)
def lower_bound(token, engine, inclusive=False):
inclusive = "=" if inclusive else ""
if token.topological is None:
return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) <%s (%s,%s))" % (
token.topological,
token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
def generate_pagination_where_clause(
direction, column_names, from_token, to_token, engine,
):
"""Creates an SQL expression to bound the columns by the pagination
tokens.
For example creates an SQL expression like:
(6, 7) >= (topological_ordering, stream_ordering)
AND (5, 3) < (topological_ordering, stream_ordering)
would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3).
Note that tokens are considered to be after the row they are in, e.g. if
a row A has a token T, then we consider A to be before T. This convention
is important when figuring out inequalities for the generated SQL, and
produces the following result:
- If paginating forwards then we exclude any rows matching the from
token, but include those that match the to token.
- If paginating backwards then we include any rows matching the from
token, but include those that match the to token.
Args:
direction (str): Whether we're paginating backwards("b") or
forwards ("f").
column_names (tuple[str, str]): The column names to bound. Must *not*
be user defined as these get inserted directly into the SQL
statement without escapes.
from_token (tuple[int, int]|None): The start point for the pagination.
This is an exclusive minimum bound if direction is "f", and an
inclusive maximum bound if direction is "b".
to_token (tuple[int, int]|None): The endpoint point for the pagination.
This is an inclusive maximum bound if direction is "f", and an
exclusive minimum bound if direction is "b".
engine: The database engine to generate the clauses for
Returns:
str: The sql expression
"""
assert direction in ("b", "f")
where_clause = []
if from_token:
where_clause.append(
_make_generic_sql_bound(
bound=">=" if direction == "b" else "<",
column_names=column_names,
values=from_token,
engine=engine,
)
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological,
"topological_ordering",
token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
)
def upper_bound(token, engine, inclusive=True):
inclusive = "=" if inclusive else ""
if token.topological is None:
return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) >%s (%s,%s))" % (
token.topological,
token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
)
if to_token:
where_clause.append(
_make_generic_sql_bound(
bound="<" if direction == "b" else ">=",
column_names=column_names,
values=to_token,
engine=engine,
)
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological,
"topological_ordering",
token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
)
return " AND ".join(where_clause)
def _make_generic_sql_bound(bound, column_names, values, engine):
"""Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
Only works with two columns.
Older versions of SQLite don't support that syntax so we have to expand it
out manually.
Args:
bound (str): The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right.
names (tuple[str, str]): The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without
escapes.
values (tuple[int|None, int]): The values to bound the columns by. If
the first value is None then only creates a bound on the second
column.
engine: The database engine to generate the SQL for
Returns:
str
"""
assert(bound in (">", "<", ">=", "<="))
name1, name2 = column_names
val1, val2 = values
if val1 is None:
val2 = int(val2)
return "(%d %s %s)" % (val2, bound, name2)
val1 = int(val1)
val2 = int(val2)
if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) %s (%s,%s))" % (
val1, val2,
bound,
name1, name2,
)
# We want to generate queries of e.g. the form:
#
# (val1 < name1 OR (val1 = name1 AND val2 <= name2))
#
# which is equivalent to (val1, val2) < (name1, name2)
return """(
{val1:d} {strict_bound} {name1}
OR ({val1:d} = {name1} AND {val2:d} {bound} {name2})
)""".format(
name1=name1,
val1=val1,
name2=name2,
val2=val2,
strict_bound=bound[0], # The first bound must always be strict equality here
bound=bound,
)
def filter_to_clause(event_filter):
# NB: This may create SQL clauses that don't optimise well (and we don't
......@@ -762,20 +838,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
bounds = upper_bound(from_token, self.database_engine)
if to_token:
bounds = "%s AND %s" % (
bounds,
lower_bound(to_token, self.database_engine),
)
else:
order = "ASC"
bounds = lower_bound(from_token, self.database_engine)
if to_token:
bounds = "%s AND %s" % (
bounds,
upper_bound(to_token, self.database_engine),
)
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=from_token,
to_token=to_token,
engine=self.database_engine,
)
filter_clause, filter_args = filter_to_clause(event_filter)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment