Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform topic check on publish #42

Closed
wants to merge 11 commits into from
16 changes: 12 additions & 4 deletions hbmqtt/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ async def client_connected(
client_id=client_session.client_id,
message=app_message,
)
await self._broadcast_message(
await self._broadcast_message_acl(
client_session, app_message.topic, app_message.data
)
if app_message.publish_packet.retain_flag:
Expand Down Expand Up @@ -710,7 +710,7 @@ async def authenticate(self, session: Session, listener):
# If all plugins returned True, authentication is success
return auth_result

async def topic_filtering(self, session: Session, topic):
async def topic_filtering(self, session: Session, topic, command):
"""
This method call the topic_filtering method on registered plugins to check that the subscription is allowed.
User is considered allowed if all plugins called return True.
Expand All @@ -720,7 +720,8 @@ async def topic_filtering(self, session: Session, topic):
- None if topic filtering can't be achieved (then plugin result is then ignored)
:param session:
:param listener:
:param topic: Topic in which the client wants to subscribe
:param topic: Topic in which the client wants to publish or subscribe
:param command: Whether it's a publish (1) or subscibe (0) command
:return:
"""
topic_plugins = None
Expand All @@ -731,6 +732,7 @@ async def topic_filtering(self, session: Session, topic):
"topic_filtering",
session=session,
topic=topic,
command=command,
filter_plugins=topic_plugins,
)
topic_result = True
Expand Down Expand Up @@ -774,7 +776,7 @@ async def add_subscription(self, subscription, session):
# [MQTT-4.7.1-3] + wildcard character must occupy entire level
return 0x80
# Check if the client is authorised to connect to the topic
permitted = await self.topic_filtering(session, topic=a_filter)
permitted = await self.topic_filtering(session, topic=a_filter, command=0)
if not permitted:
return 0x80
qos = subscription[1]
Expand Down Expand Up @@ -931,6 +933,12 @@ async def _broadcast_loop(self):
await asyncio.wait(running_tasks, loop=self._loop)
raise # reraise per CancelledError semantics

async def _broadcast_message_acl(self, session, topic, data, force_qos=None):
permitted = await self.topic_filtering(session, topic=topic, command=1)

if permitted:
await self._broadcast_message(session, topic, data, force_qos)

async def _broadcast_message(self, session, topic, data, force_qos=None):
broadcast = {"session": session, "topic": topic, "data": data}
if force_qos:
Expand Down
12 changes: 10 additions & 2 deletions hbmqtt/plugins/topic_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,24 @@ def topic_filtering(self, *args, **kwargs):
class TopicTabooPlugin(BaseTopicPlugin):
def __init__(self, context):
super().__init__(context)
self._taboo = ["prohibited", "top-secret", "data/classified"]
self._taboo = self.topic_config["taboo"]
self._taboo_command = self.topic_config.get(
"taboo_command"
) # If None, allow neither

async def topic_filtering(self, *args, **kwargs):
filter_result = super().topic_filtering(*args, **kwargs)
if filter_result:
session = kwargs.get("session", None)
topic = kwargs.get("topic", None)
command = kwargs.get("command")
if session.username and session.username == "admin":
return True
if topic and topic in self._taboo:
if (
topic
and topic in self._taboo
and (self._taboo_command is None or self._taboo_command == command)
):
return False
return True
return filter_result
Expand Down
69 changes: 69 additions & 0 deletions tests/plugins/test_topic_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2015 Nicolas JOUANIN
#
# See the file license.txt for copying permission.

import asyncio
from hbmqtt.plugins.manager import BaseContext
from hbmqtt.plugins.topic_checking import BaseTopicPlugin, TopicTabooPlugin
from hbmqtt.session import Session

formatter = (
"[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
)
logging.basicConfig(level=logging.DEBUG, format=formatter)


class TestTopicCheckingPlugin(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()

def test_taboo_topic(self):
context = BaseContext()
context.logger = logging.getLogger(__name__)
context.config = {
"topic-check": {
"taboo": ["prohibited", "top-secret", "data/classified"],
"taboo_command": 1,
}
}
s = Session()
s.username = "admin"
taboo_topic_plugin = TopicTabooPlugin(context)
ret = self.loop.run_until_complete(
taboo_topic_plugin.topic_filtering(topic="prohibited", command=1, session=s)
)
assert ret
s.username = "normal-user"
ret = self.loop.run_until_complete(
taboo_topic_plugin.topic_filtering(topic="prohibited", command=1, session=s)
)
assert ret is False
ret = self.loop.run_until_complete(
taboo_topic_plugin.topic_filtering(topic="prohibited", command=0, session=s)
)
assert ret
ret = self.loop.run_until_complete(
taboo_topic_plugin.topic_filtering(topic="allowed", command=1, session=s)
)
assert ret
ret = self.loop.run_until_complete(
taboo_topic_plugin.topic_filtering(topic="allowed", command=0, session=s)
)
context.config["topic-check"]["taboo_command"] = None
taboo_topic_plugin = TopicTabooPlugin(context)
assert ret
ret = self.loop.run_until_complete(
taboo_topic_plugin.topic_filtering(topic="prohibited", command=1, session=s)
)
assert ret is False
ret = self.loop.run_until_complete(
taboo_topic_plugin.topic_filtering(topic="prohibited", command=0, session=s)
)
assert ret is False
ret = self.loop.run_until_complete(
taboo_topic_plugin.topic_filtering(topic="allowed", command=1, session=s)
)
assert ret
ret = self.loop.run_until_complete(
taboo_topic_plugin.topic_filtering(topic="allowed", command=0, session=s)
)