Skip to content

Commit

Permalink
unasync, no add_peer in create_onion_message_route_to, add manager tests
Browse files Browse the repository at this point in the history
  • Loading branch information
accumulator committed Jan 15, 2025
1 parent 4ba665c commit bc6be0b
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 32 deletions.
2 changes: 1 addition & 1 deletion electrum/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,7 @@ async def send_onion_message(self, node_id_or_blinded_path_hex: str, message: st
}

try:
await send_onion_message_to(wallet.lnworker, node_id_or_blinded_path, destination_payload)
send_onion_message_to(wallet.lnworker, node_id_or_blinded_path, destination_payload)
return {'success': True}
except Exception as e:
msg = str(e)
Expand Down
76 changes: 48 additions & 28 deletions electrum/onion_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from electrum.lnworker import LNWallet
from electrum.network import Network
from electrum.lnrouter import NodeInfo
from electrum.lntransport import LNPeerAddr
from asyncio import Task

logger = get_logger(__name__)
Expand All @@ -59,6 +60,12 @@
FORWARD_MAX_QUEUE = 3


class NoRouteFound(Exception):
def __init__(self, *args, peer_address: 'LNPeerAddr' = None):
Exception.__init__(self, *args)
self.peer_address = peer_address


def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_data: dict, *,
hop_extras: Optional[Sequence[dict]] = None,
dummy_hops: Optional[int] = 0) -> dict:
Expand Down Expand Up @@ -135,7 +142,7 @@ def encrypt_onionmsg_tlv_hops_data(hops_data, hop_shared_secrets):
hops_data[i].payload['encrypted_recipient_data'] = {'encrypted_recipient_data': encrypted_recipient_data}


async def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> List[PathEdge]:
def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> List[PathEdge]:
"""Constructs a route to the destination node_id, first by starting with peers with existing channels,
and if no route found, opening a direct peer connection if node_id is found with an address in
channel_db."""
Expand All @@ -145,7 +152,7 @@ async def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) ->
chan.is_active() and not chan.is_frozen_for_sending()]
my_sending_channels = {chan.short_channel_id: chan for chan in my_active_channels
if chan.short_channel_id is not None}
# strat1: find route to introduction point over existing channel mesh
# find route to introduction point over existing channel mesh
# NOTE: nodes that are in channel_db but are offline are not removed from the set
if lnwallet.network.path_finder:
if path := lnwallet.network.path_finder.find_path_for_payment(
Expand All @@ -156,17 +163,19 @@ async def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) ->
my_sending_channels=my_sending_channels
): return path

# strat2: dest node has host:port in channel_db? then open direct peer connection
# alt: dest is existing peer?
if lnwallet.peers.get(node_id):
return [PathEdge(short_channel_id=None, start_node=None, end_node=node_id)]

# if we have an address, pass it.
if lnwallet.channel_db:
if peer_addr := lnwallet.channel_db.get_last_good_address(node_id):
peer = await lnwallet.add_peer(str(peer_addr))
await peer.initialized
return [PathEdge(short_channel_id=None, start_node=None, end_node=node_id)]
raise NoRouteFound('no path found, peer_addr available', peer_address=peer_addr)

raise Exception('no path found')
raise NoRouteFound('no path found')


async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None):
def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None):
if session_key is None:
session_key = os.urandom(32)

Expand Down Expand Up @@ -226,7 +235,7 @@ async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: b
# start of blinded path is our peer
blinding = blinded_path['blinding']
else:
path = await create_onion_message_route_to(lnwallet, introduction_point)
path = create_onion_message_route_to(lnwallet, introduction_point)

# first edge must be to our peer
peer = lnwallet.peers.get(path[0].end_node)
Expand Down Expand Up @@ -303,7 +312,7 @@ async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: b
# destination is our direct peer, no need to route-find
path = [PathEdge(short_channel_id=None, start_node=None, end_node=pubkey)]
else:
path = await create_onion_message_route_to(lnwallet, pubkey)
path = create_onion_message_route_to(lnwallet, pubkey)

# first edge must be to our peer
peer = lnwallet.peers.get(path[0].end_node)
Expand Down Expand Up @@ -340,9 +349,9 @@ async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: b
)


