Skip to content
Snippets Groups Projects
Commit 8e024941 authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Delete refresh tokens when deleting devices

parent d34e9f93
No related branches found
No related tags found
No related merge requests found
......@@ -138,8 +138,10 @@ class DeviceHandler(BaseHandler):
else:
raise
yield self.store.user_delete_access_tokens(user_id,
device_id=device_id)
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
......
......@@ -252,20 +252,36 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
@defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[],
device_id=None):
def f(txn):
sql = "SELECT token FROM access_tokens WHERE user_id = ?"
device_id=None,
delete_refresh_tokens=False):
"""
Invalidate access/refresh tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
delete_refresh_tokens (bool): True to delete refresh tokens as
well as access tokens.
Returns:
defer.Deferred:
"""
def f(txn, table, except_tokens, call_after_delete):
sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id]
if device_id is not None:
sql += " AND device_id = ?"
clauses.append(device_id)
if except_token_ids:
if except_tokens:
sql += " AND id NOT IN (%s)" % (
",".join(["?" for _ in except_token_ids]),
",".join(["?" for _ in except_tokens]),
)
clauses += except_token_ids
clauses += except_tokens
txn.execute(sql, clauses)
......@@ -274,16 +290,33 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
for row in chunk:
txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
if call_after_delete:
for row in chunk:
txn.call_after(call_after_delete, (row[0],))
txn.execute(
"DELETE FROM access_tokens WHERE token in (%s)" % (
"DELETE FROM %s WHERE token in (%s)" % (
table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
)
yield self.runInteraction("user_delete_access_tokens", f)
# delete refresh tokens first, to stop new access tokens being
# allocated while our backs are turned
if delete_refresh_tokens:
yield self.runInteraction(
"user_delete_access_tokens", f,
table="refresh_tokens",
except_tokens=[],
call_after_delete=None,
)
yield self.runInteraction(
"user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
)
def delete_access_token(self, access_token):
def f(txn):
......@@ -306,9 +339,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Args:
token (str): The access token of a user.
Returns:
dict: Including the name (user_id) and the ID of their access token.
Raises:
StoreError if no user was found.
defer.Deferred: None, if the token did not match, ootherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
......
......@@ -128,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase):
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
generator = TokenGenerator()
refresh_token = generator.generate(self.user_id)
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
self.device_id)
yield self.store.add_refresh_token_to_user(self.user_id, refresh_token,
self.device_id)
# now delete some
yield self.store.user_delete_access_tokens(
self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
# check they were deleted
user = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertIsNone(user, "access token was not deleted by device_id")
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(refresh_token,
generator.generate)
# check the one not associated with the device was not deleted
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertEqual(self.user_id, user["name"])
# now delete the rest
yield self.store.user_delete_access_tokens(
self.user_id, delete_refresh_tokens=True)
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertIsNone(user,
"access token was not deleted without device_id")
class TokenGenerator:
def __init__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment