Skip to content

Commit

Permalink
onion_messages: request-reply queue
Browse files Browse the repository at this point in the history
  • Loading branch information
accumulator committed Jul 17, 2024
1 parent ae51d2c commit 7bb7596
Showing 1 changed file with 168 additions and 82 deletions.
250 changes: 168 additions & 82 deletions electrum/onion_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import copy
import io
import os
import queue
import threading

from typing import TYPE_CHECKING, Optional, List, Sequence

from electrum import ecc
Expand All @@ -46,6 +48,10 @@
logger = get_logger(__name__)


REQUEST_REPLY_TIMEOUT = 120
REQUEST_REPLY_RETRY_DELAY = 5


def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_data: dict, hop_extras: Optional[Sequence[dict]] = None):
introduction_point = path[0]

Expand Down Expand Up @@ -114,10 +120,7 @@ def encrypt_onionmsg_tlv_hops_data(hops_data, hop_shared_secrets):
hops_data[i].payload['encrypted_recipient_data'] = {'encrypted_recipient_data': encrypted_recipient_data}


# TODO: integrate this with OnionMessageManager below for retry/rate-limit etc
def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None):
assert wallet.lnworker, 'not a lightning wallet'

def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None):
if session_key is None:
session_key = os.urandom(32)

Expand All @@ -130,28 +133,28 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de
logger.debug(f'blinded path: {blinded_path!r}')
except Exception as e:
logger.error(f'e!r')
return
raise

introduction_point = blinded_path['first_node_id']

hops_data = []
blinded_node_ids = []

if wallet.lnworker.node_keypair.pubkey == introduction_point:
if lnwallet.node_keypair.pubkey == introduction_point:
# blinded path introduction point is me
our_blinding = blinded_path['blinding']
our_payload = blinded_path['path'][0]
remaining_blinded_path = blinded_path['path'][1:]
assert len(remaining_blinded_path) > 0, 'sending to myself?'

# decrypt
shared_secret = get_ecdh(wallet.lnworker.node_keypair.privkey, our_blinding)
shared_secret = get_ecdh(lnwallet.node_keypair.privkey, our_blinding)
recipient_data = decrypt_encrypted_data_tlv(
shared_secret=shared_secret,
encrypted_recipient_data=our_payload['encrypted_recipient_data']
)

peer = wallet.lnworker.peers.get(recipient_data['next_node_id']['node_id'])
peer = lnwallet.peers.get(recipient_data['next_node_id']['node_id'])
assert peer, 'next_node_id not a peer'

# blinding override?
Expand All @@ -170,7 +173,7 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de
else:
# we need a route to introduction point
remaining_blinded_path = blinded_path['path']
peer = wallet.lnworker.peers.get(introduction_point)
peer = lnwallet.peers.get(introduction_point)
# if blinded path introduction point is our direct peer, no need to route-find
if peer:
# start of blinded path is our peer
Expand All @@ -187,7 +190,7 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de
raise Exception('no path found')

# first hop must be our peer
peer = wallet.lnworker.peers.get(path[0].end_node)
peer = lnwallet.peers.get(path[0].end_node)
assert peer, 'first hop not a peer'

# last hop is introduction point and start of blinded path. remove from route
Expand Down Expand Up @@ -247,19 +250,19 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de
else: # node pubkey
pubkey = node_id_or_blinded_path

if wallet.lnworker.node_keypair.pubkey == pubkey:
if lnwallet.node_keypair.pubkey == pubkey:
raise Exception('cannot send to myself')

hops_data = []
peer = wallet.lnworker.peers.get(pubkey)
peer = lnwallet.peers.get(pubkey)

