From d0251c1d41085a28621e7a1a3490d503c4b0d026 Mon Sep 17 00:00:00 2001 From: Sander van Grieken Date: Mon, 15 Jul 2024 15:55:01 +0200 Subject: [PATCH] onion_messages: request-reply queue --- electrum/onion_message.py | 250 +++++++++++++++++++++++++------------- 1 file changed, 168 insertions(+), 82 deletions(-) diff --git a/electrum/onion_message.py b/electrum/onion_message.py index de92741ec084..284cc4dedabb 100644 --- a/electrum/onion_message.py +++ b/electrum/onion_message.py @@ -24,7 +24,9 @@ import copy import io import os +import queue import threading + from typing import TYPE_CHECKING, Optional, List, Sequence from electrum import ecc @@ -46,6 +48,10 @@ logger = get_logger(__name__) +REQUEST_REPLY_TIMEOUT = 120 +REQUEST_REPLY_RETRY_DELAY = 5 + + def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_data: dict, hop_extras: Optional[Sequence[dict]] = None): introduction_point = path[0] @@ -114,10 +120,7 @@ def encrypt_onionmsg_tlv_hops_data(hops_data, hop_shared_secrets): hops_data[i].payload['encrypted_recipient_data'] = {'encrypted_recipient_data': encrypted_recipient_data} -# TODO: integrate this with OnionMessageManager below for retry/rate-limit etc -def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None): - assert wallet.lnworker, 'not a lightning wallet' - +def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None): if session_key is None: session_key = os.urandom(32) @@ -130,14 +133,14 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de logger.debug(f'blinded path: {blinded_path!r}') except Exception as e: logger.error(f'e!r') - return + raise introduction_point = blinded_path['first_node_id'] hops_data = [] blinded_node_ids = [] - if wallet.lnworker.node_keypair.pubkey == introduction_point: + if lnwallet.node_keypair.pubkey == introduction_point: # blinded path introduction point is me our_blinding = blinded_path['blinding'] our_payload = blinded_path['path'][0] @@ -145,13 +148,13 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de assert len(remaining_blinded_path) > 0, 'sending to myself?' # decrypt - shared_secret = get_ecdh(wallet.lnworker.node_keypair.privkey, our_blinding) + shared_secret = get_ecdh(lnwallet.node_keypair.privkey, our_blinding) recipient_data = decrypt_encrypted_data_tlv( shared_secret=shared_secret, encrypted_recipient_data=our_payload['encrypted_recipient_data'] ) - peer = wallet.lnworker.peers.get(recipient_data['next_node_id']['node_id']) + peer = lnwallet.peers.get(recipient_data['next_node_id']['node_id']) assert peer, 'next_node_id not a peer' # blinding override? @@ -170,7 +173,7 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de else: # we need a route to introduction point remaining_blinded_path = blinded_path['path'] - peer = wallet.lnworker.peers.get(introduction_point) + peer = lnwallet.peers.get(introduction_point) # if blinded path introduction point is our direct peer, no need to route-find if peer: # start of blinded path is our peer @@ -187,7 +190,7 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de raise Exception('no path found') # first hop must be our peer - peer = wallet.lnworker.peers.get(path[0].end_node) + peer = lnwallet.peers.get(path[0].end_node) assert peer, 'first hop not a peer' # last hop is introduction point and start of blinded path. remove from route @@ -247,19 +250,19 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de else: # node pubkey pubkey = node_id_or_blinded_path - if wallet.lnworker.node_keypair.pubkey == pubkey: + if lnwallet.node_keypair.pubkey == pubkey: raise Exception('cannot send to myself') hops_data = [] - peer = wallet.lnworker.peers.get(pubkey) + peer = lnwallet.peers.get(pubkey) if peer: # destination is our direct peer, no need to route-find path = [PathEdge(short_channel_id=None, start_node=None, end_node=pubkey)] else: # route-find to pubkey. - path = wallet.lnworker.network.path_finder.find_path_for_payment( - nodeA=wallet.lnworker.node_keypair.pubkey, + path = lnwallet.network.path_finder.find_path_for_payment( + nodeA=lnwallet.node_keypair.pubkey, nodeB=pubkey, invoice_amount_msat=10000, # TODO: do this without amount constraints node_filter=is_onion_message_node @@ -268,7 +271,7 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de raise Exception('no path found') # first hop must be our peer - peer = wallet.lnworker.peers.get(path[0].end_node) + peer = lnwallet.peers.get(path[0].end_node) assert peer, 'first hop not a peer' hops_data = [ @@ -302,6 +305,9 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de ) +class Timeout(Exception): pass + + class OnionMessageManager(Logger): """handle state around onion message sends and receives - association between onion message and their replies @@ -315,6 +321,8 @@ def __init__(self, lnwallet: 'LNWallet'): self.pending = {} self.pending_lock = threading.Lock() + self.reqrpyqueue = queue.PriorityQueue() + self.reqrpyqueue_notempty = asyncio.Event() def start_network(self, *, network: 'Network'): assert network @@ -327,7 +335,7 @@ async def main_loop(self): self.logger.info("starting taskgroup.") try: async with self.taskgroup as group: - await group.spawn(self.process()) + await group.spawn(self.process_request_reply_queue()) except Exception as e: self.logger.exception("taskgroup died.") else: @@ -336,19 +344,135 @@ async def main_loop(self): async def stop(self): await self.taskgroup.cancel_remaining() - async def process(self): - # TODO: naive and barely tested send queue loop, this needs more love. + async def process_request_reply_queue(self): while True: - await asyncio.sleep(2) - with self.pending_lock: - for key, pending_item in self.pending.items(): - state = pending_item[0] - if now() - state['submitted'] > 120: # expired - continue - if now() - state['last_attempt'] > 5: - state['last_attempt'] = now() - self.logger.debug('spawning onionmsg send') - await self.taskgroup.spawn(self.send_pending_onion_message(key)) + try: + scheduled, expires, key = self.reqrpyqueue.get_nowait() + except queue.Empty: + self.logger.debug(f'queue empty') + self.reqrpyqueue_notempty.clear() + await self.reqrpyqueue_notempty.wait() + continue + + reqrpy = self.get_reqrpy(key) + if reqrpy is None: + self.logger.debug(f'no data for key {key=}') + continue + if reqrpy.get('result') is not None: + self.logger.debug(f'has result! {key=}') + continue + if expires <= now(): + self.logger.debug(f'expired {key=}') + self._set_reqrpy_result(key, Timeout()) + continue + if scheduled > now(): + # return to queue + self.reqrpyqueue.put_nowait((scheduled, expires, key)) + await asyncio.sleep(1) # sleep here, as the first queue item wasn't due yet + continue + + try: + await self._send_pending_reqrpy(key) + except BaseException as e: + self.logger.debug(f'error while sending {key=}') + self._set_reqrpy_result(key, e) + else: + self.reqrpyqueue.put_nowait((now() + REQUEST_REPLY_RETRY_DELAY, expires, key)) + + def get_reqrpy(self, key): + with self.pending_lock: + return self.pending.get(key) + + def _set_reqrpy_result(self, key, result): + with self.pending_lock: + reqrpy = self.pending.get(key) + if reqrpy is None: + return + self.pending[key]['result'] = result + reqrpy['ev'].set() + + def _remove_reqrpy(self, key): + with self.pending_lock: + reqrpy = self.pending.get(key) + if reqrpy is None: + return + reqrpy['ev'].set() + del self.pending[key] + + def submit_reqrpy(self, *, + payload: dict, + node_id_or_blinded_path: bytes): + """Add onion message to queue for sending. Queued onion message payloads + are supplied with a path_id and a reply_path to determine which request + corresponds with arriving replies. + returns awaitable task""" + key = os.urandom(8) + self.logger.debug(f'submit_reqrpy {key=} {payload=} {node_id_or_blinded_path=}') + with self.pending_lock: + self.pending[key] = { + 'ev': asyncio.Event(), + 'payload': payload, + 'node_id_or_blinded_path': node_id_or_blinded_path + } + + # tuple = (when to process, when it expires, key) + expires = now() + REQUEST_REPLY_TIMEOUT + queueitem = (now(), expires, key) + self.reqrpyqueue.put_nowait(queueitem) + task = asyncio.create_task(self._reqrpy_task(key)) + self.reqrpyqueue_notempty.set() + return task + + async def _reqrpy_task(self, key): + reqrpy = self.get_reqrpy(key) + assert reqrpy + if reqrpy is None: + return + try: + self.logger.debug(f'wait task start {key}') + await reqrpy['ev'].wait() + finally: + self.logger.debug(f'wait task end {key}') + + try: + reqrpy = self.get_reqrpy(key) + assert reqrpy + result = reqrpy.get('result') + if isinstance(result, Exception): + raise result + return result + finally: + self._remove_reqrpy(key) + + async def _send_pending_reqrpy(self, key): + """adds reply_path to payload""" + data = self.get_reqrpy(key) + payload = data.get('payload') + node_id_or_blinded_path = data.get('node_id_or_blinded_path') + self.logger.debug(f'send_reqrpy {key=} {payload=} {node_id_or_blinded_path=}') + + path_id = self._path_id_from_payload_and_key(payload, key) + final_recipient_data = { + 'path_id': {'data': path_id} + } + + # TODO: decide blinded path introduction point (for now, just my own nodeid) + # Note: blinded path session_key != onion message session_key + rbp_session_key = os.urandom(32) + reply_path_nodes = [self.lnwallet.node_keypair.pubkey] + reply_path = create_blinded_path(rbp_session_key, reply_path_nodes, final_recipient_data) + + final_payload = copy.deepcopy(payload) + final_payload['reply_path'] = {'path': reply_path} + + # TODO: we should try alternate paths when retrying, this is currently not done. + # (send_onion_message_to decides path, without knowledge of prev attempts) + send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload) + + def _path_id_from_payload_and_key(self, payload: dict, key: bytes) -> bytes: + # TODO: construct path_id in such a way that we can determine the request originated from us and is not spoofed + # TODO: use payload to determine prefix? + return b'electrum' + key def on_onion_message_received(self, recipient_data, payload): # we are destination, sanity checks @@ -366,27 +490,22 @@ def on_onion_message_received(self, recipient_data, payload): if 'path_id' not in recipient_data: # unsolicited onion_message self.on_onion_message_received_unsolicited(recipient_data, payload) - return - - with self.pending_lock: - # check if this reply is associated with a known request - # TODO: construct path_id in such a way that we can determine the request originated from us and is not spoofed - correl_data = recipient_data['path_id'].get('data') - if not correl_data[:15] == b'electrum_invreq': - logger.warning('not a reply to our request') - return - if not correl_data[15:] in self.pending: - logger.warning('not a reply to our request') - return - - del self.pending[correl_data[15:]] + else: + self.on_onion_message_received_reply(recipient_data, payload) - # hardcoded, assumed invoice response - invoice_tlv = payload['invoice']['invoice'] - with io.BytesIO(invoice_tlv) as fd: - invoice_data = OnionWireSerializer.read_tlv_stream(fd=fd, tlv_stream_name='invoice') + def on_onion_message_received_reply(self, recipient_data, payload): + # check if this reply is associated with a known request + correl_data = recipient_data['path_id'].get('data') + if not correl_data[:8] == b'electrum': + logger.warning('not a reply to our request (unknown path_id prefix)') + return + key = correl_data[8:] + reqrpy = self.get_reqrpy(key) + if reqrpy is None: + logger.warning('not a reply to our request (unknown request)') + return - logger.debug(f'invoice {invoice_data!r}') + self._set_reqrpy_result(key, (recipient_data, payload)) def on_onion_message_received_unsolicited(self, recipient_data, payload): logger.debug('unsolicited onion_message received') @@ -439,7 +558,9 @@ def on_onion_message_forward(self, recipient_data, onion_packet, blinding, share next_blinding = next_public_key_int.get_public_key_bytes() onion_packet_b = onion_packet.to_bytes() + # construct onion message + # TODO: add queue, delay to avoid traffic analysis next_peer.send_message( "onion_message", blinding=next_blinding, @@ -485,38 +606,3 @@ def on_onion_message(self, payload): self.on_onion_message_received(recipient_data, payload) else: self.on_onion_message_forward(recipient_data, processed_onion_packet.next_packet, blinding, shared_secret) - - def submit_onion_message(self, *, payload: dict, node_id_or_blinded_path: bytes): - """Add onion message to queue for sending. Queued onion message payloads - are supplied with a path_id and a reply_path to determine which request - corresponds with arriving replies. """ - self.logger.debug('submit_onion_message') - key = os.urandom(16) - state = { - 'submitted': now(), - 'last_attempt': 0 - } - self.pending[key] = (state, payload, node_id_or_blinded_path) - return key - - async def send_pending_onion_message(self, key): - """adds reply_path to payload""" - self.logger.debug('send_pending_onion_message') - - state, payload, node_id_or_blinded_path = self.pending[key] - - # TODO: construct path_id in such a way that we can determine the request originated from us and is not spoofed - final_recipient_data = { - 'path_id': {'data': b'electrum_invreq' + key} - } - - # TODO: decide blinded path introduction point (for now, just my own nodeid) - # Note: blinded path session_key != onion message session_key - rbp_session_key = os.urandom(32) - reply_path_nodes = [self.lnwallet.node_keypair.pubkey] - reply_path = create_blinded_path(rbp_session_key, reply_path_nodes, final_recipient_data) - - final_payload = copy.deepcopy(payload) - final_payload['reply_path'] = {'path': reply_path} - - send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload)