From 0f314d1dd965424d898f4d8c564449ee42c21e97 Mon Sep 17 00:00:00 2001 From: f321x Date: Tue, 9 Sep 2025 13:10:17 +0200 Subject: [PATCH 01/17] lnpeer/lnworker: refactor htlc_switch refactor `htlc_switch` to new architecture to make it more robust against partial settlement of htlc sets and increase maintainability. Htlcs are now processed in two steps, first the htlcs are collected into sets from the channels, and potentially failed on their own already. Then a second loop iterates over the htlc sets and finalizes only on whole sets. # Conflicts: # electrum/lnpeer.py --- electrum/commands.py | 2 +- electrum/lnchannel.py | 7 +- electrum/lnonion.py | 60 ++- electrum/lnpeer.py | 860 ++++++++++++++++++++++++------------ electrum/lnsweep.py | 4 + electrum/lnutil.py | 80 +++- electrum/lnworker.py | 331 +++++++++----- electrum/submarine_swaps.py | 15 +- electrum/wallet_db.py | 107 ++++- tests/test_commands.py | 10 +- tests/test_lnpeer.py | 14 +- 11 files changed, 1040 insertions(+), 450 deletions(-) diff --git a/electrum/commands.py b/electrum/commands.py index 94d0d192b461..f249f55f9396 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -1508,7 +1508,7 @@ async def check_hold_invoice(self, payment_hash: str, wallet: Abstract_Wallet = payment_key: str = wallet.lnworker._get_payment_key(bfh(payment_hash)).hex() htlc_status = wallet.lnworker.received_mpp_htlcs[payment_key] result["closest_htlc_expiry_height"] = min( - htlc.cltv_abs for _, htlc in htlc_status.htlc_set + mpp_htlc.htlc.cltv_abs for mpp_htlc in htlc_status.htlcs ) elif wallet.lnworker.get_preimage_hex(payment_hash) is not None \ and payment_hash not in wallet.lnworker.dont_settle_htlcs: diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index dd5fedaca18b..36c2dad7f433 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -783,8 +783,8 @@ def __init__(self, state: 'StoredDict', *, name=None, lnworker=None, initial_fee self.onion_keys = state['onion_keys'] # type: Dict[int, bytes] self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) - self.unfulfilled_htlcs = state["unfulfilled_htlcs"] # type: Dict[int, Tuple[str, Optional[str]]] - # ^ htlc_id -> onion_packet_hex, forwarding_key + self.unfulfilled_htlcs = state["unfulfilled_htlcs"] # type: Dict[int, Optional[str]] + # ^ htlc_id -> onion_packet_hex self._state = ChannelState[state['state']] self.peer_state = PeerState.DISCONNECTED self._outgoing_channel_update = None # type: Optional[bytes] @@ -1112,6 +1112,7 @@ def _assert_can_add_htlc(self, *, htlc_proposer: HTLCOwner, amount_msat: int, if amount_msat <= 0: raise PaymentFailure("HTLC value must be positive") if amount_msat < chan_config.htlc_minimum_msat: + # todo: for incoming htlcs this could be handled more gracefully with `amount_below_minimum` raise PaymentFailure(f'HTLC value too small: {amount_msat} msat') if self.htlc_slots_left(htlc_proposer) == 0: @@ -1226,7 +1227,7 @@ def receive_htlc(self, htlc: UpdateAddHtlc, onion_packet:bytes = None) -> Update with self.db_lock: self.hm.recv_htlc(htlc) if onion_packet: - self.unfulfilled_htlcs[htlc.htlc_id] = onion_packet.hex(), None + self.unfulfilled_htlcs[htlc.htlc_id] = onion_packet.hex() self.logger.info("receive_htlc") return htlc diff --git a/electrum/lnonion.py b/electrum/lnonion.py index f624627e1bb0..b57c0479211b 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -26,7 +26,8 @@ import io import hashlib from functools import cached_property -from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union, Mapping +from typing import (Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union, + Mapping, Iterator) from enum import IntEnum from dataclasses import dataclass, field, replace from types import MappingProxyType @@ -485,6 +486,55 @@ def process_onion_packet( return ProcessedOnionPacket(are_we_final, hop_data, next_onion_packet, trampoline_onion_packet) +def compare_trampoline_onions( + trampoline_onions: Iterator[Optional[ProcessedOnionPacket]], + *, + exclude_amt_to_fwd: bool = False, +) -> bool: + """ + compare values of trampoline onions payloads and are_we_final. + If we are receiver of a multi trampoline payment amt_to_fwd can differ between the trampoline + parts of the payment, so it needs to be excluded from the comparison when comparing all trampoline + onions of the whole payment (however it can be compared between the onions in a single trampoline part). + """ + try: + first_onion = next(trampoline_onions) + except StopIteration: + raise ValueError("nothing to compare") + + if first_onion is None: + # we don't support mixed mpp sets of htlcs with trampoline onions and regular non-trampoline htlcs. + # In theory this could happen if a sender e.g. uses trampoline as fallback to deliver + # outstanding mpp parts if local pathfinding wasn't successful for the whole payment, + # resulting in a mixed payment. However, it's not even clear if the spec allows for such a constellation. + return all(onion is None for onion in trampoline_onions) + assert isinstance(first_onion, ProcessedOnionPacket), f"{first_onion=}" + + are_we_final = first_onion.are_we_final + payload = first_onion.hop_data.payload + total_msat = first_onion.total_msat + outgoing_cltv = first_onion.outgoing_cltv_value + payment_secret = first_onion.payment_secret + for onion in trampoline_onions: + if onion is None: + return False + assert isinstance(onion, ProcessedOnionPacket), f"{onion=}" + assert onion.trampoline_onion_packet is None, f"{onion=} cannot have trampoline_onion_packet" + if onion.are_we_final != are_we_final: + return False + if not exclude_amt_to_fwd: + if onion.hop_data.payload != payload: + return False + else: + if onion.total_msat != total_msat: + return False + if onion.outgoing_cltv_value != outgoing_cltv: + return False + if onion.payment_secret != payment_secret: + return False + return True + + class FailedToDecodeOnionError(Exception): pass @@ -521,6 +571,14 @@ def decode_data(self) -> Optional[Dict[str, Any]]: payload = None return payload + def to_wire_msg(self, onion_packet: OnionPacket, privkey: bytes, local_height: int) -> bytes: + onion_error = construct_onion_error(self, onion_packet.public_key, privkey, local_height) + error_bytes = obfuscate_onion_error(onion_error, onion_packet.public_key, privkey) + return error_bytes + + +class OnionParsingError(OnionRoutingFailure): pass + def construct_onion_error( error: OnionRoutingFailure, diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 72db12657bd2..377dd375f226 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -21,6 +21,7 @@ import aiorpcx from aiorpcx import ignore_after +from .lrucache import LRUCache from .crypto import sha256, sha256d, privkey_to_pubkey from . import bitcoin, util from . import constants @@ -31,11 +32,11 @@ from .bitcoin import make_op_return, DummyAddress from .transaction import PartialTxOutput, match_script_against_template, Sighash from .logging import Logger -from .lnrouter import RouteEdge -from .lnonion import (new_onion_packet, OnionFailureCode, calc_hops_data_for_payment, process_onion_packet, - OnionPacket, construct_onion_error, obfuscate_onion_error, OnionRoutingFailure, - ProcessedOnionPacket, UnsupportedOnionPacketVersion, InvalidOnionMac, InvalidOnionPubkey, - OnionFailureCodeMetaFlag) +from . import lnonion +from .lnonion import (OnionFailureCode, OnionPacket, obfuscate_onion_error, + OnionRoutingFailure, ProcessedOnionPacket, UnsupportedOnionPacketVersion, + InvalidOnionMac, InvalidOnionPubkey, OnionFailureCodeMetaFlag, + OnionParsingError) from .lnchannel import Channel, RevokeAndAck, ChannelState, PeerState, ChanCloseOption, CF_ANNOUNCE_CHANNEL from . import lnutil from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConfig, @@ -46,17 +47,15 @@ ln_compare_features, MIN_FINAL_CLTV_DELTA_ACCEPTED, RemoteMisbehaving, ShortChannelID, IncompatibleLightningFeatures, ChannelType, LNProtocolWarning, validate_features, - IncompatibleOrInsaneFeatures, FeeBudgetExceeded, + IncompatibleOrInsaneFeatures, ReceivedMPPStatus, ReceivedMPPHtlc, GossipForwardingMessage, GossipTimestampFilter, channel_id_from_funding_tx, - PaymentFeeBudget, serialize_htlc_key, Keypair, RecvMPPResolution) + serialize_htlc_key, Keypair, RecvMPPResolution) from .lntransport import LNTransport, LNTransportBase, LightningPeerConnectionClosed, HandshakeFailed from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType, FailedToParseMsg from .interface import GracefulDisconnect -from .lnrouter import fee_for_edge_msat from .json_db import StoredDict from .invoices import PR_PAID from .fee_policy import FEE_LN_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING -from .trampoline import decode_routing_info if TYPE_CHECKING: from .lnworker import LNGossip, LNWallet @@ -132,6 +131,7 @@ def __init__( self.downstream_htlc_resolved_event = asyncio.Event() self.register_callbacks() self._num_gossip_messages_forwarded = 0 + self._processed_onion_cache = LRUCache(maxsize=100) # type: LRUCache[bytes, ProcessedOnionPacket] def send_message(self, message_name: str, **kwargs): assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!" @@ -2136,177 +2136,151 @@ def _check_accepted_final_htlc( return payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd - def check_mpp_is_waiting( - self, - *, - payment_secret: bytes, - short_channel_id: ShortChannelID, + def _check_unfulfilled_htlc( + self, *, + chan: Channel, htlc: UpdateAddHtlc, - expected_msat: int, - exc_incorrect_or_unknown_pd: OnionRoutingFailure, - log_fail_reason: Callable[[str], None], - ) -> bool: - mpp_resolution = self.lnworker.check_mpp_status( - payment_secret=payment_secret, - short_channel_id=short_channel_id, - htlc=htlc, - expected_msat=expected_msat, - ) - if mpp_resolution == RecvMPPResolution.WAITING: - return True - elif mpp_resolution == RecvMPPResolution.EXPIRED: - log_fail_reason(f"MPP_TIMEOUT") - raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') - elif mpp_resolution == RecvMPPResolution.FAILED: - log_fail_reason(f"mpp_resolution is FAILED") - raise exc_incorrect_or_unknown_pd - elif mpp_resolution == RecvMPPResolution.COMPLETE: - return False - else: - raise Exception(f"unexpected {mpp_resolution=}") - - def maybe_fulfill_htlc( - self, *, - chan: Channel, - htlc: UpdateAddHtlc, - processed_onion: ProcessedOnionPacket, - outer_onion_payment_secret: bytes = None, # used to group trampoline htlcs for forwarding - onion_packet_bytes: bytes, - already_forwarded: bool = False, - ) -> Tuple[Optional[bytes], Optional[Tuple[str, Callable[[], Awaitable[Optional[str]]]]]]: + processed_onion: ProcessedOnionPacket, + outer_onion_payment_secret: bytes = None, # used to group trampoline htlcs for forwarding + ) -> str: """ - Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded. - Return (preimage, (payment_key, callback)) with at most a single element not None. + Does additional checks on the incoming htlc and return the payment key if the tests pass, + otherwise raises OnionRoutingError which will get the htlc failed. """ - if not processed_onion.are_we_final: - if not self.lnworker.enable_htlc_forwarding: - return None, None - # use the htlc key if we are forwarding - payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) - callback = lambda: self.lnworker.maybe_forward_htlc( - incoming_chan=chan, - htlc=htlc, - processed_onion=processed_onion) - return None, (payment_key, callback) + _log_fail_reason = self._log_htlc_fail_reason_cb(chan.short_channel_id, htlc, processed_onion.hop_data.payload) - def log_fail_reason(reason: str): - self.logger.info( - f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. " - f"{reason}. htlc={str(htlc)}. onion_payload={processed_onion.hop_data.payload}") - - chain = self.network.blockchain() # Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height. # We should not release the preimage for an HTLC that its sender could already time out as # then they might try to force-close and it becomes a race. - if chain.is_tip_stale() and not already_forwarded: - log_fail_reason(f"our chain tip is stale") - raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') + chain = self.network.blockchain() local_height = chain.height() + blocks_to_expiry = max(htlc.cltv_abs - local_height, 0) + if chain.is_tip_stale(): + _log_fail_reason(f"our chain tip is stale: {local_height=}") + raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') - # parse parameters and perform checks that are invariant - payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd = self._check_accepted_final_htlc( - chan=chan, - htlc=htlc, - processed_onion=processed_onion, - is_trampoline_onion=bool(outer_onion_payment_secret), - log_fail_reason=log_fail_reason) - - # payment key for final onions payment_hash = htlc.payment_hash - payment_key = (payment_hash + payment_secret_from_onion).hex() + if not processed_onion.are_we_final: + if outer_onion_payment_secret: + # this is a trampoline forwarding htlc, multiple incoming trampoline htlcs can be collected + payment_key = (payment_hash + outer_onion_payment_secret).hex() + return payment_key + # this is a regular htlc to forward, it will get its own set of size 1 keyed by htlc_key + # Additional checks required only for forwarding nodes will be done in maybe_forward_htlc(). + payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) + return payment_key - if self.check_mpp_is_waiting( - payment_secret=payment_secret_from_onion, - short_channel_id=chan.get_scid_or_local_alias(), + # parse parameters and perform checks that are invariant + payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd = ( + self._check_accepted_final_htlc( + chan=chan, htlc=htlc, - expected_msat=total_msat, - exc_incorrect_or_unknown_pd=exc_incorrect_or_unknown_pd, - log_fail_reason=log_fail_reason, - ): - return None, None - - # TODO check against actual min_final_cltv_expiry_delta from invoice (and give 2-3 blocks of leeway?) - # note: payment_bundles might get split here, e.g. one payment is "already forwarded" and the other is not. - # In practice, for the swap prepayment use case, this does not matter. - if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs and not already_forwarded: - log_fail_reason(f"htlc.cltv_abs is unreasonably close") + processed_onion=processed_onion, + is_trampoline_onion=bool(outer_onion_payment_secret), + log_fail_reason=_log_fail_reason, + )) + # trampoline htlcs of which we are the final receiver will first get grouped by the outer + # onions secret to allow grouping a multi-trampoline mpp in different sets. Once a trampoline + # payment part is completed (sum(htlcs) >= (trampoline-)amt_to_forward), its htlcs get moved into + # the htlc set representing the whole payment (payment key derived from trampoline/invoice secret). + payment_key = (payment_hash + (outer_onion_payment_secret or payment_secret_from_onion)).hex() + + if blocks_to_expiry < MIN_FINAL_CLTV_DELTA_ACCEPTED: + # this check should be done here for new htlcs and ongoing on pending sets. + # Here it is done so that invalid received htlcs will never get added to a set, + # so the set still has a chance to succeed until mpp timeout. + _log_fail_reason(f"htlc.cltv_abs is unreasonably close: {htlc.cltv_abs=}, {local_height=}") raise exc_incorrect_or_unknown_pd - # detect callback - # if there is a trampoline_onion, maybe_fulfill_htlc will be called again - # order is important: if we receive a trampoline onion for a hold invoice, we need to peel the onion first. - + # extract trampoline if processed_onion.trampoline_onion_packet: - # TODO: we should check that all trampoline_onions are the same - trampoline_onion = self.process_onion_packet( + trampoline_onion = self._process_incoming_onion_packet( processed_onion.trampoline_onion_packet, payment_hash=payment_hash, - onion_packet_bytes=onion_packet_bytes, is_trampoline=True) + + # compare trampoline onion against outer onion according to: + # https://github.com/lightning/bolts/blob/9938ab3d6160a3ba91f3b0e132858ab14bfe4f81/04-onion-routing.md?plain=1#L547-L553 if trampoline_onion.are_we_final: - # trampoline- we are final recipient of HTLC - # note: the returned payment_key will contain the inner payment_secret - return self.maybe_fulfill_htlc( - chan=chan, - htlc=htlc, - processed_onion=trampoline_onion, - outer_onion_payment_secret=payment_secret_from_onion, - onion_packet_bytes=onion_packet_bytes, - already_forwarded=already_forwarded, - ) - else: - callback = lambda: self.lnworker.maybe_forward_trampoline( - payment_hash=payment_hash, - inc_cltv_abs=htlc.cltv_abs, # TODO: use max or enforce same value across mpp parts - outer_onion=processed_onion, - trampoline_onion=trampoline_onion, - fw_payment_key=payment_key) - return None, (payment_key, callback) - - # TODO don't accept payments twice for same invoice - # note: we don't check invoice expiry (bolt11 'x' field) on the receiver-side. - # - semantics are weird: would make sense for simple-payment-receives, but not - # if htlc is expected to be pending for a while, e.g. for a hold-invoice. + try: + assert not processed_onion.outgoing_cltv_value < trampoline_onion.outgoing_cltv_value + is_mpp = processed_onion.total_msat > processed_onion.amt_to_forward + if is_mpp: + assert not processed_onion.total_msat < trampoline_onion.amt_to_forward + else: + assert not processed_onion.amt_to_forward < trampoline_onion.amt_to_forward + except AssertionError: + _log_fail_reason(f'incorrect trampoline onion {processed_onion=}\n{trampoline_onion=}') + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') + + return self._check_unfulfilled_htlc( + chan=chan, + htlc=htlc, + processed_onion=trampoline_onion, + outer_onion_payment_secret=payment_secret_from_onion, + ) + info = self.lnworker.get_payment_info(payment_hash) if info is None: - log_fail_reason(f"no payment_info found for RHASH {htlc.payment_hash.hex()}") + _log_fail_reason(f"no payment_info found for RHASH {payment_hash.hex()}") + raise exc_incorrect_or_unknown_pd + elif info.status == PR_PAID: + _log_fail_reason(f"invoice already paid: {payment_hash.hex()=}") + raise exc_incorrect_or_unknown_pd + elif blocks_to_expiry < info.min_final_cltv_delta: + _log_fail_reason( + f"min final cltv delta lower than requested: " + f"{payment_hash.hex()=} {htlc.cltv_abs=} {blocks_to_expiry=}" + ) + raise exc_incorrect_or_unknown_pd + elif htlc.timestamp > info.expiration_ts: # the set will get failed too if now > exp_ts + _log_fail_reason(f"not accepting htlc for expired invoice") raise exc_incorrect_or_unknown_pd - preimage = self.lnworker.get_preimage(payment_hash) - expected_payment_secret = self.lnworker.get_payment_secret(htlc.payment_hash) - if payment_secret_from_onion != expected_payment_secret: - log_fail_reason(f'incorrect payment secret {payment_secret_from_onion.hex()} != {expected_payment_secret.hex()}') + expected_payment_secret = self.lnworker.get_payment_secret(payment_hash) + if not util.constant_time_compare(payment_secret_from_onion, expected_payment_secret): + _log_fail_reason(f'incorrect payment secret: {payment_secret_from_onion.hex()=}') raise exc_incorrect_or_unknown_pd + invoice_msat = info.amount_msat if channel_opening_fee: + # deduct just-in-time channel fees from invoice amount invoice_msat -= channel_opening_fee if not (invoice_msat is None or invoice_msat <= total_msat <= 2 * invoice_msat): - log_fail_reason(f"total_msat={total_msat} too different from invoice_msat={invoice_msat}") + _log_fail_reason(f"{total_msat=} too different from {invoice_msat=}") raise exc_incorrect_or_unknown_pd - hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) - if hold_invoice_callback and not preimage: - callback = lambda: hold_invoice_callback(payment_hash) - return None, (payment_key, callback) - - if payment_hash.hex() in self.lnworker.dont_settle_htlcs: - return None, None - - if not preimage: - if not already_forwarded: - log_fail_reason(f"missing preimage and no hold invoice callback {payment_hash.hex()}") - raise exc_incorrect_or_unknown_pd - else: - return None, None - - chan.opening_fee = None - self.logger.info(f"maybe_fulfill_htlc. will FULFILL HTLC: chan {chan.short_channel_id}. htlc={str(htlc)}") - return preimage, None + return payment_key + + def _fulfill_htlc_set(self, payment_key: str, preimage: bytes): + htlc_set = self.lnworker.received_mpp_htlcs[payment_key] + assert len(htlc_set.htlcs) > 0, f"{htlc_set=}" + assert htlc_set.resolution == RecvMPPResolution.SETTLING + assert htlc_set.parent_set_key is None, f"Must not settle child {htlc_set=}" + # get payment hash of any htlc in the set (they are all the same) + payment_hash = htlc_set.get_payment_hash() + assert payment_hash is not None, htlc_set + for mpp_htlc in list(htlc_set.htlcs): + htlc_id = mpp_htlc.htlc.htlc_id + chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid) + if chan.channel_id not in self.channels: + # this htlc belongs to another peer and has to be settled in their htlc_switch + continue + if not chan.can_update_ctx(proposer=LOCAL): + continue + self.logger.info(f"fulfill htlc: {chan.short_channel_id}. {htlc_id=}. {payment_hash.hex()=}") + if chan.hm.was_htlc_preimage_released(htlc_id=htlc_id, htlc_proposer=REMOTE): + # this check is intended to gracefully handle stale htlcs in the set, e.g. after a crash + self.logger.debug(f"{mpp_htlc=} was already settled before, dropping it.") + htlc_set.htlcs.remove(mpp_htlc) + continue + self._fulfill_htlc(chan, htlc_id, preimage) + htlc_set.htlcs.remove(mpp_htlc) + # reset just-in-time opening fee of channel + chan.opening_fee = None - def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): - self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") - assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}" + def _fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id) self.received_htlcs_pending_removal.add((chan, htlc_id)) chan.settle_htlc(preimage, htlc_id) @@ -2316,6 +2290,61 @@ def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): id=htlc_id, payment_preimage=preimage) + def _fail_htlc_set( + self, + payment_key: str, + error_tuple: Tuple[Optional[bytes], Optional[OnionFailureCode | int], Optional[bytes]], + ): + htlc_set = self.lnworker.received_mpp_htlcs[payment_key] + assert htlc_set.resolution in (RecvMPPResolution.FAILED, RecvMPPResolution.EXPIRED) + + raw_error, error_code, error_data = error_tuple + local_height = self.network.blockchain().height() + for mpp_htlc in list(htlc_set.htlcs): + chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid) + htlc_id = mpp_htlc.htlc.htlc_id + if chan.channel_id not in self.channels: + # this htlc belongs to another peer and has to be settled in their htlc_switch + continue + if not chan.can_update_ctx(proposer=LOCAL): + continue + assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id) + if chan.hm.was_htlc_failed(htlc_id=htlc_id, htlc_proposer=REMOTE): + # this check is intended to gracefully handle stale htlcs in the set, e.g. after a crash + self.logger.debug(f"{mpp_htlc=} was already failed before, dropping it.") + htlc_set.htlcs.remove(mpp_htlc) + continue + onion_packet = self._parse_onion_packet(mpp_htlc.unprocessed_onion) + processed_onion_packet = self._process_incoming_onion_packet( + onion_packet, + payment_hash=mpp_htlc.htlc.payment_hash, + is_trampoline=False, + ) + if raw_error: + error_bytes = obfuscate_onion_error(raw_error, onion_packet.public_key, self.privkey) + else: + assert isinstance(error_code, (OnionFailureCode, int)) + if error_code == OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS: + amount_to_forward = processed_onion_packet.amt_to_forward + # if this was a trampoline htlc we use the inner amount_to_forward as this is + # the value known by the sender + if processed_onion_packet.trampoline_onion_packet: + processed_trampoline_onion_packet = self._process_incoming_onion_packet( + processed_onion_packet.trampoline_onion_packet, + payment_hash=mpp_htlc.htlc.payment_hash, + is_trampoline=True, + ) + amount_to_forward = processed_trampoline_onion_packet.amt_to_forward + error_data = amount_to_forward.to_bytes(8, byteorder="big") + e = OnionRoutingFailure(code=error_code, data=error_data or b'') + error_bytes = e.to_wire_msg(onion_packet, self.privkey, local_height) + self.fail_htlc( + chan=chan, + htlc_id=htlc_id, + error_bytes=error_bytes, + ) + htlc_set.htlcs.remove(mpp_htlc) + def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes): self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}" @@ -2328,7 +2357,7 @@ def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes): len=len(error_bytes), reason=error_bytes) - def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure): + def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionParsingError): self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}" if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32): @@ -2764,79 +2793,107 @@ async def htlc_switch(self): @util.profiler(min_threshold=0.02) def _run_htlc_switch_iteration(self): self._maybe_cleanup_received_htlcs_pending_removal() - # In this loop, an item of chan.unfulfilled_htlcs may go through 4 stages: - # - 1. not forwarded yet: (None, onion_packet_hex) - # - 2. forwarded: (forwarding_key, onion_packet_hex) - # - 3. processed: (forwarding_key, None), not irrevocably removed yet - # - 4. done: (forwarding_key, None), irrevocably removed + # htlc processing happens in two steps: + # 1. Step: Iterating through all channels and their pending htlcs, doing validation + # feasible for single htlcs (some checks only make sense on the whole mpp set) and + # then collecting these htlcs in a mpp set by payment key. + # HTLCs failing these checks will get failed directly and won't be added to any set. + # No htlcs will get settled in this step, settling only happens on complete mpp sets. + # If a new htlc belongs to a set which has already been failed, the htlc will be failed + # and not added to any set. + # Each htlc is only supposed to go through this first loop once when being received. for chan_id, chan in self.channels.items(): if not chan.can_update_ctx(proposer=LOCAL): continue self.maybe_send_commitment(chan) - done = set() unfulfilled = chan.unfulfilled_htlcs - for htlc_id, (onion_packet_hex, forwarding_key) in unfulfilled.items(): + for htlc_id, onion_packet_hex in list(unfulfilled.items()): if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): continue + htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id) - if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): - assert onion_packet_hex is None - self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) - if forwarding_key: - self.lnworker.maybe_cleanup_forwarding(forwarding_key) - done.add(htlc_id) - continue - if onion_packet_hex is None: - # has been processed already + try: + onion_packet = self._parse_onion_packet(onion_packet_hex) + except OnionParsingError as e: + self.fail_malformed_htlc( + chan=chan, + htlc_id=htlc.htlc_id, + reason=e, + ) + del unfulfilled[htlc_id] continue - error_reason = None # type: Optional[OnionRoutingFailure] - error_bytes = None # type: Optional[bytes] - preimage = None - onion_packet_bytes = bytes.fromhex(onion_packet_hex) - onion_packet = None + try: - onion_packet = OnionPacket.from_bytes(onion_packet_bytes) - except OnionRoutingFailure as e: - error_reason = e - else: - try: - preimage, _forwarding_key, error_bytes = self.process_unfulfilled_htlc( - chan=chan, - htlc=htlc, - forwarding_key=forwarding_key, - onion_packet_bytes=onion_packet_bytes, - onion_packet=onion_packet) - if _forwarding_key: - assert forwarding_key is None - unfulfilled[htlc_id] = onion_packet_hex, _forwarding_key - except OnionRoutingFailure as e: - error_bytes = construct_onion_error(e, onion_packet.public_key, self.privkey, self.network.get_local_height()) - if error_bytes: - error_bytes = obfuscate_onion_error(error_bytes, onion_packet.public_key, our_onion_private_key=self.privkey) - - if preimage or error_reason or error_bytes: - if preimage: - self.lnworker.set_request_status(htlc.payment_hash, PR_PAID) - if not self.lnworker.enable_htlc_settle: - continue - self.fulfill_htlc(chan, htlc.htlc_id, preimage) - elif error_bytes: - self.fail_htlc( - chan=chan, - htlc_id=htlc.htlc_id, - error_bytes=error_bytes) + processed_onion_packet = self._process_incoming_onion_packet( + onion_packet, + payment_hash=htlc.payment_hash, + is_trampoline=False, + ) + payment_key: str = self._check_unfulfilled_htlc( + chan=chan, + htlc=htlc, + processed_onion=processed_onion_packet, + ) + self.lnworker.update_or_create_mpp_with_received_htlc( + payment_key=payment_key, + scid=chan.short_channel_id, + htlc=htlc, + unprocessed_onion_packet=onion_packet_hex, # outer onion if trampoline + ) + except OnionParsingError as e: # could be raised when parsing the inner trampoline onion + self.fail_malformed_htlc( + chan=chan, + htlc_id=htlc.htlc_id, + reason=e, + ) + except Exception as e: + # Fail the htlc directly if it fails to pass these tests, it will not get added to a htlc set. + # https://github.com/lightning/bolts/blob/14272b1bd9361750cfdb3e5d35740889a6b510b5/04-onion-routing.md?plain=1#L388 + reraise = False + if isinstance(e, OnionRoutingFailure): + orf = e else: - self.fail_malformed_htlc( - chan=chan, - htlc_id=htlc.htlc_id, - reason=error_reason) - # blank onion field to mark it as processed - unfulfilled[htlc_id] = None, forwarding_key - - # cleanup - for htlc_id in done: - unfulfilled.pop(htlc_id) - self.maybe_send_commitment(chan) + orf = OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') + reraise = True # propagate this out, as this might suggest a bug + error_bytes = orf.to_wire_msg(onion_packet, self.privkey, self.network.get_local_height()) + self.fail_htlc( + chan=chan, + htlc_id=htlc.htlc_id, + error_bytes=error_bytes, + ) + if reraise: + raise + finally: + del unfulfilled[htlc_id] + + # 2. Step: Acting on sets of htlcs. + # Doing further checks that have to be done on sets of htlcs (e.g. total amount checks) + # and checks that have to be done continuously like checking for timeout. + # A set marked as failed once must never settle any htlcs associated to it. + # The sets are shared between all peers, so each peers htlc_switch acts on the same sets. + for payment_key, htlc_set in list(self.lnworker.received_mpp_htlcs.items()): + any_error, preimage, callback = self._check_unfulfilled_htlc_set(payment_key, htlc_set) + assert bool(any_error) + bool(preimage) + bool(callback) <= 1, \ + f"{any_error=}, {bool(preimage)=}, {callback=}" + if any_error: + error_tuple = self.lnworker.set_htlc_set_error(payment_key, any_error) + self._fail_htlc_set(payment_key, error_tuple) + if preimage: + if self.lnworker.enable_htlc_settle: + self.lnworker.set_request_status(htlc_set.get_payment_hash(), PR_PAID) + self._fulfill_htlc_set(payment_key, preimage) + if callback: + task = asyncio.create_task(callback()) + task.add_done_callback( # log exceptions occurring in callback + lambda t, pk=payment_key: self.logger.exception( + f"cb failed: " + f"{self.lnworker.received_mpp_htlcs[pk]=}", exc_info=t.exception()) if t.exception() else None + ) + + if len(self.lnworker.received_mpp_htlcs[payment_key].htlcs) == 0: + self.logger.debug(f"deleting resolved mpp set: {payment_key=}") + del self.lnworker.received_mpp_htlcs[payment_key] + self.lnworker.maybe_cleanup_forwarding(payment_key) def _maybe_cleanup_received_htlcs_pending_removal(self) -> None: done = set() @@ -2861,107 +2918,332 @@ async def htlc_switch_iteration(): await group.spawn(htlc_switch_iteration()) await group.spawn(self.got_disconnected.wait()) - def process_unfulfilled_htlc( - self, *, - chan: Channel, - htlc: UpdateAddHtlc, - forwarding_key: Optional[str], - onion_packet_bytes: bytes, - onion_packet: OnionPacket) -> Tuple[Optional[bytes], Optional[str], Optional[bytes]]: + def _log_htlc_fail_reason_cb( + self, + scid: ShortChannelID, + htlc: UpdateAddHtlc, + onion_payload: dict + ) -> Callable[[str], None]: + def _log_fail_reason(reason: str) -> None: + self.logger.info(f"will FAIL HTLC: {str(scid)=}. {reason=}. {str(htlc)=}. {onion_payload=}") + return _log_fail_reason + + def _log_htlc_set_fail_reason_cb(self, mpp_set: ReceivedMPPStatus) -> Callable[[str], None]: + def log_fail_reason(reason: str): + for mpp_htlc in mpp_set.htlcs: + try: + processed_onion = self._process_incoming_onion_packet( + onion_packet=self._parse_onion_packet(mpp_htlc.unprocessed_onion), + payment_hash=mpp_htlc.htlc.payment_hash, + is_trampoline=False, + ) + onion_payload = processed_onion.hop_data.payload + except Exception: + onion_payload = {} + + self._log_htlc_fail_reason_cb( + mpp_htlc.scid, + mpp_htlc.htlc, + onion_payload, + )(f"mpp set {id(mpp_set)} failed: {reason}") + + return log_fail_reason + + def _check_unfulfilled_htlc_set( + self, + payment_key: str, + mpp_set: ReceivedMPPStatus + ) -> Tuple[ + Optional[Union[OnionRoutingFailure, OnionFailureCode, bytes]], # error types used to fail the set + Optional[bytes], # preimage to settle the set + Optional[Callable[[], Awaitable[None]]], # callback + ]: """ - return (preimage, payment_key, error_bytes) with at most a single element that is not None - raise an OnionRoutingFailure if we need to fail the htlc + Returns what to do next with the given set of htlcs: + * Fail whole set -> returns error code + * Settle whole set -> Returns preimage + * call callback (e.g. forwarding, hold invoice) + May modify the mpp set in lnworker.received_mpp_htlcs (e.g. by setting its resolution to COMPLETE). """ - payment_hash = htlc.payment_hash - processed_onion = self.process_onion_packet( - onion_packet, - payment_hash=payment_hash, - onion_packet_bytes=onion_packet_bytes) + _log_fail_reason = self._log_htlc_set_fail_reason_cb(mpp_set) - preimage, forwarding_info = self.maybe_fulfill_htlc( - chan=chan, - htlc=htlc, - processed_onion=processed_onion, - onion_packet_bytes=onion_packet_bytes, - already_forwarded=bool(forwarding_key)) - - if not forwarding_key: - if forwarding_info: - # HTLC we are supposed to forward, but haven't forwarded yet - payment_key, forwarding_callback = forwarding_info + if (final_state := self._check_final_mpp_set_state(payment_key, mpp_set)) is not None: + return final_state + + assert mpp_set.resolution in (RecvMPPResolution.WAITING, RecvMPPResolution.COMPLETE) + chain = self.network.blockchain() + local_height = chain.height() + if chain.is_tip_stale(): + _log_fail_reason(f"our chain tip is stale: {local_height=}") + return OnionFailureCode.TEMPORARY_NODE_FAILURE, None, None + + amount_msat: int = 0 # sum(amount_msat of each htlc) + total_msat = None # type: Optional[int] + payment_hash = mpp_set.get_payment_hash() + closest_cltv_abs = mpp_set.get_closest_cltv_abs() + first_htlc_timestamp = mpp_set.get_first_htlc_timestamp() + processed_onions = {} # type: dict[ReceivedMPPHtlc, Tuple[ProcessedOnionPacket, Optional[ProcessedOnionPacket]]] + for mpp_htlc in mpp_set.htlcs: + processed_onion = self._process_incoming_onion_packet( + onion_packet=self._parse_onion_packet(mpp_htlc.unprocessed_onion), + payment_hash=payment_hash, + is_trampoline=False, # this is always the outer onion + ) + processed_onions[mpp_htlc] = (processed_onion, None) + inner_onion = None + if processed_onion.trampoline_onion_packet: + inner_onion = self._process_incoming_onion_packet( + onion_packet=processed_onion.trampoline_onion_packet, + payment_hash=payment_hash, + is_trampoline=True, + ) + processed_onions[mpp_htlc] = (processed_onion, inner_onion) + + total_msat_outer_onion = processed_onion.total_msat + total_msat_inner_onion = inner_onion.total_msat if inner_onion else None + if total_msat is None: + total_msat = total_msat_inner_onion or total_msat_outer_onion + + # check total_msat is equal for all htlcs of the set + if total_msat != (total_msat_inner_onion or total_msat_outer_onion): + _log_fail_reason(f"total_msat is not uniform: {total_msat=} != {processed_onion.total_msat=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + amount_msat += mpp_htlc.htlc.amount_msat + + # If the set contains outer onions with different payment secrets, the set's payment_key is + # derived from the trampoline/invoice/inner payment secret, so it is the second stage of a + # multi-trampoline payment in which all the trampoline parts/htlcs got combined. + # In this case the amt_to_forward cannot be compared as it may differ between the trampoline parts. + # However, amt_to_forward should be similar for all onions of a single trampoline part and gets + # compared in the first stage where the htlc set represents a single trampoline part. + outer_onions = [onions[0] for onions in processed_onions.values()] + can_have_different_amt_to_fwd = not all(o.payment_secret == outer_onions[0].payment_secret for o in outer_onions) + trampoline_onions = iter(onions[1] for onions in processed_onions.values()) + if not lnonion.compare_trampoline_onions(trampoline_onions, exclude_amt_to_fwd=can_have_different_amt_to_fwd): + _log_fail_reason(f"got inconsistent {trampoline_onions=}") + return OnionFailureCode.INVALID_ONION_PAYLOAD, None, None + + if len(processed_onions) == 1: + outer_onion, inner_onion = next(iter(processed_onions.values())) + if not outer_onion.are_we_final: + assert inner_onion is None, f"{outer_onion=}\n{inner_onion=}" if not self.lnworker.enable_htlc_forwarding: return None, None, None - if payment_key not in self.lnworker.active_forwardings: - async def wrapped_callback(): - forwarding_coro = forwarding_callback() - try: - next_htlc = await forwarding_coro - if next_htlc: - htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) - self.lnworker.active_forwardings[payment_key].append(next_htlc) - self.lnworker.downstream_to_upstream_htlc[next_htlc] = htlc_key - except OnionRoutingFailure as e: - if len(self.lnworker.active_forwardings[payment_key]) == 0: - self.lnworker.save_forwarding_failure(payment_key, failure_message=e) - # TODO what about other errors? e.g. TxBroadcastError for a swap. - # - malicious electrum server could fake TxBroadcastError - # Could we "catch-all Exception" and fail back the htlcs with e.g. TEMPORARY_NODE_FAILURE? - # - we don't want to fail the inc-HTLC for a syntax error that happens in the callback - # If we don't call save_forwarding_failure(), the inc-HTLC gets stuck until expiry - # and then the inc-channel will get force-closed. - # => forwarding_callback() could have an API with two exceptions types: - # - type1, such as OnionRoutingFailure, that signals we need to fail back the inc-HTLC - # - type2, such as TxBroadcastError, that signals we want to retry the callback - # add to list - assert len(self.lnworker.active_forwardings.get(payment_key, [])) == 0 - self.lnworker.active_forwardings[payment_key] = [] - fut = asyncio.ensure_future(wrapped_callback()) - # return payment_key so this branch will not be executed again - return None, payment_key, None - elif preimage: - return preimage, None, None - else: - # we are waiting for mpp consolidation or preimage + # this is a single (non-trampoline) htlc set which needs to be forwarded. + # set to settling state so it will not be failed or forwarded twice. + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + fwd_cb = lambda: self.lnworker.maybe_forward_htlc_set(payment_key, processed_htlc_set=processed_onions) + return None, None, fwd_cb + + assert payment_hash is not None and total_msat is not None + # check for expiry over time and potentially fail the whole set if any + # htlc's cltv becomes too close + blocks_to_expiry = max(0, closest_cltv_abs - local_height) + if blocks_to_expiry < MIN_FINAL_CLTV_DELTA_ACCEPTED: + _log_fail_reason(f"htlc.cltv_abs is unreasonably close") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + # check for mpp expiry (if incomplete and expired -> fail) + if mpp_set.resolution == RecvMPPResolution.WAITING \ + or not self.lnworker.is_payment_bundle_complete(payment_key): + # maybe this set is COMPLETE but the bundle is not yet completed, so the bundle can be considered WAITING + if int(time.time()) - first_htlc_timestamp > self.lnworker.MPP_EXPIRY \ + or self.lnworker.stopping_soon: + _log_fail_reason(f"MPP TIMEOUT (> {self.lnworker.MPP_EXPIRY} sec)") + return OnionFailureCode.MPP_TIMEOUT, None, None + + if mpp_set.resolution == RecvMPPResolution.WAITING: + # check if set is first stage multi-trampoline payment to us + # first stage trampoline payment: + # is a trampoline payment + we_are_final + payment key is derived from outer onion's payment secret + # (so it is not the payment secret we requested in the invoice, but some secret set by a + # trampoline forwarding node on the route). + # if it is first stage, check if sum(htlcs) >= amount_to_forward of the trampoline_payload. + # If this part is complete, move the htlcs to the overall mpp set of the payment (keyed by inner secret). + # Once the second stage set (the set containing all htlcs of the separate trampoline parts) + # is complete, the payment gets fulfilled. + trampoline_payment_key = None + any_trampoline_onion = next(iter(processed_onions.values()))[1] + if any_trampoline_onion and any_trampoline_onion.are_we_final: + trampoline_payment_secret = any_trampoline_onion.payment_secret + assert trampoline_payment_secret == self.lnworker.get_payment_secret(payment_hash) + trampoline_payment_key = (payment_hash + trampoline_payment_secret).hex() + + if trampoline_payment_key and trampoline_payment_key != payment_key: + # first stage of trampoline payment, the first stage must never get set COMPLETE + if amount_msat >= any_trampoline_onion.amt_to_forward: + # setting the parent key will mark the htlcs to be moved to the parent set + self.logger.debug(f"trampoline part complete. {len(mpp_set.htlcs)=}, " + f"{amount_msat=}. setting parent key: {trampoline_payment_key}") + self.lnworker.received_mpp_htlcs[payment_key] = mpp_set._replace( + parent_set_key=trampoline_payment_key, + ) + elif amount_msat >= total_msat: + # set mpp_set as completed as we have received the full total_msat + mpp_set = self.lnworker.set_mpp_resolution( + payment_key=payment_key, + new_resolution=RecvMPPResolution.COMPLETE, + ) + + # check if this set is a trampoline forwarding and potentially return forwarding callback + # note: all inner trampoline onions are equal (enforced above) + _, any_inner_onion = next(iter(processed_onions.values())) + if any_inner_onion and not any_inner_onion.are_we_final: + # this is a trampoline forwarding + can_forward = mpp_set.resolution == RecvMPPResolution.COMPLETE and self.lnworker.enable_htlc_forwarding + if not can_forward: return None, None, None + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + fwd_cb = lambda: self.lnworker.maybe_forward_htlc_set(payment_key, processed_htlc_set=processed_onions) + return None, None, fwd_cb + + # -- from here on it's assumed this set is a payment for us (not something to forward) -- + payment_info = self.lnworker.get_payment_info(payment_hash) + if payment_info is None: + _log_fail_reason(f"payment info has been deleted") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + # check invoice expiry, fail set if the invoice has expired before it was completed + if mpp_set.resolution == RecvMPPResolution.WAITING: + if int(time.time()) > payment_info.expiration_ts: + _log_fail_reason(f"invoice is expired {payment_info.expiration_ts=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + return None, None, None + + if payment_hash.hex() in self.lnworker.dont_settle_htlcs: + # used by hold invoice cli to prevent the htlcs from getting fulfilled automatically + return None, None, None + + preimage = self.lnworker.get_preimage(payment_hash) + hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) + if not preimage and not hold_invoice_callback: + _log_fail_reason(f"cannot settle, no preimage or callback found for {payment_hash.hex()=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + if not self.lnworker.is_payment_bundle_complete(payment_key): + # don't allow settling before all sets of the bundle are COMPLETE + return None, None, None else: - # HTLC we are supposed to forward, and have already forwarded - # for final trampoline onions, forwarding failures are stored with forwarding_key (which is the inner key) - payment_key = forwarding_key - preimage = self.lnworker.get_preimage(payment_hash) - error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key) - if error_bytes: - return None, None, error_bytes - if error_reason: - raise error_reason - if preimage: - return preimage, None, None + # If this set is part of a bundle now all parts are COMPLETE so the bundle can be deleted + # so the individual sets will get fulfilled. + self.lnworker.delete_payment_bundle(payment_key=bytes.fromhex(payment_key)) + + assert mpp_set.resolution == RecvMPPResolution.COMPLETE, "should return earlier if set is incomplete" + if not preimage: + assert hold_invoice_callback is not None, "should have been failed before" + async def callback(): + try: + await hold_invoice_callback(payment_hash) + except OnionRoutingFailure as e: # todo: should this catch all exceptions? + _log_fail_reason(f"hold invoice callback raised {e}") + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED) + # mpp set must not be failed unless the consumer calls unregister_hold_invoice and + # callback must only be called once. This is enforced by setting the set to SETTLING. + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + return None, None, callback + + # settle htlc set + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + return None, preimage, None + + def _check_final_mpp_set_state( + self, + payment_key: str, + mpp_set: ReceivedMPPStatus, + ) -> Optional[Tuple[ + Optional[Union[OnionRoutingFailure, OnionFailureCode, bytes]], # error types used to fail the set + Optional[bytes], # preimage to settle the set + None, # callback + ]]: + """ + handle sets that are already in a state eligible for fulfillment or failure and shouldn't + go through another iteration of _check_unfulfilled_htlc_set. + """ + if len(mpp_set.htlcs) == 0: + # stale set, will get deleted on the next iteration return None, None, None - def process_onion_packet( + if mpp_set.resolution == RecvMPPResolution.FAILED: + error_bytes, failure_message = self.lnworker.get_forwarding_failure(payment_key) + if error_bytes or failure_message: + return error_bytes or failure_message, None, None + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + elif mpp_set.resolution == RecvMPPResolution.EXPIRED: + return OnionFailureCode.MPP_TIMEOUT, None, None + + if mpp_set.parent_set_key: + # this is a complete trampoline part of a multi trampoline payment. Move the htlcs to parent. + parent = self.lnworker.received_mpp_htlcs.get(mpp_set.parent_set_key) + if not parent: + parent = ReceivedMPPStatus( + resolution=RecvMPPResolution.WAITING, + htlcs=set(), + ) + self.lnworker.received_mpp_htlcs[mpp_set.parent_set_key] = parent + parent.htlcs.update(mpp_set.htlcs) + mpp_set.htlcs.clear() + return None, None, None # this set will get deleted as there are no htlcs in it anymore + + assert not mpp_set.parent_set_key + if mpp_set.resolution == RecvMPPResolution.SETTLING: + # this is an ongoing forwarding, or a set that has not yet been fully settled (and removed). + # note the htlcs in SETTLING will not get failed automatically, + # even if timeout comes close, so either a forwarding failure or preimage has to be set + error_bytes, failure_message = self.lnworker.get_forwarding_failure(payment_key) + if error_bytes or failure_message: + # this was a forwarding set and it failed + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED) + return error_bytes or failure_message, None, None + preimage = self.lnworker.get_preimage(mpp_set.get_payment_hash()) + return None, preimage, None + + return None + + def _parse_onion_packet(self, onion_packet_hex: str) -> OnionPacket: + """ + https://github.com/lightning/bolts/blob/14272b1bd9361750cfdb3e5d35740889a6b510b5/02-peer-protocol.md?plain=1#L2352 + """ + onion_packet_bytes = None + try: + onion_packet_bytes = bytes.fromhex(onion_packet_hex) + onion_packet = OnionPacket.from_bytes(onion_packet_bytes) + except Exception as parsing_exc: + self.logger.warning(f"unable to parse onion: {str(parsing_exc)}") + onion_parsing_error = OnionParsingError( + code=OnionFailureCodeMetaFlag.BADONION, + data=sha256(onion_packet_bytes or b''), + ) + raise onion_parsing_error + return onion_packet + + def _process_incoming_onion_packet( self, onion_packet: OnionPacket, *, payment_hash: bytes, - onion_packet_bytes: bytes, is_trampoline: bool = False) -> ProcessedOnionPacket: - - failure_data = sha256(onion_packet_bytes) + onion_hash = onion_packet.onion_hash + cache_key = sha256(onion_hash + payment_hash + bytes([is_trampoline])) # type: ignore + if cached_onion := self._processed_onion_cache.get(cache_key): + return cached_onion try: - processed_onion = process_onion_packet( + processed_onion = lnonion.process_onion_packet( onion_packet, our_onion_private_key=self.privkey, associated_data=payment_hash, is_trampoline=is_trampoline) + self._processed_onion_cache[cache_key] = processed_onion except UnsupportedOnionPacketVersion: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=onion_hash) except InvalidOnionPubkey: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_KEY, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_KEY, data=onion_hash) except InvalidOnionMac: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_HMAC, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_HMAC, data=onion_hash) except Exception as e: - self.logger.info(f"error processing onion packet: {e!r}") - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=failure_data) + self.logger.warning(f"error processing onion packet: {e!r}") + raise OnionParsingError(code=OnionFailureCodeMetaFlag.BADONION, data=onion_hash) if self.network.config.TEST_FAIL_HTLCS_AS_MALFORMED: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=onion_hash) if self.network.config.TEST_FAIL_HTLCS_WITH_TEMP_NODE_FAILURE: raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') return processed_onion diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py index 014c87ad4ea0..9fededb8a3dc 100644 --- a/electrum/lnsweep.py +++ b/electrum/lnsweep.py @@ -445,6 +445,8 @@ def txs_htlc( if not preimage: # we might not have the preimage if this is a hold invoice continue + if htlc.payment_hash in chan.lnworker.dont_settle_htlcs: + continue else: preimage = None try: @@ -746,6 +748,8 @@ def tx_htlc( if not preimage: # we might not have the preimage if this is a hold invoice continue + if htlc.payment_hash in chan.lnworker.dont_settle_htlcs: + continue else: preimage = None tx_htlc( diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 76a6c876b6d1..28f1f4e36896 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -1934,26 +1934,86 @@ def __post_init__(self): # Note: these states are persisted in the wallet file. # Do not modify them without performing a wallet db upgrade +# todo: if this changes again states could also be persisted by name instead of int value as done for ChannelState class RecvMPPResolution(IntEnum): - WAITING = 0 - EXPIRED = 1 - COMPLETE = 2 - FAILED = 3 + WAITING = 0 # set is not complete yet, waiting for arrival of the remaining htlcs + EXPIRED = 1 # preimage must not be revealed + COMPLETE = 2 # set is complete but could still be failed (e.g. due to cltv timeout) + FAILED = 3 # preimage must not be revealed + SETTLING = 4 # Must not be failed, should be settled asap. + # Also used when forwarding (for upstream), in which case a downstream + # forwarding failure could still result in transitioning to FAILED. + + +r = RecvMPPResolution +allowed_mpp_set_transitions = ( + (r.WAITING, r.EXPIRED), + (r.WAITING, r.FAILED), + (r.WAITING, r.COMPLETE), + (r.WAITING, r.SETTLING), # normal htlc forwarding + + (r.COMPLETE, r.SETTLING), + (r.COMPLETE, r.FAILED), + (r.COMPLETE, r.EXPIRED), # this should only realistically happen for payment bundles + + (r.SETTLING, r.FAILED), # forwarding failure, hold invoice callback gets unregistered, and we don't have preimage + + (r.EXPIRED, r.FAILED), # doesn't seem useful but also not dangerous +) +del r + + +class ReceivedMPPHtlc(NamedTuple): + scid: ShortChannelID + htlc: UpdateAddHtlc + unprocessed_onion: str + + def __repr__(self): + return f"{self.scid}, {self.htlc=}, {self.unprocessed_onion[:15]=}..." + + @staticmethod + def from_tuple(scid, htlc, unprocessed_onion) -> 'ReceivedMPPHtlc': + assert is_hex_str(unprocessed_onion) and is_hex_str(scid) + return ReceivedMPPHtlc( + scid=ShortChannelID(bytes.fromhex(scid)), + htlc=UpdateAddHtlc.from_tuple(*htlc), + unprocessed_onion=unprocessed_onion, + ) class ReceivedMPPStatus(NamedTuple): resolution: RecvMPPResolution - expected_msat: int - htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + htlcs: set[ReceivedMPPHtlc] + # parent_set_key is needed as trampoline allows MPP to be nested, the parent_set_key is the + # payment key of the final mpp set (derived from inner trampoline onion payment secret) + # to which the separate trampoline sets htlcs get added once they are complete. + # https://github.com/lightning/bolts/pull/829/commits/bc7a1a0bc97b2293e7f43dd8a06529e5fdcf7cd2 + parent_set_key: str = None + + def get_first_htlc_timestamp(self) -> Optional[int]: + return min([mpp_htlc.htlc.timestamp for mpp_htlc in self.htlcs], default=None) + + def get_closest_cltv_abs(self) -> Optional[int]: + return min([mpp_htlc.htlc.cltv_abs for mpp_htlc in self.htlcs], default=None) + + def get_payment_hash(self) -> Optional[bytes]: + mpp_htlcs = iter(self.htlcs) + first_mpp_htlc = next(mpp_htlcs, None) + payment_hash = first_mpp_htlc.htlc.payment_hash if first_mpp_htlc else None + for mpp_htlc in mpp_htlcs: + assert mpp_htlc.htlc.payment_hash == payment_hash, "mpp set with inconsistent payment hashes" + return payment_hash @staticmethod @stored_in('received_mpp_htlcs', tuple) - def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus': - htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid, x) in htlc_list]) + def from_tuple(resolution, htlc_list, parent_set_key=None) -> 'ReceivedMPPStatus': + assert isinstance(resolution, int) + htlc_set = set(ReceivedMPPHtlc.from_tuple(*htlc_data) for htlc_data in htlc_list) return ReceivedMPPStatus( resolution=RecvMPPResolution(resolution), - expected_msat=expected_msat, - htlc_set=htlc_set) + htlcs=htlc_set, + parent_set_key=parent_set_key, + ) class OnionFailureCodeMetaFlag(IntFlag): diff --git a/electrum/lnworker.py b/electrum/lnworker.py index d3cf02e0aa59..d4c5bc8cb2a4 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -10,7 +10,7 @@ from enum import IntEnum from typing import ( Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Mapping, Any, Iterable, AsyncGenerator, - Callable, Awaitable + Callable, Awaitable, Union, ) from types import MappingProxyType import threading @@ -70,7 +70,7 @@ ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage, OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget, NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT, - MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE, ReceivedMPPStatus, RecvMPPResolution, + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE, RecvMPPResolution, ReceivedMPPStatus, ReceivedMPPHtlc, PaymentSuccess, ) from .lnonion import ( @@ -1237,6 +1237,8 @@ def channel_state_changed(self, chan: Channel): if type(chan) is Channel: self.save_channel(chan) self.clear_invoices_cache() + if chan._state == ChannelState.REDEEMED: + self.maybe_cleanup_mpp(chan) util.trigger_callback('channel', self.wallet, chan) def save_channel(self, chan: Channel): @@ -2399,15 +2401,47 @@ def bundle_payments(self, hash_list: Sequence[bytes]) -> None: self._payment_bundles_pkey_to_canon[pkey] = canon_pkey self._payment_bundles_canon_to_pkeylist[canon_pkey] = tuple(payment_keys) - def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]: + def get_payment_bundle(self, payment_key: Union[bytes, str]) -> Sequence[bytes]: with self.lock: + if isinstance(payment_key, str): + try: + payment_key = bytes.fromhex(payment_key) + except ValueError: + # might be a forwarding payment_key which is not hex and will never have a bundle + return [] canon_pkey = self._payment_bundles_pkey_to_canon.get(payment_key) if canon_pkey is None: return [] return self._payment_bundles_canon_to_pkeylist[canon_pkey] - def delete_payment_bundle(self, payment_hash: bytes) -> None: - payment_key = self._get_payment_key(payment_hash) + def is_payment_bundle_complete(self, any_payment_key: str) -> bool: + """ + complete means a htlc set is available for each payment key of the payment bundle and + all htlc sets have a resolution >= COMPLETE (we got the whole payment bundle amount) + """ + # get all payment keys covered by this bundle + bundle_payment_keys = self.get_payment_bundle(any_payment_key) + if not bundle_payment_keys: # there is no payment bundle + return True + for payment_key in bundle_payment_keys: + mpp_set = self.received_mpp_htlcs.get(payment_key.hex()) + if mpp_set is None: + # payment bundle is missing htlc set for payment request + # it might have already been failed and deleted + return False + elif mpp_set.resolution not in (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING): + return False + return True + + def delete_payment_bundle( + self, *, + payment_hash: Optional[bytes] = None, + payment_key: Optional[bytes] = None, + ) -> None: + assert (payment_hash is not None) ^ (payment_key is not None), \ + "must provide exactly one of (payment_hash, payment_key)" + if not payment_key: + payment_key = self._get_payment_key(payment_hash) with self.lock: canon_pkey = self._payment_bundles_pkey_to_canon.get(payment_key) if canon_pkey is None: # is it ok for bundle to be missing?? @@ -2478,10 +2512,16 @@ def add_payment_info_for_hold_invoice( self.save_payment_info(info, write_to_disk=False) def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]): + assert self.get_preimage(payment_hash) is None, "hold invoice cb won't get called if preimage is already set" self.hold_invoice_callbacks[payment_hash] = cb def unregister_hold_invoice(self, payment_hash: bytes): - self.hold_invoice_callbacks.pop(payment_hash) + self.hold_invoice_callbacks.pop(payment_hash, None) + payment_key = self._get_payment_key(payment_hash).hex() + if payment_key in self.received_mpp_htlcs: + if self.get_preimage(payment_hash) is None: + # the pending mpp set can be failed as we don't have the preimage to settle it + self.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED) def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: assert info.status in SAVED_PR_STATUS @@ -2500,132 +2540,142 @@ def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> if write_to_disk: self.wallet.save_db() - def check_mpp_status( - self, *, - payment_secret: bytes, - short_channel_id: ShortChannelID, - htlc: UpdateAddHtlc, - expected_msat: int, - ) -> RecvMPPResolution: - """Returns the status of the incoming htlc set the given *htlc* belongs to. - - ACCEPTED simply means the mpp set is complete, and we can proceed with further - checks before fulfilling (or failing) the htlcs. - In particular, note that hold-invoice-htlcs typically remain in the ACCEPTED state - for quite some time -- not in the "WAITING" state (which would refer to the mpp set - not yet being complete!). - """ - payment_hash = htlc.payment_hash - payment_key = payment_hash + payment_secret - self.update_mpp_with_received_htlc( - payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat) - mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution - # if still waiting, calc resolution now: - if mpp_resolution == RecvMPPResolution.WAITING: - bundle = self.get_payment_bundle(payment_key) - if bundle: - payment_keys = bundle - else: - payment_keys = [payment_key] - first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys]) - if self.get_payment_status(payment_hash) == PR_PAID: - mpp_resolution = RecvMPPResolution.COMPLETE - elif self.stopping_soon: - # try to time out pending HTLCs before shutting down - mpp_resolution = RecvMPPResolution.EXPIRED - elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]): - mpp_resolution = RecvMPPResolution.COMPLETE - elif time.time() - first_timestamp > self.MPP_EXPIRY: - mpp_resolution = RecvMPPResolution.EXPIRED - # save resolution, if any. - if mpp_resolution != RecvMPPResolution.WAITING: - for pkey in payment_keys: - if pkey.hex() in self.received_mpp_htlcs: - self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution) - - return mpp_resolution - - def update_mpp_with_received_htlc( + def update_or_create_mpp_with_received_htlc( self, *, - payment_key: bytes, + payment_key: str, scid: ShortChannelID, htlc: UpdateAddHtlc, - expected_msat: int, + unprocessed_onion_packet: str, ): - # add new htlc to set - mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) + # Payment key creation: + # * for regular forwarded htlcs -> "scid.hex() + ':%d' % htlc_id" [htlc key] + # * for trampoline forwarding -> "payment hash + payment secret from outer onion" + # * for final non-trampoline htlcs (we are receiver) -> "payment hash + payment secret from onion" + # * for final trampoline htlcs (we are receiver) -> 2. step grouping: + # 1. grouping of htlcs by "payments hash + outer onion payment secret", a 'multi-trampoline mpp part'. + # 2. once the set of step 1. is COMPLETE (amount_fwd outer onion >= total_amt outer onion) + # the htlcs get moved to the parent mpp set (created once first part is complete) grouped by: + # "payment_hash + inner onion payment secret (the one in the invoice)" + # After moving the htlcs the first set gets deleted. + # + # Add the validated htlc to the htlc set associated with the payment key. + # If no set exists, a new set in WAITING state is created. + mpp_status = self.received_mpp_htlcs.get(payment_key) if mpp_status is None: + self.logger.debug(f"creating new mpp set for {payment_key=}") mpp_status = ReceivedMPPStatus( resolution=RecvMPPResolution.WAITING, - expected_msat=expected_msat, - htlc_set=set(), + htlcs=set(), ) - if expected_msat != mpp_status.expected_msat: - self.logger.info( - f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}") - mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED) - key = (scid, htlc) - if key not in mpp_status.htlc_set: - mpp_status.htlc_set.add(key) # side-effecting htlc_set - self.received_mpp_htlcs[payment_key.hex()] = mpp_status - - def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution): - mpp_status = self.received_mpp_htlcs[payment_key.hex()] - self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}') - self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution) - - def is_mpp_amount_reached(self, payment_key: bytes) -> bool: - amounts = self.get_mpp_amounts(payment_key) - if amounts is None: - return False - total, expected = amounts - return total >= expected - def is_complete_mpp(self, payment_hash: bytes) -> bool: + if mpp_status.resolution > RecvMPPResolution.WAITING: + # we are getting a htlc for a set that is not in WAITING state, it cannot be safely added + self.logger.info(f"htlc set cannot accept htlc, failing htlc: {scid=} {htlc.htlc_id=}") + if mpp_status == RecvMPPResolution.EXPIRED: + raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') + raise OnionRoutingFailure( + code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, + data=htlc.amount_msat.to_bytes(8, byteorder="big"), + ) + + new_htlc = ReceivedMPPHtlc( + scid=scid, + htlc=htlc, + unprocessed_onion=unprocessed_onion_packet, + ) + assert new_htlc not in mpp_status.htlcs, "each htlc should make it here only once?" + assert isinstance(unprocessed_onion_packet, str) + mpp_status.htlcs.add(new_htlc) # side-effecting htlc_set + self.received_mpp_htlcs[payment_key] = mpp_status + + def set_mpp_resolution(self, payment_key: str, new_resolution: RecvMPPResolution) -> ReceivedMPPStatus: + mpp_status = self.received_mpp_htlcs[payment_key] + if mpp_status.resolution == new_resolution: + return mpp_status + if not (mpp_status.resolution, new_resolution) in lnutil.allowed_mpp_set_transitions: + raise ValueError(f'forbidden mpp set transition: {mpp_status.resolution} -> {new_resolution}') + self.logger.info(f'set_mpp_resolution {new_resolution.name} {len(mpp_status.htlcs)=}: {payment_key=}') + self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=new_resolution) + self.wallet.save_db() + return self.received_mpp_htlcs[payment_key] + + def set_htlc_set_error( + self, + payment_key: str, + error: Union[bytes, OnionFailureCode, OnionRoutingFailure], + ) -> Optional[Tuple[Optional[bytes], Optional[OnionFailureCode | int], Optional[bytes]]]: + """ + handles different types of errors and sets the htlc set to failed, then returns a more + structured tuple of error types which can then be used to fail the htlc set + """ + htlc_set = self.received_mpp_htlcs[payment_key] + assert htlc_set.resolution != RecvMPPResolution.SETTLING + raw_error, error_code, error_data = None, None, None + if isinstance(error, bytes): + raw_error = error + elif isinstance(error, OnionFailureCode): + error_code = error + elif isinstance(error, OnionRoutingFailure): + error_code, error_data = OnionFailureCode.from_int(error.code), error.data + else: + raise ValueError(f"invalid error type: {repr(error)}") + + if error_code == OnionFailureCode.MPP_TIMEOUT: + self.set_mpp_resolution(payment_key=payment_key, new_resolution=RecvMPPResolution.EXPIRED) + else: + self.set_mpp_resolution(payment_key=payment_key, new_resolution=RecvMPPResolution.FAILED) + + return raw_error, error_code, error_data + + def get_mpp_resolution(self, payment_hash: bytes) -> Optional[RecvMPPResolution]: payment_key = self._get_payment_key(payment_hash) status = self.received_mpp_htlcs.get(payment_key.hex()) - return status and status.resolution == RecvMPPResolution.COMPLETE + return status.resolution if status else None + + def is_complete_mpp(self, payment_hash: bytes) -> bool: + resolution = self.get_mpp_resolution(payment_hash) + if resolution is not None: + return resolution in (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING) + return False def get_payment_mpp_amount_msat(self, payment_hash: bytes) -> Optional[int]: """Returns the received mpp amount for given payment hash.""" payment_key = self._get_payment_key(payment_hash) - amounts = self.get_mpp_amounts(payment_key) - if not amounts: + total_msat = self.get_mpp_amounts(payment_key) + if not total_msat: return None - total_msat, _ = amounts return total_msat - def get_mpp_amounts(self, payment_key: bytes) -> Optional[Tuple[int, int]]: - """Returns (total received amount, expected amount) or None.""" + def get_mpp_amounts(self, payment_key: bytes) -> Optional[int]: + """Returns total received amount or None.""" mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if not mpp_status: return None - total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set]) - return total, mpp_status.expected_msat - - def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int: - mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) - if not mpp_status: - return int(time.time()) - return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) + total = sum([mpp_htlc.htlc.amount_msat for mpp_htlc in mpp_status.htlcs]) + return total def maybe_cleanup_mpp( self, - short_channel_id: ShortChannelID, - htlc: UpdateAddHtlc, + chan: Channel, ) -> None: - - htlc_key = (short_channel_id, htlc) + """ + Remove all remaining mpp htlcs of the given channel after closing. + Usually they get removed in htlc_switch after all htlcs of the set are resolved, + however if there is a force close with pending htlcs they need to be removed after the channel + is closed. + """ + # only cleanup when channel is REDEEMED as mpp set is still required for lnsweep + assert chan._state == ChannelState.REDEEMED for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()): - if htlc_key not in mpp_status.htlc_set: - continue - assert mpp_status.resolution != RecvMPPResolution.WAITING - self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}') - mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set - if len(mpp_status.htlc_set) == 0: + htlcs_to_remove = [htlc for htlc in mpp_status.htlcs if htlc.scid == chan.short_channel_id] + for stale_mpp_htlc in htlcs_to_remove: + assert mpp_status.resolution != RecvMPPResolution.WAITING + self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}') + mpp_status.htlcs.remove(stale_mpp_htlc) # side-effecting htlc_set + if len(mpp_status.htlcs) == 0: self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}') - self.received_mpp_htlcs.pop(payment_key_hex) + del self.received_mpp_htlcs[payment_key_hex] self.maybe_cleanup_forwarding(payment_key_hex) def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None: @@ -2681,6 +2731,7 @@ def is_forwarded_htlc(self, htlc_key) -> Optional[str]: for payment_key, htlcs in self.active_forwardings.items(): if htlc_key in htlcs: return payment_key + return None def notify_upstream_peer(self, htlc_key: str) -> None: """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed. @@ -3436,7 +3487,58 @@ def maybe_add_backup_from_tx(self, tx): util.trigger_callback('channels_updated', self.wallet) self.lnwatcher.add_channel(cb) - async def maybe_forward_htlc( + async def maybe_forward_htlc_set( + self, + payment_key: str, *, + processed_htlc_set: dict[ReceivedMPPHtlc, Tuple[ProcessedOnionPacket, Optional[ProcessedOnionPacket]]], + ) -> None: + assert self.enable_htlc_forwarding + assert payment_key not in self.active_forwardings, "cannot forward set twice" + self.active_forwardings[payment_key] = [] + self.logger.debug(f"adding active_forwarding: {payment_key=}") + + any_mpp_htlc, (any_outer_onion, any_trampoline_onion) = next(iter(processed_htlc_set.items())) + try: + if any_trampoline_onion is None: + assert not any_outer_onion.are_we_final + assert len(processed_htlc_set) == 1, processed_htlc_set + forward_htlc = any_mpp_htlc.htlc + incoming_chan = self.get_channel_by_short_id(any_mpp_htlc.scid) + next_htlc = await self._maybe_forward_htlc( + incoming_chan=incoming_chan, + htlc=forward_htlc, + processed_onion=any_outer_onion, + ) + htlc_key = serialize_htlc_key(incoming_chan.get_scid_or_local_alias(), forward_htlc.htlc_id) + self.active_forwardings[payment_key].append(next_htlc) + self.downstream_to_upstream_htlc[next_htlc] = htlc_key + else: + assert not any_trampoline_onion.are_we_final and any_outer_onion.are_we_final + # trampoline forwarding + min_inc_cltv_abs = min( + mpp_htlc.htlc.cltv_abs + for mpp_htlc in processed_htlc_set.keys()) # take "min" to assume worst-case + await self._maybe_forward_trampoline( + payment_hash=any_mpp_htlc.htlc.payment_hash, + closest_inc_cltv_abs=min_inc_cltv_abs, + total_msat=any_outer_onion.total_msat, + any_trampoline_onion=any_trampoline_onion, + fw_payment_key=payment_key, + ) + except OnionRoutingFailure as e: + self.logger.debug(f"forwarding failed: {e=}") + if len(self.active_forwardings[payment_key]) == 0: + self.save_forwarding_failure(payment_key, failure_message=e) + # TODO what about other errors? + # Could we "catch-all Exception" and fail back the htlcs with e.g. TEMPORARY_NODE_FAILURE? + # - we don't want to fail the inc-HTLC for a syntax error that happens in the callback + # If we don't call save_forwarding_failure(), the inc-HTLC gets stuck until expiry + # and then the inc-channel will get force-closed. + # => forwarding_callback() could have an API with two exceptions types: + # - type1, such as OnionRoutingFailure, that signals we need to fail back the inc-HTLC + # - type2, such as NoPathFound, that signals we want to retry forwarding + + async def _maybe_forward_htlc( self, *, incoming_chan: Channel, htlc: UpdateAddHtlc, @@ -3550,13 +3652,12 @@ def log_fail_reason(reason: str): htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), next_htlc.htlc_id) return htlc_key - @log_exceptions - async def maybe_forward_trampoline( + async def _maybe_forward_trampoline( self, *, payment_hash: bytes, - inc_cltv_abs: int, - outer_onion: ProcessedOnionPacket, - trampoline_onion: ProcessedOnionPacket, + closest_inc_cltv_abs: int, + total_msat: int, # total_msat of the outer onion + any_trampoline_onion: ProcessedOnionPacket, # any trampoline onion of the incoming htlc set, they should be similar fw_payment_key: str, ) -> None: @@ -3565,7 +3666,7 @@ async def maybe_forward_trampoline( if not (forwarding_enabled and forwarding_trampoline_enabled): self.logger.info(f"trampoline forwarding is disabled. failing htlc.") raise OnionRoutingFailure(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'') - payload = trampoline_onion.hop_data.payload + payload = any_trampoline_onion.hop_data.payload payment_data = payload.get('payment_data') try: payment_secret = payment_data['payment_secret'] if payment_data else os.urandom(32) @@ -3583,7 +3684,7 @@ async def maybe_forward_trampoline( else: self.logger.info('forward_trampoline: end-to-end') invoice_features = LnFeatures.BASIC_MPP_OPT - next_trampoline_onion = trampoline_onion.next_packet + next_trampoline_onion = any_trampoline_onion.next_packet r_tags = [] except Exception as e: self.logger.exception('') @@ -3597,13 +3698,12 @@ async def maybe_forward_trampoline( # these are the fee/cltv paid by the sender # pay_to_node will raise if they are not sufficient - total_msat = outer_onion.hop_data.payload["payment_data"]["total_msat"] budget = PaymentFeeBudget( fee_msat=total_msat - amt_to_forward, - cltv=inc_cltv_abs - out_cltv_abs, + cltv=closest_inc_cltv_abs - out_cltv_abs, ) self.logger.info(f'trampoline forwarding. budget={budget}') - self.logger.info(f'trampoline forwarding. {inc_cltv_abs=}, {out_cltv_abs=}') + self.logger.info(f'trampoline forwarding. {closest_inc_cltv_abs=}, {out_cltv_abs=}') # To convert abs vs rel cltvs, we need to guess blockheight used by original sender as "current blockheight". # Blocks might have been mined since. # - if we skew towards the past, we decrease our own cltv_budget accordingly (which is ok) @@ -3715,7 +3815,6 @@ def create_onion_for_route( final_cltv_abs=final_cltv_abs, total_msat=total_msat, payment_secret=payment_secret) - num_hops = len(hops_data) self.logger.info(f"pay len(route)={len(route)}. for payment_hash={payment_hash.hex()}") for i in range(len(route)): self.logger.info(f" {i}: edge={route[i].short_channel_id} hop_data={hops_data[i]!r}") diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 3705aa572474..de6a40a545af 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -37,8 +37,7 @@ run_sync_function_on_asyncio_thread, trigger_callback, NoDynamicFeeEstimates, UserFacingException, ) from . import lnutil -from .lnutil import (hex_to_bytes, REDEEM_AFTER_DOUBLE_SPENT_DELAY, Keypair, - MIN_FINAL_CLTV_DELTA_ACCEPTED) +from .lnutil import hex_to_bytes, REDEEM_AFTER_DOUBLE_SPENT_DELAY, Keypair from .lnaddr import lndecode from .json_db import StoredObject, stored_in from . import constants @@ -257,7 +256,7 @@ def __init__(self, *, wallet: 'Abstract_Wallet', lnworker: 'LNWallet'): payment_hash = bytes.fromhex(payment_hash_hex) swap._payment_hash = payment_hash self._add_or_reindex_swap(swap, is_new=False) - if not swap.is_reverse and not swap.is_redeemed: + if not swap.is_reverse and not swap.is_redeemed and not self.lnworker.get_preimage(swap.payment_hash): self.lnworker.register_hold_invoice(payment_hash, self.hold_invoice_callback) self._prepayments = {} # type: Dict[bytes, bytes] # fee_rhash -> rhash @@ -399,8 +398,8 @@ def cancel_normal_swap(self, swap: SwapData): def _fail_swap(self, swap: SwapData, reason: str): self.logger.info(f'failing swap {swap.payment_hash.hex()}: {reason}') if not swap.is_reverse and swap.payment_hash in self.lnworker.hold_invoice_callbacks: + # unregister_hold_invoice will fail pending htlcs if there is no preimage available self.lnworker.unregister_hold_invoice(swap.payment_hash) - # Peer.maybe_fulfill_htlc will fail incoming htlcs if there is no payment info self.lnworker.delete_payment_info(swap.payment_hash.hex()) self.lnworker.clear_invoices_cache() self.lnwatcher.remove_callback(swap.lockup_address) @@ -415,7 +414,7 @@ def _fail_swap(self, swap: SwapData, reason: str): self._prepayments.pop(swap.prepay_hash, None) if self.lnworker.get_payment_status(swap.prepay_hash) != PR_PAID: self.lnworker.delete_payment_info(swap.prepay_hash.hex()) - self.lnworker.delete_payment_bundle(swap.payment_hash) + self.lnworker.delete_payment_bundle(payment_hash=swap.payment_hash) @classmethod def extract_preimage(cls, swap: SwapData, claim_tx: Transaction) -> Optional[bytes]: @@ -473,7 +472,7 @@ async def _claim_swap(self, swap: SwapData) -> None: # cleanup self.lnwatcher.remove_callback(swap.lockup_address) if not swap.is_reverse: - self.lnworker.delete_payment_bundle(swap.payment_hash) + self.lnworker.delete_payment_bundle(payment_hash=swap.payment_hash) self.lnworker.unregister_hold_invoice(swap.payment_hash) if not swap.is_reverse: @@ -690,7 +689,7 @@ def add_normal_swap( self.lnworker.add_payment_info_for_hold_invoice( payment_hash, lightning_amount_sat=invoice_amount_sat, - min_final_cltv_delta=min_final_cltv_expiry_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED, + min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) info = self.lnworker.get_payment_info(payment_hash) @@ -709,7 +708,7 @@ def add_normal_swap( if prepay: prepay_hash = self.lnworker.create_payment_info( amount_msat=prepay_amount_sat*1000, - min_final_cltv_delta=min_final_cltv_expiry_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED, + min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) info = self.lnworker.get_payment_info(prepay_hash) diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 09cf9d9489f8..bf22f03fbd21 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -22,33 +22,29 @@ # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import os -import ast import datetime import json import copy -import threading from collections import defaultdict from typing import (Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, Union, AbstractSet) -import binascii import time from functools import partial import attr -from . import util, bitcoin -from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh, MyEncoder -from .invoices import Invoice, Request +from . import bitcoin +from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, MyEncoder from .keystore import bip44_derivation from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput, BadHeaderMagic from .logging import Logger -from .lnutil import HTLCOwner, ChannelType +from .lnutil import HTLCOwner, ChannelType, RecvMPPResolution from . import json_db -from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject, stored_in, stored_as +from .json_db import JsonDB, locked, modifier, StoredObject, stored_in, stored_as from .plugin import run_hook, plugin_loaders from .version import ELECTRUM_VERSION +from .i18n import _ if TYPE_CHECKING: from .storage import WalletStorage @@ -73,7 +69,7 @@ def __init__(self, wallet_db: 'WalletDB'): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 62 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 63 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -238,6 +234,7 @@ def upgrade(self): self._convert_version_60() self._convert_version_61() self._convert_version_62() + self._convert_version_63() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1182,6 +1179,96 @@ def _convert_version_62(self): swap['claim_to_output'] = None self.data['seed_version'] = 62 + def _convert_version_63(self): + if not self._is_upgrade_method_needed(62, 62): + return + # Old ReceivedMPPStatus: + # class ReceivedMPPStatus(NamedTuple): + # resolution: RecvMPPResolution + # expected_msat: int + # htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + # + # New ReceivedMPPStatus: + # class ReceivedMPPStatus(NamedTuple): + # resolution: RecvMPPResolution + # htlcs: set[ReceivedMPPHtlc] + # + # class ReceivedMPPHtlc(NamedTuple): + # scid: ShortChannelID + # htlc: UpdateAddHtlc + # unprocessed_onion: str + + # previously chan.unfulfilled_htlcs went through 4 stages: + # - 1. not forwarded yet: (onion_packet_hex, None) + # - 2. forwarded: (onion_packet_hex, forwarding_key) + # - 3. processed: (None, forwarding_key), not irrevocably removed yet + # - 4. done: (None, forwarding_key), irrevocably removed + channels = self.data.get('channels', {}) + def _move_unprocessed_onion(short_channel_id: str, htlc_id: Optional[int]) -> Optional[Tuple[str, Optional[str]]]: + if htlc_id is None: + return None + for chan_ in channels.values(): + if chan_['short_channel_id'] != short_channel_id: + continue + unfulfilled_htlcs_ = chan_.get('unfulfilled_htlcs', {}) + htlc_data = unfulfilled_htlcs_.get(str(htlc_id)) + if htlc_data is None: + return None + stored_onion_packet, htlc_forwarding_key = htlc_data + if stored_onion_packet is not None: + htlc_data[0] = None # overwrite the onion so it is not processed again in htlc_switch + return stored_onion_packet, htlc_forwarding_key + return None + + mpp_sets = self.data.get('received_mpp_htlcs', {}) + for payment_key, recv_mpp_status in list(mpp_sets.items()): + assert isinstance(recv_mpp_status, list), f"{recv_mpp_status=}" + del recv_mpp_status[1] # remove expected_msat + + new_type_htlcs = [] + forwarding_key = None + for scid, update_add_htlc in recv_mpp_status[1]: # htlc_set + htlc_info_from_chan = _move_unprocessed_onion(scid, update_add_htlc[3]) + if htlc_info_from_chan is None: + # if there is no onion packet for the htlc it is dropped as it was already + # processed in the old htlc_switch + continue + onion_packet_hex = htlc_info_from_chan[0] + forwarding_key = htlc_info_from_chan[1] if htlc_info_from_chan[1] else forwarding_key + new_type_htlcs.append([ + scid, + update_add_htlc, + onion_packet_hex, + ]) + + if len(new_type_htlcs) == 0: + self.logger.debug(f"_convert_version_62: dropping mpp set {payment_key=}.") + del mpp_sets[payment_key] + else: + recv_mpp_status[1] = new_type_htlcs + self.logger.debug(f"_convert_version_62: migrated mpp set {payment_key=}") + if forwarding_key is not None: + # if the forwarding key is set for the old mpp set it was either a forwarding + # or a swap hold invoice. Assuming users of 4.6.2 don't use forwarding this update + # most likely happens during a swap waiting for the preimage. Setting the mpp set + # to SETTLING prevents us from accidentally failing the htlc set after the update, + # however it carries the risk of the channel getting force closed if the swap fails + # as the htlcs won't get failed due to the new SETTLING state + # unless a forwarding error is set. + recv_mpp_status[0] = 4 # RecvMPPResolution.SETTLING + + # replace Tuple[onion, forwarding_key] with just the onion in chan['unfulfilled_htlcs'] + for chan in channels.values(): + unfulfilled_htlcs = chan.get('unfulfilled_htlcs', {}) + for htlc_id, (unprocessed_onion, forwarding_key) in list(unfulfilled_htlcs.items()): + if unprocessed_onion is None: + # delete all unfulfilled_htlcs with empty onion as they are already processed + del unfulfilled_htlcs[htlc_id] + else: + unfulfilled_htlcs[htlc_id] = unprocessed_onion + + self.data['seed_version'] = 63 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return diff --git a/tests/test_commands.py b/tests/test_commands.py index 28a80a607bb2..455f7b87e758 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -549,13 +549,13 @@ async def test_hold_invoice_commands(self, mock_save_db): ) mock_htlc1 = mock.Mock() - mock_htlc1.cltv_abs = 800_000 - mock_htlc1.amount_msat = 4_500_000 + mock_htlc1.htlc.cltv_abs = 800_000 + mock_htlc1.htlc.amount_msat = 4_500_000 mock_htlc2 = mock.Mock() - mock_htlc2.cltv_abs = 800_144 - mock_htlc2.amount_msat = 5_500_000 + mock_htlc2.htlc.cltv_abs = 800_144 + mock_htlc2.htlc.amount_msat = 5_500_000 mock_htlc_status = mock.Mock() - mock_htlc_status.htlc_set = [(None, mock_htlc1), (None, mock_htlc2)] + mock_htlc_status.htlcs = [mock_htlc1, mock_htlc2] mock_htlc_status.resolution = RecvMPPResolution.COMPLETE payment_key = wallet.lnworker._get_payment_key(bytes.fromhex(payment_hash)).hex() diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 360ede67e3ec..fbd148f70d51 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -301,7 +301,6 @@ async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: Ln set_request_status = LNWallet.set_request_status set_payment_status = LNWallet.set_payment_status get_payment_status = LNWallet.get_payment_status - check_mpp_status = LNWallet.check_mpp_status htlc_fulfilled = LNWallet.htlc_fulfilled htlc_failed = LNWallet.htlc_failed save_preimage = LNWallet.save_preimage @@ -334,11 +333,9 @@ async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: Ln unregister_hold_invoice = LNWallet.unregister_hold_invoice add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice - update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc + update_or_create_mpp_with_received_htlc = LNWallet.update_or_create_mpp_with_received_htlc set_mpp_resolution = LNWallet.set_mpp_resolution - is_mpp_amount_reached = LNWallet.is_mpp_amount_reached get_mpp_amounts = LNWallet.get_mpp_amounts - get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp bundle_payments = LNWallet.bundle_payments get_payment_bundle = LNWallet.get_payment_bundle _get_payment_key = LNWallet._get_payment_key @@ -347,11 +344,14 @@ async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: Ln maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw current_low_feerate_per_kw_srk_channel = LNWallet.current_low_feerate_per_kw_srk_channel - maybe_cleanup_mpp = LNWallet.maybe_cleanup_mpp create_onion_for_route = LNWallet.create_onion_for_route - maybe_forward_htlc = LNWallet.maybe_forward_htlc - maybe_forward_trampoline = LNWallet.maybe_forward_trampoline + maybe_forward_htlc_set = LNWallet.maybe_forward_htlc_set + _maybe_forward_htlc = LNWallet._maybe_forward_htlc + _maybe_forward_trampoline = LNWallet._maybe_forward_trampoline _maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created = LNWallet._maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created + set_htlc_set_error = LNWallet.set_htlc_set_error + is_payment_bundle_complete = LNWallet.is_payment_bundle_complete + delete_payment_bundle = LNWallet.delete_payment_bundle _process_htlc_log = LNWallet._process_htlc_log From da5f59903d53b71a274e4d0ba38cdda07847cca2 Mon Sep 17 00:00:00 2001 From: f321x Date: Thu, 28 Aug 2025 16:14:55 +0200 Subject: [PATCH 02/17] tests: lnpeer: test_reject_invalid_min_final_cltv_delta Add `test_reject_invalid_min_final_cltv_delta` which is supposed to test that the peer rejects incoming htlcs with final cltv delta differing from what has been requested in the lightning invoice. --- tests/test_lnpeer.py | 46 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index fbd148f70d51..2a86f5d7aca6 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -954,6 +954,52 @@ async def f(): with self.assertRaises(SuccessfulTest): await f() + async def test_reject_invalid_min_final_cltv_delta(self): + """ + Tests that htlcs with a final cltv delta < the minimum requested in the invoice get + rejected immediately upon receiving them. + """ + async def run_test(test_trampoline): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + async def try_pay_with_too_low_final_cltv_delta(lnaddr, w1=w1, w2=w2): + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + assert lnaddr.get_min_final_cltv_delta() == 400 # what the receiver expects + lnaddr.tags = [tag for tag in lnaddr.tags if tag[0] != 'c'] + [['c', 144]] + b11 = lnencode(lnaddr, w2.node_keypair.privkey) + pay_req = Invoice.from_bech32(b11) + assert pay_req._lnaddr.get_min_final_cltv_delta() == 144 # what w1 will use to pay + result, log = await w1.pay_invoice(pay_req) + if not result: + raise PaymentFailure() + raise PaymentDone() + + # create invoice with high min final cltv delta + lnaddr, _pay_req = self.prepare_invoice(w2, min_final_cltv_delta=400) + + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(try_pay_with_too_low_final_cltv_delta(lnaddr)) + + with self.assertRaises(PaymentFailure): + await f() + + for _test_trampoline in [False, True]: + await run_test(_test_trampoline) + async def test_payment_race(self): """Alice and Bob pay each other simultaneously. They both send 'update_add_htlc' and receive each other's update From 7840df2e0d4775c5bb3deb28886c8ba27a9256e8 Mon Sep 17 00:00:00 2001 From: f321x Date: Tue, 2 Sep 2025 16:08:58 +0200 Subject: [PATCH 03/17] tests: lnpeer: test_reject_payment_for_expired_invoice Test that lnpeer is rejecting incoming htlcs for invoices that are already expired. --- tests/test_lnpeer.py | 48 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 2a86f5d7aca6..e5f98ec12f6e 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -11,6 +11,7 @@ from concurrent import futures from unittest import mock from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence +import time from aiorpcx import timeout_after, TaskTimeout from electrum_ecc import ECPrivkey @@ -559,6 +560,7 @@ def prepare_invoice( payment_hash: bytes = None, invoice_features: LnFeatures = None, min_final_cltv_delta: int = None, + expiry: int = None, ) -> Tuple[LnAddr, Invoice]: amount_btc = amount_msat/Decimal(COIN*1000) if payment_preimage is None and not payment_hash: @@ -586,7 +588,7 @@ def prepare_invoice( direction=RECEIVED, status=PR_UNPAID, min_final_cltv_delta=min_final_cltv_delta, - expiry_delay=LN_EXPIRY_NEVER, + expiry_delay=expiry or LN_EXPIRY_NEVER, ) w2.save_payment_info(info) lnaddr1 = LnAddr( @@ -596,6 +598,7 @@ def prepare_invoice( ('c', min_final_cltv_delta), ('d', 'coffee'), ('9', invoice_features), + ('x', expiry or 3600), ] + routing_hints, payment_secret=payment_secret, ) @@ -1000,6 +1003,49 @@ async def f(): for _test_trampoline in [False, True]: await run_test(_test_trampoline) + async def test_reject_payment_for_expired_invoice(self): + """Tests that new htlcs paying an invoice that has already been expired will get rejected.""" + async def run_test(test_trampoline): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + # create lightning invoice in the past, so it is expired + with mock.patch('time.time', return_value=int(time.time()) - 10000): + lnaddr, _pay_req = self.prepare_invoice(w2, expiry=3600) + b11 = lnencode(lnaddr, w2.node_keypair.privkey) + pay_req = Invoice.from_bech32(b11) + + async def try_pay_expired_invoice(pay_req: Invoice, w1=w1): + assert pay_req.has_expired() + assert lnaddr.is_expired() + with mock.patch.object(w1, "_check_bolt11_invoice", return_value=lnaddr): + result, log = await w1.pay_invoice(pay_req) + if not result: + raise PaymentFailure() + raise PaymentDone() + + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(try_pay_expired_invoice(pay_req)) + + with self.assertRaises(PaymentFailure): + await f() + + for _test_trampoline in [False, True]: + await run_test(_test_trampoline) + async def test_payment_race(self): """Alice and Bob pay each other simultaneously. They both send 'update_add_htlc' and receive each other's update From a7de8de5a28405332fa978322728145824e05032 Mon Sep 17 00:00:00 2001 From: f321x Date: Tue, 2 Sep 2025 18:09:59 +0200 Subject: [PATCH 04/17] tests: lnpeer: test_reject_multiple_payments_of_same_invoice Test that lnpeer rejects incoming htlcs for payments that have already been paid so invoices cannot be paid twice. --- tests/test_lnpeer.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index e5f98ec12f6e..7cf778f943f6 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1046,6 +1046,47 @@ async def f(): for _test_trampoline in [False, True]: await run_test(_test_trampoline) + async def test_reject_multiple_payments_of_same_invoice(self): + """Tests that new htlcs paying an invoice that has already been paid will get rejected.""" + async def run_test(test_trampoline): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + lnaddr, _pay_req = self.prepare_invoice(w2) + + async def try_pay_invoice_twice(pay_req: Invoice, w1=w1): + result, log = await w1.pay_invoice(pay_req) + assert result is True + # now pay the same invoice again, the payment should be rejected by w2 + w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID) + result, log = await w1.pay_invoice(pay_req) + if not result: + # w1.pay_invoice returned a payment failure as the payment got rejected by w2 + raise SuccessfulTest() + raise PaymentDone() + + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(try_pay_invoice_twice(_pay_req)) + + with self.assertRaises(SuccessfulTest): + await f() + + for _test_trampoline in [False, True]: + await run_test(_test_trampoline) + async def test_payment_race(self): """Alice and Bob pay each other simultaneously. They both send 'update_add_htlc' and receive each other's update From a91f7c519fd3cbf1768e2c4efd7cb9d34f58f72b Mon Sep 17 00:00:00 2001 From: f321x Date: Thu, 4 Sep 2025 10:44:24 +0200 Subject: [PATCH 05/17] tests: lnpeer: test_dont_settle_partial_mpp_trigger_with_invalid_cltv_htlc Adds unittest to verify that lnpeer doesn't settle any htlcs of incomplete mpp. --- tests/test_lnpeer.py | 78 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 7cf778f943f6..26a166fcf1a9 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1336,6 +1336,84 @@ async def on_htlc_failed(*args): with self.assertRaises(SuccessfulTest): await f() + async def test_dont_settle_partial_mpp_trigger_with_invalid_cltv_htlc(self): + """Alice gets two htlcs as part of a mpp, one has a cltv too close to expiry and will get failed. + Test that the other htlc won't get settled if the mpp isn't complete anymore after failing the other htlc. + """ + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + async def pay(): + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) + w2.features |= LnFeatures.BASIC_MPP_OPT + lnaddr1, _pay_req = self.prepare_invoice(w2, amount_msat=10_000, min_final_cltv_delta=144) + self.assertTrue(lnaddr1.get_features().supports(LnFeatures.BASIC_MPP_OPT)) + route = (await w1.create_routes_from_invoice(amount_msat=10_000, decoded_invoice=lnaddr1))[0][0].route + + # now p1 sends two htlcs, one is valid (1 msat), one is invalid (9_999 msat) + p1.pay( + route=route, + chan=alice_channel, + amount_msat=1, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr1.paymenthash, + # this htlc is valid and will get accepted, but it shouldn't get settled + min_final_cltv_delta=400, + payment_secret=lnaddr1.payment_secret, + ) + await asyncio.sleep(0.1) + assert w1.get_preimage(lnaddr1.paymenthash) is None + p1.pay( + route=route, + chan=alice_channel, + amount_msat=9_999, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr1.paymenthash, + # this htlc will get failed directly as the cltv is too close to expiry (< 144) + min_final_cltv_delta=1, + payment_secret=lnaddr1.payment_secret, + ) + + while nhtlc_success + nhtlc_failed < 2: + await htlc_resolved.wait() + # both htlcs of the mpp set should get failed and w2 shouldn't release the preimage + self.assertEqual(0, nhtlc_success, f"{nhtlc_success=} | {nhtlc_failed=}") + self.assertEqual(2, nhtlc_failed, f"{nhtlc_success=} | {nhtlc_failed=}") + assert w1.get_preimage(lnaddr1.paymenthash) is None, "w1 shouldn't get the preimage" + raise SuccessfulTest() + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(pay()) + + htlc_resolved = asyncio.Event() + nhtlc_success = 0 + nhtlc_failed = 0 + async def on_htlc_fulfilled(*args): + htlc_resolved.set() + htlc_resolved.clear() + nonlocal nhtlc_success + nhtlc_success += 1 + async def on_htlc_failed(*args): + htlc_resolved.set() + htlc_resolved.clear() + nonlocal nhtlc_failed + nhtlc_failed += 1 + util.register_callback(on_htlc_fulfilled, ["htlc_fulfilled"]) + util.register_callback(on_htlc_failed, ["htlc_failed"]) + + try: + with self.assertRaises(SuccessfulTest): + await f() + finally: + util.unregister_callback(on_htlc_fulfilled) + util.unregister_callback(on_htlc_failed) + async def test_legacy_shutdown_low(self): await self._test_shutdown(alice_fee=100, bob_fee=150) From f35b3538415a1ce8a4f7c27c40443a1f7884b119 Mon Sep 17 00:00:00 2001 From: f321x Date: Fri, 26 Sep 2025 12:04:46 +0200 Subject: [PATCH 06/17] tests: lnpeer: test_mpp_cleanup_after_expiry 1. Alice sends two HTLCs to Bob, not reaching total_msat, and eventually they MPP_TIMEOUT 2. Bob fails both HTLCs 3. Alice then retries and sends HTLCs again to Bob, for the same RHASH, this time reaching total_msat, and the payment succeeds Test that the sets are properly cleaned up after MPP_TIMEOUT and the sender gets a second chance to pay the same invoice. --- tests/test_lnpeer.py | 102 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 2 deletions(-) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 26a166fcf1a9..97813f5f7ddd 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -40,8 +40,7 @@ from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo, RECEIVED from electrum.lnonion import OnionFailureCode, OnionRoutingFailure -from electrum.lnutil import UpdateAddHtlc -from electrum.lnutil import LOCAL, REMOTE +from electrum.lnutil import LOCAL, REMOTE, UpdateAddHtlc, RecvMPPResolution from electrum.invoices import PR_PAID, PR_UNPAID, Invoice, LN_EXPIRY_NEVER from electrum.interface import GracefulDisconnect from electrum.simple_config import SimpleConfig @@ -1414,6 +1413,105 @@ async def on_htlc_failed(*args): util.unregister_callback(on_htlc_fulfilled) util.unregister_callback(on_htlc_failed) + async def test_mpp_cleanup_after_expiry(self): + """ + 1. Alice sends two HTLCs to Bob, not reaching total_msat, and eventually they MPP_TIMEOUT + 2. Bob fails both HTLCs + 3. Alice then retries and sends HTLCs again to Bob, for the same RHASH, + this time reaching total_msat, and the payment succeeds + + Test that the sets are properly cleaned up after MPP_TIMEOUT + and the sender gets a second chance to pay the same invoice. + """ + async def run_test(test_trampoline: bool): + alice_channel, bob_channel = create_test_channels() + alice_peer, bob_peer, alice_wallet, bob_wallet, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + lnaddr1, pay_req1 = self.prepare_invoice(bob_wallet, amount_msat=10_000) + + if test_trampoline: + await self._activate_trampoline(alice_wallet) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=bob_wallet.node_keypair.pubkey), + } + + async def _test(): + route = (await alice_wallet.create_routes_from_invoice(amount_msat=10_000, decoded_invoice=lnaddr1))[0][0].route + assert len(bob_wallet.received_mpp_htlcs) == 0 + # now alice sends two small htlcs, so the set stays incomplete + alice_peer.pay( # htlc 1 + route=route, + chan=alice_channel, + amount_msat=lnaddr1.get_amount_msat() // 4, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr1.paymenthash, + min_final_cltv_delta=400, + payment_secret=lnaddr1.payment_secret, + ) + alice_peer.pay( # htlc 2 + route=route, + chan=alice_channel, + amount_msat=lnaddr1.get_amount_msat() // 4, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr1.paymenthash, + min_final_cltv_delta=400, + payment_secret=lnaddr1.payment_secret, + ) + await asyncio.sleep(bob_wallet.MPP_EXPIRY // 2) # give bob time to receive the htlc + bob_payment_key = bob_wallet._get_payment_key(lnaddr1.paymenthash).hex() + assert bob_wallet.received_mpp_htlcs[bob_payment_key].resolution == RecvMPPResolution.WAITING + assert len(bob_wallet.received_mpp_htlcs[bob_payment_key].htlcs) == 2 + # now wait until bob expires the mpp (set) + await asyncio.wait_for(alice_htlc_resolved.wait(), bob_wallet.MPP_EXPIRY * 3) # this can take some time, esp. on CI + # check that bob failed the htlc + assert nhtlc_success == 0 and nhtlc_failed == 2 + # check that bob deleted the mpp set as it should be expired and resolved now + assert bob_payment_key not in bob_wallet.received_mpp_htlcs + alice_wallet._paysessions.clear() + assert alice_wallet.get_preimage(lnaddr1.paymenthash) is None # bob didn't preimage + # now try to pay again, this time the full amount + result, log = await alice_wallet.pay_invoice(pay_req1) + assert result is True + assert alice_wallet.get_preimage(lnaddr1.paymenthash) is not None # bob revealed preimage + assert len(bob_wallet.received_mpp_htlcs) == 0 # bob should also clean up a successful set + raise SuccessfulTest() + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(alice_peer._message_loop()) + await group.spawn(alice_peer.htlc_switch()) + await group.spawn(bob_peer._message_loop()) + await group.spawn(bob_peer.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(_test()) + + alice_htlc_resolved = asyncio.Event() + nhtlc_success = 0 + nhtlc_failed = 0 + async def on_sender_htlc_fulfilled(*args): + alice_htlc_resolved.set() + alice_htlc_resolved.clear() + nonlocal nhtlc_success + nhtlc_success += 1 + async def on_sender_htlc_failed(*args): + alice_htlc_resolved.set() + alice_htlc_resolved.clear() + nonlocal nhtlc_failed + nhtlc_failed += 1 + util.register_callback(on_sender_htlc_fulfilled, ["htlc_fulfilled"]) + util.register_callback(on_sender_htlc_failed, ["htlc_failed"]) + + try: + with self.assertRaises(SuccessfulTest): + await f() + finally: + util.unregister_callback(on_sender_htlc_fulfilled) + util.unregister_callback(on_sender_htlc_failed) + + for use_trampoline in [True, False]: + self.logger.debug(f"test_mpp_cleanup_after_expiry: {use_trampoline=}") + await run_test(use_trampoline) + async def test_legacy_shutdown_low(self): await self._test_shutdown(alice_fee=100, bob_fee=150) From 447d91d7b69b80d63cbdec1bcaf9854c63fa342c Mon Sep 17 00:00:00 2001 From: f321x Date: Mon, 6 Oct 2025 17:15:36 +0200 Subject: [PATCH 07/17] tests: lnpeer: test_trampoline_mpp_consolidation_forwarding_amount Add sanity check that bob is not forwarding more sats to carol if than he receives from alice. (he only forwards once and doesn't try to forward multiple times). This should get caught by asserts in lnworker/lnpeer, nevertheless it seems to make sense to just add this test to prevent regressions of this kind. --- tests/test_lnpeer.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 97813f5f7ddd..1ea9ae0ed00b 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -2394,6 +2394,25 @@ async def test_trampoline_mpp_consolidation(self): graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True) await self._run_trampoline_payment(graph) + async def test_trampoline_mpp_consolidation_forwarding_amount(self): + """sanity check that bob is forwarding less than he is receiving""" + # alice->bob->carol->dave + graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True) + # bump alices trampoline fee level so the first payment succeeds and the htlc sums can be compared usefully below. + alice = graph.workers['alice'] + alice.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 6 + with self.assertRaises(PaymentDone): + await self._run_trampoline_payment(graph, attempts=1) + + # assert bob hasn't forwarded more than he received + bob_alice_channel = graph.channels[('bob', 'alice')] + htlcs_bob_received_from_alice = bob_alice_channel.hm.all_htlcs_ever() + bob_carol_channel = graph.channels[('bob', 'carol')] + htlcs_bob_sent_to_carol = bob_carol_channel.hm.all_htlcs_ever() + sum_bob_received = sum(htlc.amount_msat for (direction, htlc) in htlcs_bob_received_from_alice) + sum_bob_sent = sum(htlc.amount_msat for (direction, htlc) in htlcs_bob_sent_to_carol) + assert sum_bob_sent < sum_bob_received, f"{sum_bob_sent=} > {sum_bob_received=}" + async def test_trampoline_mpp_consolidation_with_hold_invoice(self): with self.assertRaises(PaymentDone): graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True) From bb828097b3389b13871990aae3ed883e8f330b66 Mon Sep 17 00:00:00 2001 From: f321x Date: Tue, 7 Oct 2025 12:57:33 +0200 Subject: [PATCH 08/17] tests: test_lnpeer: test compare trampoline onions Adds test_forwarder_fails_for_inconsistent_trampoline_onions which checks that a forwarder compares the trampoline onions of a mpp set and fails the set if the onions are not similar. In the test alice sends a mpp through bob with 2 htlcs, in one trampoline onion amt_to_forward is off by 1 msat so bob fails the htlc set instead of initiating the trampoline forwarding. --- tests/test_lnpeer.py | 76 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 1ea9ae0ed00b..9d9ddc699fa3 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import shutil import copy import tempfile @@ -11,6 +12,7 @@ from concurrent import futures from unittest import mock from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence +from types import MappingProxyType import time from aiorpcx import timeout_after, TaskTimeout @@ -39,7 +41,7 @@ from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo, RECEIVED -from electrum.lnonion import OnionFailureCode, OnionRoutingFailure +from electrum.lnonion import OnionFailureCode, OnionRoutingFailure, OnionHopsDataSingle, OnionPacket from electrum.lnutil import LOCAL, REMOTE, UpdateAddHtlc, RecvMPPResolution from electrum.invoices import PR_PAID, PR_UNPAID, Invoice, LN_EXPIRY_NEVER from electrum.interface import GracefulDisconnect @@ -2481,6 +2483,78 @@ async def test_multi_trampoline_payment(self): attempts=30, # the default used in LNWallet.pay_invoice() ) + async def test_forwarder_fails_for_inconsistent_trampoline_onions(self): + """ + verify that the receiver of a trampoline forwarding fails the mpp set + if the trampoline onions are not similar + In this test alice tries to forward through bob, however in one trampoline onion she sends + amt_to_forward is off by one msat. Bob should compare the trampoline onions and fail the set. + """ + + # store a modified trampoline onion to be injected into lnworker.new_onion_packet later when sending the htlcs + modified_trampoline_onion = None + def modified_new_onion_packet_trampoline(payment_path_pubkeys, session_key, hops_data: List[OnionHopsDataSingle], **kwargs): + nonlocal modified_trampoline_onion + assert modified_trampoline_onion is None, "this mock should get called only once" + modified_hops_data = copy.copy(hops_data) + # first payload (i[0]) is for bob who is supposed to forward the trampoline payment, in this + # test he should fail the incoming htlcs as their trampolines are not similar + new_payload = dict(modified_hops_data[0].payload) + amt_to_forward = dict(new_payload['amt_to_forward']) + amt_to_forward['amt_to_forward'] -= 1 + new_payload['amt_to_forward'] = amt_to_forward + modified_hops_data[0] = dataclasses.replace(modified_hops_data[0], payload=new_payload) + self.logger.debug(f"{modified_hops_data=}\nsent_{hops_data=}") + modified_trampoline_onion = electrum.lnonion.new_onion_packet( + payment_path_pubkeys, + session_key, + modified_hops_data, + **kwargs + ) + # return the unmodified onion + return electrum.lnonion.new_onion_packet( + payment_path_pubkeys, + session_key, + hops_data, + **kwargs + ) + + # this gets called in lnworker per sent htlc, for one sent htlc we inject the modified trampoline + # onion created before in the mock above + def modified_new_onion_packet_lnworker(payment_path_pubkeys, session_key, hops_data: List[OnionHopsDataSingle], **kwargs): + nonlocal modified_trampoline_onion + hops_data = copy.copy(hops_data) + if modified_trampoline_onion: + assert isinstance(modified_trampoline_onion, OnionPacket) + assert len(hops_data) == 1 + new_payload = dict(hops_data[0].payload) + new_payload['trampoline_onion_packet'] = { + "version": modified_trampoline_onion.version, + "public_key": modified_trampoline_onion.public_key, + "hops_data": modified_trampoline_onion.hops_data, + "hmac": modified_trampoline_onion.hmac, + } + hops_data[0] = dataclasses.replace(hops_data[0], payload=MappingProxyType(new_payload)) + modified_trampoline_onion = None + return electrum.lnonion.new_onion_packet( + payment_path_pubkeys, + session_key, + hops_data, + **kwargs + ) + + graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True) + alice = graph.workers['alice'] + alice.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 6 # set high so the first attempt would succeed + with self.assertRaises(PaymentFailure): + with mock.patch('electrum.trampoline.new_onion_packet', side_effect=modified_new_onion_packet_trampoline), \ + mock.patch('electrum.lnworker.new_onion_packet', side_effect=modified_new_onion_packet_lnworker): + await self._run_trampoline_payment(graph, attempts=1) + bob_alice_channel = graph.channels[('bob', 'alice')] + bob_hm = bob_alice_channel.hm + assert len(bob_hm.all_htlcs_ever()) == 2 + assert all(bob_hm.was_htlc_failed(htlc_id=htlc.htlc_id, htlc_proposer=HTLCOwner.REMOTE) for (_, htlc) in bob_hm.all_htlcs_ever()) + class TestPeerDirectAnchors(TestPeerDirect): TEST_ANCHOR_CHANNELS = True From f56b13b6100461ba07bae42cbdfbe98e03b055e8 Mon Sep 17 00:00:00 2001 From: f321x Date: Wed, 8 Oct 2025 12:59:48 +0200 Subject: [PATCH 09/17] tests: test_lnpeer: test_hold_invoice_set_doesnt_get_exp Add test `test_hold_invoice_set_doesnt_get_expired` to test_lnpeer to ensure a mpp set on which a hold invoice callback doesn't get expired automatically if the cltv_abs falls below MIN_FINAL_CLTV_DELTA_ACCEPTED as these sets should only get failed if the htlcs are safe to fail by the target of the hold invoice callback (e.g. swap got refunded successfully). --- tests/test_lnpeer.py | 85 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 3 deletions(-) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 9d9ddc699fa3..db38087ba900 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -104,11 +104,13 @@ def populate_fee_estimates(self): class MockBlockchain: - - def height(self): + def __init__(self): # Let's return a non-zero, realistic height. # 0 might hide relative vs abs locktime confusion bugs. - return 600_000 + self._height = 600_000 + + def height(self): + return self._height def is_tip_stale(self): return False @@ -1807,6 +1809,83 @@ async def f(): await f() self.assertTrue(isinstance(failing_task.exception().__cause__, lnmsg.UnexpectedEndOfStream)) + async def test_hold_invoice_set_doesnt_get_expired(self): + """ + Alice pays a hold invoice from Bob, Bob doesn't release preimage. Verify that Bob doesn't + expire the htlc set MIN_FINAL_CLTV_DELTA_ACCEPTED blocks before htlc.cltv_abs (as we would do with normal htlc sets). + The htlc set should only get failed if the user of the hold invoice callback explicitly removes the + callback (e.g. after refunding and failing a swap), otherwise it should get timed out onchain (force-close). + + This only tests hold invoice logic for hold invoices registered with `LNWallet.register_hold_invoice()`, + as used e.g. by submarine swaps. It doesn't cover the hold invoices created by the hold invoice CLI + which behave differently and use the persisted `LNWallet.dont_expire_htlcs` dict. + """ + async def run_test(test_trampoline): + alice_channel, bob_channel = create_test_channels() + alice_p, bob_p, alice_w, bob_w, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + lnaddr, pay_req = self.prepare_invoice(bob_w, min_final_cltv_delta=150) + del bob_w._preimages[pay_req.rhash] # del preimage so bob doesn't settle + payment_key = bob_w._get_payment_key(lnaddr.paymenthash).hex() + + cb_got_called = False + async def cb(_payment_hash): + self.logger.debug(f"hold invoice callback called. {bob_w.network.get_local_height()=}") + nonlocal cb_got_called + cb_got_called = True + + bob_w.register_hold_invoice(lnaddr.paymenthash, cb) + + async def check_mpp_state(): + async def wait_for_resolution(): + while True: + await asyncio.sleep(0.1) + if payment_key not in bob_w.received_mpp_htlcs: + continue + if not bob_w.received_mpp_htlcs[payment_key].resolution == RecvMPPResolution.SETTLING: + continue + return + await util.wait_for2(wait_for_resolution(), timeout=2) + assert cb_got_called + mpp_set = bob_w.received_mpp_htlcs[payment_key] + self.assertEqual(mpp_set.resolution, RecvMPPResolution.SETTLING, msg=mpp_set.resolution) + self.assertEqual(len(mpp_set.htlcs), 1, f"should get only one htlc: {mpp_set.htlcs=}") + left_to_expiry = next(iter(mpp_set.htlcs)).htlc.cltv_abs - bob_w.network.get_local_height() + # now mine up to one block after the expiry + bob_w.network._blockchain._height += left_to_expiry + 1 + await asyncio.sleep(0.2) + # bob still has the mpp set and it is not failed + # it should only get removed once the channel is redeemed + self.assertIn(bob_w.received_mpp_htlcs[payment_key].resolution, (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING)) + # now also check that the mpp set will get set failed if the hold invoice + # is being explicitly unregistered, and we don't have a preimage to settle it + bob_w.unregister_hold_invoice(lnaddr.paymenthash) + self.assertEqual(bob_w.received_mpp_htlcs[payment_key].resolution, RecvMPPResolution.FAILED) + raise SuccessfulTest() + + if test_trampoline: + await self._activate_trampoline(alice_w) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=bob_w.node_keypair.pubkey), + } + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(alice_p._message_loop()) + await group.spawn(alice_p.htlc_switch()) + await group.spawn(bob_p._message_loop()) + await group.spawn(bob_p.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(alice_w.pay_invoice(pay_req)) + await group.spawn(check_mpp_state()) + + with self.assertRaises(SuccessfulTest): + await f() + + for _test_trampoline in [False, True]: + await run_test(_test_trampoline) + class TestPeerForwarding(TestPeer): From 042557da9ba269361e39151f8998366d781ca133 Mon Sep 17 00:00:00 2001 From: f321x Date: Thu, 9 Oct 2025 15:44:52 +0200 Subject: [PATCH 10/17] tests: test_lnpeer: test_htlc_switch_iteration_benchmark Benchmark how long a call to _run_htlc_switch_iteration takes with 10 trampoline mpp sets of 1 htlc each. --- tests/test_lnchannel.py | 16 ++++++---- tests/test_lnpeer.py | 68 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/tests/test_lnchannel.py b/tests/test_lnchannel.py index 2ba4c4092f7c..6299042996b1 100644 --- a/tests/test_lnchannel.py +++ b/tests/test_lnchannel.py @@ -50,7 +50,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, local_amount, remote_amount, privkeys, other_pubkeys, seed, cur, nex, other_node_id, l_dust, r_dust, l_csv, - r_csv, anchor_outputs, local_max_inflight, remote_max_inflight): + r_csv, anchor_outputs, local_max_inflight, remote_max_inflight, + max_accepted_htlcs): #assert local_amount > 0 #assert remote_amount > 0 @@ -71,7 +72,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, to_self_delay=r_csv, dust_limit_sat=r_dust, max_htlc_value_in_flight_msat=remote_max_inflight, - max_accepted_htlcs=5, + max_accepted_htlcs=max_accepted_htlcs, initial_msat=remote_amount, reserve_sat=0, htlc_minimum_msat=1, @@ -91,7 +92,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, to_self_delay=l_csv, dust_limit_sat=l_dust, max_htlc_value_in_flight_msat=local_max_inflight, - max_accepted_htlcs=5, + max_accepted_htlcs=max_accepted_htlcs, initial_msat=local_amount, reserve_sat=0, per_commitment_secret_seed=seed, @@ -133,7 +134,8 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None, alice_name="alice", bob_name="bob", alice_pubkey=b"\x01"*33, bob_pubkey=b"\x02"*33, random_seed=None, anchor_outputs=False, - local_max_inflight=None, remote_max_inflight=None): + local_max_inflight=None, remote_max_inflight=None, + max_accepted_htlcs=5): if random_seed is None: # needed for deterministic randomness random_seed = os.urandom(32) random_gen = PRNG(random_seed) @@ -168,7 +170,8 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, other_node_id=bob_pubkey, l_dust=200, r_dust=1300, l_csv=5, r_csv=4, anchor_outputs=anchor_outputs, - local_max_inflight=local_max_inflight, remote_max_inflight=remote_max_inflight + local_max_inflight=local_max_inflight, remote_max_inflight=remote_max_inflight, + max_accepted_htlcs=max_accepted_htlcs, ), name=f"{alice_name}->{bob_name}", initial_feerate=feerate), @@ -178,7 +181,8 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, other_node_id=alice_pubkey, l_dust=1300, r_dust=200, l_csv=4, r_csv=5, anchor_outputs=anchor_outputs, - local_max_inflight=remote_max_inflight, remote_max_inflight=local_max_inflight + local_max_inflight=remote_max_inflight, remote_max_inflight=local_max_inflight, + max_accepted_htlcs=max_accepted_htlcs, ), name=f"{bob_name}->{alice_name}", initial_feerate=feerate) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index db38087ba900..204ebd6f4935 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -14,6 +14,7 @@ from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence from types import MappingProxyType import time +import statistics from aiorpcx import timeout_after, TaskTimeout from electrum_ecc import ECPrivkey @@ -1886,6 +1887,73 @@ async def f(): for _test_trampoline in [False, True]: await run_test(_test_trampoline) + async def test_htlc_switch_iteration_benchmark(self): + """Test how long a call to _run_htlc_switch_iteration takes with 10 trampoline + mpp sets of 1 htlc each. Raise if it takes longer than 20ms (median). + To create flamegraph with py-spy raise NUM_ITERATIONS to 1000 (for more samples) then run: + $ py-spy record -o flamegraph.svg --subprocesses -- python -m pytest tests/test_lnpeer.py::TestPeerDirect::test_htlc_switch_iteration_benchmark + """ + NUM_ITERATIONS = 25 + alice_channel, bob_channel = create_test_channels(max_accepted_htlcs=20) + alice_p, bob_p, alice_w, bob_w, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + await self._activate_trampoline(alice_w) + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=bob_w.node_keypair.pubkey), + } + + # create 10 invoices (10 pending htlc sets with 1 htlc each) + invoices = [] # type: list[tuple[LnAddr, Invoice]] + for i in range(10): + lnaddr, pay_req = self.prepare_invoice(bob_w) + # prevent bob from settling so that htlc switch will have to iterate through all pending htlcs + bob_w.dont_settle_htlcs[pay_req.rhash] = None + invoices.append((lnaddr, pay_req)) + self.assertEqual(len(invoices), 10, msg=len(invoices)) + + iterations = [] + do_benchmark = False + _run_bob_htlc_switch_iteration = bob_p._run_htlc_switch_iteration + def timed_htlc_switch_iteration(): + start = time.perf_counter() + _run_bob_htlc_switch_iteration() + duration = time.perf_counter() - start + if do_benchmark: + iterations.append(duration) + bob_p._run_htlc_switch_iteration = timed_htlc_switch_iteration + + async def benchmark_htlc_switch_iterations(): + waited = 0 + while not len(bob_w.received_mpp_htlcs) == 10 : + waited += 0.1 + await asyncio.sleep(0.1) + if waited > 2: + raise TimeoutError() + nonlocal do_benchmark + do_benchmark = True + while len(iterations) < NUM_ITERATIONS: + await asyncio.sleep(0.05) + # average = sum(iterations) / len(iterations) + median_duration = statistics.median(iterations) + res = f"median duration per htlc switch iteration: {median_duration:.6f}s over {len(iterations)=}" + self.logger.info(res) + self.assertLess(median_duration, 0.02, msg=res) + raise SuccessfulTest() + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(alice_p._message_loop()) + await group.spawn(alice_p.htlc_switch()) + await group.spawn(bob_p._message_loop()) + await group.spawn(bob_p.htlc_switch()) + await asyncio.sleep(0.01) + for _lnaddr, req in invoices: + await group.spawn(alice_w.pay_invoice(req)) + await benchmark_htlc_switch_iterations() + + with self.assertRaises(SuccessfulTest): + await f() + class TestPeerForwarding(TestPeer): From 95729a08efc68a905f54ec7733f37eeb2af7025e Mon Sep 17 00:00:00 2001 From: f321x Date: Thu, 16 Oct 2025 11:50:31 +0200 Subject: [PATCH 11/17] lnpeer: report htlc_switch exceptions to crash reporter It seems useful to report exceptions happening in the htlc_switch to the crash reporter as it shouldn't raise exceptions in theory and this could help catch subtle bugs. --- electrum/lnpeer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 377dd375f226..f09ef97a224b 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2788,7 +2788,15 @@ async def htlc_switch(self): await group.spawn(self.downstream_htlc_resolved_event.wait()) self._htlc_switch_iterstart_event.set() self._htlc_switch_iterstart_event.clear() - self._run_htlc_switch_iteration() + try: + self._run_htlc_switch_iteration() + except Exception as e: + # this is code with many asserts and dense logic so it seems useful to allow the user + # report to exceptions that otherwise might go unnoticed for some time + reported_exc = type(e)("redacted") # text could contain onions, payment hashes etc. + reported_exc.__traceback__ = e.__traceback__ + util.send_exception_to_crash_reporter(reported_exc) + raise e @util.profiler(min_threshold=0.02) def _run_htlc_switch_iteration(self): From b1e58450bde1a557779a6b65ec793a7c1d45ba41 Mon Sep 17 00:00:00 2001 From: f321x Date: Mon, 3 Nov 2025 12:26:13 +0100 Subject: [PATCH 12/17] tests: test_lnpeer: add test_payment_bundle_with_hold_invoice Adds test_payment_bundle_with_hold_invoice to simulate the use of a payment bundle in which one invoice of the bundle needs to trigger a hold invoice callback (similar to submarine swaps). Also modifies the test helper _test_simple_payment() to compare the results of all payment attempts instead of just returning after the first (of multiple) payments raises its result causing the test to miss if all payments were successful or not. --- tests/test_lnpeer.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 204ebd6f4935..09cdbc24daaf 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -862,20 +862,24 @@ async def _test_simple_payment( """Alice pays Bob a single HTLC via direct channel.""" alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + results = {} async def pay(lnaddr, pay_req): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) result, log = await w1.pay_invoice(pay_req) if result is True: self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) - raise PaymentDone() + results[lnaddr] = PaymentDone() else: - raise PaymentFailure() + results[lnaddr] = PaymentFailure() lnaddr, pay_req = self.prepare_invoice(w2) + to_pay = [(lnaddr, pay_req)] self.prepare_recipient(w2, lnaddr.paymenthash, test_hold_invoice, test_failure) if test_bundle: lnaddr2, pay_req2 = self.prepare_invoice(w2) w2.bundle_payments([lnaddr.paymenthash, lnaddr2.paymenthash]) + if not test_bundle_timeout: + to_pay.append((lnaddr2, pay_req2)) if test_trampoline: await self._activate_trampoline(w1) @@ -893,9 +897,16 @@ async def f(): await asyncio.sleep(0.01) invoice_features = lnaddr.get_features() self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) - await group.spawn(pay(lnaddr, pay_req)) - if test_bundle and not test_bundle_timeout: - await group.spawn(pay(lnaddr2, pay_req2)) + for lnaddr_to_pay, pay_req_to_pay in to_pay: + await group.spawn(pay(lnaddr_to_pay, pay_req_to_pay)) + elapsed = 0 + while len(results) < len(to_pay) and elapsed < 4: + await asyncio.sleep(0.05) # wait for all payments to finish/fail (or timeout) + elapsed += 0.05 + self.assertEqual(len(results), len(to_pay), msg="timeout") + # all payment results should be similar + self.assertEqual(len(set(type(res) for res in results.values())), 1, msg=results) + raise list(results.values())[0] await f() @@ -919,6 +930,11 @@ async def test_payment_bundle_timeout(self): with self.assertRaises(PaymentFailure): await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_bundle_timeout=True) + async def test_payment_bundle_with_hold_invoice(self): + for test_trampoline in [False, True]: + with self.assertRaises(PaymentDone): + await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_hold_invoice=True) + async def test_simple_payment_success_with_hold_invoice(self): for test_trampoline in [False, True]: with self.assertRaises(PaymentDone): From abc469c8462749d54a99dea540ac002b17f2f064 Mon Sep 17 00:00:00 2001 From: f321x Date: Mon, 24 Nov 2025 13:02:42 +0100 Subject: [PATCH 13/17] lnworker: split dont_settle_htlcs Splits `LNWallet.dont_settle_htlcs` into `LNWallet.dont_settle_htlcs` and `LNWallet.dont_expire_htlcs`. Registering a payment hash in dont_settle_htlcs will prevent it from getting fulfilled if we have the preimage stored. The preimage will not be released before the the payment hash gets removed from dont_settle_htlcs. Htlcs can still get expired as usual or failed if no preimage is known. This is only used by Just-in-time channel openings. Registering a payment hash in dont_expire_htlcs allows to overwrite the minimum final cltv delta value after which htlcs would usually get expired. This allows to delay expiry of htlcs or, if the value in the dont_settle_htlcs dict is None, completely prevent expiry and let the htlc get expired onchain. Splitting this up in two different dicts makes it more explicit and easier to reason about what they are actually doing. Please enter the commit message for your changes. Lines starting --- electrum/commands.py | 24 +++++++++++++----------- electrum/lnpeer.py | 22 ++++++++++++++++------ electrum/lnworker.py | 17 ++++++++++++++++- tests/test_commands.py | 5 ++--- tests/test_lnpeer.py | 1 + 5 files changed, 48 insertions(+), 21 deletions(-) diff --git a/electrum/commands.py b/electrum/commands.py index f249f55f9396..889243d80939 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -1389,7 +1389,9 @@ async def add_hold_invoice( ) -> dict: """ Create a lightning hold invoice for the given payment hash. Hold invoices have to get settled manually later. - HTLCs will get failed automatically if block_height + 144 > htlc.cltv_abs. + HTLCs will get failed automatically if block_height + 144 > htlc.cltv_abs, if the intention is to + settle them as late as possible a safety margin of some blocks should be used to prevent them + from getting failed accidentally. arg:str:payment_hash:Hex encoded payment hash to be used for the invoice arg:decimal:amount:Optional requested amount (in btc) @@ -1399,7 +1401,7 @@ async def add_hold_invoice( """ assert len(payment_hash) == 64, f"Invalid payment hash length: {len(payment_hash)} != 64" assert payment_hash not in wallet.lnworker.payment_info, "Payment hash already used!" - assert payment_hash not in wallet.lnworker.dont_settle_htlcs, "Payment hash already used!" + assert payment_hash not in wallet.lnworker.dont_expire_htlcs, "Payment hash already used!" assert wallet.lnworker.get_preimage(bfh(payment_hash)) is None, "Already got a preimage for this payment hash!" assert MIN_FINAL_CLTV_DELTA_ACCEPTED < min_final_cltv_expiry_delta < 576, "Use a sane min_final_cltv_expiry_delta value" amount = amount if amount and satoshis(amount) > 0 else None # make amount either >0 or None @@ -1419,7 +1421,9 @@ async def add_hold_invoice( message=memo, fallback_address=None ) - wallet.lnworker.dont_settle_htlcs[payment_hash] = None + # this prevents incoming htlcs from getting expired while the preimage isn't set. + # If their blocks to expiry fall below MIN_FINAL_CLTV_DELTA_ACCEPTED they will get failed. + wallet.lnworker.dont_expire_htlcs[payment_hash] = MIN_FINAL_CLTV_DELTA_ACCEPTED wallet.set_label(payment_hash, memo) result = { "invoice": invoice @@ -1439,12 +1443,11 @@ async def settle_hold_invoice(self, preimage: str, wallet: Abstract_Wallet = Non assert payment_hash not in wallet.lnworker._preimages, f"Invoice {payment_hash=} already settled" assert payment_hash in wallet.lnworker.payment_info, \ f"Couldn't find lightning invoice for {payment_hash=}" - assert payment_hash in wallet.lnworker.dont_settle_htlcs, f"Invoice {payment_hash=} not a hold invoice?" + assert payment_hash in wallet.lnworker.dont_expire_htlcs, f"Invoice {payment_hash=} not a hold invoice?" assert wallet.lnworker.is_complete_mpp(bfh(payment_hash)), \ f"MPP incomplete, cannot settle hold invoice {payment_hash} yet" info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash)) assert (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) >= (info.amount_msat or 0) - del wallet.lnworker.dont_settle_htlcs[payment_hash] wallet.lnworker.save_preimage(bfh(payment_hash), bfh(preimage)) util.trigger_callback('wallet_updated', wallet) result = { @@ -1462,15 +1465,15 @@ async def cancel_hold_invoice(self, payment_hash: str, wallet: Abstract_Wallet = assert payment_hash in wallet.lnworker.payment_info, \ f"Couldn't find lightning invoice for payment hash {payment_hash}" assert payment_hash not in wallet.lnworker._preimages, "Cannot cancel anymore, preimage already given." - assert payment_hash in wallet.lnworker.dont_settle_htlcs, f"{payment_hash=} not a hold invoice?" + assert payment_hash in wallet.lnworker.dont_expire_htlcs, f"{payment_hash=} not a hold invoice?" # set to PR_UNPAID so it can get deleted wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID) wallet.lnworker.delete_payment_info(payment_hash) wallet.set_label(payment_hash, None) + del wallet.lnworker.dont_expire_htlcs[payment_hash] while wallet.lnworker.is_complete_mpp(bfh(payment_hash)): - # wait until the htlcs got failed so the payment won't get settled accidentally in a race + # block until the htlcs got failed await asyncio.sleep(0.1) - del wallet.lnworker.dont_settle_htlcs[payment_hash] result = { "cancelled": payment_hash } @@ -1503,15 +1506,14 @@ async def check_hold_invoice(self, payment_hash: str, wallet: Abstract_Wallet = elif not is_complete_mpp and not wallet.lnworker.get_preimage_hex(payment_hash): # is_complete_mpp is False for settled payments result["status"] = "unpaid" - elif is_complete_mpp and payment_hash in wallet.lnworker.dont_settle_htlcs: + elif is_complete_mpp and payment_hash in wallet.lnworker.dont_expire_htlcs: result["status"] = "paid" payment_key: str = wallet.lnworker._get_payment_key(bfh(payment_hash)).hex() htlc_status = wallet.lnworker.received_mpp_htlcs[payment_key] result["closest_htlc_expiry_height"] = min( mpp_htlc.htlc.cltv_abs for mpp_htlc in htlc_status.htlcs ) - elif wallet.lnworker.get_preimage_hex(payment_hash) is not None \ - and payment_hash not in wallet.lnworker.dont_settle_htlcs: + elif wallet.lnworker.get_preimage_hex(payment_hash) is not None: result["status"] = "settled" plist = wallet.lnworker.get_payments(status='settled')[bfh(payment_hash)] _dir, amount_msat, _fee, _ts = wallet.lnworker.get_payment_value(info, plist) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index f09ef97a224b..e263da1ca855 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2185,6 +2185,7 @@ def _check_unfulfilled_htlc( # the htlc set representing the whole payment (payment key derived from trampoline/invoice secret). payment_key = (payment_hash + (outer_onion_payment_secret or payment_secret_from_onion)).hex() + # for safety, still enforce MIN_FINAL_CLTV_DELTA here even if payment_hash is in dont_expire_htlcs if blocks_to_expiry < MIN_FINAL_CLTV_DELTA_ACCEPTED: # this check should be done here for new htlcs and ongoing on pending sets. # Here it is done so that invalid received htlcs will never get added to a set, @@ -2261,6 +2262,8 @@ def _fulfill_htlc_set(self, payment_key: str, preimage: bytes): # get payment hash of any htlc in the set (they are all the same) payment_hash = htlc_set.get_payment_hash() assert payment_hash is not None, htlc_set + assert payment_hash not in self.lnworker.dont_settle_htlcs + self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) # htlcs wont get expired anymore for mpp_htlc in list(htlc_set.htlcs): htlc_id = mpp_htlc.htlc.htlc_id chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid) @@ -2300,6 +2303,10 @@ def _fail_htlc_set( raw_error, error_code, error_data = error_tuple local_height = self.network.blockchain().height() + payment_hash = htlc_set.get_payment_hash() + assert payment_hash is not None, "Empty htlc set?" + self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) + self.lnworker.dont_settle_htlcs.pop(payment_hash.hex(), None) # already failed for mpp_htlc in list(htlc_set.htlcs): chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid) htlc_id = mpp_htlc.htlc.htlc_id @@ -2317,7 +2324,7 @@ def _fail_htlc_set( onion_packet = self._parse_onion_packet(mpp_htlc.unprocessed_onion) processed_onion_packet = self._process_incoming_onion_packet( onion_packet, - payment_hash=mpp_htlc.htlc.payment_hash, + payment_hash=payment_hash, is_trampoline=False, ) if raw_error: @@ -2331,7 +2338,7 @@ def _fail_htlc_set( if processed_onion_packet.trampoline_onion_packet: processed_trampoline_onion_packet = self._process_incoming_onion_packet( processed_onion_packet.trampoline_onion_packet, - payment_hash=mpp_htlc.htlc.payment_hash, + payment_hash=payment_hash, is_trampoline=True, ) amount_to_forward = processed_trampoline_onion_packet.amt_to_forward @@ -3048,7 +3055,8 @@ def _check_unfulfilled_htlc_set( # check for expiry over time and potentially fail the whole set if any # htlc's cltv becomes too close blocks_to_expiry = max(0, closest_cltv_abs - local_height) - if blocks_to_expiry < MIN_FINAL_CLTV_DELTA_ACCEPTED: + accepted_expiry_delta = self.lnworker.dont_expire_htlcs.get(payment_hash.hex(), MIN_FINAL_CLTV_DELTA_ACCEPTED) + if accepted_expiry_delta is not None and blocks_to_expiry < accepted_expiry_delta: _log_fail_reason(f"htlc.cltv_abs is unreasonably close") return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None @@ -3119,11 +3127,13 @@ def _check_unfulfilled_htlc_set( return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None return None, None, None - if payment_hash.hex() in self.lnworker.dont_settle_htlcs: - # used by hold invoice cli to prevent the htlcs from getting fulfilled automatically + preimage = self.lnworker.get_preimage(payment_hash) + settling_blocked = preimage is not None and payment_hash.hex() in self.lnworker.dont_settle_htlcs + waiting_for_preimage = preimage is None and payment_hash.hex() in self.lnworker.dont_expire_htlcs + if settling_blocked or waiting_for_preimage: + # used by hold invoice cli and JIT channels to prevent the htlcs from getting fulfilled automatically return None, None, None - preimage = self.lnworker.get_preimage(payment_hash) hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) if not preimage and not hold_invoice_callback: _log_fail_reason(f"cannot settle, no preimage or callback found for {payment_hash.hex()=}") diff --git a/electrum/lnworker.py b/electrum/lnworker.py index d4c5bc8cb2a4..347bcc084f24 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -925,7 +925,22 @@ def __init__(self, wallet: 'Abstract_Wallet', xprv): self.active_forwardings = self.db.get_dict('active_forwardings') # type: Dict[str, List[str]] # Dict: payment_key -> list of htlc_keys self.forwarding_failures = self.db.get_dict('forwarding_failures') # type: Dict[str, Tuple[str, str]] # Dict: payment_key -> (error_bytes, error_message) self.downstream_to_upstream_htlc = {} # type: Dict[str, str] # Dict: htlc_key -> htlc_key (not persisted) - self.dont_settle_htlcs = self.db.get_dict('dont_settle_htlcs') # type: Dict[str, None] # payment_hashes of htlcs that we should not settle back yet even if we have the preimage + + # k: payment_hashes of htlcs that we should not expire even if we don't know the preimage + # v: If `None` the htlcs won't get expired and potentially get timed out in a force close. + # Note: it might not be safe to release the preimage shortly before expiry as this would allow the + # remote node to ignore our fulfill_htlc, wait until expiry and try to time out the htlc onchain + # in a fee race against us and then use our released preimage to fulfill upstream. + # v: If `int`: Overwrites `MIN_FINAL_CLTV_DELTA_ACCEPTED` in htlc switch and allows to set custom + # expiration delta. The htlcs will get expired if their blocks left to expiry are + # below the specified expiration delta. + # htlcs will get settled as soon as the preimage becomes available + self.dont_expire_htlcs = self.db.get_dict('dont_expire_htlcs') # type: Dict[str, Optional[int]] + + # k: payment_hash of payments for which we don't want to release the preimage, no matter + # how close to expiry. Doesn't prevent htlcs from getting expired or failed if there is no + # preimage available. Might be used in combination with dont_expire_htlcs. + self.dont_settle_htlcs = self.db.get_dict('dont_settle_htlcs') # type: Dict[str, None] # payment_hash -> callback: self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]] diff --git a/tests/test_commands.py b/tests/test_commands.py index 455f7b87e758..3016263e1b31 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -510,7 +510,7 @@ async def test_hold_invoice_commands(self, mock_save_db): invoice = lndecode(invoice=result['invoice']) assert invoice.paymenthash.hex() == payment_hash assert payment_hash in wallet.lnworker.payment_info - assert payment_hash in wallet.lnworker.dont_settle_htlcs + assert payment_hash in wallet.lnworker.dont_expire_htlcs assert invoice.get_amount_sat() == 10000 assert invoice.get_description() == "test" assert wallet.get_label_for_rhash(rhash=invoice.paymenthash.hex()) == "test" @@ -521,7 +521,7 @@ async def test_hold_invoice_commands(self, mock_save_db): wallet=wallet, ) assert payment_hash not in wallet.lnworker.payment_info - assert payment_hash not in wallet.lnworker.dont_settle_htlcs + assert payment_hash not in wallet.lnworker.dont_expire_htlcs assert wallet.get_label_for_rhash(rhash=invoice.paymenthash.hex()) == "" assert cancel_result['cancelled'] == payment_hash @@ -571,7 +571,6 @@ async def test_hold_invoice_commands(self, mock_save_db): ) assert settle_result['settled'] == payment_hash assert wallet.lnworker._preimages[payment_hash] == preimage.hex() - assert payment_hash not in wallet.lnworker.dont_settle_htlcs with (mock.patch.object( wallet.lnworker, 'get_payment_value', diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 09cdbc24daaf..e2597b85534b 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -218,6 +218,7 @@ def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_que self._preimages = {} self.stopping_soon = False self.downstream_to_upstream_htlc = {} + self.dont_expire_htlcs = {} self.dont_settle_htlcs = {} self.hold_invoice_callbacks = {} self._payment_bundles_pkey_to_canon = {} # type: Dict[bytes, bytes] From 4f2e1b65f0b4b5f6a6d9a4b09a52de8d47d82aaf Mon Sep 17 00:00:00 2001 From: f321x Date: Mon, 3 Nov 2025 17:31:39 +0100 Subject: [PATCH 14/17] tests: test_lnpeer: add test_dont_settle_htlcs Adds test for the dont_settle_htlcs functionality of lnworker used by Just-In-Time channels. --- tests/test_lnpeer.py | 70 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index e2597b85534b..dee627c19492 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1971,6 +1971,76 @@ async def f(): with self.assertRaises(SuccessfulTest): await f() + async def test_dont_settle_htlcs(self): + """ + Test that htlcs registered in LNWallet.dont_settle_htlcs don't get fulfilled if the preimage is available. + """ + async def run_test(test_trampoline, test_failure): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + preimage = os.urandom(32) + lnaddr, pay_req = self.prepare_invoice( + w2, + payment_preimage=preimage, + # use a higher min final cltv delta so we can mine some blocks later + min_final_cltv_delta=244, + ) + + # add payment_hash to dont_settle_htlcs so the htlcs are not getting settled + w2.dont_settle_htlcs[pay_req.rhash] = None + + async def pay(lnaddr, pay_req): + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3) + if result is True: + self.assertNotIn(pay_req.rhash, w2.dont_settle_htlcs) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + return PaymentDone() + else: + self.assertIsNone(w2.get_preimage(lnaddr.paymenthash)) + return PaymentFailure() + + async def wait_for_htlcs(): + payment_key = w2._get_payment_key(lnaddr.paymenthash) + while payment_key.hex() not in w2.received_mpp_htlcs: + await asyncio.sleep(0.05) + w2.network.blockchain()._height += 25 # mine some blocks, shouldn't affect anything + if test_failure: + # delete preimage, this will fail htlcs even if registered in dont_settle_htlcs + del w2._preimages[pay_req.rhash] + return # pay() should fail now + await asyncio.sleep(0.25) # give w2 some time to do mistakes + self.assertEqual(w2.received_mpp_htlcs[payment_key.hex()].resolution, RecvMPPResolution.COMPLETE) + # remove the payment hash from dont_settle_htlcs so the htlcs can get fulfilled + del w2.dont_settle_htlcs[pay_req.rhash] + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + invoice_features = lnaddr.get_features() + self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) + pay_task = await group.spawn(pay(lnaddr, pay_req)) + await util.wait_for2(wait_for_htlcs(), timeout=2) + raise await pay_task + + await f() + + for test_trampoline in [False, True]: + for test_failure in [False, True]: + with self.assertRaises(PaymentFailure if test_failure else PaymentDone): + await run_test(test_trampoline, test_failure) + class TestPeerForwarding(TestPeer): From 1fd5458b0e93e06af53e5af8fac0b8e69245813b Mon Sep 17 00:00:00 2001 From: f321x Date: Mon, 24 Nov 2025 15:51:40 +0100 Subject: [PATCH 15/17] tests: lnpeer: test_dont_expire_htlcs Adds unittest to test the dont_expire_htlcs logic --- tests/test_lnpeer.py | 77 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index dee627c19492..acc5a5ac92b3 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -2041,6 +2041,83 @@ async def f(): with self.assertRaises(PaymentFailure if test_failure else PaymentDone): await run_test(test_trampoline, test_failure) + async def test_dont_expire_htlcs(self): + """ + Test that htlcs registered in LNWallet.dont_expire_htlcs don't get expired before the + specified expiry delta if their preimage isn't available. + Also test that htlcs registered in LNWallet.dont_expire_htlcs get settled right away if their + preimage is available. + """ + async def run_test(test_trampoline, test_expiry): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + preimage = os.urandom(32) + lnaddr, pay_req = self.prepare_invoice(w2, payment_preimage=preimage, min_final_cltv_delta=144) + + # delete preimage, this would fail the htlcs if payment_hash wasn't in dont_expire_htlcs + del w2._preimages[pay_req.rhash] + # add payment_hash to dont_expire_htlcs so the htlcs are not getting failed + w2.dont_expire_htlcs[pay_req.rhash] = None if not test_expiry else 20 + + async def pay(lnaddr, pay_req): + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3) + if result is True: + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + return PaymentDone() + else: + self.assertIsNone(w2.get_preimage(lnaddr.paymenthash)) + return PaymentFailure() + + async def wait_for_htlcs(): + payment_key = w2._get_payment_key(lnaddr.paymenthash) + while payment_key.hex() not in w2.received_mpp_htlcs: + await asyncio.sleep(0.05) + if not test_expiry: + # the htlcs should never get expired if the dont_expire_htlcs value is None + w2.network.blockchain()._height += 1000 + await asyncio.sleep(0.25) # give w2 some time to do mistakes + self.assertEqual(w2.received_mpp_htlcs[payment_key.hex()].resolution, RecvMPPResolution.COMPLETE) + if test_expiry: + # we set an expiry delta of 20 blocks before expiry, htlc expiry should be +144 current height + # so adding some blocks should get the htlcs failed + w2.network.blockchain()._height += 50 + await asyncio.sleep(0.1) + # the htlcs should not get failed yet as 144-50 > 20 + self.assertEqual(w2.received_mpp_htlcs[payment_key.hex()].resolution, RecvMPPResolution.COMPLETE) + w2.network.blockchain()._height += 75 + return # the htlcs should get failed and pay should return PaymentFailure + + # saving the preimage should let the htlcs get fulfilled + w2.save_preimage(lnaddr.paymenthash, preimage) + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + invoice_features = lnaddr.get_features() + self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) + pay_task = await group.spawn(pay(lnaddr, pay_req)) + await util.wait_for2(wait_for_htlcs(), timeout=3) + raise await pay_task + + await f() + + for test_trampoline in [False, True]: + for test_expiry in [False, True]: + with self.assertRaises(PaymentFailure if test_expiry else PaymentDone): + await run_test(test_trampoline, test_expiry ) + class TestPeerForwarding(TestPeer): From 16ed7e666c95045bf915a6b777c7eda75d05e792 Mon Sep 17 00:00:00 2001 From: f321x Date: Wed, 26 Nov 2025 13:42:20 +0100 Subject: [PATCH 16/17] lnpeer: use INVALID_ONION_VERSION for unparsable onions Use the `OnionFailureCode.INVALID_ONION_VERSION` (BADONION | PERM | 4) code when sending back `update_fail_malformed_htlc` as just sending a plain `BADONION` is not explicitly mentioned as correct in the spec. --- electrum/lnonion.py | 9 ++++++++- electrum/lnpeer.py | 5 ++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/electrum/lnonion.py b/electrum/lnonion.py index b57c0479211b..7760fd4459c0 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -577,7 +577,14 @@ def to_wire_msg(self, onion_packet: OnionPacket, privkey: bytes, local_height: i return error_bytes -class OnionParsingError(OnionRoutingFailure): pass +class OnionParsingError(OnionRoutingFailure): + """ + Onion parsing error will cause a htlc to get failed with update_fail_malformed_htlc. + Using INVALID_ONION_VERSION as there is no unspecific BADONION failure code defined in the spec + for the case we just cannot parse the onion. + """ + def __init__(self, data: bytes): + OnionRoutingFailure.__init__(self, code=OnionFailureCode.INVALID_ONION_VERSION, data=data) def construct_onion_error( diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index e263da1ca855..1e8da558b468 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -3229,7 +3229,6 @@ def _parse_onion_packet(self, onion_packet_hex: str) -> OnionPacket: except Exception as parsing_exc: self.logger.warning(f"unable to parse onion: {str(parsing_exc)}") onion_parsing_error = OnionParsingError( - code=OnionFailureCodeMetaFlag.BADONION, data=sha256(onion_packet_bytes or b''), ) raise onion_parsing_error @@ -3259,9 +3258,9 @@ def _process_incoming_onion_packet( raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_HMAC, data=onion_hash) except Exception as e: self.logger.warning(f"error processing onion packet: {e!r}") - raise OnionParsingError(code=OnionFailureCodeMetaFlag.BADONION, data=onion_hash) + raise OnionParsingError(data=onion_hash) if self.network.config.TEST_FAIL_HTLCS_AS_MALFORMED: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=onion_hash) + raise OnionParsingError(data=onion_hash) if self.network.config.TEST_FAIL_HTLCS_WITH_TEMP_NODE_FAILURE: raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') return processed_onion From 59586d6f94b464f738ac9c1b57786b77142ff552 Mon Sep 17 00:00:00 2001 From: f321x Date: Wed, 26 Nov 2025 13:45:20 +0100 Subject: [PATCH 17/17] tests: lnpeer: add test_payment_with_malformed_onion Adds a simple forwarding test where the receiver fails a malformed onion with `update_fail_malformed_htlc`. --- tests/test_lnpeer.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index acc5a5ac92b3..3fe6a3729d2a 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -2866,6 +2866,45 @@ def modified_new_onion_packet_lnworker(payment_path_pubkeys, session_key, hops_d assert len(bob_hm.all_htlcs_ever()) == 2 assert all(bob_hm.was_htlc_failed(htlc_id=htlc.htlc_id, htlc_proposer=HTLCOwner.REMOTE) for (_, htlc) in bob_hm.all_htlcs_ever()) + async def test_payment_with_malformed_onion(self): + """ + Alice -> Bob -> Carol. Carol fails htlc with update_fail_malformed_htlc because she is unable + to parse the onion Alice sent to her. + """ + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['line_graph']) + peers = graph.peers.values() + + async def pay(lnaddr, pay_req): + self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash)) + result, log = await graph.workers['alice'].pay_invoice(pay_req) + self.assertEqual(OnionFailureCode.INVALID_ONION_VERSION, log[0].failure_msg.code) + self.assertFalse(result, msg=log) + raise PaymentFailure() + + # this will make carol send update_fail_malformed_htlc + graph.workers['carol'].config.TEST_FAIL_HTLCS_AS_MALFORMED = True + + async def f(): + async with OldTaskGroup() as group: + for peer in peers: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + for peer in peers: + await peer.initialized + lnaddr, pay_req = self.prepare_invoice(graph.workers['carol'], include_routing_hints=True) + await group.spawn(pay(lnaddr, pay_req)) + + with self.assertLogs('electrum', level='INFO') as logs: + with self.assertRaises(PaymentFailure): + await f() + self.assertTrue( + any('carol->bob' in msg and 'fail_malformed_htlc' in msg for msg in logs.output) + ) + self.assertTrue( + any('bob->carol' in msg and 'on_update_fail_malformed_htlc' in msg for msg in logs.output) + ) + + class TestPeerDirectAnchors(TestPeerDirect): TEST_ANCHOR_CHANNELS = True