Source code for lomas_core.models.collections

from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Discriminator, Field, Tag, model_validator

from lomas_core.models.constants import (
    CARDINALITY_FIELD,
    CATEGORICAL_TYPE_PREFIX,
    DB_TYPE_FIELD,
    TYPE_FIELD,
    MetadataColumnType,
    Precision,
    PrivateDatabaseType,
)

# Dataset of User
# -----------------------------------------------------------------------------


[docs] class DatasetOfUser(BaseModel): """BaseModel for informations of a user on a dataset.""" dataset_name: str initial_epsilon: float initial_delta: float total_spent_epsilon: float total_spent_delta: float
# User # -----------------------------------------------------------------------------
[docs] class User(BaseModel): """BaseModel for a user in a user collection.""" user_name: str may_query: bool datasets_list: List[DatasetOfUser]
[docs] class UserCollection(BaseModel): """BaseModel for users collection.""" users: List[User]
# Dataset Access Data # -----------------------------------------------------------------------------
[docs] class DSAccess(BaseModel): """BaseModel for access info to a private dataset.""" database_type: str
[docs] class DSPathAccess(DSAccess): """BaseModel for a local dataset.""" database_type: Literal[PrivateDatabaseType.PATH] # type: ignore path: str
[docs] class DSS3Access(DSAccess): """BaseModel for a dataset on S3.""" database_type: Literal[PrivateDatabaseType.S3] # type: ignore endpoint_url: str bucket: str key: str access_key_id: Optional[str] = None secret_access_key: Optional[str] = None credentials_name: str
[docs] class DSInfo(BaseModel): """BaseModel for a dataset.""" dataset_name: str dataset_access: Annotated[Union[DSPathAccess, DSS3Access], Field(discriminator=DB_TYPE_FIELD)] metadata_access: Annotated[Union[DSPathAccess, DSS3Access], Field(discriminator=DB_TYPE_FIELD)]
[docs] class DatasetsCollection(BaseModel): """BaseModel for datasets collection.""" datasets: List[DSInfo]
# Metadata # -----------------------------------------------------------------------------
[docs] class ColumnMetadata(BaseModel): """Base model for column metadata.""" private_id: bool = False nullable: bool = False # See issue #323 for checking this and validating. max_partition_length: Optional[Annotated[int, Field(gt=0)]] = None max_influenced_partitions: Optional[Annotated[int, Field(gt=0)]] = None max_partition_contributions: Optional[Annotated[int, Field(gt=0)]] = None
[docs] class StrMetadata(ColumnMetadata): """Model for string metadata.""" type: Literal[MetadataColumnType.STRING]
[docs] class CategoricalColumnMetadata(ColumnMetadata): """Model for categorical column metadata."""
[docs] @model_validator(mode="after") def validate_categories(self): """Makes sure number of categories matches cardinality.""" if len(self.categories) != self.cardinality: raise ValueError("Number of categories should be equal to cardinality.") return self
[docs] class StrCategoricalMetadata(CategoricalColumnMetadata): """Model for categorical string metadata.""" type: Literal[MetadataColumnType.STRING] cardinality: int categories: List[str]
[docs] class BoundedColumnMetadata(ColumnMetadata): """Model for columns with bounded data."""
[docs] @model_validator(mode="after") def validate_bounds(self): """Validates column bounds.""" if self.lower is not None and self.upper is not None and self.lower > self.upper: raise ValueError("Lower bound cannot be larger than upper bound.") return self
[docs] class IntMetadata(BoundedColumnMetadata): """Model for integer column metadata.""" type: Literal[MetadataColumnType.INT] precision: Precision lower: int upper: int
[docs] class IntCategoricalMetadata(CategoricalColumnMetadata): """Model for integer categorical column metadata.""" type: Literal[MetadataColumnType.INT] precision: Precision cardinality: int categories: List[int]
[docs] class FloatMetadata(BoundedColumnMetadata): """Model for float column metadata.""" type: Literal[MetadataColumnType.FLOAT] precision: Precision lower: float upper: float
[docs] class BooleanMetadata(ColumnMetadata): """Model for boolean column metadata.""" type: Literal[MetadataColumnType.BOOLEAN]
[docs] class DatetimeMetadata(BoundedColumnMetadata): """Model for datetime column metadata.""" type: Literal[MetadataColumnType.DATETIME] lower: datetime upper: datetime
[docs] def get_column_metadata_discriminator(v: Any) -> str: """Discriminator function for determining the type of column metadata. Args: v (Any): The unparsed column metadata (either dict or class object) Raises: ValueError: If the column type cannot be found. Returns: str: The metadata string type. """ if isinstance(v, dict): col_type = v.get(TYPE_FIELD) else: col_type = getattr(v, TYPE_FIELD) if ( col_type in ( MetadataColumnType.STRING, MetadataColumnType.INT, ) ) and (((isinstance(v, dict)) and CARDINALITY_FIELD in v) or (hasattr(v, CARDINALITY_FIELD))): col_type = f"{CATEGORICAL_TYPE_PREFIX}{col_type}" if not isinstance(col_type, str): raise ValueError("Could not find column type.") return col_type
[docs] class Metadata(BaseModel): """BaseModel for a metadata format.""" max_ids: Annotated[int, Field(gt=0)] rows: Annotated[int, Field(gt=0)] row_privacy: bool censor_dims: Optional[bool] = False columns: Dict[ str, Annotated[ Union[ Annotated[StrMetadata, Tag(MetadataColumnType.STRING)], Annotated[StrCategoricalMetadata, Tag(MetadataColumnType.CAT_STRING)], Annotated[IntMetadata, Tag(MetadataColumnType.INT)], Annotated[IntCategoricalMetadata, Tag(MetadataColumnType.CAT_INT)], Annotated[FloatMetadata, Tag(MetadataColumnType.FLOAT)], Annotated[BooleanMetadata, Tag(MetadataColumnType.BOOLEAN)], Annotated[DatetimeMetadata, Tag(MetadataColumnType.DATETIME)], ], Discriminator(get_column_metadata_discriminator), ], ]