2323import warnings
2424from datetime import datetime
2525from os import getenv
26- from typing import Any , Coroutine , Optional , TypeVar , Union
26+ from typing import Any , Coroutine , Optional , TypeVar , Union , cast
2727
2828import aiohttp
2929import certifi
8181 LAUNCH_NAME_LENGTH_LIMIT ,
8282 LifoQueue ,
8383 agent_name_version ,
84+ compare_semantic_versions ,
85+ extract_server_version ,
8486 root_uri_join ,
8587 uri_join ,
8688)
9496
9597DEFAULT_TASK_TIMEOUT : float = 60.0
9698DEFAULT_SHUTDOWN_TIMEOUT : float = 120.0
99+ MICROSECONDS_MIN_VERSION = "5.13.2"
97100
98101
99102class Client :
@@ -785,6 +788,9 @@ class AsyncRPClient(RP):
785788 __launch_uuid : Optional [str ]
786789 __step_reporter : StepReporter
787790 use_own_launch : bool
791+ _api_info_task : Optional [asyncio .Task [Optional [dict ]]]
792+ _api_info_cache : Optional [dict ]
793+ _use_microseconds : Optional [bool ]
788794
789795 @property
790796 def client (self ) -> Client :
@@ -888,8 +894,35 @@ def __init__(
888894 self .use_own_launch = False
889895 else :
890896 self .use_own_launch = True
897+ self ._api_info_task = None
898+ self ._api_info_cache = None
899+ self ._use_microseconds = None
891900 set_current (self )
892901
902+ def __cache_api_info (self , api_info : Optional [dict ]) -> Optional [dict ]:
903+ if not isinstance (api_info , dict ):
904+ return None
905+ self ._api_info_cache = api_info
906+ version = extract_server_version (api_info )
907+ self ._use_microseconds = bool (version and compare_semantic_versions (version , MICROSECONDS_MIN_VERSION ) >= 0 )
908+ return api_info
909+
910+ async def __prefetch_api_info (self ) -> Optional [dict ]:
911+ try :
912+ api_info = await self .__client .get_api_info ()
913+ return self .__cache_api_info (api_info )
914+ except Exception as exc :
915+ logger .warning ("Unable to prefetch API info in background: %s" , exc )
916+ return None
917+
918+ def __init_api_info_prefetch (self ) -> None :
919+ try :
920+ loop = asyncio .get_running_loop ()
921+ self ._api_info_task = loop .create_task (self .__prefetch_api_info ())
922+ except RuntimeError :
923+ # Construction may happen without an active loop.
924+ self ._api_info_task = None
925+
893926 async def start_launch (
894927 self ,
895928 name : str ,
@@ -1123,7 +1156,22 @@ async def get_api_info(self) -> Optional[dict]:
11231156
11241157 :return: server information.
11251158 """
1126- return await self .__client .get_api_info ()
1159+ if self ._api_info_cache is not None :
1160+ return self .__cache_api_info (self ._api_info_cache )
1161+ if self ._api_info_task :
1162+ return await self ._api_info_task
1163+ api_info = await self .__client .get_api_info ()
1164+ return self .__cache_api_info (api_info )
1165+
1166+ async def use_microseconds (self ) -> Optional [bool ]:
1167+ """Return if current server version supports microseconds."""
1168+ if self ._use_microseconds is not None :
1169+ return self ._use_microseconds
1170+
1171+ await self .get_api_info ()
1172+ if self ._use_microseconds is None :
1173+ self ._use_microseconds = False
1174+ return self ._use_microseconds
11271175
11281176 async def log (
11291177 self ,
@@ -1205,6 +1253,9 @@ class _RPClient(RP, metaclass=AbstractBaseClass):
12051253 __endpoint : str
12061254 __project : str
12071255 __step_reporter : StepReporter
1256+ _api_info_task : Optional [Task [Optional [dict ]]]
1257+ _api_info_cache : Optional [dict ]
1258+ _use_microseconds : Optional [bool ]
12081259
12091260 @property
12101261 def client (self ) -> Client :
@@ -1252,7 +1303,7 @@ def __init__(
12521303 project : str ,
12531304 * ,
12541305 client : Optional [Client ] = None ,
1255- launch_uuid : Optional [Task [Optional [ str ] ]] = None ,
1306+ launch_uuid : Optional [Task [str ]] = None ,
12561307 log_batch_size : int = 20 ,
12571308 log_batch_payload_limit : int = MAX_LOG_BATCH_PAYLOAD_SIZE ,
12581309 log_batcher : Optional [LogBatcher ] = None ,
@@ -1308,6 +1359,9 @@ def __init__(
13081359 else :
13091360 self .own_launch = True
13101361
1362+ self ._api_info_task = None
1363+ self ._api_info_cache = None
1364+ self ._use_microseconds = None
13111365 set_current (self )
13121366
13131367 @abstractmethod
@@ -1354,6 +1408,49 @@ async def __empty_dict(self) -> dict:
13541408 async def __int_value (self ) -> int :
13551409 return - 1
13561410
1411+ async def _return_value (self , value : _T ) -> _T :
1412+ return value
1413+
1414+ async def _prefetch_api_info (self ) -> Optional [dict ]:
1415+ try :
1416+ api_info = await self .__client .get_api_info ()
1417+ self .__cache_api_info (api_info )
1418+ return api_info
1419+ except Exception as exc :
1420+ logger .warning ("Unable to prefetch API info in background: %s" , exc )
1421+ return None
1422+
1423+ def __cache_api_info (self , api_info : Optional [dict ]) -> None :
1424+ if not isinstance (api_info , dict ):
1425+ return
1426+ self ._api_info_cache = api_info
1427+ version = extract_server_version (api_info )
1428+ self ._use_microseconds = bool (version and compare_semantic_versions (version , MICROSECONDS_MIN_VERSION ) >= 0 )
1429+
1430+ async def __resolve_use_microseconds (self ) -> bool :
1431+ if self ._use_microseconds is not None :
1432+ return self ._use_microseconds
1433+
1434+ if self ._api_info_task :
1435+ try :
1436+ api_info = await self ._api_info_task
1437+ self .__cache_api_info (api_info )
1438+ except Exception as exc :
1439+ logger .warning ("Unable to await API info prefetch: %s" , exc )
1440+
1441+ if self ._use_microseconds is not None :
1442+ return self ._use_microseconds or False
1443+
1444+ if self ._api_info_cache is None :
1445+ try :
1446+ self .__cache_api_info (await self .__client .get_api_info ())
1447+ except Exception as exc :
1448+ logger .warning ("Unable to fetch API info for microseconds check: %s" , exc )
1449+
1450+ if self ._use_microseconds is None :
1451+ self ._use_microseconds = False
1452+ return self ._use_microseconds or False
1453+
13571454 def start_launch (
13581455 self ,
13591456 name : str ,
@@ -1586,9 +1683,19 @@ def get_api_info(self) -> Task[Optional[dict]]:
15861683
15871684 :return: server information.
15881685 """
1589- result_coro = self .__client .get_api_info ()
1590- result_task = self .create_task (result_coro )
1591- return result_task
1686+ if self ._api_info_cache is not None :
1687+ return self .create_task (self ._return_value (self ._api_info_cache ))
1688+ if self ._api_info_task :
1689+ return self ._api_info_task
1690+ api_task = self .create_task (self ._prefetch_api_info ())
1691+ self ._api_info_task = api_task
1692+ return api_task
1693+
1694+ def use_microseconds (self ) -> Task [bool ]:
1695+ """Return if current server version supports microseconds."""
1696+ if self ._use_microseconds is not None :
1697+ return self .create_task (self ._return_value (self ._use_microseconds ))
1698+ return self .create_task (self .__resolve_use_microseconds ())
15921699
15931700 async def _log_batch (self , log_rq : Optional [list [AsyncRPRequestLog ]]) -> Optional [tuple [str , ...]]:
15941701 return await self .__client .log_batch (log_rq )
@@ -1689,9 +1796,6 @@ def __init_loop(self, loop: Optional[asyncio.AbstractEventLoop] = None):
16891796 thread .start ()
16901797 self ._thread = thread
16911798
1692- async def __return_value (self , value ):
1693- return value
1694-
16951799 def __init__ (
16961800 self ,
16971801 endpoint : str ,
@@ -1753,19 +1857,29 @@ def __init__(
17531857 self .__init_task_list (task_list , task_mutex )
17541858 self .__init_loop (loop )
17551859 if type (launch_uuid ) is str :
1860+ my_launch_uuid = str (launch_uuid )
17561861 super ().__init__ (
1757- endpoint , project , launch_uuid = self .create_task (self .__return_value ( launch_uuid )), ** kwargs
1862+ endpoint , project , launch_uuid = self .create_task (self ._return_value ( my_launch_uuid )), ** kwargs
17581863 )
17591864 else :
1760- super ().__init__ (endpoint , project , launch_uuid = launch_uuid , ** kwargs )
1865+ my_launch_uuid_task = cast (Task [str ], launch_uuid )
1866+ super ().__init__ (endpoint , project , launch_uuid = my_launch_uuid_task , ** kwargs )
1867+ self .__init_api_info_prefetch ()
1868+
1869+ def __init_api_info_prefetch (self ) -> None :
1870+ if self ._use_microseconds is not None or self ._api_info_cache is not None :
1871+ return
1872+ if self ._loop is None :
1873+ return
1874+ self ._api_info_task = self ._loop .create_task (self ._prefetch_api_info ())
17611875
17621876 def create_task (self , coro : Coroutine [Any , Any , _T ]) -> Task [_T ]:
17631877 """Create a Task from given Coroutine.
17641878
17651879 :param coro: Coroutine which will be used for the Task creation.
17661880 :return: Task instance.
17671881 """
1768- if not getattr ( self , " _loop" , None ) :
1882+ if self . _loop is None :
17691883 return EmptyTask ()
17701884 result = self ._loop .create_task (coro )
17711885 with self ._task_mutex :
@@ -1825,6 +1939,7 @@ def __getstate__(self) -> dict[str, Any]:
18251939 del state ["_task_mutex" ]
18261940 del state ["_loop" ]
18271941 del state ["_thread" ]
1942+ del state ["_api_info_task" ]
18281943 return state
18291944
18301945 def __setstate__ (self , state : dict [str , Any ]) -> None :
@@ -1835,6 +1950,8 @@ def __setstate__(self, state: dict[str, Any]) -> None:
18351950 self .__dict__ .update (state )
18361951 self .__init_task_list (self ._task_list , threading .RLock ())
18371952 self .__init_loop ()
1953+ self ._api_info_task = None
1954+ self .__init_api_info_prefetch ()
18381955
18391956
18401957class BatchedRPClient (_RPClient ):
@@ -1875,9 +1992,6 @@ def __init_loop(self, loop: Optional[asyncio.AbstractEventLoop] = None):
18751992 self ._loop = asyncio .new_event_loop ()
18761993 self ._loop .set_task_factory (BatchedTaskFactory ())
18771994
1878- async def __return_value (self , value ):
1879- return value
1880-
18811995 def __init__ (
18821996 self ,
18831997 endpoint : str ,
@@ -1946,11 +2060,18 @@ def __init__(
19462060 self .__last_run_time = time .time ()
19472061 self .__init_loop (loop )
19482062 if type (launch_uuid ) is str :
2063+ my_launch_uuid = str (launch_uuid )
19492064 super ().__init__ (
1950- endpoint , project , launch_uuid = self .create_task (self .__return_value ( launch_uuid )), ** kwargs
2065+ endpoint , project , launch_uuid = self .create_task (self ._return_value ( my_launch_uuid )), ** kwargs
19512066 )
19522067 else :
1953- super ().__init__ (endpoint , project , launch_uuid = launch_uuid , ** kwargs )
2068+ my_launch_uuid_task = cast (Task [str ], launch_uuid )
2069+ super ().__init__ (endpoint , project , launch_uuid = my_launch_uuid_task , ** kwargs )
2070+ self .__init_api_info_prefetch ()
2071+
2072+ def __init_api_info_prefetch (self ) -> None :
2073+ # Batched client loop runs on demand, so prefetch starts lazily.
2074+ self ._api_info_task = None
19542075
19552076 def create_task (self , coro : Coroutine [Any , Any , _T ]) -> Task [_T ]:
19562077 """Create a Task from given Coroutine.
@@ -2016,6 +2137,7 @@ def __getstate__(self) -> dict[str, Any]:
20162137 # Don't pickle 'session' field, since it contains unpickling 'socket'
20172138 del state ["_task_mutex" ]
20182139 del state ["_loop" ]
2140+ del state ["_api_info_task" ]
20192141 return state
20202142
20212143 def __setstate__ (self , state : dict [str , Any ]) -> None :
@@ -2026,3 +2148,5 @@ def __setstate__(self, state: dict[str, Any]) -> None:
20262148 self .__dict__ .update (state )
20272149 self .__init_task_list (self ._task_list , threading .RLock ())
20282150 self .__init_loop ()
2151+ self ._api_info_task = None
2152+ self .__init_api_info_prefetch ()
0 commit comments