Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions aligned/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import click
import json
from pytz import utc # type: ignore
from pytz import utc

from aligned.compiler.repo_reader import RepoReader, RepoReference
from aligned.feature_store import ContractStore
Expand Down Expand Up @@ -125,40 +125,58 @@ def cli() -> None:
@click.option("--n-workers", default=1)
@click.option("--expose-tag")
@click.option("--log-level", default="info")
@click.option("--run-as-module", default=True)
async def start_proxy_server(
contracts: str,
host: str,
port: int,
n_workers: int,
log_level: str,
expose_tag: str | None,
run_as_module: bool,
) -> None:
"""
Starts a server that exposes the data that exists in the contracts.
"""
from fastapi import FastAPI
import uvicorn
from uvicorn.main import LOG_LEVELS # type: ignore
import sys

if run_as_module:
cwd = Path().cwd().as_posix()
if cwd not in sys.path:
sys.path.append(cwd)

@asynccontextmanager
async def lifespan(app: FastAPI):
import logging
from aligned.proxy_api import router_for_store

logging.basicConfig(level=logging.INFO)
logging.basicConfig(
level=LOG_LEVELS[log_level.lower()],
# format="%(asctime)s | %(levelname)-8s | %(module)s:%(funcName)s:%(lineno)d - %(message)s"
)

logger = logging.getLogger(__name__)

logger.info(f"Loading contract from: '{contracts}'")

if contracts.startswith("http"):
store = await AlignedCloudSource(contracts).as_contract_store()
else:
store = await store_from_reference(contracts)
try:
store = await store_from_reference(contracts)
except Exception:
logger.info(f"Unable to read '{contracts}'. Will try to compile")
store = await ContractStore.from_dir()

router_for_store(store, expose_tag=expose_tag, app=app)
yield

app = FastAPI(lifespan=lifespan)

