diff --git a/changelog.d/17916.feature b/changelog.d/17916.feature new file mode 100644 index 00000000000..118997c5e59 --- /dev/null +++ b/changelog.d/17916.feature @@ -0,0 +1 @@ +Module developers will have access to user id of requester when adding `check_username_for_spam` callbacks to `spam_checker_module_callbacks`. Contributed by Wilson@Pangea.chat. \ No newline at end of file diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index ec306d81abf..c7f8606fd07 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -245,7 +245,7 @@ this callback. _First introduced in Synapse v1.37.0_ ```python -async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool +async def check_username_for_spam(user_profile: synapse.module_api.UserProfile, requester_id: str) -> bool ``` Called when computing search results in the user directory. The module must return a @@ -264,6 +264,8 @@ The profile is represented as a dictionary with the following keys: The module is given a copy of the original dictionary, so modifying it from within the module cannot modify a user's profile when included in user directory search results. +The requester_id parameter is the ID of the user that called the user directory API. + If multiple modules implement this callback, they will be considered in order. If a callback returns `False`, Synapse falls through to the next one. The value of the first callback that does not return `False` will be used. If this happens, Synapse will not call diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 1b6d814937c..4ace3512b33 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -72,8 +72,8 @@ class ExampleSpamChecker: async def user_may_publish_room(self, userid, room_id): return True # allow publishing of all rooms - async def check_username_for_spam(self, user_profile): - return False # allow all usernames + async def check_username_for_spam(self, user_profile, requester_id): + return False # allow all usernames regardless of requester async def check_registration_for_spam( self, diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index a343637b82a..1281929d384 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -161,7 +161,7 @@ async def search_users( non_spammy_users = [] for user in results["results"]: if not await self._spam_checker_module_callbacks.check_username_for_spam( - user + user, user_id ): non_spammy_users.append(user) results["results"] = non_spammy_users diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py index 17079ff781c..a2f328cafed 100644 --- a/synapse/module_api/callbacks/spamchecker_callbacks.py +++ b/synapse/module_api/callbacks/spamchecker_callbacks.py @@ -31,6 +31,7 @@ Optional, Tuple, Union, + cast, ) # `Literal` appears with Python 3.8. @@ -168,7 +169,10 @@ ] ], ] -CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]] +CHECK_USERNAME_FOR_SPAM_CALLBACK = Union[ + Callable[[UserProfile], Awaitable[bool]], + Callable[[UserProfile, str], Awaitable[bool]], +] LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[ [ Optional[dict], @@ -716,7 +720,9 @@ async def user_may_publish_room( return self.NOT_SPAM - async def check_username_for_spam(self, user_profile: UserProfile) -> bool: + async def check_username_for_spam( + self, user_profile: UserProfile, requester_id: str + ) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. If the server considers a username spammy, then it will not be included in @@ -727,15 +733,33 @@ async def check_username_for_spam(self, user_profile: UserProfile) -> bool: * user_id * display_name * avatar_url + requester_id: The user ID of the user making the user directory search request. Returns: True if the user is spammy. """ for callback in self._check_username_for_spam_callbacks: with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): + checker_args = inspect.signature(callback) # Make a copy of the user profile object to ensure the spam checker cannot # modify it. - res = await delay_cancellation(callback(user_profile.copy())) + # Also ensure backwards compatibility with spam checker callbacks + # that don't expect the requester_id argument. + if len(checker_args.parameters) == 2: + callback_with_requester_id = cast( + Callable[[UserProfile, str], Awaitable[bool]], callback + ) + res = await delay_cancellation( + callback_with_requester_id(user_profile.copy(), requester_id) + ) + else: + callback_without_requester_id = cast( + Callable[[UserProfile], Awaitable[bool]], callback + ) + res = await delay_cancellation( + callback_without_requester_id(user_profile.copy()) + ) + if res: return True diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 878d9683b6a..a75095a79f1 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -796,6 +796,7 @@ def test_spam_checker(self) -> None: s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) + # Kept old spam checker without `requester_id` tests for backwards compatibility. async def allow_all(user_profile: UserProfile) -> bool: # Allow all users. return False @@ -809,6 +810,7 @@ async def allow_all(user_profile: UserProfile) -> bool: s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) + # Kept old spam checker without `requester_id` tests for backwards compatibility. # Configure a spam checker that filters all users. async def block_all(user_profile: UserProfile) -> bool: # All users are spammy. @@ -820,6 +822,40 @@ async def block_all(user_profile: UserProfile) -> bool: s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 0) + async def allow_all_expects_requester_id( + user_profile: UserProfile, requester_id: str + ) -> bool: + self.assertEqual(requester_id, u1) + # Allow all users. + return False + + # Configure a spam checker that does not filter any users. + spam_checker = self.hs.get_module_api_callbacks().spam_checker + spam_checker._check_username_for_spam_callbacks = [ + allow_all_expects_requester_id + ] + + # The results do not change: + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # Configure a spam checker that filters all users. + async def block_all_expects_requester_id( + user_profile: UserProfile, requester_id: str + ) -> bool: + self.assertEqual(requester_id, u1) + # All users are spammy. + return True + + spam_checker._check_username_for_spam_callbacks = [ + block_all_expects_requester_id + ] + + # User1 now gets no search results for any of the other users. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + @override_config( { "spam_checker": {