diff --git a/Cargo.lock b/Cargo.lock index 273fd91..aa60031 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -66,10 +66,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "auth-kit" -version = "0.1.0" +version = "0.1.1" dependencies = [ "axum", - "axum-extra", "base64", "common-telemetry", "dashmap", @@ -163,28 +162,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "axum-extra" -version = "0.12.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fef252edff26ddba56bbcdf2ee3307b8129acb86f5749b68990c168a6fcc9c76" -dependencies = [ - "axum", - "axum-core", - "bytes", - "cookie", - "futures-core", - "futures-util", - "http", - "http-body", - "http-body-util", - "mime", - "pin-project-lite", - "tower-layer", - "tower-service", - "tracing", -] - [[package]] name = "base64" version = "0.22.1" @@ -390,17 +367,6 @@ dependencies = [ "unicode-segmentation", ] -[[package]] -name = "cookie" -version = "0.18.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" -dependencies = [ - "percent-encoding", - "time", - "version_check", -] - [[package]] name = "core-foundation" version = "0.9.4" diff --git a/src/api/handlers/auth.rs b/src/api/handlers/auth.rs index 6567d46..30a07e5 100644 --- a/src/api/handlers/auth.rs +++ b/src/api/handlers/auth.rs @@ -10,7 +10,6 @@ use serde::Deserialize; use crate::api::AppState; use crate::api::handlers::common::extract_bearer_token; -use auth_kit::middleware::{auth::AuthContext, tenant::TenantId}; #[derive(Debug, Deserialize)] pub struct CallbackQuery { @@ -84,15 +83,42 @@ fn cookie_header( } pub async fn logout_handler( - TenantId(tenant_id): TenantId, - AuthContext { user_id: _, .. }: AuthContext, axum::extract::State(state): axum::extract::State, headers: HeaderMap, ) -> Result { let secure = is_https(&headers); - let token = extract_bearer_token(&headers)?; - let _ = state.iam_client.logout(tenant_id, &token).await; + let token = extract_bearer_token(&headers).ok(); + let tenant_id = headers + .get("X-Tenant-ID") + .and_then(|v| v.to_str().ok()) + .and_then(|s| uuid::Uuid::parse_str(s).ok()) + .or_else(|| { + let cookie_header = headers + .get(header::COOKIE) + .and_then(|h| h.to_str().ok()) + .unwrap_or(""); + for part in cookie_header.split(';') { + let part = part.trim(); + let Some((name, value)) = part.split_once('=') else { + continue; + }; + if name.trim() != "tenantId" { + continue; + } + let raw = value.trim(); + let decoded = urlencoding::decode(raw).ok().map(|s| s.into_owned()); + let v = decoded.unwrap_or_else(|| raw.to_string()); + if let Ok(id) = uuid::Uuid::parse_str(&v) { + return Some(id); + } + } + None + }); + + if let (Some(tenant_id), Some(token)) = (tenant_id, token) { + let _ = state.iam_client.logout(tenant_id, &token).await; + } let mut res = axum::Json(serde_json::json!({})).into_response(); diff --git a/src/api/middleware/mod.rs b/src/api/middleware/mod.rs index f4f57f0..9aee045 100644 --- a/src/api/middleware/mod.rs +++ b/src/api/middleware/mod.rs @@ -5,9 +5,40 @@ use axum::{ }; use common_telemetry::AppError; use futures_util::FutureExt; -use http::HeaderValue; +use http::{HeaderValue, header}; use std::{panic::AssertUnwindSafe, time::Instant}; +pub async fn inject_auth_header_from_cookie(mut req: Request, next: Next) -> Response { + if req.headers().contains_key(header::AUTHORIZATION) { + return next.run(req).await; + } + + let cookie_header = req + .headers() + .get(header::COOKIE) + .and_then(|h| h.to_str().ok()) + .unwrap_or(""); + + for part in cookie_header.split(';') { + let part = part.trim(); + let Some((name, value)) = part.split_once('=') else { + continue; + }; + if name.trim() != "accessToken" { + continue; + } + let raw = value.trim(); + let decoded = urlencoding::decode(raw).ok().map(|s| s.into_owned()); + let token = decoded.unwrap_or_else(|| raw.to_string()); + if let Ok(v) = HeaderValue::from_str(&format!("Bearer {}", token)) { + req.headers_mut().insert(header::AUTHORIZATION, v); + } + break; + } + + next.run(req).await +} + pub async fn ensure_request_id(mut req: Request, next: Next) -> Response { let request_id = req .headers() diff --git a/src/main.rs b/src/main.rs index b0c9f81..8819d38 100644 --- a/src/main.rs +++ b/src/main.rs @@ -56,6 +56,7 @@ async fn main() { "/healthz".to_string(), format!("{}/auth/callback", CANONICAL_BASE), format!("{}/auth/refresh", CANONICAL_BASE), + format!("{}/auth/logout", CANONICAL_BASE), ], skip_path_prefixes: vec!["/scalar".to_string()], jwt: match &config.jwt_public_key_pem { @@ -80,6 +81,7 @@ async fn main() { "/healthz".to_string(), format!("{}/auth/callback", CANONICAL_BASE), format!("{}/auth/refresh", CANONICAL_BASE), + format!("{}/auth/logout", CANONICAL_BASE), ], skip_path_prefixes: vec!["/scalar".to_string()], }; @@ -93,6 +95,9 @@ async fn main() { auth_cfg, auth_kit::middleware::auth::authenticate_with_config, )) + .layer(from_fn( + cms_service::api::middleware::inject_auth_header_from_cookie, + )) .layer(from_fn( common_telemetry::axum_middleware::trace_http_request, ))