async def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
max_paths: int = REQUEST_REPLY_PATHS_MAX,
preferred_node_id: bytes = None) -> List[dict]:
def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
max_paths: int = REQUEST_REPLY_PATHS_MAX,
preferred_node_id: bytes = None) -> List[dict]:
# TODO: build longer paths and/or add dummy hops to increase privacy
my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()]
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and
Expand Down Expand Up @@ -376,7 +385,7 @@ class OnionMessageManager(Logger):
- association between onion message and their replies
- manage re-send attempts, TODO: iterate through routes (both directions)"""

def __init__(self, lnwallet: 'LNWallet'):
def __init__(self, lnwallet: 'LNWallet', *, request_reply_timeout=REQUEST_REPLY_TIMEOUT):
Logger.__init__(self)
self.network = None # type: Optional['Network']
self.taskgroup = None # type: OldTaskGroup
Expand All @@ -389,6 +398,8 @@ def __init__(self, lnwallet: 'LNWallet'):
self.forwardqueue = queue.PriorityQueue()
self.forwardqueue_notempty = asyncio.Event()

self.request_reply_timeout = request_reply_timeout

def start_network(self, *, network: 'Network'):
assert network
assert self.network is None, "already started"
Expand All @@ -415,13 +426,13 @@ async def process_forward_queue(self):
try:
scheduled, expires, onion_packet, blinding, node_id = self.forwardqueue.get_nowait()
except queue.Empty:
self.logger.debug(f'fwd queue empty')
self.logger.info(f'forward queue empty')
self.forwardqueue_notempty.clear()
await self.forwardqueue_notempty.wait()
continue

if expires <= now():
self.logger.debug(f'fwd expired {node_id=}')
self.logger.debug(f'forward expired {node_id=}')
continue
if scheduled > now():
# return to queue
Expand All @@ -448,7 +459,7 @@ def submit_forward(self, *,
blinding: bytes,
node_id: bytes):
if self.forwardqueue.qsize() >= FORWARD_MAX_QUEUE:
self.logger.debug('fwd queue full, dropping packet')
self.logger.debug('forward queue full, dropping packet')
return
expires = now() + FORWARD_RETRY_TIMEOUT
queueitem = (now(), expires, onion_packet, blinding, node_id)
Expand All @@ -460,9 +471,13 @@ async def process_request_reply_queue(self):
try:
scheduled, expires, key = self.requestreply_queue.get_nowait()
except queue.Empty:
self.logger.debug(f'requestreply queue empty')
self.logger.info(f'requestreply queue empty')
self.requestreply_queue_notempty.clear()
await self.requestreply_queue_notempty.wait()
try:
self.requestreply_queue_notempty.clear()
await self.requestreply_queue_notempty.wait() # NOTE: quirk, see note below
except Exception as e:
self.logger.info(f'Exception e={e!r}')
continue

requestreply = self.get_requestreply(key)
Expand All @@ -483,11 +498,16 @@ async def process_request_reply_queue(self):
continue

try:
await self._send_pending_requestreply(key)
self._send_pending_requestreply(key)
except BaseException as e:
self.logger.debug(f'error while sending {key=}')
self._set_requestreply_result(key, e)
self.logger.debug(f'error while sending {key=} {e!r}')
self._set_requestreply_result(key, copy.copy(e))
# NOTE: above, when passing the caught exception instance e directly it leads to GeneratorExit() in
# queue_notempty.wait() later (??). pass a copy instead.
if isinstance(e, NoRouteFound) and e.peer_address:
await self.lnwallet.add_peer(str(e.peer_address))
else:
self.logger.debug(f'resubmit {key=}')
self.requestreply_queue.put_nowait((now() + REQUEST_REPLY_RETRY_DELAY, expires, key))

def get_requestreply(self, key):
Expand All @@ -498,6 +518,7 @@ def _set_requestreply_result(self, key, result):
with self.pending_lock:
requestreply = self.pending.get(key)
if requestreply is None:
self.logger.error(f'requestreply with {key=} not found!')
return
self.pending[key]['result'] = result
requestreply['ev'].set()
Expand Down Expand Up @@ -537,7 +558,7 @@ def submit_requestreply(self, *,
}

# tuple = (when to process, when it expires, key)
expires = now() + REQUEST_REPLY_TIMEOUT
expires = now() + self.request_reply_timeout
queueitem = (now(), expires, key)
self.requestreply_queue.put_nowait(queueitem)
task = asyncio.create_task(self._requestreply_task(key))
Expand All @@ -560,12 +581,12 @@ async def _requestreply_task(self, key):
assert requestreply
result = requestreply.get('result')
if isinstance(result, Exception):
raise result
raise result # raising in the task requires caller to explicitly extract exception.
return result
finally:
self._remove_requestreply(key)

async def _send_pending_requestreply(self, key):
def _send_pending_requestreply(self, key):
"""adds reply_path to payload"""
data = self.get_requestreply(key)
payload = data.get('payload')
Expand All @@ -577,18 +598,17 @@ async def _send_pending_requestreply(self, key):
if 'reply_path' not in final_payload:
# unless explicitly set in payload, generate reply_path here
path_id = self._path_id_from_payload_and_key(payload, key)
reply_paths = await get_blinded_reply_paths(self.lnwallet, path_id, max_paths=1)
reply_paths = get_blinded_reply_paths(self.lnwallet, path_id, max_paths=1)
if not reply_paths:
raise Exception(f'Could not create a reply_path for {key=}')

final_payload['reply_path'] = {'path': reply_paths}

# TODO: we should try alternate paths when retrying, this is currently not done.
# (send_onion_message_to decides path, without knowledge of prev attempts)
await send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload)
send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload)

def _path_id_from_payload_and_key(self, payload: dict, key: bytes) -> bytes:
# TODO: construct path_id in such a way that we can determine the request originated from us and is not spoofed
# TODO: use payload to determine prefix?
return b'electrum' + key

Expand Down
Loading

0 comments on commit bc6be0b

Please sign in to comment.