diff --git a/Cargo.lock b/Cargo.lock index 8ed6375..273fd91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -69,6 +69,7 @@ name = "auth-kit" version = "0.1.0" dependencies = [ "axum", + "axum-extra", "base64", "common-telemetry", "dashmap", @@ -162,6 +163,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fef252edff26ddba56bbcdf2ee3307b8129acb86f5749b68990c168a6fcc9c76" +dependencies = [ + "axum", + "axum-core", + "bytes", + "cookie", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "base64" version = "0.22.1" @@ -367,6 +390,17 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + [[package]] name = "core-foundation" version = "0.9.4" diff --git a/docs/API.md b/docs/API.md index b4d9665..797e48e 100644 --- a/docs/API.md +++ b/docs/API.md @@ -17,33 +17,33 @@ CMS 对外暴露 RESTful API,并提供 Scalar 文档: ### 栏目(Column) -- `POST /v1/columns`(`cms:column:write`) -- `GET /v1/columns`(`cms:column:read`,分页/搜索) -- `GET /v1/columns/{id}`(`cms:column:read`) -- `PATCH /v1/columns/{id}`(`cms:column:write`) -- `DELETE /v1/columns/{id}`(`cms:column:write`) +- `POST /api/v1/columns`(`cms:column:write`) +- `GET /api/v1/columns`(`cms:column:read`,分页/搜索) +- `GET /api/v1/columns/{id}`(`cms:column:read`) +- `PATCH /api/v1/columns/{id}`(`cms:column:write`) +- `DELETE /api/v1/columns/{id}`(`cms:column:write`) ### 标签/分类(Tag) -- `POST /v1/tags`(`cms:tag:write`,`kind` 支持 `tag|category`) -- `GET /v1/tags`(`cms:tag:read`,分页/搜索/按 kind 过滤) -- `GET /v1/tags/{id}`(`cms:tag:read`) -- `PATCH /v1/tags/{id}`(`cms:tag:write`) -- `DELETE /v1/tags/{id}`(`cms:tag:write`) +- `POST /api/v1/tags`(`cms:tag:write`,`kind` 支持 `tag|category`) +- `GET /api/v1/tags`(`cms:tag:read`,分页/搜索/按 kind 过滤) +- `GET /api/v1/tags/{id}`(`cms:tag:read`) +- `PATCH /api/v1/tags/{id}`(`cms:tag:write`) +- `DELETE /api/v1/tags/{id}`(`cms:tag:write`) ### 媒体库(Media) -- `POST /v1/media`(`cms:media:manage`,登记 URL/元数据) -- `GET /v1/media`(`cms:media:read`,分页/搜索) -- `GET /v1/media/{id}`(`cms:media:read`) -- `DELETE /v1/media/{id}`(`cms:media:manage`) +- `POST /api/v1/media`(`cms:media:manage`,登记 URL/元数据) +- `GET /api/v1/media`(`cms:media:read`,分页/搜索) +- `GET /api/v1/media/{id}`(`cms:media:read`) +- `DELETE /api/v1/media/{id}`(`cms:media:manage`) ### 文章(Article) -- `POST /v1/articles`(`cms:article:write`,创建草稿) -- `GET /v1/articles`(`cms:article:read`,分页/搜索/按状态/栏目/标签过滤) -- `GET /v1/articles/{id}`(`cms:article:read`) -- `PATCH /v1/articles/{id}`(`cms:article:write`) -- `POST /v1/articles/{id}/publish`(`cms:article:publish`,发布并生成版本) -- `POST /v1/articles/{id}/rollback`(`cms:article:rollback`,回滚到指定版本并生成新版本) -- `GET /v1/articles/{id}/versions`(`cms:article:read`,版本列表分页) +- `POST /api/v1/articles`(`cms:article:edit`,创建草稿) +- `GET /api/v1/articles`(`cms:article:edit`,分页/搜索/按状态/栏目/标签过滤) +- `GET /api/v1/articles/{id}`(`cms:article:edit`) +- `PATCH /api/v1/articles/{id}`(`cms:article:edit`) +- `POST /api/v1/articles/{id}/publish`(`cms:article:publish`,发布并生成版本) +- `POST /api/v1/articles/{id}/rollback`(`cms:article:rollback`,回滚到指定版本并生成新版本) +- `GET /api/v1/articles/{id}/versions`(`cms:article:edit`,版本列表分页) diff --git a/src/api/docs.rs b/src/api/docs.rs index 55d1ffd..e403016 100644 --- a/src/api/docs.rs +++ b/src/api/docs.rs @@ -28,6 +28,9 @@ impl Modify for SecurityAddon { version = "0.1.0", description = include_str!("../../docs/API.md") ), + servers( + (url = "/api/v1", description = "Canonical API base") + ), paths( crate::api::handlers::column::create_column_handler, crate::api::handlers::column::list_columns_handler, @@ -66,6 +69,7 @@ impl Modify for SecurityAddon { crate::domain::models::Media, crate::domain::models::Article, crate::domain::models::ArticleVersion, + crate::domain::models::Paged, crate::infrastructure::repositories::article::ArticleWithTags ) ), diff --git a/src/api/handlers/article.rs b/src/api/handlers/article.rs index e332409..489cffa 100644 --- a/src/api/handlers/article.rs +++ b/src/api/handlers/article.rs @@ -8,7 +8,7 @@ use utoipa::IntoParams; use uuid::Uuid; use crate::api::{AppState, handlers::common::extract_bearer_token}; -use auth_kit::middleware::{tenant::TenantId, auth::AuthContext}; +use auth_kit::middleware::{auth::AuthContext, tenant::TenantId}; #[derive(Debug, serde::Deserialize, utoipa::ToSchema)] pub struct CreateArticleRequest { @@ -65,7 +65,7 @@ pub fn router() -> Router { #[utoipa::path( post, - path = "/v1/articles", + path = "/articles", tag = "Article", request_body = CreateArticleRequest, security( @@ -85,7 +85,12 @@ pub async fn create_article_handler( let token = extract_bearer_token(&headers)?; state .iam_client - .require_permission(tenant_id, user_id, "cms:article:write", &token) + .require_any_permissions( + tenant_id, + user_id, + &["cms:article:edit", "cms:article:create"], + &token, + ) .await?; let article = state @@ -106,7 +111,7 @@ pub async fn create_article_handler( #[utoipa::path( get, - path = "/v1/articles", + path = "/articles", tag = "Article", params(ListArticlesQuery), security( @@ -122,12 +127,14 @@ pub async fn list_articles_handler( State(state): State, headers: axum::http::HeaderMap, Query(query): Query, -) -> Result>, AppError> -{ +) -> Result< + AppResponse>, + AppError, +> { let token = extract_bearer_token(&headers)?; state .iam_client - .require_permission(tenant_id, user_id, "cms:article:read", &token) + .require_permission(tenant_id, user_id, "cms:article:edit", &token) .await?; let result = state @@ -149,7 +156,7 @@ pub async fn list_articles_handler( #[utoipa::path( get, - path = "/v1/articles/{id}", + path = "/articles/{id}", tag = "Article", params( ("id" = String, Path, description = "文章ID") @@ -171,7 +178,7 @@ pub async fn get_article_handler( let token = extract_bearer_token(&headers)?; state .iam_client - .require_permission(tenant_id, user_id, "cms:article:read", &token) + .require_permission(tenant_id, user_id, "cms:article:edit", &token) .await?; let article = state.services.get_article(tenant_id, id).await?; @@ -180,7 +187,7 @@ pub async fn get_article_handler( #[utoipa::path( patch, - path = "/v1/articles/{id}", + path = "/articles/{id}", tag = "Article", request_body = UpdateArticleRequest, params( @@ -204,7 +211,7 @@ pub async fn update_article_handler( let token = extract_bearer_token(&headers)?; state .iam_client - .require_permission(tenant_id, user_id, "cms:article:write", &token) + .require_permission(tenant_id, user_id, "cms:article:edit", &token) .await?; let article = state @@ -226,7 +233,7 @@ pub async fn update_article_handler( #[utoipa::path( post, - path = "/v1/articles/{id}/publish", + path = "/articles/{id}/publish", tag = "Article", params( ("id" = String, Path, description = "文章ID") @@ -251,13 +258,16 @@ pub async fn publish_article_handler( .require_permission(tenant_id, user_id, "cms:article:publish", &token) .await?; - let article = state.services.publish_article(tenant_id, id, Some(user_id)).await?; + let article = state + .services + .publish_article(tenant_id, id, Some(user_id)) + .await?; Ok(AppResponse::ok(article)) } #[utoipa::path( post, - path = "/v1/articles/{id}/rollback", + path = "/articles/{id}/rollback", tag = "Version", request_body = RollbackRequest, params( @@ -293,7 +303,7 @@ pub async fn rollback_article_handler( #[utoipa::path( get, - path = "/v1/articles/{id}/versions", + path = "/articles/{id}/versions", tag = "Version", params( ("id" = String, Path, description = "文章ID"), @@ -303,7 +313,10 @@ pub async fn rollback_article_handler( ("bearer_auth" = []) ), responses( - (status = 200, description = "版本列表", body = crate::infrastructure::repositories::column::Paged) + (status = 200, description = "版本列表", body = crate::domain::models::Paged), + (status = 401, description = "未认证"), + (status = 403, description = "无权限"), + (status = 404, description = "不存在") ) )] pub async fn list_versions_handler( @@ -313,12 +326,16 @@ pub async fn list_versions_handler( headers: axum::http::HeaderMap, Path(id): Path, Query(query): Query, -) -> Result>, AppError> -{ +) -> Result< + AppResponse< + crate::infrastructure::repositories::column::Paged, + >, + AppError, +> { let token = extract_bearer_token(&headers)?; state .iam_client - .require_permission(tenant_id, user_id, "cms:article:read", &token) + .require_permission(tenant_id, user_id, "cms:article:edit", &token) .await?; let versions = state diff --git a/src/api/handlers/auth.rs b/src/api/handlers/auth.rs index 068e593..6567d46 100644 --- a/src/api/handlers/auth.rs +++ b/src/api/handlers/auth.rs @@ -1,14 +1,16 @@ use axum::{ Router, extract::Query, - http::{HeaderValue, header}, + http::{HeaderMap, HeaderValue, header}, response::{IntoResponse, Redirect}, - routing::get, + routing::{get, post}, }; use common_telemetry::AppError; use serde::Deserialize; use crate::api::AppState; +use crate::api::handlers::common::extract_bearer_token; +use auth_kit::middleware::{auth::AuthContext, tenant::TenantId}; #[derive(Debug, Deserialize)] pub struct CallbackQuery { @@ -18,7 +20,6 @@ pub struct CallbackQuery { } #[derive(Debug, Deserialize, serde::Serialize)] -#[serde(rename_all = "camelCase")] struct Code2TokenRequest { code: String, client_id: String, @@ -31,10 +32,10 @@ struct RefreshTokenRequest { } #[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] struct Code2TokenData { access_token: String, refresh_token: String, + token_type: Option, expires_in: usize, tenant_id: String, user_id: String, @@ -51,6 +52,7 @@ pub fn router() -> Router { Router::new() .route("/callback", get(sso_callback_handler)) .route("/refresh", get(refresh_token_handler)) + .route("/logout", post(logout_handler)) } fn is_https(headers: &axum::http::HeaderMap) -> bool { @@ -81,6 +83,30 @@ fn cookie_header( s } +pub async fn logout_handler( + TenantId(tenant_id): TenantId, + AuthContext { user_id: _, .. }: AuthContext, + axum::extract::State(state): axum::extract::State, + headers: HeaderMap, +) -> Result { + let secure = is_https(&headers); + let token = extract_bearer_token(&headers)?; + + let _ = state.iam_client.logout(tenant_id, &token).await; + + let mut res = axum::Json(serde_json::json!({})).into_response(); + + for name in ["accessToken", "refreshToken", "tenantId", "userId"] { + res.headers_mut().append( + header::SET_COOKIE, + HeaderValue::from_str(&cookie_header(name, "", secure, true, Some(0))) + .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?, + ); + } + + Ok(res) +} + #[derive(Debug, Deserialize)] pub struct RefreshTokenQuery { pub token: String, @@ -91,8 +117,37 @@ pub async fn refresh_token_handler( headers: axum::http::HeaderMap, Query(q): Query, ) -> Result { + fn clear_cookie( + res: &mut axum::response::Response, + name: &str, + secure: bool, + ) -> Result<(), AppError> { + res.headers_mut().append( + header::SET_COOKIE, + HeaderValue::from_str(&cookie_header(name, "", secure, true, Some(0))) + .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?, + ); + Ok(()) + } + + fn clear_auth_cookies( + res: &mut axum::response::Response, + secure: bool, + ) -> Result<(), AppError> { + clear_cookie(res, "accessToken", secure)?; + clear_cookie(res, "refreshToken", secure)?; + clear_cookie(res, "tenantId", secure)?; + clear_cookie(res, "userId", secure)?; + Ok(()) + } + + let secure = is_https(&headers); + let target = resolve_front_redirect(q.next.clone()); + if q.token.trim().is_empty() { - return Ok(Redirect::temporary("/auth-error?message=missing_token").into_response()); + let mut res = Redirect::temporary(&target).into_response(); + clear_auth_cookies(&mut res, secure)?; + return Ok(res); } let iam_base = std::env::var("IAM_BASE_URL") @@ -117,50 +172,36 @@ pub async fn refresh_token_handler( // But LoginResponse structure is: access_token, refresh_token, token_type, expires_in. // Code2TokenData has tenant_id, user_id extra? // Let's check IAM service LoginResponse definition. - // IAM Service LoginResponse: access_token, refresh_token, token_type, expires_in. - // Wait, Code2TokenData expects tenant_id and user_id. - // Does IAM refresh endpoint return tenant_id and user_id? - // IAM Service LoginResponse struct in src/models.rs (iam-service) DOES NOT have tenant_id/user_id. - // So we cannot reuse Code2TokenData for refresh response parsing if we expect those fields. - // But usually refresh token response just updates access_token (and maybe refresh_token). - // TenantId and UserId should not change. We can keep existing cookies for them if we don't have them. - // But wait, we are setting cookies. If we don't get tenant_id/user_id, we can't set them (or we re-set them if we knew them). - // The previous cookies are still there. We just need to update access_token and refresh_token. - - // Let's define a separate struct for Refresh Response if needed. #[derive(Debug, Deserialize)] - #[serde(rename_all = "camelCase")] struct RefreshResponseData { access_token: String, refresh_token: String, + token_type: Option, expires_in: usize, } - let body = resp - .json::>() + let bytes = resp + .bytes() .await .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?; + let body = serde_json::from_slice::>(&bytes).ok(); - if !status.is_success() || body.code != 0 { - // Refresh failed, redirect to login - let login_url = resolve_front_redirect(q.next); // Actually redirect to front login page or handle error - // If refresh fails, we probably want to redirect to the original requested page so it can trigger login flow, - // OR redirect to auth-error. - // But the middleware calls this. If this returns redirect, the middleware will return redirect. - // If middleware sees error, it should redirect to login. - return Ok(Redirect::temporary("/auth-error?message=refresh_failed").into_response()); + if !status.is_success() || body.as_ref().map(|b| b.code).unwrap_or(1) != 0 { + let mut res = Redirect::temporary(&target).into_response(); + clear_auth_cookies(&mut res, secure)?; + return Ok(res); } - let Some(data) = body.data else { - return Ok(Redirect::temporary("/auth-error?message=invalid_refresh_response").into_response()); + let Some(data) = body.and_then(|b| b.data) else { + let mut res = Redirect::temporary(&target).into_response(); + clear_auth_cookies(&mut res, secure)?; + return Ok(res); }; - let target = resolve_front_redirect(q.next); - let secure = is_https(&headers); let mut res = Redirect::temporary(&target).into_response(); - + let refresh_max_age = 30_u64 * 24 * 60 * 60; - + res.headers_mut().append( header::SET_COOKIE, HeaderValue::from_str(&cookie_header( @@ -183,7 +224,7 @@ pub async fn refresh_token_handler( )) .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?, ); - + // We don't update tenantId/userId as we don't get them from refresh endpoint usually. // They should persist. @@ -211,6 +252,26 @@ fn resolve_front_redirect(next: Option) -> String { return raw; } + if raw.starts_with("http://") { + if let Some(base) = base { + if raw.starts_with(&base) { + return raw; + } + if cfg!(debug_assertions) + && (raw.starts_with("http://localhost") || raw.starts_with("http://127.0.0.1")) + { + return raw; + } + return base; + } + if cfg!(debug_assertions) + && (raw.starts_with("http://localhost") || raw.starts_with("http://127.0.0.1")) + { + return raw; + } + return "/".to_string(); + } + base.unwrap_or_else(|| "/".to_string()) } @@ -237,12 +298,7 @@ async fn sso_callback_handler( return Ok(Redirect::temporary(&target).into_response()); } - let tenant_id = q - .tenant_id - .as_deref() - .unwrap_or("") - .trim() - .to_string(); + let tenant_id = q.tenant_id.as_deref().unwrap_or("").trim().to_string(); if uuid::Uuid::parse_str(&tenant_id).is_err() { let target = resolve_front_error_redirect("missing or invalid tenant_id"); return Ok(Redirect::temporary(&target).into_response()); diff --git a/src/api/handlers/column.rs b/src/api/handlers/column.rs index 289382b..5157654 100644 --- a/src/api/handlers/column.rs +++ b/src/api/handlers/column.rs @@ -8,7 +8,7 @@ use utoipa::IntoParams; use uuid::Uuid; use crate::api::{AppState, handlers::common::extract_bearer_token}; -use auth_kit::middleware::{tenant::TenantId, auth::AuthContext}; +use auth_kit::middleware::{auth::AuthContext, tenant::TenantId}; #[derive(Debug, serde::Deserialize, utoipa::ToSchema)] pub struct CreateColumnRequest { @@ -49,7 +49,7 @@ pub fn router() -> Router { #[utoipa::path( post, - path = "/v1/columns", + path = "/columns", tag = "Column", request_body = CreateColumnRequest, security( @@ -90,14 +90,14 @@ pub async fn create_column_handler( #[utoipa::path( get, - path = "/v1/columns", + path = "/columns", tag = "Column", params(ListColumnsQuery), security( ("bearer_auth" = []) ), responses( - (status = 200, description = "栏目列表", body = crate::infrastructure::repositories::column::Paged), + (status = 200, description = "栏目列表", body = crate::domain::models::Paged), (status = 401, description = "未认证"), (status = 403, description = "无权限") ) @@ -108,8 +108,10 @@ pub async fn list_columns_handler( State(state): State, headers: axum::http::HeaderMap, Query(query): Query, -) -> Result>, AppError> -{ +) -> Result< + AppResponse>, + AppError, +> { let token = extract_bearer_token(&headers)?; state .iam_client @@ -133,7 +135,7 @@ pub async fn list_columns_handler( #[utoipa::path( get, - path = "/v1/columns/{id}", + path = "/columns/{id}", tag = "Column", params( ("id" = String, Path, description = "栏目ID") @@ -167,7 +169,7 @@ pub async fn get_column_handler( #[utoipa::path( patch, - path = "/v1/columns/{id}", + path = "/columns/{id}", tag = "Column", request_body = UpdateColumnRequest, params( @@ -214,7 +216,7 @@ pub async fn update_column_handler( #[utoipa::path( delete, - path = "/v1/columns/{id}", + path = "/columns/{id}", tag = "Column", params( ("id" = String, Path, description = "栏目ID") diff --git a/src/api/handlers/common.rs b/src/api/handlers/common.rs index e9c2296..0469d47 100644 --- a/src/api/handlers/common.rs +++ b/src/api/handlers/common.rs @@ -2,10 +2,31 @@ use axum::http::HeaderMap; use common_telemetry::AppError; pub fn extract_bearer_token(headers: &HeaderMap) -> Result { - let token = headers + if let Some(token) = headers .get(axum::http::header::AUTHORIZATION) .and_then(|h| h.to_str().ok()) .and_then(|v| v.strip_prefix("Bearer ")) - .ok_or(AppError::MissingAuthHeader)?; - Ok(token.to_string()) + { + return Ok(token.to_string()); + } + + let cookie_header = headers + .get(axum::http::header::COOKIE) + .and_then(|h| h.to_str().ok()) + .unwrap_or(""); + + for part in cookie_header.split(';') { + let part = part.trim(); + let Some((name, value)) = part.split_once('=') else { + continue; + }; + if name.trim() != "accessToken" { + continue; + } + let raw = value.trim(); + let decoded = urlencoding::decode(raw).ok().map(|s| s.into_owned()); + return Ok(decoded.unwrap_or_else(|| raw.to_string())); + } + + Err(AppError::MissingAuthHeader) } diff --git a/src/api/handlers/media.rs b/src/api/handlers/media.rs index ffb751d..6ccf99f 100644 --- a/src/api/handlers/media.rs +++ b/src/api/handlers/media.rs @@ -8,7 +8,7 @@ use utoipa::IntoParams; use uuid::Uuid; use crate::api::{AppState, handlers::common::extract_bearer_token}; -use auth_kit::middleware::{tenant::TenantId, auth::AuthContext}; +use auth_kit::middleware::{auth::AuthContext, tenant::TenantId}; #[derive(Debug, serde::Deserialize, utoipa::ToSchema)] pub struct CreateMediaRequest { @@ -34,7 +34,7 @@ pub fn router() -> Router { #[utoipa::path( post, - path = "/v1/media", + path = "/media", tag = "Media", request_body = CreateMediaRequest, security( @@ -74,14 +74,16 @@ pub async fn create_media_handler( #[utoipa::path( get, - path = "/v1/media", + path = "/media", tag = "Media", params(ListMediaQuery), security( ("bearer_auth" = []) ), responses( - (status = 200, description = "媒体列表", body = crate::infrastructure::repositories::column::Paged) + (status = 200, description = "媒体列表", body = crate::domain::models::Paged), + (status = 401, description = "未认证"), + (status = 403, description = "无权限") ) )] pub async fn list_media_handler( @@ -90,8 +92,10 @@ pub async fn list_media_handler( State(state): State, headers: axum::http::HeaderMap, Query(query): Query, -) -> Result>, AppError> -{ +) -> Result< + AppResponse>, + AppError, +> { let token = extract_bearer_token(&headers)?; state .iam_client @@ -114,7 +118,7 @@ pub async fn list_media_handler( #[utoipa::path( get, - path = "/v1/media/{id}", + path = "/media/{id}", tag = "Media", params( ("id" = String, Path, description = "媒体ID") @@ -145,7 +149,7 @@ pub async fn get_media_handler( #[utoipa::path( delete, - path = "/v1/media/{id}", + path = "/media/{id}", tag = "Media", params( ("id" = String, Path, description = "媒体ID") diff --git a/src/api/handlers/tag.rs b/src/api/handlers/tag.rs index 67ffa33..0969e3e 100644 --- a/src/api/handlers/tag.rs +++ b/src/api/handlers/tag.rs @@ -8,7 +8,7 @@ use utoipa::IntoParams; use uuid::Uuid; use crate::api::{AppState, handlers::common::extract_bearer_token}; -use auth_kit::middleware::{tenant::TenantId, auth::AuthContext}; +use auth_kit::middleware::{auth::AuthContext, tenant::TenantId}; #[derive(Debug, serde::Deserialize, utoipa::ToSchema)] pub struct CreateTagRequest { @@ -44,7 +44,7 @@ pub fn router() -> Router { #[utoipa::path( post, - path = "/v1/tags", + path = "/tags", tag = "Tag", request_body = CreateTagRequest, security( @@ -76,14 +76,16 @@ pub async fn create_tag_handler( #[utoipa::path( get, - path = "/v1/tags", + path = "/tags", tag = "Tag", params(ListTagsQuery), security( ("bearer_auth" = []) ), responses( - (status = 200, description = "标签/分类列表", body = crate::infrastructure::repositories::column::Paged) + (status = 200, description = "标签/分类列表", body = crate::domain::models::Paged), + (status = 401, description = "未认证"), + (status = 403, description = "无权限") ) )] pub async fn list_tags_handler( @@ -92,8 +94,10 @@ pub async fn list_tags_handler( State(state): State, headers: axum::http::HeaderMap, Query(query): Query, -) -> Result>, AppError> -{ +) -> Result< + AppResponse>, + AppError, +> { let token = extract_bearer_token(&headers)?; state .iam_client @@ -117,7 +121,7 @@ pub async fn list_tags_handler( #[utoipa::path( get, - path = "/v1/tags/{id}", + path = "/tags/{id}", tag = "Tag", params( ("id" = String, Path, description = "标签/分类ID") @@ -148,7 +152,7 @@ pub async fn get_tag_handler( #[utoipa::path( patch, - path = "/v1/tags/{id}", + path = "/tags/{id}", tag = "Tag", request_body = UpdateTagRequest, params( @@ -184,7 +188,7 @@ pub async fn update_tag_handler( #[utoipa::path( delete, - path = "/v1/tags/{id}", + path = "/tags/{id}", tag = "Tag", params( ("id" = String, Path, description = "标签/分类ID") diff --git a/src/api/mod.rs b/src/api/mod.rs index 2a9b587..2800e1d 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -2,14 +2,15 @@ pub mod docs; pub mod handlers; pub mod middleware; -use axum::routing::get; use axum::Router; +use axum::routing::get; use utoipa::OpenApi; use utoipa_scalar::{Scalar, Servable}; use crate::api::docs::ApiDoc; use crate::api::middleware::{catch_panic, request_logger}; use crate::application::services::CmsServices; +use crate::constants::CANONICAL_BASE; use crate::infrastructure::iam_client::IamClient; #[derive(Clone)] @@ -21,20 +22,25 @@ pub struct AppState { pub fn build_router(state: AppState) -> Router { let health = Router::new().route("/healthz", get(|| async { axum::http::StatusCode::OK })); - let auth = Router::new().nest("/auth", handlers::auth::router()); - let v1 = Router::new() + .nest("/auth", handlers::auth::router()) .nest("/columns", handlers::column::router()) .nest("/tags", handlers::tag::router()) .nest("/media", handlers::media::router()) .nest("/articles", handlers::article::router()); Router::new() - .route("/favicon.ico", get(|| async { axum::http::StatusCode::NO_CONTENT })) + .route( + "/favicon.ico", + get(|| async { axum::http::StatusCode::NO_CONTENT }), + ) .merge(Scalar::with_url("/scalar", ApiDoc::openapi())) + .route( + "/scalar/openapi.json", + get(|| async { axum::Json(ApiDoc::openapi()) }), + ) .merge(health) - .merge(auth) - .nest("/v1", v1) + .nest(CANONICAL_BASE, v1) .layer(axum::middleware::from_fn(catch_panic)) .layer(axum::middleware::from_fn(request_logger)) .with_state(state) diff --git a/src/constants.rs b/src/constants.rs new file mode 100644 index 0000000..13432fa --- /dev/null +++ b/src/constants.rs @@ -0,0 +1 @@ +pub const CANONICAL_BASE: &str = "/api/v1"; diff --git a/src/domain/models.rs b/src/domain/models.rs index 39ea1f7..62c2cc3 100644 --- a/src/domain/models.rs +++ b/src/domain/models.rs @@ -3,6 +3,15 @@ use sqlx::FromRow; use utoipa::ToSchema; use uuid::Uuid; +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct Paged { + pub items: Vec, + pub total: i64, + pub page: i32, + pub page_size: i32, + pub total_pages: i32, +} + #[derive(Debug, Clone, Serialize, Deserialize, ToSchema, FromRow)] pub struct Column { pub tenant_id: Uuid, diff --git a/src/infrastructure/iam_client/mod.rs b/src/infrastructure/iam_client/mod.rs index b3a765f..aaceb99 100644 --- a/src/infrastructure/iam_client/mod.rs +++ b/src/infrastructure/iam_client/mod.rs @@ -9,6 +9,10 @@ use dashmap::DashMap; use serde::{Deserialize, Serialize}; use uuid::Uuid; +use crate::constants::CANONICAL_BASE; + +const CMS_ADMIN_PERMISSION: &str = "cms:admin"; + #[derive(Clone, Debug)] pub struct IamClientConfig { pub base_url: String, @@ -70,6 +74,42 @@ struct AuthorizationCheckResponse { allowed: bool, } +#[derive(Debug, Serialize)] +struct AuthorizationExprCheckRequest { + expr: PermissionExpr, +} + +#[derive(Debug, Deserialize)] +struct AuthorizationExprCheckResponse { + allowed: bool, +} + +#[derive(Debug, Serialize)] +#[serde(deny_unknown_fields)] +struct AnyExpr { + any: Vec, +} + +#[derive(Debug, Serialize)] +#[serde(deny_unknown_fields)] +struct AllExpr { + all: Vec, +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +enum PermissionExpr { + Any(AnyExpr), + All(AllExpr), +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +enum PermissionExprItem { + Permission(String), + Expr(Box), +} + #[derive(Debug, Deserialize)] struct ApiSuccessResponse { #[allow(dead_code)] @@ -111,6 +151,146 @@ impl IamClient { } } + pub async fn logout(&self, tenant_id: Uuid, access_token: &str) -> Result<(), AppError> { + let base = self.inner.cfg.base_url.trim_end_matches('/'); + let api_base = if base.ends_with(CANONICAL_BASE) { + base.to_string() + } else { + format!("{}{}", base, CANONICAL_BASE) + }; + let url = format!("{}/auth/logout", api_base); + + let resp = self + .inner + .http + .post(url) + .bearer_auth(access_token) + .header("X-Tenant-ID", tenant_id.to_string()) + .send() + .await + .map_err(|e| AppError::ExternalReqError(format!("iam:request_failed:{}", e)))?; + + let status = resp.status(); + if status == reqwest::StatusCode::UNAUTHORIZED { + return Err(AppError::AuthError("iam:unauthorized".into())); + } + if !status.is_success() { + return Err(AppError::ExternalReqError(format!( + "iam:unexpected_status:{}", + status.as_u16() + ))); + } + + Ok(()) + } + + async fn check_permission_expr( + &self, + tenant_id: Uuid, + access_token: &str, + expr: PermissionExpr, + ) -> Result { + let expr = with_admin_override(expr); + let base = self.inner.cfg.base_url.trim_end_matches('/'); + let api_base = if base.ends_with(CANONICAL_BASE) { + base.to_string() + } else { + format!("{}{}", base, CANONICAL_BASE) + }; + let url = format!("{}/authorize/check-expr", api_base); + + let resp = self + .inner + .http + .post(url) + .bearer_auth(access_token) + .header("X-Tenant-ID", tenant_id.to_string()) + .json(&AuthorizationExprCheckRequest { expr }) + .send() + .await + .map_err(|e| AppError::ExternalReqError(format!("iam:request_failed:{}", e)))?; + + let status = resp.status(); + if status == reqwest::StatusCode::UNAUTHORIZED { + return Err(AppError::AuthError("iam:unauthorized".into())); + } + if status == reqwest::StatusCode::FORBIDDEN { + return Err(AppError::PermissionDenied("iam:forbidden".into())); + } + if !status.is_success() { + return Err(AppError::ExternalReqError(format!( + "iam:unexpected_status:{}", + status.as_u16() + ))); + } + + let body: ApiSuccessResponse = resp + .json() + .await + .map_err(|e| AppError::ExternalReqError(format!("iam:decode_failed:{}", e)))?; + + let allowed = body + .data + .map(|d| d.allowed) + .ok_or_else(|| AppError::ExternalReqError("iam:missing_data".into()))?; + + Ok(allowed) + } + + async fn require_permission_expr( + &self, + tenant_id: Uuid, + user_id: Uuid, + expr: PermissionExpr, + access_token: &str, + ) -> Result<(), AppError> { + let allowed = self + .check_permission_expr(tenant_id, access_token, expr) + .await?; + if allowed { + Ok(()) + } else { + Err(AppError::PermissionDenied(format!( + "iam:expr_denied:{}:{}", + tenant_id, user_id + ))) + } + } + + pub async fn require_any_permissions( + &self, + tenant_id: Uuid, + user_id: Uuid, + permissions: &[&str], + access_token: &str, + ) -> Result<(), AppError> { + let expr = PermissionExpr::Any(AnyExpr { + any: permissions + .iter() + .map(|p| PermissionExprItem::Permission((*p).to_string())) + .collect(), + }); + self.require_permission_expr(tenant_id, user_id, expr, access_token) + .await + } + + pub async fn require_all_permissions( + &self, + tenant_id: Uuid, + user_id: Uuid, + permissions: &[&str], + access_token: &str, + ) -> Result<(), AppError> { + let expr = PermissionExpr::All(AllExpr { + all: permissions + .iter() + .map(|p| PermissionExprItem::Permission((*p).to_string())) + .collect(), + }); + self.require_permission_expr(tenant_id, user_id, expr, access_token) + .await + } + async fn check_permission( &self, tenant_id: Uuid, @@ -177,48 +357,28 @@ impl IamClient { permission: &str, access_token: &str, ) -> Result { - let url = format!( - "{}/authorize/check", - self.inner.cfg.base_url.trim_end_matches('/') - ); - - let resp = self - .inner - .http - .post(url) - .bearer_auth(access_token) - .header("X-Tenant-ID", tenant_id.to_string()) - .json(&AuthorizationCheckRequest { - permission: permission.to_string(), - }) - .send() + let expr = PermissionExpr::Any(AnyExpr { + any: vec![PermissionExprItem::Permission(permission.to_string())], + }); + self.check_permission_expr(tenant_id, access_token, expr) .await - .map_err(|e| AppError::ExternalReqError(format!("iam:request_failed:{}", e)))?; - - let status = resp.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - return Err(AppError::AuthError("iam:unauthorized".into())); - } - if status == reqwest::StatusCode::FORBIDDEN { - return Err(AppError::PermissionDenied("iam:forbidden".into())); - } - if !status.is_success() { - return Err(AppError::ExternalReqError(format!( - "iam:unexpected_status:{}", - status.as_u16() - ))); - } - - let body: ApiSuccessResponse = resp - .json() - .await - .map_err(|e| AppError::ExternalReqError(format!("iam:decode_failed:{}", e)))?; - - let allowed = body - .data - .map(|d| d.allowed) - .ok_or_else(|| AppError::ExternalReqError("iam:missing_data".into()))?; - - Ok(allowed) + } +} + +fn with_admin_override(expr: PermissionExpr) -> PermissionExpr { + match expr { + PermissionExpr::Any(mut x) => { + x.any.insert( + 0, + PermissionExprItem::Permission(CMS_ADMIN_PERMISSION.to_string()), + ); + PermissionExpr::Any(x) + } + PermissionExpr::All(x) => PermissionExpr::Any(AnyExpr { + any: vec![ + PermissionExprItem::Permission(CMS_ADMIN_PERMISSION.to_string()), + PermissionExprItem::Expr(Box::new(PermissionExpr::All(x))), + ], + }), } } diff --git a/src/lib.rs b/src/lib.rs index edd3285..5221358 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod api; pub mod application; pub mod config; +pub mod constants; pub mod domain; pub mod infrastructure; diff --git a/src/main.rs b/src/main.rs index de16f95..b0c9f81 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,16 @@ +use auth_kit::middleware::{auth::AuthMiddlewareConfig, tenant::TenantMiddlewareConfig}; use axum::middleware::{from_fn, from_fn_with_state}; use cms_service::{ api::{self, AppState}, application::services::CmsServices, config::AppConfig, - infrastructure::{db, iam_client::{IamClient, IamClientConfig}}, + constants::CANONICAL_BASE, + infrastructure::{ + db, + iam_client::{IamClient, IamClientConfig}, + }, }; use common_telemetry::telemetry::{self, TelemetryConfig}; -use auth_kit::middleware::{tenant::TenantMiddlewareConfig, auth::AuthMiddlewareConfig}; use std::net::SocketAddr; use std::time::Duration; @@ -23,7 +27,9 @@ async fn main() { log_file: Some(config.log_file_name.clone()), }); - let pool = db::init_pool(&config).await.expect("failed to init db pool"); + let pool = db::init_pool(&config) + .await + .expect("failed to init db pool"); let run_migrations = std::env::var("RUN_MIGRATIONS") .ok() .map(|v| matches!(v.as_str(), "1" | "true" | "TRUE")) @@ -46,17 +52,23 @@ async fn main() { }; let auth_cfg = AuthMiddlewareConfig { - skip_exact_paths: vec!["/healthz".to_string(), "/auth/callback".to_string()], + skip_exact_paths: vec![ + "/healthz".to_string(), + format!("{}/auth/callback", CANONICAL_BASE), + format!("{}/auth/refresh", CANONICAL_BASE), + ], skip_path_prefixes: vec!["/scalar".to_string()], jwt: match &config.jwt_public_key_pem { Some(pem) => auth_kit::jwt::JwtVerifyConfig::rs256_from_pem("iam-service", pem) .expect("invalid JWT_PUBLIC_KEY_PEM"), None => { let jwks_url = config.iam_jwks_url.clone().unwrap_or_else(|| { - format!( - "{}/.well-known/jwks.json", - config.iam_base_url.trim_end_matches('/') - ) + let base = config.iam_base_url.trim_end_matches('/'); + if base.ends_with(CANONICAL_BASE) { + format!("{}/.well-known/jwks.json", base) + } else { + format!("{}{}/.well-known/jwks.json", base, CANONICAL_BASE) + } }); auth_kit::jwt::JwtVerifyConfig::rs256_from_jwks("iam-service", &jwks_url) .expect("invalid IAM_JWKS_URL") @@ -64,7 +76,11 @@ async fn main() { }, }; let tenant_cfg = TenantMiddlewareConfig { - skip_exact_paths: vec!["/healthz".to_string(), "/auth/callback".to_string()], + skip_exact_paths: vec![ + "/healthz".to_string(), + format!("{}/auth/callback", CANONICAL_BASE), + format!("{}/auth/refresh", CANONICAL_BASE), + ], skip_path_prefixes: vec!["/scalar".to_string()], }; @@ -77,7 +93,9 @@ async fn main() { auth_cfg, auth_kit::middleware::auth::authenticate_with_config, )) - .layer(from_fn(common_telemetry::axum_middleware::trace_http_request)) + .layer(from_fn( + common_telemetry::axum_middleware::trace_http_request, + )) .layer(from_fn(cms_service::api::middleware::ensure_request_id)); let addr = SocketAddr::from(([0, 0, 0, 0], config.port)); @@ -85,7 +103,10 @@ async fn main() { tracing::info!("Docs available at http://{}/scalar", addr); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); - axum::serve(listener, app.into_make_service_with_connect_info::()) - .await - .unwrap(); + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .await + .unwrap(); } diff --git a/tests/iam_client_cache.rs b/tests/iam_client_cache.rs index 22b3b21..9cb58e0 100644 --- a/tests/iam_client_cache.rs +++ b/tests/iam_client_cache.rs @@ -4,15 +4,11 @@ use std::sync::{ }; use std::time::Duration; -use axum::{Json, Router, routing::post}; use axum::response::IntoResponse; +use axum::{Json, Router, routing::post}; use cms_service::infrastructure::iam_client::{IamClient, IamClientConfig}; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Deserialize)] -struct AuthorizationCheckRequest { - permission: String, -} +use serde::Serialize; +use serde_json::Value; #[derive(Debug, Serialize)] struct AuthorizationCheckResponse { @@ -31,28 +27,28 @@ async fn start_mock_iam( call_count: Arc, fail: Arc, ) -> (String, tokio::task::JoinHandle<()>) { - let app = Router::new().route( - "/authorize/check", - post(move |Json(body): Json| { - let call_count = call_count.clone(); - let fail = fail.clone(); - async move { - call_count.fetch_add(1, Ordering::SeqCst); - if fail.load(Ordering::SeqCst) { - return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "fail").into_response(); - } - - let allowed = body.permission == "cms:article:read"; - let resp = ApiSuccessResponse { - code: 0, - message: "ok".to_string(), - data: AuthorizationCheckResponse { allowed }, - trace_id: None, - }; - (axum::http::StatusCode::OK, Json(resp)).into_response() + let handler = move |Json(_body): Json| { + let call_count = call_count.clone(); + let fail = fail.clone(); + async move { + call_count.fetch_add(1, Ordering::SeqCst); + if fail.load(Ordering::SeqCst) { + return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "fail").into_response(); } - }), - ); + + let resp = ApiSuccessResponse { + code: 0, + message: "ok".to_string(), + data: AuthorizationCheckResponse { allowed: true }, + trace_id: None, + }; + (axum::http::StatusCode::OK, Json(resp)).into_response() + } + }; + + let app = Router::new() + .route("/authorize/check-expr", post(handler.clone())) + .route("/api/v1/authorize/check-expr", post(handler)); let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -81,11 +77,11 @@ async fn iam_client_caches_decisions() { let user_id = uuid::Uuid::new_v4(); client - .require_permission(tenant_id, user_id, "cms:article:read", "token") + .require_permission(tenant_id, user_id, "cms:article:edit", "token") .await .unwrap(); client - .require_permission(tenant_id, user_id, "cms:article:read", "token") + .require_permission(tenant_id, user_id, "cms:article:edit", "token") .await .unwrap(); @@ -111,7 +107,7 @@ async fn iam_client_uses_stale_cache_on_error() { let user_id = uuid::Uuid::new_v4(); client - .require_permission(tenant_id, user_id, "cms:article:read", "token") + .require_permission(tenant_id, user_id, "cms:article:edit", "token") .await .unwrap(); @@ -119,7 +115,7 @@ async fn iam_client_uses_stale_cache_on_error() { fail.store(true, Ordering::SeqCst); client - .require_permission(tenant_id, user_id, "cms:article:read", "token") + .require_permission(tenant_id, user_id, "cms:article:edit", "token") .await .unwrap(); diff --git a/tests/iam_client_expr.rs b/tests/iam_client_expr.rs new file mode 100644 index 0000000..3af7d31 --- /dev/null +++ b/tests/iam_client_expr.rs @@ -0,0 +1,68 @@ +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; +use std::time::Duration; + +use axum::{Json, Router, routing::post}; +use axum::response::IntoResponse; +use cms_service::infrastructure::iam_client::{IamClient, IamClientConfig}; +use serde_json::Value; + +async fn start_mock_iam(call_count: Arc) -> (String, tokio::task::JoinHandle<()>) { + let app = Router::new().route( + "/api/v1/authorize/check-expr", + post(move |Json(body): Json| { + let call_count = call_count.clone(); + async move { + call_count.fetch_add(1, Ordering::SeqCst); + let allowed = body.get("expr").is_some(); + let resp = serde_json::json!({ + "code": 0, + "message": "ok", + "data": { "allowed": allowed } + }); + (axum::http::StatusCode::OK, Json(resp)).into_response() + } + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let base_url = format!("http://{}", addr); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + (base_url, handle) +} + +#[tokio::test] +async fn iam_client_check_expr_hits_endpoint() { + let call_count = Arc::new(AtomicUsize::new(0)); + let (base_url, handle) = start_mock_iam(call_count.clone()).await; + + let client = IamClient::new(IamClientConfig { + base_url, + timeout: Duration::from_millis(500), + cache_ttl: Duration::from_secs(5), + cache_stale_if_error: Duration::from_secs(30), + cache_max_entries: 1000, + }); + + let tenant_id = uuid::Uuid::new_v4(); + let user_id = uuid::Uuid::new_v4(); + + client + .require_any_permissions( + tenant_id, + user_id, + &["cms:article:edit", "cms:article:create"], + "token", + ) + .await + .unwrap(); + + assert_eq!(call_count.load(Ordering::SeqCst), 1); + handle.abort(); +} + diff --git a/tests/iam_refresh_decode.rs b/tests/iam_refresh_decode.rs new file mode 100644 index 0000000..6663e62 --- /dev/null +++ b/tests/iam_refresh_decode.rs @@ -0,0 +1,40 @@ +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +struct AppResponse { + code: i32, + message: String, + data: Option, +} + +#[derive(Debug, Deserialize)] +struct RefreshResponseData { + access_token: String, + refresh_token: String, + token_type: Option, + expires_in: usize, +} + +#[test] +fn can_decode_iam_refresh_response_snake_case() { + let json = r#" + { + "code": 0, + "message": "ok", + "data": { + "access_token": "a", + "refresh_token": "r", + "token_type": "Bearer", + "expires_in": 7200 + } + } + "#; + + let parsed: AppResponse = serde_json::from_str(json).unwrap(); + let data = parsed.data.unwrap(); + assert_eq!(data.access_token, "a"); + assert_eq!(data.refresh_token, "r"); + assert_eq!(data.token_type.as_deref(), Some("Bearer")); + assert_eq!(data.expires_in, 7200); +} +