Skip to content

Commit

Permalink
add onion message support
Browse files Browse the repository at this point in the history
  • Loading branch information
accumulator committed Nov 22, 2024
1 parent 8a28239 commit c044bcd
Show file tree
Hide file tree
Showing 10 changed files with 1,386 additions and 33 deletions.
73 changes: 68 additions & 5 deletions electrum/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import io
import sys
import datetime
import copy
Expand All @@ -44,9 +44,12 @@

from . import util
from . import keystore
from .util import (bfh, format_satoshis, json_decode, json_normalize,
is_hash256_str, is_hex_str, to_bytes, parse_max_spend, to_decimal,
UserFacingException, InvalidPassword)
from .lnmsg import OnionWireSerializer
from .logging import Logger
from .onion_message import create_blinded_path, send_onion_message_to
from .util import (bfh, format_satoshis, json_decode, json_normalize, is_hash256_str, is_hex_str, to_bytes,
parse_max_spend, to_decimal, UserFacingException, InvalidPassword)

from . import bitcoin
from .bitcoin import is_address, hash_160, COIN
from .bip32 import BIP32Node
Expand Down Expand Up @@ -173,11 +176,12 @@ async def func_wrapper(*args, **kwargs):
return decorator


class Commands:
class Commands(Logger):

def __init__(self, *, config: 'SimpleConfig',
network: 'Network' = None,
daemon: 'Daemon' = None, callback=None):
Logger.__init__(self)
self.config = config
self.daemon = daemon
self.network = network
Expand Down Expand Up @@ -1413,6 +1417,62 @@ async def convert_currency(self, from_amount=1, from_ccy = '', to_ccy = ''):
"source": self.daemon.fx.exchange.name(),
}

@command('wnl')
async def send_onion_message(self, node_id_or_blinded_path_hex: str, message: str, wallet: Abstract_Wallet = None):
"""
Send an onion message with onionmsg_tlv.message payload to node_id.
"""
assert wallet
assert wallet.lnworker
assert node_id_or_blinded_path_hex
assert message

node_id_or_blinded_path = bfh(node_id_or_blinded_path_hex)
assert len(node_id_or_blinded_path) >= 33

destination_payload = {
'message': {'text': message.encode('utf-8')}
}

try:
await send_onion_message_to(wallet.lnworker, node_id_or_blinded_path, destination_payload)
return {'success': True}
except Exception as e:
msg = str(e)

return {
'success': False,
'msg': msg
}

@command('wnl')
async def get_blinded_path_via(self, node_id: str, dummy_hops: int = 0, wallet: Abstract_Wallet = None):
"""
Create a blinded path with node_id as introduction point. Introduction point must be direct peer of me.
"""
# TODO: allow introduction_point to not be a direct peer and construct a route
assert wallet
assert node_id

pubkey = bfh(node_id)
assert len(pubkey) == 33, 'invalid node_id'

peer = wallet.lnworker.peers[pubkey]
assert peer, 'node_id not a peer'

path = [pubkey, wallet.lnworker.node_keypair.pubkey]
session_key = os.urandom(32)
blinded_path = create_blinded_path(session_key, path=path, final_recipient_data={}, dummy_hops=dummy_hops)

with io.BytesIO() as blinded_path_fd:
OnionWireSerializer._write_complex_field(fd=blinded_path_fd,
field_type='blinded_path',
count=1,
value=blinded_path)
encoded_blinded_path = blinded_path_fd.getvalue()

return encoded_blinded_path.hex()


def eval_bool(x: str) -> bool:
if x == 'false': return False
Expand Down Expand Up @@ -1440,6 +1500,7 @@ def eval_bool(x: str) -> bool:
'redeem_script': 'redeem script (hexadecimal)',
'lightning_amount': "Amount sent or received in a submarine swap. Set it to 'dryrun' to receive a value",
'onchain_amount': "Amount sent or received in a submarine swap. Set it to 'dryrun' to receive a value",
'node_id': "Node pubkey in hex format"
}

command_options = {
Expand Down Expand Up @@ -1494,6 +1555,7 @@ def eval_bool(x: str) -> bool:
'from_ccy': (None, "Currency to convert from"),
'to_ccy': (None, "Currency to convert to"),
'public': (None, 'Channel will be announced'),
'dummy_hops': (None, 'Number of dummy hops to add'),
}


