Source code for lomas_client.utils

import warnings
from typing import Any

import requests
from fastapi import status
from lomas_core.constants import SSynthGanSynthesizer, SSynthMarginalSynthesizer
from lomas_core.error_handler import (
    ExternalLibraryException,
    InternalServerException,
    InvalidQueryException,
    UnauthorizedAccessException,
)


[docs] def raise_error(response: requests.Response) -> str: """Raise error message based on the HTTP response. Args: res (requests.Response): The response object from an HTTP request. Raise: Server Error """ error_message = response.json() if response.status_code == status.HTTP_400_BAD_REQUEST: raise InvalidQueryException(error_message["InvalidQueryException"]) if response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY: raise ExternalLibraryException( error_message["library"], error_message["ExternalLibraryException"] ) if response.status_code == status.HTTP_403_FORBIDDEN: raise UnauthorizedAccessException(error_message["UnauthorizedAccessException"]) if response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR: raise InternalServerException(error_message["InternalServerException"]) raise InternalServerException(f"Unknown {InternalServerException}")
[docs] def validate_synthesizer(synth_name: str, return_model: bool = False): """Validate smartnoise synthesizer (some model are not accepted). Args: synth_name (str): name of the Synthesizer model to use. return_model (bool): True to get Synthesizer model, False to get samples Raises: ValueError: if a synthesizer or its parameters are not valid """ if synth_name in [ SSynthGanSynthesizer.DP_CTGAN, SSynthGanSynthesizer.DP_GAN, ]: warnings.warn( f"Warning:{synth_name} synthesizer random generator for noise and " + "shuffling is not cryptographically secure. " + "(pseudo-rng in vanilla PyTorch)." ) if synth_name == SSynthMarginalSynthesizer.MST and return_model: raise ValueError( f"{synth_name} synthesizer cannot be returned, only samples. " + "Please, change synthesizer or set `return_model=False`." ) if synth_name == SSynthMarginalSynthesizer.PAC_SYNTH: raise ValueError( f"{synth_name} synthesizer not supported. " + "Please choose another synthesizer." )
[docs] def validate_model_response(response: requests.Response, response_model: Any) -> Any: """Validate and process a HTTP response. Args: response (requests.Response): The response object from an HTTP request. Returns: response_model: Model for responses requests. """ if response.status_code == status.HTTP_200_OK: data = response.content.decode("utf8") r_model = response_model.model_validate_json(data) return r_model raise_error(response) return None