Skip to content

Commit 9682896

Browse files
feat: improve SQL validation for aggregate functions and entity model resilience (#1576)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6d1ca97 commit 9682896

6 files changed

Lines changed: 183 additions & 33 deletions

File tree

packages/uipath-platform/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-platform"
3-
version = "0.1.31"
3+
version = "0.1.32"
44
description = "HTTP client library for programmatic access to UiPath Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"

packages/uipath-platform/src/uipath/platform/entities/_entities_service.py

Lines changed: 108 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import sqlparse
66
from httpx import Response
7-
from sqlparse.sql import Parenthesis, Where
8-
from sqlparse.tokens import DML, Keyword, Wildcard
7+
from sqlparse.sql import Function, Identifier, IdentifierList, Parenthesis, Where
8+
from sqlparse.tokens import DML, Keyword, Whitespace, Wildcard
99
from uipath.core.tracing import traced
1010

1111
from ..common._base_service import BaseService
@@ -49,6 +49,7 @@
4949
"GROUPING",
5050
"PARTITION",
5151
]
52+
_AGGREGATE_FUNCTIONS = ("COUNT", "SUM", "AVG", "MIN", "MAX")
5253

5354

5455
class EntitiesService(BaseService):
@@ -177,6 +178,7 @@ def retrieve_by_name(
177178
spec = self._retrieve_by_name_spec(entity_name)
178179
headers = self._folder_key_headers(folder_key)
179180
response = self.request(spec.method, spec.endpoint, headers=headers)
181+
180182
return Entity.model_validate(response.json())
181183

182184
@traced(name="entity_retrieve_by_name", run_type="uipath")
@@ -196,6 +198,7 @@ async def retrieve_by_name_async(
196198
spec = self._retrieve_by_name_spec(entity_name)
197199
headers = self._folder_key_headers(folder_key)
198200
response = await self.request_async(spec.method, spec.endpoint, headers=headers)
201+
199202
return Entity.model_validate(response.json())
200203

201204
@traced(name="list_entities", run_type="uipath")
@@ -1333,18 +1336,102 @@ def _validate_sql_query(self, sql_query: str) -> None:
13331336

13341337
has_where = any(isinstance(t, Where) for t in stmt.tokens)
13351338
has_limit = "LIMIT" in keywords
1336-
if not has_where and not has_limit:
1337-
raise ValueError("Queries without WHERE must include a LIMIT clause.")
1339+
has_from = "FROM" in keywords
1340+
1341+
if not has_from:
1342+
raise ValueError("Queries must include a FROM clause.")
13381343

13391344
projection = self._projection_tokens(stmt)
1340-
has_wildcard = any(t.ttype is Wildcard for t in projection)
1341-
if has_wildcard and not has_where:
1342-
raise ValueError("SELECT * without filtering is not allowed.")
1345+
1346+
if self._projection_has_count_star(projection):
1347+
raise ValueError(
1348+
"COUNT(*) is not supported. Use COUNT(column_name) instead."
1349+
)
1350+
1351+
has_aggregate = self._projection_has_aggregate(projection)
1352+
1353+
if not has_where and not has_limit and not has_aggregate:
1354+
raise ValueError("Queries without WHERE must include a LIMIT clause.")
1355+
1356+
has_bare_wildcard = self._projection_has_bare_wildcard(projection)
1357+
if has_bare_wildcard:
1358+
raise ValueError("SELECT * is not allowed. Specify column names instead.")
13431359
if not has_where and self._projection_column_count(projection) > 4:
13441360
raise ValueError(
13451361
"Selecting more than 4 columns without filtering is not allowed."
13461362
)
13471363

1364+
@staticmethod
1365+
def _projection_has_aggregate(
1366+
projection: list[sqlparse.sql.Token],
1367+
) -> bool:
1368+
"""Check whether the projection contains an aggregate function call."""
1369+
1370+
def _has_agg(token: sqlparse.sql.Token) -> bool:
1371+
if isinstance(token, Function):
1372+
return token.get_name().upper() in _AGGREGATE_FUNCTIONS
1373+
if isinstance(token, Identifier):
1374+
return any(_has_agg(child) for child in token.tokens)
1375+
return False
1376+
1377+
for node in projection:
1378+
if _has_agg(node):
1379+
return True
1380+
if isinstance(node, IdentifierList):
1381+
if any(_has_agg(child) for child in node.tokens):
1382+
return True
1383+
return False
1384+
1385+
@staticmethod
1386+
def _projection_has_count_star(
1387+
projection: list[sqlparse.sql.Token],
1388+
) -> bool:
1389+
"""Check whether projection contains COUNT(*)."""
1390+
1391+
def _is_count_star(func: Function) -> bool:
1392+
if func.get_name().upper() != "COUNT":
1393+
return False
1394+
return any(t.ttype is Wildcard for t in func.flatten())
1395+
1396+
def _has_count_star(token: sqlparse.sql.Token) -> bool:
1397+
if isinstance(token, Function):
1398+
return _is_count_star(token)
1399+
if isinstance(token, Identifier):
1400+
return any(_has_count_star(child) for child in token.tokens)
1401+
return False
1402+
1403+
for node in projection:
1404+
if _has_count_star(node):
1405+
return True
1406+
if isinstance(node, IdentifierList):
1407+
if any(_has_count_star(child) for child in node.tokens):
1408+
return True
1409+
return False
1410+
1411+
@staticmethod
1412+
def _projection_has_bare_wildcard(
1413+
projection: list[sqlparse.sql.Token],
1414+
) -> bool:
1415+
"""Check for a bare ``*`` or qualified ``table.*`` outside a function."""
1416+
1417+
def _identifier_has_wildcard(ident: Identifier) -> bool:
1418+
return any(t.ttype is Wildcard for t in ident.tokens)
1419+
1420+
for node in projection:
1421+
if node.ttype is Wildcard:
1422+
return True
1423+
if isinstance(node, Identifier) and _identifier_has_wildcard(node):
1424+
return True
1425+
if isinstance(node, IdentifierList):
1426+
for child in node.tokens:
1427+
if child.ttype is Wildcard:
1428+
return True
1429+
if isinstance(child, Identifier) and _identifier_has_wildcard(
1430+
child
1431+
):
1432+
return True
1433+
return False
1434+
13481435
@staticmethod
13491436
def _has_subquery(stmt: sqlparse.sql.Statement) -> bool:
13501437
"""Recursively walk the AST looking for SELECT inside parentheses."""
@@ -1369,27 +1456,33 @@ def _walk(token: sqlparse.sql.Token) -> bool:
13691456
def _projection_tokens(
13701457
stmt: sqlparse.sql.Statement,
13711458
) -> list[sqlparse.sql.Token]:
1372-
"""Extract tokens between the first SELECT and FROM."""
1459+
"""Extract non-flattened AST nodes between the first SELECT and FROM."""
13731460
tokens: list[sqlparse.sql.Token] = []
13741461
collecting = False
1375-
for token in stmt.flatten():
1462+
for token in stmt.tokens:
13761463
if token.ttype is DML and token.normalized == "SELECT":
13771464
collecting = True
13781465
continue
1379-
if token.ttype is Keyword and token.normalized == "FROM":
1466+
if token.ttype is Keyword and token.normalized in ("FROM", "INTO"):
13801467
break
1381-
if collecting:
1468+
if token.ttype is Keyword and token.normalized == "DISTINCT":
1469+
continue
1470+
if collecting and token.ttype is not Whitespace:
13821471
tokens.append(token)
13831472
return tokens
13841473

13851474
@staticmethod
13861475
def _projection_column_count(
13871476
projection: list[sqlparse.sql.Token],
13881477
) -> int:
1389-
text = "".join(t.value for t in projection).strip()
1390-
if not text:
1391-
return 0
1392-
return len([part for part in text.split(",") if part.strip()])
1478+
for node in projection:
1479+
if isinstance(node, IdentifierList):
1480+
return len(list(node.get_identifiers()))
1481+
if isinstance(node, (Identifier, Function)):
1482+
return 1
1483+
if node.ttype is Wildcard:
1484+
return 1
1485+
return 0
13931486

13941487

13951488
# Resolve the forward reference to EntitiesService in EntitySetResolution.

packages/uipath-platform/src/uipath/platform/entities/entities.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
get_origin,
1717
)
1818

19-
from pydantic import BaseModel, ConfigDict, Field, create_model
19+
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, create_model
2020

2121
if TYPE_CHECKING:
2222
from ._entities_service import EntitiesService
@@ -140,7 +140,7 @@ class FieldMetadata(BaseModel):
140140
reference_field: Optional["EntityField"] = Field(
141141
default=None, alias="referenceField"
142142
)
143-
reference_type: ReferenceType = Field(alias="referenceType")
143+
reference_type: Optional[ReferenceType] = Field(default=None, alias="referenceType")
144144
sql_type: "FieldDataType" = Field(alias="sqlType")
145145
is_required: bool = Field(alias="isRequired")
146146
display_name: str = Field(alias="displayName")
@@ -212,14 +212,21 @@ class SourceJoinCriteria(BaseModel):
212212
model_config = ConfigDict(
213213
validate_by_name=True,
214214
validate_by_alias=True,
215+
extra="allow",
216+
)
217+
id: Optional[str] = None
218+
entity_id: Optional[str] = Field(default=None, alias="entityId")
219+
join_field_name: Optional[str] = Field(default=None, alias="joinFieldName")
220+
join_type: Optional[str] = Field(default=None, alias="joinType")
221+
related_source_object_id: Optional[str] = Field(
222+
default=None, alias="relatedSourceObjectId"
223+
)
224+
related_source_object_field_name: Optional[str] = Field(
225+
default=None, alias="relatedSourceObjectFieldName"
226+
)
227+
related_source_field_name: Optional[str] = Field(
228+
default=None, alias="relatedSourceFieldName"
215229
)
216-
id: str
217-
entity_id: str = Field(alias="entityId")
218-
join_field_name: str = Field(alias="joinFieldName")
219-
join_type: str = Field(alias="joinType")
220-
related_source_object_id: str = Field(alias="relatedSourceObjectId")
221-
related_source_object_field_name: str = Field(alias="relatedSourceObjectFieldName")
222-
related_source_field_name: str = Field(alias="relatedSourceFieldName")
223230

224231

225232
class ChoiceSetValue(BaseModel):
@@ -326,11 +333,16 @@ class Entity(BaseModel):
326333
entity_type: str = Field(alias="entityType")
327334
description: Optional[str] = Field(default=None, alias="description")
328335
fields: Optional[List[FieldMetadata]] = Field(default=None, alias="fields")
329-
external_fields: Optional[List[ExternalSourceFields]] = Field(
330-
default=None, alias="externalFields"
336+
external_fields: Optional[
337+
List[ExternalField | ExternalSourceFields | Dict[str, Any]]
338+
] = Field(
339+
default=None,
340+
alias="externalFields",
331341
)
332-
source_join_criteria: Optional[List[SourceJoinCriteria]] = Field(
333-
default=None, alias="sourceJoinCriteria"
342+
source_join_criteria: Optional[List[SourceJoinCriteria | Dict[str, Any]]] = Field(
343+
default=None,
344+
validation_alias=AliasChoices("sourceJoinCriteria", "sourceJoinCriterias"),
345+
alias="sourceJoinCriteria",
334346
)
335347
record_count: Optional[int] = Field(default=None, alias="recordCount")
336348
storage_size_in_mb: Optional[float] = Field(default=None, alias="storageSizeInMB")

packages/uipath-platform/tests/services/test_entities_service.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,12 @@ def test_retrieve_records_without_start_and_limit(
308308
[
309309
"SELECT id FROM Customers WHERE id = 1",
310310
"SELECT id, name FROM Customers LIMIT 10",
311-
"SELECT * FROM Customers WHERE status = 'Active'",
311+
"SELECT COUNT(id) FROM Customers",
312+
"SELECT SUM(amount) FROM Orders",
313+
"SELECT AVG(price) FROM Products",
314+
"SELECT MIN(created), MAX(created) FROM Events",
315+
"SELECT COUNT(id) AS total, SUM(amount) AS amt FROM Orders",
316+
"SELECT COUNT(id), name FROM Customers LIMIT 10",
312317
"SELECT id, name, email, phone FROM Customers LIMIT 5",
313318
"SELECT DISTINCT id FROM Customers WHERE id > 100",
314319
"SELECT id FROM Customers WHERE name = 'foo;bar'",
@@ -356,9 +361,49 @@ def test_validate_sql_query_allows_supported_select_queries(
356361
"SELECT id FROM Customers",
357362
"Queries without WHERE must include a LIMIT clause.",
358363
),
364+
(
365+
"SELECT UPPER(name) FROM Customers",
366+
"Queries without WHERE must include a LIMIT clause.",
367+
),
368+
(
369+
"SELECT COALESCE(name, 'N/A') FROM Customers",
370+
"Queries without WHERE must include a LIMIT clause.",
371+
),
372+
(
373+
"SELECT 1 LIMIT 1",
374+
"Queries must include a FROM clause.",
375+
),
376+
(
377+
"SELECT COUNT(*) FROM Customers",
378+
"COUNT(*) is not supported. Use COUNT(column_name) instead.",
379+
),
380+
(
381+
"SELECT COUNT(*), name FROM Customers LIMIT 10",
382+
"COUNT(*) is not supported. Use COUNT(column_name) instead.",
383+
),
384+
(
385+
"SELECT COUNT(*) AS total FROM Customers",
386+
"COUNT(*) is not supported. Use COUNT(column_name) instead.",
387+
),
359388
(
360389
"SELECT * FROM Customers LIMIT 10",
361-
"SELECT * without filtering is not allowed.",
390+
"SELECT * is not allowed. Specify column names instead.",
391+
),
392+
(
393+
"SELECT Customers.* FROM Customers LIMIT 10",
394+
"SELECT * is not allowed. Specify column names instead.",
395+
),
396+
(
397+
"SELECT t.* FROM Customers t LIMIT 10",
398+
"SELECT * is not allowed. Specify column names instead.",
399+
),
400+
(
401+
"SELECT * FROM Customers WHERE status = 'Active'",
402+
"SELECT * is not allowed. Specify column names instead.",
403+
),
404+
(
405+
"SELECT Customers.* FROM Customers WHERE status = 'Active'",
406+
"SELECT * is not allowed. Specify column names instead.",
362407
),
363408
(
364409
"SELECT id, name, email, phone, address FROM Customers LIMIT 10",

packages/uipath-platform/uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/uipath/uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)