import json
import logging
import time
from typing import Tuple
from fastapi import Request
from opentelemetry.trace import format_trace_id, get_tracer
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import Response
from starlette.routing import Match
from starlette.types import ASGIApp
from lomas_core.error_handler import KNOWN_EXCEPTIONS
from lomas_server.constants import SERVER_SERVICE_NAME
from lomas_server.utils.metrics import (
FAST_API_EXCEPTION_COUNTER,
FAST_API_REQUESTS_COUNTER,
FAST_API_REQUESTS_IN_PROGRESS_GAUGE,
FAST_API_REQUESTS_PROCESSING_HISTOGRAM,
FAST_API_RESPONSES_COUNTER,
)
[docs]
class LoggingAndTracingMiddleware(BaseHTTPMiddleware):
"""
Middleware for logging and tracing incoming HTTP requests.
This middleware logs the incoming requests, including the user name
the route being accessed, and any query parameters.
Additionally, it creates a trace span to trace the user’s request and
adds attributes to the span related to the user name and query parameters.
"""
[docs]
async def dispatch(self, request: Request, call_next):
"""
Handles the request and performs logging and tracing.
Logs the user name, the route and the query parameters.
Creates a trace span to monitor the request and adds relevant attributes.
Args:
request (Request): The incoming request object.
call_next (Callable): A function that, when called, passes the request
to the next middleware or request handler.
Returns:
Response: The HTTP response generated by calling `call_next(request)`.
"""
user_name = request.headers.get("user-name")
route = request.url.path
try:
query_params = await request.json()
except (json.JSONDecodeError, ValueError):
query_params = {}
for param, value in query_params.items():
if value is None:
query_params[param] = ""
if isinstance(value, dict):
query_params[param] = json.dumps(value)
tracer = get_tracer(__name__)
with tracer.start_as_current_span("user_request_span") as span:
span.set_attribute("user_name", user_name)
for param, value in query_params.items():
span.set_attribute(f"query_param.{param}", value)
logging.info(
f"User '{user_name}' is making a request to route '{route}' "
+ f"with query params: {query_params}. "
+ f"trace_id={format_trace_id(span.get_span_context().trace_id)}"
)
response = await call_next(request)
return response
[docs]
class FastAPIMetricMiddleware(BaseHTTPMiddleware):
"""
Middleware to collect and expose Prometheus metrics for a FastAPI application.
This middleware tracks various metrics related to HTTP requests, including:
- Total requests (`fastapi_requests_total`)
- Total responses (`fastapi_responses_total`)
- Exceptions raised (`fastapi_exceptions_total`)
- Request processing duration (`fastapi_requests_duration_seconds`)
- Current requests in progress (`fastapi_requests_in_progress`)
It also supports integration with an OpenTelemetry exporter for exporting metrics
to a metrics collector (e.g., Prometheus or any other OTLP-compatible collector).
"""
def __init__(self, app: ASGIApp, app_name: str = SERVER_SERVICE_NAME) -> None:
"""
Initializes the MetricMiddleware.
Args:
app (ASGIApp): The FastAPI application instance.
app_name (str): The name of the application used for metric labeling.
"""
super().__init__(app)
self.app_name = app_name
[docs]
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
"""
Processes HTTP request, records metrics and returns the HTTP response.
This method performs the following steps:
1. Tracks the current request in progress using the
`fastapi_requests_in_progress` gauge.
2. Records the request count with the `fastapi_requests_total` counter.
3. Records the time taken to process the request using the
`fastapi_requests_duration_seconds` histogram.
4. Handles exceptions, if raised, and records the exception details using the
`fastapi_exceptions_total` counter.
5. Records the response status code with the `fastapi_responses_total` counter.
6. Decrements the in-progress request gauge after processing.
Args:
request (Request): The incoming HTTP request to be processed.
call_next (RequestResponseEndpoint): The endpoint function that processes
the request and returns a response.
Returns:
Response: The HTTP response after processing the request.
Raises:
BaseException: If an exception occurs during request processing, it is
raised after logging it.
"""
method = request.method
path, is_handled_path = self.get_path(request)
if not is_handled_path:
return await call_next(request)
# Track requests being processed
FAST_API_REQUESTS_IN_PROGRESS_GAUGE.add(
1, {"method": method, "path": path, "app_name": self.app_name}
)
FAST_API_REQUESTS_COUNTER.add(1, {"method": method, "path": path, "app_name": self.app_name})
before_time = time.perf_counter()
try:
response = await call_next(request)
except KNOWN_EXCEPTIONS as e:
FAST_API_EXCEPTION_COUNTER.add(
1,
{
"method": method,
"path": path,
"exception_type": type(e).__name__,
"app_name": self.app_name,
},
)
raise e from None
else:
status_code = response.status_code
after_time = time.perf_counter()
# Record request processing time
FAST_API_REQUESTS_PROCESSING_HISTOGRAM.record(
after_time - before_time,
{"method": method, "path": path, "app_name": self.app_name},
)
finally:
FAST_API_RESPONSES_COUNTER.add(
1,
{
"method": method,
"path": path,
"status_code": status_code,
"app_name": self.app_name,
},
)
FAST_API_REQUESTS_IN_PROGRESS_GAUGE.add(
-1, {"method": method, "path": path, "app_name": self.app_name}
)
return response
[docs]
@staticmethod
def get_path(request: Request) -> Tuple[str, bool]:
"""
Attempts to match the request' route to a defined route.
Args:
request (Request): The HTTP request to check for a matching path.
Returns:
Tuple[str, bool]: A tuple containing:
- The matched path (str) from the request URL.
- Boolean (True if the path was handled by one of the routes).
"""
for route in request.app.routes:
match, _ = route.matches(request.scope)
if match == Match.FULL:
return route.path, True
return request.url.path, False