import base64
import json
import pickle
import pandas as pd
from fastapi import status
from fastapi.testclient import TestClient
from smartnoise_synth_logger import serialise_constraints
from snsynth.transform import (
ChainTransformer,
LabelTransformer,
MinMaxTransformer,
OneHotEncoder,
)
from lomas_server.app import app
from lomas_server.tests.constants import PENGUIN_COLUMNS, PUMS_COLUMNS
from lomas_server.tests.test_api import TestRootAPIEndpoint
from lomas_server.utils.query_examples import (
example_dummy_smartnoise_synth_query,
example_smartnoise_synth_cost,
example_smartnoise_synth_query,
)
[docs]
def get_model(query_response):
"""Unpickle model from API response"""
model = base64.b64decode(query_response)
model = pickle.loads(model)
return model
[docs]
class TestSmartnoiseSynthEndpoint(
TestRootAPIEndpoint
): # pylint: disable=R0904
"""
Test Smartnoise Synth Endpoints with different Synthesizers
"""
[docs]
def test_smartnoise_synth_query(self) -> None:
"""Test smartnoise synth query"""
with TestClient(app, headers=self.headers) as client:
# Expect to work
response = client.post(
"/smartnoise_synth_query",
json=example_smartnoise_synth_query,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
assert response_dict["requested_by"] == self.user_name
assert response_dict["spent_epsilon"] >= 0.1
assert response_dict["spent_delta"] >= 1e-05
model = get_model(response_dict["query_response"])
assert model.__class__.__name__ == "DPCTGAN"
df = model.sample(10)
assert list(df.columns) == PENGUIN_COLUMNS
# Expect to fail due to parameters
body = dict(example_smartnoise_synth_query)
body["synth_params"] = {}
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
assert response.json() == {
"ExternalLibraryException": "Error fitting model: "
+ "sample_rate=1.4534883720930232 is not a valid value. "
+ "Please provide a float between 0 and 1. "
+ "Try decreasing batch_size in "
+ "synth_params (default batch_size=500).",
"library": "smartnoise_synth",
}
[docs]
def test_smartnoise_synth_query_samples(self) -> None:
"""Test smartnoise synth query return samples"""
with TestClient(app, headers=self.headers) as client:
nb_samples = 100
body = dict(example_smartnoise_synth_query)
body["return_model"] = False
body["nb_samples"] = nb_samples
# Expect to work - no condition
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
assert response_dict["requested_by"] == self.user_name
df_0 = pd.DataFrame(response_dict["query_response"])
assert df_0.shape[0] == nb_samples
assert list(df_0.columns) == PENGUIN_COLUMNS
# Expect to work - condition
body["condition"] = "bill_length_mm < 40"
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
assert response_dict["requested_by"] == self.user_name
df_1 = pd.DataFrame(response_dict["query_response"])
assert df_1.shape[0] == nb_samples
assert list(df_1.columns) == PENGUIN_COLUMNS
assert (
df_0["bill_length_mm"].mean() > df_1["bill_length_mm"].mean()
)
[docs]
def test_smartnoise_synth_query_select_cols(self) -> None:
"""Test smartnoise synth query select_cols"""
with TestClient(app, headers=self.headers) as client:
# Expect to work
body = dict(example_smartnoise_synth_query)
body["select_cols"] = ["species", "island"]
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == ["species", "island"]
# Expect to fail
body = dict(example_smartnoise_synth_query)
body["select_cols"] = ["species", "idonotexist"]
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["InvalidQueryException"].startswith(
"Error while selecting provided select_cols: "
)
[docs]
def test_smartnoise_synth_query_constraints(self) -> None:
"""Test smartnoise synth query constraints"""
with TestClient(app, headers=self.headers) as client:
constraints = {
"species": ChainTransformer(
[LabelTransformer(nullable=True), OneHotEncoder()]
),
"island": ChainTransformer(
[LabelTransformer(nullable=True), OneHotEncoder()]
),
"bill_length_mm": MinMaxTransformer(
lower=30.0, upper=65.0, nullable=True
),
"bill_depth_mm": MinMaxTransformer(
lower=13.0, upper=23.0, nullable=True
),
"flipper_length_mm": MinMaxTransformer(
lower=150.0, upper=250.0, nullable=True
),
"body_mass_g": MinMaxTransformer(
lower=2000.0, upper=7000.0, nullable=True
),
"sex": ChainTransformer(
[LabelTransformer(nullable=True), OneHotEncoder()]
),
}
body = dict(example_smartnoise_synth_query)
body["constraints"] = serialise_constraints(constraints)
# Expect to work
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == PENGUIN_COLUMNS
[docs]
def test_smartnoise_synth_query_private_id(self) -> None:
"""Test smartnoise synth query on other dataset for private id
and categorical int columns
"""
with TestClient(app, headers=self.headers) as client:
# Expect to work
body = dict(example_smartnoise_synth_query)
body["dataset_name"] = "PUMS"
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == PUMS_COLUMNS
[docs]
def test_smartnoise_synth_query_delta_none(self) -> None:
"""Test smartnoise synth query on other synthesizer with delta None"""
with TestClient(app, headers=self.headers) as client:
# Expect to work
body = dict(example_dummy_smartnoise_synth_query)
body["dataset_name"] = "PUMS"
body["delta"] = None
body["synth_params"] = {"batch_size": 2, "epochs": 5}
response = client.post(
"/dummy_smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == PUMS_COLUMNS
[docs]
def test_dummy_smartnoise_synth_query(self) -> None:
"""test_dummy_smartnoise_synth_query"""
with TestClient(app) as client:
# Expect to work
response = client.post(
"/dummy_smartnoise_synth_query",
json=example_dummy_smartnoise_synth_query,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
model = base64.b64decode(response_dict["query_response"])
model = pickle.loads(model)
assert model.__class__.__name__ == "DPCTGAN"
# Expect to fail: user does have access to dataset
body = dict(example_dummy_smartnoise_synth_query)
body["dataset_name"] = "IRIS"
response = client.post(
"/dummy_smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.json() == {
"UnauthorizedAccessException": ""
+ f"{self.user_name} does not have access to IRIS."
}
[docs]
def test_smartnoise_synth_cost(self) -> None:
"""test_smartnoise_synth_cost"""
with TestClient(app) as client:
# Expect to work
response = client.post(
"/estimate_smartnoise_synth_cost",
json=example_smartnoise_synth_cost,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
assert response_dict["epsilon_cost"] >= 0.1
assert response_dict["delta_cost"] >= 1e-5
# Expect to fail: user does have access to dataset
body = dict(example_smartnoise_synth_cost)
body["dataset_name"] = "IRIS"
response = client.post(
"/estimate_smartnoise_synth_cost",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.json() == {
"UnauthorizedAccessException": ""
+ f"{self.user_name} does not have access to IRIS."
}
[docs]
def test_smartnoise_synth_query_datetime(self) -> None:
"""Test smartnoise synth query on other dataset for datetime columns"""
with TestClient(app) as client:
# Expect to work
new_headers = self.headers
new_headers["user-name"] = "BirthdayGirl"
body = dict(example_smartnoise_synth_query)
body["dataset_name"] = "BIRTHDAYS"
body["synth_params"]["batch_size"] = 2 # type: ignore
# With gan synthesizer
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=new_headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == ["birthday"]
# With marginal synthesizer
body["synth_name"] = "mwem"
body["delta"] = None
body["synth_params"] = {}
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=new_headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == ["birthday"]
[docs]
def test_smartnoise_synth_query_aim(self) -> None:
"""Test smartnoise synth query AIM Synthesizer"""
with TestClient(app) as client:
# Expect to work
body = dict(example_smartnoise_synth_query)
body["synth_name"] = "aim"
body["select_cols"] = [
"bill_depth_mm",
"species",
] # too slow otherwise
body["synth_params"] = {}
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == body["select_cols"]
[docs]
def test_smartnoise_synth_query_mwem(self) -> None:
"""Test smartnoise synth query MWEM Synthesizer"""
with TestClient(app) as client:
# Expect to fail: delta
body = dict(example_smartnoise_synth_query)
body["synth_name"] = "mwem"
body["synth_params"] = {}
body["select_cols"] = ["species", "island"]
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
assert response.json() == {
"ExternalLibraryException": "Error creating model: "
+ "MWEMSynthesizer.__init__() got an "
+ "unexpected keyword argument 'delta'",
"library": "smartnoise_synth",
}
# Expect to work: limited columns and delta None
body["delta"] = None
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == ["species", "island"]
# Expect to work: special parameters
body["synth_params"] = {"split_factor": 2, "measure_only": False}
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == ["species", "island"]
[docs]
def test_smartnoise_synth_query_mst(self) -> None:
"""Test smartnoise synth query MST Synthesizer"""
with TestClient(app) as client:
# Expect to work:
body = dict(example_smartnoise_synth_query)
body["synth_name"] = "mst"
body["return_model"] = False
body["nb_samples"] = 10
body["select_cols"] = ["bill_length_mm"] # too slow otherwise
body["synth_params"] = {}
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
df = pd.DataFrame(response_dict["query_response"])
assert df.shape[0] == body["nb_samples"]
assert list(df.columns) == body["select_cols"]
# Espect to fail: MST cannot return model
body["return_model"] = True
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["InvalidQueryException"].startswith(
"mst synthesizer cannot be returned, only samples. "
+ "Please, change model or set `return_model=False`"
)
[docs]
def test_smartnoise_synth_query_pacsynth(self) -> None:
"""Test smartnoise synth query PAC-Synth Synthesizer
TOO UNSTABLE BECAUSE OF RUST PANIC
"""
with TestClient(app) as client:
# Expect to fail: #TODO why
body = dict(example_smartnoise_synth_query)
body["synth_name"] = "pacsynth"
body["synth_params"] = {}
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["InvalidQueryException"].startswith(
"pacsynth synthesizer not supported due to Rust panic. "
+ "Please select another Synthesizer."
)
[docs]
def test_smartnoise_synth_query_patectgan(self) -> None:
"""Test smartnoise synth query PATE-CTGAN Synthesizer"""
with TestClient(app) as client:
# Expect to fail: epsilon too small
body = dict(example_smartnoise_synth_query)
body["synth_name"] = "patectgan"
body["synth_params"] = {}
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
assert response.json() == {
"ExternalLibraryException": "Error fitting model: "
+ "Inputted epsilon parameter is too small to create a private"
+ " dataset. Try increasing epsilon and rerunning.",
"library": "smartnoise_synth",
}
# Expect to work
body["epsilon"] = 1.0
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == PENGUIN_COLUMNS
[docs]
def test_smartnoise_synth_query_pategan(self) -> None:
"""Test smartnoise synth query pategan Synthesizer"""
with TestClient(app) as client:
# Expect to fail: penguin dataset is too small
# (pategan needs > 1000 rows)
body = dict(example_smartnoise_synth_query)
body["synth_name"] = "pategan"
body["synth_params"] = {}
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
assert response.json() == {
"ExternalLibraryException": "pategan not reliable with this dataset.",
"library": "smartnoise_synth",
}
[docs]
def test_smartnoise_synth_query_dpgan(self) -> None:
"""Test smartnoise synth query dpgan Synthesizer"""
with TestClient(app) as client:
# Expect to fail: epsilon too small
body = dict(example_smartnoise_synth_query)
body["synth_name"] = "dpgan"
body["synth_params"] = {}
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
assert response.json() == {
"ExternalLibraryException": "Error fitting model: "
+ "Inputted epsilon and sigma parameters "
+ "are too small to create a private dataset. "
+ "Try increasing either parameter and rerunning.",
"library": "smartnoise_synth",
}
body["epsilon"] = 1.0
response = client.post(
"/smartnoise_synth_query",
json=body,
headers=self.headers,
)
assert response.status_code == status.HTTP_200_OK
response_dict = json.loads(response.content.decode("utf8"))
model = get_model(response_dict["query_response"])
df = model.sample(1)
assert list(df.columns) == PENGUIN_COLUMNS