diff --git a/amqtt/broker.py b/amqtt/broker.py index a4f8ef1a..94d54720 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -1,7 +1,7 @@ # Copyright (c) 2015 Nicolas JOUANIN # # See the file license.txt for copying permission. -from typing import Optional +from typing import Optional, Tuple import logging import ssl import websockets @@ -33,6 +33,9 @@ "auth": {"allow-anonymous": True, "password-file": None}, } +# Default port numbers +DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} + EVENT_BROKER_PRE_START = "broker_pre_start" EVENT_BROKER_POST_START = "broker_post_start" EVENT_BROKER_PRE_SHUTDOWN = "broker_pre_shutdown" @@ -44,6 +47,48 @@ EVENT_BROKER_MESSAGE_RECEIVED = "broker_message_received" +def split_bindaddr_port(port_str: str, default_port: int) -> Tuple[Optional[str], int]: + """ + Split an address:port pair into separate IP address and port, with IPv6 + special-case handling. + """ + # Address can be specified using one of the following methods: + # 1883 - Port number only (listen all interfaces) + # :1883 - Port number only (listen all interfaces) + # 0.0.0.0:1883 - IPv4 address + # [::]:1883 - IPv6 address + # empty string - all interfaces default port + + def _parse_port(port_str: str) -> int: + if port_str.startswith(":"): + port_str = port_str[1:] + + if not port_str: + return default_port + + return int(port_str) + + if port_str.startswith("["): # IPv6 literal + try: + addr_end = port_str.index("]") + except ValueError: + raise ValueError("Expecting '[' to be followed by ']'") + + return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :])) + elif ":" in port_str: + # Address : port + address, port_str = port_str.rsplit(":", 1) + return (address or None, _parse_port(port_str)) + else: + # Address or port + try: + # Port number? + return (None, _parse_port(port_str)) + except ValueError: + # Address, default port + return (port_str, default_port) + + class Action(Enum): subscribe = "subscribe" publish = "publish" @@ -54,7 +99,6 @@ class BrokerException(Exception): class RetainedApplicationMessage: - __slots__ = ("source_session", "topic", "data", "qos") def __init__(self, source_session, topic, data, qos=None): @@ -294,10 +338,10 @@ async def start(self) -> None: % (listener["certfile"], listener["keyfile"], fnfe) ) - address, s_port = listener["bind"].split(":") - port = 0 try: - port = int(s_port) + address, port = split_bindaddr_port( + listener["bind"], DEFAULT_PORTS[listener["type"]] + ) except ValueError: raise BrokerException( "Invalid port value in bind value: %s" % listener["bind"] @@ -936,7 +980,7 @@ async def _run_broadcast(self, running_tasks: deque): continue subscriptions = self._subscriptions[k_filter] - for (target_session, qos) in subscriptions: + for target_session, qos in subscriptions: qos = broadcast.get("qos", qos) # Retain all messages which cannot be broadcasted diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index bb6d4133..eb9875c1 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -530,8 +530,8 @@ async def _send_packet(self, packet): await self.handle_connection_closed() except asyncio.CancelledError: raise - except Exception as e: - self.logger.warning("Unhandled exception: %s" % e) + except: + self.logger.warning("Unhandled exception", exc_info=True) raise async def mqtt_deliver_next_message(self): diff --git a/tests/test_broker.py b/tests/test_broker.py index f9d06d34..97d26014 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -21,6 +21,7 @@ EVENT_BROKER_CLIENT_SUBSCRIBED, EVENT_BROKER_CLIENT_UNSUBSCRIBED, EVENT_BROKER_MESSAGE_RECEIVED, + split_bindaddr_port, ) from amqtt.client import MQTTClient, ConnectException from amqtt.mqtt import ( @@ -53,6 +54,23 @@ async def async_magic(): MagicMock.__await__ = lambda x: async_magic().__await__() +@pytest.mark.parametrize( + "input_str, output_addr, output_port", + [ + ("1234", None, 1234), + (":1234", None, 1234), + ("0.0.0.0:1234", "0.0.0.0", 1234), + ("[::]:1234", "[::]", 1234), + ("0.0.0.0", "0.0.0.0", 5678), + ("[::]", "[::]", 5678), + ("localhost", "localhost", 5678), + ("localhost:1234", "localhost", 1234), + ], +) +def test_split_bindaddr_port(input_str, output_addr, output_port): + assert split_bindaddr_port(input_str, 5678) == (output_addr, output_port) + + @pytest.mark.asyncio async def test_start_stop(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls(