Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 49 additions & 14 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def online_read(
# For single batch, no parallelization overhead needed
if len(batches) == 1:
batch_entity_ids = self._to_resource_batch_get_payload(
online_config, table_name, batches[0]
online_config, table_name, batches[0], requested_features
)
response = dynamodb_resource.batch_get_item(RequestItems=batch_entity_ids)
return self._process_batch_get_response(table_name, response, batches[0])
Expand All @@ -520,7 +520,7 @@ def online_read(

def fetch_batch(batch: List[str]) -> Dict[str, Any]:
batch_entity_ids = self._to_client_batch_get_payload(
online_config, table_name, batch
online_config, table_name, batch, requested_features
)
return dynamodb_client.batch_get_item(RequestItems=batch_entity_ids)

Expand Down Expand Up @@ -599,7 +599,7 @@ def to_tbl_resp(raw_client_response):
if not batch:
break
entity_id_batch = self._to_client_batch_get_payload(
online_config, table_name, batch
online_config, table_name, batch, requested_features
)
batches.append(batch)
entity_id_batches.append(entity_id_batch)
Expand Down Expand Up @@ -760,21 +760,56 @@ def _to_entity_ids(config: RepoConfig, entity_keys: List[EntityKeyProto]):
]

@staticmethod
def _to_resource_batch_get_payload(online_config, table_name, batch):
return {
table_name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
def _to_resource_batch_get_payload(
online_config, table_name, batch, requested_features=None
):
payload: Dict[str, Any] = {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
projection = DynamoDBOnlineStore._build_projection_expression(
requested_features
)
if projection:
payload["ProjectionExpression"] = projection["ProjectionExpression"]
payload["ExpressionAttributeNames"] = projection["ExpressionAttributeNames"]
return {table_name: payload}

@staticmethod
def _to_client_batch_get_payload(
online_config, table_name, batch, requested_features=None
):
payload: Dict[str, Any] = {
"Keys": [{"entity_id": {"S": entity_id}} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
projection = DynamoDBOnlineStore._build_projection_expression(
requested_features
)
if projection:
payload["ProjectionExpression"] = projection["ProjectionExpression"]
payload["ExpressionAttributeNames"] = projection["ExpressionAttributeNames"]
return {table_name: payload}

@staticmethod
def _to_client_batch_get_payload(online_config, table_name, batch):
def _build_projection_expression(
requested_features: Optional[List[str]],
) -> Optional[Dict[str, Any]]:
if not requested_features:
return None
attr_names: Dict[str, str] = {
"#entity_id": "entity_id",
"#event_ts": "event_ts",
"#vals": "values",
}
projections = ["#entity_id", "#event_ts"]
for i, feat in enumerate(requested_features):
alias = f"#feat{i}"
attr_names[alias] = feat
projections.append(f"#vals.{alias}")
return {
table_name: {
"Keys": [{"entity_id": {"S": entity_id}} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
"ProjectionExpression": ", ".join(projections),
"ExpressionAttributeNames": attr_names,
}

def update_online_store(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1050,3 +1050,107 @@ def tracking_client(*args, **kwargs):
f"Expected 1 shared client for thread-safety, "
f"got {len(set(dynamodb_clients))} unique clients"
)


@mock_dynamodb
def test_dynamodb_online_store_online_read_with_requested_features(
repo_config, dynamodb_online_store
):
"""Test that requested_features filters returned features."""
n_samples = 5
db_table_name = f"{TABLE_NAME}_requested_features"
create_test_table(PROJECT, db_table_name, REGION)
data = create_n_customer_test_samples(n=n_samples)
insert_data_test_table(data, PROJECT, db_table_name, REGION)

entity_keys, features, *rest = zip(*data)
returned_items = dynamodb_online_store.online_read(
config=repo_config,
table=MockFeatureView(name=db_table_name),
entity_keys=entity_keys,
requested_features=["name", "age"],
)
assert len(returned_items) == n_samples
for _, feat_dict in returned_items:
assert feat_dict is not None
assert "name" in feat_dict
assert "age" in feat_dict
assert "avg_orders_day" not in feat_dict


@mock_dynamodb
def test_dynamodb_online_store_online_read_without_requested_features(
repo_config, dynamodb_online_store
):
"""Test that omitting requested_features returns all features."""
n_samples = 5
db_table_name = f"{TABLE_NAME}_all_features"
create_test_table(PROJECT, db_table_name, REGION)
data = create_n_customer_test_samples(n=n_samples)
insert_data_test_table(data, PROJECT, db_table_name, REGION)

entity_keys, features, *rest = zip(*data)
returned_items = dynamodb_online_store.online_read(
config=repo_config,
table=MockFeatureView(name=db_table_name),
entity_keys=entity_keys,
requested_features=None,
)
assert len(returned_items) == n_samples
for _, feat_dict in returned_items:
assert feat_dict is not None
assert set(feat_dict.keys()) == {"avg_orders_day", "name", "age"}


def test_build_projection_expression():
"""Test that _build_projection_expression generates correct DynamoDB expressions."""
result = DynamoDBOnlineStore._build_projection_expression(["feat_a", "feat_b"])
assert result is not None
assert "#entity_id" in result["ProjectionExpression"]
assert "#event_ts" in result["ProjectionExpression"]
assert "#vals.#feat0" in result["ProjectionExpression"]
assert "#vals.#feat1" in result["ProjectionExpression"]
attr_names = result["ExpressionAttributeNames"]
assert attr_names["#vals"] == "values"
assert attr_names["#feat0"] == "feat_a"
assert attr_names["#feat1"] == "feat_b"


def test_build_projection_expression_none():
"""Test that _build_projection_expression returns None for empty input."""
assert DynamoDBOnlineStore._build_projection_expression(None) is None
assert DynamoDBOnlineStore._build_projection_expression([]) is None


@mock_dynamodb
def test_dynamodb_online_store_online_read_requested_features_parallel(
dynamodb_online_store,
):
"""Test that requested_features works across parallel batches."""
small_batch_config = RepoConfig(
registry=REGISTRY,
project=PROJECT,
provider=PROVIDER,
online_store=DynamoDBOnlineStoreConfig(region=REGION, batch_size=5),
offline_store=DaskOfflineStoreConfig(),
entity_key_serialization_version=3,
)
n_samples = 15
db_table_name = f"{TABLE_NAME}_requested_parallel"
create_test_table(PROJECT, db_table_name, REGION)
data = create_n_customer_test_samples(n=n_samples)
insert_data_test_table(data, PROJECT, db_table_name, REGION)

entity_keys, features, *rest = zip(*data)
returned_items = dynamodb_online_store.online_read(
config=small_batch_config,
table=MockFeatureView(name=db_table_name),
entity_keys=entity_keys,
requested_features=["age"],
)
assert len(returned_items) == n_samples
for _, feat_dict in returned_items:
assert feat_dict is not None
assert "age" in feat_dict
assert "name" not in feat_dict
assert "avg_orders_day" not in feat_dict
Loading