Source code for lomas_server.dp_queries.dp_libraries.smartnoise_synth

import re
from datetime import datetime
from typing import Dict, List, Optional, TypeAlias, TypeGuard, Union

import pandas as pd
from smartnoise_synth_logger import deserialise_constraints
from snsynth import Synthesizer
from snsynth.transform import (
    AnonymizationTransformer,
    BinTransformer,
    ChainTransformer,
    LabelTransformer,
    MinMaxTransformer,
    OneHotEncoder,
)
from snsynth.transform.datetime import DateTimeTransformer
from snsynth.transform.table import TableTransformer

from lomas_server.admin_database.admin_database import AdminDatabase
from lomas_server.constants import (
    SECONDS_IN_A_DAY,
    SSYNTH_DEFAULT_BINS,
    SSYNTH_MIN_ROWS_PATE_GAN,
    SSYNTH_PRIVATE_COLUMN,
    DPLibraries,
    SSynthGanSynthesizer,
    SSynthMarginalSynthesizer,
    SSynthTableTransStyle,
)
from lomas_server.data_connector.data_connector import DataConnector
from lomas_server.dp_queries.dp_libraries.utils import serialise_model
from lomas_server.dp_queries.dp_querier import DPQuerier
from lomas_server.utils.collection_models import (
    BooleanMetadata,
    ColumnMetadata,
    DatetimeMetadata,
    FloatMetadata,
    IntCategoricalMetadata,
    IntMetadata,
    Metadata,
    StrCategoricalMetadata,
    StrMetadata,
)
from lomas_server.utils.error_handler import (
    ExternalLibraryException,
    InternalServerException,
    InvalidQueryException,
)
from lomas_server.utils.query_models import (
    SmartnoiseSynthQueryModel,
    SmartnoiseSynthRequestModel,
)


