feat(telemetry): add axum_middleware

This commit is contained in:
2026-01-30 16:28:29 +08:00
parent 4db955113c
commit 9465892cc6
5 changed files with 139 additions and 1 deletions

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "common-telemetry" name = "common-telemetry"
version = "0.1.4" version = "0.1.5"
edition = "2024" edition = "2024"
description = "Microservice infrastructure library" description = "Microservice infrastructure library"

72
src/axum_middleware.rs Normal file
View File

@@ -0,0 +1,72 @@
#[cfg(all(feature = "response", feature = "telemetry"))]
use axum::{
extract::{ConnectInfo, MatchedPath, Request},
middleware::Next,
response::Response,
};
#[cfg(all(feature = "response", feature = "telemetry"))]
use std::net::SocketAddr;
#[cfg(all(feature = "response", feature = "telemetry"))]
use tracing::field;
#[cfg(all(feature = "response", feature = "telemetry"))]
fn first_forwarded_for(req: &Request) -> Option<String> {
let raw = req.headers().get("x-forwarded-for")?.to_str().ok()?;
raw.split(',')
.next()
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
}
#[cfg(all(feature = "response", feature = "telemetry"))]
fn connect_info_ip(req: &Request) -> Option<String> {
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip().to_string())
}
#[cfg(all(feature = "response", feature = "telemetry"))]
fn matched_route(req: &Request) -> Option<String> {
req.extensions()
.get::<MatchedPath>()
.map(|m| m.as_str().to_string())
}
#[cfg(all(feature = "response", feature = "telemetry"))]
fn request_id(req: &Request) -> Option<String> {
req.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
#[cfg(all(feature = "response", feature = "telemetry"))]
pub async fn trace_http_request(req: Request, next: Next) -> Response {
let method = req.method().to_string();
let path = req.uri().path().to_string();
let route = matched_route(&req).unwrap_or_else(|| path.clone());
let client_ip = first_forwarded_for(&req).or_else(|| connect_info_ip(&req));
let req_id = request_id(&req);
let span = tracing::info_span!(
"http.request",
http.method = %method,
http.path = %path,
http.route = %route,
client.ip = %client_ip.clone().unwrap_or_else(|| "unknown".into()),
request_id = %req_id.clone().unwrap_or_else(|| "unknown".into()),
tenant_id = field::Empty,
user_id = field::Empty,
status = field::Empty,
);
let _guard = span.enter();
let resp = next.run(req).await;
let status = resp.status().as_u16();
tracing::Span::current().record("status", status);
resp
}

View File

@@ -9,6 +9,9 @@ pub mod response;
#[cfg(feature = "telemetry")] #[cfg(feature = "telemetry")]
pub mod telemetry; pub mod telemetry;
#[cfg(all(feature = "response", feature = "telemetry"))]
pub mod axum_middleware;
// 方便外部直接 use common_lib::AppError; // 方便外部直接 use common_lib::AppError;
#[cfg(feature = "response")] #[cfg(feature = "response")]
pub use error::{AppError, BizCode}; pub use error::{AppError, BizCode};

View File

@@ -37,6 +37,8 @@ pub fn init(config: TelemetryConfig) -> Option<WorkerGuard> {
.json() // 生产环境通常用 json .json() // 生产环境通常用 json
.with_file(true) .with_file(true)
.with_line_number(true) .with_line_number(true)
.with_current_span(true)
.with_span_list(true)
.with_writer(std::io::stdout); .with_writer(std::io::stdout);
// 3. 准备注册表 // 3. 准备注册表
@@ -57,6 +59,8 @@ pub fn init(config: TelemetryConfig) -> Option<WorkerGuard> {
.json() .json()
.with_file(true) .with_file(true)
.with_line_number(true) .with_line_number(true)
.with_current_span(true)
.with_span_list(true)
.with_writer(non_blocking) .with_writer(non_blocking)
.with_ansi(false); // 文件不需要颜色 .with_ansi(false); // 文件不需要颜色

View File

@@ -2,6 +2,7 @@ use axum::{Router, body::Body, routing::get};
use common_telemetry::{ use common_telemetry::{
error::AppError, error::AppError,
response::AppResponse, response::AppResponse,
axum_middleware::trace_http_request,
telemetry::{self, TelemetryConfig}, telemetry::{self, TelemetryConfig},
}; };
use http::{Request, StatusCode}; use http::{Request, StatusCode};
@@ -237,3 +238,61 @@ async fn test_full_flow_error_and_logging() {
// temp_dir 会在作用域结束时自动删除清理 // temp_dir 会在作用域结束时自动删除清理
} }
#[tokio::test]
async fn error_log_includes_http_context_span() {
use std::sync::{Arc, Mutex};
use tracing_subscriber::fmt::MakeWriter;
struct BufferWriter(Arc<Mutex<Vec<u8>>>);
impl<'a> MakeWriter<'a> for BufferWriter {
type Writer = BufferGuard;
fn make_writer(&'a self) -> Self::Writer {
BufferGuard(self.0.clone())
}
}
struct BufferGuard(Arc<Mutex<Vec<u8>>>);
impl std::io::Write for BufferGuard {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
async fn handler() -> Result<String, AppError> {
Err(AppError::MissingAuthHeader)
}
let buf = Arc::new(Mutex::new(Vec::<u8>::new()));
let subscriber = tracing_subscriber::fmt()
.with_writer(BufferWriter(buf.clone()))
.with_ansi(false)
.json()
.with_current_span(true)
.with_span_list(true)
.finish();
let _guard = tracing::subscriber::set_default(subscriber);
let app = Router::new()
.route("/needs-auth", get(handler))
.layer(axum::middleware::from_fn(trace_http_request));
let resp = app
.oneshot(
Request::builder()
.uri("/needs-auth")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let logs = String::from_utf8(buf.lock().unwrap().clone()).unwrap();
assert!(logs.contains("\"message\":\"Request failed\""));
assert!(logs.contains("/needs-auth"));
assert!(logs.contains("http.method") || logs.contains("GET"));
}