Expand All @@ -1519,6 +1581,7 @@ def eval_bool(x: str) -> bool:
'encrypt_file': eval_bool,
'rbf': eval_bool,
'timeout': float,
'dummy_hops': int,
}

config_variables = {
Expand Down
99 changes: 79 additions & 20 deletions electrum/lnonion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import electrum_ecc as ecc

from .crypto import sha256, hmac_oneshot, chacha20_encrypt, get_ecdh
from .crypto import sha256, hmac_oneshot, chacha20_encrypt, get_ecdh, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt
from .util import profiler, xor_bytes, bfh
from .lnutil import (PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH,
NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, OnionFailureCodeMetaFlag)
Expand All @@ -44,20 +44,25 @@
HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04
TRAMPOLINE_HOPS_DATA_SIZE = 400
PER_HOP_HMAC_SIZE = 32

ONION_MESSAGE_LARGE_SIZE = 32768

class UnsupportedOnionPacketVersion(Exception): pass
class InvalidOnionMac(Exception): pass
class InvalidOnionPubkey(Exception): pass
class InvalidPayloadSize(Exception): pass


class OnionHopsDataSingle: # called HopData in lnd

def __init__(self, *, payload: dict = None):
def __init__(self, *, payload: dict = None, tlv_stream_name: str = 'payload', blind_fields: dict = None):
if payload is None:
payload = {}
self.payload = payload
self.hmac = None
self.tlv_stream_name = tlv_stream_name
if blind_fields is None:
blind_fields = {}
self.blind_fields = blind_fields
self._raw_bytes_payload = None # used in unit tests

def to_bytes(self) -> bytes:
Expand All @@ -69,7 +74,7 @@ def to_bytes(self) -> bytes:
# adding TLV payload. note: legacy hop data format no longer supported.
payload_fd = io.BytesIO()
OnionWireSerializer.write_tlv_stream(fd=payload_fd,
tlv_stream_name="payload",
tlv_stream_name=self.tlv_stream_name,
**self.payload)
payload_bytes = payload_fd.getvalue()
with io.BytesIO() as fd:
Expand All @@ -79,7 +84,7 @@ def to_bytes(self) -> bytes:
return fd.getvalue()

@classmethod
def from_fd(cls, fd: io.BytesIO) -> 'OnionHopsDataSingle':
def from_fd(cls, fd: io.BytesIO, *, tlv_stream_name: str = 'payload') -> 'OnionHopsDataSingle':
first_byte = fd.read(1)
if len(first_byte) == 0:
raise Exception(f"unexpected EOF")
Expand All @@ -95,9 +100,9 @@ def from_fd(cls, fd: io.BytesIO) -> 'OnionHopsDataSingle':
hop_payload = fd.read(hop_payload_length)
if hop_payload_length != len(hop_payload):
raise Exception(f"unexpected EOF")
ret = OnionHopsDataSingle()
ret = OnionHopsDataSingle(tlv_stream_name=tlv_stream_name)
ret.payload = OnionWireSerializer.read_tlv_stream(fd=io.BytesIO(hop_payload),
tlv_stream_name="payload")
tlv_stream_name=tlv_stream_name)
ret.hmac = fd.read(PER_HOP_HMAC_SIZE)
assert len(ret.hmac) == PER_HOP_HMAC_SIZE
return ret
Expand All @@ -110,7 +115,7 @@ class OnionPacket:

def __init__(self, public_key: bytes, hops_data: bytes, hmac: bytes):
assert len(public_key) == 33
assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE]
assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]
assert len(hmac) == PER_HOP_HMAC_SIZE
self.version = 0
self.public_key = public_key
Expand All @@ -127,13 +132,13 @@ def to_bytes(self) -> bytes:
ret += self.public_key
ret += self.hops_data
ret += self.hmac
if len(ret) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE]:
if len(ret) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]:
raise Exception('unexpected length {}'.format(len(ret)))
return ret

@classmethod
def from_bytes(cls, b: bytes):
if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE]:
if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]:
raise Exception('unexpected length {}'.format(len(b)))
version = b[0]
if version != 0:
Expand All @@ -146,42 +151,70 @@ def from_bytes(cls, b: bytes):


def get_bolt04_onion_key(key_type: bytes, secret: bytes) -> bytes:
if key_type not in (b'rho', b'mu', b'um', b'ammag', b'pad'):
if key_type not in (b'rho', b'mu', b'um', b'ammag', b'pad', b'blinded_node_id'):
raise Exception('invalid key_type {}'.format(key_type))
key = hmac_oneshot(key_type, msg=secret, digest=hashlib.sha256)
return key