config = uvicorn.Config(
app, host=host, port=port, workers=n_workers, log_level=log_level
app, host=host, port=port, workers=n_workers, log_level=log_level.lower()
)
server = uvicorn.Server(config)
await server.serve()
Expand Down
49 changes: 34 additions & 15 deletions aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
FeatureViewMetadata,
FeatureViewWrapper,
)
from aligned.exposed_model.interface import ExposedModel
from aligned.exposed_model.interface import CallableModelFunction, ExposedModel
from aligned.retrieval_job import ConvertableToRetrievalJob, PredictionJob, RetrievalJob
from aligned.schemas.derivied_feature import DerivedFeature
from aligned.schemas.feature import (
Expand Down Expand Up @@ -363,11 +363,17 @@ class FeatureInputVersions:
default_version: str
versions: dict[str, list[FeatureReferencable]]

def compile(self) -> FeatureVersionSchema:
def compile(
self, labels: set[FeatureReference] | None = None
) -> FeatureVersionSchema:
return FeatureVersionSchema(
default_version=self.default_version,
versions={
version: [feature.feature_reference() for feature in features]
version: [
feature.feature_reference()
for feature in features
if feature.feature_reference() not in (labels or set())
]
for version, features in self.versions.items()
},
)
Expand All @@ -380,7 +386,8 @@ def model_contract(
| ModelContractWrapper
| Sequence[FeatureReferencable]
]
| FeatureInputVersions,
| FeatureInputVersions
| None = None,
name: str | None = None,
contacts: list[Contact] | list[str] | None = None,
tags: list[str] | None = None,
Expand All @@ -390,19 +397,24 @@ def model_contract(
application_source: CodableBatchDataSource | None = None,
dataset_store: DatasetStore | StorageFileReference | None = None,
exposed_at_url: str | None = None,
exposed_model: ExposedModel | None = None,
exposed_model: ExposedModel | CallableModelFunction | None = None,
acceptable_freshness: timedelta | None = None,
unacceptable_freshness: timedelta | None = None,
) -> Callable[[Type[T]], ModelContractWrapper[T]]:
def decorator(cls: Type[T]) -> ModelContractWrapper[T]:
from aligned.sources.renamer import camel_to_snake_case

if input_features is None and output_source is None:
source = RandomDataSource(fill_mode="random_samples")
else:
source = output_source

if isinstance(input_features, FeatureInputVersions):
features_versions = input_features
else:
unwrapped_input_features: list[FeatureReferencable] = []

for feature in input_features:
for feature in input_features or []:
if isinstance(feature, FeatureViewWrapper):
compiled_view = feature.compile()
request = compiled_view.request_all
Expand Down Expand Up @@ -439,9 +451,14 @@ def decorator(cls: Type[T]) -> ModelContractWrapper[T]:
elif cls.__doc__:
used_description = str(cls.__doc__)

if exposed_model is not None and not isinstance(exposed_model, ExposedModel):
model = ExposedModel.polars_predictor(exposed_model)
else:
model = exposed_model

used_exposed_at_url = exposed_at_url
if exposed_model:
used_exposed_at_url = exposed_model.exposed_at_url or exposed_at_url
if model:
used_exposed_at_url = model.exposed_at_url or exposed_at_url

conts = [
Contact(cont) if isinstance(cont, str) else cont for cont in contacts or []
Expand All @@ -453,14 +470,14 @@ def decorator(cls: Type[T]) -> ModelContractWrapper[T]:
contacts=conts,
tags=tags,
description=used_description,
output_source=output_source,
output_source=source,
output_stream=output_stream,
application_source=application_source,
dataset_store=resolve_dataset_store(dataset_store)
if dataset_store
else None,
exposed_at_url=used_exposed_at_url,
exposed_model=exposed_model,
exposed_model=model,
acceptable_freshness=acceptable_freshness,
unacceptable_freshness=unacceptable_freshness,
)
Expand Down Expand Up @@ -514,7 +531,9 @@ class MyModel(ModelContract):
for var_name in var_names:
feature = getattr(model, var_name)
if isinstance(feature, FeatureFactory):
assert feature._name, f"Expected name but found none in model: {metadata.name} for feature {var_name}"
assert feature._name, (
f"Expected name but found none in model: {metadata.name} for feature {var_name}"
)
feature._location = FeatureLocation.model(metadata.name)

if isinstance(feature, FeatureView):
Expand Down Expand Up @@ -547,9 +566,9 @@ class MyModel(ModelContract):
feature_name = feature.target._name
assert feature_name
assert feature._name
assert (
feature.target._name in classification_targets
), "Target must be a classification target."
assert feature.target._name in classification_targets, (
"Target must be a classification target."
)

target = classification_targets[feature.target._name]
target.class_probabilities.add(feature.compile())
Expand Down Expand Up @@ -632,7 +651,7 @@ def sort_key(x: tuple[int, FeatureFactory]) -> int:
inference_view.features.add(feature.feature())

# Needs to run after the feature views have compiled
features = metadata.features.compile()
features = metadata.features.compile(inference_view.labels_estimates_refs())

for target, probabilities in probability_features.items():
from aligned.schemas.transformation import MapArgMax
Expand Down
63 changes: 44 additions & 19 deletions aligned/exposed_model/interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations

import polars as pl
from typing import TYPE_CHECKING, Any, AsyncIterable, Awaitable, Callable, Coroutine
from typing import (
TYPE_CHECKING,
Any,
AsyncIterable,
Awaitable,
Callable,
Coroutine,
Union,
)
from dataclasses import dataclass
from aligned.compiler.feature_factory import FeatureReferencable
from aligned.config_value import ConfigValue, LiteralValue
Expand All @@ -24,6 +32,18 @@
from aligned.exposed_model.mlflow import MlflowConfig
from aligned.exposed_model.openai import OpenAiConfig

CallableModelFunction = Union[
Callable[[pl.DataFrame, "ContractStore"], Awaitable[pl.DataFrame]],
Callable[[pl.DataFrame, "ContractStore"], Awaitable[pl.Series]],
Callable[[pl.DataFrame], Awaitable[pl.DataFrame]],
Callable[[pl.DataFrame], Awaitable[pl.Series]],
Callable[[pl.DataFrame, "ContractStore"], pl.DataFrame],
Callable[[pl.DataFrame, "ContractStore"], pl.Series],
Callable[[pl.DataFrame], pl.DataFrame],
Callable[[pl.DataFrame], pl.Series],
]


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -140,9 +160,9 @@ async def potential_drift_from_model(self, old_model: ExposedModel) -> str | Non
raise NotImplementedError(type(self))

def _serialize(self) -> dict:
assert (
self.model_type in PredictorFactory.shared().supported_predictors
), f"Unknown predictor_type: {self.model_type}"
assert self.model_type in PredictorFactory.shared().supported_predictors, (
f"Unknown predictor_type: {self.model_type}"
)
return self.to_dict()

def with_shadow(self, shadow_model: ExposedModel) -> ShadowModel:
Expand All @@ -163,9 +183,7 @@ def _deserialize(cls, value: dict) -> ExposedModel:

@staticmethod
def polars_predictor(
callable: Callable[
[pl.DataFrame, ModelFeatureStore], Coroutine[None, None, pl.DataFrame]
],
callable: CallableModelFunction,
features: list[FeatureReferencable] | None = None,
) -> "ExposedModel":
refs = [feat.feature_reference() for feat in features or []]
Expand Down Expand Up @@ -331,20 +349,29 @@ async def run_polars(
exec(self.code)

function = locals()[self.function_name]
args = inspect.signature(function).parameters
sign = inspect.signature(function)

args = sign.parameters

if len(args) == 1:
out = sign.return_annotation

if self.feature_refs:
features = await store.store.features_for(
values, self.feature_refs
).to_polars()
job = store.store.features_for(values, self.feature_refs)
features = await job.to_polars()
else:
job = store.input_features_for(values)
features = await job.to_polars()

if (out == pl.Series) or (out == "pl.Series"):
input_features = features[job.request_result.feature_columns]
else:
features = await store.input_features_for(values).to_polars()
input_features = features

if inspect.iscoroutinefunction(function):
pred = await function(features)
pred = await function(input_features)
else:
pred = function(features)
pred = function(input_features)
else:
features = await values.to_polars()
if inspect.iscoroutinefunction(function):
Expand All @@ -355,7 +382,8 @@ async def run_polars(
if isinstance(pred, pl.DataFrame):
return pred

assert isinstance(pred, pl.Series), f"Expected a Series but got {type(pred)}"
if not isinstance(pred, pl.Series):
pred = pl.Series(pred)

pred_view = store.model.predictions_view
pred_columns = pred_view.labels()
Expand All @@ -376,10 +404,7 @@ async def run_polars(

@staticmethod
def from_function(
function: Callable[[pl.DataFrame], pl.Series]
| Callable[[pl.DataFrame], Awaitable[pl.Series]]
| Callable[[pl.DataFrame, ContractStore], pl.Series]
| Callable[[pl.DataFrame, ContractStore], Awaitable[pl.Series]],
function: CallableModelFunction,
feature_refs: list[FeatureReference] | None = None,
) -> CodePredictor:
import dill
Expand Down
4 changes: 4 additions & 0 deletions aligned/exposed_model/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def mlflow_spec(
return ColSpec("float", name=feature.name)
elif dtype.name == "string":
return ColSpec("string", name=feature.name)
elif dtype.name == "bool":
return ColSpec("boolean", name=feature.name)
elif dtype.name in ["int64"]:
return ColSpec("long", name=feature.name)
elif dtype.is_numeric:
return ColSpec("integer", name=feature.name)
elif dtype.is_datetime:
Expand Down
30 changes: 25 additions & 5 deletions aligned/feature_view/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,10 @@ def vstack(
source_column=source_column,
)

def feature_references(
def references(
self,
exclude: Sequence[str]
exclude: Sequence[str | FeatureReferencable]
| FeatureFactory
| Callable[[T], Sequence[FeatureReferencable] | FeatureReferencable]
| None = None,
include: Sequence[str]
Expand All @@ -232,17 +233,36 @@ def feature_references(
req = self.request

if exclude is not None:
all_features = req.all_features

if isinstance(exclude, FeatureFactory):
exclude = [exclude]

if callable(exclude):
refs = exclude(self.view)
if isinstance(refs, FeatureReferencable):
refs = [refs]
feature_names = [feat.feature_reference().name for feat in refs]
else:
feature_names = list(exclude)
feature_names = []

for feat in exclude:
if isinstance(feat, str):
feature_names.append(feat)
elif isinstance(feat, FeatureFactory) and feat._name is None:
feature_names.extend(
[
exfeat.name
for exfeat in all_features
if exfeat.dtype == feat.dtype
]
)
else:
feature_names.append(feat.feature_reference().name)

return [
feat.as_reference(req.location)
for feat in req.all_features
for feat in all_features
if feat.name not in feature_names
]

Expand Down Expand Up @@ -1064,7 +1084,7 @@ class MyView:

return f"""
from aligned import feature_view, {all_types}
{imports or ''}
{imports or ""}

@feature_view(
name="{view_name}",
Expand Down
Loading
Loading