Source code for lomas_server.routes.middlewares

import json
import logging
import time
from typing import Tuple

from fastapi import Request
from opentelemetry.trace import format_trace_id, get_tracer
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import Response
from starlette.routing import Match
from starlette.types import ASGIApp

from lomas_core.error_handler import KNOWN_EXCEPTIONS
from lomas_server.constants import SERVER_SERVICE_NAME
from lomas_server.utils.metrics import (
    FAST_API_EXCEPTION_COUNTER,
    FAST_API_REQUESTS_COUNTER,
    FAST_API_REQUESTS_IN_PROGRESS_GAUGE,
    FAST_API_REQUESTS_PROCESSING_HISTOGRAM,
    FAST_API_RESPONSES_COUNTER,
)


[docs] class LoggingAndTracingMiddleware(BaseHTTPMiddleware): """ Middleware for logging and tracing incoming HTTP requests. This middleware logs the incoming requests, including the user name the route being accessed, and any query parameters. Additionally, it creates a trace span to trace the user’s request and adds attributes to the span related to the user name and query parameters. """
[docs] async def dispatch(self, request: Request, call_next): """ Handles the request and performs logging and tracing. Logs the user name, the route and the query parameters. Creates a trace span to monitor the request and adds relevant attributes. Args: request (Request): The incoming request object. call_next (Callable): A function that, when called, passes the request to the next middleware or request handler. Returns: Response: The HTTP response generated by calling `call_next(request)`. """ user_name = request.headers.get("user-name") route = request.url.path try: query_params = await request.json() except (json.JSONDecodeError, ValueError): query_params = {} for param, value in query_params.items(): if value is None: query_params[param] = "" if isinstance(value, dict): query_params[param] = json.dumps(value) tracer = get_tracer(__name__) with tracer.start_as_current_span("user_request_span") as span: span.set_attribute("user_name", user_name) for param, value in query_params.items(): span.set_attribute(f"query_param.{param}", value) logging.info( f"User '{user_name}' is making a request to route '{route}' " + f"with query params: {query_params}. " + f"trace_id={format_trace_id(span.get_span_context().trace_id)}" ) response = await call_next(request) return response
[docs] class FastAPIMetricMiddleware(BaseHTTPMiddleware): """ Middleware to collect and expose Prometheus metrics for a FastAPI application. This middleware tracks various metrics related to HTTP requests, including: - Total requests (`fastapi_requests_total`) - Total responses (`fastapi_responses_total`) - Exceptions raised (`fastapi_exceptions_total`) - Request processing duration (`fastapi_requests_duration_seconds`) - Current requests in progress (`fastapi_requests_in_progress`) It also supports integration with an OpenTelemetry exporter for exporting metrics to a metrics collector (e.g., Prometheus or any other OTLP-compatible collector). """ def __init__(self, app: ASGIApp, app_name: str = SERVER_SERVICE_NAME) -> None: """ Initializes the MetricMiddleware. Args: app (ASGIApp): The FastAPI application instance. app_name (str): The name of the application used for metric labeling. """ super().__init__(app) self.app_name = app_name
[docs] async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: """ Processes HTTP request, records metrics and returns the HTTP response. This method performs the following steps: 1. Tracks the current request in progress using the `fastapi_requests_in_progress` gauge. 2. Records the request count with the `fastapi_requests_total` counter. 3. Records the time taken to process the request using the `fastapi_requests_duration_seconds` histogram. 4. Handles exceptions, if raised, and records the exception details using the `fastapi_exceptions_total` counter. 5. Records the response status code with the `fastapi_responses_total` counter. 6. Decrements the in-progress request gauge after processing. Args: request (Request): The incoming HTTP request to be processed. call_next (RequestResponseEndpoint): The endpoint function that processes the request and returns a response. Returns: Response: The HTTP response after processing the request. Raises: BaseException: If an exception occurs during request processing, it is raised after logging it. """ method = request.method path, is_handled_path = self.get_path(request) if not is_handled_path: return await call_next(request) # Track requests being processed FAST_API_REQUESTS_IN_PROGRESS_GAUGE.add( 1, {"method": method, "path": path, "app_name": self.app_name} ) FAST_API_REQUESTS_COUNTER.add(1, {"method": method, "path": path, "app_name": self.app_name}) before_time = time.perf_counter() try: response = await call_next(request) except KNOWN_EXCEPTIONS as e: FAST_API_EXCEPTION_COUNTER.add( 1, { "method": method, "path": path, "exception_type": type(e).__name__, "app_name": self.app_name, }, ) raise e from None else: status_code = response.status_code after_time = time.perf_counter() # Record request processing time FAST_API_REQUESTS_PROCESSING_HISTOGRAM.record( after_time - before_time, {"method": method, "path": path, "app_name": self.app_name}, ) finally: FAST_API_RESPONSES_COUNTER.add( 1, { "method": method, "path": path, "status_code": status_code, "app_name": self.app_name, }, ) FAST_API_REQUESTS_IN_PROGRESS_GAUGE.add( -1, {"method": method, "path": path, "app_name": self.app_name} ) return response
[docs] @staticmethod def get_path(request: Request) -> Tuple[str, bool]: """ Attempts to match the request' route to a defined route. Args: request (Request): The HTTP request to check for a matching path. Returns: Tuple[str, bool]: A tuple containing: - The matched path (str) from the request URL. - Boolean (True if the path was handled by one of the routes). """ for route in request.app.routes: match, _ = route.matches(request.scope) if match == Match.FULL: return route.path, True return request.url.path, False