[docs] def datetime_to_float(upper: datetime, lower: datetime) -> float: """Convert the upper date as the distance between the upper date and lower date as float Args: upper (datetime): date to convert lower (datetime): start date to convert from Returns: float: number of days between upper and lower """ distance = upper - lower return float(distance.total_seconds() / SECONDS_IN_A_DAY)
# TODO maybe a better place to put this? See issue #336 SSynthColumnType: TypeAlias = Union[ StrMetadata, StrCategoricalMetadata, BooleanMetadata, IntCategoricalMetadata, IntMetadata, FloatMetadata, DatetimeMetadata, ]
[docs] class SmartnoiseSynthQuerier( DPQuerier[SmartnoiseSynthRequestModel, SmartnoiseSynthQueryModel] ): """ Concrete implementation of the DPQuerier ABC for the SmartNoiseSynth library. """ def __init__( self, data_connector: DataConnector, admin_database: AdminDatabase, ) -> None: super().__init__(data_connector, admin_database) self.model: Optional[Synthesizer] = None def _is_categorical( self, col_metadata: ColumnMetadata ) -> TypeGuard[ StrMetadata | StrCategoricalMetadata | BooleanMetadata | IntCategoricalMetadata ]: """ Checks if the column type is categorical Args: col_metadata (ColumnMetadata): The column metadata Returns: TypeGuard[StrMetadata | StrCategoricalMetadata| BooleanMetadata | IntCategoricalMetadata]: TypeGuard for categorical columns metadata """ return isinstance( col_metadata, ( StrMetadata, StrCategoricalMetadata, BooleanMetadata, IntCategoricalMetadata, ), ) def _is_continuous( self, col_metadata: ColumnMetadata ) -> TypeGuard[IntMetadata | FloatMetadata]: """Checks if the column type is continuous Args: col_metadata (ColumnMetadata): The column metadata Returns: TypeGuard[IntMetadata | FloatMetadata]: TypeGuard for continuous columns metadata """ return isinstance(col_metadata, (IntMetadata, FloatMetadata)) def _is_datetime( self, col_metadata: ColumnMetadata ) -> TypeGuard[DatetimeMetadata]: """Checks if the column type is datetime Args: col_metadata (ColumnMetadata): The column metadata Returns: TypeGuard[DatetimeMetadata]: TypeGuard for datetime metadata. """ return isinstance(col_metadata, DatetimeMetadata) def _get_and_check_valid_column_types( self, metadata: Metadata, select_cols: List[str] ) -> Dict[str, SSynthColumnType]: """ Ensures the type of the selected columns can be handled with SmartnoiseSynth and returns the dict of column metadata for the selected columns. Args: metadata (Metadata): Dataset metadata select_cols (List[str]): List of selected columns Raises: InternalServerException: If one of the column types cannot be handled with SmartnoiseSynth. Returns: Dict[str, SSynthColumnType]: The filtered dict of selected columns. """ columns: Dict[str, SSynthColumnType] = {} for col_name, data in metadata.columns.items(): if select_cols and col_name not in select_cols: continue if not isinstance(data, SSynthColumnType): # type: ignore[misc, arg-type] raise InternalServerException( f"Column type {data.type} not supported for SmartnoiseSynth" ) columns[col_name] = data return columns def _get_default_constraints( self, metadata: Metadata, query_json: SmartnoiseSynthRequestModel, table_transformer_style: str, ) -> TableTransformer: """ Get the defaults table transformer constraints based on the metadata See https://docs.smartnoise.org/synth/transforms/index.html for documentation See https://github.com/opendp/smartnoise-sdk/blob/main/synth/snsynth/ transform/type_map.py#L40 for get_transformer() method taken as basis. Args: metadata (Metadata): Metadata of the dataset query_json (SmartnoiseSynthRequestModel): JSON request object for the query select_cols (List[str]): List of columns to select nullable (bool): True is the data can have Null values, False otherwise table_transformer_style (str): 'gan' or 'cube' Returns: table_tranformer (TableTransformer) to pre and post-process the data """ columns = self._get_and_check_valid_column_types( metadata, query_json.select_cols ) constraints = {} nullable = query_json.nullable for col, col_metadata in columns.items(): if col_metadata.private_id: constraints[col] = AnonymizationTransformer( SSYNTH_PRIVATE_COLUMN ) if table_transformer_style == SSynthTableTransStyle.GAN: # gan if self._is_categorical( col_metadata ): # TODO any way of specifying cardinality? See issue #337 constraints[col] = ChainTransformer( [LabelTransformer(nullable=nullable), OneHotEncoder()] ) elif self._is_continuous(col_metadata): constraints[col] = MinMaxTransformer( lower=col_metadata.lower, upper=col_metadata.upper, nullable=nullable, ) elif self._is_datetime(col_metadata): constraints[col] = ChainTransformer( [ DateTimeTransformer(epoch=col_metadata.lower), MinMaxTransformer( lower=0.0, # because start epoch at lower bound upper=datetime_to_float( col_metadata.upper, col_metadata.lower, ), nullable=nullable, ), ] ) else: # Cube if self._is_categorical( col_metadata ): # TODO any way of specifying cardinality? See issue #337 constraints[col] = LabelTransformer(nullable=nullable) elif self._is_continuous(col_metadata): constraints[col] = BinTransformer( lower=col_metadata.lower, upper=col_metadata.upper, bins=SSYNTH_DEFAULT_BINS, nullable=nullable, ) elif self._is_datetime(col_metadata): constraints[col] = ChainTransformer( [ DateTimeTransformer(epoch=col_metadata.lower), BinTransformer( lower=0.0, # because start epoch at lower bound upper=datetime_to_float( col_metadata.upper, col_metadata.lower, ), bins=SSYNTH_DEFAULT_BINS, nullable=nullable, ), ] ) return constraints def _get_fit_model( self, private_data: pd.DataFrame, transformer: TableTransformer, query_json: SmartnoiseSynthRequestModel, ) -> Synthesizer: """ Create and fit the synthesizer model. Args: private_data (pd.DataFrame): Private data for fitting the model transformer (TableTransformer): Transformer to pre/postprocess data query_json (SmartnoiseSynthRequestModel): JSON request object for the query synth_name (str): name of the Yanthesizer model to use epsilon (float): epsilon budget value nullable (bool): True if some data cells may be null synth_params (dict): Keyword arguments to pass to the synthesizer constructor. Returns: Synthesizer: Fitted synthesizer model """ if query_json.delta is not None: query_json.synth_params["delta"] = query_json.delta if query_json.synth_name == SSynthGanSynthesizer.DP_CTGAN: query_json.synth_params["disabled_dp"] = False try: model = Synthesizer.create( synth=query_json.synth_name, epsilon=query_json.epsilon, **query_json.synth_params, ) except Exception as e: raise ExternalLibraryException( DPLibraries.SMARTNOISE_SYNTH, "Error creating model: " + str(e) ) from e try: model.fit( data=private_data, transformer=transformer, preprocessor_eps=0.0, # will error if not 0.0 nullable=query_json.nullable, ) except ValueError as e: # Improve snsynth error messages pattern = ( r"sample_rate=[\d\.]+ is not a valid value\. " r"Please provide a float between 0 and 1\." ) if ( query_json.synth_name == SSynthGanSynthesizer.DP_CTGAN and re.match(pattern, str(e)) ): raise ExternalLibraryException( DPLibraries.SMARTNOISE_SYNTH, f"Error fitting model: {e} Try decreasing batch_size in " + "synth_params (default batch_size=500).", ) from e raise ExternalLibraryException( DPLibraries.SMARTNOISE_SYNTH, "Error fitting model: " + str(e) ) from e except Exception as e: raise ExternalLibraryException( DPLibraries.SMARTNOISE_SYNTH, "Error fitting model: " + str(e) ) from e return model def _model_pipeline( self, query_json: SmartnoiseSynthRequestModel ) -> Synthesizer: """Return a trained Synthesizer model based on query_json Args: query_json (SmartnoiseSynthRequestModel): JSON request object for the query. Returns: model: Smartnoise Synthesizer """ if ( query_json.synth_name == SSynthMarginalSynthesizer.MST and query_json.return_model ): raise InvalidQueryException( "mst synthesizer cannot be returned, only samples. " + "Please, change model or set `return_model=False`" ) if query_json.synth_name == SSynthMarginalSynthesizer.PAC_SYNTH: raise InvalidQueryException( "pacsynth synthesizer not supported due to Rust panic. " + "Please select another Synthesizer." ) # Table Transformation depenps on the type of Synthesizer if query_json.synth_name in [ s.value for s in SSynthMarginalSynthesizer ]: table_transformer_style = SSynthTableTransStyle.CUBE else: table_transformer_style = SSynthTableTransStyle.GAN # Preprocessing information from metadata metadata = self.data_connector.get_metadata() if query_json.synth_name == SSynthGanSynthesizer.PATE_GAN: if metadata.rows < SSYNTH_MIN_ROWS_PATE_GAN: raise ExternalLibraryException( DPLibraries.SMARTNOISE_SYNTH, f"{SSynthGanSynthesizer.PATE_GAN} not reliable " + "with this dataset.", ) constraints = self._get_default_constraints( metadata, query_json, table_transformer_style ) # Overwrite default constraint with custom constraint (if any) constraints_json = query_json.constraints if constraints_json: custom_constraints = deserialise_constraints(constraints_json) custom_constraints = { key: custom_constraints[key] for key in query_json.select_cols if key in custom_constraints } constraints.update(custom_constraints) # Prepare private data private_data = self.data_connector.get_pandas_df() if query_json.select_cols: try: private_data = private_data[query_json.select_cols] except KeyError as e: raise InvalidQueryException( "Error while selecting provided select_cols: " + str(e) ) from e # Get transformer transformer = TableTransformer.create( data=private_data, style=table_transformer_style, nullable=query_json.nullable, constraints=constraints, ) # Create and fit synthesizer model = self._get_fit_model(private_data, transformer, query_json) return model
[docs] def cost( self, query_json: SmartnoiseSynthRequestModel ) -> tuple[float, float]: """Return cost of query_json Args: query_json (SmartnoiseSynthRequestModel): JSON request object for the query. Returns: tuple[float, float]: The tuple of costs, the first value is the epsilon cost, the second value is the delta value. # TODO: verify and model.rho """ self.model = self._model_pipeline(query_json) if query_json.synth_name == SSynthMarginalSynthesizer.MWEM: epsilon, delta = self.model.epsilon, 0 elif query_json.synth_name == SSynthGanSynthesizer.DP_CTGAN: epsilon, delta = self.model.epsilon_list[-1], self.model.delta else: epsilon, delta = self.model.epsilon, self.model.delta return epsilon, delta
[docs] def query( self, query_json: SmartnoiseSynthQueryModel, ) -> Union[pd.DataFrame, str]: """Perform the query and return the response. Args: query_json (SmartnoiseSynthQueryModel): The request object for the query. Raises: ExternalLibraryException: For exceptions from libraries external to this package. InvalidQueryException: If the budget values are too small to perform the query. Returns: pd.DataFrame: The resulting pd.DataFrame samples. """ if self.model is None: raise InternalServerException( "Smartnoise Synth `query` method called before `cost` method" ) if not query_json.return_model: # Sample df_samples = ( self.model.sample_conditional( query_json.nb_samples, query_json.condition ) if query_json.condition else self.model.sample(query_json.nb_samples) ) # Ensure serialisable df_samples = df_samples.fillna("") return df_samples.to_dict(orient="records") return serialise_model(self.model)