def get_shared_secrets_along_route(payment_path_pubkeys: Sequence[bytes],
session_key: bytes) -> Sequence[bytes]:
session_key: bytes) -> Tuple[Sequence[bytes], Sequence[bytes]]:
num_hops = len(payment_path_pubkeys)
hop_shared_secrets = num_hops * [b'']
hop_blinded_node_ids = num_hops * [b'']
ephemeral_key = session_key
# compute shared key for each hop
for i in range(0, num_hops):
hop_shared_secrets[i] = get_ecdh(ephemeral_key, payment_path_pubkeys[i])
hop_blinded_node_ids[i] = get_blinded_node_id(payment_path_pubkeys[i], hop_shared_secrets[i])
ephemeral_pubkey = ecc.ECPrivkey(ephemeral_key).get_public_key_bytes()
blinding_factor = sha256(ephemeral_pubkey + hop_shared_secrets[i])
blinding_factor_int = int.from_bytes(blinding_factor, byteorder="big")
ephemeral_key_int = int.from_bytes(ephemeral_key, byteorder="big")
ephemeral_key_int = ephemeral_key_int * blinding_factor_int % ecc.CURVE_ORDER
ephemeral_key = ephemeral_key_int.to_bytes(32, byteorder="big")
return hop_shared_secrets
return hop_shared_secrets, hop_blinded_node_ids


def get_blinded_node_id(node_id: bytes, shared_secret: bytes):
# blinded node id
# B(i) = HMAC256("blinded_node_id", ss(i)) * N(i)
ss_bni_hmac = get_bolt04_onion_key(b'blinded_node_id', shared_secret)
ss_bni_hmac_int = int.from_bytes(ss_bni_hmac, byteorder="big")
blinded_node_id = ecc.ECPubkey(node_id) * ss_bni_hmac_int
return blinded_node_id.get_public_key_bytes()


def new_onion_packet(
payment_path_pubkeys: Sequence[bytes],
session_key: bytes,
hops_data: Sequence[OnionHopsDataSingle],
*,
associated_data: bytes,
associated_data: bytes = b'',
trampoline: bool = False,
onion_message: bool = False
) -> OnionPacket:
num_hops = len(payment_path_pubkeys)
assert num_hops == len(hops_data)
hop_shared_secrets = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
hop_shared_secrets, _ = get_shared_secrets_along_route(payment_path_pubkeys, session_key)

payload_size = 0
for i in range(num_hops):
# FIXME: serializing here and again below. cache bytes in OnionHopsDataSingle? _raw_bytes_payload?
payload_size += PER_HOP_HMAC_SIZE + len(hops_data[i].to_bytes())
if trampoline:
data_size = TRAMPOLINE_HOPS_DATA_SIZE
elif onion_message:
if payload_size <= HOPS_DATA_SIZE:
data_size = HOPS_DATA_SIZE
else:
data_size = ONION_MESSAGE_LARGE_SIZE
else:
data_size = HOPS_DATA_SIZE

if payload_size > data_size:
raise InvalidPayloadSize(f'payload too big for onion packet (max={data_size}, required={payload_size})')

data_size = TRAMPOLINE_HOPS_DATA_SIZE if trampoline else HOPS_DATA_SIZE
filler = _generate_filler(b'rho', hops_data, hop_shared_secrets, data_size)
next_hmac = bytes(PER_HOP_HMAC_SIZE)

Expand Down Expand Up @@ -211,6 +244,30 @@ def new_onion_packet(
hmac=next_hmac)


def encrypt_onionmsg_data_tlv(*, shared_secret, **kwargs):
rho_key = get_bolt04_onion_key(b'rho', shared_secret)
with io.BytesIO() as encrypted_data_tlv_fd:
OnionWireSerializer.write_tlv_stream(
fd=encrypted_data_tlv_fd,
tlv_stream_name='encrypted_data_tlv',
**kwargs)
encrypted_data_tlv_bytes = encrypted_data_tlv_fd.getvalue()
encrypted_recipient_data = chacha20_poly1305_encrypt(
key=rho_key, nonce=bytes(12),
data=encrypted_data_tlv_bytes)
return encrypted_recipient_data


