From 026be6fe76b10b9945e01bbd412aad3850fd066f Mon Sep 17 00:00:00 2001 From: Nate Prewitt Date: Wed, 18 Feb 2026 09:18:34 -0700 Subject: [PATCH 1/6] Add basic type checking with pyright to CI --- .github/workflows/tests.yml | 24 ++++++++++++++++++++++++ pyproject.toml | 6 ++++++ 2 files changed, 30 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 93d0b6c0..87f4e315 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,3 +60,27 @@ jobs: with: file: ./coverage.xml fail_ci_if_error: false + + type-check: + runs-on: "ubuntu-22.04" + continue-on-error: true + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: 'pip' + cache-dependency-path: setup.cfg + + - name: install + run: | + pip install --upgrade pip wheel + pip install -r "requirements/latest.txt" + pip install pyright + + - name: pyright + run: pyright adlfs diff --git a/pyproject.toml b/pyproject.toml index e212ff71..f46dac5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,5 +49,11 @@ include = ["adlfs*"] exclude = ["tests"] namespaces = false +[tool.pyright] +include = ["adlfs"] +exclude = ["adlfs/tests", "build"] +pythonVersion = "3.10" +typeCheckingMode = "basic" + [tool.isort] profile = "black" From f7296e5ff1fef3473f3849fb7c72254e62734cd0 Mon Sep 17 00:00:00 2001 From: Nate Prewitt Date: Wed, 18 Feb 2026 12:43:57 -0700 Subject: [PATCH 2/6] Fix typing --- adlfs/gen1.py | 2 +- adlfs/spec.py | 85 +++++++++++++++++++++++++++++++-------------------- 2 files changed, 53 insertions(+), 34 deletions(-) diff --git a/adlfs/gen1.py b/adlfs/gen1.py index baf8cde0..de3142ee 100644 --- a/adlfs/gen1.py +++ b/adlfs/gen1.py @@ -186,7 +186,7 @@ def __getstate__(self): def __setstate__(self, state): logger.debug("De-serialize with state: %s", state) - self.__dict__.update(state) + self.__dict__.update(state) # type: ignore[reportAttributeAccessIssue] self.do_connect() diff --git a/adlfs/spec.py b/adlfs/spec.py index 082cf260..d9bed1b7 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -15,7 +15,7 @@ from collections import defaultdict from datetime import datetime, timedelta, timezone from glob import has_magic -from typing import Optional, Tuple +from typing import Any, Literal, Optional, Tuple from uuid import uuid4 from azure.core.exceptions import ( @@ -126,7 +126,7 @@ def _create_aio_blob_service_client( location_mode: Optional[str] = None, credential: Optional[str] = None, ) -> AIOBlobServiceClient: - service_client_kwargs = { + service_client_kwargs: dict[str, Any] = { "account_url": account_url, "user_agent": _USER_AGENT, } @@ -264,30 +264,30 @@ class AzureBlobFileSystem(AsyncFileSystem): def __init__( self, - account_name: str = None, - account_key: str = None, - connection_string: str = None, - credential: str = None, - sas_token: str = None, + account_name: str | None = None, + account_key: str | None = None, + connection_string: str | None = None, + credential: str | None = None, + sas_token: str | None = None, request_session=None, socket_timeout=_SOCKET_TIMEOUT_DEFAULT, blocksize: int = _DEFAULT_BLOCK_SIZE, - client_id: str = None, - client_secret: str = None, - tenant_id: str = None, - anon: bool = None, + client_id: str | None = None, + client_secret: str | None = None, + tenant_id: str | None = None, + anon: bool | None = None, location_mode: str = "primary", loop=None, asynchronous: bool = False, default_fill_cache: bool = True, default_cache_type: str = "bytes", version_aware: bool = False, - assume_container_exists: Optional[bool] = None, - max_concurrency: Optional[int] = None, - timeout: Optional[int] = None, - connection_timeout: Optional[int] = None, - read_timeout: Optional[int] = None, - account_host: str = None, + assume_container_exists: bool | None = None, + max_concurrency: int | None = None, + timeout: int | None = None, + connection_timeout: int | None = None, + read_timeout: int | None = None, + account_host: str | None = None, **kwargs, ): self.kwargs = kwargs.copy() @@ -386,13 +386,15 @@ def __init__( weakref.finalize(self, sync, self.loop, close_credential, self) if max_concurrency is None: - batch_size = _get_batch_size() + batch_size: int = _get_batch_size() # type: ignore[assignment] if batch_size > 0: max_concurrency = batch_size + else: + max_concurrency = 1 self.max_concurrency = max_concurrency @classmethod - def _strip_protocol(cls, path: str): + def _strip_protocol(cls, path: str) -> str: """ Remove the protocol from the input path @@ -407,7 +409,7 @@ def _strip_protocol(cls, path: str): Returns a path without the protocol """ if isinstance(path, list): - return [cls._strip_protocol(p) for p in path] + return [cls._strip_protocol(p) for p in path] # type: ignore[return-value] STORE_SUFFIX = ".dfs.core.windows.net" logger.debug(f"_strip_protocol for {path}") @@ -473,6 +475,16 @@ def _get_credential_from_service_principal(self): ------- Tuple of (Async Credential, Sync Credential). """ + if ( + self.tenant_id is None + or self.client_id is None + or self.client_secret is None + ): + raise ValueError( + "tenant_id, client_id, and client_secret must all be provided " + "when authenticating with a service principal." + ) + from azure.identity import ClientSecretCredential from azure.identity.aio import ( ClientSecretCredential as AIOClientSecretCredential, @@ -1635,6 +1647,12 @@ async def _url( account_name = self.account_name account_key = self.account_key + if account_name is None: + raise ValueError( + "account_name is required to generate a SAS URL. " + "Provide account_name or include AccountName in the connection string." + ) + sas_token = generate_blob_sas( account_name=account_name, container_name=container_name, @@ -1653,8 +1671,8 @@ async def _url( url = f"{bc.url}?{sas_token}" return url - def expand_path(self, path, recursive=False, maxdepth=None, skip_noexist=True): - return sync( + def expand_path(self, path, recursive=False, maxdepth=None, skip_noexist=True) -> list[str]: + return sync( # type: ignore[return-value] self.loop, self._expand_path, path, recursive, maxdepth, skip_noexist ) @@ -1887,7 +1905,7 @@ def _open( self, path: str, mode: str = "rb", - block_size: int = None, + block_size: int | None = None, autocommit: bool = True, cache_options: dict = {}, cache_type="readahead", @@ -1954,7 +1972,7 @@ def __init__( fs: AzureBlobFileSystem, path: str, mode: str = "rb", - block_size="default", + block_size: int | Literal["default"] | None = "default", autocommit: bool = True, cache_type: str = "bytes", cache_options: dict = {}, @@ -2017,9 +2035,10 @@ def __init__( self.loop = self._get_loop() self.container_client = self._get_container_client() - self.blocksize = ( - self.DEFAULT_BLOCK_SIZE if block_size in ["default", None] else block_size - ) + if block_size == "default" or block_size is None: + self.blocksize: int = self.DEFAULT_BLOCK_SIZE + else: + self.blocksize = block_size self.loc = 0 self.autocommit = autocommit self.end = None @@ -2127,9 +2146,9 @@ def connect_client(self): """ try: if hasattr(self.fs, "account_host"): - self.fs.account_url: str = f"https://{self.fs.account_host}" + self.fs.account_url = f"https://{self.fs.account_host}" else: - self.fs.account_url: str = ( + self.fs.account_url = ( f"https://{self.fs.account_name}.blob.core.windows.net" ) @@ -2164,7 +2183,7 @@ def connect_client(self): f"Unable to fetch container_client with provided params for {e}!!" ) from e - async def _async_fetch_range(self, start: int, end: int = None, **kwargs): + async def _async_fetch_range(self, start: int, end: int | None = None, **kwargs): """ Download a chunk of data specified by start and end @@ -2221,7 +2240,7 @@ async def _stage_block(self, data, start, end, block_id, semaphore): async with self.container_client.get_blob_client(blob=self.blob) as bc: await bc.stage_block( block_id=block_id, - data=data[start:end], + data=data[start:end], # type: ignore[arg-type] length=end - start, ) return block_id @@ -2301,7 +2320,7 @@ async def _async_upload_chunk(self, final: bool = False, **kwargs): await bc.upload_blob( data=data, length=length, - blob_type=BlobType.AppendBlob, + blob_type=BlobType.APPENDBLOB, metadata=self.metadata, ) else: @@ -2329,6 +2348,6 @@ def __getstate__(self): return state def __setstate__(self, state): - self.__dict__.update(state) + self.__dict__.update(state) # type: ignore[reportAttributeAccessIssue] self.loop = self._get_loop() self.container_client = self._get_container_client() From 134cc012eb798baf9220562c0858ee0b1cc09be3 Mon Sep 17 00:00:00 2001 From: Nate Prewitt Date: Wed, 18 Feb 2026 12:50:33 -0700 Subject: [PATCH 3/6] Update typing syntax to use PEP 585/604 --- adlfs/spec.py | 32 ++++++++++++++++---------------- adlfs/utils.py | 8 +++----- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/adlfs/spec.py b/adlfs/spec.py index d9bed1b7..7bbba8ef 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -9,13 +9,13 @@ import logging import os import re -import typing import warnings import weakref from collections import defaultdict +from collections.abc import Iterable from datetime import datetime, timedelta, timezone from glob import has_magic -from typing import Any, Literal, Optional, Tuple +from typing import Any, Literal from uuid import uuid4 from azure.core.exceptions import ( @@ -105,7 +105,7 @@ def get_running_loop(): return loop -def _coalesce_version_id(*args) -> Optional[str]: +def _coalesce_version_id(*args) -> str | None: """Helper to coalesce a list of version_ids down to one""" version_ids = set(args) if None in version_ids: @@ -123,8 +123,8 @@ def _coalesce_version_id(*args) -> Optional[str]: def _create_aio_blob_service_client( account_url: str, - location_mode: Optional[str] = None, - credential: Optional[str] = None, + location_mode: str | None = None, + credential: str | None = None, ) -> AIOBlobServiceClient: service_client_kwargs: dict[str, Any] = { "account_url": account_url, @@ -585,7 +585,7 @@ def do_connect(self): def split_path( self, path, delimiter="/", return_container: bool = False, **kwargs - ) -> Tuple[str, str, Optional[str]]: + ) -> tuple[str, str, str | None]: """ Normalize ABFS path string into bucket and key. @@ -720,7 +720,7 @@ async def _ls_blobs( path: str, delimiter: str = "/", return_glob: bool = False, - version_id: Optional[str] = None, + version_id: str | None = None, versions: bool = False, **kwargs, ): @@ -811,7 +811,7 @@ async def _ls( invalidate_cache: bool = False, delimiter: str = "/", return_glob: bool = False, - version_id: Optional[str] = None, + version_id: str | None = None, versions: bool = False, **kwargs, ): @@ -879,7 +879,7 @@ async def _details( delimiter="/", return_glob: bool = False, target_path="", - version_id: Optional[str] = None, + version_id: str | None = None, versions: bool = False, **kwargs, ): @@ -1207,9 +1207,9 @@ def makedir(self, path, exist_ok=False): async def _rm( self, - path: typing.Union[str, typing.List[str]], + path: str | list[str], recursive: bool = False, - maxdepth: typing.Optional[int] = None, + maxdepth: int | None = None, delimiter: str = "/", expand_path: bool = True, **kwargs, @@ -1269,7 +1269,7 @@ async def _rm( rm = sync_wrapper(_rm) async def _rm_files( - self, container_name: str, file_paths: typing.Iterable[str], **kwargs + self, container_name: str, file_paths: Iterable[str], **kwargs ): """ Delete the given file(s) @@ -1334,8 +1334,8 @@ async def _rm_file(self, path: str, **kwargs): self.invalidate_cache(self._parent(path)) async def _separate_directory_markers_for_non_empty_directories( - self, file_paths: typing.Iterable[str] - ) -> typing.Tuple[typing.List[str], typing.List[str]]: + self, file_paths: Iterable[str] + ) -> tuple[list[str], list[str]]: """ Distinguish directory markers of non-empty directories from files and directory markers for empty directories. A directory marker is an empty blob who's name is the path of the directory. @@ -1910,7 +1910,7 @@ def _open( cache_options: dict = {}, cache_type="readahead", metadata=None, - version_id: Optional[str] = None, + version_id: str | None = None, **kwargs, ): """Open a file on the datalake, or a block blob @@ -1977,7 +1977,7 @@ def __init__( cache_type: str = "bytes", cache_options: dict = {}, metadata=None, - version_id: Optional[str] = None, + version_id: str | None = None, **kwargs, ): """ diff --git a/adlfs/utils.py b/adlfs/utils.py index d4a00a98..3922212e 100644 --- a/adlfs/utils.py +++ b/adlfs/utils.py @@ -1,5 +1,3 @@ -from typing import Optional - try: from ._version import version as __version__ # type: ignore[import] from ._version import version_tuple # type: ignore[import] @@ -8,7 +6,7 @@ version_tuple = (0, 0, __version__) # type: ignore[assignment] -def match_blob_version(blob, version_id: Optional[str]): +def match_blob_version(blob, version_id: str | None): blob_version_id = blob.get("version_id") return ( version_id is None @@ -20,7 +18,7 @@ async def filter_blobs( blobs, target_path, delimiter="/", - version_id: Optional[str] = None, + version_id: str | None = None, versions: bool = False, ): """ @@ -50,7 +48,7 @@ async def filter_blobs( return finalblobs -async def get_blob_metadata(container_client, path, version_id: Optional[str] = None): +async def get_blob_metadata(container_client, path, version_id: str | None = None): async with container_client.get_blob_client(path) as bc: properties = await bc.get_blob_properties(version_id=version_id) if "metadata" in properties.keys(): From 816c35562b93fac3ab95bbdf22615a790cbba2a4 Mon Sep 17 00:00:00 2001 From: Nate Prewitt Date: Wed, 18 Feb 2026 13:27:38 -0700 Subject: [PATCH 4/6] Lint changes --- adlfs/spec.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/adlfs/spec.py b/adlfs/spec.py index 7bbba8ef..8bf09ccb 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -1268,9 +1268,7 @@ async def _rm( rm = sync_wrapper(_rm) - async def _rm_files( - self, container_name: str, file_paths: Iterable[str], **kwargs - ): + async def _rm_files(self, container_name: str, file_paths: Iterable[str], **kwargs): """ Delete the given file(s) @@ -1671,7 +1669,9 @@ async def _url( url = f"{bc.url}?{sas_token}" return url - def expand_path(self, path, recursive=False, maxdepth=None, skip_noexist=True) -> list[str]: + def expand_path( + self, path, recursive=False, maxdepth=None, skip_noexist=True + ) -> list[str]: return sync( # type: ignore[return-value] self.loop, self._expand_path, path, recursive, maxdepth, skip_noexist ) From 1392026fd8c86de87c9a8d54236e2233576c46b1 Mon Sep 17 00:00:00 2001 From: Nate Prewitt Date: Wed, 18 Feb 2026 13:42:00 -0700 Subject: [PATCH 5/6] Add type stub for dynamic version --- adlfs/_version.pyi | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 adlfs/_version.pyi diff --git a/adlfs/_version.pyi b/adlfs/_version.pyi new file mode 100644 index 00000000..996dce94 --- /dev/null +++ b/adlfs/_version.pyi @@ -0,0 +1,6 @@ +version: str +__version__: str +version_tuple: tuple[int | str, ...] +__version_tuple__: tuple[int | str, ...] +commit_id: str | None +__commit_id__: str | None From deb657cd1f2b8e2c4ba195729fb8d583d73680cb Mon Sep 17 00:00:00 2001 From: Nate Prewitt Date: Wed, 18 Feb 2026 13:50:21 -0700 Subject: [PATCH 6/6] Ignore gen1.py due to dependency typing missing --- adlfs/gen1.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/adlfs/gen1.py b/adlfs/gen1.py index de3142ee..baf8cde0 100644 --- a/adlfs/gen1.py +++ b/adlfs/gen1.py @@ -186,7 +186,7 @@ def __getstate__(self): def __setstate__(self, state): logger.debug("De-serialize with state: %s", state) - self.__dict__.update(state) # type: ignore[reportAttributeAccessIssue] + self.__dict__.update(state) self.do_connect() diff --git a/pyproject.toml b/pyproject.toml index f46dac5a..f8a75d5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ namespaces = false [tool.pyright] include = ["adlfs"] -exclude = ["adlfs/tests", "build"] +exclude = ["adlfs/tests", "adlfs/gen1.py", "build"] pythonVersion = "3.10" typeCheckingMode = "basic"