from typing import Annotated, Dict, List, Literal, Union
import pandas as pd
from diffprivlib.validation import DiffprivlibMixin
from pydantic import (
BaseModel,
ConfigDict,
Discriminator,
PlainSerializer,
PlainValidator,
ValidationInfo,
field_validator,
)
from snsynth import Synthesizer
from lomas_core.constants import DPLibraries
from lomas_core.models.utils import (
dataframe_from_dict,
dataframe_to_dict,
deserialize_model,
serialize_model,
)
[docs]
class ResponseModel(BaseModel):
"""Base model for any response from the server."""
[docs]
class InitialBudgetResponse(ResponseModel):
"""Model for responses to initial budget queries."""
initial_epsilon: float
"""The initial epsilon privacy loss budget."""
initial_delta: float
"""The initial delta privacy loss budget."""
[docs]
class SpentBudgetResponse(ResponseModel):
"""Model for responses to spent budget queries."""
total_spent_epsilon: float
"""The total spent epsilon privacy loss budget."""
total_spent_delta: float
"""The total spent delta privacy loss budget."""
[docs]
class RemainingBudgetResponse(ResponseModel):
"""Model for responses to remaining budget queries."""
remaining_epsilon: float
"""The remaining epsilon privacy loss budget."""
remaining_delta: float
"""The remaining delta privacy loss budget."""
[docs]
class DummyDsResponse(ResponseModel):
"""Model for responses to dummy dataset requests."""
model_config = ConfigDict(arbitrary_types_allowed=True)
dtypes: Dict[str, str]
"""The dummy_df column data types."""
datetime_columns: List[str]
"""The list of columns with datetime type."""
dummy_df: Annotated[pd.DataFrame, PlainSerializer(dataframe_to_dict)]
"""The dummy dataframe."""
[docs]
@field_validator("dummy_df", mode="before")
@classmethod
def deserialize_dummy_df(cls, v: pd.DataFrame | dict, info: ValidationInfo) -> pd.DataFrame:
"""Decodes the dict representation of the dummy df with correct types.
Only does so if the input value is not already a dataframe.
Args:
v (pd.DataFrame | dict): The dataframe to decode.
info (ValidationInfo): Validation info to access other model fields.
Returns:
pd.DataFrame: The decoded dataframe.
"""
if isinstance(v, pd.DataFrame):
return v
dtypes = info.data["dtypes"]
datetime_columns = info.data["datetime_columns"]
dummy_df = dataframe_from_dict(v)
dummy_df = dummy_df.astype(dtypes)
for col in datetime_columns:
dummy_df[col] = pd.to_datetime(dummy_df[col])
return dummy_df
[docs]
class CostResponse(ResponseModel):
"""Model for responses to cost estimation requests or queries."""
model_config = ConfigDict(use_attribute_docstrings=True)
epsilon: float
"""The epsilon cost of the query."""
delta: float
"""The delta cost of the query."""
# Query Responses
# -----------------------------------------------------------------------------
# DiffPrivLib
[docs]
class DiffPrivLibQueryResult(BaseModel):
"""Model for diffprivlib query result."""
model_config = ConfigDict(arbitrary_types_allowed=True)
res_type: Literal[DPLibraries.DIFFPRIVLIB] = DPLibraries.DIFFPRIVLIB
"""Result type description."""
score: float
"""The trained model score."""
model: Annotated[
DiffprivlibMixin,
PlainSerializer(serialize_model),
PlainValidator(deserialize_model),
]
"""The trained model."""
# SmartnoiseSQL
[docs]
class SmartnoiseSQLQueryResult(BaseModel):
"""Type for smartnoise_sql result type."""
model_config = ConfigDict(arbitrary_types_allowed=True)
res_type: Literal[DPLibraries.SMARTNOISE_SQL] = DPLibraries.SMARTNOISE_SQL
"""Result type description."""
df: Annotated[
pd.DataFrame,
PlainSerializer(dataframe_to_dict),
PlainValidator(dataframe_from_dict),
]
"""Dataframe containing the query result."""
# SmartnoiseSynth
[docs]
class SmartnoiseSynthModel(BaseModel):
"""Type for smartnoise_synth result when it is a pickled model."""
model_config = ConfigDict(arbitrary_types_allowed=True)
res_type: Literal[DPLibraries.SMARTNOISE_SYNTH] = DPLibraries.SMARTNOISE_SYNTH
"""Result type description."""
model: Annotated[Synthesizer, PlainSerializer(serialize_model), PlainValidator(deserialize_model)]
"""Synthetic data generator model."""
[docs]
class SmartnoiseSynthSamples(BaseModel):
"""Type for smartnoise_synth result when it is a dataframe of samples."""
model_config = ConfigDict(arbitrary_types_allowed=True)
res_type: Literal["sn_synth_samples"] = "sn_synth_samples"
"""Result type description."""
df_samples: Annotated[
pd.DataFrame,
PlainSerializer(dataframe_to_dict),
PlainValidator(dataframe_from_dict),
]
"""Dataframe containing the generated synthetic samples."""
# OpenDP
[docs]
class OpenDPQueryResult(BaseModel):
"""Type for opendp result."""
res_type: Literal[DPLibraries.OPENDP] = DPLibraries.OPENDP
"""Result type description."""
value: Union[int, float, List[Union[int, float]]]
"""The result value of the query."""
# Response object
QueryResultTypeAlias = Union[
DiffPrivLibQueryResult,
SmartnoiseSQLQueryResult,
SmartnoiseSynthModel,
SmartnoiseSynthSamples,
OpenDPQueryResult,
]
[docs]
class QueryResponse(CostResponse):
"""Response to Lomas queries."""
requested_by: str
"""The user that triggered the query."""
result: Annotated[
QueryResultTypeAlias,
Discriminator("res_type"),
]
"""The query result object."""