Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 17 additions & 29 deletions api/library/python/iterm2/iterm2/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,12 @@
import os
import sys
import traceback
import types
import typing
import websockets
import websockets.exceptions

gDisconnectCallbacks: typing.List[typing.Callable[[], None]] = []

# websockets 9.0 moved client into legacy.client and didn't document how to
# migrate to the new API :(. Stick with the old one until I have time to deal
# with this.
websockets_client: types.ModuleType
try:
import websockets.legacy.client
websockets_client = websockets.legacy.client
from websockets.legacy.client import connect as websockets_connect
except:
websockets_client = websockets.client
from websockets import connect as websockets_connect # type: ignore[assignment]

import iterm2.api_pb2
from iterm2._version import __version__

Expand Down Expand Up @@ -114,8 +102,8 @@ async def async_create() -> 'Connection':
connection._async_dispatch_forever(
connection, asyncio.get_running_loop()))
return connection
except websockets.exceptions.InvalidStatusCode as status_code_exception: # type: ignore[attr-defined]
if status_code_exception.status_code == 401:
except websockets.exceptions.InvalidStatus as status_exception:
if status_exception.response.status_code == 401:
if have_fresh_cookie:
raise
# Force request a cookie and try one more time.
Expand All @@ -124,7 +112,7 @@ async def async_create() -> 'Connection':
if not have_fresh_cookie:
# Didn't get a cookie, so no point trying again.
raise
elif status_code_exception.status_code == 406:
elif status_exception.response.status_code == 406:
print("This version of the iterm2 module is too old for " +
"the current version of iTerm2. Please upgrade.")
sys.exit(1)
Expand Down Expand Up @@ -328,9 +316,9 @@ def iterm2_protocol_version(self):
version of iTerm2 that doesn't report its version or it's unknown.
"""
key = "X-iTerm2-Protocol-Version"
if key not in self.websocket.response_headers:
if key not in self.websocket.response.headers:
return (0, 0)
header_value = self.websocket.response_headers[key]
header_value = self.websocket.response.headers[key]
parts = header_value.split(".")
if len(parts) != 2:
return (0, 0)
Expand Down Expand Up @@ -358,23 +346,23 @@ def _unix_domain_socket_path(self):
def _get_unix_connect_coro(self):
"""Experimental: connect with unix domain socket."""
path = self._unix_domain_socket_path()
return websockets_client.unix_connect(
return websockets.unix_connect(
path,
"ws://localhost/",
ping_interval=None,
close_timeout=0,
extra_headers=_headers(),
additional_headers=_headers(),
subprotocols=_subprotocols(),
max_size=None)


def _get_tcp_connect_coro(self):
"""Legacy: connect with tcp socket."""
return websockets_connect(_uri(),
ping_interval=None,
close_timeout=0,
extra_headers=_headers(),
subprotocols=_subprotocols())
"""Connect with tcp socket."""
return websockets.connect(_uri(),
ping_interval=None,
close_timeout=0,
additional_headers=_headers(),
subprotocols=_subprotocols())

def authenticate(self, force):
"""
Expand Down Expand Up @@ -429,8 +417,8 @@ async def async_connect(self, coro, retry=False):
except Exception as _err:
traceback.print_exc()
sys.exit(1)
except websockets.exceptions.InvalidStatusCode as exception: # type: ignore[attr-defined]
if exception.status_code == 401:
except websockets.exceptions.InvalidStatus as exception:
if exception.response.status_code == 401:
# Auth failure.
if retry:
# Sleep and try to authenticate until successful.
Expand All @@ -449,7 +437,7 @@ async def async_connect(self, coro, retry=False):
if not have_fresh_cookie:
# Failed to get a cookie. Give up.
raise
elif exception.status_code == 406:
elif exception.response.status_code == 406:
print("This version of the iterm2 module is too old " +
"for the current version of iTerm2. Please upgrade.")
sys.exit(1)
Expand Down
2 changes: 1 addition & 1 deletion api/library/python/iterm2/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def readme():
packages=['iterm2'],
install_requires=[
'protobuf',
'websockets'
'websockets>=14.0'
],
extras_require={
'full': ['pyobjc']
Expand Down
220 changes: 220 additions & 0 deletions api/library/python/iterm2/tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""Tests for iterm2.connection module."""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
import websockets.exceptions

from iterm2.connection import Connection


class TestIterm2ProtocolVersion:
"""Tests for the iterm2_protocol_version property."""

def _make_connection_with_headers(self, headers):
"""Create a Connection with a mocked websocket having the given response headers."""
conn = Connection()
conn.websocket = SimpleNamespace(
response=SimpleNamespace(headers=headers)
)
return conn

def test_returns_version_tuple(self):
"""Test that a valid version header returns the correct tuple."""
conn = self._make_connection_with_headers(
{"X-iTerm2-Protocol-Version": "1.5"}
)
assert conn.iterm2_protocol_version == (1, 5)

def test_returns_zero_when_header_missing(self):
"""Test that missing header returns (0, 0)."""
conn = self._make_connection_with_headers({})
assert conn.iterm2_protocol_version == (0, 0)

def test_returns_zero_when_header_malformed(self):
"""Test that a malformed version header returns (0, 0)."""
conn = self._make_connection_with_headers(
{"X-iTerm2-Protocol-Version": "invalid"}
)
assert conn.iterm2_protocol_version == (0, 0)

def test_returns_zero_when_header_has_too_many_parts(self):
"""Test that a version header with too many parts returns (0, 0)."""
conn = self._make_connection_with_headers(
{"X-iTerm2-Protocol-Version": "1.2.3"}
)
assert conn.iterm2_protocol_version == (0, 0)


class TestInvalidStatusHandling:
"""Tests for InvalidStatus exception handling in async_connect."""

def _make_invalid_status(self, status_code):
"""Create an InvalidStatus exception with the given status code."""
response = SimpleNamespace(status_code=status_code)
return websockets.exceptions.InvalidStatus(response)

@pytest.mark.asyncio
async def test_async_connect_exits_on_406(self):
"""Test that async_connect calls sys.exit(1) on 406 status."""
conn = Connection()
exc = self._make_invalid_status(406)

mock_context = AsyncMock()
mock_context.__aenter__ = AsyncMock(side_effect=exc)
mock_context.__aexit__ = AsyncMock(return_value=False)

with patch.object(conn, 'authenticate', return_value=True), \
patch.object(conn, '_remove_auth'), \
patch.object(conn, '_get_connect_coro', return_value=mock_context), \
pytest.raises(SystemExit) as exc_info:
await conn.async_connect(AsyncMock(), retry=False)
assert exc_info.value.code == 1

@pytest.mark.asyncio
async def test_async_connect_raises_on_other_status(self):
"""Test that async_connect re-raises on unexpected status codes."""
conn = Connection()
exc = self._make_invalid_status(500)

mock_context = AsyncMock()
mock_context.__aenter__ = AsyncMock(side_effect=exc)
mock_context.__aexit__ = AsyncMock(return_value=False)

with patch.object(conn, 'authenticate', return_value=True), \
patch.object(conn, '_remove_auth'), \
patch.object(conn, '_get_connect_coro', return_value=mock_context), \
pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
await conn.async_connect(AsyncMock(), retry=False)
assert exc_info.value.response.status_code == 500

@pytest.mark.asyncio
async def test_async_connect_retries_on_401_without_fresh_cookie(self):
"""Test that async_connect re-authenticates on 401 when cookie is not fresh."""
conn = Connection()
exc_401 = self._make_invalid_status(401)

mock_websocket = AsyncMock()

mock_context_1 = AsyncMock()
mock_context_1.__aenter__ = AsyncMock(side_effect=exc_401)
mock_context_1.__aexit__ = AsyncMock(return_value=False)

mock_context_2 = AsyncMock()
mock_context_2.__aenter__ = AsyncMock(return_value=mock_websocket)
mock_context_2.__aexit__ = AsyncMock(return_value=False)

coro_results = [mock_context_1, mock_context_2]

# authenticate is called 3 times:
# 1. Top of loop, 1st iteration → False (not fresh)
# 2. Inside 401 handler → True (re-authenticated)
# 3. Top of loop, 2nd iteration → True
auth_results = [False, True, True]

mock_coro = AsyncMock(return_value="ok")
mock_authenticate = MagicMock(side_effect=auth_results)

with patch.object(conn, 'authenticate', mock_authenticate), \
patch.object(conn, '_remove_auth'), \
patch.object(conn, '_get_connect_coro', side_effect=coro_results):
await conn.async_connect(mock_coro, retry=False)

mock_authenticate.assert_any_call(True)
mock_coro.assert_awaited_once()

@pytest.mark.asyncio
async def test_async_connect_raises_401_when_fresh_cookie(self):
"""Test that async_connect raises on 401 when cookie was already fresh."""
conn = Connection()
exc_401 = self._make_invalid_status(401)

mock_context = AsyncMock()
mock_context.__aenter__ = AsyncMock(side_effect=exc_401)
mock_context.__aexit__ = AsyncMock(return_value=False)

with patch.object(conn, 'authenticate', return_value=True), \
patch.object(conn, '_remove_auth'), \
patch.object(conn, '_get_connect_coro', return_value=mock_context), \
pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
await conn.async_connect(AsyncMock(), retry=False)
assert exc_info.value.response.status_code == 401


class TestAsyncCreate:
"""Tests for the async_create static method."""

def _make_invalid_status(self, status_code):
"""Create an InvalidStatus exception with the given status code."""
response = SimpleNamespace(status_code=status_code)
return websockets.exceptions.InvalidStatus(response)

@pytest.mark.asyncio
async def test_async_create_exits_on_406(self):
"""Test that async_create calls sys.exit(1) on 406 status."""
exc = self._make_invalid_status(406)

async def raise_exc():
raise exc

with patch.object(Connection, 'authenticate', return_value=True), \
patch.object(Connection, '_remove_auth'), \
patch.object(Connection, '_get_connect_coro', return_value=raise_exc()), \
pytest.raises(SystemExit) as exc_info:
await Connection.async_create()
assert exc_info.value.code == 1

@pytest.mark.asyncio
async def test_async_create_raises_on_other_status(self):
"""Test that async_create re-raises on unexpected status codes."""
exc = self._make_invalid_status(500)

async def raise_exc():
raise exc

with patch.object(Connection, 'authenticate', return_value=True), \
patch.object(Connection, '_remove_auth'), \
patch.object(Connection, '_get_connect_coro', return_value=raise_exc()), \
pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
await Connection.async_create()
assert exc_info.value.response.status_code == 500


class TestConnectCoroutineArgs:
"""Tests that connect methods pass correct arguments."""

@patch('iterm2.connection._headers', return_value={"x-test": "value"})
@patch('iterm2.connection._subprotocols', return_value=['api.iterm2.com'])
@patch('iterm2.connection._uri', return_value='ws://localhost:1912')
@patch('websockets.connect')
def test_tcp_connect_uses_additional_headers(
self, mock_connect, mock_uri, mock_subprotocols, mock_headers):
"""Test that _get_tcp_connect_coro uses additional_headers parameter."""
conn = Connection()
conn._get_tcp_connect_coro()
mock_connect.assert_called_once_with(
'ws://localhost:1912',
ping_interval=None,
close_timeout=0,
additional_headers={"x-test": "value"},
subprotocols=['api.iterm2.com'],
)

@patch('iterm2.connection._headers', return_value={"x-test": "value"})
@patch('iterm2.connection._subprotocols', return_value=['api.iterm2.com'])
@patch('websockets.unix_connect')
def test_unix_connect_uses_additional_headers(
self, mock_unix_connect, mock_subprotocols, mock_headers):
"""Test that _get_unix_connect_coro uses additional_headers parameter."""
conn = Connection()
with patch.object(conn, '_unix_domain_socket_path', return_value='/tmp/test.sock'):
conn._get_unix_connect_coro()
mock_unix_connect.assert_called_once_with(
'/tmp/test.sock',
'ws://localhost/',
ping_interval=None,
close_timeout=0,
additional_headers={"x-test": "value"},
subprotocols=['api.iterm2.com'],
max_size=None,
)