def decrypt_onionmsg_data_tlv(*, shared_secret: bytes, encrypted_recipient_data: bytes) -> dict:
rho_key = get_bolt04_onion_key(b'rho', shared_secret)
recipient_data_bytes = chacha20_poly1305_decrypt(key=rho_key, nonce=bytes(12), data=encrypted_recipient_data)

with io.BytesIO(recipient_data_bytes) as fd:
recipient_data = OnionWireSerializer.read_tlv_stream(fd=fd, tlv_stream_name='encrypted_data_tlv')

return recipient_data


def calc_hops_data_for_payment(
route: 'LNPaymentRoute',
amount_msat: int, # that final recipient receives
Expand Down Expand Up @@ -299,9 +356,11 @@ class ProcessedOnionPacket(NamedTuple):
# TODO replay protection
def process_onion_packet(
onion_packet: OnionPacket,
associated_data: bytes,
our_onion_private_key: bytes,
is_trampoline=False) -> ProcessedOnionPacket:
*,
associated_data: bytes = b'',
is_trampoline=False,
tlv_stream_name='payload') -> ProcessedOnionPacket:
if not ecc.ECPubkey.is_pubkey_bytes(onion_packet.public_key):
raise InvalidOnionPubkey()
shared_secret = get_ecdh(our_onion_private_key, onion_packet.public_key)
Expand All @@ -319,7 +378,7 @@ def process_onion_packet(
padded_header = onion_packet.hops_data + bytes(data_size)
next_hops_data = xor_bytes(padded_header, stream_bytes)
next_hops_data_fd = io.BytesIO(next_hops_data)
hop_data = OnionHopsDataSingle.from_fd(next_hops_data_fd)
hop_data = OnionHopsDataSingle.from_fd(next_hops_data_fd, tlv_stream_name=tlv_stream_name)
# trampoline
trampoline_onion_packet = hop_data.payload.get('trampoline_onion_packet')
if trampoline_onion_packet:
Expand Down Expand Up @@ -427,7 +486,7 @@ def _decode_onion_error(error_packet: bytes, payment_path_pubkeys: Sequence[byte
session_key: bytes) -> Tuple[bytes, int]:
"""Returns the decoded error bytes, and the index of the sender of the error."""
num_hops = len(payment_path_pubkeys)
hop_shared_secrets = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
hop_shared_secrets, _ = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
for i in range(num_hops):
ammag_key = get_bolt04_onion_key(b'ammag', hop_shared_secrets[i])
um_key = get_bolt04_onion_key(b'um', hop_shared_secrets[i])
Expand Down
10 changes: 7 additions & 3 deletions electrum/lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from .transaction import PartialTxOutput, match_script_against_template, Sighash
from .logging import Logger
from .lnrouter import RouteEdge
from .lnonion import (new_onion_packet, OnionFailureCode, calc_hops_data_for_payment,
process_onion_packet, OnionPacket, construct_onion_error, obfuscate_onion_error, OnionRoutingFailure,
from .lnonion import (new_onion_packet, OnionFailureCode, calc_hops_data_for_payment, process_onion_packet,
OnionPacket, construct_onion_error, obfuscate_onion_error, OnionRoutingFailure,
ProcessedOnionPacket, UnsupportedOnionPacketVersion, InvalidOnionMac, InvalidOnionPubkey,
OnionFailureCodeMetaFlag)
from .lnchannel import Channel, RevokeAndAck, RemoteCtnTooFarInFuture, ChannelState, PeerState, ChanCloseOption, CF_ANNOUNCE_CHANNEL
Expand Down Expand Up @@ -2846,8 +2846,8 @@ def process_onion_packet(
try:
processed_onion = process_onion_packet(
onion_packet,
associated_data=payment_hash,
our_onion_private_key=self.privkey,
associated_data=payment_hash,
is_trampoline=is_trampoline)
except UnsupportedOnionPacketVersion:
raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=failure_data)
Expand All @@ -2863,3 +2863,7 @@ def process_onion_packet(
if self.network.config.TEST_FAIL_HTLCS_WITH_TEMP_NODE_FAILURE:
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
return processed_onion

def on_onion_message(self, payload):
if hasattr(self.lnworker, 'onion_message_manager'): # only on LNWallet
self.lnworker.onion_message_manager.on_onion_message(payload)
Loading

0 comments on commit c044bcd

Please sign in to comment.