import operator as op
import shelve
import sys
from pathlib import Path
from typing import Any, Self
import boto3
import yaml
from pydantic import HttpUrl
from lomas_core.error_handler import InternalServerException
from lomas_core.models.collections import (
DatasetOfUser,
DatasetsCollection,
DSInfo,
DSPathAccess,
DSS3Access,
Metadata,
User,
UserCollection,
UserId,
)
from lomas_core.models.constants import PrivateDatabaseType, init_logging
from lomas_core.models.requests import LomasRequestModel
from lomas_core.models.responses import QueryResponse
from lomas_server.admin_database.admin_database import (
AdminDatabase,
dataset_must_exist,
user_must_exist,
user_must_have_access_to_dataset,
)
from lomas_server.admin_database.constants import BudgetDBKey, TopDBKey as TK
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = init_logging(__name__)
[docs]
class LocalAdminDatabase(AdminDatabase):
"""Local Admin database in a single file."""
path: Path
"""Database accepts existing path or new (creatable) path."""
[docs]
def model_post_init(self, _: Any, /) -> None:
# create the file if it doesn't exists yet (makes open with flag='r' safe)
shelve.open(self.path).close()
[docs]
@override
def wipe(self) -> None:
if (p := Path(self.path)).exists():
p.unlink()
[docs]
def load_users_collection(self, users: list[User]) -> None:
with shelve.open(self.path, writeback=True) as db:
db[TK.USERS] = {user.id.name: user.model_dump() for user in users}
[docs]
def users(self) -> list[User]:
with shelve.open(self.path, flag="r") as db:
return list(map(User.model_validate, db.get(TK.USERS, {}).values()))
[docs]
def load_dataset_collection(self, datasets: list[DSInfo], path_prefix: str) -> None:
with shelve.open(self.path, writeback=True) as db:
# Step 1: add datasets
new_datasets = []
for ds in datasets:
# Overwrite path
if isinstance(ds.dataset_access, DSPathAccess):
match ds.dataset_access.path:
case HttpUrl():
pass
case Path():
ds.dataset_access.path = Path(path_prefix) / ds.dataset_access.path
if isinstance(ds.metadata_access, DSPathAccess):
match ds.metadata_access.path:
case HttpUrl():
pass
case Path():
ds.metadata_access.path = Path(path_prefix) / ds.metadata_access.path
# Fill datasets_list
new_datasets.append(ds)
# Add dataset collection
if new_datasets:
new_datasets_dicts = [ds.model_dump() for ds in new_datasets]
db[TK.DATASETS] = new_datasets_dicts
db[TK.METADATA] = {}
# Step 2: add metadata collections (one metadata per dataset)
for ds in datasets:
dataset_name = ds.dataset_name
metadata_access = ds.metadata_access
match metadata_access:
case DSPathAccess():
metadata_dict = yaml.safe_load(metadata_access.path.open())
case DSS3Access():
client = boto3.client(
"s3",
endpoint_url=str(metadata_access.endpoint_url),
aws_access_key_id=metadata_access.access_key_id,
aws_secret_access_key=metadata_access.secret_access_key,
)
response = client.get_object(
Bucket=metadata_access.bucket,
Key=metadata_access.key,
)
try:
metadata_dict = yaml.safe_load(response["Body"])
except yaml.YAMLError as e:
return e
case _:
raise InternalServerException(
f"Unknown metadata_db_type PrivateDatabaseType: {metadata_access.database_type}"
)
db[TK.METADATA][dataset_name] = metadata_dict
logger.debug(f"Added metadata of {dataset_name} dataset. ")
[docs]
def datasets(self) -> list[DSInfo]:
with shelve.open(self.path, flag="r") as db:
return list(map(DSInfo.model_validate, db.get(TK.DATASETS, [])))
[docs]
def add_datasets_via_yaml(
self,
yaml_file: Path,
clean: bool,
path_prefix: str = "",
) -> None:
"""Set all database types to datasets in dataset collection based.
on yaml file.
Args:
yaml_file Path: path to the YAML file location
clean (bool): Whether to clean the collection before adding.
path_prefix (str, optional): Prefix to add to all file paths. Defaults to "".
Raises:
ValueError: If there are errors in the YAML file format.
Returns:
None
"""
if clean:
self.drop_collection("datasets")
yaml_dict = yaml.safe_load(yaml_file.resolve().open())
self.load_dataset_collection(DatasetsCollection(**yaml_dict).datasets, path_prefix)
[docs]
def add_dataset(
self,
dataset_name: str,
database_type: str,
metadata_database_type: str,
dataset_path: str | None = "",
metadata_path: str = "",
bucket: str | None = "",
key: str | None = "",
endpoint_url: str | None = "",
credentials_name: str | None = "",
metadata_bucket: str | None = "",
metadata_key: str | None = "",
metadata_endpoint_url: str | None = "",
metadata_access_key_id: str | None = "",
metadata_secret_access_key: str | None = "",
metadata_credentials_name: str | None = "",
) -> None:
"""Set a database type to a dataset in dataset collection.
Args:
dataset_name (str): Dataset name
database_type (str): Type of the database
metadata_database_type (str): Metadata database type
dataset_path (str): Path to the dataset (for local db type)
metadata_path (str): Path to metadata (for local db type)
bucket (str): S3 bucket name
key (str): S3 key
endpoint_url (str): S3 endpoint URL
credentials_name (str): The name of the credentials in the\
server config to retrieve the dataset from S3 storage.
metadata_bucket (str): Metadata S3 bucket name
metadata_key (str): Metadata S3 key
metadata_endpoint_url (str): Metadata S3 endpoint URL
metadata_access_key_id (str): Metadata AWS access key ID
metadata_secret_access_key (str): Metadata AWS secret access key
metadata_credentials_name (str): The name of the credentials in the\
server config for retrieving the metadata.
Raises:
ValueError: If the dataset already exists
or if the database type is unknown.
Returns:
None
"""
# Step 1: Build dataset
dataset: dict[str, Any] = {"dataset_name": dataset_name}
dataset_access: dict[str, Any] = {
"database_type": database_type,
}
if database_type == PrivateDatabaseType.PATH:
if dataset_path is None:
raise ValueError("Dataset path not set.")
dataset_access["path"] = dataset_path
elif database_type == PrivateDatabaseType.S3:
dataset_access["bucket"] = bucket
dataset_access["key"] = key
dataset_access["endpoint_url"] = endpoint_url
dataset_access["credentials_name"] = credentials_name
else:
raise ValueError(f"Unknown database type {database_type}")
dataset["dataset_access"] = dataset_access
# Step 2: Build metadata
metadata_access: dict[str, Any] = {"database_type": metadata_database_type}
if metadata_database_type == PrivateDatabaseType.PATH:
# Store metadata from yaml to metadata collection
with Path(metadata_path).resolve().open(encoding="utf-8") as f:
metadata_dict = yaml.safe_load(f)
metadata_access["path"] = metadata_path
elif metadata_database_type == PrivateDatabaseType.S3:
client = boto3.client(
"s3",
endpoint_url=metadata_endpoint_url,
aws_access_key_id=metadata_access_key_id,
aws_secret_access_key=metadata_secret_access_key,
)
response = client.get_object(Bucket=metadata_bucket, Key=metadata_key)
try:
metadata_dict = yaml.safe_load(response["Body"])
except yaml.YAMLError as e:
raise e
metadata_access["bucket"] = metadata_bucket
metadata_access["key"] = metadata_key
metadata_access["endpoint_url"] = metadata_endpoint_url
metadata_access["credentials_name"] = metadata_credentials_name
else:
raise ValueError(f"Unknown database type {metadata_database_type}")
dataset["metadata_access"] = metadata_access
# Step 3: Validate
ds_info = DSInfo.model_validate(dataset)
validated_dataset = ds_info.model_dump()
validated_metadata = Metadata.model_validate(metadata_dict).model_dump()
# Step 4: Insert into db
with shelve.open(self.path, writeback=True) as db:
db[TK.DATASETS] = [*db.get(TK.DATASETS, []), validated_dataset]
db[TK.METADATA] = db.get(TK.METADATA, {}) | {dataset_name: validated_metadata}
[docs]
def del_dataset(self, dataset_name: str) -> None:
with shelve.open(self.path, writeback=True) as db:
for ds in db[TK.DATASETS]:
if ds["dataset_name"] == dataset_name:
db[TK.DATASETS].remove(ds)
[docs]
def add_dataset_to_user(self, username: str, dataset_name: str, epsilon: float, delta: float) -> None:
with shelve.open(self.path, writeback=True) as db:
user = User.model_validate(db[TK.USERS][username])
ds = DatasetOfUser(dataset_name=dataset_name, initial_epsilon=epsilon, initial_delta=delta)
user_updated = User(
id=user.id,
may_query=user.may_query,
datasets_list=[*user.datasets_list, ds],
)
db[TK.USERS][username] = user_updated.model_dump()
[docs]
def del_dataset_to_user(self, username: str, dataset_name: str) -> None:
with shelve.open(self.path, writeback=True) as db:
user = User.model_validate(db[TK.USERS][username])
user_updated = User(
id=user.id,
may_query=user.may_query,
datasets_list=[dsu for dsu in user.datasets_list if dsu.dataset_name != dataset_name],
)
db[TK.USERS][username] = user_updated.model_dump()
[docs]
def add_users_via_yaml(self, yaml_file: Path, clean: bool) -> None:
"""Add all users from yaml file to the user collection.
Args:
yaml_file (Path): a path to the YAML file location
clean (bool): boolean flag
True if drop current user collection
False if keep current user collection
Returns:
None
"""
if clean:
self.drop_collection("users")
# Load yaml data and insert it
yaml_dict = yaml.safe_load(yaml_file.resolve().open())
self.load_users_collection(UserCollection(**yaml_dict).users)
[docs]
def add_user(
self,
username: str,
email: str,
dataset_name: str | None = None,
epsilon: float = 0.0,
delta: float = 0.0,
) -> None:
"""Add new user in users collection with default values for all fields.
Args:
username (str): username to be added
email (str): email to be added
Raises:
ValueError: If the username already exists.
WriteConcernError: If the result is not acknowledged.
Returns:
None
"""
validated_user = User(
id=UserId(name=username, email=email),
may_query=True,
datasets_list=(
[]
if dataset_name is None
else [DatasetOfUser(dataset_name=dataset_name, initial_epsilon=epsilon, initial_delta=delta)]
),
).model_dump()
with shelve.open(self.path, writeback=True) as db:
if "users" not in db:
db[TK.USERS] = {}
db[TK.USERS][username] = validated_user
[docs]
@user_must_exist
def del_user(self, username: str) -> None:
with shelve.open(self.path, writeback=True) as db:
del db[TK.USERS][username]
[docs]
@override
def does_user_exist(self, user_name: str) -> bool:
return user_name in map(lambda user: user.id.name, self.users())
[docs]
@override
def does_dataset_exist(self, dataset_name: str) -> bool:
return dataset_name in map(lambda ds: ds.dataset_name, self.datasets())
[docs]
@override
@dataset_must_exist
def get_dataset(self, dataset_name: str) -> DSInfo:
with shelve.open(self.path, flag="r") as db:
dataset = next(filter(lambda ds: ds["dataset_name"] == dataset_name, db[TK.DATASETS]))
return DSInfo.model_validate(dataset)
[docs]
@override
@user_must_exist
def get_and_set_may_user_query(self, user_name: str, may_query: bool) -> bool:
with shelve.open(self.path, writeback=True) as db:
previous_may_query = db[TK.USERS][user_name]["may_query"]
db[TK.USERS][user_name]["may_query"] = may_query
return previous_may_query
[docs]
@override
@user_must_exist
def has_user_access_to_dataset(self, user_name: str, dataset_name: str) -> bool:
@dataset_must_exist
def has_access_to_dataset(self: Self, dataset_name: str) -> bool:
with shelve.open(self.path, flag="r") as db:
return bool(
[
ds
for ds in db[TK.USERS][user_name]["datasets_list"]
if ds["dataset_name"] == dataset_name
]
)
return has_access_to_dataset(self, dataset_name)
[docs]
@override
def get_epsilon_or_delta(self, user_name: str, dataset_name: str, parameter: BudgetDBKey) -> float:
with shelve.open(self.path, flag="r") as db:
return sum(
map(
op.itemgetter(parameter),
filter(
lambda ds: ds["dataset_name"] == dataset_name,
db[TK.USERS][user_name]["datasets_list"],
),
)
)
[docs]
@override
def update_epsilon_or_delta(
self,
user_name: str,
dataset_name: str,
parameter: BudgetDBKey,
spent_value: float,
) -> None:
with shelve.open(self.path, writeback=True) as db:
datasets = db[TK.USERS][user_name]["datasets_list"]
for ds in datasets:
if ds["dataset_name"] == dataset_name:
ds[parameter] += spent_value
[docs]
@override
@user_must_have_access_to_dataset
def get_user_previous_queries(
self,
user_name: str,
dataset_name: str,
) -> list[dict]:
def match(archive: dict[str, str]) -> bool:
return (user_name, dataset_name) == op.itemgetter("user_name", "dataset_name")(archive)
with shelve.open(self.path, flag="r") as db:
return list(filter(match, db.get(TK.ARCHIVE, [])))
[docs]
@override
def save_query(self, user_name: str, query: LomasRequestModel, response: QueryResponse) -> None:
with shelve.open(self.path, writeback=True) as db:
to_archive = self.prepare_save_query(user_name, query, response)
if TK.ARCHIVE not in db:
db[TK.ARCHIVE] = [to_archive]
else:
db[TK.ARCHIVE].append(to_archive)
[docs]
def get_archives_of_user(self, username: str) -> list[dict]:
with shelve.open(self.path, flag="r") as db:
return [archive for archive in db.get(TK.ARCHIVE, []) if archive["user_name"] == username]
[docs]
def drop_archive(self) -> None:
self.drop_collection(TK.ARCHIVE)
[docs]
def get_collection(self, collection: str) -> dict[str, Any]:
with shelve.open(self.path, flag="r") as db:
return db.get(collection, {})
[docs]
def drop_collection(self, collection: str) -> None:
with shelve.open(self.path, writeback=True) as db:
if collection in db:
del db[collection]