diff --git a/nexios_contrib/accepts/__init__.py b/nexios_contrib/accepts/__init__.py index 49ba034..c0cdf7f 100644 --- a/nexios_contrib/accepts/__init__.py +++ b/nexios_contrib/accepts/__init__.py @@ -4,39 +4,40 @@ This module provides content negotiation and Accept header processing for Nexios applications. """ + from __future__ import annotations + +from .dependency import AcceptsDepend from .helpers import ( AcceptItem, - parse_accept_header, - parse_accept_language, - parse_accept_charset, - parse_accept_encoding, - negotiate_content_type, - negotiate_language, - negotiate_charset, - negotiate_encoding, - matches_media_type, - get_best_match, - get_accepts_info, + AcceptsInfo, create_vary_header, - get_accepted_content_types, - get_accepted_languages, get_accepted_charsets, + get_accepted_content_types, get_accepted_encodings, + get_accepted_languages, + get_accepts_from_request, + get_accepts_info, get_best_accepted_content_type, get_best_accepted_language, - get_accepts_from_request, - AcceptsInfo, - + get_best_match, + matches_media_type, + negotiate_charset, + negotiate_content_type, + negotiate_encoding, + negotiate_language, + parse_accept_charset, + parse_accept_encoding, + parse_accept_header, + parse_accept_language, ) from .middleware import ( Accepts, - AcceptsMiddleware, ContentNegotiationMiddleware, StrictContentNegotiationMiddleware, ) -from .dependency import AcceptsDepend + __all__ = [ "AcceptItem", "parse_accept_header", diff --git a/nexios_contrib/accepts/dependency.py b/nexios_contrib/accepts/dependency.py index 9828488..33ce57b 100644 --- a/nexios_contrib/accepts/dependency.py +++ b/nexios_contrib/accepts/dependency.py @@ -4,19 +4,20 @@ This module provides dependency injection utilities for accessing parsed Accept header information from requests. """ -from __future__ import annotations - -from typing import Any -from nexios.dependencies import Depend, Context -from nexios.http import Request -from .helpers import AcceptsInfo +from __future__ import annotations +from typing import cast +from nexios.dependencies import Context, Depend +from nexios.http import Request +from .helpers import AcceptsInfo -def get_accepts_info_from_request(request: Request, attribute_name: str = "accepts") -> AcceptsInfo: +def get_accepts_info_from_request( + request: Request, attribute_name: str = "accepts" +) -> AcceptsInfo: """ Get AcceptsInfo object from request. @@ -54,7 +55,9 @@ async def get_users( accepted_types = accepts.get_accepted_types() return {"accepted_types": accepted_types} """ + def _wrap(request: Request = Context().request) -> AcceptsInfo: return get_accepts_info_from_request(request, attribute_name) - return Depend(_wrap) + return cast(AcceptsInfo, Depend(_wrap)) + diff --git a/nexios_contrib/accepts/helpers.py b/nexios_contrib/accepts/helpers.py index dac6038..3582466 100644 --- a/nexios_contrib/accepts/helpers.py +++ b/nexios_contrib/accepts/helpers.py @@ -5,13 +5,15 @@ content negotiation for HTTP requests, as well as helper functions for accessing parsed accepts information from requests. """ + from __future__ import annotations +from ty_extensions import Unknown -import re -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional from nexios.http import Request + class AcceptsInfo: """ Container for parsed accepts information from a request. @@ -38,44 +40,58 @@ def __init__(self, request: Request): self._state_accept_encoding = None @property - def accept(self) -> List[Dict[str, Any]]: + def accept(self) -> Any | List[AcceptItem]: """Get parsed Accept header items from state or parse fresh.""" if self._state_accept is None: - if hasattr(self.request.state, 'accepts_parsed'): - item = getattr(self.request.state, 'accepts_parsed', {}) - self._state_accept = item.get('accept', []) if item else [] + if hasattr(self.request.state, "accepts_parsed"): + item = getattr(self.request.state, "accepts_parsed", {}) + self._state_accept: Any | list[Unknown] = item.get("accept", []) if item else [] else: - self._state_accept = parse_accept_header(self.request.headers.get('Accept', '')) + self._state_accept = parse_accept_header( + self.request.headers.get("Accept", "") + ) return self._state_accept @property - def accept_language(self) -> List[Dict[str, Any]]: + def accept_language(self) -> List[AcceptItem]: """Get parsed Accept-Language header items from state or parse fresh.""" if self._state_accept_language is None: - if hasattr(self.request.state, 'accepts_parsed'): - self._state_accept_language = getattr(self.request.state, 'accepts_parsed', {}).get('accept_language', []) + if hasattr(self.request.state, "accepts_parsed"): + self._state_accept_language = getattr( + self.request.state, "accepts_parsed", {} + ).get("accept_language", []) else: - self._state_accept_language = parse_accept_language(self.request.headers.get('Accept-Language', '')) + self._state_accept_language = parse_accept_language( + self.request.headers.get("Accept-Language", "") + ) return self._state_accept_language @property - def accept_charset(self) -> List[Dict[str, Any]]: + def accept_charset(self) -> List[AcceptItem]: """Get parsed Accept-Charset header items from state or parse fresh.""" if self._state_accept_charset is None: - if hasattr(self.request.state, 'accepts_parsed'): - self._state_accept_charset = getattr(self.request.state, 'accepts_parsed', {}).get('accept_charset', []) + if hasattr(self.request.state, "accepts_parsed"): + self._state_accept_charset = getattr( + self.request.state, "accepts_parsed", {} + ).get("accept_charset", []) else: - self._state_accept_charset = parse_accept_charset(self.request.headers.get('Accept-Charset', '')) + self._state_accept_charset = parse_accept_charset( + self.request.headers.get("Accept-Charset", "") + ) return self._state_accept_charset @property - def accept_encoding(self) -> List[Dict[str, Any]]: + def accept_encoding(self) -> List[AcceptItem]: """Get parsed Accept-Encoding header items from state or parse fresh.""" if self._state_accept_encoding is None: - if hasattr(self.request.state, 'accepts_parsed'): - self._state_accept_encoding = getattr(self.request.state, 'accepts_parsed', {}).get('accept_encoding', []) + if hasattr(self.request.state, "accepts_parsed"): + self._state_accept_encoding = getattr( + self.request.state, "accepts_parsed", {} + ).get("accept_encoding", []) else: - self._state_accept_encoding = parse_accept_encoding(self.request.headers.get('Accept-Encoding', '')) + self._state_accept_encoding = parse_accept_encoding( + self.request.headers.get("Accept-Encoding", "") + ) return self._state_accept_encoding def get_accepted_types(self) -> List[str]: @@ -114,12 +130,15 @@ def get_accepted_encodings(self) -> List[str]: """ return [item.value for item in self.accept_encoding if item.quality > 0] + class AcceptItem: """ Represents a single item in an Accept header with type/subtype and parameters. """ - def __init__(self, value: str, quality: float = 1.0, params: Optional[Dict[str, str]] = None): + def __init__( + self, value: str, quality: float = 1.0, params: Optional[Dict[str, str]] = None + ): """ Initialize an AcceptItem. @@ -151,7 +170,7 @@ def parse_accept_header(accept_header: str) -> List[AcceptItem]: items = [] - for part in accept_header.split(','): + for part in accept_header.split(","): part = part.strip() if not part: continue @@ -160,19 +179,19 @@ def parse_accept_header(accept_header: str) -> List[AcceptItem]: quality = 1.0 params = {} - if ';' in part: - media_range, param_str = part.split(';', 1) + if ";" in part: + media_range, param_str = part.split(";", 1) media_range = media_range.strip() # Parse parameters - for param in param_str.split(';'): + for param in param_str.split(";"): param = param.strip() - if '=' in param: - key, value = param.split('=', 1) + if "=" in param: + key, value = param.split("=", 1) key = key.strip().lower() value = value.strip() - if key == 'q': + if key == "q": try: quality = max(0.0, min(1.0, float(value))) except ValueError: @@ -188,7 +207,7 @@ def parse_accept_header(accept_header: str) -> List[AcceptItem]: items.append(AcceptItem(media_range, quality, params)) # Sort by quality (highest first), then by specificity - items.sort(key=lambda x: (-x.quality, x.value.count('/'), -len(x.value))) + items.sort(key=lambda x: (-x.quality, x.value.count("/"), -len(x.value))) return items @@ -232,7 +251,9 @@ def parse_accept_encoding(accept_encoding: str) -> List[AcceptItem]: return parse_accept_header(accept_encoding) -def negotiate_content_type(accept_header: str, available_types: List[str]) -> Optional[str]: +def negotiate_content_type( + accept_header: str, available_types: List[str] +) -> Optional[str]: """ Perform content negotiation for media types. @@ -263,19 +284,21 @@ def negotiate_content_type(accept_header: str, available_types: List[str]) -> Op continue # Check for */* or type/* - if accept_item.value == '*/*': + if accept_item.value == "*/*": return available_types[0] # Return first available type - if '/*' in accept_item.value: - accept_type = accept_item.value.split('/')[0] + if "/*" in accept_item.value: + accept_type = accept_item.value.split("/")[0] for available_type in available_types: - if available_type.startswith(accept_type + '/'): + if available_type.startswith(accept_type + "/"): return available_type return None -def negotiate_language(accept_language: str, available_languages: List[str]) -> Optional[str]: +def negotiate_language( + accept_language: str, available_languages: List[str] +) -> Optional[str]: """ Perform language negotiation. @@ -300,10 +323,10 @@ def negotiate_language(accept_language: str, available_languages: List[str]) -> return accept_item.value # Language prefix match (e.g., "en" matches "en-US") - if '-' in accept_item.value: - lang_prefix = accept_item.value.split('-')[0] + if "-" in accept_item.value: + lang_prefix = accept_item.value.split("-")[0] for available_lang in available_languages: - if available_lang.startswith(lang_prefix + '-'): + if available_lang.startswith(lang_prefix + "-"): return available_lang if available_lang == lang_prefix: return available_lang @@ -311,7 +334,9 @@ def negotiate_language(accept_language: str, available_languages: List[str]) -> return available_languages[0] if available_languages else None -def negotiate_charset(accept_charset: str, available_charsets: List[str]) -> Optional[str]: +def negotiate_charset( + accept_charset: str, available_charsets: List[str] +) -> Optional[str]: """ Perform charset negotiation. @@ -335,13 +360,15 @@ def negotiate_charset(accept_charset: str, available_charsets: List[str]) -> Opt return accept_item.value # Handle * wildcard - if accept_item.value == '*': + if accept_item.value == "*": return available_charsets[0] return available_charsets[0] if available_charsets else None -def negotiate_encoding(accept_encoding: str, available_encodings: List[str]) -> List[str]: +def negotiate_encoding( + accept_encoding: str, available_encodings: List[str] +) -> List[str]: """ Perform encoding negotiation. @@ -363,8 +390,10 @@ def negotiate_encoding(accept_encoding: str, available_encodings: List[str]) -> continue # Handle identity encoding - if accept_item.value == 'identity' or accept_item.value == '*': - accepted_encodings.extend([enc for enc in available_encodings if enc != 'identity']) + if accept_item.value == "identity" or accept_item.value == "*": + accepted_encodings.extend( + [enc for enc in available_encodings if enc != "identity"] + ) continue # Check for specific encoding match @@ -388,12 +417,12 @@ def matches_media_type(pattern: str, media_type: str) -> bool: if pattern == media_type: return True - if pattern == '*/*': + if pattern == "*/*": return True - if pattern.endswith('/*'): + if pattern.endswith("/*"): pattern_type = pattern[:-2] - return media_type.startswith(pattern_type + '/') + return media_type.startswith(pattern_type + "/") return False @@ -425,7 +454,7 @@ def get_best_match(accept_header: str, options: List[str]) -> Optional[str]: return options[0] if options else None -def get_accepts_info(request: Request) -> Dict[str, any]: +def get_accepts_info(request: Request) -> Dict[str, Any]: """ Extract and parse all Accept-related headers from a request. @@ -436,14 +465,20 @@ def get_accepts_info(request: Request) -> Dict[str, any]: Dict[str, any]: Dictionary containing parsed accept information. """ return { - 'accept': parse_accept_header(request.headers.get('Accept', '')), - 'accept_language': parse_accept_language(request.headers.get('Accept-Language', '')), - 'accept_charset': parse_accept_charset(request.headers.get('Accept-Charset', '')), - 'accept_encoding': parse_accept_encoding(request.headers.get('Accept-Encoding', '')), - 'raw_accept': request.headers.get('Accept', ''), - 'raw_accept_language': request.headers.get('Accept-Language', ''), - 'raw_accept_charset': request.headers.get('Accept-Charset', ''), - 'raw_accept_encoding': request.headers.get('Accept-Encoding', ''), + "accept": parse_accept_header(request.headers.get("Accept", "")), + "accept_language": parse_accept_language( + request.headers.get("Accept-Language", "") + ), + "accept_charset": parse_accept_charset( + request.headers.get("Accept-Charset", "") + ), + "accept_encoding": parse_accept_encoding( + request.headers.get("Accept-Encoding", "") + ), + "raw_accept": request.headers.get("Accept", ""), + "raw_accept_language": request.headers.get("Accept-Language", ""), + "raw_accept_charset": request.headers.get("Accept-Charset", ""), + "raw_accept_encoding": request.headers.get("Accept-Encoding", ""), } @@ -459,20 +494,23 @@ def create_vary_header(existing_vary: Optional[str], new_fields: List[str]) -> s str: Updated Vary header value. """ if not existing_vary: - return ', '.join(new_fields) + return ", ".join(new_fields) - existing_fields = [field.strip() for field in existing_vary.split(',')] + existing_fields = [field.strip() for field in existing_vary.split(",")] for field in new_fields: if field not in existing_fields: existing_fields.append(field) - return ', '.join(existing_fields) + return ", ".join(existing_fields) # Helper functions for accessing accepts information from requests -def get_accepts_from_request(request: Request, attribute_name: str = "accepts") -> AcceptsInfo: + +def get_accepts_from_request( + request: Request, attribute_name: str = "accepts" +) -> AcceptsInfo: """ Get AcceptsInfo object from request. @@ -486,7 +524,9 @@ def get_accepts_from_request(request: Request, attribute_name: str = "accepts") return AcceptsInfo(request) -def get_accepted_content_types(request: Request, attribute_name: str = "accepts_parsed") -> List[str]: +def get_accepted_content_types( + request: Request, attribute_name: str = "accepts_parsed" +) -> List[str]: """ Get accepted content types from request. @@ -498,12 +538,14 @@ def get_accepted_content_types(request: Request, attribute_name: str = "accepts_ List[str]: List of accepted content types ordered by quality. """ accepts_parsed = getattr(request.state, attribute_name, {}) - accept_items = accepts_parsed.get('accept', []) + accept_items = accepts_parsed.get("accept", []) return [item.value for item in accept_items if item.quality > 0] -def get_accepted_languages(request: Request, attribute_name: str = "accepts_parsed") -> List[str]: +def get_accepted_languages( + request: Request, attribute_name: str = "accepts_parsed" +) -> List[str]: """ Get accepted languages from request. @@ -515,12 +557,14 @@ def get_accepted_languages(request: Request, attribute_name: str = "accepts_pars List[str]: List of accepted languages ordered by quality. """ accepts_parsed = getattr(request.state, attribute_name, {}) - accept_items = accepts_parsed.get('accept_language', []) + accept_items = accepts_parsed.get("accept_language", []) return [item.value for item in accept_items if item.quality > 0] -def get_accepted_charsets(request: Request, attribute_name: str = "accepts_parsed") -> List[str]: +def get_accepted_charsets( + request: Request, attribute_name: str = "accepts_parsed" +) -> List[str]: """ Get accepted charsets from request. @@ -532,12 +576,14 @@ def get_accepted_charsets(request: Request, attribute_name: str = "accepts_parse List[str]: List of accepted charsets ordered by quality. """ accepts_parsed = getattr(request.state, attribute_name, {}) - accept_items = accepts_parsed.get('accept_charset', []) + accept_items = accepts_parsed.get("accept_charset", []) return [item.value for item in accept_items if item.quality > 0] -def get_accepted_encodings(request: Request, attribute_name: str = "accepts_parsed") -> List[str]: +def get_accepted_encodings( + request: Request, attribute_name: str = "accepts_parsed" +) -> List[str]: """ Get accepted encodings from request. @@ -549,12 +595,14 @@ def get_accepted_encodings(request: Request, attribute_name: str = "accepts_pars List[str]: List of accepted encodings ordered by quality. """ accepts_parsed = getattr(request.state, attribute_name, {}) - accept_items = accepts_parsed.get('accept_encoding', []) + accept_items = accepts_parsed.get("accept_encoding", []) return [item.value for item in accept_items if item.quality > 0] -def get_best_accepted_content_type(request: Request, available_types: List[str], attribute_name: str = "accepts_parsed") -> Optional[str]: +def get_best_accepted_content_type( + request: Request, available_types: List[str], attribute_name: str = "accepts_parsed" +) -> Optional[str]: """ Get the best matching content type from available types. @@ -577,7 +625,11 @@ def get_best_accepted_content_type(request: Request, available_types: List[str], return available_types[0] if available_types else None -def get_best_accepted_language(request: Request, available_languages: List[str], attribute_name: str = "accepts_parsed") -> Optional[str]: +def get_best_accepted_language( + request: Request, + available_languages: List[str], + attribute_name: str = "accepts_parsed", +) -> Optional[str]: """ Get the best matching language from available languages. @@ -597,21 +649,25 @@ def get_best_accepted_language(request: Request, available_languages: List[str], return accepted_lang # Language prefix match (e.g., "en" matches "en-US") - if '-' in accepted_lang: - lang_prefix = accepted_lang.split('-')[0] + if "-" in accepted_lang: + lang_prefix = accepted_lang.split("-")[0] for available_lang in available_languages: - if available_lang.startswith(lang_prefix + '-'): + if available_lang.startswith(lang_prefix + "-"): return available_lang if available_lang == lang_prefix: return available_lang # Fallback to first available language if no specific match + + class AcceptItem: """ Represents a single item in an Accept header with type/subtype and parameters. """ - def __init__(self, value: str, quality: float = 1.0, params: Optional[Dict[str, str]] = None): + def __init__( + self, value: str, quality: float = 1.0, params: Optional[Dict[str, str]] = None + ): """ Initialize an AcceptItem. @@ -643,7 +699,7 @@ def parse_accept_header(accept_header: str) -> List[AcceptItem]: items = [] - for part in accept_header.split(','): + for part in accept_header.split(","): part = part.strip() if not part: continue @@ -652,19 +708,19 @@ def parse_accept_header(accept_header: str) -> List[AcceptItem]: quality = 1.0 params = {} - if ';' in part: - media_range, param_str = part.split(';', 1) + if ";" in part: + media_range, param_str = part.split(";", 1) media_range = media_range.strip() # Parse parameters - for param in param_str.split(';'): + for param in param_str.split(";"): param = param.strip() - if '=' in param: - key, value = param.split('=', 1) + if "=" in param: + key, value = param.split("=", 1) key = key.strip().lower() value = value.strip() - if key == 'q': + if key == "q": try: quality = max(0.0, min(1.0, float(value))) except ValueError: @@ -680,7 +736,7 @@ def parse_accept_header(accept_header: str) -> List[AcceptItem]: items.append(AcceptItem(media_range, quality, params)) # Sort by quality (highest first), then by specificity - items.sort(key=lambda x: (-x.quality, x.value.count('/'), -len(x.value))) + items.sort(key=lambda x: (-x.quality, x.value.count("/"), -len(x.value))) return items @@ -724,7 +780,9 @@ def parse_accept_encoding(accept_encoding: str) -> List[AcceptItem]: return parse_accept_header(accept_encoding) -def negotiate_content_type(accept_header: str, available_types: List[str]) -> Optional[str]: +def negotiate_content_type( + accept_header: str, available_types: List[str] +) -> Optional[str]: """ Perform content negotiation for media types. @@ -755,19 +813,21 @@ def negotiate_content_type(accept_header: str, available_types: List[str]) -> Op continue # Check for */* or type/* - if accept_item.value == '*/*': + if accept_item.value == "*/*": return available_types[0] # Return first available type - if '/*' in accept_item.value: - accept_type = accept_item.value.split('/')[0] + if "/*" in accept_item.value: + accept_type = accept_item.value.split("/")[0] for available_type in available_types: - if available_type.startswith(accept_type + '/'): + if available_type.startswith(accept_type + "/"): return available_type return None -def negotiate_language(accept_language: str, available_languages: List[str]) -> Optional[str]: +def negotiate_language( + accept_language: str, available_languages: List[str] +) -> Optional[str]: """ Perform language negotiation. @@ -792,10 +852,10 @@ def negotiate_language(accept_language: str, available_languages: List[str]) -> return accept_item.value # Language prefix match (e.g., "en" matches "en-US") - if '-' in accept_item.value: - lang_prefix = accept_item.value.split('-')[0] + if "-" in accept_item.value: + lang_prefix = accept_item.value.split("-")[0] for available_lang in available_languages: - if available_lang.startswith(lang_prefix + '-'): + if available_lang.startswith(lang_prefix + "-"): return available_lang if available_lang == lang_prefix: return available_lang @@ -803,7 +863,9 @@ def negotiate_language(accept_language: str, available_languages: List[str]) -> return available_languages[0] if available_languages else None -def negotiate_charset(accept_charset: str, available_charsets: List[str]) -> Optional[str]: +def negotiate_charset( + accept_charset: str, available_charsets: List[str] +) -> Optional[str]: """ Perform charset negotiation. @@ -827,13 +889,15 @@ def negotiate_charset(accept_charset: str, available_charsets: List[str]) -> Opt return accept_item.value # Handle * wildcard - if accept_item.value == '*': + if accept_item.value == "*": return available_charsets[0] return available_charsets[0] if available_charsets else None -def negotiate_encoding(accept_encoding: str, available_encodings: List[str]) -> List[str]: +def negotiate_encoding( + accept_encoding: str, available_encodings: List[str] +) -> List[str]: """ Perform encoding negotiation. @@ -855,8 +919,10 @@ def negotiate_encoding(accept_encoding: str, available_encodings: List[str]) -> continue # Handle identity encoding - if accept_item.value == 'identity' or accept_item.value == '*': - accepted_encodings.extend([enc for enc in available_encodings if enc != 'identity']) + if accept_item.value == "identity" or accept_item.value == "*": + accepted_encodings.extend( + [enc for enc in available_encodings if enc != "identity"] + ) continue # Check for specific encoding match @@ -880,12 +946,12 @@ def matches_media_type(pattern: str, media_type: str) -> bool: if pattern == media_type: return True - if pattern == '*/*': + if pattern == "*/*": return True - if pattern.endswith('/*'): + if pattern.endswith("/*"): pattern_type = pattern[:-2] - return media_type.startswith(pattern_type + '/') + return media_type.startswith(pattern_type + "/") return False @@ -928,14 +994,20 @@ def get_accepts_info(request: Request) -> Dict[str, Any]: Dict[str, any]: Dictionary containing parsed accept information. """ return { - 'accept': parse_accept_header(request.headers.get('Accept', '')), - 'accept_language': parse_accept_language(request.headers.get('Accept-Language', '')), - 'accept_charset': parse_accept_charset(request.headers.get('Accept-Charset', '')), - 'accept_encoding': parse_accept_encoding(request.headers.get('Accept-Encoding', '')), - 'raw_accept': request.headers.get('Accept', ''), - 'raw_accept_language': request.headers.get('Accept-Language', ''), - 'raw_accept_charset': request.headers.get('Accept-Charset', ''), - 'raw_accept_encoding': request.headers.get('Accept-Encoding', ''), + "accept": parse_accept_header(request.headers.get("Accept", "")), + "accept_language": parse_accept_language( + request.headers.get("Accept-Language", "") + ), + "accept_charset": parse_accept_charset( + request.headers.get("Accept-Charset", "") + ), + "accept_encoding": parse_accept_encoding( + request.headers.get("Accept-Encoding", "") + ), + "raw_accept": request.headers.get("Accept", ""), + "raw_accept_language": request.headers.get("Accept-Language", ""), + "raw_accept_charset": request.headers.get("Accept-Charset", ""), + "raw_accept_encoding": request.headers.get("Accept-Encoding", ""), } @@ -951,12 +1023,12 @@ def create_vary_header(existing_vary: Optional[str], new_fields: List[str]) -> s str: Updated Vary header value. """ if not existing_vary: - return ', '.join(new_fields) + return ", ".join(new_fields) - existing_fields = [field.strip() for field in existing_vary.split(',')] + existing_fields = [field.strip() for field in existing_vary.split(",")] for field in new_fields: if field not in existing_fields: existing_fields.append(field) - return ', '.join(existing_fields) + return ", ".join(existing_fields) diff --git a/nexios_contrib/accepts/middleware.py b/nexios_contrib/accepts/middleware.py index 88e2000..7598218 100644 --- a/nexios_contrib/accepts/middleware.py +++ b/nexios_contrib/accepts/middleware.py @@ -4,22 +4,23 @@ This middleware provides automatic content negotiation and Accept header processing for Nexios applications. """ + from __future__ import annotations -from typing import Any,List, Optional +from typing import Any, List, Optional from nexios.http import Request, Response from nexios.middleware.base import BaseMiddleware from .helpers import ( + create_vary_header, + get_accepts_info, negotiate_content_type, negotiate_language, - get_accepts_info, - create_vary_header, - parse_accept_header, - parse_accept_language, parse_accept_charset, parse_accept_encoding, + parse_accept_header, + parse_accept_language, ) @@ -87,32 +88,35 @@ async def process_request( Returns: Any: The result from the next middleware or handler. """ - # Store parsed accepts information in request state if enabled + # Store parsed accepts information in request state if enabled if self.store_accepts_info: accepts_info = get_accepts_info(request) request.state.accepts = accepts_info - + # Store individual components for easier access request.state.accepts_parsed = { - 'accept': parse_accept_header(request.headers.get('Accept', '')), - 'accept_language': parse_accept_language(request.headers.get('Accept-Language', '')), - 'accept_charset': parse_accept_charset(request.headers.get('Accept-Charset', '')), - 'accept_encoding': parse_accept_encoding(request.headers.get('Accept-Encoding', '')), + "accept": parse_accept_header(request.headers.get("Accept", "")), + "accept_language": parse_accept_language( + request.headers.get("Accept-Language", "") + ), + "accept_charset": parse_accept_charset( + request.headers.get("Accept-Charset", "") + ), + "accept_encoding": parse_accept_encoding( + request.headers.get("Accept-Encoding", "") + ), } # Set Vary header if requested if self.set_vary_header: - if request.headers.get('Accept'): - self.vary.append('Accept') - if request.headers.get('Accept-Language'): - self.vary.append('Accept-Language') - if request.headers.get('Accept-Charset'): - self.vary.append('Accept-Charset') - if request.headers.get('Accept-Encoding'): - self.vary.append('Accept-Encoding') - - - + if request.headers.get("Accept"): + self.vary.append("Accept") + if request.headers.get("Accept-Language"): + self.vary.append("Accept-Language") + if request.headers.get("Accept-Charset"): + self.vary.append("Accept-Charset") + if request.headers.get("Accept-Encoding"): + self.vary.append("Accept-Encoding") return await call_next() @@ -132,21 +136,24 @@ async def process_response( Any: The response object. """ if self.vary: - existing_vary = response.headers.get('Vary') - response.set_header('Vary', create_vary_header(existing_vary, self.vary),overide=True) + existing_vary = response.headers.get("Vary") + response.set_header( + "Vary", create_vary_header(existing_vary, self.vary), overide=True + ) # Set default content type if not already set and Content-Type header is missing - if not response.headers.get('Content-Type') and self.default_content_type: + if not response.headers.get("Content-Type") and self.default_content_type: # Try to negotiate content type based on Accept header - accept_header = request.headers.get('Accept') + accept_header = request.headers.get("Accept") if accept_header: negotiated_type = negotiate_content_type( - accept_header, - [self.default_content_type] + accept_header, [self.default_content_type] ) if negotiated_type: - response.set_header('Content-Type', negotiated_type,overide=True) + response.set_header("Content-Type", negotiated_type, overide=True) else: - response.set_header('Content-Type', self.default_content_type,overide=True) + response.set_header( + "Content-Type", self.default_content_type, overide=True + ) return response @@ -196,7 +203,7 @@ def negotiate_content_type( self, request: Request, available_types: List[str], - default_type: Optional[str] = None + default_type: Optional[str] = None, ) -> str: """ Negotiate the best content type for this request. @@ -209,7 +216,7 @@ def negotiate_content_type( Returns: str: The best matching content type. """ - accept_header = request.headers.get('Accept') + accept_header = request.headers.get("Accept") if accept_header: negotiated = negotiate_content_type(accept_header, available_types) if negotiated: @@ -221,7 +228,7 @@ def negotiate_language( self, request: Request, available_languages: List[str], - default_language: Optional[str] = None + default_language: Optional[str] = None, ) -> str: """ Negotiate the best language for this request. @@ -234,7 +241,7 @@ def negotiate_language( Returns: str: The best matching language. """ - accept_language = request.headers.get('Accept-Language') + accept_language = request.headers.get("Accept-Language") if accept_language: negotiated = negotiate_language(accept_language, available_languages) if negotiated: @@ -252,9 +259,9 @@ def get_accepted_types(self, request: Request) -> List[str]: Returns: List[str]: List of accepted content types. """ - accepts_parsed = getattr(request.state, 'accepts_parsed', {}) - accept_items = accepts_parsed.get('accept', []) - + accepts_parsed = getattr(request.state, "accepts_parsed", {}) + accept_items = accepts_parsed.get("accept", []) + return [item.value for item in accept_items if item.quality > 0] def get_accepted_languages(self, request: Request) -> List[str]: @@ -267,8 +274,8 @@ def get_accepted_languages(self, request: Request) -> List[str]: Returns: List[str]: List of accepted languages. """ - accepts_parsed = getattr(request.state, 'accepts_parsed', {}) - accept_items = accepts_parsed.get('accept_language', []) + accepts_parsed = getattr(request.state, "accepts_parsed", {}) + accept_items = accepts_parsed.get("accept_language", []) return [item.value for item in accept_items if item.quality > 0] @@ -286,7 +293,7 @@ def __init__( *, available_types: List[str], available_languages: Optional[List[str]] = None, - **kwargs + **kwargs, ): """ Initialize the StrictContentNegotiationMiddleware. @@ -298,7 +305,7 @@ def __init__( """ super().__init__(**kwargs) self.available_types = available_types - self.available_languages = available_languages or ['en'] + self.available_languages = available_languages or ["en"] async def process_request( self, @@ -311,31 +318,29 @@ async def process_request( """ # Perform strict content negotiation best_type = self.negotiate_content_type( - request, - self.available_types, - self.default_content_type + request, self.available_types, self.default_content_type ) # Check if client accepts the best available type - accept_header = request.headers.get('Accept') + accept_header = request.headers.get("Accept") if accept_header and best_type not in self.available_types: # Client doesn't accept any of our available types response.status(406) - response.set_header('Content-Type', 'application/json') - return response.json({ - "error": "Not Acceptable", - "message": "Client does not accept any available content types", - "available_types": self.available_types - }) + response.set_header("Content-Type", "application/json") + return response.json( + { + "error": "Not Acceptable", + "message": "Client does not accept any available content types", + "available_types": self.available_types, + } + ) # Store negotiation results in request - setattr(request, 'negotiated_content_type', best_type) + setattr(request, "negotiated_content_type", best_type) best_language = self.negotiate_language( - request, - self.available_languages, - self.default_language + request, self.available_languages, self.default_language ) - setattr(request, 'negotiated_language', best_language) + setattr(request, "negotiated_language", best_language) return await call_next() diff --git a/nexios_contrib/etag/__init__.py b/nexios_contrib/etag/__init__.py index 9db8dd0..ba4e959 100644 --- a/nexios_contrib/etag/__init__.py +++ b/nexios_contrib/etag/__init__.py @@ -1,3 +1,5 @@ +from typing import Iterable + from .helper import ( compute_and_set_etag, etag_matches, @@ -9,7 +11,6 @@ set_response_etag, ) from .middleware import ETagMiddleware -from typing import Iterable __all__ = [ "ETagMiddleware", @@ -23,5 +24,8 @@ "is_fresh", ] -def ETag(weak: bool = True, methods: Iterable[str] = ("GET", "HEAD"), override: bool = False) -> ETagMiddleware: - return ETagMiddleware(weak=weak, methods=methods, override=override) \ No newline at end of file + +def ETag( + weak: bool = True, methods: Iterable[str] = ("GET", "HEAD"), override: bool = False +) -> ETagMiddleware: + return ETagMiddleware(weak=weak, methods=methods, override=override) diff --git a/nexios_contrib/etag/helper.py b/nexios_contrib/etag/helper.py index cd501b7..936156c 100644 --- a/nexios_contrib/etag/helper.py +++ b/nexios_contrib/etag/helper.py @@ -4,16 +4,17 @@ This module provides functions for generating, validating, and setting ETag headers. It is designed to work with Nexios's Request and Response abstractions. """ + from __future__ import annotations import re from base64 import b64encode from hashlib import sha1 -from typing import Iterable, Optional +from typing import Iterable from nexios.http import Request, Response -_WEAK_PREFIX = 'W/' +_WEAK_PREFIX = "W/" _ETAG_TOKEN_RE = re.compile(r"^(W\/)?\s*\"[^\"]*\"\s*$") @@ -42,7 +43,7 @@ def normalize_etag(tag: str) -> str: tag = tag.strip() if not _ETAG_TOKEN_RE.match(tag): # Try to coerce when unquoted simple tokens are passed - if not tag.startswith('W/'): + if not tag.startswith("W/"): clean = tag.strip('"') tag = f'"{clean}"' else: @@ -67,14 +68,16 @@ def set_response_etag(response: Response, etag: str, override: bool = True) -> N response.set_header("etag", tag, overide=override) -def compute_and_set_etag(response: Response, body :bytes = b'' ,weak: bool = True, override: bool = False) -> str: +def compute_and_set_etag( + response: Response, body: bytes = b"", weak: bool = True, override: bool = False +) -> str: """ Compute an ETag for the current response body and set it if not present (or if override=True). Returns the ETag value used. """ # Response.body returns bytes - + tag = generate_etag_from_bytes(body, weak=weak) set_response_etag(response, tag, override=override) return tag @@ -112,7 +115,9 @@ def parse_if_match(request: Request) -> list[str]: return tags -def etag_matches(etag: str, candidates: Iterable[str], weak_compare: bool = True) -> bool: +def etag_matches( + etag: str, candidates: Iterable[str], weak_compare: bool = True +) -> bool: """ Check if an ETag matches any in candidates. diff --git a/nexios_contrib/etag/middleware.py b/nexios_contrib/etag/middleware.py index 356cb5d..66f30c1 100644 --- a/nexios_contrib/etag/middleware.py +++ b/nexios_contrib/etag/middleware.py @@ -5,13 +5,13 @@ handles conditional GET/HEAD using the If-None-Match header to return 304 Not Modified when appropriate. """ + from __future__ import annotations -from typing import Any, Iterable, Tuple +from typing import Any, Iterable from nexios.http import Request, Response from nexios.middleware.base import BaseMiddleware -from nexios.http.response import BaseResponse from .helper import compute_and_set_etag, is_fresh @@ -39,8 +39,6 @@ def __init__( self.methods = tuple(m.upper() for m in methods) self.override = override - - async def __call__( self, request: Request, @@ -59,12 +57,11 @@ async def __call__( async for chunk in stream.content_iterator: body += chunk compute_and_set_etag(response, body, weak=self.weak, override=True) - print("cnext header b4 isfresh",response.headers,has_existing) + print("cnext header b4 isfresh", response.headers, has_existing) # Handle If-None-Match freshness for conditional requests if is_fresh(request, response, weak_compare=True): # Per RFC 9110, a 304 response must not include a message body # Ensure body is empty; BaseResponse avoids content-length for 304 - response.status(304) - + response.status(304) - return stream + return stream diff --git a/nexios_contrib/graphql/plugin.py b/nexios_contrib/graphql/plugin.py index a76dda5..28a7495 100644 --- a/nexios_contrib/graphql/plugin.py +++ b/nexios_contrib/graphql/plugin.py @@ -1,12 +1,10 @@ -import json -from typing import Optional, Any +from typing import Any import strawberry -from strawberry.types import ExecutionResult - from nexios.application import NexiosApp from nexios.http import Request, Response from nexios.routing import Route +from strawberry.types import ExecutionResult class GraphQL: @@ -25,7 +23,7 @@ def __init__( self.schema = schema self.path = path self.graphiql = graphiql - + self._setup() def _setup(self): @@ -45,10 +43,14 @@ async def handle_request(self, req: Request, res: Response): try: data = await req.json except Exception: - return res.status(400).json({"errors": [{"message": "Invalid JSON body"}]}) + return res.status(400).json( + {"errors": [{"message": "Invalid JSON body"}]} + ) if not isinstance(data, dict): - return res.status(400).json({"errors": [{"message": "JSON body must be an object"}]}) + return res.status(400).json( + {"errors": [{"message": "JSON body must be an object"}]} + ) query = data.get("query") variables = data.get("variables") @@ -68,7 +70,7 @@ async def handle_request(self, req: Request, res: Response): response_data["data"] = result.data if result.errors: response_data["errors"] = [err.formatted for err in result.errors] - + return res.json(response_data) def _get_graphiql_html(self) -> str: diff --git a/nexios_contrib/jrpc/__init__.py b/nexios_contrib/jrpc/__init__.py index d02d6b8..d254c1a 100644 --- a/nexios_contrib/jrpc/__init__.py +++ b/nexios_contrib/jrpc/__init__.py @@ -3,9 +3,15 @@ """ from .client import JsonRpcClient -from .server import JsonRpcPlugin +from .exceptions import ( + JsonRpcClientError, + JsonRpcError, + JsonRpcInvalidParams, + JsonRpcInvalidRequest, + JsonRpcMethodNotFound, +) from .registry import JsonRpcRegistry, get_registry -from .exceptions import JsonRpcError, JsonRpcMethodNotFound, JsonRpcInvalidParams, JsonRpcInvalidRequest, JsonRpcClientError +from .server import JsonRpcPlugin __all__ = [ "JsonRpcClient", diff --git a/nexios_contrib/jrpc/client.py b/nexios_contrib/jrpc/client.py index 658aaff..1ba7454 100644 --- a/nexios_contrib/jrpc/client.py +++ b/nexios_contrib/jrpc/client.py @@ -1,7 +1,7 @@ import asyncio import json -import urllib.request import urllib.error +import urllib.request from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Union @@ -24,7 +24,9 @@ def __init__(self, base_url: str): """ self.base_url = base_url self.request_id = 0 - self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="jsonrpc-client") + self._executor = ThreadPoolExecutor( + max_workers=4, thread_name_prefix="jsonrpc-client" + ) def _generate_request_id(self) -> int: """Generate a unique request ID.""" @@ -143,7 +145,9 @@ def method_caller(*args, **kwargs): return method_caller - async def acall(self, method: str, params: Union[Dict[str, Any], List[Any]] = None) -> Any: + async def acall( + self, method: str, params: Union[Dict[str, Any], List[Any]] = None + ) -> Any: """ Call a JSON-RPC method asynchronously. @@ -171,4 +175,4 @@ def call(self, method: str, params: Union[Dict[str, Any], List[Any]] = None) -> """ if params is None: params = {} - return self._make_request_sync(method, params) \ No newline at end of file + return self._make_request_sync(method, params) diff --git a/nexios_contrib/jrpc/registry.py b/nexios_contrib/jrpc/registry.py index 9df5bb2..99c48e1 100644 --- a/nexios_contrib/jrpc/registry.py +++ b/nexios_contrib/jrpc/registry.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Optional +from typing import Callable, Optional from .exceptions import JsonRpcMethodNotFound @@ -43,4 +43,4 @@ def get_registry() -> JsonRpcRegistry: Returns: JsonRpcRegistry: The singleton registry instance """ - return JsonRpcRegistry() \ No newline at end of file + return JsonRpcRegistry() diff --git a/nexios_contrib/jrpc/server.py b/nexios_contrib/jrpc/server.py index c63a521..837a073 100644 --- a/nexios_contrib/jrpc/server.py +++ b/nexios_contrib/jrpc/server.py @@ -116,4 +116,4 @@ async def _call_method(self, method: Callable, params: Dict[str, Any]) -> Any: else: return method(*bound.args, **bound.kwargs) except TypeError as e: - raise JsonRpcInvalidParams(str(e)) \ No newline at end of file + raise JsonRpcInvalidParams(str(e)) diff --git a/nexios_contrib/mail/__init__.py b/nexios_contrib/mail/__init__.py index 94e5b91..0353c90 100644 --- a/nexios_contrib/mail/__init__.py +++ b/nexios_contrib/mail/__init__.py @@ -17,56 +17,56 @@ from .config import MailConfig from .dependency import MailDepend, get_mail_client from .models import EmailMessage, EmailResult -from .tasks import MailTaskManager, add_task_support, send_email_async, send_template_email_async +from .tasks import ( + MailTaskManager, + add_task_support, + send_email_async, + send_template_email_async, +) __all__ = [ # Main classes - 'MailClient', - 'MailConfig', - 'EmailMessage', - 'EmailResult', - + "MailClient", + "MailConfig", + "EmailMessage", + "EmailResult", # Dependency injection - 'MailDepend', - 'get_mail_client', - + "MailDepend", + "get_mail_client", # Background tasks - 'MailTaskManager', - 'add_task_support', - 'send_email_async', - 'send_template_email_async', - + "MailTaskManager", + "add_task_support", + "send_email_async", + "send_template_email_async", # Setup functions - 'setup_mail', - 'get_mail_from_request', + "setup_mail", + "get_mail_from_request", ] -def setup_mail( - app: NexiosApp, - config: Optional[MailConfig] = None -) -> MailClient: + +def setup_mail(app: NexiosApp, config: Optional[MailConfig] = None) -> MailClient: """Set up the mail client for a Nexios application. - + This function initializes the mail client and registers it with the Nexios app. It should be called during application startup. - + Args: app: The Nexios application instance. config: Optional configuration for the mail client. - + Returns: The initialized MailClient instance. - + Example: ```python from nexios import NexiosApp from nexios_contrib.mail import setup_mail, MailConfig - + app = NexiosApp() - + # Initialize with default configuration mail_client = setup_mail(app) - + # Or with custom configuration config = MailConfig( smtp_host="smtp.gmail.com", @@ -78,41 +78,42 @@ def setup_mail( mail_client = setup_mail(app, config=config) ``` """ - if not hasattr(app, 'mail_client'): + if not hasattr(app, "mail_client"): mail_client = MailClient(config=config) - app.mail_client = mail_client + app.mail_client = mail_client # ty:ignore[invalid-assignment] app.on_startup(mail_client.start) app.on_shutdown(mail_client.stop) - + # Add background task support if available try: add_task_support(mail_client) except Exception: # Tasks not available, continue without them pass - - return app.mail_client + + return app.mail_client # ty:ignore[unresolved-attribute] + def get_mail_from_request(request: Request) -> MailClient: """Get the mail client from a request. - + This is a convenience function to get the mail client instance from a request object. - + Args: request: The current request object. - + Returns: The MailClient instance. - + Raises: AttributeError: If the mail client is not initialized. - + Example: ```python from nexios import Request from nexios_contrib.mail import get_mail_from_request - + @app.post("/send-email") async def send_email_endpoint(request: Request): mail_client = get_mail_from_request(request) @@ -124,7 +125,7 @@ async def send_email_endpoint(request: Request): return {"status": "sent", "message_id": result.message_id} ``` """ - mail_client = getattr(request.base_app, 'mail_client', None) + mail_client = getattr(request.base_app, "mail_client", None) if mail_client is None: raise AttributeError( "Mail client not initialized. Call setup_mail(app) during application startup." diff --git a/nexios_contrib/mail/client.py b/nexios_contrib/mail/client.py index 1c73ac7..3be6ccb 100644 --- a/nexios_contrib/mail/client.py +++ b/nexios_contrib/mail/client.py @@ -11,32 +11,32 @@ import logging import smtplib from email.mime.multipart import MIMEMultipart -from email.utils import formataddr from pathlib import Path from typing import Any, Dict, List, Optional, Union try: import jinja2 + JINJA2_AVAILABLE = True except ImportError: JINJA2_AVAILABLE = False from .config import MailConfig -from .models import EmailMessage, EmailResult, EmailError +from .models import EmailMessage, EmailResult logger = logging.getLogger(__name__) class MailClient: """Main mail client for sending emails with SMTP and template support. - + This client provides a high-level interface for sending emails through SMTP servers, with support for HTML templates, attachments, and background tasks. - + Example: ```python from nexios_contrib.mail import MailClient, MailConfig - + config = MailConfig( smtp_host="smtp.gmail.com", smtp_port=587, @@ -44,10 +44,10 @@ class MailClient: smtp_password="your-app-password", use_tls=True ) - + mail_client = MailClient(config=config) await mail_client.start() - + result = await mail_client.send_email( to="recipient@example.com", subject="Hello World", @@ -55,10 +55,10 @@ class MailClient: ) ``` """ - + def __init__(self, config: Optional[MailConfig] = None) -> None: """Initialize the mail client. - + Args: config: Optional mail configuration. If not provided, uses default config. """ @@ -66,51 +66,51 @@ def __init__(self, config: Optional[MailConfig] = None) -> None: self._smtp_pool: Optional[smtplib.SMTP] = None self._template_env: Optional[jinja2.Environment] = None self._is_started = False - + # Setup template environment if Jinja2 is available if JINJA2_AVAILABLE and self.config.template_directory: self._setup_template_environment() - + def _setup_template_environment(self) -> None: """Setup the Jinja2 template environment.""" if not self.config.template_directory: return - + template_path = Path(self.config.template_directory) if not template_path.exists(): logger.warning(f"Template directory not found: {template_path}") return - + loader = jinja2.FileSystemLoader(str(template_path)) self._template_env = jinja2.Environment( loader=loader, autoescape=self.config.template_auto_escape, trim_blocks=True, - lstrip_blocks=True + lstrip_blocks=True, ) - + # Add custom filters self._template_env.filters["format_date"] = self._format_date_filter - + def _format_date_filter(self, value: Any, format_str: str = "%Y-%m-%d") -> str: """Jinja2 filter for formatting dates. - + Args: value: Date value to format. format_str: Date format string. - + Returns: Formatted date string. """ if hasattr(value, "strftime"): return value.strftime(format_str) return str(value) - + async def start(self) -> None: """Start the mail client and initialize SMTP connection.""" if self._is_started: return - + try: if not self.config.suppress_send: await self._create_smtp_connection() @@ -119,12 +119,12 @@ async def start(self) -> None: except Exception as e: logger.error(f"Failed to start mail client: {e}") raise - + async def stop(self) -> None: """Stop the mail client and close SMTP connection.""" if not self._is_started: return - + try: if self._smtp_pool: self._smtp_pool.quit() @@ -133,41 +133,42 @@ async def stop(self) -> None: logger.info("Mail client stopped successfully") except Exception as e: logger.error(f"Error stopping mail client: {e}") - + async def _create_smtp_connection(self) -> None: """Create and configure SMTP connection.""" if self.config.use_ssl: self._smtp_pool = smtplib.SMTP_SSL( self.config.smtp_host, self.config.smtp_port, - timeout=self.config.smtp_timeout + timeout=self.config.smtp_timeout, ) else: self._smtp_pool = smtplib.SMTP( self.config.smtp_host, self.config.smtp_port, - timeout=self.config.smtp_timeout + timeout=self.config.smtp_timeout, ) - + if self.config.use_tls: self._smtp_pool.starttls() - + # Enable debug mode if configured if self.config.debug: self._smtp_pool.set_debuglevel(1) - + # Authenticate if credentials are provided if self.config.smtp_username and self.config.smtp_password: try: self._smtp_pool.login( - self.config.smtp_username, - self.config.smtp_password + self.config.smtp_username, self.config.smtp_password + ) + logger.info( + f"SMTP authentication successful for {self.config.smtp_username}" ) - logger.info(f"SMTP authentication successful for {self.config.smtp_username}") except Exception as e: logger.error(f"SMTP authentication failed: {e}") raise - + async def send_email( self, to: Union[str, List[str]], @@ -181,10 +182,10 @@ async def send_email( attachments: Optional[List[Any]] = None, template_name: Optional[str] = None, template_context: Optional[Dict[str, Any]] = None, - **kwargs: Any + **kwargs: Any, ) -> EmailResult: """Send an email. - + Args: to: Recipient email address(es). subject: Email subject. @@ -198,7 +199,7 @@ async def send_email( template_name: Name of the template to use. template_context: Context variables for the template. **kwargs: Additional email parameters. - + Returns: EmailResult indicating success or failure. """ @@ -214,25 +215,25 @@ async def send_email( bcc=bcc, template_name=template_name, template_context=template_context, - **kwargs + **kwargs, ) - + # Add attachments if provided if attachments: for attachment in attachments: if isinstance(attachment, dict): message.add_attachment(**attachment) else: - message.add_attachment(attachment) - + message.add_attachment(attachment) # ty:ignore[missing-argument] + return await self.send_message(message) - + async def send_message(self, message: EmailMessage) -> EmailResult: """Send an EmailMessage. - + Args: message: The EmailMessage to send. - + Returns: EmailResult indicating success or failure. """ @@ -240,116 +241,121 @@ async def send_message(self, message: EmailMessage) -> EmailResult: # Render template if specified if message.template_name and self._template_env: await self._render_template(message) - + # Use default from email if not specified from_email = message.from_email or self.config.default_from if not from_email: raise ValueError("No 'from' email address specified") - + # Create MIME message mime_message = message.to_mime_message(from_email) - + # Add default CC/BCC if not specified if self.config.default_cc and not message.cc: mime_message["Cc"] = ", ".join(self.config.default_cc) - message.cc.extend(self.config.default_cc) - + message.cc.extend(self.config.default_cc) # ty:ignore[unresolved-attribute] + if self.config.default_bcc and not message.bcc: - message.bcc.extend(self.config.default_bcc) - + message.bcc.extend(self.config.default_bcc) # ty:ignore[unresolved-attribute] + # Prepare recipient list recipients = list(message.to) if message.cc: recipients.extend(message.cc) if message.bcc: recipients.extend(message.bcc) - + # Send email if self.config.suppress_send: - logger.info(f"Email sending suppressed: {message.subject} to {recipients}") + logger.info( + f"Email sending suppressed: {message.subject} to {recipients}" + ) return EmailResult( success=True, message_id=message.message_id, to=recipients, subject=message.subject, - provider_response={"suppressed": True} + provider_response={"suppressed": True}, ) - + # Send in a thread pool to avoid blocking loop = asyncio.get_event_loop() await loop.run_in_executor( - None, - self._send_mime_message, - mime_message, - recipients + None, self._send_mime_message, mime_message, recipients ) - - logger.info(f"Email sent successfully: {message.message_id} to {recipients}") - + + logger.info( + f"Email sent successfully: {message.message_id} to {recipients}" + ) + return EmailResult( success=True, message_id=message.message_id, to=recipients, - subject=message.subject + subject=message.subject, ) - + except Exception as e: error_msg = str(e) logger.error(f"Failed to send email: {error_msg}") - + return EmailResult( success=False, message_id=message.message_id, to=list(message.to), subject=message.subject, - error=error_msg + error=error_msg, ) - - def _send_mime_message(self, mime_message: MIMEMultipart, recipients: List[str]) -> None: + + def _send_mime_message( + self, mime_message: MIMEMultipart, recipients: List[str] + ) -> None: """Send MIME message using SMTP. - + Args: mime_message: The MIME message to send. recipients: List of recipient email addresses. """ if not self._smtp_pool: raise RuntimeError("SMTP connection not established") - + self._smtp_pool.sendmail( - mime_message["From"], - recipients, - mime_message.as_string() + mime_message["From"], recipients, mime_message.as_string() ) - + async def _render_template(self, message: EmailMessage) -> None: """Render email template. - + Args: message: The email message to render template for. """ if not self._template_env or not message.template_name: return - + try: # Try to render HTML template - html_template = self._template_env.get_template(f"{message.template_name}.html") - message.html_body = html_template.render(**message.template_context) - + html_template = self._template_env.get_template( + f"{message.template_name}.html" + ) + message.html_body = html_template.render(**message.template_context) # ty:ignore[invalid-argument-type] + # Try to render text template try: - text_template = self._template_env.get_template(f"{message.template_name}.txt") - message.body = text_template.render(**message.template_context) + text_template = self._template_env.get_template( + f"{message.template_name}.txt" + ) + message.body = text_template.render(**message.template_context) # ty:ignore[invalid-argument-type] except jinja2.TemplateNotFound: # Text template is optional pass - + except jinja2.TemplateNotFound as e: logger.error(f"Template not found: {e}") raise except jinja2.TemplateError as e: logger.error(f"Template rendering error: {e}") raise - + async def send_template_email( self, to: Union[str, List[str]], @@ -357,10 +363,10 @@ async def send_template_email( template_name: str, context: Optional[Dict[str, Any]] = None, from_email: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ) -> EmailResult: """Send an email using a template. - + Args: to: Recipient email address(es). subject: Email subject. @@ -368,7 +374,7 @@ async def send_template_email( context: Template context variables. from_email: Sender email address. **kwargs: Additional email parameters. - + Returns: EmailResult indicating success or failure. """ @@ -378,22 +384,19 @@ async def send_template_email( template_name=template_name, template_context=context, from_email=from_email, - **kwargs + **kwargs, ) - + def create_message( - self, - to: Union[str, List[str]], - subject: str, - **kwargs: Any + self, to: Union[str, List[str]], subject: str, **kwargs: Any ) -> EmailMessage: """Create an EmailMessage object. - + Args: to: Recipient email address(es). subject: Email subject. **kwargs: Additional message parameters. - + Returns: EmailMessage instance. """ diff --git a/nexios_contrib/mail/config.py b/nexios_contrib/mail/config.py index 8910ad0..b49e275 100644 --- a/nexios_contrib/mail/config.py +++ b/nexios_contrib/mail/config.py @@ -9,52 +9,77 @@ import os from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, List @dataclass class MailConfig: """Configuration for the mail client. - + This class contains all the settings needed to configure the SMTP connection and email sending behavior. """ - + # SMTP Configuration smtp_host: str = field(default_factory=lambda: os.getenv("SMTP_HOST", "localhost")) smtp_port: int = field(default_factory=lambda: int(os.getenv("SMTP_PORT", "587"))) - smtp_username: Optional[str] = field(default_factory=lambda: os.getenv("SMTP_USERNAME")) - smtp_password: Optional[str] = field(default_factory=lambda: os.getenv("SMTP_PASSWORD")) - use_tls: bool = field(default_factory=lambda: os.getenv("SMTP_USE_TLS", "true").lower() == "true") - use_ssl: bool = field(default_factory=lambda: os.getenv("SMTP_USE_SSL", "false").lower() == "true") - + smtp_username: Optional[str] = field( + default_factory=lambda: os.getenv("SMTP_USERNAME") + ) + smtp_password: Optional[str] = field( + default_factory=lambda: os.getenv("SMTP_PASSWORD") + ) + use_tls: bool = field( + default_factory=lambda: os.getenv("SMTP_USE_TLS", "true").lower() == "true" + ) + use_ssl: bool = field( + default_factory=lambda: os.getenv("SMTP_USE_SSL", "false").lower() == "true" + ) + # Email defaults - default_from: Optional[str] = field(default_factory=lambda: os.getenv("MAIL_DEFAULT_FROM")) - default_reply_to: Optional[str] = field(default_factory=lambda: os.getenv("MAIL_DEFAULT_REPLY_TO")) + default_from: Optional[str] = field( + default_factory=lambda: os.getenv("MAIL_DEFAULT_FROM") + ) + default_reply_to: Optional[str] = field( + default_factory=lambda: os.getenv("MAIL_DEFAULT_REPLY_TO") + ) default_cc: Optional[List[str]] = None default_bcc: Optional[List[str]] = None - + # Connection settings - smtp_timeout: float = field(default_factory=lambda: float(os.getenv("SMTP_TIMEOUT", "30"))) - max_connections: int = field(default_factory=lambda: int(os.getenv("SMTP_MAX_CONNECTIONS", "10"))) - + smtp_timeout: float = field( + default_factory=lambda: float(os.getenv("SMTP_TIMEOUT", "30")) + ) + max_connections: int = field( + default_factory=lambda: int(os.getenv("SMTP_MAX_CONNECTIONS", "10")) + ) + # Template settings - template_directory: Optional[str] = field(default_factory=lambda: os.getenv("MAIL_TEMPLATE_DIR")) + template_directory: Optional[str] = field( + default_factory=lambda: os.getenv("MAIL_TEMPLATE_DIR") + ) template_auto_escape: bool = True - + # Background task settings use_background_tasks: bool = True - task_timeout: Optional[float] = field(default_factory=lambda: float(os.getenv("MAIL_TASK_TIMEOUT", "300"))) - + task_timeout: Optional[float] = field( + default_factory=lambda: float(os.getenv("MAIL_TASK_TIMEOUT", "300")) + ) + # Debug settings - debug: bool = field(default_factory=lambda: os.getenv("MAIL_DEBUG", "false").lower() == "true") - suppress_send: bool = field(default_factory=lambda: os.getenv("MAIL_SUPPRESS_SEND", "false").lower() == "true") - + debug: bool = field( + default_factory=lambda: os.getenv("MAIL_DEBUG", "false").lower() == "true" + ) + suppress_send: bool = field( + default_factory=lambda: os.getenv("MAIL_SUPPRESS_SEND", "false").lower() + == "true" + ) + def __post_init__(self) -> None: """Validate configuration after initialization.""" if self.use_ssl and self.use_tls: raise ValueError("Cannot use both SSL and TLS. Choose one.") - + if self.smtp_port == 465 and not self.use_ssl: # Port 465 is typically used for SSL self.use_ssl = True @@ -63,10 +88,10 @@ def __post_init__(self) -> None: # Port 587 is typically used for TLS self.use_tls = True self.use_ssl = False - + def to_dict(self) -> Dict[str, Any]: """Convert configuration to dictionary. - + Returns: Dictionary representation of the configuration. """ @@ -90,28 +115,28 @@ def to_dict(self) -> Dict[str, Any]: "debug": self.debug, "suppress_send": self.suppress_send, } - + @classmethod def from_dict(cls, data: Dict[str, Any]) -> MailConfig: """Create configuration from dictionary. - + Args: data: Dictionary containing configuration values. - + Returns: MailConfig instance. """ return cls(**data) - + @classmethod def for_gmail(cls, username: str, password: str, **kwargs: Any) -> MailConfig: """Create configuration for Gmail SMTP. - + Args: username: Gmail username or email address. password: Gmail app password (not regular password). **kwargs: Additional configuration options. - + Returns: MailConfig configured for Gmail. """ @@ -121,18 +146,18 @@ def for_gmail(cls, username: str, password: str, **kwargs: Any) -> MailConfig: smtp_username=username, smtp_password=password, use_tls=True, - **kwargs + **kwargs, ) - + @classmethod def for_outlook(cls, username: str, password: str, **kwargs: Any) -> MailConfig: """Create configuration for Outlook/Office 365 SMTP. - + Args: username: Outlook username or email address. password: Outlook password. **kwargs: Additional configuration options. - + Returns: MailConfig configured for Outlook. """ @@ -142,17 +167,17 @@ def for_outlook(cls, username: str, password: str, **kwargs: Any) -> MailConfig: smtp_username=username, smtp_password=password, use_tls=True, - **kwargs + **kwargs, ) - + @classmethod def for_sendgrid(cls, api_key: str, **kwargs: Any) -> MailConfig: """Create configuration for SendGrid SMTP. - + Args: api_key: SendGrid API key. **kwargs: Additional configuration options. - + Returns: MailConfig configured for SendGrid. """ @@ -162,5 +187,5 @@ def for_sendgrid(cls, api_key: str, **kwargs: Any) -> MailConfig: smtp_username="apikey", smtp_password=api_key, use_tls=True, - **kwargs + **kwargs, ) diff --git a/nexios_contrib/mail/dependency.py b/nexios_contrib/mail/dependency.py index 3ae0eda..e6218f5 100644 --- a/nexios_contrib/mail/dependency.py +++ b/nexios_contrib/mail/dependency.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Optional, TypeVar, cast +from typing import TypeVar, cast from nexios.dependencies import Depend, current_context from nexios.http import Request @@ -19,14 +19,14 @@ class MailDepend(Depend[MailClient]): """Dependency provider for the mail client. - + This class provides a dependency injection wrapper for the mail client, allowing it to be easily injected into route handlers and other components. - + Example: ```python from nexios_contrib.mail import MailDepend - + @app.post("/send-email") async def send_email( mail_client: MailClient = MailDepend() @@ -39,17 +39,17 @@ async def send_email( return {"status": "sent", "message_id": result.message_id} ``` """ - + def __init__(self) -> None: """Initialize the mail dependency.""" super().__init__(self._get_mail_client) - + async def _get_mail_client(self) -> MailClient: """Get the mail client from the current context. - + Returns: The MailClient instance from the current request context. - + Raises: RuntimeError: If no mail client is found in the context. """ @@ -59,7 +59,7 @@ async def _get_mail_client(self) -> MailClient: return get_mail_from_request(ctx.request) except LookupError: pass - + raise RuntimeError( "Mail client not found in current context. " "Make sure setup_mail(app) was called during application startup." @@ -68,24 +68,24 @@ async def _get_mail_client(self) -> MailClient: def get_mail_client(request: Request) -> MailClient: """Get the mail client from a request. - + This is a convenience function that retrieves the mail client from the Nexios application instance attached to the request. - + Args: request: The current request object. - + Returns: The MailClient instance. - + Raises: AttributeError: If the mail client is not initialized. - + Example: ```python from nexios import Request from nexios_contrib.mail import get_mail_client - + @app.post("/send-email") async def send_email(request: Request): mail_client = get_mail_client(request) @@ -97,7 +97,7 @@ async def send_email(request: Request): return {"status": "sent", "message_id": result.message_id} ``` """ - mail_client = getattr(request.base_app, 'mail_client', None) + mail_client = getattr(request.base_app, "mail_client", None) if mail_client is None: raise AttributeError( "Mail client not initialized. Call setup_mail(app) during application startup." @@ -107,10 +107,10 @@ async def send_email(request: Request): def get_mail_from_request(request: Request) -> MailClient: """Alias for get_mail_client for backward compatibility. - + Args: request: The current request object. - + Returns: The MailClient instance. """ diff --git a/nexios_contrib/mail/models.py b/nexios_contrib/mail/models.py index a0bf588..007d68d 100644 --- a/nexios_contrib/mail/models.py +++ b/nexios_contrib/mail/models.py @@ -8,7 +8,7 @@ import uuid from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone, UTC from email.mime.base import MIMEBase from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText @@ -19,12 +19,12 @@ @dataclass class EmailAttachment: """Represents an email attachment.""" - + filename: str content: Union[bytes, str] content_type: Optional[str] = None content_id: Optional[str] = None # For inline images - + def __post_init__(self) -> None: """Validate and process attachment after initialization.""" if isinstance(self.content, str): @@ -35,10 +35,11 @@ def __post_init__(self) -> None: if not self.content_type: # Guess content type from file extension import mimetypes + self.content_type, _ = mimetypes.guess_type(str(path)) else: raise FileNotFoundError(f"Attachment file not found: {self.content}") - + if not self.content_type: self.content_type = "application/octet-stream" @@ -46,50 +47,50 @@ def __post_init__(self) -> None: @dataclass class EmailMessage: """Represents an email message.""" - + # Required fields to: Union[str, List[str]] subject: str - + # Content fields body: Optional[str] = None html_body: Optional[str] = None template_name: Optional[str] = None template_context: Optional[Dict[str, Any]] = None - + # Optional fields from_email: Optional[str] = None reply_to: Optional[Union[str, List[str]]] = None cc: Optional[Union[str, List[str]]] = None bcc: Optional[Union[str, List[str]]] = None attachments: Optional[List[EmailAttachment]] = None - + # Metadata message_id: str = field(default_factory=lambda: str(uuid.uuid4())) headers: Optional[Dict[str, str]] = None priority: Optional[int] = None # 1 (high), 3 (normal), 5 (low) - + def __post_init__(self) -> None: """Normalize and validate email addresses after initialization.""" # Normalize single addresses to lists if isinstance(self.to, str): self.to = [self.to] - + if isinstance(self.cc, str): self.cc = [self.cc] elif self.cc is None: self.cc = [] - + if isinstance(self.bcc, str): self.bcc = [self.bcc] elif self.bcc is None: self.bcc = [] - + if isinstance(self.reply_to, str): self.reply_to = [self.reply_to] elif self.reply_to is None: self.reply_to = [] - + # Initialize empty lists/dicts if None if self.attachments is None: self.attachments = [] @@ -97,16 +98,16 @@ def __post_init__(self) -> None: self.headers = {} if self.template_context is None: self.template_context = {} - + def add_attachment( self, filename: str, content: Union[bytes, str], content_type: Optional[str] = None, - content_id: Optional[str] = None + content_id: Optional[str] = None, ) -> None: """Add an attachment to the email. - + Args: filename: Name of the attachment file. content: File content as bytes or file path string. @@ -115,29 +116,31 @@ def add_attachment( """ if self.attachments is None: self.attachments = [] - + attachment = EmailAttachment( filename=filename, content=content, content_type=content_type, - content_id=content_id + content_id=content_id, ) self.attachments.append(attachment) - - def set_template(self, template_name: str, context: Optional[Dict[str, Any]] = None) -> None: + + def set_template( + self, template_name: str, context: Optional[Dict[str, Any]] = None + ) -> None: """Set the template for the email. - + Args: template_name: Name of the template file. context: Template context variables. """ self.template_name = template_name if context: - self.template_context = {**self.template_context, **context} - + self.template_context = {**self.template_context, **context} # ty:ignore[invalid-argument-type] + def add_header(self, name: str, value: str) -> None: """Add a custom header to the email. - + Args: name: Header name. value: Header value. @@ -145,87 +148,87 @@ def add_header(self, name: str, value: str) -> None: if self.headers is None: self.headers = {} self.headers[name] = value - + def to_mime_message(self, from_email: Optional[str] = None) -> MIMEMultipart: """Convert the email message to a MIME message. - + Args: from_email: Override the from email address. - + Returns: MIMEMultipart message ready for sending. """ # Create the main message msg = MIMEMultipart("alternative") - + # Set headers msg["Subject"] = self.subject msg["To"] = ", ".join(self.to) msg["From"] = from_email or self.from_email or "" msg["Message-ID"] = self.message_id - + # Set optional headers if self.cc: msg["Cc"] = ", ".join(self.cc) - + if self.reply_to: msg["Reply-To"] = ", ".join(self.reply_to) - + if self.priority: priority_map = {1: "High", 3: "Normal", 5: "Low"} msg["X-Priority"] = str(self.priority) msg["Priority"] = priority_map.get(self.priority, "Normal") - + # Add custom headers if self.headers: for name, value in self.headers.items(): msg[name] = value - + # Add body parts if self.body: text_part = MIMEText(self.body, "plain", "utf-8") msg.attach(text_part) - + if self.html_body: html_part = MIMEText(self.html_body, "html", "utf-8") msg.attach(html_part) - + # Add attachments if self.attachments: for attachment in self.attachments: - part = MIMEBase(*attachment.content_type.split("/", 1)) + part = MIMEBase(*attachment.content_type.split("/", 1)) # ty:ignore[unresolved-attribute] part.set_payload(attachment.content) import email.encoders + email.encoders.encode_base64(part) - + part.add_header( - "Content-Disposition", - f"attachment; filename={attachment.filename}" + "Content-Disposition", f"attachment; filename={attachment.filename}" ) - + if attachment.content_id: part.add_header("Content-ID", f"<{attachment.content_id}>") - + msg.attach(part) - + return msg @dataclass class EmailResult: """Represents the result of sending an email.""" - + success: bool message_id: str to: List[str] subject: str - sent_at: datetime = field(default_factory=datetime.utcnow) + sent_at: datetime = field(default_factory=datetime.now(UTC)) # ty:ignore[no-matching-overload] error: Optional[str] = None provider_response: Optional[Dict[str, Any]] = None - + def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary. - + Returns: Dictionary representation of the result. """ @@ -243,12 +246,12 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class EmailError: """Represents an email sending error.""" - + message: str error_code: Optional[str] = None provider: Optional[str] = None details: Optional[Dict[str, Any]] = None - + def __str__(self) -> str: """String representation of the error.""" if self.error_code: diff --git a/nexios_contrib/mail/tasks.py b/nexios_contrib/mail/tasks.py index 3ceb742..169eb4d 100644 --- a/nexios_contrib/mail/tasks.py +++ b/nexios_contrib/mail/tasks.py @@ -13,32 +13,35 @@ from nexios.http import Request try: - from ..tasks import create_task, Task + from ..tasks import Task, create_task + TASKS_AVAILABLE = True except ImportError: TASKS_AVAILABLE = False from .client import MailClient from .models import EmailMessage, EmailResult +from . import get_mail_from_request + logger = logging.getLogger(__name__) class MailTaskManager: """Manager for email background tasks. - + This class provides methods to send emails in the background using the nexios-contrib tasks system. """ - + def __init__(self, mail_client: MailClient) -> None: """Initialize the mail task manager. - + Args: mail_client: The mail client instance. """ self.mail_client = mail_client - + async def send_email_async( self, to: Union[str, List[str]], @@ -54,10 +57,10 @@ async def send_email_async( template_context: Optional[Dict[str, Any]] = None, priority: str = "normal", timeout: Optional[float] = None, - **kwargs: Any + **kwargs: Any, ) -> Optional[Task]: """Send an email in the background. - + Args: to: Recipient email address(es). subject: Email subject. @@ -73,12 +76,14 @@ async def send_email_async( priority: Task priority ("low", "normal", "high"). timeout: Task timeout in seconds. **kwargs: Additional email parameters. - + Returns: Task instance if tasks are available, None otherwise. """ if not TASKS_AVAILABLE: - logger.warning("Background tasks not available, sending email synchronously") + logger.warning( + "Background tasks not available, sending email synchronously" + ) await self.mail_client.send_email( to=to, subject=subject, @@ -91,12 +96,12 @@ async def send_email_async( attachments=attachments, template_name=template_name, template_context=template_context, - **kwargs + **kwargs, ) return None - + # Create the background task - task = await create_task( + task = create_task( self._send_email_task, to=to, subject=subject, @@ -110,44 +115,46 @@ async def send_email_async( template_name=template_name, template_context=template_context, name=f"send_email_{subject}", - timeout=timeout or self.mail_client.config.task_timeout + timeout=timeout or self.mail_client.config.task_timeout, ) - + logger.info(f"Email task created: {task.id} for {subject}") return task - + async def send_message_async( self, message: EmailMessage, priority: str = "normal", - timeout: Optional[float] = None + timeout: Optional[float] = None, ) -> Optional[Task]: """Send an EmailMessage in the background. - + Args: message: The EmailMessage to send. priority: Task priority ("low", "normal", "high"). timeout: Task timeout in seconds. - + Returns: Task instance if tasks are available, None otherwise. """ if not TASKS_AVAILABLE: - logger.warning("Background tasks not available, sending message synchronously") + logger.warning( + "Background tasks not available, sending message synchronously" + ) await self.mail_client.send_message(message) return None - + # Create the background task - task = await create_task( + task = create_task( self._send_message_task, message, name=f"send_message_{message.subject}", - timeout=timeout or self.mail_client.config.task_timeout + timeout=timeout or self.mail_client.config.task_timeout, ) - + logger.info(f"Email message task created: {task.id} for {message.subject}") return task - + async def send_template_email_async( self, to: Union[str, List[str]], @@ -157,10 +164,10 @@ async def send_template_email_async( from_email: Optional[str] = None, priority: str = "normal", timeout: Optional[float] = None, - **kwargs: Any + **kwargs: Any, ) -> Optional[Task]: """Send a template email in the background. - + Args: to: Recipient email address(es). subject: Email subject. @@ -170,24 +177,26 @@ async def send_template_email_async( priority: Task priority ("low", "normal", "high"). timeout: Task timeout in seconds. **kwargs: Additional email parameters. - + Returns: Task instance if tasks are available, None otherwise. """ if not TASKS_AVAILABLE: - logger.warning("Background tasks not available, sending template email synchronously") + logger.warning( + "Background tasks not available, sending template email synchronously" + ) await self.mail_client.send_template_email( to=to, subject=subject, template_name=template_name, context=context, from_email=from_email, - **kwargs + **kwargs, ) return None - + # Create the background task - task = await create_task( + task = create_task( self._send_template_email_task, to=to, subject=subject, @@ -196,19 +205,19 @@ async def send_template_email_async( from_email=from_email, name=f"send_template_email_{subject}", timeout=timeout or self.mail_client.config.task_timeout, - **kwargs + **kwargs, ) - + logger.info(f"Template email task created: {task.id} for {subject}") return task - + async def _send_email_task(self, *args: Any, **kwargs: Any) -> EmailResult: """Background task for sending emails. - + Args: *args: Positional arguments for send_email. **kwargs: Keyword arguments for send_email. - + Returns: EmailResult from the mail client. """ @@ -222,27 +231,29 @@ async def _send_email_task(self, *args: Any, **kwargs: Any) -> EmailResult: except Exception as e: logger.error(f"Background email task error: {e}") raise - + async def _send_message_task(self, message: EmailMessage) -> EmailResult: """Background task for sending EmailMessage. - + Args: message: The EmailMessage to send. - + Returns: EmailResult from the mail client. """ try: result = await self.mail_client.send_message(message) if result.success: - logger.info(f"Background message sent successfully: {result.message_id}") + logger.info( + f"Background message sent successfully: {result.message_id}" + ) else: logger.error(f"Background message failed: {result.error}") return result except Exception as e: logger.error(f"Background message task error: {e}") raise - + async def _send_template_email_task( self, to: Union[str, List[str]], @@ -250,10 +261,10 @@ async def _send_template_email_task( template_name: str, context: Optional[Dict[str, Any]] = None, from_email: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ) -> EmailResult: """Background task for sending template emails. - + Args: to: Recipient email address(es). subject: Email subject. @@ -261,7 +272,7 @@ async def _send_template_email_task( context: Template context variables. from_email: Sender email address. **kwargs: Additional email parameters. - + Returns: EmailResult from the mail client. """ @@ -272,10 +283,12 @@ async def _send_template_email_task( template_name=template_name, context=context, from_email=from_email, - **kwargs + **kwargs, ) if result.success: - logger.info(f"Background template email sent successfully: {result.message_id}") + logger.info( + f"Background template email sent successfully: {result.message_id}" + ) else: logger.error(f"Background template email failed: {result.error}") return result @@ -287,44 +300,40 @@ async def _send_template_email_task( # Add task manager to mail client def add_task_support(mail_client: MailClient) -> MailTaskManager: """Add background task support to a mail client. - + Args: mail_client: The mail client to extend. - + Returns: MailTaskManager instance. """ task_manager = MailTaskManager(mail_client) - mail_client.tasks = task_manager + mail_client.tasks = task_manager # ty:ignore[unresolved-attribute] return task_manager # Convenience functions for background email sending async def send_email_async( - request: Request, - to: Union[str, List[str]], - subject: str, - **kwargs: Any + request: Request, to: Union[str, List[str]], subject: str, **kwargs: Any ) -> Optional[Task]: """Send an email in the background from a request context. - + Args: request: The current request object. to: Recipient email address(es). subject: Email subject. **kwargs: Additional email parameters. - + Returns: Task instance if tasks are available, None otherwise. """ - from . import get_mail_from_request - + mail_client = get_mail_from_request(request) - - if not hasattr(mail_client, 'tasks'): + + if not hasattr(mail_client, "tasks"): add_task_support(mail_client) - - return await mail_client.tasks.send_email_async(to=to, subject=subject, **kwargs) + + return await mail_client.tasks.send_email_async(to=to, subject=subject, **kwargs) # ty:ignore[unresolved-attribute] async def send_template_email_async( @@ -333,10 +342,10 @@ async def send_template_email_async( subject: str, template_name: str, context: Optional[Dict[str, Any]] = None, - **kwargs: Any + **kwargs: Any, ) -> Optional[Task]: """Send a template email in the background from a request context. - + Args: request: The current request object. to: Recipient email address(es). @@ -344,21 +353,16 @@ async def send_template_email_async( template_name: Name of the template to use. context: Template context variables. **kwargs: Additional email parameters. - + Returns: Task instance if tasks are available, None otherwise. """ - from . import get_mail_from_request - + mail_client = get_mail_from_request(request) - - if not hasattr(mail_client, 'tasks'): + + if not hasattr(mail_client, "tasks"): add_task_support(mail_client) - - return await mail_client.tasks.send_template_email_async( - to=to, - subject=subject, - template_name=template_name, - context=context, - **kwargs + + return await mail_client.tasks.send_template_email_async( # ty:ignore[unresolved-attribute] + to=to, subject=subject, template_name=template_name, context=context, **kwargs ) diff --git a/nexios_contrib/proxy/__init__.py b/nexios_contrib/proxy/__init__.py index 40e0853..943f731 100644 --- a/nexios_contrib/proxy/__init__.py +++ b/nexios_contrib/proxy/__init__.py @@ -4,22 +4,23 @@ This module provides proxy header handling and security for applications running behind proxy servers, load balancers, or CDNs. """ + from __future__ import annotations from .helper import ( + build_forwarded_header, + get_client_ip_from_headers, + get_host_from_headers, + get_protocol_from_headers, + is_trusted_proxy, parse_forwarded_header, parse_x_forwarded_for, - parse_x_forwarded_proto, parse_x_forwarded_host, parse_x_forwarded_port, - get_client_ip_from_headers, - get_protocol_from_headers, - get_host_from_headers, - is_trusted_proxy, + parse_x_forwarded_proto, validate_proxy_headers, - build_forwarded_header, ) -from .middleware import ProxyMiddleware, Proxy, TrustedProxyMiddleware +from .middleware import Proxy, ProxyMiddleware, TrustedProxyMiddleware __all__ = [ "ProxyMiddleware", diff --git a/nexios_contrib/proxy/helper.py b/nexios_contrib/proxy/helper.py index 8b6a519..acb36f8 100644 --- a/nexios_contrib/proxy/helper.py +++ b/nexios_contrib/proxy/helper.py @@ -4,11 +4,11 @@ This module provides utilities for handling proxy headers and managing applications behind proxy servers. """ + from __future__ import annotations import ipaddress -from typing import List, Optional, Union -from urllib.parse import urlparse +from typing import List, Optional from nexios.http import Request @@ -28,7 +28,7 @@ def parse_forwarded_header(forwarded_header: str) -> dict: return result # Split by comma and semicolon - for forwarded in forwarded_header.split(','): + for forwarded in forwarded_header.split(","): forwarded = forwarded.strip() if not forwarded: continue @@ -41,10 +41,10 @@ def parse_forwarded_header(forwarded_header: str) -> dict: while i < len(forwarded): char = forwarded[i] - if char == '=': + if char == "=": current_param = current_value.strip() current_value = "" - elif char == ';': + elif char == ";": if current_param: params[current_param] = current_value.strip() current_param = "" @@ -59,14 +59,14 @@ def parse_forwarded_header(forwarded_header: str) -> dict: # Process standard forwarded parameters for key, value in params.items(): - if key.lower() == 'for': - result['for'] = value - elif key.lower() == 'by': - result['by'] = value - elif key.lower() == 'host': - result['host'] = value - elif key.lower() == 'proto': - result['proto'] = value + if key.lower() == "for": + result["for"] = value + elif key.lower() == "by": + result["by"] = value + elif key.lower() == "host": + result["host"] = value + elif key.lower() == "proto": + result["proto"] = value return result @@ -84,7 +84,7 @@ def parse_x_forwarded_for(x_forwarded_for: str) -> List[str]: if not x_forwarded_for: return [] - return [ip.strip() for ip in x_forwarded_for.split(',') if ip.strip()] + return [ip.strip() for ip in x_forwarded_for.split(",") if ip.strip()] def parse_x_forwarded_proto(x_forwarded_proto: str) -> Optional[str]: @@ -101,7 +101,7 @@ def parse_x_forwarded_proto(x_forwarded_proto: str) -> Optional[str]: return None proto = x_forwarded_proto.strip().lower() - return proto if proto in ['http', 'https'] else None + return proto if proto in ["http", "https"] else None def parse_x_forwarded_host(x_forwarded_host: str) -> Optional[str]: @@ -143,7 +143,9 @@ def parse_x_forwarded_port(x_forwarded_port: str) -> Optional[int]: return None -def get_client_ip_from_headers(request: Request, trusted_proxies: List[str] = None) -> Optional[str]: +def get_client_ip_from_headers( + request: Request, trusted_proxies: List[str] = None +) -> Optional[str]: """ Extract the real client IP from proxy headers. @@ -157,7 +159,7 @@ def get_client_ip_from_headers(request: Request, trusted_proxies: List[str] = No trusted_proxies = trusted_proxies or [] # Try X-Forwarded-For first - x_forwarded_for = request.headers.get('X-Forwarded-For') + x_forwarded_for = request.headers.get("X-Forwarded-For") if x_forwarded_for: forwarded_ips = parse_x_forwarded_for(x_forwarded_for) if forwarded_ips: @@ -167,16 +169,16 @@ def get_client_ip_from_headers(request: Request, trusted_proxies: List[str] = No return ip # Try Forwarded header - forwarded = request.headers.get('Forwarded') + forwarded = request.headers.get("Forwarded") if forwarded: parsed = parse_forwarded_header(forwarded) - if 'for' in parsed: - for_value = parsed['for'] + if "for" in parsed: + for_value = parsed["for"] if for_value not in trusted_proxies: return for_value # Fall back to direct client IP - return getattr(request, 'client_ip', None) + return getattr(request, "client_ip", None) def get_protocol_from_headers(request: Request) -> str: @@ -190,19 +192,23 @@ def get_protocol_from_headers(request: Request) -> str: str: The protocol (http/https). """ # Check X-Forwarded-Proto first - proto = parse_x_forwarded_proto(request.headers.get('X-Forwarded-Proto')) + proto = parse_x_forwarded_proto(request.headers.get("X-Forwarded-Proto","")) if proto: return proto # Check Forwarded header - forwarded = request.headers.get('Forwarded') + forwarded = request.headers.get("Forwarded") if forwarded: parsed = parse_forwarded_header(forwarded) - if 'proto' in parsed: - return parsed['proto'] + if "proto" in parsed: + return parsed["proto"] # Fall back to request URL scheme - return getattr(request, 'url', '').split('://')[0] if getattr(request, 'url', '') else 'http' + return ( + getattr(request, "url", "").split("://")[0] + if getattr(request, "url", "") + else "http" + ) def get_host_from_headers(request: Request) -> Optional[str]: @@ -216,19 +222,19 @@ def get_host_from_headers(request: Request) -> Optional[str]: Optional[str]: The real host. """ # Check X-Forwarded-Host first - host = parse_x_forwarded_host(request.headers.get('X-Forwarded-Host')) + host = parse_x_forwarded_host(request.headers.get("X-Forwarded-Host","")) if host: return host # Check Forwarded header - forwarded = request.headers.get('Forwarded') + forwarded = request.headers.get("Forwarded") if forwarded: parsed = parse_forwarded_header(forwarded) - if 'host' in parsed: - return parsed['host'] + if "host" in parsed: + return parsed["host"] # Fall back to request host - return getattr(request, 'host', None) + return getattr(request, "host", None) def is_trusted_proxy(client_ip: str, trusted_proxies: List[str]) -> bool: @@ -249,7 +255,7 @@ def is_trusted_proxy(client_ip: str, trusted_proxies: List[str]) -> bool: for proxy in trusted_proxies: try: - if '/' in proxy: + if "/" in proxy: # CIDR notation network = ipaddress.ip_network(proxy, strict=False) if client_addr in network: @@ -275,30 +281,31 @@ def validate_proxy_headers(request: Request, trusted_proxies: List[str] = None) Returns: dict: Validated proxy information. """ - client_ip = getattr(request, 'client_ip', None) + client_ip = getattr(request, "client_ip", None) # Only process proxy headers if client is from trusted proxy if client_ip and trusted_proxies and is_trusted_proxy(client_ip, trusted_proxies): return { - 'client_ip': get_client_ip_from_headers(request, trusted_proxies), - 'protocol': get_protocol_from_headers(request), - 'host': get_host_from_headers(request), - 'trusted_proxy': True + "client_ip": get_client_ip_from_headers(request, trusted_proxies), + "protocol": get_protocol_from_headers(request), + "host": get_host_from_headers(request), + "trusted_proxy": True, } return { - 'client_ip': client_ip, - 'protocol': getattr(request, 'url', '').split('://')[0] if getattr(request, 'url', '') else 'http', - 'host': getattr(request, 'host', None), - 'trusted_proxy': False + "client_ip": client_ip, + "protocol": ( + getattr(request, "url", "").split("://")[0] + if getattr(request, "url", "") + else "http" + ), + "host": getattr(request, "host", None), + "trusted_proxy": False, } def build_forwarded_header( - client_ip: str, - protocol: str = None, - host: str = None, - by: str = None + client_ip: str, protocol: str = None, host: str = None, by: str = None ) -> str: """ Build a Forwarded header according to RFC 7239. diff --git a/nexios_contrib/proxy/middleware.py b/nexios_contrib/proxy/middleware.py index e71d261..6130b81 100644 --- a/nexios_contrib/proxy/middleware.py +++ b/nexios_contrib/proxy/middleware.py @@ -4,6 +4,7 @@ This middleware handles applications running behind proxy servers by properly processing X-Forwarded-* headers and managing proxy-related security. """ + from __future__ import annotations from typing import Any, List, Optional @@ -12,13 +13,10 @@ from nexios.middleware.base import BaseMiddleware from .helper import ( - get_client_ip_from_headers, - get_protocol_from_headers, - get_host_from_headers, is_trusted_proxy, - validate_proxy_headers, - parse_x_forwarded_for, parse_forwarded_header, + parse_x_forwarded_for, + validate_proxy_headers, ) @@ -90,40 +88,51 @@ async def process_request( Returns: Any: The result from the next middleware or handler. """ - original_client_ip = getattr(request, 'client_ip', None) - original_host = getattr(request, 'host', None) - original_url = getattr(request, 'url', '') + original_client_ip = getattr(request, "client_ip", None) + original_host = getattr(request, "host", None) + original_url = getattr(request, "url", "") # Validate proxy headers and extract real client information proxy_info = validate_proxy_headers(request, self.trusted_proxies) # Update request with real client IP if from trusted proxy - if proxy_info['trusted_proxy'] and proxy_info['client_ip']: - setattr(request, 'client_ip', proxy_info['client_ip']) + if proxy_info["trusted_proxy"] and proxy_info["client_ip"]: + setattr(request, "client_ip", proxy_info["client_ip"]) # Update request URL scheme if protocol changed - if proxy_info['protocol'] and proxy_info['protocol'] != original_url.split('://')[0]: + if ( + proxy_info["protocol"] + and proxy_info["protocol"] != original_url.split("://")[0] + ): if original_url: # Reconstruct URL with correct scheme - new_scheme = proxy_info['protocol'] - rest_of_url = original_url.split('://', 1)[1] if '://' in original_url else original_url - setattr(request, 'url', f"{new_scheme}://{rest_of_url}") + new_scheme = proxy_info["protocol"] + rest_of_url = ( + original_url.split("://", 1)[1] + if "://" in original_url + else original_url + ) + setattr(request, "url", f"{new_scheme}://{rest_of_url}") # Update host if provided and not preserving original - if not self.preserve_host_header and proxy_info['host']: - setattr(request, 'host', proxy_info['host']) + if not self.preserve_host_header and proxy_info["host"]: + setattr(request, "host", proxy_info["host"]) # Store proxy information in request for later access if self.store_proxy_info: - setattr(request, 'proxy_info', proxy_info) - setattr(request, 'x_forwarded_for', parse_x_forwarded_for( - request.headers.get('X-Forwarded-For', '') - )) + setattr(request, "proxy_info", proxy_info) + setattr( + request, + "x_forwarded_for", + parse_x_forwarded_for(request.headers.get("X-Forwarded-For", "")), + ) if self.trust_forwarded_header: - forwarded_data = parse_forwarded_header(request.headers.get('Forwarded', '')) + forwarded_data = parse_forwarded_header( + request.headers.get("Forwarded", "") + ) if forwarded_data: - setattr(request, 'forwarded_header', forwarded_data) + setattr(request, "forwarded_header", forwarded_data) return await call_next() @@ -143,9 +152,11 @@ async def process_response( Any: The response object. """ # Add X-Client-IP header if we have real client IP info - proxy_info = getattr(request, 'proxy_info', None) - if proxy_info and proxy_info['client_ip'] != getattr(request, 'client_ip', None): - response.set_header('X-Client-IP', proxy_info['client_ip']) + proxy_info = getattr(request, "proxy_info", None) + if proxy_info and proxy_info["client_ip"] != getattr( + request, "client_ip", None + ): + response.set_header("X-Client-IP", proxy_info["client_ip"]) return response @@ -208,7 +219,7 @@ def __init__( trusted_proxies=trusted_proxies, trust_forwarded_headers=True, trust_forwarded_header=True, - **kwargs + **kwargs, ) self.require_https = require_https self.max_forwards = max_forwards @@ -223,11 +234,11 @@ async def process_request( Process request with enhanced security checks. """ # Check if client is from trusted proxy - client_ip = getattr(request, 'client_ip', None) + client_ip = getattr(request, "client_ip", None) if not client_ip or not is_trusted_proxy(client_ip, self.trusted_proxies): # Not from trusted proxy, don't process proxy headers if self.store_proxy_info: - setattr(request, 'proxy_info', {'trusted_proxy': False}) + setattr(request, "proxy_info", {"trusted_proxy": False}) return await call_next() # Process proxy headers @@ -235,16 +246,18 @@ async def process_request( # Additional security checks if self.require_https: - proxy_info = getattr(request, 'proxy_info', {}) - if proxy_info.get('protocol') != 'https': + proxy_info = getattr(request, "proxy_info", {}) + if proxy_info.get("protocol") != "https": # Could redirect to HTTPS or return error - response.status_code = 400 + response.status(400) return {"error": "HTTPS required when behind proxy"} # Check for too many forwards - x_forwarded_for = parse_x_forwarded_for(request.headers.get('X-Forwarded-For', '')) + x_forwarded_for = parse_x_forwarded_for( + request.headers.get("X-Forwarded-For", "") + ) if len(x_forwarded_for) > self.max_forwards: - response.status_code = 400 + response.status(400) return {"error": "Too many proxy hops"} return result diff --git a/nexios_contrib/redis/__init__.py b/nexios_contrib/redis/__init__.py index 0384f30..295217c 100644 --- a/nexios_contrib/redis/__init__.py +++ b/nexios_contrib/redis/__init__.py @@ -8,15 +8,14 @@ import json import logging -from typing import Dict, List, Optional, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, List, Optional, Union -from nexios import NexiosApp, Depend +from nexios import Depend, NexiosApp from nexios.http import Request, Response from .client import RedisClient, RedisOperationError from .config import RedisConfig - if TYPE_CHECKING: from nexios.dependencies import Context @@ -50,7 +49,7 @@ def init_redis( health_check_interval: int = 30, max_connections: Optional[int] = None, retry_on_timeout: bool = False, -) -> None: +) -> RedisClient: """ Initialize Redis client for use in a Nexios application. @@ -113,7 +112,7 @@ def init_redis( retry_on_timeout=retry_on_timeout, ) - _redis_client = RedisClient(config) + _redis_client = RedisClient(config) # ty:ignore[invalid-argument-type] # Store Redis client in app state for easy access app.state["redis"] = _redis_client diff --git a/nexios_contrib/redis/client.py b/nexios_contrib/redis/client.py index d37607e..9305bf8 100644 --- a/nexios_contrib/redis/client.py +++ b/nexios_contrib/redis/client.py @@ -1,6 +1,7 @@ """ Redis client wrapper for Nexios integration. """ + from __future__ import annotations import asyncio @@ -12,14 +13,16 @@ # Check for async redis availability at module load time try: import redis.asyncio as async_redis + _REDIS_AVAILABLE = True except ImportError: - async_redis = None + async_redis = None # ty:ignore[invalid-assignment] _REDIS_AVAILABLE = False class RedisOperationError(Exception): """Raised when there's an error performing a Redis operation.""" + pass @@ -90,7 +93,9 @@ async def json_get(self, key: str, path: str = ".") -> Any: except Exception as e: raise RedisOperationError(f"Failed to get JSON from key '{key}': {e}") - async def json_set(self, key: str, path: str, value: Any, nx: bool = False, xx: bool = False) -> bool: + async def json_set( + self, key: str, path: str, value: Any, nx: bool = False, xx: bool = False + ) -> bool: """ Set JSON value in Redis. diff --git a/nexios_contrib/redis/config.py b/nexios_contrib/redis/config.py index e05397f..5bf59f3 100644 --- a/nexios_contrib/redis/config.py +++ b/nexios_contrib/redis/config.py @@ -1,12 +1,13 @@ """ Redis configuration for Nexios Redis integration. """ + from __future__ import annotations import os from typing import Any, Dict, Optional -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator class RedisConfig(BaseModel): @@ -18,61 +19,36 @@ class RedisConfig(BaseModel): """ url: str = Field( - default="redis://localhost:6379", - description="Redis connection URL" - ) - db: int = Field( - default=0, - ge=0, - le=15, - description="Redis database number (0-15)" - ) - password: Optional[str] = Field( - default=None, - description="Redis password" + default="redis://localhost:6379", description="Redis connection URL" ) + db: int = Field(default=0, ge=0, le=15, description="Redis database number (0-15)") + password: Optional[str] = Field(default=None, description="Redis password") decode_responses: bool = Field( - default=True, - description="Whether to decode responses as strings" - ) - encoding: str = Field( - default="utf-8", - description="Encoding for string operations" + default=True, description="Whether to decode responses as strings" ) + encoding: str = Field(default="utf-8", description="Encoding for string operations") encoding_errors: str = Field( - default="strict", - description="Error handling for encoding" + default="strict", description="Error handling for encoding" ) socket_timeout: Optional[float] = Field( - default=None, - description="Socket timeout in seconds" + default=None, description="Socket timeout in seconds" ) socket_connect_timeout: Optional[float] = Field( - default=None, - description="Socket connect timeout in seconds" - ) - socket_keepalive: bool = Field( - default=False, - description="Enable TCP keepalive" + default=None, description="Socket connect timeout in seconds" ) + socket_keepalive: bool = Field(default=False, description="Enable TCP keepalive") socket_keepalive_options: Optional[Dict[int, int]] = Field( - default=None, - description="TCP keepalive options" + default=None, description="TCP keepalive options" ) health_check_interval: int = Field( - default=30, - description="Health check interval in seconds" + default=30, description="Health check interval in seconds" ) max_connections: Optional[int] = Field( - default=None, - description="Maximum connection pool size" - ) - retry_on_timeout: bool = Field( - default=False, - description="Retry on timeout" + default=None, description="Maximum connection pool size" ) + retry_on_timeout: bool = Field(default=False, description="Retry on timeout") - @validator("url") + @field_validator("url") def validate_url(cls, v: str) -> str: """Validate Redis URL format.""" if not v.startswith(("redis://", "rediss://", "unix://")): @@ -108,20 +84,25 @@ def from_env(cls, prefix: str = "REDIS_") -> RedisConfig: if env_value is not None: # Type conversion for specific fields - if field.type_ in (int, float): + if field.type_ in (int, float): #ty: ignore try: - if field.type_ == int: + if isinstance(field,int): env_vars[field_name] = int(env_value) else: env_vars[field_name] = float(env_value) except ValueError: continue # Skip invalid values - elif field.type_ == bool: - env_vars[field_name] = env_value.lower() in ("true", "1", "yes", "on") + elif isinstance(field,bool): + env_vars[field_name] = env_value.lower() in ( + "true", + "1", + "yes", + "on", + ) else: env_vars[field_name] = env_value - return cls(**env_vars) + return cls(**env_vars) # ty:ignore[invalid-argument-type] def to_connection_kwargs(self) -> Dict[str, Any]: """ @@ -130,7 +111,7 @@ def to_connection_kwargs(self) -> Dict[str, Any]: Returns: Dictionary of kwargs to pass to Redis client """ - kwargs = self.dict(exclude={"url"}) + kwargs = self.model_dump(exclude={"url"}) # Handle URL parsing for host, port, etc. if self.url.startswith(("redis://", "rediss://")): @@ -154,7 +135,7 @@ def to_connection_kwargs(self) -> Dict[str, Any]: def __str__(self) -> str: """String representation of Redis config (without sensitive data).""" - safe_dict = self.dict() + safe_dict = self.model_dump() if self.password: safe_dict["password"] = "***" return f"RedisConfig({safe_dict})" diff --git a/nexios_contrib/redis/dependency.py b/nexios_contrib/redis/dependency.py index edb5480..fdae826 100644 --- a/nexios_contrib/redis/dependency.py +++ b/nexios_contrib/redis/dependency.py @@ -6,11 +6,12 @@ """ from __future__ import annotations +from typing import cast -from nexios.dependencies import Depend, Context +from nexios.dependencies import Depend -from .client import RedisClient from . import get_redis +from .client import RedisClient def RedisDepend() -> RedisClient: @@ -42,4 +43,4 @@ async def get_cached_data( ``` """ - return Depend(get_redis) + return cast(RedisClient,Depend(get_redis)) diff --git a/nexios_contrib/redis/utils.py b/nexios_contrib/redis/utils.py index e214535..45ab8e2 100644 --- a/nexios_contrib/redis/utils.py +++ b/nexios_contrib/redis/utils.py @@ -4,9 +4,9 @@ This module provides convenient functions for common Redis operations that can be used directly in Nexios route handlers. """ + from __future__ import annotations -import json from typing import Any, Dict, List, Optional, Union from nexios_contrib.redis import get_redis_client @@ -31,8 +31,8 @@ async def redis_get(key: str) -> Optional[str]: print(f"User data: {value}") ``` """ - redis = get_redis_client() - return await redis.get(key) + redis = get_redis_client() + return await redis.get(key) # ty:ignore[invalid-return-type] async def redis_set( @@ -69,7 +69,7 @@ async def redis_set( ``` """ redis = get_redis_client() - return await redis.set(key, value, ex=ex, px=px, nx=nx, xx=xx) + return await redis.set(key, value, ex=ex, px=px, nx=nx, xx=xx) # ty:ignore[invalid-return-type] async def redis_delete(*keys: str) -> int: @@ -192,7 +192,7 @@ async def redis_incr(key: str, amount: int = 1) -> int: ``` """ redis = get_redis_client() - return await redis.incr(key, amount) + return await redis.incr(key, amount) # ty:ignore[invalid-await] async def redis_decr(key: str, amount: int = 1) -> int: @@ -215,7 +215,7 @@ async def redis_decr(key: str, amount: int = 1) -> int: ``` """ redis = get_redis_client() - return await redis.decr(key, amount) + return await redis.decr(key, amount) # ty:ignore[invalid-await] async def redis_json_get(key: str, path: str = ".") -> Any: @@ -295,7 +295,7 @@ async def redis_hget(key: str, field: str) -> Optional[str]: ``` """ redis = get_redis_client() - return await redis.hget(key, field) + return await redis.hget(key, field) # ty:ignore[invalid-return-type] async def redis_hset(key: str, field: str, value: Union[str, int, float]) -> int: @@ -341,7 +341,7 @@ async def redis_hgetall(key: str) -> Dict[str, str]: ``` """ redis = get_redis_client() - return await redis.hgetall(key) + return await redis.hgetall(key) # ty:ignore[invalid-return-type] async def redis_lpush(key: str, *values: Union[str, int, float]) -> int: @@ -390,7 +390,9 @@ async def redis_rpush(key: str, *values: Union[str, int, float]) -> int: return await redis.rpush(key, *values) -async def redis_lpop(key: str, count: Optional[int] = None) -> Union[Optional[str], List[str]]: +async def redis_lpop( + key: str, count: Optional[int] = None +) -> Union[Optional[str], List[str]]: """ Remove and return the first element(s) of a list. @@ -411,10 +413,12 @@ async def redis_lpop(key: str, count: Optional[int] = None) -> Union[Optional[st ``` """ redis = get_redis_client() - return await redis.lpop(key, count) + return await redis.lpop(key, count) # ty:ignore[invalid-return-type] -async def redis_rpop(key: str, count: Optional[int] = None) -> Union[Optional[str], List[str]]: +async def redis_rpop( + key: str, count: Optional[int] = None +) -> Union[Optional[str], List[str]]: """ Remove and return the last element(s) of a list. @@ -435,7 +439,7 @@ async def redis_rpop(key: str, count: Optional[int] = None) -> Union[Optional[st ``` """ redis = get_redis_client() - return await redis.rpop(key, count) + return await redis.rpop(key, count) # ty:ignore[invalid-return-type] async def redis_llen(key: str) -> int: @@ -502,7 +506,7 @@ async def redis_smembers(key: str) -> List[str]: ``` """ redis = get_redis_client() - return await redis.smembers(key) + return await redis.smembers(key) # ty:ignore[invalid-return-type] async def redis_srem(key: str, *members: Union[str, int, float]) -> int: @@ -569,7 +573,7 @@ async def redis_keys(pattern: str = "*") -> List[str]: ``` """ redis = get_redis_client() - return await redis.keys(pattern) + return await redis.keys(pattern) # ty:ignore[invalid-return-type] async def redis_flushdb(asynchronous: bool = False) -> bool: diff --git a/nexios_contrib/request_id/__init__.py b/nexios_contrib/request_id/__init__.py index 8f16e35..a22cbc5 100644 --- a/nexios_contrib/request_id/__init__.py +++ b/nexios_contrib/request_id/__init__.py @@ -4,18 +4,19 @@ This module provides automatic request ID generation and management for better request tracing and debugging in Nexios applications. """ + from __future__ import annotations from .helper import ( generate_request_id, + get_or_generate_request_id, get_request_id_from_header, + get_request_id_from_request, set_request_id_header, - get_or_generate_request_id, - validate_request_id, store_request_id_in_request, - get_request_id_from_request, + validate_request_id, ) -from .middleware import RequestIdMiddleware, RequestId +from .middleware import RequestId, RequestIdMiddleware __all__ = [ "RequestIdMiddleware", diff --git a/nexios_contrib/request_id/dependency.py b/nexios_contrib/request_id/dependency.py index 4efef8a..2fc535e 100644 --- a/nexios_contrib/request_id/dependency.py +++ b/nexios_contrib/request_id/dependency.py @@ -1,8 +1,10 @@ -from nexios.dependencies import Depend,Context +from nexios.dependencies import Context, Depend + from nexios_contrib.request_id import get_request_id_from_request + def RequestIdDepend(attribute_name: str = "request_id"): - def _wrap(ctx = Context()): - return get_request_id_from_request(ctx.request, attribute_name) + def _wrap(ctx=Context()): + return get_request_id_from_request(ctx.request, attribute_name) # ty:ignore[invalid-argument-type] + return Depend(_wrap) - \ No newline at end of file diff --git a/nexios_contrib/request_id/helper.py b/nexios_contrib/request_id/helper.py index da6c0a4..766f3c3 100644 --- a/nexios_contrib/request_id/helper.py +++ b/nexios_contrib/request_id/helper.py @@ -4,6 +4,7 @@ This module provides utilities for generating, managing, and working with request IDs in the Nexios framework. """ + from __future__ import annotations import uuid @@ -23,8 +24,7 @@ def generate_request_id() -> str: def get_request_id_from_header( - request: Request, - header_name: str = "X-Request-ID" + request: Request, header_name: str = "X-Request-ID" ) -> Optional[str]: """ Extract request ID from request headers. @@ -40,9 +40,7 @@ def get_request_id_from_header( def set_request_id_header( - response: Response, - request_id: str, - header_name: str = "X-Request-ID" + response: Response, request_id: str, header_name: str = "X-Request-ID" ) -> None: """ Set the request ID in the response headers. @@ -52,12 +50,11 @@ def set_request_id_header( request_id: The request ID to set. header_name: The header name to use (default: "X-Request-ID"). """ - response.set_header(header_name, request_id,overide=True) + response.set_header(header_name, request_id, overide=True) def get_or_generate_request_id( - request: Request, - header_name: str = "X-Request-ID" + request: Request, header_name: str = "X-Request-ID" ) -> str: """ Get request ID from request headers or generate a new one. @@ -93,9 +90,7 @@ def validate_request_id(request_id: str) -> bool: def store_request_id_in_request( - request: Request, - request_id: str, - attribute_name: str = "request_id" + request: Request, request_id: str, attribute_name: str = "request_id" ) -> None: """ Store request ID in the request object for later access. @@ -109,8 +104,7 @@ def store_request_id_in_request( def get_request_id_from_request( - request: Request, - attribute_name: str = "request_id" + request: Request, attribute_name: str = "request_id" ) -> Optional[str]: """ Retrieve request ID from the request object. @@ -122,4 +116,4 @@ def get_request_id_from_request( Returns: Optional[str]: The stored request ID if found, None otherwise. """ - return getattr(request.state,attribute_name,None) + return getattr(request.state, attribute_name, None) diff --git a/nexios_contrib/request_id/middleware.py b/nexios_contrib/request_id/middleware.py index cf37820..3b7488e 100644 --- a/nexios_contrib/request_id/middleware.py +++ b/nexios_contrib/request_id/middleware.py @@ -4,9 +4,10 @@ This middleware automatically generates or extracts request IDs from incoming requests and includes them in response headers for better request tracing and debugging. """ + from __future__ import annotations -from typing import Any, Optional +from typing import Any from nexios.http import Request, Response from nexios.middleware.base import BaseMiddleware @@ -17,7 +18,6 @@ get_request_id_from_header, set_request_id_header, store_request_id_in_request, - get_request_id_from_request, ) @@ -96,7 +96,9 @@ async def process_request( self.request_id = request_id # Store request ID in request object if enabled if self.store_in_request: - store_request_id_in_request(request, request_id, self.request_attribute_name) + store_request_id_in_request( + request, request_id, self.request_attribute_name + ) # Store request ID in response headers if enabled if self.include_in_response: diff --git a/nexios_contrib/scalar/__init__.py b/nexios_contrib/scalar/__init__.py index 2aafe0b..9c788b6 100644 --- a/nexios_contrib/scalar/__init__.py +++ b/nexios_contrib/scalar/__init__.py @@ -6,7 +6,19 @@ # Re-export scalar_doc classes for convenience try: - from scalar_doc import ScalarConfiguration, ScalarTheme, ScalarHeader, ScalarColorSchema - __all__ = ["Scalar", "ScalarConfiguration", "ScalarTheme", "ScalarHeader", "ScalarColorSchema"] + from scalar_doc import ( + ScalarColorSchema, + ScalarConfiguration, + ScalarHeader, + ScalarTheme, + ) + + __all__ = [ + "Scalar", + "ScalarConfiguration", + "ScalarTheme", + "ScalarHeader", + "ScalarColorSchema", + ] except ImportError: __all__ = ["Scalar"] diff --git a/nexios_contrib/scalar/plugin.py b/nexios_contrib/scalar/plugin.py index e1f1358..4ea303e 100644 --- a/nexios_contrib/scalar/plugin.py +++ b/nexios_contrib/scalar/plugin.py @@ -2,13 +2,19 @@ Scalar DOC plugin for Nexios - Beautiful OpenAPI documentation using Scalar. """ -from typing import Optional, Dict, Any, Union +from typing import Any, Dict, Optional, Union + from nexios.application import NexiosApp from nexios.http import Request, Response from nexios.routing import Route try: - from scalar_doc import ScalarDoc, ScalarConfiguration, ScalarTheme, ScalarHeader, ScalarColorSchema + from scalar_doc import ( + ScalarConfiguration, + ScalarDoc, + ScalarHeader, + ScalarTheme, + ) except ImportError: raise ImportError( "scalar_doc is required for the Scalar plugin. " @@ -19,10 +25,10 @@ class Scalar: """ Scalar DOC plugin for Nexios. - + Provides beautiful, interactive OpenAPI documentation using Scalar. """ - + def __init__( self, app: NexiosApp, @@ -37,7 +43,7 @@ def __init__( ): """ Initialize Scalar documentation. - + Args: app: NexiosApp instance path: URL path for the documentation @@ -58,19 +64,17 @@ def __init__( self.header = header self.custom_spec = custom_spec self.spec_mode = spec_mode - + self._setup() - + def _setup(self): """Register the Scalar documentation route.""" - self.app.add_route( - Route(self.path, self.handle_request, methods=["GET"]) - ) - + self.app.add_route(Route(self.path, self.handle_request, methods=["GET"])) + async def handle_request(self, req: Request, res: Response): """Handle Scalar documentation requests.""" return res.html(self._generate_html()) - + def _generate_html(self) -> str: """Generate the Scalar HTML documentation using scalar_doc.""" # Determine the OpenAPI spec source @@ -87,45 +91,40 @@ def _generate_html(self) -> str: else: # Use the app's OpenAPI URL docs = ScalarDoc.from_spec(spec=self.openapi_url, mode="url") - + # Set title docs.set_title(self.title) - + # Set configuration if provided if self.configuration: docs.set_configuration(self.configuration) - + # Set theme if provided if self.theme: docs.set_theme(self.theme) - + # Set header if provided if self.header: docs.set_header(self.header) - + # Generate HTML using scalar_doc return docs.to_html() - + @classmethod def from_spec( cls, app: NexiosApp, spec: Union[str, Dict[str, Any]], mode: str = "url", - **kwargs + **kwargs, ): """ Create Scalar instance from a custom OpenAPI spec. - + Args: app: NexiosApp instance spec: OpenAPI spec (URL, JSON string, or dict) mode: Mode for the spec ("url", "json", or "dict") **kwargs: Additional arguments for Scalar constructor """ - return cls( - app=app, - custom_spec=spec, - spec_mode=mode, - **kwargs - ) + return cls(app=app, custom_spec=spec, spec_mode=mode, **kwargs) diff --git a/nexios_contrib/scheduler/__init__.py b/nexios_contrib/scheduler/__init__.py index 45d47e6..865aeee 100644 --- a/nexios_contrib/scheduler/__init__.py +++ b/nexios_contrib/scheduler/__init__.py @@ -4,6 +4,7 @@ Provides interval-based, cron-based, and one-time job scheduling integrated with the Nexios application lifecycle and dependency injection. """ + from __future__ import annotations from typing import Optional @@ -73,9 +74,9 @@ async def my_task(): """ if not hasattr(app, "scheduler"): scheduler = SchedulerManager(app, config=config) - app.scheduler = scheduler + app.scheduler = scheduler # ty:ignore[invalid-assignment] app.on_startup(scheduler.start) - return app.scheduler + return app.scheduler # ty:ignore[unresolved-attribute] def get_scheduler(app: NexiosApp) -> SchedulerManager: diff --git a/nexios_contrib/scheduler/config.py b/nexios_contrib/scheduler/config.py index c11ee3b..6a5a5f7 100644 --- a/nexios_contrib/scheduler/config.py +++ b/nexios_contrib/scheduler/config.py @@ -3,6 +3,7 @@ This module provides configuration options and enums for the scheduler system. """ + from __future__ import annotations import logging @@ -57,12 +58,7 @@ def __post_init__(self) -> None: def as_seconds(self) -> float: """Return the total interval in seconds.""" - return ( - self.days * 86400 - + self.hours * 3600 - + self.minutes * 60 - + self.seconds - ) + return self.days * 86400 + self.hours * 3600 + self.minutes * 60 + self.seconds @dataclass @@ -163,8 +159,6 @@ def get_next_run(self, from_timestamp: float) -> float: Uses a simple minute-resolution iteration starting from ``from_timestamp``. """ - import calendar - import time as time_module from datetime import datetime, timedelta, timezone dt = datetime.fromtimestamp(from_timestamp, tz=timezone.utc) diff --git a/nexios_contrib/scheduler/dependency.py b/nexios_contrib/scheduler/dependency.py index 97454c8..0040c28 100644 --- a/nexios_contrib/scheduler/dependency.py +++ b/nexios_contrib/scheduler/dependency.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast from nexios.dependencies import Context, Depend from nexios.http import Request @@ -25,7 +25,7 @@ class SchedulerDepend: def __init__(self, request: Request) -> None: self.request = request - self.scheduler: SchedulerManager = request.base_app.scheduler + self.scheduler: SchedulerManager = request.base_app.scheduler # ty:ignore[unresolved-attribute] def add_job( self, @@ -76,8 +76,9 @@ def resume_job(self, job_id: str) -> bool: def _get_scheduler_depend(ctx: Context = Context()) -> SchedulerDepend: """Factory used by the ``SchedulerDepends`` callable.""" - return SchedulerDepend(ctx.request) + + return SchedulerDepend(ctx.request) # ty:ignore[invalid-argument-type] def SchedulerDepends() -> SchedulerDepend: - return Depend(_get_scheduler_depend) + return cast(typ=SchedulerDepend, val=Depend(_get_scheduler_depend)) \ No newline at end of file diff --git a/nexios_contrib/scheduler/manager.py b/nexios_contrib/scheduler/manager.py index 15e672b..c845543 100644 --- a/nexios_contrib/scheduler/manager.py +++ b/nexios_contrib/scheduler/manager.py @@ -4,6 +4,7 @@ This module provides the SchedulerManager class which is responsible for managing scheduled jobs and their execution lifecycle in a Nexios application. """ + from __future__ import annotations import asyncio @@ -131,9 +132,10 @@ def add_job( name=name, args=args or (), kwargs=kwargs or {}, - max_instances=max_instances or self.config.job_defaults.get("max_instances", 3), + max_instances=max_instances + or self.config.job_defaults.get("max_instances", 3), misfire_grace_time=misfire_grace_time - or self.config.job_defaults.get("misfire_grace_time", 30), + or self.config.job_defaults.get("misfire_grace_time", 30), coalesce=coalesce or self.config.job_defaults.get("coalesce", True), id=id, ) @@ -205,9 +207,7 @@ def resume_job(self, job_id: str) -> bool: @property def _active_count(self) -> int: - return sum( - 1 for j in self._jobs.values() if j.status == JobStatus.ACTIVE - ) + return sum(1 for j in self._jobs.values() if j.status == JobStatus.ACTIVE) async def _ticker_loop(self) -> None: """Background loop that checks every second for due jobs.""" diff --git a/nexios_contrib/scheduler/models.py b/nexios_contrib/scheduler/models.py index 7bec6ad..259e7e2 100644 --- a/nexios_contrib/scheduler/models.py +++ b/nexios_contrib/scheduler/models.py @@ -4,11 +4,11 @@ This module defines the Job class that represents a scheduled task along with its trigger configuration and execution state. """ + from __future__ import annotations import logging import time -from dataclasses import dataclass, field from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar from uuid import uuid4 @@ -45,7 +45,7 @@ def __init__( id: Optional[str] = None, ) -> None: self.id = id or str(uuid4()) - self.name = name or func.__name__ + self.name = name or func.__name__ # ty:ignore[unresolved-attribute] self.func = func self.trigger = trigger self.args = args or () @@ -128,9 +128,7 @@ async def run(self) -> Any: except Exception as exc: self._last_error = str(exc) self._status = JobStatus.FAILED - self._logger.exception( - "Job %s (%s) failed: %s", self.id, self.name, exc - ) + self._logger.exception("Job %s (%s) failed: %s", self.id, self.name, exc) raise finally: self._current_instances -= 1 diff --git a/nexios_contrib/slashes/__init__.py b/nexios_contrib/slashes/__init__.py index 9154ef2..4c8e839 100644 --- a/nexios_contrib/slashes/__init__.py +++ b/nexios_contrib/slashes/__init__.py @@ -2,7 +2,6 @@ Nexios URL Normalization contrib package. """ -from .middleware import SlashesMiddleware, SlashAction from .helpers import ( add_trailing_slash, clean_url_path, @@ -14,6 +13,7 @@ remove_trailing_slash, should_skip_path_processing, ) +from .middleware import SlashAction, SlashesMiddleware __all__ = [ "SlashesMiddleware", @@ -29,10 +29,11 @@ "should_skip_path_processing", ] + def Slashes( slash_action: SlashAction = SlashAction.REDIRECT_REMOVE, auto_remove_double_slashes: bool = True, - redirect_status_code: int = 301 + redirect_status_code: int = 301, ) -> SlashesMiddleware: """ Create a SlashesMiddleware instance with the given configuration. @@ -48,5 +49,5 @@ def Slashes( return SlashesMiddleware( slash_action=slash_action, auto_remove_double_slashes=auto_remove_double_slashes, - redirect_status_code=redirect_status_code + redirect_status_code=redirect_status_code, ) diff --git a/nexios_contrib/slashes/helpers.py b/nexios_contrib/slashes/helpers.py index 5cfa9b3..07e651f 100644 --- a/nexios_contrib/slashes/helpers.py +++ b/nexios_contrib/slashes/helpers.py @@ -1,15 +1,17 @@ """ Helper functions for URL normalization and path handling. """ + from __future__ import annotations from enum import Enum -from typing import List, Optional +from typing import List from urllib.parse import urlparse, urlunparse class SlashAction(Enum): """Actions for handling trailing slashes.""" + ADD = "add" REMOVE = "remove" REDIRECT_ADD = "redirect_add" @@ -77,7 +79,7 @@ def build_normalized_url( base_url: str, path: str, preserve_query: bool = True, - preserve_fragment: bool = True + preserve_fragment: bool = True, ) -> str: """ Build a normalized URL from components. @@ -100,7 +102,7 @@ def build_normalized_url( path, parsed.params, parsed.query if preserve_query else "", - parsed.fragment if preserve_fragment else "" + parsed.fragment if preserve_fragment else "", ] return urlunparse(components) @@ -122,14 +124,16 @@ def clean_url_path(url: str) -> str: normalized_path = normalize_path(parsed.path) # Rebuild URL with normalized path - return urlunparse(( - parsed.scheme, - parsed.netloc, - normalized_path, - parsed.params, - parsed.query, - parsed.fragment - )) + return urlunparse( + ( + parsed.scheme, + parsed.netloc, + normalized_path, + parsed.params, + parsed.query, + parsed.fragment, + ) + ) def get_path_segments(path: str) -> List[str]: @@ -204,14 +208,17 @@ def normalize_url(url: str, preserve_case: bool = True) -> str: normalized_path = normalize_path(parsed.path) # Build normalized URL - return urlunparse(( - parsed.scheme, - parsed.netloc, - normalized_path, - parsed.params, - parsed.query, - parsed.fragment - )) - -def is_double_slash(path:str): - return "//" in path \ No newline at end of file + return urlunparse( + ( + parsed.scheme, + parsed.netloc, + normalized_path, + parsed.params, + parsed.query, + parsed.fragment, + ) + ) + + +def is_double_slash(path: str): + return "//" in path diff --git a/nexios_contrib/slashes/middleware.py b/nexios_contrib/slashes/middleware.py index c11bec7..8ad47d9 100644 --- a/nexios_contrib/slashes/middleware.py +++ b/nexios_contrib/slashes/middleware.py @@ -4,24 +4,26 @@ This middleware cleans up URLs by handling trailing slashes, double slashes, and other common URL normalization issues. """ + from __future__ import annotations from enum import Enum from typing import Any -from urllib.parse import urlparse, urlunparse +from urllib.parse import urlunparse from nexios.http import Request, Response from nexios.middleware.base import BaseMiddleware -from .helpers import is_double_slash + class SlashAction(Enum): """Actions for handling trailing slashes.""" - ADD = "add" # Add trailing slash if missing - REMOVE = "remove" # Remove trailing slash if present - REDIRECT_ADD = "redirect_add" # Redirect to add trailing slash + + ADD = "add" # Add trailing slash if missing + REMOVE = "remove" # Remove trailing slash if present + REDIRECT_ADD = "redirect_add" # Redirect to add trailing slash REDIRECT_REMOVE = "redirect_remove" # Redirect to remove trailing slash - IGNORE = "ignore" # Leave as-is + IGNORE = "ignore" # Leave as-is class SlashesMiddleware(BaseMiddleware): @@ -51,7 +53,6 @@ def __init__( def _normalize_path(self, path: str) -> str: """Normalize a path by removing double slashes.""" - # Remove double slashes while "//" in path: @@ -108,7 +109,6 @@ async def process_request( # Update the request path request.scope["path"] = normalized_path - elif self.slash_action == SlashAction.ADD: # Add trailing slash if not self._has_trailing_slash(normalized_path): @@ -121,7 +121,10 @@ async def process_request( new_path = self._remove_trailing_slash(normalized_path) request.scope["path"] = new_path - elif self.slash_action in [SlashAction.REDIRECT_ADD, SlashAction.REDIRECT_REMOVE]: + elif self.slash_action in [ + SlashAction.REDIRECT_ADD, + SlashAction.REDIRECT_REMOVE, + ]: # Handle redirects should_redirect = False redirect_path = normalized_path @@ -135,20 +138,23 @@ async def process_request( if self._has_trailing_slash(normalized_path): redirect_path = self._remove_trailing_slash(normalized_path) should_redirect = True - + if should_redirect: # Build the redirect URL - redirect_url = urlunparse(( - request.url.scheme, - request.url.netloc, - redirect_path, - request.path_params, - request.url.query, - request.url.fragment - )) + redirect_url = urlunparse( + ( + request.url.scheme, + request.url.netloc, + redirect_path, + request.path_params, + request.url.query, + request.url.fragment, + ) + ) # Return redirect response - return response.redirect(redirect_url, status_code=self.redirect_status_code) - + return response.redirect( + redirect_url, status_code=self.redirect_status_code + ) await call_next() diff --git a/nexios_contrib/tasks/__init__.py b/nexios_contrib/tasks/__init__.py index 2203466..b5c7ef2 100644 --- a/nexios_contrib/tasks/__init__.py +++ b/nexios_contrib/tasks/__init__.py @@ -4,14 +4,15 @@ This module provides a robust and efficient way to manage background tasks in Nexios applications. It includes features like task lifecycle management, error handling, and result callbacks. """ + from __future__ import annotations +import warnings from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast from nexios import NexiosApp from nexios.dependencies import Depend, current_context from nexios.http import Request -import warnings from .config import TaskConfig, TaskStatus from .dependency import TaskDepend, TaskDependency, get_task_dependency @@ -21,55 +22,51 @@ # Re-export public API __all__ = [ # Main classes - 'Task', - 'TaskManager', - 'TaskConfig', - 'TaskStatus', - 'TaskResult', - 'TaskError', - + "Task", + "TaskManager", + "TaskConfig", + "TaskStatus", + "TaskResult", + "TaskError", # Dependency injection - 'TaskDepend', - 'TaskDependency', - 'get_task_dependency', - + "TaskDepend", + "TaskDependency", + "get_task_dependency", # Utility functions - 'setup_tasks', - 'get_task_manager', - 'create_task', + "setup_tasks", + "get_task_manager", + "create_task", ] # Type variables for generic type hints -T = TypeVar('T') +T = TypeVar("T") TaskCallback = Callable[..., Awaitable[Any]] TaskResultCallback = Callable[[str, Any, Optional[Exception]], Awaitable[None]] -def setup_tasks( - app: NexiosApp, - config: Optional[TaskConfig] = None -) -> TaskManager: + +def setup_tasks(app: NexiosApp, config: Optional[TaskConfig] = None) -> TaskManager: """Set up the task manager for a Nexios application. - + This function initializes the task manager and registers it with the Nexios app. It should be called during application startup. - + Args: app: The Nexios application instance. config: Optional configuration for the task manager. - + Returns: The initialized TaskManager instance. - + Example: ```python from nexios import NexiosApp from nexios_contrib.tasks import setup_tasks, TaskConfig - + app = NexiosApp() - + # Initialize with default configuration task_manager = setup_tasks(app) - + # Or with custom configuration config = TaskConfig( max_concurrent_tasks=50, @@ -79,32 +76,33 @@ def setup_tasks( task_manager = setup_tasks(app, config=config) ``` """ - if not hasattr(app, 'task_manager'): + if not hasattr(app, "task_manager"): task_manager = TaskManager(app, config=config) - app.task_manager = task_manager + app.task_manager = task_manager # ty:ignore[invalid-assignment] app.on_startup(task_manager.start) - return app.task_manager + return app.task_manager # ty:ignore[unresolved-attribute] + def get_task_manager(request: Request) -> TaskManager: """Get the task manager from a request. - + This is a convenience function to get the task manager instance from a request object. - + Args: request: The current request object. - + Returns: The TaskManager instance. - + Raises: AttributeError: If the task manager is not initialized. - + Example: ```python from nexios import Request from nexios_contrib.tasks import get_task_manager - + @app.get("/tasks/{task_id}") async def get_task_status(request: Request): task_manager = get_task_manager(request) @@ -113,31 +111,32 @@ async def get_task_status(request: Request): return {"status": task.status if task else "not_found"} ``` """ - task_manager = getattr(request.base_app, 'task_manager', None) + task_manager = getattr(request.base_app, "task_manager", None) if task_manager is None: raise AttributeError( "Task manager not initialized. Call setup_tasks(app) during application startup." ) return task_manager + def create_task( request_or_func: Union[Request, TaskCallback], func_or_arg: Optional[Union[TaskCallback, Any]] = None, name: Optional[str] = None, timeout: Optional[float] = None, *args: Any, - **kwargs: Any + **kwargs: Any, ) -> Task: """Create and schedule a new background task. - + This function creates a new background task. It can be called in two ways: - + 1. New (Recommended): create_task(func, *args, **kwargs) The request/task manager is automatically resolved from the current context. - + 2. Deprecated: create_task(request, func, *args, **kwargs) Explicitly passing the request object. - + Args: request_or_func: Context request or the task function. func_or_arg: Task function (if request passed first) or first task argument. @@ -145,7 +144,7 @@ def create_task( name: Optional name for the task. timeout: Optional timeout in seconds. **kwargs: Keyword arguments for the task. - + Returns: The created Task instance. """ @@ -163,23 +162,25 @@ def create_task( stacklevel=2, ) request = cast(Request, request_or_func) - + if func_or_arg is None or not callable(func_or_arg): - raise ValueError("When passing request explicitly, the second argument must be a callable task function.") - + raise ValueError( + "When passing request explicitly, the second argument must be a callable task function." + ) + func = cast(TaskCallback, func_or_arg) task_args = list(args) else: # New usage: create_task(func, arg1, arg2...) func = cast(TaskCallback, request_or_func) - + # In this mode, func_or_arg is actually the first argument for the task (if present) if func_or_arg is not None: task_args = [func_or_arg] + list(args) else: task_args = list(args) - + # Resolve request from context try: ctx = current_context.get() @@ -197,9 +198,5 @@ def create_task( task_manager = get_task_manager(request) return task_manager.create_task( - func=func, - *task_args, - name=name, - timeout=timeout, - **kwargs - ) + func=func, *task_args, name=name, timeout=timeout, **kwargs # ty:ignore[parameter-already-assigned] + ) # ty:ignore[invalid-return-type] diff --git a/nexios_contrib/tasks/config.py b/nexios_contrib/tasks/config.py index df0d8b7..7ed8961 100644 --- a/nexios_contrib/tasks/config.py +++ b/nexios_contrib/tasks/config.py @@ -3,23 +3,27 @@ This module provides configuration options for the task management system. """ -from typing import Optional, Dict, Any, Union, List, Callable, Awaitable -from dataclasses import dataclass, field -from enum import Enum + import logging +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Optional + class TaskStatus(str, Enum): """Status of a background task.""" + PENDING = "PENDING" RUNNING = "RUNNING" COMPLETED = "COMPLETED" FAILED = "FAILED" CANCELLED = "CANCELLED" + @dataclass class TaskConfig: """Configuration for the task manager. - + Attributes: max_concurrent_tasks: Maximum number of tasks that can run concurrently. default_timeout: Default timeout in seconds for tasks. @@ -27,12 +31,13 @@ class TaskConfig: enable_task_history: Whether to keep a history of completed tasks. log_level: Logging level for task-related logs. """ + max_concurrent_tasks: int = 100 default_timeout: Optional[float] = None task_result_ttl: int = 3600 # 1 hour enable_task_history: bool = True log_level: int = logging.INFO - + def to_dict(self) -> Dict[str, Any]: """Convert the configuration to a dictionary.""" return { @@ -43,5 +48,6 @@ def to_dict(self) -> Dict[str, Any]: "log_level": self.log_level, } + # Default configuration DEFAULT_CONFIG = TaskConfig() diff --git a/nexios_contrib/tasks/dependency.py b/nexios_contrib/tasks/dependency.py index 3cda1d9..28c5fd9 100644 --- a/nexios_contrib/tasks/dependency.py +++ b/nexios_contrib/tasks/dependency.py @@ -3,105 +3,103 @@ This module provides dependency injection utilities for working with background tasks. """ + from __future__ import annotations -from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, Generic, Type -from uuid import UUID, uuid4 -import asyncio import logging +from typing import ( + Any, + Awaitable, + Callable, + Generic, + Optional, + TypeVar, cast, +) -from nexios.dependencies import Depend,Context +from nexios.dependencies import Context, Depend from nexios.http import Request -from .config import TaskStatus, TaskConfig -from .models import Task, TaskResult, TaskError +from .models import Task -T = TypeVar('T') +T = TypeVar("T") TaskCallback = Callable[..., Awaitable[Any]] TaskResultCallback = Callable[[str, Any, Optional[Exception]], Awaitable[None]] + class TaskDepend(Generic[T]): """Dependency for working with background tasks. - + This class provides methods for creating and managing background tasks. It's designed to be used as a dependency in route handlers. """ - + def __init__(self, request: Request): """Initialize the task dependency. - + Args: request: The current request object. """ self.request = request - self.task_manager = request.base_app.task_manager + self.task_manager = request.base_app.task_manager # ty:ignore[unresolved-attribute] self.logger = logging.getLogger("nexios.tasks") - + async def create( self, func: TaskCallback, *args: Any, name: Optional[str] = None, timeout: Optional[float] = None, - **kwargs: Any + **kwargs: Any, ) -> Task: """Create and schedule a new background task. - + Args: func: The coroutine function to execute. *args: Positional arguments to pass to the function. name: Optional name for the task. timeout: Optional timeout in seconds. **kwargs: Keyword arguments to pass to the function. - + Returns: The created task instance. """ return await self.task_manager.create_task( - func=func, - *args, - name=name, - timeout=timeout, - **kwargs + func=func, *args, name=name, timeout=timeout, **kwargs ) - + async def get_task(self, task_id: str) -> Optional[Task]: """Get a task by its ID. - + Args: task_id: The ID of the task to retrieve. - + Returns: The task instance, or None if not found. """ return self.task_manager.get_task(task_id) - - async def wait_for_task( - self, - task_id: str, - timeout: Optional[float] = None - ) -> Any: + + async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> Any: """Wait for a task to complete and return its result. - + Args: task_id: The ID of the task to wait for. timeout: Optional timeout in seconds. - + Returns: The result of the task. - + Raises: asyncio.TimeoutError: If the task times out. Exception: If the task raises an exception. """ return await self.task_manager.wait_for_task(task_id, timeout=timeout) - + async def cancel_task(self, task_id: str) -> bool: """Cancel a running task. - + Args: task_id: The ID of the task to cancel. - + Returns: True if the task was cancelled, False otherwise. """ @@ -109,15 +107,16 @@ async def cancel_task(self, task_id: str) -> bool: def get_task_dependency( - ctx = Context(), + ctx=Context(), ) -> TaskDepend: """Dependency function to get a TaskDepend instance. - + This is the recommended way to get a TaskDepend instance in route handlers. """ + if not isinstance(ctx.request, Request): + raise TypeError("Task dependency requires a Request object") return TaskDepend(ctx.request) def TaskDependency() -> TaskDepend: - return Depend(get_task_dependency) - + return cast(TaskDepend,Depend(get_task_dependency)) diff --git a/nexios_contrib/tasks/manager.py b/nexios_contrib/tasks/manager.py index 4886976..d8b5edf 100644 --- a/nexios_contrib/tasks/manager.py +++ b/nexios_contrib/tasks/manager.py @@ -4,32 +4,36 @@ This module provides the TaskManager class which is responsible for managing background tasks in a Nexios application. """ + from __future__ import annotations import asyncio import logging +import time from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union from nexios import NexiosApp from .config import TaskConfig, TaskStatus -from .models import Task, TaskResult, TaskError -import time +from .models import Task, TaskResult -T = TypeVar('T') +T = TypeVar("T") TaskCallback = Callable[..., Awaitable[Any]] TaskResultCallback = Callable[[str, Any, Optional[Exception]], Awaitable[None]] + class TaskManager: """Manages background tasks for Nexios applications. - + This class provides methods to create, monitor, and manage background tasks. It's designed to be used as a singleton per application instance. """ - - def __init__(self, app: Optional[NexiosApp] = None, config: Optional[TaskConfig] = None) -> None: + + def __init__( + self, app: Optional[NexiosApp] = None, config: Optional[TaskConfig] = None + ) -> None: """Initialize the task manager. - + Args: app: The Nexios application instance. config: Configuration for the task manager. @@ -40,97 +44,98 @@ def __init__(self, app: Optional[NexiosApp] = None, config: Optional[TaskConfig] self._shutdown = False self._task_callbacks: Dict[str, List[TaskResultCallback]] = {} self.logger = logging.getLogger("nexios.tasks") - + # Configure logging logging.basicConfig(level=self.config.log_level) - + async def start(self) -> None: """Initialize the task manager. - + This method should be called during application startup. """ if self.app is not None: self.app.on_shutdown(self.shutdown) self.logger.info("Task manager started") - + async def shutdown(self) -> None: """Shutdown the task manager and cancel all running tasks. - + This method is automatically called during application shutdown. """ self._shutdown = True self.logger.info("Shutting down task manager...") - + # Cancel all running tasks - tasks_to_cancel = [task for task in self.tasks.values() - if task._task and not task.is_done] - + tasks_to_cancel = [ + task for task in self.tasks.values() if task._task and not task.is_done + ] + if tasks_to_cancel: self.logger.info("Cancelling %d running tasks...", len(tasks_to_cancel)) for task in tasks_to_cancel: if task._task: task._task.cancel() - + # Wait for tasks to complete cancellation await asyncio.gather( *(task._task for task in tasks_to_cancel if task._task), - return_exceptions=True + return_exceptions=True, ) - + self.logger.info("Task manager shutdown complete") - + async def create_task( self, func: TaskCallback, *args: Any, name: Optional[str] = None, timeout: Optional[float] = None, - **kwargs: Any + **kwargs: Any, ) -> Task: """Create and schedule a new background task. - + Args: func: The coroutine function to execute. *args: Positional arguments to pass to the function. name: Optional name for the task. timeout: Optional timeout in seconds. **kwargs: Keyword arguments to pass to the function. - + Returns: The created task instance. - + Raises: RuntimeError: If the task manager is shutting down. asyncio.TimeoutError: If the task times out. """ if self._shutdown: raise RuntimeError("Cannot create new tasks during shutdown") - + # Create the task - task = Task(func, *args, name=name or func.__name__, **kwargs) + task = Task(func, *args, name=name or func.__name__, **kwargs) # ty:ignore[unresolved-attribute] self.tasks[task.id] = task - + # Create the asyncio task task._task = asyncio.create_task(self._run_task(task, timeout)) - + self.logger.debug("Created task %s (ID: %s)", task.name, task.id) return task - + async def _run_task(self, task: Task, timeout: Optional[float] = None) -> None: """Internal method to run a task and handle its completion.""" try: # Use the configured timeout if none provided timeout = timeout or self.config.default_timeout - + if timeout is not None: # Run with timeout if specified await asyncio.wait_for(task.run(), timeout=timeout) else: # Run without timeout await task.run() - + self.logger.debug("Task %s completed successfully", task.id) - + except asyncio.CancelledError: # Task was cancelled self.logger.debug("Task %s was cancelled", task.id) @@ -139,13 +144,13 @@ async def _run_task(self, task: Task, timeout: Optional[float] = None) -> None: task_id=task.id, result=None, status=TaskStatus.CANCELLED, - error=asyncio.CancelledError("Task was cancelled"), - completed_at=time.time() + error=asyncio.CancelledError("Task was cancelled"), # ty:ignore[invalid-argument-type] + completed_at=time.time(), ) if task.id in self.tasks: del self.tasks[task.id] raise - + except asyncio.TimeoutError: # Task timed out error_msg = f"Task {task.id} timed out after {timeout} seconds" @@ -156,9 +161,9 @@ async def _run_task(self, task: Task, timeout: Optional[float] = None) -> None: result=None, status=TaskStatus.FAILED, error=TimeoutError(error_msg), - completed_at=time.time() + completed_at=time.time(), ) - + except Exception as e: # Task failed with an exception self.logger.exception("Task %s failed with error: %s", task.id, str(e)) @@ -168,61 +173,61 @@ async def _run_task(self, task: Task, timeout: Optional[float] = None) -> None: result=None, status=TaskStatus.FAILED, error=e, - completed_at=time.time() + completed_at=time.time(), ) - + finally: # Ensure the task is marked as done task._completed_at = time.time() - if hasattr(task, '_done_event'): + if hasattr(task, "_done_event"): task._done_event.set() - + # Execute any registered callbacks await self._execute_callbacks(task) - + # Clean up if needed if not self.config.enable_task_history and task.id in self.tasks: del self.tasks[task.id] - + async def _execute_callbacks(self, task: Task) -> None: """Execute all registered callbacks for a task.""" if task.id not in self._task_callbacks: return - + for callback in self._task_callbacks[task.id]: try: - await callback(task.id, task._result, task._result.error if task._result else None) + await callback( + task.id, task._result, task._result.error if task._result else None + ) except Exception as e: - self.logger.exception("Error in task callback for %s: %s", task.id, str(e)) - + self.logger.exception( + "Error in task callback for %s: %s", task.id, str(e) + ) + # Clean up callbacks del self._task_callbacks[task.id] - + def get_task(self, task_id: str) -> Optional[Task]: """Get a task by its ID. - + Args: task_id: The ID of the task to retrieve. - + Returns: The task instance, or None if not found. """ return self.tasks.get(task_id) - - async def wait_for_task( - self, - task_id: str, - timeout: Optional[float] = None - ) -> Any: + + async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> Any: """Wait for a task to complete and return its result. - + Args: task_id: The ID of the task to wait for. timeout: Optional timeout in seconds. - + Returns: The result of the task. - + Raises: ValueError: If the task ID is not found. asyncio.TimeoutError: If the task times out. @@ -231,49 +236,47 @@ async def wait_for_task( task = self.get_task(task_id) if not task: raise ValueError(f"Task {task_id} not found") - + return await task.wait(timeout=timeout) - + async def cancel_task(self, task_id: str) -> bool: """Cancel a running task. - + Args: task_id: The ID of the task to cancel. - + Returns: True if the task was cancelled, False otherwise. """ task = self.get_task(task_id) if not task or not task._task or task.is_done: return False - + task._task.cancel() return True - + def list_tasks(self, status: Optional[TaskStatus] = None) -> List[Task]: """Get a list of all tasks, optionally filtered by status. - + Args: status: Optional status to filter tasks by. - + Returns: A list of tasks matching the criteria. """ if status is None: return list(self.tasks.values()) return [task for task in self.tasks.values() if task.status == status] - + def add_callback( - self, - task: Union[str, Task], - callback: TaskResultCallback + self, task: Union[str, Task], callback: TaskResultCallback ) -> None: """Add a callback to be called when a task completes. - + Args: task_id: The ID of the task to add the callback to. callback: The callback function to call when the task completes. - + Raises: ValueError: If the task ID is not found. """ @@ -281,29 +284,25 @@ def add_callback( task = task.id if task not in self.tasks: raise ValueError(f"Task {task} not found") - + if task not in self._task_callbacks: self._task_callbacks[task] = [] - + self._task_callbacks[task].append(callback) - - def remove_callback( - self, - task_id: str, - callback: TaskResultCallback - ) -> bool: + + def remove_callback(self, task_id: str, callback: TaskResultCallback) -> bool: """Remove a callback from a task. - + Args: task_id: The ID of the task to remove the callback from. callback: The callback function to remove. - + Returns: True if the callback was removed, False otherwise. """ if task_id not in self._task_callbacks: return False - + try: self._task_callbacks[task_id].remove(callback) return True diff --git a/nexios_contrib/tasks/models.py b/nexios_contrib/tasks/models.py index 33a408e..3e59791 100644 --- a/nexios_contrib/tasks/models.py +++ b/nexios_contrib/tasks/models.py @@ -3,30 +3,31 @@ This module defines the core data structures used by the task system. """ + from __future__ import annotations import asyncio import time from dataclasses import dataclass, field -from datetime import datetime -from enum import Enum -from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union -from uuid import UUID, uuid4 +from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar +from uuid import uuid4 from .config import TaskStatus -T = TypeVar('T') +T = TypeVar("T") + @dataclass class TaskResult: """Represents the result of a completed task.""" + task_id: str result: Any status: TaskStatus error: Optional[Exception] = None created_at: float = field(default_factory=time.time) completed_at: Optional[float] = None - + def to_dict(self) -> Dict[str, Any]: """Convert the task result to a dictionary.""" return { @@ -36,40 +37,43 @@ def to_dict(self) -> Dict[str, Any]: "error": str(self.error) if self.error else None, "created_at": self.created_at, "completed_at": self.completed_at, - "duration": (self.completed_at or time.time()) - self.created_at + "duration": (self.completed_at or time.time()) - self.created_at, } + @dataclass class TaskError: """Represents an error that occurred during task execution.""" + message: str exception_type: str traceback: Optional[str] = None - + def to_dict(self) -> Dict[str, Any]: """Convert the error to a dictionary.""" return { "message": self.message, "exception_type": self.exception_type, - "traceback": self.traceback + "traceback": self.traceback, } + class Task: """Represents a background task. - + This class encapsulates a coroutine that runs in the background and provides methods to monitor and control its execution. """ - + def __init__( self, func: Callable[..., Awaitable[Any]], *args: Any, name: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """Initialize a new task. - + Args: func: The coroutine function to execute. *args: Positional arguments to pass to the function. @@ -88,37 +92,41 @@ def __init__( self._started_at: Optional[float] = None self._completed_at: Optional[float] = None self._done_event = asyncio.Event() - + @property def status(self) -> TaskStatus: """Get the current status of the task.""" return self._status - + @property def result(self) -> Optional[TaskResult]: """Get the result of the task if it has completed.""" return self._result - + @property def is_done(self) -> bool: """Check if the task has completed (successfully or not).""" - return self._status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED) - + return self._status in ( + TaskStatus.COMPLETED, + TaskStatus.FAILED, + TaskStatus.CANCELLED, + ) + async def run(self) -> None: """Execute the task.""" if self._status != TaskStatus.PENDING: raise RuntimeError(f"Task {self.id} has already been executed") - + self._status = TaskStatus.RUNNING self._started_at = time.time() - + try: result = await self.func(*self.args, **self.kwargs) self._result = TaskResult( task_id=self.id, result=result, status=TaskStatus.COMPLETED, - completed_at=time.time() + completed_at=time.time(), ) self._status = TaskStatus.COMPLETED except asyncio.CancelledError: @@ -127,33 +135,33 @@ async def run(self) -> None: task_id=self.id, result=None, status=TaskStatus.CANCELLED, - completed_at=time.time() + completed_at=time.time(), ) raise except Exception as e: - import traceback + self._status = TaskStatus.FAILED self._result = TaskResult( task_id=self.id, result=None, status=TaskStatus.FAILED, error=e, - completed_at=time.time() + completed_at=time.time(), ) raise finally: self._completed_at = time.time() self._done_event.set() - + async def wait(self, timeout: Optional[float] = None) -> Any: """Wait for the task to complete and return its result. - + Args: timeout: Optional timeout in seconds. - + Returns: The result of the task. - + Raises: asyncio.TimeoutError: If the task times out. Exception: If the task raises an exception. @@ -162,15 +170,21 @@ async def wait(self, timeout: Optional[float] = None) -> Any: if self._status == TaskStatus.FAILED and self._result.error: raise self._result.error return self._result.result - + try: await asyncio.wait_for(self._done_event.wait(), timeout=timeout) - if self._result and self._status == TaskStatus.FAILED and self._result.error: + if ( + self._result + and self._status == TaskStatus.FAILED + and self._result.error + ): raise self._result.error return self._result.result if self._result else None except asyncio.TimeoutError: - raise asyncio.TimeoutError(f"Task {self.id} timed out after {timeout} seconds") - + raise asyncio.TimeoutError( + f"Task {self.id} timed out after {timeout} seconds" + ) + def to_dict(self) -> Dict[str, Any]: """Convert the task to a dictionary.""" return { @@ -180,6 +194,11 @@ def to_dict(self) -> Dict[str, Any]: "created_at": self._created_at, "started_at": self._started_at, "completed_at": self._completed_at, - "duration": (self._completed_at or time.time()) - (self._started_at or self._created_at) if self._started_at else None, - "result": self._result.to_dict() if self._result else None + "duration": ( + (self._completed_at or time.time()) + - (self._started_at or self._created_at) + if self._started_at + else None + ), + "result": self._result.to_dict() if self._result else None, } diff --git a/nexios_contrib/timeout/helper.py b/nexios_contrib/timeout/helper.py index bd5e897..2d52066 100644 --- a/nexios_contrib/timeout/helper.py +++ b/nexios_contrib/timeout/helper.py @@ -4,7 +4,9 @@ This module provides utilities for timeout handling and request timing for Nexios applications. """ + from __future__ import annotations +from nexios.http.response import JSONResponse, BaseResponse import asyncio import time @@ -53,6 +55,7 @@ def timeout_after( async def slow_operation(): await asyncio.sleep(60) # This will timeout after 30 seconds """ + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: async def wrapper(*args: Any, **kwargs: Any) -> Any: if timeout <= 0: @@ -64,7 +67,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: if exception: raise exception raise TimeoutException(timeout) + return wrapper + return decorator @@ -112,8 +117,8 @@ def get_request_duration(request: Request) -> float: Returns: Duration of the request in seconds, or 0 if not available. """ - if hasattr(request, 'start_time'): - return time.time() - request.start_time + if hasattr(request, "start_time"): + return time.time() - int(str(request.start_time)) return 0.0 @@ -124,7 +129,7 @@ def set_request_start_time(request: Request) -> None: Args: request: The HTTP request object to set the start time for. """ - request.start_time = time.time() + request.start_time = time.time() # ty:ignore[unresolved-attribute] def get_request_start_time(request: Request) -> Optional[float]: @@ -140,10 +145,7 @@ def get_request_start_time(request: Request) -> Optional[float]: return getattr(request, "start_time", None) -def get_timeout_from_request( - request: Request, - default_timeout: float = 30.0 -) -> float: +def get_timeout_from_request(request: Request, default_timeout: float = 30.0) -> float: """ Extract timeout value from request headers or query parameters. @@ -159,7 +161,7 @@ def get_timeout_from_request( Timeout duration in seconds. """ # Check for timeout in headers - timeout_header = request.headers.get('X-Request-Timeout') + timeout_header = request.headers.get("X-Request-Timeout") if timeout_header: try: return max(0.1, float(timeout_header)) @@ -167,7 +169,7 @@ def get_timeout_from_request( pass # Check for timeout in query parameters - timeout_param = request.query_params.get('timeout') + timeout_param = request.query_params.get("timeout") if timeout_param: try: return max(0.1, float(timeout_param)) @@ -178,10 +180,9 @@ def get_timeout_from_request( def create_timeout_response( - response: Response, timeout: float, detail: Any = None, -) -> Response: +) -> BaseResponse: """ Create a timeout error response. @@ -193,10 +194,11 @@ def create_timeout_response( Returns: HTTP response indicating a timeout error. """ - response.set_header("X-Timeout", str(timeout)) - return response.json( + return JSONResponse( {"error": "Request Timeout", "timeout": timeout, "detail": detail}, status_code=408, + headers={"X-Timeout": str(timeout)} + ) diff --git a/nexios_contrib/timeout/middleware.py b/nexios_contrib/timeout/middleware.py index 9d3e0dd..1d820fb 100644 --- a/nexios_contrib/timeout/middleware.py +++ b/nexios_contrib/timeout/middleware.py @@ -4,23 +4,22 @@ This middleware provides request timeout handling and automatic timeout responses for Nexios applications. """ + from __future__ import annotations +from nexios.http.response import BaseResponse import asyncio -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union from nexios.http import Request, Response from nexios.middleware.base import BaseMiddleware -from nexios.exceptions import HTTPException from .helper import ( TimeoutException, create_timeout_response, get_request_duration, - get_timeout_from_request, is_timeout_error, set_request_start_time, - timeout_after, ) @@ -107,18 +106,19 @@ async def process_request( timeout = self._get_request_timeout(request) # Store timeout information in request for later use - request.timeout = timeout - request.timeout_config = { - 'default_timeout': self.default_timeout, - 'max_timeout': self.max_timeout, - 'min_timeout': self.min_timeout, - 'timeout_header': self.timeout_header, - 'timeout_param': self.timeout_param, + request.timeout = timeout # ty:ignore[unresolved-attribute] + request.timeout_config = { # ty:ignore[unresolved-attribute] + "default_timeout": self.default_timeout, + "max_timeout": self.max_timeout, + "min_timeout": self.min_timeout, + "timeout_header": self.timeout_header, + "timeout_param": self.timeout_param, } try: # Apply timeout to the next handler if timeout > 0: + async def timeout_wrapper() -> Any: try: return await asyncio.wait_for(call_next(), timeout=timeout) @@ -138,10 +138,7 @@ async def timeout_wrapper() -> Any: return self._create_timeout_response(request, e) else: # Return a basic timeout response - return response.json( - status_code=408, - content="Request Timeout" - ) + return response.json(status_code=408, data="Request Timeout") except Exception as e: # Handle other exceptions if is_timeout_error(e): @@ -167,14 +164,13 @@ async def process_response( Response: The modified HTTP response object. """ # Add request duration to response headers if tracking is enabled - if self.track_duration and hasattr(request, 'start_time'): + if self.track_duration and hasattr(request, "start_time"): duration = get_request_duration(request) - response.set_header('X-Request-Duration',str(duration)) + response.set_header("X-Request-Duration", str(duration)) # Add timeout information if available - if hasattr(request, 'timeout'): - response.set_header('X-Request-Timeout',str(request.timeout)) - + if hasattr(request, "timeout"): + response.set_header("X-Request-Timeout", str(request.timeout)) return response @@ -234,7 +230,7 @@ def _create_timeout_response( self, request: Request, timeout_exception: Union[TimeoutException, Exception], - ) -> Response: + ) -> BaseResponse: """ Create a timeout error response. @@ -245,19 +241,18 @@ def _create_timeout_response( Returns: Response: HTTP response indicating a timeout error. """ - timeout = getattr(timeout_exception, 'timeout', self.default_timeout) + timeout = getattr(timeout_exception, "timeout", self.default_timeout) # Create the response response = create_timeout_response( timeout=timeout, - detail=str(timeout_exception) if str(timeout_exception) else None + detail=str(timeout_exception) if str(timeout_exception) else None, ) # Add timing information - if hasattr(request, 'start_time'): + if hasattr(request, "start_time"): duration = get_request_duration(request) - response.set_header('X-Actual-Duration',str(duration)) - + response.set_header("X-Actual-Duration", str(duration)) return response diff --git a/nexios_contrib/tortoise/__init__.py b/nexios_contrib/tortoise/__init__.py index 23befbc..6b5cfee 100644 --- a/nexios_contrib/tortoise/__init__.py +++ b/nexios_contrib/tortoise/__init__.py @@ -4,10 +4,11 @@ This module provides Tortoise ORM initialization, connection management, and exception handling for Nexios applications. """ + from __future__ import annotations import logging -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from nexios import NexiosApp @@ -15,7 +16,6 @@ from .config import TortoiseConfig from .exceptions import handle_tortoise_exceptions - __version__ = "0.1.0" # Global Tortoise client instance @@ -24,11 +24,6 @@ logger = logging.getLogger("nexios.tortoise") -class TortoiseConnectionError(Exception): - """Raised when there's an error connecting to the database.""" - pass - - def init_tortoise( app: NexiosApp, db_url: str, @@ -82,7 +77,7 @@ def init_tortoise( db_url=db_url, modules=modules or {"models": []}, generate_schemas=generate_schemas, - **kwargs + **kwargs, ) _tortoise_client = TortoiseClient(config) @@ -144,5 +139,7 @@ async def get_user(request, response): """ global _tortoise_client if _tortoise_client is None: - raise TortoiseConnectionError("Tortoise ORM client not initialized. Call init_tortoise() first.") - return _tortoise_client \ No newline at end of file + raise TortoiseConnectionError( + "Tortoise ORM client not initialized. Call init_tortoise() first." + ) + return _tortoise_client diff --git a/nexios_contrib/tortoise/client.py b/nexios_contrib/tortoise/client.py index b6f45bf..65b0f30 100644 --- a/nexios_contrib/tortoise/client.py +++ b/nexios_contrib/tortoise/client.py @@ -1,10 +1,11 @@ """ Tortoise ORM client wrapper for Nexios integration. """ + from __future__ import annotations import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from .config import TortoiseConfig @@ -13,6 +14,7 @@ class TortoiseConnectionError(Exception): """Raised when there's an error connecting to the database.""" + pass @@ -44,7 +46,7 @@ async def init(self) -> None: # Convert config to Tortoise.init() kwargs tortoise_config = self.config.to_tortoise_config() - + await Tortoise.init(**tortoise_config) # Generate schemas if requested @@ -140,7 +142,7 @@ async def execute_query(self, query: str, *args: Any) -> List[Dict[str, Any]]: from tortoise import Tortoise connection = Tortoise.get_connection("default") - return await connection.execute_query(query, args) + return await connection.execute_query(query, args) # ty:ignore[invalid-return-type, invalid-argument-type] except Exception as e: logger.error(f"Failed to execute query: {e}") @@ -192,4 +194,4 @@ def get_models(self) -> Dict[str, Any]: def __repr__(self) -> str: """String representation of TortoiseClient.""" status = "initialized" if self._initialized else "not initialized" - return f"TortoiseClient({status}, config={self.config})" \ No newline at end of file + return f"TortoiseClient({status}, config={self.config})" diff --git a/nexios_contrib/tortoise/config.py b/nexios_contrib/tortoise/config.py index 4ca5536..1762c17 100644 --- a/nexios_contrib/tortoise/config.py +++ b/nexios_contrib/tortoise/config.py @@ -1,12 +1,13 @@ """ Tortoise ORM configuration for Nexios Tortoise integration. """ + from __future__ import annotations import os from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator class TortoiseConfig(BaseModel): @@ -17,40 +18,37 @@ class TortoiseConfig(BaseModel): via Tortoise ORM and provides validation and sensible defaults. """ - db_url: str = Field( - description="Database connection URL" - ) + db_url: str = Field(description="Database connection URL") modules: Dict[str, List[str]] = Field( default_factory=lambda: {"models": []}, - description="Dictionary mapping app names to model module paths" + description="Dictionary mapping app names to model module paths", ) generate_schemas: bool = Field( - default=False, - description="Whether to generate database schemas on startup" + default=False, description="Whether to generate database schemas on startup" ) use_tz: bool = Field( - default=False, - description="Whether to use timezone-aware datetime objects" + default=False, description="Whether to use timezone-aware datetime objects" ) timezone: str = Field( - default="UTC", - description="Default timezone for datetime objects" + default="UTC", description="Default timezone for datetime objects" ) connections: Optional[Dict[str, Any]] = Field( - default=None, - description="Custom connection configurations" + default=None, description="Custom connection configurations" ) apps: Optional[Dict[str, Any]] = Field( - default=None, - description="Custom app configurations" + default=None, description="Custom app configurations" ) - @validator("db_url") + @field_validator("db_url") def validate_db_url(cls, v: str) -> str: """Validate database URL format.""" supported_schemes = [ - "sqlite://", "postgres://", "postgresql://", - "mysql://", "asyncpg://", "aiomysql://" + "sqlite://", + "postgres://", + "postgresql://", + "mysql://", + "asyncpg://", + "aiomysql://", ] if not any(v.startswith(scheme) for scheme in supported_schemes): raise ValueError( @@ -58,20 +56,22 @@ def validate_db_url(cls, v: str) -> str: ) return v - @validator("modules") + @field_validator("modules") def validate_modules(cls, v: Dict[str, List[str]]) -> Dict[str, List[str]]: """Validate modules configuration.""" if not isinstance(v, dict): raise ValueError("modules must be a dictionary") - + for app_name, module_list in v.items(): if not isinstance(module_list, list): raise ValueError(f"modules['{app_name}'] must be a list of strings") - + for module in module_list: if not isinstance(module, str): - raise ValueError(f"All modules in modules['{app_name}'] must be strings") - + raise ValueError( + f"All modules in modules['{app_name}'] must be strings" + ) + return v @classmethod @@ -96,27 +96,32 @@ def from_env(cls, prefix: str = "TORTOISE_") -> TortoiseConfig: ``` """ env_vars = {} - + # Handle simple fields for field_name, field in cls.__fields__.items(): if field_name in ["modules", "connections", "apps"]: continue # Skip complex fields - + env_name = f"{prefix}{field_name.upper()}" env_value = os.getenv(env_name) if env_value is not None: # Type conversion for specific fields - if field.type_ in (int, float): + if isinstance(field,(int, float)): try: - if field.type_ == int: + if field.type_ is int: env_vars[field_name] = int(env_value) else: env_vars[field_name] = float(env_value) except ValueError: continue # Skip invalid values - elif field.type_ == bool: - env_vars[field_name] = env_value.lower() in ("true", "1", "yes", "on") + elif isinstance(field, bool): + env_vars[field_name] = env_value.lower() in ( + "true", + "1", + "yes", + "on", + ) else: env_vars[field_name] = env_value @@ -125,11 +130,12 @@ def from_env(cls, prefix: str = "TORTOISE_") -> TortoiseConfig: if modules_env: try: import json + env_vars["modules"] = json.loads(modules_env) except (json.JSONDecodeError, ValueError): pass # Skip invalid JSON - return cls(**env_vars) + return cls(**env_vars) # ty:ignore[invalid-argument-type] def to_tortoise_config(self) -> Dict[str, Any]: """ @@ -147,7 +153,7 @@ def to_tortoise_config(self) -> Dict[str, Any]: if self.connections: config["connections"] = self.connections - + if self.apps: config["apps"] = self.apps @@ -155,7 +161,7 @@ def to_tortoise_config(self) -> Dict[str, Any]: def __str__(self) -> str: """String representation of Tortoise config (without sensitive data).""" - safe_dict = self.dict() + safe_dict = self.model_dump() # Mask password in db_url if present if "://" in self.db_url and "@" in self.db_url: parts = self.db_url.split("://") @@ -167,5 +173,5 @@ def __str__(self) -> str: if ":" in auth_part: user, _ = auth_part.split(":", 1) safe_dict["db_url"] = f"{scheme}://{user}:***@{host_part}" - - return f"TortoiseConfig({safe_dict})" \ No newline at end of file + + return f"TortoiseConfig({safe_dict})" diff --git a/nexios_contrib/tortoise/exceptions.py b/nexios_contrib/tortoise/exceptions.py index 0f9db94..18ea2d2 100644 --- a/nexios_contrib/tortoise/exceptions.py +++ b/nexios_contrib/tortoise/exceptions.py @@ -1,12 +1,13 @@ from __future__ import annotations + import logging from typing import TYPE_CHECKING from tortoise.exceptions import ( - IntegrityError, DoesNotExist, - ValidationError, + IntegrityError, OperationalError, + ValidationError, ) if TYPE_CHECKING: @@ -19,39 +20,49 @@ def handle_tortoise_exceptions(app: "NexiosApp") -> None: @app.add_exception_handler(IntegrityError) - async def handle_integrity(request: Request, response: Response, exc: IntegrityError): + async def handle_integrity( + request: Request, response: Response, exc: IntegrityError + ): logger.warning(f"Tortoise IntegrityError: {exc}") - return response.json({ - "error": "Integrity constraint violation", - "detail": str(exc), - "type": "integrity_error" - }).status(400) + return response.json( + { + "error": "Integrity constraint violation", + "detail": str(exc), + "type": "integrity_error", + } + ).status(400) @app.add_exception_handler(DoesNotExist) async def handle_not_found(request: Request, response: Response, exc: DoesNotExist): logger.info(f"Tortoise DoesNotExist: {exc}") - return response.json({ - "error": "Record not found", - "detail": str(exc), - "type": "not_found_error" - }).status(404) + return response.json( + {"error": "Record not found", "detail": str(exc), "type": "not_found_error"} + ).status(404) @app.add_exception_handler(ValidationError) - async def handle_validation(request: Request, response: Response, exc: ValidationError): + async def handle_validation( + request: Request, response: Response, exc: ValidationError + ): logger.warning(f"Tortoise ValidationError: {exc}") - return response.json({ - "error": "Validation failed", - "detail": str(exc), - "type": "validation_error" - }).status(422) + return response.json( + { + "error": "Validation failed", + "detail": str(exc), + "type": "validation_error", + } + ).status(422) @app.add_exception_handler(OperationalError) - async def handle_operational(request: Request, response: Response, exc: OperationalError): + async def handle_operational( + request: Request, response: Response, exc: OperationalError + ): logger.error(f"Tortoise OperationalError: {exc}") - return response.json({ - "error": "Database operational error", - "detail": "Service temporarily unavailable", - "type": "operational_error" - }).status(503) + return response.json( + { + "error": "Database operational error", + "detail": "Service temporarily unavailable", + "type": "operational_error", + } + ).status(503) logger.info("Tortoise exception handlers (4 types) registered.") diff --git a/nexios_contrib/trusted/__init__.py b/nexios_contrib/trusted/__init__.py index 3b21f26..bbb039b 100644 --- a/nexios_contrib/trusted/__init__.py +++ b/nexios_contrib/trusted/__init__.py @@ -2,7 +2,6 @@ Nexios Trusted Host contrib package. """ -from .middleware import TrustedHostMiddleware from .helpers import ( get_host_from_headers, is_www_host, @@ -10,6 +9,7 @@ strip_www_prefix, validate_host_against_patterns, ) +from .middleware import TrustedHostMiddleware __all__ = [ "TrustedHostMiddleware", @@ -20,14 +20,13 @@ "validate_host_against_patterns", ] + def TrustedHost( - allowed_hosts: list[str], - allowed_ports: list[int] = None, - www_redirect: bool = True + allowed_hosts: list[str], allowed_ports: list[int] = None, www_redirect: bool = True ) -> TrustedHostMiddleware: """Create a TrustedHostMiddleware instance with the given configuration.""" return TrustedHostMiddleware( allowed_hosts=allowed_hosts, allowed_ports=allowed_ports, - www_redirect=www_redirect + www_redirect=www_redirect, ) diff --git a/nexios_contrib/trusted/helpers.py b/nexios_contrib/trusted/helpers.py index 7bcbf84..82da994 100644 --- a/nexios_contrib/trusted/helpers.py +++ b/nexios_contrib/trusted/helpers.py @@ -1,6 +1,7 @@ """ Helper functions for trusted host validation. """ + from __future__ import annotations from typing import List, Optional, Set @@ -45,9 +46,7 @@ def matches_wildcard_pattern(host: str, pattern: str) -> bool: def validate_host_against_patterns( - host: str, - allowed_patterns: List[str], - allowed_ports: Optional[Set[int]] = None + host: str, allowed_patterns: List[str], allowed_ports: Optional[Set[int]] = None ) -> bool: """ Validate a host against a list of allowed patterns. diff --git a/nexios_contrib/trusted/middleware.py b/nexios_contrib/trusted/middleware.py index 2ba0cd2..fb0873d 100644 --- a/nexios_contrib/trusted/middleware.py +++ b/nexios_contrib/trusted/middleware.py @@ -4,6 +4,7 @@ This middleware validates the Host header to ensure requests are coming from trusted hosts/domains. This is a security feature to prevent Host header attacks. """ + from __future__ import annotations from typing import Any, List, Optional, Set @@ -94,13 +95,15 @@ async def process_request( # Check if host is allowed if not self._is_host_allowed(host): - return response.json({"error": f"Host '{host}' is not allowed"}, status_code=400) + return response.json( + {"error": f"Host '{host}' is not allowed"}, status_code=400 + ) # Handle www redirect if enabled if self.www_redirect and host.startswith("www."): www_prefix = "www." if host.startswith(www_prefix): - clean_host = host[len(www_prefix):] + clean_host = host[len(www_prefix) :] # Check if the clean host is allowed if self._is_host_allowed(clean_host): # In a real implementation, you'd redirect here diff --git a/pyproject.toml b/pyproject.toml index d9dc1d9..8f784b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,3 +216,10 @@ version_files = [ ] tag_format = "v$version" update_changelog_on_bump = true + + +[tool.ty.rules] +unresolved-import = "ignore" +invalid-parameter-default = "ignore" +unused-ignore-comment = "ignore" +unresolved-attribute = "ignore" diff --git a/tests/conftest.py b/tests/conftest.py index 6785f8a..92b49a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,9 @@ """ import functools -from typing import Callable, Optional, Any +from typing import Any, Callable, Optional import pytest - from nexios import NexiosApp from nexios.testclient import TestClient @@ -20,6 +19,7 @@ def test_client_factory(): @pytest.fixture def app_factory(): """Factory for creating NexiosApp instances with optional middleware.""" + def _create_app(middleware: Optional[Any] = None): app = NexiosApp() if middleware: diff --git a/tests/etag/test_configuration.py b/tests/etag/test_configuration.py index 63252c3..cf47595 100644 --- a/tests/etag/test_configuration.py +++ b/tests/etag/test_configuration.py @@ -3,9 +3,9 @@ """ import pytest - from nexios import NexiosApp from nexios.testclient import TestClient + from nexios_contrib.etag import ETagMiddleware @@ -51,7 +51,7 @@ async def strong_handler(request, response): assert resp2.headers["etag"].startswith('W/"') # Strong should produce strong ETags - assert not resp3.headers["etag"].startswith('W/') + assert not resp3.headers["etag"].startswith("W/") assert resp3.headers["etag"].startswith('"') def test_methods_configuration(self): @@ -136,11 +136,13 @@ async def with_override_handler(request, response): def test_combined_configuration_options(self): """Test combined configuration options.""" app = NexiosApp() - app.add_middleware(ETagMiddleware( - weak=False, # Strong ETags - methods=["GET", "POST"], # Include POST - override=True # Override manual ETags - )) + app.add_middleware( + ETagMiddleware( + weak=False, # Strong ETags + methods=["GET", "POST"], # Include POST + override=True, # Override manual ETags + ) + ) @app.get("/get-endpoint") async def get_handler(request, response): @@ -154,12 +156,12 @@ async def post_handler(request, response): # GET should use computed strong ETag (override manual) resp1 = client.get("/get-endpoint") - assert not resp1.headers["etag"].startswith('W/') + assert not resp1.headers["etag"].startswith("W/") assert resp1.headers["etag"] != '"manual-etag"' # POST should also use computed strong ETag (override manual) resp2 = client.post("/post-endpoint", json={}) - assert not resp2.headers["etag"].startswith('W/') + assert not resp2.headers["etag"].startswith("W/") assert resp2.headers["etag"] != '"manual-etag"' def test_case_insensitive_methods(self): @@ -261,9 +263,9 @@ async def bytes_handler(request, response): resp2 = client.get("/text") resp3 = client.get("/bytes") - assert not resp1.headers["etag"].startswith('W/') - assert not resp2.headers["etag"].startswith('W/') - assert not resp3.headers["etag"].startswith('W/') + assert not resp1.headers["etag"].startswith("W/") + assert not resp2.headers["etag"].startswith("W/") + assert not resp3.headers["etag"].startswith("W/") # All should be different ETags since content is different etags = [resp1.headers["etag"], resp2.headers["etag"], resp3.headers["etag"]] diff --git a/tests/etag/test_edge_cases.py b/tests/etag/test_edge_cases.py index 9f8bcd2..553c5c6 100644 --- a/tests/etag/test_edge_cases.py +++ b/tests/etag/test_edge_cases.py @@ -3,10 +3,10 @@ """ import pytest - from nexios import NexiosApp from nexios.http import Request, Response from nexios.testclient import TestClient + from nexios_contrib.etag import ( ETagMiddleware, etag_matches, @@ -184,7 +184,9 @@ async def handler(request, response): assert resp2.headers["etag"] == valid_etag # Request with mix of valid and invalid ETags - resp3 = client.get("/test", headers={"if-none-match": f'"valid", invalid, {valid_etag}'}) + resp3 = client.get( + "/test", headers={"if-none-match": f'"valid", invalid, {valid_etag}'} + ) assert resp3.status_code == 304 # Should match the valid one def test_malformed_if_none_match_header(self): @@ -228,9 +230,6 @@ def test_etag_matches_with_invalid_inputs(self): assert not etag_matches(valid_etag, []) assert not etag_matches("", []) - - - def test_generate_etag_from_bytes_edge_cases(self): """Test generate_etag_from_bytes with edge cases.""" # Empty bytes diff --git a/tests/etag/test_helper_functions.py b/tests/etag/test_helper_functions.py index d3865a8..6f58e48 100644 --- a/tests/etag/test_helper_functions.py +++ b/tests/etag/test_helper_functions.py @@ -5,6 +5,7 @@ import typing from nexios.http import Request, Response + from nexios_contrib.etag import ( compute_and_set_etag, etag_matches, @@ -36,7 +37,7 @@ def create_mock_scope( method: str = "GET", path: str = "/", headers: typing.Optional[typing.Dict[str, str]] = None, - query_string: bytes = b"" + query_string: bytes = b"", ) -> Scope: """Create a mock ASGI scope for testing.""" return { @@ -45,7 +46,9 @@ def create_mock_scope( "path": path, "raw_path": path.encode("utf-8"), "query_string": query_string, - "headers": [(k.lower().encode(), v.encode()) for k, v in (headers or {}).items()], + "headers": [ + (k.lower().encode(), v.encode()) for k, v in (headers or {}).items() + ], "server": ("testserver", 80), "client": ("testclient", 12345), "scheme": "http", @@ -73,7 +76,7 @@ def test_generate_etag_from_bytes_strong(self): """Test generating strong ETag.""" data = b"Hello, World!" etag = generate_etag_from_bytes(data, weak=False) - assert not etag.startswith('W/') + assert not etag.startswith("W/") assert etag.endswith('"') # Should be consistent for same input etag2 = generate_etag_from_bytes(data, weak=False) @@ -95,7 +98,7 @@ def test_generate_etag_from_bytes_empty_data(self): def test_generate_etag_from_bytes_unicode(self): """Test generating ETag from unicode data.""" - data = "Hello, 世界!".encode('utf-8') + data = "Hello, 世界!".encode("utf-8") etag = generate_etag_from_bytes(data) assert etag.startswith('W/"') assert etag.endswith('"') @@ -134,8 +137,6 @@ def test_normalize_etag_with_spaces(self): normalized = normalize_etag(etag) assert normalized == '"abc123"' - - class TestSetResponseEtag: """Test set_response_etag function.""" @@ -205,7 +206,7 @@ def test_compute_and_set_etag_weak_false(self): request = Request(scope, mock_receive, mock_send) response = Response(request).empty() etag = compute_and_set_etag(response, body=b"Hello, World!", weak=False) - assert not etag.startswith('W/') + assert not etag.startswith("W/") assert response.headers.get("etag") == etag @@ -247,7 +248,6 @@ def test_parse_if_none_match_empty(self): etags = parse_if_none_match(request) assert etags == [] - class TestParseIfMatch: """Test parse_if_match function.""" @@ -273,8 +273,6 @@ def test_parse_if_match_empty(self): etags = parse_if_match(request) assert etags == [] - - class TestEtagMatches: """Test etag_matches function.""" @@ -298,8 +296,12 @@ def test_etag_matches_weak_compare_true(self): def test_etag_matches_multiple_candidates(self): """Test matching against multiple candidates.""" - assert etag_matches('"abc123"', ['"def456"', '"abc123"', '"ghi789"'], weak_compare=False) - assert not etag_matches('"abc123"', ['"def456"', '"xyz123"'], weak_compare=False) + assert etag_matches( + '"abc123"', ['"def456"', '"abc123"', '"ghi789"'], weak_compare=False + ) + assert not etag_matches( + '"abc123"', ['"def456"', '"xyz123"'], weak_compare=False + ) def test_etag_matches_invalid_etag(self): """Test matching with invalid ETag.""" diff --git a/tests/etag/test_integration.py b/tests/etag/test_integration.py index 3b97ddb..9ee0d39 100644 --- a/tests/etag/test_integration.py +++ b/tests/etag/test_integration.py @@ -3,9 +3,9 @@ """ import pytest - from nexios import NexiosApp from nexios.testclient import TestClient + from nexios_contrib.etag import ETagMiddleware @@ -182,16 +182,12 @@ def test_etag_override_behavior(self): async def handler_no_override(request, response): response.json({"data": "test"}).set_header("etag", '"manual-etag"') - - client = TestClient(app) # Without override - should keep manual ETag resp1 = client.get("/no-override") assert resp1.headers["etag"] == '"manual-etag"' - - def test_etag_strong_vs_weak(self): """Test strong vs weak ETag generation.""" weak_app = NexiosApp() @@ -218,7 +214,7 @@ async def strong_handler(request, response): assert weak_resp.headers["etag"].startswith('W/"') # Strong ETag should not start with W/ - assert not strong_resp.headers["etag"].startswith('W/') + assert not strong_resp.headers["etag"].startswith("W/") assert strong_resp.headers["etag"].startswith('"') # Weak comparison should work between weak and strong diff --git a/tests/etag/test_middleware.py b/tests/etag/test_middleware.py index 3f1064e..50a1045 100644 --- a/tests/etag/test_middleware.py +++ b/tests/etag/test_middleware.py @@ -3,10 +3,10 @@ """ import pytest - from nexios import NexiosApp from nexios.http import Request, Response from nexios.testclient import TestClient + from nexios_contrib.etag import ETagMiddleware @@ -95,7 +95,7 @@ async def handler(request, response): with test_client_factory(app) as client: resp = client.get("/test") - assert not resp.headers["etag"].startswith('W/') + assert not resp.headers["etag"].startswith("W/") assert resp.headers["etag"].startswith('"') def test_middleware_does_not_override_existing_etag(self, test_client_factory): @@ -105,7 +105,9 @@ def test_middleware_does_not_override_existing_etag(self, test_client_factory): @app.get("/test") async def handler(request, response): - return response.json({"message": "Hello, World!"}).set_header("etag", '"custom-etag"') + return response.json({"message": "Hello, World!"}).set_header( + "etag", '"custom-etag"' + ) with test_client_factory(app) as client: resp = client.get("/test") @@ -118,7 +120,9 @@ def test_middleware_overrides_existing_etag(self, test_client_factory): @app.get("/test") async def handler(request, response): - return response.json({"message": "Hello, World!"}).set_header("etag", '"custom-etag"') + return response.json({"message": "Hello, World!"}).set_header( + "etag", '"custom-etag"' + ) with test_client_factory(app) as client: resp = client.get("/test") @@ -237,7 +241,10 @@ async def handler(request, response): etag = resp1.headers["etag"] # Second request with multiple ETags including the matching one - resp2 = client.get("/test", headers={"if-none-match": f'"other-etag", {etag}, "another-etag"'}) + resp2 = client.get( + "/test", + headers={"if-none-match": f'"other-etag", {etag}, "another-etag"'}, + ) assert resp2.status_code == 304 def test_conditional_get_weak_etag_match(self, test_client_factory): @@ -255,7 +262,9 @@ async def handler(request, response): etag = resp1.headers["etag"] # Second request with weak version of same ETag - weak_etag = etag.replace('W/"', '"') if etag.startswith('W/') else f'W/{etag}' + weak_etag = ( + etag.replace('W/"', '"') if etag.startswith("W/") else f"W/{etag}" + ) resp2 = client.get("/test", headers={"if-none-match": weak_etag}) assert resp2.status_code == 304 @@ -286,7 +295,7 @@ def test_conditional_request_no_etag_on_response(self, test_client_factory): @app.get("/test") async def handler(request, response): # Manually remove any ETag that might be set - return response.json({"message":"hell world"}).remove_header("etag") + return response.json({"message": "hell world"}).remove_header("etag") with test_client_factory(app) as client: resp = client.get("/test", headers={"if-none-match": '"some-etag"'}) diff --git a/tests/jrpc/conftest.py b/tests/jrpc/conftest.py index b27ae4e..b5bcd46 100644 --- a/tests/jrpc/conftest.py +++ b/tests/jrpc/conftest.py @@ -3,12 +3,12 @@ """ import functools -from typing import Callable, Optional, Any +from typing import Any, Callable, Optional import pytest - from nexios import NexiosApp from nexios.testclient import TestClient + from nexios_contrib.jrpc import JsonRpcPlugin, JsonRpcRegistry, get_registry @@ -21,6 +21,7 @@ def test_client_factory(): @pytest.fixture def app_factory(): """Factory for creating NexiosApp instances with optional JRPC plugin.""" + def _create_app(jrpc_config: Optional[dict] = None): app = NexiosApp() if jrpc_config: diff --git a/tests/jrpc/test_client.py b/tests/jrpc/test_client.py index a59bf20..a17b085 100644 --- a/tests/jrpc/test_client.py +++ b/tests/jrpc/test_client.py @@ -3,10 +3,11 @@ """ import json -import pytest +import pytest from nexios import NexiosApp from nexios.testclient import TestClient + from nexios_contrib.jrpc import JsonRpcClient, JsonRpcRegistry @@ -33,6 +34,7 @@ def divide(a: float, b: float) -> float: app = NexiosApp() from nexios_contrib.jrpc import JsonRpcPlugin + JsonRpcPlugin(app) return app @@ -90,7 +92,7 @@ def test_call_method_sync(self, client_app): # This would normally make HTTP requests # For testing purposes, we test the method structure - assert hasattr(client, 'call') + assert hasattr(client, "call") assert callable(client.call) # Test parameters @@ -104,10 +106,9 @@ def test_acall_method_async(self, client_app): client = JsonRpcClient("http://localhost:8000/rpc") # This would normally make async HTTP requests - assert hasattr(client, 'acall') + assert hasattr(client, "acall") assert callable(client.acall) - def test_client_with_different_base_urls(self, client_app): """Test client with different base URLs.""" diff --git a/tests/jrpc/test_configuration.py b/tests/jrpc/test_configuration.py index 1839780..e8251f9 100644 --- a/tests/jrpc/test_configuration.py +++ b/tests/jrpc/test_configuration.py @@ -3,9 +3,9 @@ """ import pytest - from nexios import NexiosApp from nexios.testclient import TestClient + from nexios_contrib.jrpc import JsonRpcPlugin, JsonRpcRegistry @@ -26,7 +26,7 @@ def test_method(a: int) -> int: "jsonrpc": "2.0", "method": "test_method", "params": {"a": 5}, - "id": 1 + "id": 1, } response = client.post("/rpc", json=payload) @@ -47,7 +47,7 @@ def test_method(a: int) -> int: "jsonrpc": "2.0", "method": "test_method", "params": {"a": 4}, - "id": 1 + "id": 1, } # Should work with custom prefix @@ -73,7 +73,7 @@ def test_method(a: int) -> int: "jsonrpc": "2.0", "method": "test_method", "params": {"a": 5}, - "id": 1 + "id": 1, } response = client.post("/", json=payload) @@ -94,7 +94,7 @@ def test_method(a: int) -> int: "jsonrpc": "2.0", "method": "test_method", "params": {"a": 3}, - "id": 1 + "id": 1, } # Should work with nested prefix @@ -106,7 +106,9 @@ def test_method(a: int) -> int: response_short = client.post("/api/rpc", json=payload) assert response_short.status_code == 404 - def test_plugin_without_explicit_config(self, app_factory, test_client_factory, registry): + def test_plugin_without_explicit_config( + self, app_factory, test_client_factory, registry + ): """Test plugin with default configuration.""" @registry.register() @@ -122,7 +124,7 @@ def test_method(a: int) -> int: "jsonrpc": "2.0", "method": "test_method", "params": {"a": 5}, - "id": 1 + "id": 1, } response = client.post("/rpc", json=payload) @@ -151,30 +153,22 @@ def method_two(a: int) -> int: return a + 2 # Test first app - payload1 = { - "jsonrpc": "2.0", - "method": "method1", - "params": {"a": 5}, - "id": 1 - } + payload1 = {"jsonrpc": "2.0", "method": "method1", "params": {"a": 5}, "id": 1} response1 = client1.post("/api1", json=payload1) assert response1.status_code == 200 assert response1.json()["result"] == 6 # Test second app - payload2 = { - "jsonrpc": "2.0", - "method": "method2", - "params": {"a": 5}, - "id": 1 - } + payload2 = {"jsonrpc": "2.0", "method": "method2", "params": {"a": 5}, "id": 1} response2 = client2.post("/api2", json=payload2) assert response2.status_code == 200 assert response2.json()["result"] == 7 - def test_plugin_integration_with_other_routes(self, app_factory, test_client_factory, registry): + def test_plugin_integration_with_other_routes( + self, app_factory, test_client_factory, registry + ): """Test JRPC plugin alongside other routes.""" app = app_factory({"path_prefix": "/jsonrpc"}) @@ -207,7 +201,7 @@ async def get_users(request, response): "jsonrpc": "2.0", "method": "calculate", "params": {"a": 10, "b": 20}, - "id": 1 + "id": 1, } jrpc_response = client.post("/jsonrpc", json=jrpc_payload) diff --git a/tests/jrpc/test_edge_cases.py b/tests/jrpc/test_edge_cases.py index 0c3314e..d492eb8 100644 --- a/tests/jrpc/test_edge_cases.py +++ b/tests/jrpc/test_edge_cases.py @@ -6,9 +6,9 @@ from typing import Any, Dict import pytest - from nexios import NexiosApp from nexios.testclient import TestClient + from nexios_contrib.jrpc import JsonRpcPlugin, JsonRpcRegistry @@ -26,7 +26,7 @@ def echo(value: Any) -> Any: "jsonrpc": "2.0", "method": "echo", "params": {"value": "test"}, - "id": 999999999999999999999999999999999999999999999999999999999999999999 + "id": 999999999999999999999999999999999999999999999999999999999999999999, } response = jrpc_client.post("/rpc", json=payload) @@ -34,7 +34,10 @@ def echo(value: Any) -> Any: result = response.json() assert result["result"] == "test" - assert result["id"] == 999999999999999999999999999999999999999999999999999999999999999999 + assert ( + result["id"] + == 999999999999999999999999999999999999999999999999999999999999999999 + ) def test_string_request_id(self, jrpc_app, jrpc_client, registry): """Test with string request ID.""" @@ -47,7 +50,7 @@ def echo(value: Any) -> Any: "jsonrpc": "2.0", "method": "echo", "params": {"value": "test"}, - "id": "unique-request-id-123" + "id": "unique-request-id-123", } response = jrpc_client.post("/rpc", json=payload) @@ -68,7 +71,7 @@ def echo(value: Any) -> Any: "jsonrpc": "2.0", "method": "echo", "params": {"value": "test"}, - "id": None + "id": None, } response = jrpc_client.post("/rpc", json=payload) @@ -78,29 +81,19 @@ def echo(value: Any) -> Any: assert result["result"] == "test" assert result["id"] is None - - - - - - def test_method_with_mixed_param_types(self, jrpc_app, jrpc_client, registry): """Test method with mixed parameter types.""" @registry.register() def mixed_params( - number: int, - text: str, - flag: bool, - numbers: list, - config: dict + number: int, text: str, flag: bool, numbers: list, config: dict ) -> dict: return { "number": number, "text": text, "flag": flag, "sum": sum(numbers), - "config_keys": list(config.keys()) + "config_keys": list(config.keys()), } payload = { @@ -111,9 +104,9 @@ def mixed_params( "text": "test", "flag": True, "numbers": [1, 2, 3], - "config": {"host": "localhost", "port": 8080} + "config": {"host": "localhost", "port": 8080}, }, - "id": 1 + "id": 1, } response = jrpc_client.post("/rpc", json=payload) @@ -140,7 +133,7 @@ def increment(counter: int) -> int: "jsonrpc": "2.0", "method": "increment", "params": {"counter": i}, - "id": i + "id": i, } response = jrpc_client.post("/rpc", json=payload) @@ -151,7 +144,9 @@ def increment(counter: int) -> int: expected = list(range(1, 11)) assert results == expected - def test_method_with_special_characters_in_name(self, jrpc_app, jrpc_client, registry): + def test_method_with_special_characters_in_name( + self, jrpc_app, jrpc_client, registry + ): """Test method with special characters in name.""" @registry.register("method_with_special-chars.test") @@ -162,7 +157,7 @@ def special_method(value: str) -> str: "jsonrpc": "2.0", "method": "method_with_special-chars.test", "params": {"value": "test"}, - "id": 1 + "id": 1, } response = jrpc_client.post("/rpc", json=payload) @@ -172,12 +167,7 @@ def special_method(value: str) -> str: def test_empty_method_name(self, jrpc_app, jrpc_client): """Test empty method name error.""" - payload = { - "jsonrpc": "2.0", - "method": "", - "params": {}, - "id": 1 - } + payload = {"jsonrpc": "2.0", "method": "", "params": {}, "id": 1} response = jrpc_client.post("/rpc", json=payload) assert response.status_code == 200 @@ -198,7 +188,7 @@ def method_with_spaces() -> str: "jsonrpc": "2.0", "method": "method_with_spaces", "params": {}, - "id": 1 + "id": 1, } response = jrpc_client.post("/rpc", json=payload) @@ -214,12 +204,7 @@ def test_very_long_method_name(self, jrpc_app, jrpc_client, registry): def very_long_name_method() -> str: return "success" - payload = { - "jsonrpc": "2.0", - "method": long_name, - "params": {}, - "id": 1 - } + payload = {"jsonrpc": "2.0", "method": long_name, "params": {}, "id": 1} response = jrpc_client.post("/rpc", json=payload) assert response.status_code == 200 @@ -236,7 +221,7 @@ def generate_large_data(size: int) -> list: "jsonrpc": "2.0", "method": "generate_large_data", "params": {"size": 1000}, - "id": 1 + "id": 1, } response = jrpc_client.post("/rpc", json=payload) diff --git a/tests/jrpc/test_exceptions.py b/tests/jrpc/test_exceptions.py index 87ed8f9..e696b6c 100644 --- a/tests/jrpc/test_exceptions.py +++ b/tests/jrpc/test_exceptions.py @@ -5,33 +5,26 @@ import json import pytest - from nexios import NexiosApp from nexios.testclient import TestClient + from nexios_contrib.jrpc import JsonRpcPlugin, JsonRpcRegistry from nexios_contrib.jrpc.exceptions import ( + JsonRpcClientError, JsonRpcError, - JsonRpcMethodNotFound, JsonRpcInvalidParams, JsonRpcInvalidRequest, - JsonRpcClientError + JsonRpcMethodNotFound, ) class TestJRPCExceptions: """Tests for JRPC exception handling.""" - - - def test_invalid_request_missing_method(self, jrpc_app, jrpc_client): """Test invalid request - missing method.""" - payload = { - "jsonrpc": "2.0", - "params": {}, - "id": 1 - } + payload = {"jsonrpc": "2.0", "params": {}, "id": 1} response = jrpc_client.post("/rpc", json=payload) assert response.status_code == 200 @@ -44,12 +37,7 @@ def test_invalid_request_missing_method(self, jrpc_app, jrpc_client): def test_invalid_request_invalid_method_type(self, jrpc_app, jrpc_client): """Test invalid request - method is not a string.""" - payload = { - "jsonrpc": "2.0", - "method": 123, - "params": {}, - "id": 1 - } + payload = {"jsonrpc": "2.0", "method": 123, "params": {}, "id": 1} response = jrpc_client.post("/rpc", json=payload) assert response.status_code == 200 @@ -66,7 +54,7 @@ def test_invalid_request_invalid_params_type(self, jrpc_app, jrpc_client): "jsonrpc": "2.0", "method": "test", "params": "invalid_params", - "id": 1 + "id": 1, } response = jrpc_client.post("/rpc", json=payload) @@ -84,7 +72,7 @@ def test_invalid_request_invalid_id_type(self, jrpc_app, jrpc_client): "jsonrpc": "2.0", "method": "test", "params": {}, - "id": {"invalid": "id"} + "id": {"invalid": "id"}, } response = jrpc_client.post("/rpc", json=payload) @@ -95,8 +83,6 @@ def test_invalid_request_invalid_id_type(self, jrpc_app, jrpc_client): assert result["error"]["code"] == -32600 assert "Id must be a string or number" in result["error"]["message"] - - def test_method_exception_handling(self, jrpc_app, jrpc_client, registry): """Test handling of exceptions raised in methods.""" @@ -104,12 +90,7 @@ def test_method_exception_handling(self, jrpc_app, jrpc_client, registry): def failing_method() -> str: raise ValueError("Something went wrong") - payload = { - "jsonrpc": "2.0", - "method": "failing_method", - "params": {}, - "id": 1 - } + payload = {"jsonrpc": "2.0", "method": "failing_method", "params": {}, "id": 1} response = jrpc_client.post("/rpc", json=payload) assert response.status_code == 200 @@ -120,7 +101,6 @@ def failing_method() -> str: assert "Internal error" in result["error"]["message"] assert "Something went wrong" in result["error"]["data"] - def test_notification_request(self, jrpc_app, jrpc_client, registry): """Test notification request (no id).""" @@ -132,7 +112,7 @@ def log_message(message: str) -> str: payload = { "jsonrpc": "2.0", "method": "log_message", - "params": {"message": "test"} + "params": {"message": "test"}, } response = jrpc_client.post("/rpc", json=payload) @@ -145,10 +125,6 @@ def log_message(message: str) -> str: # But no id field for notifications assert result["id"] is None - - - - def test_json_rpc_error_class(self): """Test JsonRpcError class.""" @@ -159,7 +135,9 @@ def test_json_rpc_error_class(self): assert error.data is None # Test with data - error_with_data = JsonRpcError(code=-32001, message="Error with data", data={"extra": "info"}) + error_with_data = JsonRpcError( + code=-32001, message="Error with data", data={"extra": "info"} + ) assert error_with_data.data == {"extra": "info"} def test_json_rpc_method_not_found_exception(self): diff --git a/tests/jrpc/test_integration.py b/tests/jrpc/test_integration.py index 7003212..a19b09b 100644 --- a/tests/jrpc/test_integration.py +++ b/tests/jrpc/test_integration.py @@ -3,12 +3,12 @@ """ import json -from typing import Dict, Any +from typing import Any, Dict import pytest - from nexios import NexiosApp from nexios.testclient import TestClient + from nexios_contrib.jrpc import JsonRpcPlugin, JsonRpcRegistry, get_registry @@ -27,7 +27,7 @@ def add(a: int, b: int) -> int: "jsonrpc": "2.0", "method": "add", "params": {"a": 5, "b": 3}, - "id": 1 + "id": 1, } response = jrpc_client.post("/rpc", json=payload) @@ -45,12 +45,7 @@ def test_method_call_with_positional_params(self, jrpc_app, jrpc_client, registr def multiply(a: int, b: int) -> int: return a * b - payload = { - "jsonrpc": "2.0", - "method": "multiply", - "params": [4, 7], - "id": 2 - } + payload = {"jsonrpc": "2.0", "method": "multiply", "params": [4, 7], "id": 2} response = jrpc_client.post("/rpc", json=payload) assert response.status_code == 200 @@ -70,7 +65,7 @@ async def async_add(a: int, b: int) -> int: "jsonrpc": "2.0", "method": "async_add", "params": {"a": 10, "b": 5}, - "id": 3 + "id": 3, } response = jrpc_client.post("/rpc", json=payload) @@ -96,7 +91,7 @@ def divide(a: float, b: float) -> float: "jsonrpc": "2.0", "method": "subtract", "params": {"a": 10, "b": 3}, - "id": 4 + "id": 4, } response1 = jrpc_client.post("/rpc", json=payload1) @@ -108,7 +103,7 @@ def divide(a: float, b: float) -> float: "jsonrpc": "2.0", "method": "divide", "params": {"a": 15.0, "b": 3.0}, - "id": 5 + "id": 5, } response2 = jrpc_client.post("/rpc", json=payload2) @@ -127,7 +122,7 @@ def greet(name: str, greeting: str = "Hello") -> str: "jsonrpc": "2.0", "method": "greet", "params": {"name": "Alice"}, - "id": 6 + "id": 6, } response1 = jrpc_client.post("/rpc", json=payload1) @@ -139,7 +134,7 @@ def greet(name: str, greeting: str = "Hello") -> str: "jsonrpc": "2.0", "method": "greet", "params": {"name": "Bob", "greeting": "Hi"}, - "id": 7 + "id": 7, } response2 = jrpc_client.post("/rpc", json=payload2) @@ -153,18 +148,12 @@ def test_method_without_parameters(self, jrpc_app, jrpc_client, registry): def get_time() -> str: return "2023-01-01" - payload = { - "jsonrpc": "2.0", - "method": "get_time", - "params": {}, - "id": 8 - } + payload = {"jsonrpc": "2.0", "method": "get_time", "params": {}, "id": 8} response = jrpc_client.post("/rpc", json=payload) assert response.status_code == 200 assert response.json()["result"] == "2023-01-01" - def test_custom_method_name(self, jrpc_app, jrpc_client, registry): """Test method registration with custom name.""" @@ -176,7 +165,7 @@ def add_numbers(a: int, b: int) -> int: "jsonrpc": "2.0", "method": "custom_sum", "params": {"a": 7, "b": 8}, - "id": 9 + "id": 9, } response = jrpc_client.post("/rpc", json=payload) @@ -207,41 +196,31 @@ def get_dict() -> dict: return {"key": "value", "number": 123} # Test string - response1 = jrpc_client.post("/rpc", json={ - "jsonrpc": "2.0", - "method": "get_string", - "id": 10 - }) + response1 = jrpc_client.post( + "/rpc", json={"jsonrpc": "2.0", "method": "get_string", "id": 10} + ) assert response1.json()["result"] == "test string" # Test number - response2 = jrpc_client.post("/rpc", json={ - "jsonrpc": "2.0", - "method": "get_number", - "id": 11 - }) + response2 = jrpc_client.post( + "/rpc", json={"jsonrpc": "2.0", "method": "get_number", "id": 11} + ) assert response2.json()["result"] == 42 # Test boolean - response3 = jrpc_client.post("/rpc", json={ - "jsonrpc": "2.0", - "method": "get_boolean", - "id": 12 - }) + response3 = jrpc_client.post( + "/rpc", json={"jsonrpc": "2.0", "method": "get_boolean", "id": 12} + ) assert response3.json()["result"] is True # Test list - response4 = jrpc_client.post("/rpc", json={ - "jsonrpc": "2.0", - "method": "get_list", - "id": 13 - }) + response4 = jrpc_client.post( + "/rpc", json={"jsonrpc": "2.0", "method": "get_list", "id": 13} + ) assert response4.json()["result"] == [1, 2, 3, 4, 5] # Test dict - response5 = jrpc_client.post("/rpc", json={ - "jsonrpc": "2.0", - "method": "get_dict", - "id": 14 - }) + response5 = jrpc_client.post( + "/rpc", json={"jsonrpc": "2.0", "method": "get_dict", "id": 14} + ) assert response5.json()["result"] == {"key": "value", "number": 123} diff --git a/tests/jrpc/test_registry.py b/tests/jrpc/test_registry.py index fc04198..a0d07fc 100644 --- a/tests/jrpc/test_registry.py +++ b/tests/jrpc/test_registry.py @@ -128,8 +128,6 @@ def second_method() -> str: assert registry.get_method("first_method") is first_method assert registry.get_method("second_method") is second_method - - def test_registry_isolation_between_tests(self, registry): """Test that registry is clean between tests.""" diff --git a/tests/request_id/test_configuration.py b/tests/request_id/test_configuration.py index 3ec0e09..03bf5b1 100644 --- a/tests/request_id/test_configuration.py +++ b/tests/request_id/test_configuration.py @@ -3,11 +3,12 @@ """ import uuid -import pytest +import pytest from nexios import NexiosApp from nexios.testclient import TestClient -from nexios_contrib.request_id import RequestIdMiddleware, RequestId + +from nexios_contrib.request_id import RequestId, RequestIdMiddleware from nexios_contrib.request_id.dependency import RequestIdDepend @@ -20,7 +21,7 @@ def test_request_id_dependency_injection(self, test_client_factory): app.add_middleware(RequestIdMiddleware()) @app.get("/test") - async def handler(request,response,request_id: str = RequestIdDepend()): + async def handler(request, response, request_id: str = RequestIdDepend()): return {"request_id": request_id} with test_client_factory(app) as client: @@ -45,7 +46,9 @@ def test_request_id_dependency_custom_attribute(self, test_client_factory): app.add_middleware(RequestIdMiddleware(request_attribute_name="custom_id")) @app.get("/test") - async def handler(request, response, request_id: str = RequestIdDepend("custom_id")): + async def handler( + request, response, request_id: str = RequestIdDepend("custom_id") + ): return {"request_id": request_id} with test_client_factory(app) as client: @@ -67,7 +70,7 @@ def test_request_id_dependency_without_middleware(self, test_client_factory): # Note: No middleware added @app.get("/test") - async def handler(request, response,request_id: str = RequestIdDepend()): + async def handler(request, response, request_id: str = RequestIdDepend()): return {"request_id": request_id} with test_client_factory(app) as client: @@ -85,11 +88,11 @@ def test_request_id_dependency_multiple_endpoints(self, test_client_factory): app.add_middleware(RequestIdMiddleware()) @app.get("/endpoint1") - async def handler1(request, response,request_id: str = RequestIdDepend()): + async def handler1(request, response, request_id: str = RequestIdDepend()): return {"endpoint": "1", "request_id": request_id} @app.get("/endpoint2") - async def handler2(request, response,request_id: str = RequestIdDepend()): + async def handler2(request, response, request_id: str = RequestIdDepend()): return {"endpoint": "2", "request_id": request_id} with test_client_factory(app) as client: @@ -117,8 +120,6 @@ async def handler2(request, response,request_id: str = RequestIdDepend()): class TestRequestIdConfiguration: """Test Request ID configuration options.""" - - def test_force_generate_configuration(self, test_client_factory): """Test force_generate configuration option.""" app = NexiosApp() @@ -218,11 +219,13 @@ async def handler(request, response): def test_convenience_function_configuration(self, test_client_factory): """Test RequestId convenience function with custom configuration.""" app = NexiosApp() - app.add_middleware(RequestId( - header_name="X-Custom-Request-ID", - force_generate=True, - request_attribute_name="custom_id" - )) + app.add_middleware( + RequestId( + header_name="X-Custom-Request-ID", + force_generate=True, + request_attribute_name="custom_id", + ) + ) @app.get("/test") async def handler(request, response): @@ -249,13 +252,15 @@ async def handler(request, response): def test_all_configuration_options_together(self, test_client_factory): """Test all configuration options working together.""" app = NexiosApp() - app.add_middleware(RequestIdMiddleware( - header_name="X-Trace-ID", - force_generate=True, - store_in_request=True, - request_attribute_name="trace_id", - include_in_response=True - )) + app.add_middleware( + RequestIdMiddleware( + header_name="X-Trace-ID", + force_generate=True, + store_in_request=True, + request_attribute_name="trace_id", + include_in_response=True, + ) + ) @app.get("/test") async def handler(request, response): diff --git a/tests/request_id/test_helper_functions.py b/tests/request_id/test_helper_functions.py index bbb48f7..55d31bf 100644 --- a/tests/request_id/test_helper_functions.py +++ b/tests/request_id/test_helper_functions.py @@ -3,17 +3,18 @@ """ import uuid -import pytest +import pytest from nexios.http import Request, Response + from nexios_contrib.request_id.helper import ( generate_request_id, + get_or_generate_request_id, get_request_id_from_header, + get_request_id_from_request, set_request_id_header, - get_or_generate_request_id, - validate_request_id, store_request_id_in_request, - get_request_id_from_request, + validate_request_id, ) @@ -36,6 +37,7 @@ def test_generate_request_id(self): def test_get_request_id_from_header_found(self): """Test getting request ID from header when present.""" + # Mock request object with headers class MockRequest: def __init__(self, headers): @@ -48,6 +50,7 @@ def __init__(self, headers): def test_get_request_id_from_header_not_found(self): """Test getting request ID from header when not present.""" + # Mock request object without the header class MockRequest: def __init__(self, headers): @@ -60,22 +63,22 @@ def __init__(self, headers): def test_get_request_id_from_header_custom_header(self): """Test getting request ID from custom header.""" + # Mock request object with custom header class MockRequest: def __init__(self, headers): self.headers = headers - request = MockRequest({"X-Custom-Request-ID": "550e8400-e29b-41d4-a716-446655440000"}) + request = MockRequest( + {"X-Custom-Request-ID": "550e8400-e29b-41d4-a716-446655440000"} + ) result = get_request_id_from_header(request, "X-Custom-Request-ID") assert result == "550e8400-e29b-41d4-a716-446655440000" - - - - def test_get_or_generate_request_id_from_header(self): """Test get_or_generate_request_id when header is present.""" + # Mock request object with header class MockRequest: def __init__(self, headers): @@ -88,6 +91,7 @@ def __init__(self, headers): def test_get_or_generate_request_id_generate_new(self): """Test get_or_generate_request_id when header is not present.""" + # Mock request object without header class MockRequest: def __init__(self, headers): @@ -121,6 +125,7 @@ def test_validate_request_id_none(self): def test_store_request_id_in_request(self): """Test storing request ID in request object.""" + # Mock request object with state class MockState: def __init__(self): @@ -142,6 +147,7 @@ def __init__(self): def test_store_request_id_in_request_custom_attribute(self): """Test storing request ID in request object with custom attribute name.""" + # Mock request object with state class MockState: def __init__(self): @@ -162,11 +168,6 @@ def __init__(self): assert request.state.data["custom_request_id"] == request_id assert "request_id" not in request.state.data - - - - - def test_uuid_generation_consistency(self): """Test that generated UUIDs are unique.""" uuids = [generate_request_id() for _ in range(100)] @@ -177,5 +178,3 @@ def test_uuid_generation_consistency(self): # All should be valid UUIDs for uid in uuids: uuid.UUID(uid) # Should not raise exception - - diff --git a/tests/request_id/test_integration.py b/tests/request_id/test_integration.py index 114c88a..2a87982 100644 --- a/tests/request_id/test_integration.py +++ b/tests/request_id/test_integration.py @@ -3,10 +3,10 @@ """ import pytest - from nexios import NexiosApp from nexios.testclient import TestClient -from nexios_contrib.request_id import RequestIdMiddleware, RequestId + +from nexios_contrib.request_id import RequestId, RequestIdMiddleware class TestRequestIdIntegration: @@ -33,6 +33,7 @@ async def get_users(request, response): # Validate UUID format import uuid + assert uuid.UUID(request_id) is not None def test_request_id_persistence_across_requests(self): @@ -57,7 +58,11 @@ async def get_data(request, response): assert "X-Request-ID" in resp3.headers # All request IDs should be unique - request_ids = [resp1.headers["X-Request-ID"], resp2.headers["X-Request-ID"], resp3.headers["X-Request-ID"]] + request_ids = [ + resp1.headers["X-Request-ID"], + resp2.headers["X-Request-ID"], + resp3.headers["X-Request-ID"], + ] assert len(set(request_ids)) == 3 def test_request_id_extraction_from_headers(self): @@ -140,7 +145,7 @@ async def delete_user(request, response): request_ids = [ resp1.headers["X-Request-ID"], resp2.headers["X-Request-ID"], - resp3.headers["X-Request-ID"] + resp3.headers["X-Request-ID"], ] assert len(set(request_ids)) == 3 @@ -191,6 +196,7 @@ async def get_data(request, response): # Should be a valid UUID import uuid + assert uuid.UUID(resp.headers["X-Request-ID"]) is not None def test_request_id_custom_header_name(self): @@ -213,6 +219,7 @@ async def test_endpoint(request, response): # Should be valid UUID import uuid + assert uuid.UUID(resp.headers["X-Custom-Request-ID"]) is not None def test_request_id_without_response_inclusion(self): @@ -252,4 +259,5 @@ async def test_endpoint(request, response): # Should be valid UUID import uuid + assert uuid.UUID(resp.headers["X-Request-ID"]) is not None diff --git a/tests/request_id/test_middleware.py b/tests/request_id/test_middleware.py index accfbf3..4574e91 100644 --- a/tests/request_id/test_middleware.py +++ b/tests/request_id/test_middleware.py @@ -2,11 +2,12 @@ Tests for RequestIdMiddleware. """ -import pytest import uuid +import pytest from nexios import NexiosApp from nexios.testclient import TestClient + from nexios_contrib.request_id import RequestIdMiddleware @@ -94,7 +95,7 @@ async def handler(request, response): request_ids = [ resp1.headers["X-Request-ID"], resp2.headers["X-Request-ID"], - resp3.headers["X-Request-ID"] + resp3.headers["X-Request-ID"], ] # All request IDs should be unique @@ -168,7 +169,9 @@ async def handler(request, response): with test_client_factory(app) as client: custom_request_id = "550e8400-e29b-41d4-a716-446655440000" - resp = client.get("/test", headers={"X-Custom-Request-ID": custom_request_id}) + resp = client.get( + "/test", headers={"X-Custom-Request-ID": custom_request_id} + ) assert resp.status_code == 200 assert resp.headers["X-Custom-Request-ID"] == custom_request_id @@ -200,7 +203,9 @@ async def handler(request, response): def test_middleware_custom_attribute_name(self, test_client_factory): """Test middleware with custom request attribute name.""" app = NexiosApp() - app.add_middleware(RequestIdMiddleware(request_attribute_name="custom_request_id")) + app.add_middleware( + RequestIdMiddleware(request_attribute_name="custom_request_id") + ) @app.get("/test") async def handler(request, response): diff --git a/tests/slashes/test_middleware.py b/tests/slashes/test_middleware.py index 7a5823d..9769408 100644 --- a/tests/slashes/test_middleware.py +++ b/tests/slashes/test_middleware.py @@ -1,24 +1,27 @@ """ Integration tests for SlashesMiddleware. """ + import pytest from httpx import Response as HttpxResponse -from nexios_contrib.slashes.middleware import SlashesMiddleware, SlashAction +from nexios_contrib.slashes.middleware import SlashAction, SlashesMiddleware @pytest.mark.asyncio async def test_remove_trailing_slash_redirect(app_factory, test_client_factory): """Test that trailing slashes are removed with REDIRECT_REMOVE action.""" - app = app_factory(middleware=SlashesMiddleware(slash_action=SlashAction.REDIRECT_REMOVE)) - + app = app_factory( + middleware=SlashesMiddleware(slash_action=SlashAction.REDIRECT_REMOVE) + ) + @app.get("/test") - async def test_endpoint(request,response): + async def test_endpoint(request, response): return {"message": "success"} - - client = test_client_factory(app,follow_redirects=False) + + client = test_client_factory(app, follow_redirects=False) response = client.get("/test/") - + assert response.status_code == 301 assert response.headers["location"] == "http://testserver/test" @@ -26,15 +29,17 @@ async def test_endpoint(request,response): @pytest.mark.asyncio async def test_add_trailing_slash_redirect(app_factory, test_client_factory): """Test that missing trailing slashes are added with REDIRECT_ADD action.""" - app = app_factory(middleware=SlashesMiddleware(slash_action=SlashAction.REDIRECT_ADD)) - + app = app_factory( + middleware=SlashesMiddleware(slash_action=SlashAction.REDIRECT_ADD) + ) + @app.get("/test/") - async def test_endpoint(request,response): + async def test_endpoint(request, response): return {"message": "success"} - + client = test_client_factory(app) - response = client.get("/test", follow_redirects=False) - + response = client.get("/test", follow_redirects=False) + assert response.status_code == 301 assert response.headers["location"] == "http://testserver/test/" @@ -43,14 +48,14 @@ async def test_endpoint(request,response): async def test_remove_trailing_slash_inplace(app_factory, test_client_factory): """Test that trailing slashes are removed in-place with REMOVE action.""" app = app_factory(middleware=SlashesMiddleware(slash_action=SlashAction.REMOVE)) - + @app.get("/test") - async def test_endpoint(request,response): + async def test_endpoint(request, response): return {"message": "success"} - + client = test_client_factory(app) response = client.get("/test/") - + assert response.status_code == 200 assert response.json() == {"message": "success"} @@ -59,14 +64,14 @@ async def test_endpoint(request,response): async def test_add_trailing_slash_inplace(app_factory, test_client_factory): """Test that trailing slashes are added in-place with ADD action.""" app = app_factory(middleware=SlashesMiddleware(slash_action=SlashAction.ADD)) - + @app.get("/test/") - async def test_endpoint(request,response): + async def test_endpoint(request, response): return {"message": "success"} - + client = test_client_factory(app) response = client.get("/test") - + assert response.status_code == 200 assert response.json() == {"message": "success"} @@ -75,33 +80,33 @@ async def test_endpoint(request,response): async def test_double_slash_removal(app_factory, test_client_factory): """Test that double slashes are removed from URLs.""" app = app_factory(middleware=SlashesMiddleware(slash_action=SlashAction.IGNORE)) - + @app.get("/test/path") - async def test_endpoint(request,response): + async def test_endpoint(request, response): return {"message": "success"} - + client = test_client_factory(app) - response = client.get("/test//path") - + response = client.get("/test//path") + assert response.status_code == 200 assert response.json() == {"message": "success"} - - @pytest.mark.asyncio async def test_skip_processing_for_query_params(app_factory, test_client_factory): """Test that URLs with query parameters are handled correctly.""" - app = app_factory(middleware=SlashesMiddleware(slash_action=SlashAction.REDIRECT_REMOVE)) - + app = app_factory( + middleware=SlashesMiddleware(slash_action=SlashAction.REDIRECT_REMOVE) + ) + @app.get("/search") async def search_endpoint(request, response): q = request.query_params.get("q") return {"query": q} - + client = test_client_factory(app) - response = client.get("/search/?q=test", follow_redirects=False) - + response = client.get("/search/?q=test", follow_redirects=False) + # Should redirect to remove the trailing slash before the query assert response.status_code == 301 assert response.headers["location"] == "http://testserver/search?q=test" @@ -110,18 +115,20 @@ async def search_endpoint(request, response): @pytest.mark.asyncio async def test_custom_redirect_status_code(app_factory, test_client_factory): """Test that custom redirect status codes work.""" - app = app_factory(middleware=SlashesMiddleware( - slash_action=SlashAction.REDIRECT_REMOVE, - redirect_status_code=308 # Permanent Redirect - )) - + app = app_factory( + middleware=SlashesMiddleware( + slash_action=SlashAction.REDIRECT_REMOVE, + redirect_status_code=308, # Permanent Redirect + ) + ) + @app.get("/test") async def test_endpoint(request, response): return {"message": "success"} - + client = test_client_factory(app) response = client.get("/test/", follow_redirects=False) - + assert response.status_code == 308 assert response.headers["location"] == "http://testserver/test" @@ -130,17 +137,17 @@ async def test_endpoint(request, response): async def test_root_path_handling(app_factory, test_client_factory): """Test that root path is handled correctly.""" app = app_factory(middleware=SlashesMiddleware(slash_action=SlashAction.IGNORE)) - + @app.get("/") async def root_endpoint(request, response): return {"message": "root"} - + client = test_client_factory(app) response = client.get("/") - + assert response.status_code == 200 assert response.json() == {"message": "root"} - + # Root with trailing slash should also work response = client.get("//") assert response.status_code == 200 diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 341709d..318dcfc 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -1,45 +1,46 @@ +from typing import Optional + import pytest import strawberry from nexios.application import NexiosApp from nexios.testclient import TestClient + from nexios_contrib.graphql import GraphQL -from typing import Optional + @strawberry.type class Query: @strawberry.field def hello(self) -> str: return "Hello World" - + @strawberry.field def get_user_agent(self, info: strawberry.Info) -> str: """Get user agent from request context.""" request = info.context["request"] return request.headers.get("user-agent", "Unknown") - + @strawberry.field def get_request_method(self, info: strawberry.Info) -> str: """Get request method from context.""" request = info.context["request"] return request.method + schema = strawberry.Schema(query=Query) + def test_graphql_query(): app = NexiosApp() GraphQL(app, schema) client = TestClient(app) - response = client.post( - "/graphql", - json={ - "query": "{ hello }" - } - ) - + response = client.post("/graphql", json={"query": "{ hello }"}) + assert response.status_code == 200 assert response.json() == {"data": {"hello": "Hello World"}} + def test_graphql_context_user_agent(): """Test accessing request headers through context.""" app = NexiosApp() @@ -48,38 +49,33 @@ def test_graphql_context_user_agent(): response = client.post( "/graphql", - json={ - "query": "{ getUserAgent }" - }, - headers={"User-Agent": "TestAgent/1.0"} + json={"query": "{ getUserAgent }"}, + headers={"User-Agent": "TestAgent/1.0"}, ) - + assert response.status_code == 200 assert response.json() == {"data": {"getUserAgent": "TestAgent/1.0"}} + def test_graphql_context_request_method(): """Test accessing request method through context.""" app = NexiosApp() GraphQL(app, schema) client = TestClient(app) - response = client.post( - "/graphql", - json={ - "query": "{ getRequestMethod }" - } - ) - + response = client.post("/graphql", json={"query": "{ getRequestMethod }"}) + assert response.status_code == 200 assert response.json() == {"data": {"getRequestMethod": "POST"}} + def test_graphiql_html(): app = NexiosApp() GraphQL(app, schema, graphiql=True) client = TestClient(app) response = client.get("/graphql") - + assert response.status_code == 200 assert "" in response.text assert "GraphiQL" in response.text diff --git a/tests/test_redis/__init__.py b/tests/test_redis/__init__.py index 88f262a..15f5121 100644 --- a/tests/test_redis/__init__.py +++ b/tests/test_redis/__init__.py @@ -1,3 +1,3 @@ """ Integration tests for Nexios Redis contrib module. -""" \ No newline at end of file +""" diff --git a/tests/test_redis/conftest.py b/tests/test_redis/conftest.py index 078ecf2..9799e5c 100644 --- a/tests/test_redis/conftest.py +++ b/tests/test_redis/conftest.py @@ -1,16 +1,17 @@ """ Test configuration and fixtures for Redis integration tests. """ + import asyncio -import pytest -import pytest_asyncio from typing import AsyncGenerator, Generator from unittest.mock import AsyncMock, MagicMock +import pytest +import pytest_asyncio from nexios import NexiosApp from nexios.testclient import TestClient -from nexios_contrib.redis import RedisClient, RedisConfig, init_redis +from nexios_contrib.redis import RedisClient, RedisConfig, init_redis # @pytest.fixture(scope="session") # def event_loop(): @@ -28,7 +29,7 @@ def redis_config(): db=15, # Use a test database decode_responses=True, socket_timeout=5.0, - socket_connect_timeout=5.0 + socket_connect_timeout=5.0, ) @@ -36,7 +37,7 @@ def redis_config(): def mock_redis(): """Create a mock Redis client for testing without actual Redis connection.""" mock = AsyncMock() - + # Mock common Redis operations mock.ping.return_value = True mock.get.return_value = None @@ -49,30 +50,30 @@ def mock_redis(): mock.decr.return_value = 0 mock.keys.return_value = [] mock.flushdb.return_value = True - + # Hash operations mock.hget.return_value = None mock.hset.return_value = 1 mock.hgetall.return_value = {} - + # List operations mock.lpush.return_value = 1 mock.rpush.return_value = 1 mock.lpop.return_value = None mock.rpop.return_value = None mock.llen.return_value = 0 - + # Set operations mock.sadd.return_value = 1 mock.smembers.return_value = set() mock.srem.return_value = 1 mock.scard.return_value = 0 - + # JSON operations (mock for redis-py with JSON support) mock.json.return_value = mock - + mock.close.return_value = None - + return mock @@ -83,12 +84,30 @@ async def redis_client(redis_config, mock_redis): client._connected = True _mocked_methods = [ - 'ping', 'get', 'set', 'delete', 'exists', 'expire', 'ttl', - 'incr', 'decr', 'keys', 'flushdb', - 'hget', 'hset', 'hgetall', - 'lpush', 'rpush', 'lpop', 'rpop', 'llen', - 'sadd', 'smembers', 'srem', 'scard', - 'execute_command', + "ping", + "get", + "set", + "delete", + "exists", + "expire", + "ttl", + "incr", + "decr", + "keys", + "flushdb", + "hget", + "hset", + "hgetall", + "lpush", + "rpush", + "lpop", + "rpop", + "llen", + "sadd", + "smembers", + "srem", + "scard", + "execute_command", ] for method in _mocked_methods: setattr(client, method, getattr(mock_redis, method)) @@ -108,15 +127,12 @@ async def redis_client(redis_config, mock_redis): def app_with_redis(): """Create a Nexios app with Redis initialized.""" app = NexiosApp() - + # Initialize Redis with test configuration init_redis( - app, - url="redis://localhost:6379", - db=15, # Test database - decode_responses=True + app, url="redis://localhost:6379", db=15, decode_responses=True # Test database ) - + return app @@ -134,12 +150,30 @@ def app_with_mock_redis(mock_redis): client._connected = True _mocked_methods = [ - 'ping', 'get', 'set', 'delete', 'exists', 'expire', 'ttl', - 'incr', 'decr', 'keys', 'flushdb', - 'hget', 'hset', 'hgetall', - 'lpush', 'rpush', 'lpop', 'rpop', 'llen', - 'sadd', 'smembers', 'srem', 'scard', - 'execute_command', + "ping", + "get", + "set", + "delete", + "exists", + "expire", + "ttl", + "incr", + "decr", + "keys", + "flushdb", + "hget", + "hset", + "hgetall", + "lpush", + "rpush", + "lpop", + "rpop", + "llen", + "sadd", + "smembers", + "srem", + "scard", + "execute_command", ] for method in _mocked_methods: setattr(client, method, getattr(mock_redis, method)) @@ -153,6 +187,7 @@ def app_with_mock_redis(mock_redis): # Store in app state and global variable app.state["redis"] = client import nexios_contrib.redis + nexios_contrib.redis._redis_client = client yield app @@ -167,7 +202,6 @@ def test_client_with_redis(app_with_mock_redis): return TestClient(app_with_mock_redis) - @pytest.fixture def sample_data(): """Sample data for testing.""" @@ -176,17 +210,17 @@ def sample_data(): "id": "123", "name": "John Doe", "email": "john@example.com", - "age": 30 + "age": 30, }, "session": { "id": "session_abc123", "user_id": "123", - "expires_at": "2024-12-31T23:59:59Z" + "expires_at": "2024-12-31T23:59:59Z", }, "cache_data": { "key": "expensive_computation", - "value": {"result": 42, "computed_at": "2024-01-01T00:00:00Z"} - } + "value": {"result": 42, "computed_at": "2024-01-01T00:00:00Z"}, + }, } @@ -200,5 +234,5 @@ def redis_keys(): "cache": "cache:expensive_computation", "list": "messages:inbox", "set": "tags:article:123", - "hash": "profile:123" - } \ No newline at end of file + "hash": "profile:123", + } diff --git a/tests/test_redis/run_redis_tests.py b/tests/test_redis/run_redis_tests.py index 3302dd5..10fb194 100644 --- a/tests/test_redis/run_redis_tests.py +++ b/tests/test_redis/run_redis_tests.py @@ -7,9 +7,10 @@ 2. Integration tests with real Redis server 3. All tests including performance benchmarks """ -import sys -import subprocess + import argparse +import subprocess +import sys from pathlib import Path @@ -17,7 +18,8 @@ def check_redis_available(): """Check if Redis server is available.""" try: import redis - r = redis.Redis(host='localhost', port=6379, decode_responses=True) + + r = redis.Redis(host="localhost", port=6379, decode_responses=True) r.ping() return True except (ImportError, redis.ConnectionError, redis.TimeoutError): @@ -28,13 +30,16 @@ def run_unit_tests(): """Run unit tests with mocked Redis.""" print("Running Redis unit tests with mocked Redis...") cmd = [ - sys.executable, "-m", "pytest", + sys.executable, + "-m", + "pytest", "tests/test_redis/", "-v", "--tb=short", - "-m", "not integration", + "-m", + "not integration", "--cov=nexios_contrib.redis", - "--cov-report=term-missing" + "--cov-report=term-missing", ] return subprocess.run(cmd).returncode @@ -45,14 +50,17 @@ def run_integration_tests(): print("❌ Redis server not available. Please start Redis server first.") print(" docker run -d -p 6379:6379 redis:latest") return 1 - + print("Running Redis integration tests with real Redis server...") cmd = [ - sys.executable, "-m", "pytest", + sys.executable, + "-m", + "pytest", "tests/test_redis/test_redis_real_integration.py", "-v", "--tb=short", - "-m", "integration" + "-m", + "integration", ] return subprocess.run(cmd).returncode @@ -60,18 +68,18 @@ def run_integration_tests(): def run_all_tests(): """Run all Redis tests.""" print("Running all Redis tests...") - + # Run unit tests first unit_result = run_unit_tests() if unit_result != 0: print("❌ Unit tests failed") return unit_result - + print("✅ Unit tests passed") - + # Run integration tests if Redis is available if check_redis_available(): - print("\n" + "="*50) + print("\n" + "=" * 50) integration_result = run_integration_tests() if integration_result != 0: print("❌ Integration tests failed") @@ -79,7 +87,7 @@ def run_all_tests(): print("✅ Integration tests passed") else: print("⚠️ Skipping integration tests (Redis not available)") - + return 0 @@ -87,10 +95,12 @@ def run_specific_test(test_file): """Run a specific test file.""" print(f"Running specific test: {test_file}") cmd = [ - sys.executable, "-m", "pytest", + sys.executable, + "-m", + "pytest", f"tests/test_redis/{test_file}", "-v", - "--tb=short" + "--tb=short", ] return subprocess.run(cmd).returncode @@ -102,20 +112,17 @@ def main(): "--mode", choices=["unit", "integration", "all"], default="unit", - help="Test mode to run (default: unit)" + help="Test mode to run (default: unit)", ) parser.add_argument( - "--file", - help="Run specific test file (e.g., test_redis_client.py)" + "--file", help="Run specific test file (e.g., test_redis_client.py)" ) parser.add_argument( - "--setup-redis", - action="store_true", - help="Show Redis setup instructions" + "--setup-redis", action="store_true", help="Show Redis setup instructions" ) - + args = parser.parse_args() - + if args.setup_redis: print("Redis Setup Instructions:") print("========================") @@ -135,10 +142,10 @@ def main(): print(" redis-cli ping") print(" # Should return: PONG") return 0 - + if args.file: return run_specific_test(args.file) - + if args.mode == "unit": return run_unit_tests() elif args.mode == "integration": @@ -148,4 +155,4 @@ def main(): if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/tests/test_redis/test_redis_client.py b/tests/test_redis/test_redis_client.py index 6f3366d..1648065 100644 --- a/tests/test_redis/test_redis_client.py +++ b/tests/test_redis/test_redis_client.py @@ -1,10 +1,12 @@ """ Integration tests for RedisClient class. """ -import pytest + import json from unittest.mock import AsyncMock +import pytest + from nexios_contrib.redis.client import RedisClient, RedisOperationError from nexios_contrib.redis.config import RedisConfig @@ -180,21 +182,27 @@ async def test_execute_error_handling(self, redis_client, mock_redis): """Test execute error raises RedisOperationError.""" mock_redis.execute_command.side_effect = Exception("command failed") - with pytest.raises(RedisOperationError, match="Failed to execute Redis command"): + with pytest.raises( + RedisOperationError, match="Failed to execute Redis command" + ): await redis_client.execute("INVALID") async def test_json_get_error_handling(self, redis_client, mock_redis): """Test json_get error raises RedisOperationError.""" mock_redis.get.side_effect = Exception("parse error") - with pytest.raises(RedisOperationError, match="Failed to get JSON from key 'bad_key'"): + with pytest.raises( + RedisOperationError, match="Failed to get JSON from key 'bad_key'" + ): await redis_client.json_get("bad_key") async def test_json_set_error_handling(self, redis_client, mock_redis): """Test json_set error raises RedisOperationError.""" mock_redis.set.side_effect = Exception("set failed") - with pytest.raises(RedisOperationError, match="Failed to set JSON for key 'bad_key'"): + with pytest.raises( + RedisOperationError, match="Failed to set JSON for key 'bad_key'" + ): await redis_client.json_set("bad_key", ".", {"a": 1}) def test_client_repr(self, redis_client, redis_config): diff --git a/tests/test_redis/test_redis_dependencies.py b/tests/test_redis/test_redis_dependencies.py index 36d9092..dd98811 100644 --- a/tests/test_redis/test_redis_dependencies.py +++ b/tests/test_redis/test_redis_dependencies.py @@ -1,17 +1,16 @@ """ Integration tests for Redis dependency injection. """ -import pytest + from unittest.mock import AsyncMock -from nexios import NexiosApp, Depend +import pytest +from nexios import Depend, NexiosApp +from nexios.dependencies import Context from nexios.http import Request, Response from nexios.testclient import TestClient -from nexios.dependencies import Context -from nexios_contrib.redis.dependency import ( - RedisDepend -) +from nexios_contrib.redis.dependency import RedisDepend class TestRedisDependencies: @@ -20,14 +19,14 @@ class TestRedisDependencies: def test_redis_depend_basic(self, test_client_with_redis, mock_redis): """Test basic RedisDepend functionality.""" app = test_client_with_redis.app - + @app.get("/test") - async def test_endpoint(request,response,redis=RedisDepend()): + async def test_endpoint(request, response, redis=RedisDepend()): # Redis client should be injected assert redis is not None - assert hasattr(redis, 'get') + assert hasattr(redis, "get") return {"status": "ok"} - + response = test_client_with_redis.get("/test") assert response.status_code == 200 assert response.json() == {"status": "ok"} @@ -35,32 +34,32 @@ async def test_endpoint(request,response,redis=RedisDepend()): def test_redis_depend_in_route(self, test_client_with_redis, mock_redis): """Test RedisDepend in actual route usage.""" app = test_client_with_redis.app - + @app.get("/cache/{key}") - async def get_value(request: Request,response,key, redis=RedisDepend()): + async def get_value(request: Request, response, key, redis=RedisDepend()): value = await redis.get(key) return {"key": key, "value": value} - + mock_redis.get.return_value = "test_value" - + response = test_client_with_redis.get("/cache/mykey") assert response.status_code == 200 assert response.json() == {"key": "mykey", "value": "test_value"} - mock_redis.get.assert_called_with("mykey") - + mock_redis.get.assert_called_with("mykey") def test_redis_depend_with_context(self, app_with_mock_redis, mock_redis): """Test Redis dependencies with explicit context.""" - from nexios_contrib.redis.dependency import RedisDepend from nexios.dependencies import Context - + + from nexios_contrib.redis.dependency import RedisDepend + # Create a dependency function redis_dep = RedisDepend() - + # Create a mock context # Call the dependency function redis_client = redis_dep.dependency() - + assert redis_client is not None - assert hasattr(redis_client, 'get') - assert hasattr(redis_client, 'set') \ No newline at end of file + assert hasattr(redis_client, "get") + assert hasattr(redis_client, "set") diff --git a/tests/test_redis/test_redis_integration.py b/tests/test_redis/test_redis_integration.py index 91ef480..9172e7f 100644 --- a/tests/test_redis/test_redis_integration.py +++ b/tests/test_redis/test_redis_integration.py @@ -1,15 +1,19 @@ """ Integration tests for Redis with Nexios application. """ -import pytest + from unittest.mock import AsyncMock, patch -from nexios import NexiosApp, Depend +import pytest +from nexios import Depend, NexiosApp from nexios.http import Request, Response from nexios.testclient import TestClient from nexios_contrib.redis import ( - init_redis, get_redis, get_redis_client, RedisConnectionError + RedisConnectionError, + get_redis, + get_redis_client, + init_redis, ) @@ -19,10 +23,10 @@ class TestRedisIntegration: def test_init_redis_basic(self): """Test basic Redis initialization.""" app = NexiosApp() - + # Initialize Redis init_redis(app) - + # Check that Redis client is stored in app state assert "redis" in app.state assert app.state["redis"] is not None @@ -30,7 +34,7 @@ def test_init_redis_basic(self): def test_init_redis_with_custom_config(self): """Test Redis initialization with custom configuration.""" app = NexiosApp() - + # Initialize Redis with custom settings init_redis( app, @@ -38,9 +42,9 @@ def test_init_redis_with_custom_config(self): db=2, password="test_password", decode_responses=False, - socket_timeout=15.0 + socket_timeout=15.0, ) - + # Check configuration redis_client = app.state["redis"] assert redis_client.config.url == "redis://localhost:6380" @@ -52,15 +56,15 @@ def test_init_redis_with_custom_config(self): async def test_get_redis_dependency(self, app_with_mock_redis): """Test get_redis dependency injection.""" redis_client = get_redis() - + assert redis_client is not None - assert hasattr(redis_client, 'get') - assert hasattr(redis_client, 'set') + assert hasattr(redis_client, "get") + assert hasattr(redis_client, "set") async def test_get_redis_client_alias(self, app_with_mock_redis): """Test get_redis_client alias.""" redis_client = get_redis_client() - + assert redis_client is not None assert redis_client == get_redis() @@ -68,61 +72,63 @@ def test_get_redis_not_initialized(self): """Test get_redis when Redis is not initialized.""" # Clear global Redis client import nexios_contrib.redis + original_client = nexios_contrib.redis._redis_client nexios_contrib.redis._redis_client = None - + try: - with pytest.raises(RedisConnectionError, match="Redis client not initialized"): + with pytest.raises( + RedisConnectionError, match="Redis client not initialized" + ): get_redis() finally: # Restore original client nexios_contrib.redis._redis_client = original_client - - def test_redis_list_operations_route(self, test_client_with_redis, mock_redis): """Test Redis list operations in routes.""" app = test_client_with_redis.app - + @app.post("/messages") - async def add_message(request: Request, response,redis=Depend(get_redis)): + async def add_message(request: Request, response, redis=Depend(get_redis)): data = await request.json message = data.get("message") - + length = await redis.lpush("messages", message) return {"message": message, "queue_length": length} - + @app.get("/messages/next") - async def get_next_message(request: Request,response, redis=Depend(get_redis)): + async def get_next_message(request: Request, response, redis=Depend(get_redis)): message = await redis.rpop("messages") if not message: return {"message": None} - + return {"message": message} - + @app.get("/messages/count") - async def get_message_count(request: Request, response, redis=Depend(get_redis)): + async def get_message_count( + request: Request, response, redis=Depend(get_redis) + ): count = await redis.llen("messages") return {"count": count} - + # Mock Redis responses mock_redis.lpush.return_value = 3 mock_redis.rpop.return_value = "Hello World" mock_redis.llen.return_value = 2 - + # Test add message response = test_client_with_redis.post( - "/messages", - json={"message": "Hello World"} + "/messages", json={"message": "Hello World"} ) assert response.status_code == 200 assert response.json() == {"message": "Hello World", "queue_length": 3} - + # Test get next message response = test_client_with_redis.get("/messages/next") assert response.status_code == 200 assert response.json() == {"message": "Hello World"} - + # Test get message count response = test_client_with_redis.get("/messages/count") assert response.status_code == 200 @@ -131,37 +137,43 @@ async def get_message_count(request: Request, response, redis=Depend(get_redis)) def test_redis_hash_operations_route(self, test_client_with_redis, mock_redis): """Test Redis hash operations in routes.""" app = test_client_with_redis.app - + @app.post("/user/{user_id}/profile") - async def update_profile(request: Request, response,user_id,redis=Depend(get_redis)): + async def update_profile( + request: Request, response, user_id, redis=Depend(get_redis) + ): data = await request.json - + profile_key = f"profile:{user_id}" - + for field, value in data.items(): await redis.hset(profile_key, field, str(value)) - + return {"user_id": user_id, "status": "updated"} - + @app.get("/user/{user_id}/profile") - async def get_profile(request: Request, response,user_id,redis=Depend(get_redis)): + async def get_profile( + request: Request, response, user_id, redis=Depend(get_redis) + ): profile_key = f"profile:{user_id}" - + profile = await redis.hgetall(profile_key) return {"user_id": user_id, "profile": profile} - + # Mock Redis responses mock_redis.hset.return_value = 1 - mock_redis.hgetall.return_value = {"name": "John Doe", "email": "john@example.com"} - + mock_redis.hgetall.return_value = { + "name": "John Doe", + "email": "john@example.com", + } + # Test update profile response = test_client_with_redis.post( - "/user/123/profile", - json={"name": "John Doe", "email": "john@example.com"} + "/user/123/profile", json={"name": "John Doe", "email": "john@example.com"} ) assert response.status_code == 200 assert response.json() == {"user_id": "123", "status": "updated"} - + # Test get profile response = test_client_with_redis.get("/user/123/profile") assert response.status_code == 200 @@ -170,48 +182,46 @@ async def get_profile(request: Request, response,user_id,redis=Depend(get_redis) assert data["profile"]["name"] == "John Doe" assert data["profile"]["email"] == "john@example.com" - - - @patch('nexios_contrib.redis.client.RedisClient.connect') + @patch("nexios_contrib.redis.client.RedisClient.connect") async def test_startup_connection_success(self, mock_connect): """Test successful Redis connection on app startup.""" app = NexiosApp() init_redis(app) - + # Mock successful connection mock_connect.return_value = None - + # Simulate app startup for handler in app.startup_handlers: await handler() - + mock_connect.assert_called_once() - @patch('nexios_contrib.redis.client.RedisClient.connect') + @patch("nexios_contrib.redis.client.RedisClient.connect") async def test_startup_connection_failure(self, mock_connect): """Test Redis connection failure on app startup.""" app = NexiosApp() init_redis(app) - + # Mock connection failure mock_connect.side_effect = Exception("Connection failed") - + # Simulate app startup - should raise RedisConnectionError with pytest.raises(RedisConnectionError, match="Failed to connect to Redis"): for handler in app.startup_handlers: await handler() - @patch('nexios_contrib.redis.client.RedisClient.close') + @patch("nexios_contrib.redis.client.RedisClient.close") async def test_shutdown_cleanup(self, mock_close): """Test Redis cleanup on app shutdown.""" app = NexiosApp() init_redis(app) - + # Mock successful close mock_close.return_value = None - + # Simulate app shutdown for handler in app.shutdown_handlers: await handler() - - mock_close.assert_called_once() \ No newline at end of file + + mock_close.assert_called_once() diff --git a/tests/test_redis/test_redis_real_integration.py b/tests/test_redis/test_redis_real_integration.py index aa1830e..0d21840 100644 --- a/tests/test_redis/test_redis_real_integration.py +++ b/tests/test_redis/test_redis_real_integration.py @@ -4,20 +4,24 @@ These tests require a running Redis server and are marked with @pytest.mark.integration so they can be skipped in CI/CD environments where Redis is not available. """ -import pytest + import asyncio import json from typing import Optional -from nexios import NexiosApp, Depend +import pytest +from nexios import Depend, NexiosApp from nexios.http import Request, Response from nexios.testclient import TestClient from nexios_contrib.redis import ( - init_redis, get_redis, RedisClient, RedisConfig, RedisConnectionError + RedisClient, + RedisConfig, + RedisConnectionError, + get_redis, + init_redis, ) - # Skip these tests if Redis is not available pytestmark = pytest.mark.integration @@ -27,7 +31,8 @@ def redis_available(): """Check if Redis server is available for testing.""" try: import redis - r = redis.Redis(host='localhost', port=6379, db=15, decode_responses=True) + + r = redis.Redis(host="localhost", port=6379, db=15, decode_responses=True) r.ping() return True except (ImportError, redis.ConnectionError, redis.TimeoutError): @@ -38,19 +43,17 @@ def redis_available(): async def real_redis_client(redis_available): """Create a real Redis client for testing.""" config = RedisConfig( - url="redis://localhost:6379", - db=15, # Use test database - decode_responses=True + url="redis://localhost:6379", db=15, decode_responses=True # Use test database ) - + client = RedisClient(config) await client.connect() - + # Clean up test database before tests await client.flushdb() - + yield client - + # Clean up after tests await client.flushdb() await client.close() @@ -60,14 +63,11 @@ async def real_redis_client(redis_available): def real_redis_app(redis_available): """Create a Nexios app with real Redis connection.""" app = NexiosApp() - + init_redis( - app, - url="redis://localhost:6379", - db=15, # Test database - decode_responses=True + app, url="redis://localhost:6379", db=15, decode_responses=True # Test database ) - + return app @@ -92,15 +92,15 @@ async def test_real_basic_operations(self, real_redis_client): await real_redis_client.set("test_key", "test_value") value = await real_redis_client.get("test_key") assert value == "test_value" - + # Test EXISTS exists = await real_redis_client.exists("test_key") assert exists == 1 - + # Test DELETE deleted = await real_redis_client.delete("test_key") assert deleted == 1 - + # Verify deletion value = await real_redis_client.get("test_key") assert value is None @@ -109,14 +109,14 @@ async def test_real_expiration(self, real_redis_client): """Test Redis expiration with real server.""" # Set key with expiration await real_redis_client.set("expire_key", "expire_value", ex=1) - + # Check TTL ttl = await real_redis_client.ttl("expire_key") assert 0 < ttl <= 1 - + # Wait for expiration await asyncio.sleep(1.1) - + # Key should be expired value = await real_redis_client.get("expire_key") assert value is None @@ -126,10 +126,10 @@ async def test_real_counter_operations(self, real_redis_client): # Test INCR value1 = await real_redis_client.incr("counter") assert value1 == 1 - + value2 = await real_redis_client.incr("counter", 5) assert value2 == 6 - + # Test DECR value3 = await real_redis_client.decr("counter", 2) assert value3 == 4 @@ -139,14 +139,14 @@ async def test_real_hash_operations(self, real_redis_client): # Test HSET result = await real_redis_client.hset("user:123", "name", "John") assert result == 1 - + result = await real_redis_client.hset("user:123", "email", "john@example.com") assert result == 1 - + # Test HGET name = await real_redis_client.hget("user:123", "name") assert name == "John" - + # Test HGETALL user_data = await real_redis_client.hgetall("user:123") assert user_data == {"name": "John", "email": "john@example.com"} @@ -156,21 +156,21 @@ async def test_real_list_operations(self, real_redis_client): # Test LPUSH and RPUSH length1 = await real_redis_client.lpush("messages", "msg1") assert length1 == 1 - + length2 = await real_redis_client.rpush("messages", "msg2") assert length2 == 2 - + # Test LLEN length = await real_redis_client.llen("messages") assert length == 2 - + # Test LPOP and RPOP left_msg = await real_redis_client.lpop("messages") assert left_msg == "msg1" - + right_msg = await real_redis_client.rpop("messages") assert right_msg == "msg2" - + # List should be empty length = await real_redis_client.llen("messages") assert length == 0 @@ -180,49 +180,45 @@ async def test_real_set_operations(self, real_redis_client): # Test SADD added = await real_redis_client.sadd("tags", "python", "redis", "cache") assert added == 3 - + # Test SCARD size = await real_redis_client.scard("tags") assert size == 3 - + # Test SMEMBERS members = await real_redis_client.smembers("tags") assert set(members) == {"python", "redis", "cache"} - + # Test SREM removed = await real_redis_client.srem("tags", "cache") assert removed == 1 - + # Verify removal size = await real_redis_client.scard("tags") assert size == 2 - - async def test_real_keys_operation(self, real_redis_client): """Test Redis KEYS operation with real server.""" # Set up test keys await real_redis_client.set("user:123", "data1") await real_redis_client.set("user:456", "data2") await real_redis_client.set("session:abc", "session_data") - + # Test KEYS with pattern user_keys = await real_redis_client.keys("user:*") assert set(user_keys) == {"user:123", "user:456"} - + all_keys = await real_redis_client.keys("*") assert len(all_keys) >= 3 - async def test_real_connection_error_handling(self): """Test connection error handling with invalid Redis config.""" config = RedisConfig( - url="redis://localhost:9999", # Invalid port - socket_connect_timeout=1.0 + url="redis://localhost:9999", socket_connect_timeout=1.0 # Invalid port ) - + client = RedisClient(config) - + with pytest.raises(ConnectionError): await client.connect() @@ -231,33 +227,33 @@ async def test_real_database_isolation(self, redis_available): # Create clients for different databases config1 = RedisConfig(url="redis://localhost:6379", db=14) config2 = RedisConfig(url="redis://localhost:6379", db=15) - + client1 = RedisClient(config1) client2 = RedisClient(config2) - + try: await client1.connect() await client2.connect() - + # Set value in database 14 await client1.set("isolation_test", "db14_value") - + # Check that it doesn't exist in database 15 value_db15 = await client2.get("isolation_test") assert value_db15 is None - + # Set different value in database 15 await client2.set("isolation_test", "db15_value") - + # Verify both databases have their own values value_db14 = await client1.get("isolation_test") value_db15 = await client2.get("isolation_test") - + assert value_db14 == "db14_value" assert value_db15 == "db15_value" - + finally: await client1.flushdb() await client2.flushdb() await client1.close() - await client2.close() \ No newline at end of file + await client2.close() diff --git a/tests/test_redis/test_redis_utils.py b/tests/test_redis/test_redis_utils.py index ea0e671..42a4197 100644 --- a/tests/test_redis/test_redis_utils.py +++ b/tests/test_redis/test_redis_utils.py @@ -1,289 +1,298 @@ """ Integration tests for Redis utility functions. """ -import pytest + import json from unittest.mock import patch +import pytest + from nexios_contrib.redis import utils class TestRedisUtils: """Test Redis utility functions.""" - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_get(self, mock_get_client, mock_redis): """Test redis_get utility function.""" mock_get_client.return_value = mock_redis mock_redis.get.return_value = "test_value" - + result = await utils.redis_get("test_key") - + assert result == "test_value" mock_redis.get.assert_called_with("test_key") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_set(self, mock_get_client, mock_redis): """Test redis_set utility function.""" mock_get_client.return_value = mock_redis mock_redis.set.return_value = True - + result = await utils.redis_set("test_key", "test_value", ex=300) - + assert result is True - mock_redis.set.assert_called_with("test_key", "test_value", ex=300, px=None, nx=False, xx=False) + mock_redis.set.assert_called_with( + "test_key", "test_value", ex=300, px=None, nx=False, xx=False + ) - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_delete(self, mock_get_client, mock_redis): """Test redis_delete utility function.""" mock_get_client.return_value = mock_redis mock_redis.delete.return_value = 2 - + result = await utils.redis_delete("key1", "key2") - + assert result == 2 mock_redis.delete.assert_called_with("key1", "key2") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_exists(self, mock_get_client, mock_redis): """Test redis_exists utility function.""" mock_get_client.return_value = mock_redis mock_redis.exists.return_value = 1 - + result = await utils.redis_exists("test_key") - + assert result == 1 mock_redis.exists.assert_called_with("test_key") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_expire(self, mock_get_client, mock_redis): """Test redis_expire utility function.""" mock_get_client.return_value = mock_redis mock_redis.expire.return_value = True - + result = await utils.redis_expire("test_key", 300, nx=True) - + assert result is True - mock_redis.expire.assert_called_with("test_key", 300, nx=True, xx=False, gt=False, lt=False) + mock_redis.expire.assert_called_with( + "test_key", 300, nx=True, xx=False, gt=False, lt=False + ) - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_ttl(self, mock_get_client, mock_redis): """Test redis_ttl utility function.""" mock_get_client.return_value = mock_redis mock_redis.ttl.return_value = 250 - + result = await utils.redis_ttl("test_key") - + assert result == 250 - mock_redis.ttl.assert_called_with("test_key") - @patch('nexios_contrib.redis.utils.get_redis_client') + mock_redis.ttl.assert_called_with("test_key") + + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_incr(self, mock_get_client, mock_redis): """Test redis_incr utility function.""" mock_get_client.return_value = mock_redis mock_redis.incr.return_value = 5 - + result = await utils.redis_incr("counter", 3) - + assert result == 5 mock_redis.incr.assert_called_with("counter", 3) - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_decr(self, mock_get_client, mock_redis): """Test redis_decr utility function.""" mock_get_client.return_value = mock_redis mock_redis.decr.return_value = 2 - + result = await utils.redis_decr("counter", 1) - + assert result == 2 mock_redis.decr.assert_called_with("counter", 1) - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_json_get(self, mock_get_client, mock_redis): """Test redis_json_get utility function.""" mock_get_client.return_value = mock_redis test_data = {"name": "John", "age": 30} mock_redis.json_get.return_value = test_data - + result = await utils.redis_json_get("user:123", ".") - + assert result == test_data mock_redis.json_get.assert_called_with("user:123", ".") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_json_set(self, mock_get_client, mock_redis): """Test redis_json_set utility function.""" mock_get_client.return_value = mock_redis mock_redis.json_set.return_value = True test_data = {"name": "John", "age": 30} - + result = await utils.redis_json_set("user:123", ".", test_data, nx=True) - + assert result is True - mock_redis.json_set.assert_called_with("user:123", ".", test_data, nx=True, xx=False) + mock_redis.json_set.assert_called_with( + "user:123", ".", test_data, nx=True, xx=False + ) - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_hget(self, mock_get_client, mock_redis): """Test redis_hget utility function.""" mock_get_client.return_value = mock_redis mock_redis.hget.return_value = "John" - + result = await utils.redis_hget("user:123", "name") - + assert result == "John" mock_redis.hget.assert_called_with("user:123", "name") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_hset(self, mock_get_client, mock_redis): """Test redis_hset utility function.""" mock_get_client.return_value = mock_redis mock_redis.hset.return_value = 1 - + result = await utils.redis_hset("user:123", "name", "John") - + assert result == 1 mock_redis.hset.assert_called_with("user:123", "name", "John") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_hgetall(self, mock_get_client, mock_redis): """Test redis_hgetall utility function.""" mock_get_client.return_value = mock_redis test_hash = {"name": "John", "email": "john@example.com"} mock_redis.hgetall.return_value = test_hash - + result = await utils.redis_hgetall("user:123") - + assert result == test_hash mock_redis.hgetall.assert_called_with("user:123") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_lpush(self, mock_get_client, mock_redis): """Test redis_lpush utility function.""" mock_get_client.return_value = mock_redis mock_redis.lpush.return_value = 3 - + result = await utils.redis_lpush("messages", "msg1", "msg2") - + assert result == 3 mock_redis.lpush.assert_called_with("messages", "msg1", "msg2") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_rpush(self, mock_get_client, mock_redis): """Test redis_rpush utility function.""" mock_get_client.return_value = mock_redis mock_redis.rpush.return_value = 4 - + result = await utils.redis_rpush("messages", "msg3") - + assert result == 4 mock_redis.rpush.assert_called_with("messages", "msg3") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_lpop(self, mock_get_client, mock_redis): """Test redis_lpop utility function.""" mock_get_client.return_value = mock_redis mock_redis.lpop.return_value = "msg1" - + result = await utils.redis_lpop("messages") - + assert result == "msg1" mock_redis.lpop.assert_called_with("messages", None) - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_rpop(self, mock_get_client, mock_redis): """Test redis_rpop utility function.""" mock_get_client.return_value = mock_redis mock_redis.rpop.return_value = "msg3" - + result = await utils.redis_rpop("messages", 2) - + assert result == "msg3" mock_redis.rpop.assert_called_with("messages", 2) - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_llen(self, mock_get_client, mock_redis): """Test redis_llen utility function.""" mock_get_client.return_value = mock_redis mock_redis.llen.return_value = 5 - + result = await utils.redis_llen("messages") - + assert result == 5 mock_redis.llen.assert_called_with("messages") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_sadd(self, mock_get_client, mock_redis): """Test redis_sadd utility function.""" mock_get_client.return_value = mock_redis mock_redis.sadd.return_value = 2 - + result = await utils.redis_sadd("tags", "python", "redis") - + assert result == 2 mock_redis.sadd.assert_called_with("tags", "python", "redis") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_smembers(self, mock_get_client, mock_redis): """Test redis_smembers utility function.""" mock_get_client.return_value = mock_redis mock_redis.smembers.return_value = ["python", "redis"] - + result = await utils.redis_smembers("tags") - + assert result == ["python", "redis"] mock_redis.smembers.assert_called_with("tags") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_srem(self, mock_get_client, mock_redis): """Test redis_srem utility function.""" mock_get_client.return_value = mock_redis mock_redis.srem.return_value = 1 - + result = await utils.redis_srem("tags", "python") - + assert result == 1 mock_redis.srem.assert_called_with("tags", "python") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_scard(self, mock_get_client, mock_redis): """Test redis_scard utility function.""" mock_get_client.return_value = mock_redis mock_redis.scard.return_value = 3 - + result = await utils.redis_scard("tags") - + assert result == 3 mock_redis.scard.assert_called_with("tags") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_keys(self, mock_get_client, mock_redis): """Test redis_keys utility function.""" mock_get_client.return_value = mock_redis mock_redis.keys.return_value = ["user:123", "user:456"] - + result = await utils.redis_keys("user:*") - + assert result == ["user:123", "user:456"] mock_redis.keys.assert_called_with("user:*") - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_flushdb(self, mock_get_client, mock_redis): """Test redis_flushdb utility function.""" mock_get_client.return_value = mock_redis mock_redis.flushdb.return_value = True - + result = await utils.redis_flushdb(asynchronous=True) - + assert result is True mock_redis.flushdb.assert_called_with(True) - @patch('nexios_contrib.redis.utils.get_redis_client') + @patch("nexios_contrib.redis.utils.get_redis_client") async def test_redis_execute(self, mock_get_client, mock_redis): """Test redis_execute utility function.""" mock_get_client.return_value = mock_redis mock_redis.execute.return_value = "PONG" - + result = await utils.redis_execute("PING") - + assert result == "PONG" - mock_redis.execute.assert_called_with("PING") \ No newline at end of file + mock_redis.execute.assert_called_with("PING") diff --git a/tests/test_scheduler/test_config.py b/tests/test_scheduler/test_config.py index f9d347b..eb4172d 100644 --- a/tests/test_scheduler/test_config.py +++ b/tests/test_scheduler/test_config.py @@ -1,6 +1,7 @@ """ Tests for the scheduler configuration module. """ + import time from datetime import datetime, timezone diff --git a/tests/test_scheduler/test_manager.py b/tests/test_scheduler/test_manager.py index a6cea4b..6400c87 100644 --- a/tests/test_scheduler/test_manager.py +++ b/tests/test_scheduler/test_manager.py @@ -1,10 +1,10 @@ """ Tests for the SchedulerManager class. """ + import asyncio import pytest - from nexios import NexiosApp from nexios_contrib.scheduler.config import ( @@ -224,7 +224,9 @@ async def test_job_failure_logged(self): async def failing_task(): raise RuntimeError("job failed") - job = scheduler.add_job(failing_task, IntervalTrigger(seconds=1, start_now=True)) + job = scheduler.add_job( + failing_task, IntervalTrigger(seconds=1, start_now=True) + ) await scheduler.start() await asyncio.sleep(1.5) diff --git a/tests/test_scheduler/test_models.py b/tests/test_scheduler/test_models.py index d54e4ae..32bea7e 100644 --- a/tests/test_scheduler/test_models.py +++ b/tests/test_scheduler/test_models.py @@ -1,6 +1,7 @@ """ Tests for the scheduler models module. """ + import time import pytest diff --git a/tests/test_tasks/test_config.py b/tests/test_tasks/test_config.py index c34e3fc..66ba790 100644 --- a/tests/test_tasks/test_config.py +++ b/tests/test_tasks/test_config.py @@ -1,10 +1,13 @@ """ Tests for the task configuration in nexios_contrib.tasks.config. """ + import logging + import pytest -from nexios_contrib.tasks.config import TaskConfig, TaskStatus, DEFAULT_CONFIG +from nexios_contrib.tasks.config import DEFAULT_CONFIG, TaskConfig, TaskStatus + def test_task_status_enum(): """Test the TaskStatus enum values.""" @@ -14,16 +17,18 @@ def test_task_status_enum(): assert TaskStatus.FAILED == "FAILED" assert TaskStatus.CANCELLED == "CANCELLED" + def test_task_config_defaults(): """Test TaskConfig default values.""" config = TaskConfig() - + assert config.max_concurrent_tasks == 100 assert config.default_timeout is None assert config.task_result_ttl == 3600 # 1 hour assert config.enable_task_history is True assert config.log_level == logging.INFO + def test_task_config_custom_values(): """Test TaskConfig with custom values.""" config = TaskConfig( @@ -31,15 +36,16 @@ def test_task_config_custom_values(): default_timeout=30.0, task_result_ttl=1800, # 30 minutes enable_task_history=False, - log_level=logging.DEBUG + log_level=logging.DEBUG, ) - + assert config.max_concurrent_tasks == 50 assert config.default_timeout == 30.0 assert config.task_result_ttl == 1800 assert config.enable_task_history is False assert config.log_level == logging.DEBUG + def test_task_config_to_dict(): """Test converting TaskConfig to dictionary.""" config = TaskConfig( @@ -47,11 +53,11 @@ def test_task_config_to_dict(): default_timeout=60.0, task_result_ttl=7200, enable_task_history=True, - log_level=logging.WARNING + log_level=logging.WARNING, ) - + config_dict = config.to_dict() - + assert config_dict == { "max_concurrent_tasks": 10, "default_timeout": 60.0, @@ -60,6 +66,7 @@ def test_task_config_to_dict(): "log_level": logging.WARNING, } + def test_default_config(): """Test the default configuration constant.""" assert isinstance(DEFAULT_CONFIG, TaskConfig) @@ -68,4 +75,3 @@ def test_default_config(): assert DEFAULT_CONFIG.task_result_ttl == 3600 assert DEFAULT_CONFIG.enable_task_history is True assert DEFAULT_CONFIG.log_level == logging.INFO - diff --git a/tests/test_tasks/test_dependency.py b/tests/test_tasks/test_dependency.py index 16f68d9..df649b1 100644 --- a/tests/test_tasks/test_dependency.py +++ b/tests/test_tasks/test_dependency.py @@ -1,14 +1,21 @@ """ Tests for the task dependencies in nexios_contrib.tasks.dependency. """ + import asyncio -import pytest from unittest.mock import AsyncMock, MagicMock, patch +import pytest from nexios.http import Request -from nexios_contrib.tasks.dependency import TaskDepend, get_task_dependency, TaskDependency -from nexios_contrib.tasks.models import Task, TaskResult + from nexios_contrib.tasks.config import TaskStatus +from nexios_contrib.tasks.dependency import ( + TaskDepend, + TaskDependency, + get_task_dependency, +) +from nexios_contrib.tasks.models import Task, TaskResult + @pytest.fixture def mock_request(): @@ -17,6 +24,7 @@ def mock_request(): request.base_app.task_manager = MagicMock() return request + @pytest.fixture def task_depend(mock_request): """Create a TaskDepend instance with a mock request.""" @@ -29,56 +37,58 @@ async def test_task_depend_get_task(task_depend, mock_request): # Setup mock_task = MagicMock() mock_request.base_app.task_manager.get_task.return_value = mock_task - + # Test task_id = "test-task-id" result = await task_depend.get_task(task_id) - + # Verify assert result == mock_task mock_request.base_app.task_manager.get_task.assert_called_once_with(task_id) + @pytest.mark.asyncio async def test_task_depend_wait_for_task(task_depend, mock_request): """Test waiting for a task to complete.""" # Setup expected_result = "task result" - mock_request.base_app.task_manager.wait_for_task = AsyncMock(return_value=expected_result) - + mock_request.base_app.task_manager.wait_for_task = AsyncMock( + return_value=expected_result + ) + # Test task_id = "test-task-id" result = await task_depend.wait_for_task(task_id, timeout=5.0) - + # Verify assert result == expected_result mock_request.base_app.task_manager.wait_for_task.assert_called_once_with( - task_id, - timeout=5.0 + task_id, timeout=5.0 ) + @pytest.mark.asyncio async def test_task_depend_cancel_task(task_depend, mock_request): """Test cancelling a task.""" # Setup mock_request.base_app.task_manager.cancel_task = AsyncMock(return_value=True) - + # Test task_id = "test-task-id" result = await task_depend.cancel_task(task_id) - + # Verify assert result is True mock_request.base_app.task_manager.cancel_task.assert_called_once_with(task_id) - def test_task_dependency(): """Test the TaskDependency function.""" # Test result = TaskDependency() - + # Verify from nexios.dependencies import Depend + assert isinstance(result, Depend) assert result.dependency == get_task_dependency - diff --git a/tests/test_tasks/test_models.py b/tests/test_tasks/test_models.py index bbd24ca..086034e 100644 --- a/tests/test_tasks/test_models.py +++ b/tests/test_tasks/test_models.py @@ -1,50 +1,57 @@ """ Tests for the task models in nexios_contrib.tasks.models. """ + import asyncio import time from datetime import datetime -import pytest from unittest.mock import AsyncMock, MagicMock, patch -from nexios_contrib.tasks.models import Task, TaskResult, TaskError +import pytest + from nexios_contrib.tasks.config import TaskStatus +from nexios_contrib.tasks.models import Task, TaskError, TaskResult + @pytest.fixture async def sample_task(): """Create a sample task for testing.""" + async def task_func(x, y): return x + y - + return Task(task_func, 2, 3, name="test_task") + @pytest.mark.asyncio async def test_task_initialization(): """Test task initialization with different parameters.""" + async def task_func(): return "result" - + # Test with minimal parameters task = Task(task_func) assert task.name.startswith("task-") assert task.status == TaskStatus.PENDING assert task.args == () assert task.kwargs == {} - + # Test with all parameters task = Task(task_func, 1, 2, 3, name="test", x=10, y=20) assert task.name == "test" assert task.args == (1, 2, 3) assert task.kwargs == {"x": 10, "y": 20} + @pytest.mark.asyncio async def test_task_run(sample_task): """Test running a task and getting its result.""" assert sample_task.status == TaskStatus.PENDING - + # Run the task await sample_task.run() - + # Verify the task completed successfully assert sample_task.status == TaskStatus.COMPLETED assert sample_task.result is not None @@ -52,31 +59,34 @@ async def test_task_run(sample_task): assert sample_task.result.result == 5 # 2 + 3 = 5 assert sample_task.result.error is None + @pytest.mark.asyncio async def test_task_wait(sample_task): """Test waiting for a task to complete.""" # Start the task in the background asyncio.create_task(sample_task.run()) - + # Wait for the task to complete result = await sample_task.wait(timeout=1.0) - + # Verify the result assert result == 5 # 2 + 3 = 5 assert sample_task.status == TaskStatus.COMPLETED + @pytest.mark.asyncio async def test_task_error_handling(): """Test error handling in tasks.""" + async def failing_task(): raise ValueError("Test error") - + task = Task(failing_task, name="failing_task") - + # Run the task and expect an exception with pytest.raises(ValueError, match="Test error"): await task.run() - + # Verify the task failed with the correct error assert task.status == TaskStatus.FAILED assert task.result is not None @@ -85,23 +95,20 @@ async def failing_task(): assert "Test error" in str(task.result.error) - @pytest.mark.asyncio async def test_task_timeout(): """Test task timeout handling.""" + async def long_running_task(): await asyncio.sleep(10) return "too late" - + task = Task(long_running_task, name="timeout_test") - + # Wait for the task with a short timeout with pytest.raises(asyncio.TimeoutError): await task.wait(timeout=0.1) - - - - + def test_task_result_serialization(): """Test serialization of task results.""" @@ -112,9 +119,9 @@ def test_task_result_serialization(): status=TaskStatus.COMPLETED, error=None, created_at=1000.0, - completed_at=1001.5 + completed_at=1001.5, ) - + result_dict = result.to_dict() assert result_dict == { "task_id": "123", @@ -123,24 +130,24 @@ def test_task_result_serialization(): "error": None, "created_at": 1000.0, "completed_at": 1001.5, - "duration": 1.5 + "duration": 1.5, } - + # Test error result try: raise ValueError("Test error") except ValueError as e: error = e - + error_result = TaskResult( task_id="456", result=None, status=TaskStatus.FAILED, error=error, created_at=2000.0, - completed_at=2000.5 + completed_at=2000.5, ) - + error_dict = error_result.to_dict() assert error_dict["task_id"] == "456" assert error_dict["status"] == "FAILED" diff --git a/tests/test_tasks/test_task_manager.py b/tests/test_tasks/test_task_manager.py index 828d358..d2a9202 100644 --- a/tests/test_tasks/test_task_manager.py +++ b/tests/test_tasks/test_task_manager.py @@ -1,26 +1,31 @@ """ Tests for the TaskManager class in nexios_contrib.tasks.manager. """ + import asyncio from pickle import FALSE -import pytest from unittest.mock import AsyncMock, MagicMock, patch +import pytest from nexios import NexiosApp + +from nexios_contrib.tasks.config import TaskConfig, TaskStatus from nexios_contrib.tasks.manager import TaskManager -from nexios_contrib.tasks.models import Task, TaskResult, TaskError -from nexios_contrib.tasks.config import TaskStatus, TaskConfig +from nexios_contrib.tasks.models import Task, TaskError, TaskResult + @pytest.fixture def app(): """Create a test Nexios app.""" return NexiosApp() + @pytest.fixture def task_manager(app): """Create a task manager instance for testing.""" return TaskManager(app) + @pytest.fixture async def async_task_manager(task_manager): """Create and start a task manager for async tests.""" @@ -30,105 +35,116 @@ async def async_task_manager(task_manager): finally: await task_manager.shutdown() + @pytest.mark.asyncio async def test_task_creation(task_manager): """Test creating a new task.""" + async def sample_task(x, y): return x + y - + task = await task_manager.create_task(sample_task, 2, 3) - + assert task is not None assert task.status == TaskStatus.PENDING assert len(task_manager.tasks) == 1 assert task_manager.tasks[task.id] == task + @pytest.mark.asyncio async def test_task_execution(async_task_manager): """Test task execution and result retrieval.""" + async def sample_task(x, y): return x * y - + task = await async_task_manager.create_task(sample_task, 4, 5) result = await task.wait() - + assert result == 20 assert task.status == TaskStatus.COMPLETED assert task.result is not None assert task.result.status == TaskStatus.COMPLETED assert task.result.result == 20 + @pytest.mark.asyncio async def test_task_error_handling(async_task_manager): """Test error handling in tasks.""" + async def failing_task(): raise ValueError("Something went wrong") - + task = await async_task_manager.create_task(failing_task) - + with pytest.raises(ValueError): await task.wait() - + assert task.status == TaskStatus.FAILED assert task.result is not None assert task.result.status == TaskStatus.FAILED assert "Something went wrong" in str(task.result.error) + @pytest.mark.asyncio async def test_task_cancellation(async_task_manager): """Test task cancellation.""" + async def long_running_task(): try: await asyncio.sleep(10) return "completed" except asyncio.CancelledError: return "cancelled" - + task = await async_task_manager.create_task(long_running_task) await asyncio.sleep(0.1) # Let the task start - + canceled = await async_task_manager.cancel_task(task.id) assert canceled is True - + result = await task.wait() assert result == "cancelled" + @pytest.mark.asyncio async def test_task_callbacks(async_task_manager): """Test task completion callbacks.""" - called = False + called = False + async def sample_task(): return "success" + async def callback(task_id, result, error): nonlocal called called = True - + task = await async_task_manager.create_task(sample_task) async_task_manager.add_callback(task, callback) - + await task.wait() - + assert called - + @pytest.mark.asyncio async def test_task_manager_shutdown(async_task_manager): """Test task manager shutdown behavior.""" task_completed = asyncio.Event() - + async def long_running_task(): try: await asyncio.sleep(10) except asyncio.CancelledError: task_completed.set() raise - + task = await async_task_manager.create_task(long_running_task) await asyncio.sleep(0.1) # Let the task start - + # Shutdown should cancel running tasks await async_task_manager.shutdown() - + # Wait for the task to be cancelled await asyncio.wait_for(task_completed.wait(), timeout=1.0) assert task.status == TaskStatus.CANCELLED diff --git a/tests/tortoise/__init__.py b/tests/tortoise/__init__.py index cb68f6c..dce7260 100644 --- a/tests/tortoise/__init__.py +++ b/tests/tortoise/__init__.py @@ -1 +1 @@ -# Tortoise ORM tests for nexios-contrib \ No newline at end of file +# Tortoise ORM tests for nexios-contrib diff --git a/tests/tortoise/conftest.py b/tests/tortoise/conftest.py index 362e164..748e521 100644 --- a/tests/tortoise/conftest.py +++ b/tests/tortoise/conftest.py @@ -2,10 +2,11 @@ Test configuration and fixtures for Tortoise ORM tests. """ -import pytest import asyncio from typing import Generator +import pytest + @pytest.fixture(scope="session") def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: @@ -20,7 +21,8 @@ async def reset_tortoise_client(): """Reset the global Tortoise client before each test.""" # Reset the global client to ensure clean state import nexios_contrib.tortoise + nexios_contrib.tortoise._tortoise_client = None yield # Clean up after test - nexios_contrib.tortoise._tortoise_client = None \ No newline at end of file + nexios_contrib.tortoise._tortoise_client = None diff --git a/tests/tortoise/test_config.py b/tests/tortoise/test_config.py index 51ebf3a..aa1f362 100644 --- a/tests/tortoise/test_config.py +++ b/tests/tortoise/test_config.py @@ -3,6 +3,7 @@ """ import os + import pytest from pydantic import ValidationError @@ -17,9 +18,9 @@ def test_basic_config_creation(self): config = TortoiseConfig( db_url="sqlite://:memory:", modules={"models": ["app.models"]}, - generate_schemas=True + generate_schemas=True, ) - + assert config.db_url == "sqlite://:memory:" assert config.modules == {"models": ["app.models"]} assert config.generate_schemas is True @@ -29,7 +30,7 @@ def test_basic_config_creation(self): def test_default_values(self): """Test default configuration values.""" config = TortoiseConfig(db_url="sqlite://:memory:") - + assert config.modules == {"models": []} assert config.generate_schemas is False assert config.use_tz is False @@ -49,7 +50,7 @@ def test_db_url_validation(self): "asyncpg://user:pass@localhost:5432/db", "aiomysql://user:pass@localhost:3306/db", ] - + for url in valid_urls: config = TortoiseConfig(db_url=url) assert config.db_url == url @@ -61,7 +62,7 @@ def test_db_url_validation(self): "ftp://example.com", "redis://localhost:6379", ] - + for url in invalid_urls: with pytest.raises(ValidationError): TortoiseConfig(db_url=url) @@ -74,7 +75,7 @@ def test_modules_validation(self): {"app1": ["app1.models"], "app2": ["app2.models"]}, {"models": ["app.models", "app.user_models"]}, ] - + for modules in valid_modules: config = TortoiseConfig(db_url="sqlite://:memory:", modules=modules) assert config.modules == modules @@ -85,7 +86,7 @@ def test_modules_validation(self): {"models": "not_a_list"}, {"models": [123, 456]}, # Non-string modules ] - + for modules in invalid_modules: with pytest.raises(ValidationError): TortoiseConfig(db_url="sqlite://:memory:", modules=modules) @@ -97,15 +98,15 @@ def test_to_tortoise_config(self): modules={"models": ["app.models"]}, generate_schemas=True, use_tz=True, - timezone="America/New_York" + timezone="America/New_York", ) - + tortoise_config = config.to_tortoise_config() - + expected_keys = ["db_url", "modules", "use_tz", "timezone"] for key in expected_keys: assert key in tortoise_config - + assert tortoise_config["db_url"] == "sqlite://:memory:" assert tortoise_config["modules"] == {"models": ["app.models"]} assert tortoise_config["use_tz"] is True @@ -113,51 +114,32 @@ def test_to_tortoise_config(self): def test_to_tortoise_config_with_connections_and_apps(self): """Test conversion with custom connections and apps.""" - connections = { - "default": "sqlite://:memory:", - "cache": "sqlite://cache.db" - } - apps = { - "models": { - "models": ["app.models"], - "default_connection": "default" - } - } - + connections = {"default": "sqlite://:memory:", "cache": "sqlite://cache.db"} + apps = {"models": {"models": ["app.models"], "default_connection": "default"}} + config = TortoiseConfig( - db_url="sqlite://:memory:", - connections=connections, - apps=apps + db_url="sqlite://:memory:", connections=connections, apps=apps ) - + tortoise_config = config.to_tortoise_config() - + assert tortoise_config["connections"] == connections assert tortoise_config["apps"] == apps - - - - - - def test_str_representation(self): """Test string representation of config.""" config = TortoiseConfig( - db_url="sqlite://:memory:", - modules={"models": ["app.models"]} + db_url="sqlite://:memory:", modules={"models": ["app.models"]} ) - + config_str = str(config) assert "TortoiseConfig" in config_str assert "sqlite://:memory:" in config_str def test_str_representation_with_password(self): """Test string representation masks password in URL.""" - config = TortoiseConfig( - db_url="postgres://user:secret@localhost:5432/db" - ) - + config = TortoiseConfig(db_url="postgres://user:secret@localhost:5432/db") + config_str = str(config) assert "secret" not in config_str assert "***" in config_str @@ -168,20 +150,14 @@ def test_complex_configuration(self): connections = { "default": "sqlite://:memory:", "users": "postgres://user:pass@localhost:5432/users", - "analytics": "mysql://user:pass@localhost:3306/analytics" + "analytics": "mysql://user:pass@localhost:3306/analytics", } - + apps = { - "models": { - "models": ["app.models"], - "default_connection": "default" - }, - "users": { - "models": ["users.models"], - "default_connection": "users" - } + "models": {"models": ["app.models"], "default_connection": "default"}, + "users": {"models": ["users.models"], "default_connection": "users"}, } - + config = TortoiseConfig( db_url="sqlite://:memory:", modules={"models": ["app.models", "users.models"]}, @@ -189,9 +165,9 @@ def test_complex_configuration(self): use_tz=True, timezone="UTC", connections=connections, - apps=apps + apps=apps, ) - + assert config.db_url == "sqlite://:memory:" assert config.modules == {"models": ["app.models", "users.models"]} assert config.generate_schemas is True @@ -199,8 +175,8 @@ def test_complex_configuration(self): assert config.timezone == "UTC" assert config.connections == connections assert config.apps == apps - + # Test conversion tortoise_config = config.to_tortoise_config() assert tortoise_config["connections"] == connections - assert tortoise_config["apps"] == apps \ No newline at end of file + assert tortoise_config["apps"] == apps diff --git a/tests/tortoise/test_init.py b/tests/tortoise/test_init.py index 28d7135..ee6206c 100644 --- a/tests/tortoise/test_init.py +++ b/tests/tortoise/test_init.py @@ -2,14 +2,15 @@ Tests for Tortoise ORM initialization and main module functions. """ -import pytest from unittest.mock import AsyncMock, MagicMock, patch +import pytest from nexios import NexiosApp + from nexios_contrib.tortoise import ( - init_tortoise, - get_tortoise_client, TortoiseConnectionError, + get_tortoise_client, + init_tortoise, ) from nexios_contrib.tortoise.config import TortoiseConfig @@ -20,18 +21,18 @@ class TestTortoiseInit: def test_init_tortoise_basic(self): """Test basic Tortoise initialization.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client - + init_tortoise( app, db_url="sqlite://:memory:", modules={"models": ["app.models"]}, - generate_schemas=True + generate_schemas=True, ) - + # Verify client was created with correct config mock_client_class.assert_called_once() config_arg = mock_client_class.call_args[0][0] @@ -39,10 +40,10 @@ def test_init_tortoise_basic(self): assert config_arg.db_url == "sqlite://:memory:" assert config_arg.modules == {"models": ["app.models"]} assert config_arg.generate_schemas is True - + # Verify client was stored in app state assert app.state["tortoise"] == mock_client - + # Verify startup and shutdown handlers were registered assert len(app.startup_handlers) == 1 assert len(app.shutdown_handlers) == 1 @@ -50,10 +51,10 @@ def test_init_tortoise_basic(self): def test_init_tortoise_with_defaults(self): """Test Tortoise initialization with default values.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: init_tortoise(app, db_url="sqlite://:memory:") - + config_arg = mock_client_class.call_args[0][0] assert config_arg.modules == {"models": []} assert config_arg.generate_schemas is False @@ -61,37 +62,37 @@ def test_init_tortoise_with_defaults(self): def test_init_tortoise_without_exception_handlers(self): """Test initialization without exception handlers.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient"): - with patch("nexios_contrib.tortoise.handle_tortoise_exceptions") as mock_handle: + with patch( + "nexios_contrib.tortoise.handle_tortoise_exceptions" + ) as mock_handle: init_tortoise( - app, - db_url="sqlite://:memory:", - add_exception_handlers=False + app, db_url="sqlite://:memory:", add_exception_handlers=False ) - + # Exception handlers should not be added mock_handle.assert_not_called() def test_init_tortoise_with_exception_handlers(self): """Test initialization with exception handlers.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient"): - with patch("nexios_contrib.tortoise.handle_tortoise_exceptions") as mock_handle: + with patch( + "nexios_contrib.tortoise.handle_tortoise_exceptions" + ) as mock_handle: init_tortoise( - app, - db_url="sqlite://:memory:", - add_exception_handlers=True + app, db_url="sqlite://:memory:", add_exception_handlers=True ) - + # Exception handlers should be added mock_handle.assert_called_once_with(app) def test_init_tortoise_with_kwargs(self): """Test initialization with additional kwargs.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: init_tortoise( app, @@ -99,9 +100,9 @@ def test_init_tortoise_with_kwargs(self): modules={"models": ["app.models"]}, use_tz=True, timezone="America/New_York", - custom_param="custom_value" + custom_param="custom_value", ) - + config_arg = mock_client_class.call_args[0][0] assert config_arg.use_tz is True assert config_arg.timezone == "America/New_York" @@ -110,17 +111,17 @@ def test_init_tortoise_with_kwargs(self): async def test_startup_handler_success(self): """Test successful startup handler execution.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: mock_client = MagicMock() mock_client.init = AsyncMock() mock_client_class.return_value = mock_client - + init_tortoise(app, db_url="sqlite://:memory:") - + # Execute startup handler await app._startup() - + # Verify client.init was called mock_client.init.assert_called_once() @@ -128,18 +129,18 @@ async def test_startup_handler_success(self): async def test_startup_handler_failure(self): """Test startup handler failure.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: mock_client = MagicMock() mock_client.init = AsyncMock(side_effect=Exception("Init failed")) mock_client_class.return_value = mock_client - + init_tortoise(app, db_url="sqlite://:memory:") - + # Startup should raise TortoiseConnectionError with pytest.raises(TortoiseConnectionError) as exc_info: await app._startup() - + assert "Failed to initialize Tortoise ORM" in str(exc_info.value) assert "Init failed" in str(exc_info.value) @@ -147,17 +148,17 @@ async def test_startup_handler_failure(self): async def test_shutdown_handler_success(self): """Test successful shutdown handler execution.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: mock_client = MagicMock() mock_client.close = AsyncMock() mock_client_class.return_value = mock_client - + init_tortoise(app, db_url="sqlite://:memory:") - + # Execute shutdown handler await app._shutdown() - + # Verify client.close was called mock_client.close.assert_called_once() @@ -165,30 +166,30 @@ async def test_shutdown_handler_success(self): async def test_shutdown_handler_failure(self): """Test shutdown handler with error (should not raise).""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: mock_client = MagicMock() mock_client.close = AsyncMock(side_effect=Exception("Close failed")) mock_client_class.return_value = mock_client - + init_tortoise(app, db_url="sqlite://:memory:") - + # Shutdown should not raise (errors are logged) await app._shutdown() - + # Verify client.close was called mock_client.close.assert_called_once() def test_get_tortoise_client_success(self): """Test successful client retrieval.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client - + init_tortoise(app, db_url="sqlite://:memory:") - + client = get_tortoise_client() assert client == mock_client @@ -196,28 +197,29 @@ def test_get_tortoise_client_not_initialized(self): """Test client retrieval when not initialized.""" # Reset global client import nexios_contrib.tortoise + nexios_contrib.tortoise._tortoise_client = None - + with pytest.raises(TortoiseConnectionError) as exc_info: get_tortoise_client() - + assert "Tortoise ORM client not initialized" in str(exc_info.value) assert "Call init_tortoise() first" in str(exc_info.value) def test_multiple_init_calls(self): """Test that multiple init calls work correctly.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: mock_client1 = MagicMock() mock_client2 = MagicMock() mock_client_class.side_effect = [mock_client1, mock_client2] - + # First init init_tortoise(app, db_url="sqlite://:memory:") client1 = get_tortoise_client() assert client1 == mock_client1 - + # Second init (should replace the first) init_tortoise(app, db_url="sqlite://test.db") client2 = get_tortoise_client() @@ -227,12 +229,12 @@ def test_multiple_init_calls(self): def test_init_tortoise_complex_config(self): """Test initialization with complex configuration.""" app = NexiosApp() - + modules = { "models": ["app.models", "app.user_models"], - "analytics": ["analytics.models"] + "analytics": ["analytics.models"], } - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: init_tortoise( app, @@ -241,9 +243,9 @@ def test_init_tortoise_complex_config(self): generate_schemas=False, add_exception_handlers=True, use_tz=True, - timezone="Europe/London" + timezone="Europe/London", ) - + config_arg = mock_client_class.call_args[0][0] assert config_arg.db_url == "postgres://user:pass@localhost:5432/db" assert config_arg.modules == modules @@ -254,13 +256,13 @@ def test_init_tortoise_complex_config(self): def test_app_state_storage(self): """Test that client is stored in app state.""" app = NexiosApp() - + with patch("nexios_contrib.tortoise.TortoiseClient") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client - + init_tortoise(app, db_url="sqlite://:memory:") - + # Verify client is in app state assert "tortoise" in app.state assert app.state["tortoise"] == mock_client @@ -268,13 +270,13 @@ def test_app_state_storage(self): def test_handler_registration_count(self): """Test that exactly one startup and shutdown handler are registered.""" app = NexiosApp() - + initial_startup_count = len(app.startup_handlers) initial_shutdown_count = len(app.shutdown_handlers) - + with patch("nexios_contrib.tortoise.TortoiseClient"): init_tortoise(app, db_url="sqlite://:memory:") - + # Should add exactly one of each handler assert len(app.startup_handlers) == initial_startup_count + 1 - assert len(app.shutdown_handlers) == initial_shutdown_count + 1 \ No newline at end of file + assert len(app.shutdown_handlers) == initial_shutdown_count + 1