Source code for lomas_server.routes.utils

import asyncio
import logging
import posix as Status
import random
import sys
import time
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import wraps
from typing import Annotated
from uuid import UUID

import aio_pika
from fastapi import Depends, FastAPI, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor

from lomas_core.constants import DPLibraries
from lomas_core.error_handler import UnauthorizedAccessException
from lomas_core.models.collections import UserId
from lomas_core.models.constants import TimeAttackMethod
from lomas_core.models.requests import (
    DummyQueryModel,
    LomasRequestModel,
    QueryModel,
)
from lomas_core.models.responses import CostResponse, Job, QueryResponse
from lomas_server.models.config import Config

AioPikaInstrumentor().instrument()


[docs] async def process_response( queue: aio_pika.Queue, cls: type[QueryResponse | CostResponse], jobs: dict[UUID, Job] ) -> None: """Process responses queue into Jobs.""" async with queue.iterator() as queue_iter: async for message in queue_iter: async with message.process(ignore_processed=True): if message.correlation_id not in jobs: await message.reject(requeue=True) else: await message.ack() match message.headers: case {"type": "exception", "status_code": status_code}: jobs[message.correlation_id].error = message.body.decode() jobs[message.correlation_id].status = "failed" jobs[message.correlation_id].result = None jobs[message.correlation_id].status_code = status_code case _: jobs[message.correlation_id].result = cls.model_validate_json( message.body.decode() ) jobs[message.correlation_id].status = "complete"
[docs] async def rabbitmq_connect_queue( config: Config, reconnect_interval: int = 10, timeout: int = 120 ) -> aio_pika.RobustConnection: """Attempt with retries to connect to the queue.""" try: async with asyncio.timeout(timeout): connection = await aio_pika.connect_robust( str(config.amqp.dsn), fail_fast=False, reconnect_interval=reconnect_interval, ) return connection except TimeoutError: logging.error(f"Couldn't connect to queue {config.amqp.base_url} in time") sys.exit(Status.EX_UNAVAILABLE)
[docs] @asynccontextmanager async def rabbitmq_ctx(app: FastAPI) -> AsyncIterator[None]: """RabbitMQ queue context to connect and register callbacks.""" config = Config() connection = await rabbitmq_connect_queue(config) channel = await connection.channel() background_tasks = set() # Avoid dangling asyncio.Task by storing them here await channel.declare_queue("task_queue", auto_delete=True) app.state.task_queue_channel = channel queue = await channel.declare_queue("task_response", auto_delete=True) tasks_response_task = asyncio.create_task(process_response(queue, QueryResponse, app.state.jobs)) background_tasks.add(tasks_response_task) tasks_response_task.add_done_callback(background_tasks.discard) await channel.declare_queue("cost_queue", auto_delete=True) app.state.cost_queue_channel = channel queue = await channel.declare_queue("cost_response", auto_delete=True) cost_response_task = asyncio.create_task(process_response(queue, CostResponse, app.state.jobs)) background_tasks.add(cost_response_task) cost_response_task.add_done_callback(background_tasks.discard) await channel.declare_queue("dummy_queue", auto_delete=True) app.state.dummy_queue_channel = channel queue = await channel.declare_queue("dummy_response", auto_delete=True) dummy_response_task = asyncio.create_task(process_response(queue, QueryResponse, app.state.jobs)) background_tasks.add(dummy_response_task) dummy_response_task.add_done_callback(background_tasks.discard) yield # app is handling requests await connection.close()
[docs] def timing_protection(func): # type: ignore[no-untyped-def] """Adds delays to requests response to protect against timing attack.""" @wraps(func) def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] config = Config() start_time = time.time() response = func(*args, **kwargs) process_time = time.time() - start_time match config.server.time_attack.method: case TimeAttackMethod.STALL: # Slows to a minimum response time defined by magnitude if process_time < config.server.time_attack.magnitude: time.sleep(config.server.time_attack.magnitude - process_time) case TimeAttackMethod.JITTER: # Adds some time between 0 and magnitude secs time.sleep(config.server.time_attack.magnitude * random.uniform(0, 1)) return response return wrapper
[docs] def get_user_id_from_authenticator( request: Request, security_scopes: SecurityScopes, auth_creds: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer())], ) -> UserId: """Extracts the authenticator from the app state and calls its get_user_id method. Also adds the user_name to the request state to annotate the telemetry request span. Args: request (Request): The request to access the app and state. security_scopes (SecurityScopes): The required scopes for the endpoint. auth_creds (Annotated[HTTPAuthorizationCredentials, Depends): The HTTP bearer token. Returns: UserId: A UserId instance extracted from the token. """ user_id = request.app.state.authenticator.get_user_id(security_scopes, auth_creds) request.state.user_name = user_id.name return user_id
[docs] @timing_protection async def handle_query_to_job( request: Request, query: DummyQueryModel | QueryModel | LomasRequestModel, user_name: str, dp_library: DPLibraries, ) -> Job: """ Submit Job to handles queries on private, dummy and cost datasets on a worker. Args: request (Request): Raw request object query (DummyQueryModel|QueryModel|LomasRequestModel): A Request or Query to be scheduled user_name (str): The user name dp_library (DPLibraries): Name of the DP library to use for the request Raises: UnauthorizedAccessException: A query is already ongoing for this user, the user does not exist or does not have access to the dataset. Returns: Job: A scheduled Job resulting in a QueryResponse containing the result of the query (specific to the library) as well as the cost of the query. or a CostResponse containing the epsilon, delta and privacy-loss budget cost for the request. """ app = request.app dataset_name = query.dataset_name if not app.state.admin_database.has_user_access_to_dataset(user_name, dataset_name): raise UnauthorizedAccessException(f"{user_name} does not have access to {dataset_name}.") match query: case DummyQueryModel(): queue_name = "dummy_queue" case QueryModel(): queue_name = "task_queue" case LomasRequestModel(): queue_name = "cost_queue" new_task = Job(requested_by=user_name) app.state.jobs[str(new_task.uid)] = new_task await app.state.cost_queue_channel.default_exchange.publish( aio_pika.Message( body=f"{user_name}:{dp_library}:{query.model_dump_json()}".encode(), correlation_id=new_task.uid ), routing_key=queue_name, ) return new_task