44
55import sqlparse
66from httpx import Response
7- from sqlparse .sql import Parenthesis , Where
8- from sqlparse .tokens import DML , Keyword , Wildcard
7+ from sqlparse .sql import Function , Identifier , IdentifierList , Parenthesis , Where
8+ from sqlparse .tokens import DML , Keyword , Whitespace , Wildcard
99from uipath .core .tracing import traced
1010
1111from ..common ._base_service import BaseService
4949 "GROUPING" ,
5050 "PARTITION" ,
5151]
52+ _AGGREGATE_FUNCTIONS = ("COUNT" , "SUM" , "AVG" , "MIN" , "MAX" )
5253
5354
5455class EntitiesService (BaseService ):
@@ -177,6 +178,7 @@ def retrieve_by_name(
177178 spec = self ._retrieve_by_name_spec (entity_name )
178179 headers = self ._folder_key_headers (folder_key )
179180 response = self .request (spec .method , spec .endpoint , headers = headers )
181+
180182 return Entity .model_validate (response .json ())
181183
182184 @traced (name = "entity_retrieve_by_name" , run_type = "uipath" )
@@ -196,6 +198,7 @@ async def retrieve_by_name_async(
196198 spec = self ._retrieve_by_name_spec (entity_name )
197199 headers = self ._folder_key_headers (folder_key )
198200 response = await self .request_async (spec .method , spec .endpoint , headers = headers )
201+
199202 return Entity .model_validate (response .json ())
200203
201204 @traced (name = "list_entities" , run_type = "uipath" )
@@ -1333,18 +1336,102 @@ def _validate_sql_query(self, sql_query: str) -> None:
13331336
13341337 has_where = any (isinstance (t , Where ) for t in stmt .tokens )
13351338 has_limit = "LIMIT" in keywords
1336- if not has_where and not has_limit :
1337- raise ValueError ("Queries without WHERE must include a LIMIT clause." )
1339+ has_from = "FROM" in keywords
1340+
1341+ if not has_from :
1342+ raise ValueError ("Queries must include a FROM clause." )
13381343
13391344 projection = self ._projection_tokens (stmt )
1340- has_wildcard = any (t .ttype is Wildcard for t in projection )
1341- if has_wildcard and not has_where :
1342- raise ValueError ("SELECT * without filtering is not allowed." )
1345+
1346+ if self ._projection_has_count_star (projection ):
1347+ raise ValueError (
1348+ "COUNT(*) is not supported. Use COUNT(column_name) instead."
1349+ )
1350+
1351+ has_aggregate = self ._projection_has_aggregate (projection )
1352+
1353+ if not has_where and not has_limit and not has_aggregate :
1354+ raise ValueError ("Queries without WHERE must include a LIMIT clause." )
1355+
1356+ has_bare_wildcard = self ._projection_has_bare_wildcard (projection )
1357+ if has_bare_wildcard :
1358+ raise ValueError ("SELECT * is not allowed. Specify column names instead." )
13431359 if not has_where and self ._projection_column_count (projection ) > 4 :
13441360 raise ValueError (
13451361 "Selecting more than 4 columns without filtering is not allowed."
13461362 )
13471363
1364+ @staticmethod
1365+ def _projection_has_aggregate (
1366+ projection : list [sqlparse .sql .Token ],
1367+ ) -> bool :
1368+ """Check whether the projection contains an aggregate function call."""
1369+
1370+ def _has_agg (token : sqlparse .sql .Token ) -> bool :
1371+ if isinstance (token , Function ):
1372+ return token .get_name ().upper () in _AGGREGATE_FUNCTIONS
1373+ if isinstance (token , Identifier ):
1374+ return any (_has_agg (child ) for child in token .tokens )
1375+ return False
1376+
1377+ for node in projection :
1378+ if _has_agg (node ):
1379+ return True
1380+ if isinstance (node , IdentifierList ):
1381+ if any (_has_agg (child ) for child in node .tokens ):
1382+ return True
1383+ return False
1384+
1385+ @staticmethod
1386+ def _projection_has_count_star (
1387+ projection : list [sqlparse .sql .Token ],
1388+ ) -> bool :
1389+ """Check whether projection contains COUNT(*)."""
1390+
1391+ def _is_count_star (func : Function ) -> bool :
1392+ if func .get_name ().upper () != "COUNT" :
1393+ return False
1394+ return any (t .ttype is Wildcard for t in func .flatten ())
1395+
1396+ def _has_count_star (token : sqlparse .sql .Token ) -> bool :
1397+ if isinstance (token , Function ):
1398+ return _is_count_star (token )
1399+ if isinstance (token , Identifier ):
1400+ return any (_has_count_star (child ) for child in token .tokens )
1401+ return False
1402+
1403+ for node in projection :
1404+ if _has_count_star (node ):
1405+ return True
1406+ if isinstance (node , IdentifierList ):
1407+ if any (_has_count_star (child ) for child in node .tokens ):
1408+ return True
1409+ return False
1410+
1411+ @staticmethod
1412+ def _projection_has_bare_wildcard (
1413+ projection : list [sqlparse .sql .Token ],
1414+ ) -> bool :
1415+ """Check for a bare ``*`` or qualified ``table.*`` outside a function."""
1416+
1417+ def _identifier_has_wildcard (ident : Identifier ) -> bool :
1418+ return any (t .ttype is Wildcard for t in ident .tokens )
1419+
1420+ for node in projection :
1421+ if node .ttype is Wildcard :
1422+ return True
1423+ if isinstance (node , Identifier ) and _identifier_has_wildcard (node ):
1424+ return True
1425+ if isinstance (node , IdentifierList ):
1426+ for child in node .tokens :
1427+ if child .ttype is Wildcard :
1428+ return True
1429+ if isinstance (child , Identifier ) and _identifier_has_wildcard (
1430+ child
1431+ ):
1432+ return True
1433+ return False
1434+
13481435 @staticmethod
13491436 def _has_subquery (stmt : sqlparse .sql .Statement ) -> bool :
13501437 """Recursively walk the AST looking for SELECT inside parentheses."""
@@ -1369,27 +1456,33 @@ def _walk(token: sqlparse.sql.Token) -> bool:
13691456 def _projection_tokens (
13701457 stmt : sqlparse .sql .Statement ,
13711458 ) -> list [sqlparse .sql .Token ]:
1372- """Extract tokens between the first SELECT and FROM."""
1459+ """Extract non-flattened AST nodes between the first SELECT and FROM."""
13731460 tokens : list [sqlparse .sql .Token ] = []
13741461 collecting = False
1375- for token in stmt .flatten () :
1462+ for token in stmt .tokens :
13761463 if token .ttype is DML and token .normalized == "SELECT" :
13771464 collecting = True
13781465 continue
1379- if token .ttype is Keyword and token .normalized == "FROM" :
1466+ if token .ttype is Keyword and token .normalized in ( "FROM" , "INTO" ) :
13801467 break
1381- if collecting :
1468+ if token .ttype is Keyword and token .normalized == "DISTINCT" :
1469+ continue
1470+ if collecting and token .ttype is not Whitespace :
13821471 tokens .append (token )
13831472 return tokens
13841473
13851474 @staticmethod
13861475 def _projection_column_count (
13871476 projection : list [sqlparse .sql .Token ],
13881477 ) -> int :
1389- text = "" .join (t .value for t in projection ).strip ()
1390- if not text :
1391- return 0
1392- return len ([part for part in text .split ("," ) if part .strip ()])
1478+ for node in projection :
1479+ if isinstance (node , IdentifierList ):
1480+ return len (list (node .get_identifiers ()))
1481+ if isinstance (node , (Identifier , Function )):
1482+ return 1
1483+ if node .ttype is Wildcard :
1484+ return 1
1485+ return 0
13931486
13941487
13951488# Resolve the forward reference to EntitiesService in EntitySetResolution.
0 commit comments