Source code for lomas_server.routes.utils

import asyncio
import os
import random
import time
from contextlib import asynccontextmanager
from functools import wraps

import aio_pika
from fastapi import Request
from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor

from lomas_core.constants import DPLibraries
from lomas_core.error_handler import UnauthorizedAccessException
from lomas_core.models.constants import TimeAttackMethod
from lomas_core.models.requests import (
    DummyQueryModel,
    LomasRequestModel,
    QueryModel,
)
from lomas_core.models.responses import CostResponse, Job, QueryResponse
from lomas_server.utils.config import get_config

AioPikaInstrumentor().instrument()

# TODO: merge in pydantic-settings
amqp_user = os.environ.get("LOMAS_AMQP_USER", "guest")
amqp_pass = os.environ.get("LOMAS_AMQP_PASS", "guest")


[docs] async def process_response(queue, cls, jobs_var): """Process responses queue into Jobs.""" async with queue.iterator() as queue_iter: async for message in queue_iter: async with message.process(): jobs = jobs_var.get() match message.headers: case {"type": "exception", "status_code": status_code}: jobs[message.correlation_id].error = message.body.decode() jobs[message.correlation_id].status = "failed" jobs[message.correlation_id].result = None jobs[message.correlation_id].status_code = status_code case _: jobs[message.correlation_id].result = cls.model_validate_json(message.body.decode()) jobs[message.correlation_id].status = "complete" jobs_var.set(jobs)
[docs] @asynccontextmanager async def rabbitmq_ctx(app): """RabbitMQ queue context to connect and register callbacks.""" connection = await aio_pika.connect_robust(f"amqp://{amqp_user}:{amqp_pass}@127.0.0.1/") channel = await connection.channel() await channel.declare_queue("task_queue", auto_delete=True) app.state.task_queue_channel = channel queue = await channel.declare_queue("task_response", auto_delete=True) asyncio.create_task(process_response(queue, QueryResponse, app.state.jobs_var)) await channel.declare_queue("cost_queue", auto_delete=True) app.state.cost_queue_channel = channel queue = await channel.declare_queue("cost_response", auto_delete=True) asyncio.create_task(process_response(queue, CostResponse, app.state.jobs_var)) await channel.declare_queue("dummy_queue", auto_delete=True) app.state.dummy_queue_channel = channel queue = await channel.declare_queue("dummy_response", auto_delete=True) asyncio.create_task(process_response(queue, QueryResponse, app.state.jobs_var)) yield # app is handling requests await connection.close()
[docs] def timing_protection(func): """Adds delays to requests response to protect against timing attack.""" @wraps(func) def wrapper(*args, **kwargs): config = get_config() start_time = time.time() response = func(*args, **kwargs) process_time = time.time() - start_time match config.server.time_attack.method: case TimeAttackMethod.STALL: # Slows to a minimum response time defined by magnitude if process_time < config.server.time_attack.magnitude: time.sleep(config.server.time_attack.magnitude - process_time) case TimeAttackMethod.JITTER: # Adds some time between 0 and magnitude secs time.sleep(config.server.time_attack.magnitude * random.uniform(0, 1)) return response return wrapper
[docs] @timing_protection async def handle_query_to_job( request: Request, query: DummyQueryModel | QueryModel | LomasRequestModel, user_name: str, dp_library: DPLibraries, ) -> Job: """ Submit Job to handles queries on private, dummy and cost datasets on a worker. Args: request (Request): Raw request object query (DummyQueryModel|QueryModel|LomasRequestModel): A Request or Query to be scheduled user_name (str): The user name dp_library (DPLibraries): Name of the DP library to use for the request Raises: UnauthorizedAccessException: A query is already ongoing for this user, the user does not exist or does not have access to the dataset. Returns: Job: A scheduled Job resulting in a QueryResponse containing the result of the query (specific to the library) as well as the cost of the query. or a CostResponse containing the epsilon, delta and privacy-loss budget cost for the request. """ app = request.app dataset_name = query.dataset_name if not app.state.admin_database.has_user_access_to_dataset(user_name, dataset_name): raise UnauthorizedAccessException(f"{user_name} does not have access to {dataset_name}.") match query: case DummyQueryModel(): queue_name = "dummy_queue" case QueryModel(): queue_name = "task_queue" case LomasRequestModel(): queue_name = "cost_queue" new_task = Job() jobs = app.state.jobs_var.get() jobs[str(new_task.uid)] = new_task app.state.jobs_var.set(jobs) await app.state.cost_queue_channel.default_exchange.publish( aio_pika.Message( body=f"{user_name}:{dp_library}:{query.json()}".encode(), correlation_id=new_task.uid ), routing_key=queue_name, ) return new_task