Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import logging
import uuid
from datetime import date, datetime
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, ContextManager, Dict, Iterator, List, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -37,6 +39,8 @@
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage

logger = logging.getLogger(__name__)


class BasicAuthModel(FeastConfigBaseModel):
username: StrictStr
Expand Down Expand Up @@ -177,14 +181,23 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
class TrinoRetrievalJob(RetrievalJob):
def __init__(
self,
query: str,
query: Union[str, Callable[[], ContextManager[str]]],
client: Trino,
config: RepoConfig,
full_feature_names: bool,
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
metadata: Optional[RetrievalMetadata] = None,
):
self._query = query
if not isinstance(query, str):
self._query_generator = query
else:

@contextlib.contextmanager
def query_generator() -> Iterator[str]:
assert isinstance(query, str)
yield query

self._query_generator = query_generator
self._client = client
self._config = config
self._full_feature_names = full_feature_names
Expand All @@ -201,17 +214,19 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:

def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
"""Return dataset as Pandas DataFrame synchronously including on demand transforms"""
results = self._client.execute_query(query_text=self._query)
self.pyarrow_schema = results.pyarrow_schema
return results.to_dataframe()
with self._query_generator() as query:
results = self._client.execute_query(query_text=query)
self.pyarrow_schema = results.pyarrow_schema
return results.to_dataframe()

def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
"""Return payrrow dataset as synchronously including on demand transforms"""
return pyarrow.Table.from_pandas(self._to_df_internal(timeout=timeout))

def to_sql(self) -> str:
"""Returns the SQL query that will be executed in Trino to build the historical feature table"""
return self._query
with self._query_generator() as query:
return query

def to_trino(
self,
Expand All @@ -234,8 +249,9 @@ def to_trino(
destination_table = f"{self._client.catalog}.{self._config.offline_store.dataset}.historical_{today}_{rand_id}"

# TODO: Implement the timeout logic
query = f"CREATE TABLE {destination_table} AS ({self._query})"
self._client.execute_query(query_text=query)
with self._query_generator() as query:
create_query = f"CREATE TABLE {destination_table} AS ({query})"
self._client.execute_query(query_text=create_query)
return destination_table

def persist(
Expand Down Expand Up @@ -372,19 +388,36 @@ def get_historical_features(
)

# Generate the Trino SQL query from the query context
entity_table_ref = table_reference
if type(entity_df) is str:
table_reference = f"({entity_df})"
entity_table_ref = f"({entity_df})"
query = offline_utils.build_point_in_time_query(
query_context,
left_table_query_string=table_reference,
left_table_query_string=entity_table_ref,
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
entity_df_columns=entity_schema.keys(),
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
full_feature_names=full_feature_names,
)

@contextlib.contextmanager
def query_generator() -> Iterator[str]:
try:
yield query
finally:
if isinstance(entity_df, pd.DataFrame):
try:
client.execute_query(
f"DROP TABLE IF EXISTS {table_reference}"
)
except Exception:
logger.exception(
"Failed to drop temporary entity table %s",
table_reference,
)

return TrinoRetrievalJob(
query=query,
query=query_generator,
client=client,
config=config,
full_feature_names=full_feature_names,
Expand Down Expand Up @@ -483,8 +516,6 @@ def _upload_entity_df_and_get_entity_schema(
else:
raise InvalidEntityType(type(entity_df))

# TODO: Ensure that the table expires after some time


def _get_trino_client(config: RepoConfig) -> Trino:
auth = None
Expand Down
Loading