if peer:
# destination is our direct peer, no need to route-find
path = [PathEdge(short_channel_id=None, start_node=None, end_node=pubkey)]
else:
# route-find to pubkey.
path = wallet.lnworker.network.path_finder.find_path_for_payment(
nodeA=wallet.lnworker.node_keypair.pubkey,
path = lnwallet.network.path_finder.find_path_for_payment(
nodeA=lnwallet.node_keypair.pubkey,
nodeB=pubkey,
invoice_amount_msat=10000, # TODO: do this without amount constraints
node_filter=is_onion_message_node
Expand All @@ -268,7 +271,7 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de
raise Exception('no path found')

# first hop must be our peer
peer = wallet.lnworker.peers.get(path[0].end_node)
peer = lnwallet.peers.get(path[0].end_node)
assert peer, 'first hop not a peer'

hops_data = [
Expand Down Expand Up @@ -302,6 +305,9 @@ def send_onion_message_to(wallet: 'LNWallet', node_id_or_blinded_path: bytes, de
)


class Timeout(Exception): pass


class OnionMessageManager(Logger):
"""handle state around onion message sends and receives
- association between onion message and their replies
Expand All @@ -315,6 +321,8 @@ def __init__(self, lnwallet: 'LNWallet'):

self.pending = {}
self.pending_lock = threading.Lock()
self.reqrpyqueue = queue.PriorityQueue()
self.reqrpyqueue_notempty = asyncio.Event()

def start_network(self, *, network: 'Network'):
assert network
Expand All @@ -327,7 +335,7 @@ async def main_loop(self):
self.logger.info("starting taskgroup.")
try:
async with self.taskgroup as group:
await group.spawn(self.process())
await group.spawn(self.process_request_reply_queue())
except Exception as e:
self.logger.exception("taskgroup died.")
else:
Expand All @@ -336,19 +344,135 @@ async def main_loop(self):
async def stop(self):
await self.taskgroup.cancel_remaining()

async def process(self):
# TODO: naive and barely tested send queue loop, this needs more love.
async def process_request_reply_queue(self):
while True:
await asyncio.sleep(2)
with self.pending_lock:
for key, pending_item in self.pending.items():
state = pending_item[0]
if now() - state['submitted'] > 120: # expired
continue
if now() - state['last_attempt'] > 5:
state['last_attempt'] = now()
self.logger.debug('spawning onionmsg send')
await self.taskgroup.spawn(self.send_pending_onion_message(key))
try:
scheduled, expires, key = self.reqrpyqueue.get_nowait()
except queue.Empty:
self.logger.debug(f'queue empty')
self.reqrpyqueue_notempty.clear()
await self.reqrpyqueue_notempty.wait()
continue

reqrpy = self.get_reqrpy(key)
if reqrpy is None:
self.logger.debug(f'no data for key {key=}')
continue
if reqrpy.get('result') is not None:
self.logger.debug(f'has result! {key=}')
continue
if expires <= now():
self.logger.debug(f'expired {key=}')
self._set_reqrpy_result(key, Timeout())
continue
if scheduled > now():
# return to queue
self.reqrpyqueue.put_nowait((scheduled, expires, key))
await asyncio.sleep(1) # sleep here, as the first queue item wasn't due yet
continue

try:
await self._send_pending_reqrpy(key)
except BaseException as e:
self.logger.debug(f'error while sending {key=}')
self._set_reqrpy_result(key, e)
else:
self.reqrpyqueue.put_nowait((now() + REQUEST_REPLY_RETRY_DELAY, expires, key))

def get_reqrpy(self, key):
with self.pending_lock:
return self.pending.get(key)

def _set_reqrpy_result(self, key, result):
with self.pending_lock:
reqrpy = self.pending.get(key)
if reqrpy is None:
return
self.pending[key]['result'] = result
reqrpy['ev'].set()

def _remove_reqrpy(self, key):
with self.pending_lock:
reqrpy = self.pending.get(key)
if reqrpy is None:
return
reqrpy['ev'].set()
del self.pending[key]

def submit_reqrpy(self, *,
payload: dict,
node_id_or_blinded_path: bytes):
"""Add onion message to queue for sending. Queued onion message payloads
are supplied with a path_id and a reply_path to determine which request
corresponds with arriving replies.
returns generated key, for internal tracking and caller cancelling the request"""
key = os.urandom(8)
self.logger.debug(f'submit_reqrpy {key=} {payload=} {node_id_or_blinded_path=}')
with self.pending_lock:
self.pending[key] = {
'ev': asyncio.Event(),
'payload': payload,
'node_id_or_blinded_path': node_id_or_blinded_path
}

# tuple = (when to process, when it expires, key)
expires = now() + REQUEST_REPLY_TIMEOUT
queueitem = (now(), expires, key)
self.reqrpyqueue.put_nowait(queueitem)
task = asyncio.create_task(self._reqrpy_task(key))
self.reqrpyqueue_notempty.set()
return task

async def _reqrpy_task(self, key):
reqrpy = self.get_reqrpy(key)
assert reqrpy
if reqrpy is None:
return
try:
self.logger.debug(f'wait task start {key}')
await reqrpy['ev'].wait()
finally:
self.logger.debug(f'wait task end {key}')

try:
reqrpy = self.get_reqrpy(key)
assert reqrpy
result = reqrpy.get('result')
if isinstance(result, Exception):
raise result
return result
finally:
self._remove_reqrpy(key)

async def _send_pending_reqrpy(self, key):
"""adds reply_path to payload"""
data = self.get_reqrpy(key)
payload = data.get('payload')
node_id_or_blinded_path = data.get('node_id_or_blinded_path')
self.logger.debug(f'send_reqrpy {key=} {payload=} {node_id_or_blinded_path=}')

path_id = self._path_id_from_payload_and_key(payload, key)
final_recipient_data = {
'path_id': {'data': path_id}
}

# TODO: decide blinded path introduction point (for now, just my own nodeid)
# Note: blinded path session_key != onion message session_key
rbp_session_key = os.urandom(32)
reply_path_nodes = [self.lnwallet.node_keypair.pubkey]
reply_path = create_blinded_path(rbp_session_key, reply_path_nodes, final_recipient_data)

final_payload = copy.deepcopy(payload)
final_payload['reply_path'] = {'path': reply_path}

# TODO: we should try alternate paths when retrying, this is currently not done.
# (send_onion_message_to decides path, without knowledge of prev attempts)
send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload)

def _path_id_from_payload_and_key(self, payload: dict, key: bytes) -> bytes:
# TODO: construct path_id in such a way that we can determine the request originated from us and is not spoofed
# TODO: use payload to determine prefix?
return b'electrum' + key

def on_onion_message_received(self, recipient_data, payload):
# we are destination, sanity checks
Expand All @@ -366,27 +490,22 @@ def on_onion_message_received(self, recipient_data, payload):
if 'path_id' not in recipient_data:
# unsolicited onion_message
self.on_onion_message_received_unsolicited(recipient_data, payload)
return

with self.pending_lock:
# check if this reply is associated with a known request
# TODO: construct path_id in such a way that we can determine the request originated from us and is not spoofed
correl_data = recipient_data['path_id'].get('data')
if not correl_data[:15] == b'electrum_invreq':
logger.warning('not a reply to our request')
return
if not correl_data[15:] in self.pending:
logger.warning('not a reply to our request')
return

del self.pending[correl_data[15:]]
else:
self.on_onion_message_received_reply(recipient_data, payload)

# hardcoded, assumed invoice response
invoice_tlv = payload['invoice']['invoice']
with io.BytesIO(invoice_tlv) as fd:
invoice_data = OnionWireSerializer.read_tlv_stream(fd=fd, tlv_stream_name='invoice')
def on_onion_message_received_reply(self, recipient_data, payload):
# check if this reply is associated with a known request
correl_data = recipient_data['path_id'].get('data')
if not correl_data[:8] == b'electrum':
logger.warning('not a reply to our request (unknown path_id prefix)')
return
key = correl_data[8:]
reqrpy = self.get_reqrpy(key)
if reqrpy is None:
logger.warning('not a reply to our request (unknown request)')
return

logger.debug(f'invoice {invoice_data!r}')
self._set_reqrpy_result(key, (recipient_data, payload))

def on_onion_message_received_unsolicited(self, recipient_data, payload):
logger.debug('unsolicited onion_message received')
Expand Down Expand Up @@ -439,7 +558,9 @@ def on_onion_message_forward(self, recipient_data, onion_packet, blinding, share
next_blinding = next_public_key_int.get_public_key_bytes()

onion_packet_b = onion_packet.to_bytes()

# construct onion message
# TODO: add queue, delay to avoid traffic analysis
next_peer.send_message(
"onion_message",
blinding=next_blinding,
Expand Down Expand Up @@ -485,38 +606,3 @@ def on_onion_message(self, payload):
self.on_onion_message_received(recipient_data, payload)
else:
self.on_onion_message_forward(recipient_data, processed_onion_packet.next_packet, blinding, shared_secret)

def submit_onion_message(self, *, payload: dict, node_id_or_blinded_path: bytes):
"""Add onion message to queue for sending. Queued onion message payloads
are supplied with a path_id and a reply_path to determine which request
corresponds with arriving replies. """
self.logger.debug('submit_onion_message')
key = os.urandom(16)
state = {
'submitted': now(),
'last_attempt': 0
}
self.pending[key] = (state, payload, node_id_or_blinded_path)
return key

async def send_pending_onion_message(self, key):
"""adds reply_path to payload"""
self.logger.debug('send_pending_onion_message')

state, payload, node_id_or_blinded_path = self.pending[key]

# TODO: construct path_id in such a way that we can determine the request originated from us and is not spoofed
final_recipient_data = {
'path_id': {'data': b'electrum_invreq' + key}
}

# TODO: decide blinded path introduction point (for now, just my own nodeid)
# Note: blinded path session_key != onion message session_key
rbp_session_key = os.urandom(32)
reply_path_nodes = [self.lnwallet.node_keypair.pubkey]
reply_path = create_blinded_path(rbp_session_key, reply_path_nodes, final_recipient_data)

final_payload = copy.deepcopy(payload)
final_payload['reply_path'] = {'path': reply_path}

send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload)

0 comments on commit 7bb7596

Please sign in to comment.