From 9465892cc6675f487754f1a641ee9d489d9289aa Mon Sep 17 00:00:00 2001 From: shay7sev Date: Fri, 30 Jan 2026 16:28:29 +0800 Subject: [PATCH] feat(telemetry): add axum_middleware --- Cargo.toml | 2 +- src/axum_middleware.rs | 72 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 3 ++ src/telemetry.rs | 4 +++ tests/integration_test.rs | 59 ++++++++++++++++++++++++++++++++ 5 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 src/axum_middleware.rs diff --git a/Cargo.toml b/Cargo.toml index be92542..e6943ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "common-telemetry" -version = "0.1.4" +version = "0.1.5" edition = "2024" description = "Microservice infrastructure library" diff --git a/src/axum_middleware.rs b/src/axum_middleware.rs new file mode 100644 index 0000000..8d1e7cf --- /dev/null +++ b/src/axum_middleware.rs @@ -0,0 +1,72 @@ +#[cfg(all(feature = "response", feature = "telemetry"))] +use axum::{ + extract::{ConnectInfo, MatchedPath, Request}, + middleware::Next, + response::Response, +}; + +#[cfg(all(feature = "response", feature = "telemetry"))] +use std::net::SocketAddr; + +#[cfg(all(feature = "response", feature = "telemetry"))] +use tracing::field; + +#[cfg(all(feature = "response", feature = "telemetry"))] +fn first_forwarded_for(req: &Request) -> Option { + let raw = req.headers().get("x-forwarded-for")?.to_str().ok()?; + raw.split(',') + .next() + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) +} + +#[cfg(all(feature = "response", feature = "telemetry"))] +fn connect_info_ip(req: &Request) -> Option { + req.extensions() + .get::>() + .map(|ci| ci.0.ip().to_string()) +} + +#[cfg(all(feature = "response", feature = "telemetry"))] +fn matched_route(req: &Request) -> Option { + req.extensions() + .get::() + .map(|m| m.as_str().to_string()) +} + +#[cfg(all(feature = "response", feature = "telemetry"))] +fn request_id(req: &Request) -> Option { + req.headers() + .get("x-request-id") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()) +} + +#[cfg(all(feature = "response", feature = "telemetry"))] +pub async fn trace_http_request(req: Request, next: Next) -> Response { + let method = req.method().to_string(); + let path = req.uri().path().to_string(); + let route = matched_route(&req).unwrap_or_else(|| path.clone()); + let client_ip = first_forwarded_for(&req).or_else(|| connect_info_ip(&req)); + let req_id = request_id(&req); + + let span = tracing::info_span!( + "http.request", + http.method = %method, + http.path = %path, + http.route = %route, + client.ip = %client_ip.clone().unwrap_or_else(|| "unknown".into()), + request_id = %req_id.clone().unwrap_or_else(|| "unknown".into()), + tenant_id = field::Empty, + user_id = field::Empty, + status = field::Empty, + ); + + let _guard = span.enter(); + let resp = next.run(req).await; + let status = resp.status().as_u16(); + tracing::Span::current().record("status", status); + resp +} + diff --git a/src/lib.rs b/src/lib.rs index b6e0beb..0a91a10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,9 @@ pub mod response; #[cfg(feature = "telemetry")] pub mod telemetry; +#[cfg(all(feature = "response", feature = "telemetry"))] +pub mod axum_middleware; + // 方便外部直接 use common_lib::AppError; #[cfg(feature = "response")] pub use error::{AppError, BizCode}; diff --git a/src/telemetry.rs b/src/telemetry.rs index b74c2a3..26038fc 100644 --- a/src/telemetry.rs +++ b/src/telemetry.rs @@ -37,6 +37,8 @@ pub fn init(config: TelemetryConfig) -> Option { .json() // 生产环境通常用 json .with_file(true) .with_line_number(true) + .with_current_span(true) + .with_span_list(true) .with_writer(std::io::stdout); // 3. 准备注册表 @@ -57,6 +59,8 @@ pub fn init(config: TelemetryConfig) -> Option { .json() .with_file(true) .with_line_number(true) + .with_current_span(true) + .with_span_list(true) .with_writer(non_blocking) .with_ansi(false); // 文件不需要颜色 diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 3a49c1c..109dd25 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -2,6 +2,7 @@ use axum::{Router, body::Body, routing::get}; use common_telemetry::{ error::AppError, response::AppResponse, + axum_middleware::trace_http_request, telemetry::{self, TelemetryConfig}, }; use http::{Request, StatusCode}; @@ -237,3 +238,61 @@ async fn test_full_flow_error_and_logging() { // temp_dir 会在作用域结束时自动删除清理 } + +#[tokio::test] +async fn error_log_includes_http_context_span() { + use std::sync::{Arc, Mutex}; + use tracing_subscriber::fmt::MakeWriter; + + struct BufferWriter(Arc>>); + impl<'a> MakeWriter<'a> for BufferWriter { + type Writer = BufferGuard; + fn make_writer(&'a self) -> Self::Writer { + BufferGuard(self.0.clone()) + } + } + struct BufferGuard(Arc>>); + impl std::io::Write for BufferGuard { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.0.lock().unwrap().extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } + } + + async fn handler() -> Result { + Err(AppError::MissingAuthHeader) + } + + let buf = Arc::new(Mutex::new(Vec::::new())); + let subscriber = tracing_subscriber::fmt() + .with_writer(BufferWriter(buf.clone())) + .with_ansi(false) + .json() + .with_current_span(true) + .with_span_list(true) + .finish(); + let _guard = tracing::subscriber::set_default(subscriber); + + let app = Router::new() + .route("/needs-auth", get(handler)) + .layer(axum::middleware::from_fn(trace_http_request)); + + let resp = app + .oneshot( + Request::builder() + .uri("/needs-auth") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + + let logs = String::from_utf8(buf.lock().unwrap().clone()).unwrap(); + assert!(logs.contains("\"message\":\"Request failed\"")); + assert!(logs.contains("/needs-auth")); + assert!(logs.contains("http.method") || logs.contains("GET")); +}