diff --git a/bumble/device.py b/bumble/device.py index dbaeb52e..e6b216c9 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -4836,11 +4836,21 @@ async def get_long_term_key( if keys.ltk: return keys.ltk.value - if connection.role == hci.Role.CENTRAL and keys.ltk_central: + # Check both ltk_central and ltk_peripheral by matching EDIV+Rand + if ( + keys.ltk_central + and keys.ltk_central.ediv == ediv + and keys.ltk_central.rand == rand + ): return keys.ltk_central.value - if connection.role == hci.Role.PERIPHERAL and keys.ltk_peripheral: + if ( + keys.ltk_peripheral + and keys.ltk_peripheral.ediv == ediv + and keys.ltk_peripheral.rand == rand + ): return keys.ltk_peripheral.value + return None async def get_link_key(self, address: hci.Address) -> bytes | None: diff --git a/tests/device_test.py b/tests/device_test.py index af18c78c..1d8fddb7 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -51,6 +51,7 @@ Role, ) from bumble.host import DataPacketQueue, Host +from bumble.keys import PairingKeys from .test_utils import TwoDevices, async_barrier @@ -823,6 +824,92 @@ async def test_remote_name_request(): actual_name = await devices[0].request_remote_name(devices[1].public_address) assert actual_name == expected_name +# ----------------------------------------------------------------------------- +@pytest.fixture +def device_with_connection() -> Device: + """Device with a registered LE connection and an SMP manager that has no LTK.""" + device = Device(host=Host(None, None)) + peer_address = Address('AA:BB:CC:DD:EE:FF', address_type=Address.RANDOM_DEVICE_ADDRESS) + connection = Connection( + device=device, + handle=0x0001, + transport=PhysicalTransport.LE, + self_address=Address('11:22:33:44:55:66'), + self_resolvable_address=None, + peer_address=peer_address, + peer_resolvable_address=None, + role=Role.CENTRAL, + parameters=Connection.Parameters( + connection_interval=10.0, + peripheral_latency=0, + supervision_timeout=720.0, + ), + ) + device.connections[0x0001] = connection + device.smp_manager = mock.MagicMock() + device.smp_manager.get_long_term_key.return_value = None + return device + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_get_long_term_key_no_keystore(device_with_connection): + """Returns None when no keystore is configured and SMP has no key.""" + device_with_connection.keystore = None + + result = await device_with_connection.get_long_term_key(connection_handle=0x0001, rand=b'\x00' * 8, ediv=0) + + assert result is None + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + 'keys, rand, ediv, expected_ltk', + [ + pytest.param( + PairingKeys(ltk=PairingKeys.Key(value=b'\x02' * 16)), + b'\xaa' * 8, + 0x1234, + b'\x02' * 16, + id='legacy_ltk', + ), + pytest.param( + PairingKeys(ltk_central=PairingKeys.Key(value=b'\x03' * 16, ediv=0x5678, rand=b'\xbb' * 8)), + b'\xbb' * 8, + 0x5678, + b'\x03' * 16, + id='ltk_central_matching_ediv_rand', + ), + pytest.param( + PairingKeys(ltk_peripheral=PairingKeys.Key(value=b'\x04' * 16, ediv=0x9ABC, rand=b'\xcc' * 8)), + b'\xcc' * 8, + 0x9ABC, + b'\x04' * 16, + id='ltk_peripheral_matching_ediv_rand', + ), + pytest.param( + PairingKeys(ltk_central=PairingKeys.Key(value=b'\x05' * 16, ediv=0x0001, rand=b'\xdd' * 8)), + b'\xff' * 8, + 0x0002, + None, + id='ltk_central_wrong_ediv_rand', + ), + pytest.param( + None, + b'\x00' * 8, + 0, + None, + id='keystore_no_entry', + ), + ], +) +@pytest.mark.asyncio +async def test_get_long_term_key_from_keystore(device_with_connection, keys, rand, ediv, expected_ltk): + keystore = mock.AsyncMock() + keystore.get.return_value = keys + device_with_connection.keystore = keystore + + result = await device_with_connection.get_long_term_key(connection_handle=0x0001, rand=rand, ediv=ediv) + assert result == expected_ltk # ----------------------------------------------------------------------------- async def run_test_device():