diff --git a/changelog.d/6786.misc b/changelog.d/6786.misc
new file mode 100644
index 0000000000000000000000000000000000000000..94c692e53a7b0e7be368f1628774b4f99fd56ff4
--- /dev/null
+++ b/changelog.d/6786.misc
@@ -0,0 +1 @@
+Attempt to resync remote users' devices when detected as stale.
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 5c5fe77be2e13b61fda99b1a344189dfe32e2494..05c4b3eec0e9e566217be83483942abcd3c4ea9a 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -21,6 +21,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
+from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
     get_active_span_text_map,
     log_kv,
@@ -48,6 +49,8 @@ class DeviceMessageHandler(object):
             "m.direct_to_device", self.on_direct_to_device_edu
         )
 
+        self._device_list_updater = hs.get_device_handler().device_list_updater
+
     @defer.inlineCallbacks
     def on_direct_to_device_edu(self, origin, content):
         local_messages = {}
@@ -134,8 +137,11 @@ class DeviceMessageHandler(object):
                 unknown_devices,
             )
             yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
-            # TODO: Poke something to start trying to refetch user's
-            # keys.
+
+            # Immediately attempt a resync in the background
+            run_in_background(
+                self._device_list_updater.user_device_resync, sender_user_id
+            )
 
     @defer.inlineCallbacks
     def send_device_message(self, sender_user_id, message_type, messages):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index a67020a259037b7ce24e7bf5ef02938449a901bb..ca484e545837d35a57a75c7c8ab1e86099fd388d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -57,6 +57,7 @@ from synapse.logging.context import (
     run_in_background,
 )
 from synapse.logging.utils import log_function
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.replication.http.federation import (
     ReplicationCleanRoomRestServlet,
     ReplicationFederationSendEventsRestServlet,
@@ -156,6 +157,13 @@ class FederationHandler(BaseHandler):
             hs
         )
 
+        if hs.config.worker_app:
+            self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
+                hs
+            )
+        else:
+            self._device_list_updater = hs.get_device_handler().device_list_updater
+
         # When joining a room we need to queue any events for that room up
         self.room_queues = {}
         self._room_pdu_linearizer = Linearizer("fed_room_pdu")
@@ -759,8 +767,14 @@ class FederationHandler(BaseHandler):
                     await self.store.mark_remote_user_device_cache_as_stale(
                         event.sender
                     )
-                    # TODO: Poke something to start trying to refetch user's
-                    # keys.
+
+                    # Immediately attempt a resync in the background
+                    if self.config.worker_app:
+                        return run_in_background(self._user_device_resync, event.sender)
+                    else:
+                        return run_in_background(
+                            self._device_list_updater.user_device_resync, event.sender
+                        )
 
     @log_function
     async def backfill(self, dest, room_id, limit, extremities):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 596ddc6970504422790fbd387c9d5e11de436197..68b9847bd2b2ebf5502635c16b0bb7af4269f901 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -81,6 +81,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             ]
         )
 
+        # the tests assume that we are starting at unix time 1000
+        reactor.pump((1000,))
+
         hs = self.setup_test_homeserver(
             notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
         )
@@ -90,9 +93,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor, clock, hs):
-        # the tests assume that we are starting at unix time 1000
-        reactor.pump((1000,))
-
         mock_notifier = hs.get_notifier()
         self.on_new_event = mock_notifier.on_new_event