Source code for lomas_core.models.collections

from datetime import datetime
from typing import Annotated, Any, Literal, Self

from pydantic import BaseModel, Discriminator, Field, Tag, model_validator

from lomas_core.models.constants import (
    CARDINALITY_FIELD,
    CATEGORICAL_TYPE_PREFIX,
    DB_TYPE_FIELD,
    TYPE_FIELD,
    MetadataColumnType,
    Precision,
    PrivateDatabaseType,
)

# Dataset of User
# -----------------------------------------------------------------------------


[docs] class DatasetOfUser(BaseModel): """BaseModel for informations of a user on a dataset.""" dataset_name: str initial_epsilon: float initial_delta: float total_spent_epsilon: float total_spent_delta: float
# User # -----------------------------------------------------------------------------
[docs] class UserId(BaseModel): """BaseModel for user identification.""" name: str email: str client_secret: Annotated[ str | None, Field(default=None, exclude=True), # exclude the field at serialization for security reasons ]
[docs] class User(BaseModel): """BaseModel for a user in a user collection.""" id: UserId may_query: bool datasets_list: list[DatasetOfUser]
[docs] class UserCollection(BaseModel): """BaseModel for users collection.""" users: list[User]
# Dataset Access Data # -----------------------------------------------------------------------------
[docs] class DSAccess(BaseModel): """BaseModel for access info to a private dataset.""" database_type: str
[docs] class DSPathAccess(DSAccess): """BaseModel for a local dataset.""" database_type: Literal[PrivateDatabaseType.PATH] path: str
[docs] class DSS3Access(DSAccess): """BaseModel for a dataset on S3.""" database_type: Literal[PrivateDatabaseType.S3] endpoint_url: str bucket: str key: str access_key_id: str | None = None secret_access_key: str | None = None credentials_name: str
[docs] class DSInfo(BaseModel): """BaseModel for a dataset.""" dataset_name: str dataset_access: Annotated[DSPathAccess | DSS3Access, Field(discriminator=DB_TYPE_FIELD)] metadata_access: Annotated[DSPathAccess | DSS3Access, Field(discriminator=DB_TYPE_FIELD)]
[docs] class DatasetsCollection(BaseModel): """BaseModel for datasets collection.""" datasets: list[DSInfo]
# Metadata # -----------------------------------------------------------------------------
[docs] class ColumnMetadata(BaseModel): """Base model for column metadata.""" private_id: bool = False nullable_proportion: Annotated[float, Field(ge=0, lt=1)] = 0.0 # See issue #323 for checking this and validating. max_partition_length: Annotated[int, Field(gt=0)] | None = None max_influenced_partitions: Annotated[int, Field(gt=0)] | None = None max_partition_contributions: Annotated[int, Field(gt=0)] | None = None
[docs] class StrMetadata(ColumnMetadata): """Model for string metadata.""" type: Literal[MetadataColumnType.STRING]
[docs] class CategoricalColumnMetadata(ColumnMetadata): """Model for categorical column metadata."""
[docs] @model_validator(mode="after") def validate_categories(self) -> Self: """Makes sure number of categories matches cardinality.""" if len(self.categories) != self.cardinality: raise ValueError("Number of categories should be equal to cardinality.") return self
[docs] class StrCategoricalMetadata(CategoricalColumnMetadata): """Model for categorical string metadata.""" # The type must be kept to string (NOT categorical_string). # Some functions rely on this attribute # (e.g. data_connector when building pandas dataframe). type: Literal[MetadataColumnType.STRING] cardinality: int categories: list[str]
[docs] class BoundedColumnMetadata(ColumnMetadata): """Model for columns with bounded data."""
[docs] @model_validator(mode="after") def validate_bounds(self) -> Self: """Validates column bounds.""" if self.lower is not None and self.upper is not None and self.lower > self.upper: raise ValueError("Lower bound cannot be larger than upper bound.") return self
[docs] class IntMetadata(BoundedColumnMetadata): """Model for integer column metadata.""" type: Literal[MetadataColumnType.INT] precision: Precision lower: int upper: int
[docs] class IntCategoricalMetadata(CategoricalColumnMetadata): """Model for integer categorical column metadata.""" # The type must be kept to int (NOT categorical_int). # Some functions rely on this attribute # (e.g. data_connector when building pandas dataframe). type: Literal[MetadataColumnType.INT] precision: Precision cardinality: int categories: list[int]
[docs] class FloatMetadata(BoundedColumnMetadata): """Model for float column metadata.""" type: Literal[MetadataColumnType.FLOAT] precision: Precision lower: float upper: float int_with_nulls: bool = False
[docs] class BooleanMetadata(ColumnMetadata): """Model for boolean column metadata.""" type: Literal[MetadataColumnType.BOOLEAN]
[docs] class DatetimeMetadata(BoundedColumnMetadata): """Model for datetime column metadata.""" type: Literal[MetadataColumnType.DATETIME] lower: datetime upper: datetime
[docs] def get_column_metadata_discriminator(v: Any) -> str: """Discriminator function for determining the type of column metadata. Args: v (Any): The unparsed column metadata (either dict or class object) Raises: ValueError: If the column type cannot be found. Returns: str: The metadata string type. """ if isinstance(v, dict): col_type = v.get(TYPE_FIELD) else: col_type = getattr(v, TYPE_FIELD) if ( col_type in ( MetadataColumnType.STRING, MetadataColumnType.INT, ) ) and (((isinstance(v, dict)) and CARDINALITY_FIELD in v) or (hasattr(v, CARDINALITY_FIELD))): col_type = f"{CATEGORICAL_TYPE_PREFIX}{col_type}" if not isinstance(col_type, str): raise ValueError("Could not find column type.") return col_type
[docs] class Metadata(BaseModel): """ BaseModel for a metadata format. See Smartnoise-SQL documentation https://docs.smartnoise.org/sql/metadata.html """ max_ids: Annotated[int, Field(gt=0)] rows: Annotated[int, Field(gt=0)] row_privacy: bool censor_dims: bool = True clamp_counts: bool = True clamp_columns: bool = True use_dpsu: bool = False # When parsing input data, pydantic first calls the discriminator function with the input data. # The model to build is then selected by matching the returned discriminator # with the tag that annotates each possible model. # The discriminator function is used to differentiate between int and categorical_int # or string and categorical_string columns. # The integer and string models always keep their type to int or string (not categorical_**). columns: dict[ str, Annotated[ Annotated[StrMetadata, Tag(MetadataColumnType.STRING)] | Annotated[StrCategoricalMetadata, Tag(MetadataColumnType.CAT_STRING)] | Annotated[IntMetadata, Tag(MetadataColumnType.INT)] | Annotated[IntCategoricalMetadata, Tag(MetadataColumnType.CAT_INT)] | Annotated[FloatMetadata, Tag(MetadataColumnType.FLOAT)] | Annotated[BooleanMetadata, Tag(MetadataColumnType.BOOLEAN)] | Annotated[DatetimeMetadata, Tag(MetadataColumnType.DATETIME)], Discriminator(get_column_metadata_discriminator), ], ]