diff --git a/aligned/cli.py b/aligned/cli.py index dbf458ce..124b2440 100644 --- a/aligned/cli.py +++ b/aligned/cli.py @@ -9,7 +9,7 @@ import click import json -from pytz import utc # type: ignore +from pytz import utc from aligned.compiler.repo_reader import RepoReader, RepoReference from aligned.feature_store import ContractStore @@ -125,6 +125,7 @@ def cli() -> None: @click.option("--n-workers", default=1) @click.option("--expose-tag") @click.option("--log-level", default="info") +@click.option("--run-as-module", default=True) async def start_proxy_server( contracts: str, host: str, @@ -132,25 +133,42 @@ async def start_proxy_server( n_workers: int, log_level: str, expose_tag: str | None, + run_as_module: bool, ) -> None: + """ + Starts a server that exposes the data that exists in the contracts. + """ from fastapi import FastAPI import uvicorn + from uvicorn.main import LOG_LEVELS # type: ignore + import sys + + if run_as_module: + cwd = Path().cwd().as_posix() + if cwd not in sys.path: + sys.path.append(cwd) @asynccontextmanager async def lifespan(app: FastAPI): import logging from aligned.proxy_api import router_for_store - logging.basicConfig(level=logging.INFO) + logging.basicConfig( + level=LOG_LEVELS[log_level.lower()], + # format="%(asctime)s | %(levelname)-8s | %(module)s:%(funcName)s:%(lineno)d - %(message)s" + ) logger = logging.getLogger(__name__) - logger.info(f"Loading contract from: '{contracts}'") if contracts.startswith("http"): store = await AlignedCloudSource(contracts).as_contract_store() else: - store = await store_from_reference(contracts) + try: + store = await store_from_reference(contracts) + except Exception: + logger.info(f"Unable to read '{contracts}'. Will try to compile") + store = await ContractStore.from_dir() router_for_store(store, expose_tag=expose_tag, app=app) yield @@ -158,7 +176,7 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) config = uvicorn.Config( - app, host=host, port=port, workers=n_workers, log_level=log_level + app, host=host, port=port, workers=n_workers, log_level=log_level.lower() ) server = uvicorn.Server(config) await server.serve() diff --git a/aligned/compiler/model.py b/aligned/compiler/model.py index 82381d66..9cb39208 100644 --- a/aligned/compiler/model.py +++ b/aligned/compiler/model.py @@ -35,7 +35,7 @@ FeatureViewMetadata, FeatureViewWrapper, ) -from aligned.exposed_model.interface import ExposedModel +from aligned.exposed_model.interface import CallableModelFunction, ExposedModel from aligned.retrieval_job import ConvertableToRetrievalJob, PredictionJob, RetrievalJob from aligned.schemas.derivied_feature import DerivedFeature from aligned.schemas.feature import ( @@ -363,11 +363,17 @@ class FeatureInputVersions: default_version: str versions: dict[str, list[FeatureReferencable]] - def compile(self) -> FeatureVersionSchema: + def compile( + self, labels: set[FeatureReference] | None = None + ) -> FeatureVersionSchema: return FeatureVersionSchema( default_version=self.default_version, versions={ - version: [feature.feature_reference() for feature in features] + version: [ + feature.feature_reference() + for feature in features + if feature.feature_reference() not in (labels or set()) + ] for version, features in self.versions.items() }, ) @@ -380,7 +386,8 @@ def model_contract( | ModelContractWrapper | Sequence[FeatureReferencable] ] - | FeatureInputVersions, + | FeatureInputVersions + | None = None, name: str | None = None, contacts: list[Contact] | list[str] | None = None, tags: list[str] | None = None, @@ -390,19 +397,24 @@ def model_contract( application_source: CodableBatchDataSource | None = None, dataset_store: DatasetStore | StorageFileReference | None = None, exposed_at_url: str | None = None, - exposed_model: ExposedModel | None = None, + exposed_model: ExposedModel | CallableModelFunction | None = None, acceptable_freshness: timedelta | None = None, unacceptable_freshness: timedelta | None = None, ) -> Callable[[Type[T]], ModelContractWrapper[T]]: def decorator(cls: Type[T]) -> ModelContractWrapper[T]: from aligned.sources.renamer import camel_to_snake_case + if input_features is None and output_source is None: + source = RandomDataSource(fill_mode="random_samples") + else: + source = output_source + if isinstance(input_features, FeatureInputVersions): features_versions = input_features else: unwrapped_input_features: list[FeatureReferencable] = [] - for feature in input_features: + for feature in input_features or []: if isinstance(feature, FeatureViewWrapper): compiled_view = feature.compile() request = compiled_view.request_all @@ -439,9 +451,14 @@ def decorator(cls: Type[T]) -> ModelContractWrapper[T]: elif cls.__doc__: used_description = str(cls.__doc__) + if exposed_model is not None and not isinstance(exposed_model, ExposedModel): + model = ExposedModel.polars_predictor(exposed_model) + else: + model = exposed_model + used_exposed_at_url = exposed_at_url - if exposed_model: - used_exposed_at_url = exposed_model.exposed_at_url or exposed_at_url + if model: + used_exposed_at_url = model.exposed_at_url or exposed_at_url conts = [ Contact(cont) if isinstance(cont, str) else cont for cont in contacts or [] @@ -453,14 +470,14 @@ def decorator(cls: Type[T]) -> ModelContractWrapper[T]: contacts=conts, tags=tags, description=used_description, - output_source=output_source, + output_source=source, output_stream=output_stream, application_source=application_source, dataset_store=resolve_dataset_store(dataset_store) if dataset_store else None, exposed_at_url=used_exposed_at_url, - exposed_model=exposed_model, + exposed_model=model, acceptable_freshness=acceptable_freshness, unacceptable_freshness=unacceptable_freshness, ) @@ -514,7 +531,9 @@ class MyModel(ModelContract): for var_name in var_names: feature = getattr(model, var_name) if isinstance(feature, FeatureFactory): - assert feature._name, f"Expected name but found none in model: {metadata.name} for feature {var_name}" + assert feature._name, ( + f"Expected name but found none in model: {metadata.name} for feature {var_name}" + ) feature._location = FeatureLocation.model(metadata.name) if isinstance(feature, FeatureView): @@ -547,9 +566,9 @@ class MyModel(ModelContract): feature_name = feature.target._name assert feature_name assert feature._name - assert ( - feature.target._name in classification_targets - ), "Target must be a classification target." + assert feature.target._name in classification_targets, ( + "Target must be a classification target." + ) target = classification_targets[feature.target._name] target.class_probabilities.add(feature.compile()) @@ -632,7 +651,7 @@ def sort_key(x: tuple[int, FeatureFactory]) -> int: inference_view.features.add(feature.feature()) # Needs to run after the feature views have compiled - features = metadata.features.compile() + features = metadata.features.compile(inference_view.labels_estimates_refs()) for target, probabilities in probability_features.items(): from aligned.schemas.transformation import MapArgMax diff --git a/aligned/exposed_model/interface.py b/aligned/exposed_model/interface.py index 96cef1c2..949b446e 100644 --- a/aligned/exposed_model/interface.py +++ b/aligned/exposed_model/interface.py @@ -1,7 +1,15 @@ from __future__ import annotations import polars as pl -from typing import TYPE_CHECKING, Any, AsyncIterable, Awaitable, Callable, Coroutine +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + Awaitable, + Callable, + Coroutine, + Union, +) from dataclasses import dataclass from aligned.compiler.feature_factory import FeatureReferencable from aligned.config_value import ConfigValue, LiteralValue @@ -24,6 +32,18 @@ from aligned.exposed_model.mlflow import MlflowConfig from aligned.exposed_model.openai import OpenAiConfig +CallableModelFunction = Union[ + Callable[[pl.DataFrame, "ContractStore"], Awaitable[pl.DataFrame]], + Callable[[pl.DataFrame, "ContractStore"], Awaitable[pl.Series]], + Callable[[pl.DataFrame], Awaitable[pl.DataFrame]], + Callable[[pl.DataFrame], Awaitable[pl.Series]], + Callable[[pl.DataFrame, "ContractStore"], pl.DataFrame], + Callable[[pl.DataFrame, "ContractStore"], pl.Series], + Callable[[pl.DataFrame], pl.DataFrame], + Callable[[pl.DataFrame], pl.Series], +] + + logger = logging.getLogger(__name__) @@ -140,9 +160,9 @@ async def potential_drift_from_model(self, old_model: ExposedModel) -> str | Non raise NotImplementedError(type(self)) def _serialize(self) -> dict: - assert ( - self.model_type in PredictorFactory.shared().supported_predictors - ), f"Unknown predictor_type: {self.model_type}" + assert self.model_type in PredictorFactory.shared().supported_predictors, ( + f"Unknown predictor_type: {self.model_type}" + ) return self.to_dict() def with_shadow(self, shadow_model: ExposedModel) -> ShadowModel: @@ -163,9 +183,7 @@ def _deserialize(cls, value: dict) -> ExposedModel: @staticmethod def polars_predictor( - callable: Callable[ - [pl.DataFrame, ModelFeatureStore], Coroutine[None, None, pl.DataFrame] - ], + callable: CallableModelFunction, features: list[FeatureReferencable] | None = None, ) -> "ExposedModel": refs = [feat.feature_reference() for feat in features or []] @@ -331,20 +349,29 @@ async def run_polars( exec(self.code) function = locals()[self.function_name] - args = inspect.signature(function).parameters + sign = inspect.signature(function) + + args = sign.parameters if len(args) == 1: + out = sign.return_annotation + if self.feature_refs: - features = await store.store.features_for( - values, self.feature_refs - ).to_polars() + job = store.store.features_for(values, self.feature_refs) + features = await job.to_polars() + else: + job = store.input_features_for(values) + features = await job.to_polars() + + if (out == pl.Series) or (out == "pl.Series"): + input_features = features[job.request_result.feature_columns] else: - features = await store.input_features_for(values).to_polars() + input_features = features if inspect.iscoroutinefunction(function): - pred = await function(features) + pred = await function(input_features) else: - pred = function(features) + pred = function(input_features) else: features = await values.to_polars() if inspect.iscoroutinefunction(function): @@ -355,7 +382,8 @@ async def run_polars( if isinstance(pred, pl.DataFrame): return pred - assert isinstance(pred, pl.Series), f"Expected a Series but got {type(pred)}" + if not isinstance(pred, pl.Series): + pred = pl.Series(pred) pred_view = store.model.predictions_view pred_columns = pred_view.labels() @@ -376,10 +404,7 @@ async def run_polars( @staticmethod def from_function( - function: Callable[[pl.DataFrame], pl.Series] - | Callable[[pl.DataFrame], Awaitable[pl.Series]] - | Callable[[pl.DataFrame, ContractStore], pl.Series] - | Callable[[pl.DataFrame, ContractStore], Awaitable[pl.Series]], + function: CallableModelFunction, feature_refs: list[FeatureReference] | None = None, ) -> CodePredictor: import dill diff --git a/aligned/exposed_model/mlflow.py b/aligned/exposed_model/mlflow.py index 0cff5dde..07de7139 100644 --- a/aligned/exposed_model/mlflow.py +++ b/aligned/exposed_model/mlflow.py @@ -36,6 +36,10 @@ def mlflow_spec( return ColSpec("float", name=feature.name) elif dtype.name == "string": return ColSpec("string", name=feature.name) + elif dtype.name == "bool": + return ColSpec("boolean", name=feature.name) + elif dtype.name in ["int64"]: + return ColSpec("long", name=feature.name) elif dtype.is_numeric: return ColSpec("integer", name=feature.name) elif dtype.is_datetime: diff --git a/aligned/feature_view/feature_view.py b/aligned/feature_view/feature_view.py index 4c3036a6..549ab0b1 100644 --- a/aligned/feature_view/feature_view.py +++ b/aligned/feature_view/feature_view.py @@ -220,9 +220,10 @@ def vstack( source_column=source_column, ) - def feature_references( + def references( self, - exclude: Sequence[str] + exclude: Sequence[str | FeatureReferencable] + | FeatureFactory | Callable[[T], Sequence[FeatureReferencable] | FeatureReferencable] | None = None, include: Sequence[str] @@ -232,17 +233,36 @@ def feature_references( req = self.request if exclude is not None: + all_features = req.all_features + + if isinstance(exclude, FeatureFactory): + exclude = [exclude] + if callable(exclude): refs = exclude(self.view) if isinstance(refs, FeatureReferencable): refs = [refs] feature_names = [feat.feature_reference().name for feat in refs] else: - feature_names = list(exclude) + feature_names = [] + + for feat in exclude: + if isinstance(feat, str): + feature_names.append(feat) + elif isinstance(feat, FeatureFactory) and feat._name is None: + feature_names.extend( + [ + exfeat.name + for exfeat in all_features + if exfeat.dtype == feat.dtype + ] + ) + else: + feature_names.append(feat.feature_reference().name) return [ feat.as_reference(req.location) - for feat in req.all_features + for feat in all_features if feat.name not in feature_names ] @@ -1064,7 +1084,7 @@ class MyView: return f""" from aligned import feature_view, {all_types} -{imports or ''} +{imports or ""} @feature_view( name="{view_name}", diff --git a/aligned/proxy_api.py b/aligned/proxy_api.py index 000ab56f..c5d92872 100644 --- a/aligned/proxy_api.py +++ b/aligned/proxy_api.py @@ -173,9 +173,9 @@ async def read_entities(entities: dict[str, list[Any]]) -> list[dict[str, Any]]: ) async def read_entity(entity: dict[str, Any]) -> dict: first_key = next(iter(entity.keys())) - assert not isinstance( - entity[first_key], list - ), "Expects only one entity. Consider using the /entities request instead." + assert not isinstance(entity[first_key], list), ( + "Expects only one entity. Consider using the /entities request instead." + ) start_time = monotonic() if location.location_type == "model": @@ -202,9 +202,9 @@ def add_infer_route( assert location.location_type == "model" model = store.model(location.name) - assert ( - model.has_exposed_model() - ), f"Model '{location.name}' needs to have an exposed model to infer." + assert model.has_exposed_model(), ( + f"Model '{location.name}' needs to have an exposed model to infer." + ) route_name = location.name.replace("_", "-") @@ -264,7 +264,7 @@ def add_infer_route( ) async def infer(entities: dict[str, list[Any]]) -> dict[str, list[Any]]: output = await model.predict_over(entities).to_polars() - return output.to_dict(as_series=False) + return output.select(view_request.all_returned_columns).to_dict(as_series=False) @overload diff --git a/aligned/schemas/feature_view.py b/aligned/schemas/feature_view.py index 2604e5e7..08899470 100644 --- a/aligned/schemas/feature_view.py +++ b/aligned/schemas/feature_view.py @@ -33,7 +33,9 @@ class ViewTags: class Contact(Codable): name: str email: str | None = field(default=None) + slack_member_id: str | None = field(default=None) + discord_member_id: str | None = field(default=None) @dataclass diff --git a/aligned/sources/local.py b/aligned/sources/local.py index 08172899..f2e1e4c6 100644 --- a/aligned/sources/local.py +++ b/aligned/sources/local.py @@ -24,6 +24,7 @@ from aligned.s3.storage import FileStorage, HttpStorage from aligned.schemas.codable import Codable from aligned.schemas.feature import FeatureType, Feature +from aligned.sources.s3_config import S3Config from aligned.storage import Storage from aligned.feature_source import WritableFeatureSource from aligned.schemas.date_formatter import DateFormatter @@ -739,6 +740,9 @@ class ParquetFileSource( config: ParquetConfig = field(default_factory=ParquetConfig) date_formatter: DateFormatter = field(default_factory=lambda: DateFormatter.noop()) + s3_config: S3Config | None = field(default=None) + azure_config: AzureBlobConfig | None = field(default=None) + type_name: str = "parquet" @property @@ -759,6 +763,8 @@ def with_view(self, view: CompiledFeatureView) -> ParquetFileSource: mapping_keys=self.mapping_keys, config=self.config, date_formatter=self.date_formatter, + s3_config=self.s3_config, + azure_config=self.azure_config, ) def job_group_key(self) -> str: @@ -767,6 +773,20 @@ def job_group_key(self) -> str: def __hash__(self) -> int: return hash(self.job_group_key()) + def storage_options(self) -> dict[str, Any] | None: + if self.azure_config: + return self.azure_config.read_creds() + if self.s3_config: + return self.s3_config.storage_options() + return None + + def with_protocol(self, path: str) -> str: + if self.azure_config: + return f"adfs://{path}" + if self.s3_config: + return f"s3://{path}" + return path + async def delete(self, predicate: Expression | None = None) -> None: if not predicate: delete_path(self.path.as_posix()) @@ -778,9 +798,9 @@ async def delete(self, predicate: Expression | None = None) -> None: await self.write_polars(filtered_df) async def read_pandas(self) -> pd.DataFrame: - path = self.path.as_posix() + path = self.with_protocol(self.path.as_posix()) try: - return pd.read_parquet(path) + return pd.read_parquet(path, storage_options=self.storage_options()) except FileNotFoundError: raise UnableToFindFileException(path) except HTTPStatusError: @@ -789,26 +809,38 @@ async def read_pandas(self) -> pd.DataFrame: async def write_pandas(self, df: pd.DataFrame) -> None: create_parent_dir(self.path.as_posix()) df.to_parquet( - self.path.as_posix(), + self.with_protocol(self.path.as_posix()), engine=self.config.engine, compression=self.config.compression, index=False, + storage_options=self.storage_options(), ) async def to_lazy_polars(self) -> pl.LazyFrame: - path = self.path.as_posix() - if (not path.startswith("http")) and (not do_file_exist(path)): + path = self.with_protocol(self.path.as_posix()) + storage_options = self.storage_options() + if ( + storage_options is None + and (not path.startswith("http")) + and (not do_file_exist(path)) + ): raise UnableToFindFileException(path) try: - return pl.scan_parquet(path) + return pl.scan_parquet(path, storage_options=storage_options) except OSError: raise UnableToFindFileException(path) async def write_polars(self, df: pl.LazyFrame) -> None: - path = self.path.as_posix() - create_parent_dir(path) - df.collect().write_parquet(path, compression=self.config.compression) + path = self.with_protocol(self.path.as_posix()) + if ":" not in path: + create_parent_dir(path) + logger.info(f"Writing to {path}") + df.collect().write_parquet( + path, + compression=self.config.compression, + storage_options=self.storage_options(), + ) def all_data(self, request: RetrievalRequest, limit: int | None) -> RetrievalJob: return FileFullJob(self, request, limit, date_formatter=self.date_formatter) @@ -901,12 +933,15 @@ class DeltaFileSource( date_formatter: DateFormatter = field(default_factory=lambda: DateFormatter.noop()) azure_config: AzureBlobConfig | None = field(default=None) + s3_config: S3Config | None = field(default=None) type_name: str = "delta" def storage_options(self) -> dict[str, Any] | None: if self.azure_config: return self.azure_config.read_creds() + if self.s3_config: + return self.s3_config.storage_options() return None def resolved_path(self) -> str: diff --git a/aligned/sources/s3.py b/aligned/sources/s3.py index 3baa1611..646d1e59 100644 --- a/aligned/sources/s3.py +++ b/aligned/sources/s3.py @@ -6,7 +6,9 @@ import polars as pl from httpx import HTTPStatusError -from aligned.config_value import ConfigValue +from aligned.config_value import ( + ConfigValue, +) from aligned.lazy_imports import pandas as pd from aligned.data_source.batch_data_source import ( CodableBatchDataSource, diff --git a/aligned/sources/s3_config.py b/aligned/sources/s3_config.py new file mode 100644 index 00000000..7d58ca81 --- /dev/null +++ b/aligned/sources/s3_config.py @@ -0,0 +1,59 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import TYPE_CHECKING +from aligned.schemas.codable import Codable +from aligned.config_value import ConfigValue, EnvironmentValue, PathResolver + +if TYPE_CHECKING: + from aligned.sources.local import ParquetConfig, DateFormatter, ParquetFileSource + + +@dataclass +class S3Config(Codable): + access_key: ConfigValue = field( + default_factory=lambda: EnvironmentValue("AWS_ACCESS_KEY_ID") + ) + secret_token: ConfigValue = field( + default_factory=lambda: EnvironmentValue("AWS_SECRET_ACCESS_KEY") + ) + region: ConfigValue = field(default_factory=lambda: EnvironmentValue("AWS_REGION")) + + bucket: ConfigValue | None = field(default=None) + endpoint: ConfigValue | None = field(default=None) + + def storage_options(self) -> dict[str, str]: + vals = { + "aws_access_key_id": self.access_key.read(), + "aws_secret_access_key": self.secret_token.read(), + "aws_region": self.region.read(), + } + if self.bucket: + vals["aws_bucket"] = self.bucket.read() + if self.endpoint: + vals["aws_endpoint"] = self.endpoint.read() + return vals + + def parquet_at( + self, + path: str, + mapping_keys: dict[str, str] | None = None, + config: ParquetConfig | None = None, + date_formatter: DateFormatter | None = None, + ) -> ParquetFileSource: + from aligned.sources.local import ( + ParquetFileSource, + ParquetConfig, + DateFormatter, + ) + + if "/" not in path: + # Need to add a base folder if not existing + path = f"default/{path}" + + return ParquetFileSource( + PathResolver.from_value(path), + mapping_keys=mapping_keys or {}, + config=config or ParquetConfig(), + date_formatter=date_formatter or DateFormatter.noop(), + s3_config=self, + )