fix(auth): iam check

This commit is contained in:
2026-02-11 10:56:04 +08:00
parent 583fd521a2
commit 909d9a6da2
18 changed files with 646 additions and 202 deletions

34
Cargo.lock generated
View File

@@ -69,6 +69,7 @@ name = "auth-kit"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"axum", "axum",
"axum-extra",
"base64", "base64",
"common-telemetry", "common-telemetry",
"dashmap", "dashmap",
@@ -162,6 +163,28 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "axum-extra"
version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fef252edff26ddba56bbcdf2ee3307b8129acb86f5749b68990c168a6fcc9c76"
dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"tower-layer",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.22.1" version = "0.22.1"
@@ -367,6 +390,17 @@ dependencies = [
"unicode-segmentation", "unicode-segmentation",
] ]
[[package]]
name = "cookie"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
dependencies = [
"percent-encoding",
"time",
"version_check",
]
[[package]] [[package]]
name = "core-foundation" name = "core-foundation"
version = "0.9.4" version = "0.9.4"

View File

@@ -17,33 +17,33 @@ CMS 对外暴露 RESTful API并提供 Scalar 文档:
### 栏目Column ### 栏目Column
- `POST /v1/columns``cms:column:write` - `POST /api/v1/columns``cms:column:write`
- `GET /v1/columns``cms:column:read`,分页/搜索) - `GET /api/v1/columns``cms:column:read`,分页/搜索)
- `GET /v1/columns/{id}``cms:column:read` - `GET /api/v1/columns/{id}``cms:column:read`
- `PATCH /v1/columns/{id}``cms:column:write` - `PATCH /api/v1/columns/{id}``cms:column:write`
- `DELETE /v1/columns/{id}``cms:column:write` - `DELETE /api/v1/columns/{id}``cms:column:write`
### 标签/分类Tag ### 标签/分类Tag
- `POST /v1/tags``cms:tag:write``kind` 支持 `tag|category` - `POST /api/v1/tags``cms:tag:write``kind` 支持 `tag|category`
- `GET /v1/tags``cms:tag:read`,分页/搜索/按 kind 过滤) - `GET /api/v1/tags``cms:tag:read`,分页/搜索/按 kind 过滤)
- `GET /v1/tags/{id}``cms:tag:read` - `GET /api/v1/tags/{id}``cms:tag:read`
- `PATCH /v1/tags/{id}``cms:tag:write` - `PATCH /api/v1/tags/{id}``cms:tag:write`
- `DELETE /v1/tags/{id}``cms:tag:write` - `DELETE /api/v1/tags/{id}``cms:tag:write`
### 媒体库Media ### 媒体库Media
- `POST /v1/media``cms:media:manage`,登记 URL/元数据) - `POST /api/v1/media``cms:media:manage`,登记 URL/元数据)
- `GET /v1/media``cms:media:read`,分页/搜索) - `GET /api/v1/media``cms:media:read`,分页/搜索)
- `GET /v1/media/{id}``cms:media:read` - `GET /api/v1/media/{id}``cms:media:read`
- `DELETE /v1/media/{id}``cms:media:manage` - `DELETE /api/v1/media/{id}``cms:media:manage`
### 文章Article ### 文章Article
- `POST /v1/articles``cms:article:write`,创建草稿) - `POST /api/v1/articles``cms:article:edit`,创建草稿)
- `GET /v1/articles``cms:article:read`,分页/搜索/按状态/栏目/标签过滤) - `GET /api/v1/articles``cms:article:edit`,分页/搜索/按状态/栏目/标签过滤)
- `GET /v1/articles/{id}``cms:article:read` - `GET /api/v1/articles/{id}``cms:article:edit`
- `PATCH /v1/articles/{id}``cms:article:write` - `PATCH /api/v1/articles/{id}``cms:article:edit`
- `POST /v1/articles/{id}/publish``cms:article:publish`,发布并生成版本) - `POST /api/v1/articles/{id}/publish``cms:article:publish`,发布并生成版本)
- `POST /v1/articles/{id}/rollback``cms:article:rollback`,回滚到指定版本并生成新版本) - `POST /api/v1/articles/{id}/rollback``cms:article:rollback`,回滚到指定版本并生成新版本)
- `GET /v1/articles/{id}/versions``cms:article:read`,版本列表分页) - `GET /api/v1/articles/{id}/versions``cms:article:edit`,版本列表分页)

View File

@@ -28,6 +28,9 @@ impl Modify for SecurityAddon {
version = "0.1.0", version = "0.1.0",
description = include_str!("../../docs/API.md") description = include_str!("../../docs/API.md")
), ),
servers(
(url = "/api/v1", description = "Canonical API base")
),
paths( paths(
crate::api::handlers::column::create_column_handler, crate::api::handlers::column::create_column_handler,
crate::api::handlers::column::list_columns_handler, crate::api::handlers::column::list_columns_handler,
@@ -66,6 +69,7 @@ impl Modify for SecurityAddon {
crate::domain::models::Media, crate::domain::models::Media,
crate::domain::models::Article, crate::domain::models::Article,
crate::domain::models::ArticleVersion, crate::domain::models::ArticleVersion,
crate::domain::models::Paged<crate::domain::models::Column>,
crate::infrastructure::repositories::article::ArticleWithTags crate::infrastructure::repositories::article::ArticleWithTags
) )
), ),

View File

@@ -8,7 +8,7 @@ use utoipa::IntoParams;
use uuid::Uuid; use uuid::Uuid;
use crate::api::{AppState, handlers::common::extract_bearer_token}; use crate::api::{AppState, handlers::common::extract_bearer_token};
use auth_kit::middleware::{tenant::TenantId, auth::AuthContext}; use auth_kit::middleware::{auth::AuthContext, tenant::TenantId};
#[derive(Debug, serde::Deserialize, utoipa::ToSchema)] #[derive(Debug, serde::Deserialize, utoipa::ToSchema)]
pub struct CreateArticleRequest { pub struct CreateArticleRequest {
@@ -65,7 +65,7 @@ pub fn router() -> Router<AppState> {
#[utoipa::path( #[utoipa::path(
post, post,
path = "/v1/articles", path = "/articles",
tag = "Article", tag = "Article",
request_body = CreateArticleRequest, request_body = CreateArticleRequest,
security( security(
@@ -85,7 +85,12 @@ pub async fn create_article_handler(
let token = extract_bearer_token(&headers)?; let token = extract_bearer_token(&headers)?;
state state
.iam_client .iam_client
.require_permission(tenant_id, user_id, "cms:article:write", &token) .require_any_permissions(
tenant_id,
user_id,
&["cms:article:edit", "cms:article:create"],
&token,
)
.await?; .await?;
let article = state let article = state
@@ -106,7 +111,7 @@ pub async fn create_article_handler(
#[utoipa::path( #[utoipa::path(
get, get,
path = "/v1/articles", path = "/articles",
tag = "Article", tag = "Article",
params(ListArticlesQuery), params(ListArticlesQuery),
security( security(
@@ -122,12 +127,14 @@ pub async fn list_articles_handler(
State(state): State<AppState>, State(state): State<AppState>,
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Query(query): Query<ListArticlesQuery>, Query(query): Query<ListArticlesQuery>,
) -> Result<AppResponse<crate::infrastructure::repositories::column::Paged<crate::domain::models::Article>>, AppError> ) -> Result<
{ AppResponse<crate::infrastructure::repositories::column::Paged<crate::domain::models::Article>>,
AppError,
> {
let token = extract_bearer_token(&headers)?; let token = extract_bearer_token(&headers)?;
state state
.iam_client .iam_client
.require_permission(tenant_id, user_id, "cms:article:read", &token) .require_permission(tenant_id, user_id, "cms:article:edit", &token)
.await?; .await?;
let result = state let result = state
@@ -149,7 +156,7 @@ pub async fn list_articles_handler(
#[utoipa::path( #[utoipa::path(
get, get,
path = "/v1/articles/{id}", path = "/articles/{id}",
tag = "Article", tag = "Article",
params( params(
("id" = String, Path, description = "文章ID") ("id" = String, Path, description = "文章ID")
@@ -171,7 +178,7 @@ pub async fn get_article_handler(
let token = extract_bearer_token(&headers)?; let token = extract_bearer_token(&headers)?;
state state
.iam_client .iam_client
.require_permission(tenant_id, user_id, "cms:article:read", &token) .require_permission(tenant_id, user_id, "cms:article:edit", &token)
.await?; .await?;
let article = state.services.get_article(tenant_id, id).await?; let article = state.services.get_article(tenant_id, id).await?;
@@ -180,7 +187,7 @@ pub async fn get_article_handler(
#[utoipa::path( #[utoipa::path(
patch, patch,
path = "/v1/articles/{id}", path = "/articles/{id}",
tag = "Article", tag = "Article",
request_body = UpdateArticleRequest, request_body = UpdateArticleRequest,
params( params(
@@ -204,7 +211,7 @@ pub async fn update_article_handler(
let token = extract_bearer_token(&headers)?; let token = extract_bearer_token(&headers)?;
state state
.iam_client .iam_client
.require_permission(tenant_id, user_id, "cms:article:write", &token) .require_permission(tenant_id, user_id, "cms:article:edit", &token)
.await?; .await?;
let article = state let article = state
@@ -226,7 +233,7 @@ pub async fn update_article_handler(
#[utoipa::path( #[utoipa::path(
post, post,
path = "/v1/articles/{id}/publish", path = "/articles/{id}/publish",
tag = "Article", tag = "Article",
params( params(
("id" = String, Path, description = "文章ID") ("id" = String, Path, description = "文章ID")
@@ -251,13 +258,16 @@ pub async fn publish_article_handler(
.require_permission(tenant_id, user_id, "cms:article:publish", &token) .require_permission(tenant_id, user_id, "cms:article:publish", &token)
.await?; .await?;
let article = state.services.publish_article(tenant_id, id, Some(user_id)).await?; let article = state
.services
.publish_article(tenant_id, id, Some(user_id))
.await?;
Ok(AppResponse::ok(article)) Ok(AppResponse::ok(article))
} }
#[utoipa::path( #[utoipa::path(
post, post,
path = "/v1/articles/{id}/rollback", path = "/articles/{id}/rollback",
tag = "Version", tag = "Version",
request_body = RollbackRequest, request_body = RollbackRequest,
params( params(
@@ -293,7 +303,7 @@ pub async fn rollback_article_handler(
#[utoipa::path( #[utoipa::path(
get, get,
path = "/v1/articles/{id}/versions", path = "/articles/{id}/versions",
tag = "Version", tag = "Version",
params( params(
("id" = String, Path, description = "文章ID"), ("id" = String, Path, description = "文章ID"),
@@ -303,7 +313,10 @@ pub async fn rollback_article_handler(
("bearer_auth" = []) ("bearer_auth" = [])
), ),
responses( responses(
(status = 200, description = "版本列表", body = crate::infrastructure::repositories::column::Paged<crate::domain::models::ArticleVersion>) (status = 200, description = "版本列表", body = crate::domain::models::Paged<crate::domain::models::ArticleVersion>),
(status = 401, description = "未认证"),
(status = 403, description = "无权限"),
(status = 404, description = "不存在")
) )
)] )]
pub async fn list_versions_handler( pub async fn list_versions_handler(
@@ -313,12 +326,16 @@ pub async fn list_versions_handler(
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Path(id): Path<Uuid>, Path(id): Path<Uuid>,
Query(query): Query<ListVersionsQuery>, Query(query): Query<ListVersionsQuery>,
) -> Result<AppResponse<crate::infrastructure::repositories::column::Paged<crate::domain::models::ArticleVersion>>, AppError> ) -> Result<
{ AppResponse<
crate::infrastructure::repositories::column::Paged<crate::domain::models::ArticleVersion>,
>,
AppError,
> {
let token = extract_bearer_token(&headers)?; let token = extract_bearer_token(&headers)?;
state state
.iam_client .iam_client
.require_permission(tenant_id, user_id, "cms:article:read", &token) .require_permission(tenant_id, user_id, "cms:article:edit", &token)
.await?; .await?;
let versions = state let versions = state

View File

@@ -1,14 +1,16 @@
use axum::{ use axum::{
Router, Router,
extract::Query, extract::Query,
http::{HeaderValue, header}, http::{HeaderMap, HeaderValue, header},
response::{IntoResponse, Redirect}, response::{IntoResponse, Redirect},
routing::get, routing::{get, post},
}; };
use common_telemetry::AppError; use common_telemetry::AppError;
use serde::Deserialize; use serde::Deserialize;
use crate::api::AppState; use crate::api::AppState;
use crate::api::handlers::common::extract_bearer_token;
use auth_kit::middleware::{auth::AuthContext, tenant::TenantId};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct CallbackQuery { pub struct CallbackQuery {
@@ -18,7 +20,6 @@ pub struct CallbackQuery {
} }
#[derive(Debug, Deserialize, serde::Serialize)] #[derive(Debug, Deserialize, serde::Serialize)]
#[serde(rename_all = "camelCase")]
struct Code2TokenRequest { struct Code2TokenRequest {
code: String, code: String,
client_id: String, client_id: String,
@@ -31,10 +32,10 @@ struct RefreshTokenRequest {
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Code2TokenData { struct Code2TokenData {
access_token: String, access_token: String,
refresh_token: String, refresh_token: String,
token_type: Option<String>,
expires_in: usize, expires_in: usize,
tenant_id: String, tenant_id: String,
user_id: String, user_id: String,
@@ -51,6 +52,7 @@ pub fn router() -> Router<AppState> {
Router::new() Router::new()
.route("/callback", get(sso_callback_handler)) .route("/callback", get(sso_callback_handler))
.route("/refresh", get(refresh_token_handler)) .route("/refresh", get(refresh_token_handler))
.route("/logout", post(logout_handler))
} }
fn is_https(headers: &axum::http::HeaderMap) -> bool { fn is_https(headers: &axum::http::HeaderMap) -> bool {
@@ -81,6 +83,30 @@ fn cookie_header(
s s
} }
pub async fn logout_handler(
TenantId(tenant_id): TenantId,
AuthContext { user_id: _, .. }: AuthContext,
axum::extract::State(state): axum::extract::State<AppState>,
headers: HeaderMap,
) -> Result<axum::response::Response, AppError> {
let secure = is_https(&headers);
let token = extract_bearer_token(&headers)?;
let _ = state.iam_client.logout(tenant_id, &token).await;
let mut res = axum::Json(serde_json::json!({})).into_response();
for name in ["accessToken", "refreshToken", "tenantId", "userId"] {
res.headers_mut().append(
header::SET_COOKIE,
HeaderValue::from_str(&cookie_header(name, "", secure, true, Some(0)))
.map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?,
);
}
Ok(res)
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct RefreshTokenQuery { pub struct RefreshTokenQuery {
pub token: String, pub token: String,
@@ -91,8 +117,37 @@ pub async fn refresh_token_handler(
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Query(q): Query<RefreshTokenQuery>, Query(q): Query<RefreshTokenQuery>,
) -> Result<axum::response::Response, AppError> { ) -> Result<axum::response::Response, AppError> {
fn clear_cookie(
res: &mut axum::response::Response,
name: &str,
secure: bool,
) -> Result<(), AppError> {
res.headers_mut().append(
header::SET_COOKIE,
HeaderValue::from_str(&cookie_header(name, "", secure, true, Some(0)))
.map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?,
);
Ok(())
}
fn clear_auth_cookies(
res: &mut axum::response::Response,
secure: bool,
) -> Result<(), AppError> {
clear_cookie(res, "accessToken", secure)?;
clear_cookie(res, "refreshToken", secure)?;
clear_cookie(res, "tenantId", secure)?;
clear_cookie(res, "userId", secure)?;
Ok(())
}
let secure = is_https(&headers);
let target = resolve_front_redirect(q.next.clone());
if q.token.trim().is_empty() { if q.token.trim().is_empty() {
return Ok(Redirect::temporary("/auth-error?message=missing_token").into_response()); let mut res = Redirect::temporary(&target).into_response();
clear_auth_cookies(&mut res, secure)?;
return Ok(res);
} }
let iam_base = std::env::var("IAM_BASE_URL") let iam_base = std::env::var("IAM_BASE_URL")
@@ -117,46 +172,32 @@ pub async fn refresh_token_handler(
// But LoginResponse structure is: access_token, refresh_token, token_type, expires_in. // But LoginResponse structure is: access_token, refresh_token, token_type, expires_in.
// Code2TokenData has tenant_id, user_id extra? // Code2TokenData has tenant_id, user_id extra?
// Let's check IAM service LoginResponse definition. // Let's check IAM service LoginResponse definition.
// IAM Service LoginResponse: access_token, refresh_token, token_type, expires_in.
// Wait, Code2TokenData expects tenant_id and user_id.
// Does IAM refresh endpoint return tenant_id and user_id?
// IAM Service LoginResponse struct in src/models.rs (iam-service) DOES NOT have tenant_id/user_id.
// So we cannot reuse Code2TokenData for refresh response parsing if we expect those fields.
// But usually refresh token response just updates access_token (and maybe refresh_token).
// TenantId and UserId should not change. We can keep existing cookies for them if we don't have them.
// But wait, we are setting cookies. If we don't get tenant_id/user_id, we can't set them (or we re-set them if we knew them).
// The previous cookies are still there. We just need to update access_token and refresh_token.
// Let's define a separate struct for Refresh Response if needed.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct RefreshResponseData { struct RefreshResponseData {
access_token: String, access_token: String,
refresh_token: String, refresh_token: String,
token_type: Option<String>,
expires_in: usize, expires_in: usize,
} }
let body = resp let bytes = resp
.json::<AppResponse<RefreshResponseData>>() .bytes()
.await .await
.map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?; .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?;
let body = serde_json::from_slice::<AppResponse<RefreshResponseData>>(&bytes).ok();
if !status.is_success() || body.code != 0 { if !status.is_success() || body.as_ref().map(|b| b.code).unwrap_or(1) != 0 {
// Refresh failed, redirect to login let mut res = Redirect::temporary(&target).into_response();
let login_url = resolve_front_redirect(q.next); // Actually redirect to front login page or handle error clear_auth_cookies(&mut res, secure)?;
// If refresh fails, we probably want to redirect to the original requested page so it can trigger login flow, return Ok(res);
// OR redirect to auth-error.
// But the middleware calls this. If this returns redirect, the middleware will return redirect.
// If middleware sees error, it should redirect to login.
return Ok(Redirect::temporary("/auth-error?message=refresh_failed").into_response());
} }
let Some(data) = body.data else { let Some(data) = body.and_then(|b| b.data) else {
return Ok(Redirect::temporary("/auth-error?message=invalid_refresh_response").into_response()); let mut res = Redirect::temporary(&target).into_response();
clear_auth_cookies(&mut res, secure)?;
return Ok(res);
}; };
let target = resolve_front_redirect(q.next);
let secure = is_https(&headers);
let mut res = Redirect::temporary(&target).into_response(); let mut res = Redirect::temporary(&target).into_response();
let refresh_max_age = 30_u64 * 24 * 60 * 60; let refresh_max_age = 30_u64 * 24 * 60 * 60;
@@ -211,6 +252,26 @@ fn resolve_front_redirect(next: Option<String>) -> String {
return raw; return raw;
} }
if raw.starts_with("http://") {
if let Some(base) = base {
if raw.starts_with(&base) {
return raw;
}
if cfg!(debug_assertions)
&& (raw.starts_with("http://localhost") || raw.starts_with("http://127.0.0.1"))
{
return raw;
}
return base;
}
if cfg!(debug_assertions)
&& (raw.starts_with("http://localhost") || raw.starts_with("http://127.0.0.1"))
{
return raw;
}
return "/".to_string();
}
base.unwrap_or_else(|| "/".to_string()) base.unwrap_or_else(|| "/".to_string())
} }
@@ -237,12 +298,7 @@ async fn sso_callback_handler(
return Ok(Redirect::temporary(&target).into_response()); return Ok(Redirect::temporary(&target).into_response());
} }
let tenant_id = q let tenant_id = q.tenant_id.as_deref().unwrap_or("").trim().to_string();
.tenant_id
.as_deref()
.unwrap_or("")
.trim()
.to_string();
if uuid::Uuid::parse_str(&tenant_id).is_err() { if uuid::Uuid::parse_str(&tenant_id).is_err() {
let target = resolve_front_error_redirect("missing or invalid tenant_id"); let target = resolve_front_error_redirect("missing or invalid tenant_id");
return Ok(Redirect::temporary(&target).into_response()); return Ok(Redirect::temporary(&target).into_response());

View File

@@ -8,7 +8,7 @@ use utoipa::IntoParams;
use uuid::Uuid; use uuid::Uuid;
use crate::api::{AppState, handlers::common::extract_bearer_token}; use crate::api::{AppState, handlers::common::extract_bearer_token};
use auth_kit::middleware::{tenant::TenantId, auth::AuthContext}; use auth_kit::middleware::{auth::AuthContext, tenant::TenantId};
#[derive(Debug, serde::Deserialize, utoipa::ToSchema)] #[derive(Debug, serde::Deserialize, utoipa::ToSchema)]
pub struct CreateColumnRequest { pub struct CreateColumnRequest {
@@ -49,7 +49,7 @@ pub fn router() -> Router<AppState> {
#[utoipa::path( #[utoipa::path(
post, post,
path = "/v1/columns", path = "/columns",
tag = "Column", tag = "Column",
request_body = CreateColumnRequest, request_body = CreateColumnRequest,
security( security(
@@ -90,14 +90,14 @@ pub async fn create_column_handler(
#[utoipa::path( #[utoipa::path(
get, get,
path = "/v1/columns", path = "/columns",
tag = "Column", tag = "Column",
params(ListColumnsQuery), params(ListColumnsQuery),
security( security(
("bearer_auth" = []) ("bearer_auth" = [])
), ),
responses( responses(
(status = 200, description = "栏目列表", body = crate::infrastructure::repositories::column::Paged<crate::domain::models::Column>), (status = 200, description = "栏目列表", body = crate::domain::models::Paged<crate::domain::models::Column>),
(status = 401, description = "未认证"), (status = 401, description = "未认证"),
(status = 403, description = "无权限") (status = 403, description = "无权限")
) )
@@ -108,8 +108,10 @@ pub async fn list_columns_handler(
State(state): State<AppState>, State(state): State<AppState>,
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Query(query): Query<ListColumnsQuery>, Query(query): Query<ListColumnsQuery>,
) -> Result<AppResponse<crate::infrastructure::repositories::column::Paged<crate::domain::models::Column>>, AppError> ) -> Result<
{ AppResponse<crate::infrastructure::repositories::column::Paged<crate::domain::models::Column>>,
AppError,
> {
let token = extract_bearer_token(&headers)?; let token = extract_bearer_token(&headers)?;
state state
.iam_client .iam_client
@@ -133,7 +135,7 @@ pub async fn list_columns_handler(
#[utoipa::path( #[utoipa::path(
get, get,
path = "/v1/columns/{id}", path = "/columns/{id}",
tag = "Column", tag = "Column",
params( params(
("id" = String, Path, description = "栏目ID") ("id" = String, Path, description = "栏目ID")
@@ -167,7 +169,7 @@ pub async fn get_column_handler(
#[utoipa::path( #[utoipa::path(
patch, patch,
path = "/v1/columns/{id}", path = "/columns/{id}",
tag = "Column", tag = "Column",
request_body = UpdateColumnRequest, request_body = UpdateColumnRequest,
params( params(
@@ -214,7 +216,7 @@ pub async fn update_column_handler(
#[utoipa::path( #[utoipa::path(
delete, delete,
path = "/v1/columns/{id}", path = "/columns/{id}",
tag = "Column", tag = "Column",
params( params(
("id" = String, Path, description = "栏目ID") ("id" = String, Path, description = "栏目ID")

View File

@@ -2,10 +2,31 @@ use axum::http::HeaderMap;
use common_telemetry::AppError; use common_telemetry::AppError;
pub fn extract_bearer_token(headers: &HeaderMap) -> Result<String, AppError> { pub fn extract_bearer_token(headers: &HeaderMap) -> Result<String, AppError> {
let token = headers if let Some(token) = headers
.get(axum::http::header::AUTHORIZATION) .get(axum::http::header::AUTHORIZATION)
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer ")) .and_then(|v| v.strip_prefix("Bearer "))
.ok_or(AppError::MissingAuthHeader)?; {
Ok(token.to_string()) return Ok(token.to_string());
}
let cookie_header = headers
.get(axum::http::header::COOKIE)
.and_then(|h| h.to_str().ok())
.unwrap_or("");
for part in cookie_header.split(';') {
let part = part.trim();
let Some((name, value)) = part.split_once('=') else {
continue;
};
if name.trim() != "accessToken" {
continue;
}
let raw = value.trim();
let decoded = urlencoding::decode(raw).ok().map(|s| s.into_owned());
return Ok(decoded.unwrap_or_else(|| raw.to_string()));
}
Err(AppError::MissingAuthHeader)
} }

View File

@@ -8,7 +8,7 @@ use utoipa::IntoParams;
use uuid::Uuid; use uuid::Uuid;
use crate::api::{AppState, handlers::common::extract_bearer_token}; use crate::api::{AppState, handlers::common::extract_bearer_token};
use auth_kit::middleware::{tenant::TenantId, auth::AuthContext}; use auth_kit::middleware::{auth::AuthContext, tenant::TenantId};
#[derive(Debug, serde::Deserialize, utoipa::ToSchema)] #[derive(Debug, serde::Deserialize, utoipa::ToSchema)]
pub struct CreateMediaRequest { pub struct CreateMediaRequest {
@@ -34,7 +34,7 @@ pub fn router() -> Router<AppState> {
#[utoipa::path( #[utoipa::path(
post, post,
path = "/v1/media", path = "/media",
tag = "Media", tag = "Media",
request_body = CreateMediaRequest, request_body = CreateMediaRequest,
security( security(
@@ -74,14 +74,16 @@ pub async fn create_media_handler(
#[utoipa::path( #[utoipa::path(
get, get,
path = "/v1/media", path = "/media",
tag = "Media", tag = "Media",
params(ListMediaQuery), params(ListMediaQuery),
security( security(
("bearer_auth" = []) ("bearer_auth" = [])
), ),
responses( responses(
(status = 200, description = "媒体列表", body = crate::infrastructure::repositories::column::Paged<crate::domain::models::Media>) (status = 200, description = "媒体列表", body = crate::domain::models::Paged<crate::domain::models::Media>),
(status = 401, description = "未认证"),
(status = 403, description = "无权限")
) )
)] )]
pub async fn list_media_handler( pub async fn list_media_handler(
@@ -90,8 +92,10 @@ pub async fn list_media_handler(
State(state): State<AppState>, State(state): State<AppState>,
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Query(query): Query<ListMediaQuery>, Query(query): Query<ListMediaQuery>,
) -> Result<AppResponse<crate::infrastructure::repositories::column::Paged<crate::domain::models::Media>>, AppError> ) -> Result<
{ AppResponse<crate::infrastructure::repositories::column::Paged<crate::domain::models::Media>>,
AppError,
> {
let token = extract_bearer_token(&headers)?; let token = extract_bearer_token(&headers)?;
state state
.iam_client .iam_client
@@ -114,7 +118,7 @@ pub async fn list_media_handler(
#[utoipa::path( #[utoipa::path(
get, get,
path = "/v1/media/{id}", path = "/media/{id}",
tag = "Media", tag = "Media",
params( params(
("id" = String, Path, description = "媒体ID") ("id" = String, Path, description = "媒体ID")
@@ -145,7 +149,7 @@ pub async fn get_media_handler(
#[utoipa::path( #[utoipa::path(
delete, delete,
path = "/v1/media/{id}", path = "/media/{id}",
tag = "Media", tag = "Media",
params( params(
("id" = String, Path, description = "媒体ID") ("id" = String, Path, description = "媒体ID")

View File

@@ -8,7 +8,7 @@ use utoipa::IntoParams;
use uuid::Uuid; use uuid::Uuid;
use crate::api::{AppState, handlers::common::extract_bearer_token}; use crate::api::{AppState, handlers::common::extract_bearer_token};
use auth_kit::middleware::{tenant::TenantId, auth::AuthContext}; use auth_kit::middleware::{auth::AuthContext, tenant::TenantId};
#[derive(Debug, serde::Deserialize, utoipa::ToSchema)] #[derive(Debug, serde::Deserialize, utoipa::ToSchema)]
pub struct CreateTagRequest { pub struct CreateTagRequest {
@@ -44,7 +44,7 @@ pub fn router() -> Router<AppState> {
#[utoipa::path( #[utoipa::path(
post, post,
path = "/v1/tags", path = "/tags",
tag = "Tag", tag = "Tag",
request_body = CreateTagRequest, request_body = CreateTagRequest,
security( security(
@@ -76,14 +76,16 @@ pub async fn create_tag_handler(
#[utoipa::path( #[utoipa::path(
get, get,
path = "/v1/tags", path = "/tags",
tag = "Tag", tag = "Tag",
params(ListTagsQuery), params(ListTagsQuery),
security( security(
("bearer_auth" = []) ("bearer_auth" = [])
), ),
responses( responses(
(status = 200, description = "标签/分类列表", body = crate::infrastructure::repositories::column::Paged<crate::domain::models::Tag>) (status = 200, description = "标签/分类列表", body = crate::domain::models::Paged<crate::domain::models::Tag>),
(status = 401, description = "未认证"),
(status = 403, description = "无权限")
) )
)] )]
pub async fn list_tags_handler( pub async fn list_tags_handler(
@@ -92,8 +94,10 @@ pub async fn list_tags_handler(
State(state): State<AppState>, State(state): State<AppState>,
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Query(query): Query<ListTagsQuery>, Query(query): Query<ListTagsQuery>,
) -> Result<AppResponse<crate::infrastructure::repositories::column::Paged<crate::domain::models::Tag>>, AppError> ) -> Result<
{ AppResponse<crate::infrastructure::repositories::column::Paged<crate::domain::models::Tag>>,
AppError,
> {
let token = extract_bearer_token(&headers)?; let token = extract_bearer_token(&headers)?;
state state
.iam_client .iam_client
@@ -117,7 +121,7 @@ pub async fn list_tags_handler(
#[utoipa::path( #[utoipa::path(
get, get,
path = "/v1/tags/{id}", path = "/tags/{id}",
tag = "Tag", tag = "Tag",
params( params(
("id" = String, Path, description = "标签/分类ID") ("id" = String, Path, description = "标签/分类ID")
@@ -148,7 +152,7 @@ pub async fn get_tag_handler(
#[utoipa::path( #[utoipa::path(
patch, patch,
path = "/v1/tags/{id}", path = "/tags/{id}",
tag = "Tag", tag = "Tag",
request_body = UpdateTagRequest, request_body = UpdateTagRequest,
params( params(
@@ -184,7 +188,7 @@ pub async fn update_tag_handler(
#[utoipa::path( #[utoipa::path(
delete, delete,
path = "/v1/tags/{id}", path = "/tags/{id}",
tag = "Tag", tag = "Tag",
params( params(
("id" = String, Path, description = "标签/分类ID") ("id" = String, Path, description = "标签/分类ID")

View File

@@ -2,14 +2,15 @@ pub mod docs;
pub mod handlers; pub mod handlers;
pub mod middleware; pub mod middleware;
use axum::routing::get;
use axum::Router; use axum::Router;
use axum::routing::get;
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_scalar::{Scalar, Servable}; use utoipa_scalar::{Scalar, Servable};
use crate::api::docs::ApiDoc; use crate::api::docs::ApiDoc;
use crate::api::middleware::{catch_panic, request_logger}; use crate::api::middleware::{catch_panic, request_logger};
use crate::application::services::CmsServices; use crate::application::services::CmsServices;
use crate::constants::CANONICAL_BASE;
use crate::infrastructure::iam_client::IamClient; use crate::infrastructure::iam_client::IamClient;
#[derive(Clone)] #[derive(Clone)]
@@ -21,20 +22,25 @@ pub struct AppState {
pub fn build_router(state: AppState) -> Router { pub fn build_router(state: AppState) -> Router {
let health = Router::new().route("/healthz", get(|| async { axum::http::StatusCode::OK })); let health = Router::new().route("/healthz", get(|| async { axum::http::StatusCode::OK }));
let auth = Router::new().nest("/auth", handlers::auth::router());
let v1 = Router::new() let v1 = Router::new()
.nest("/auth", handlers::auth::router())
.nest("/columns", handlers::column::router()) .nest("/columns", handlers::column::router())
.nest("/tags", handlers::tag::router()) .nest("/tags", handlers::tag::router())
.nest("/media", handlers::media::router()) .nest("/media", handlers::media::router())
.nest("/articles", handlers::article::router()); .nest("/articles", handlers::article::router());
Router::new() Router::new()
.route("/favicon.ico", get(|| async { axum::http::StatusCode::NO_CONTENT })) .route(
"/favicon.ico",
get(|| async { axum::http::StatusCode::NO_CONTENT }),
)
.merge(Scalar::with_url("/scalar", ApiDoc::openapi())) .merge(Scalar::with_url("/scalar", ApiDoc::openapi()))
.route(
"/scalar/openapi.json",
get(|| async { axum::Json(ApiDoc::openapi()) }),
)
.merge(health) .merge(health)
.merge(auth) .nest(CANONICAL_BASE, v1)
.nest("/v1", v1)
.layer(axum::middleware::from_fn(catch_panic)) .layer(axum::middleware::from_fn(catch_panic))
.layer(axum::middleware::from_fn(request_logger)) .layer(axum::middleware::from_fn(request_logger))
.with_state(state) .with_state(state)

1
src/constants.rs Normal file
View File

@@ -0,0 +1 @@
pub const CANONICAL_BASE: &str = "/api/v1";

View File

@@ -3,6 +3,15 @@ use sqlx::FromRow;
use utoipa::ToSchema; use utoipa::ToSchema;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct Paged<T: ToSchema + Sized> {
pub items: Vec<T>,
pub total: i64,
pub page: i32,
pub page_size: i32,
pub total_pages: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, FromRow)] #[derive(Debug, Clone, Serialize, Deserialize, ToSchema, FromRow)]
pub struct Column { pub struct Column {
pub tenant_id: Uuid, pub tenant_id: Uuid,

View File

@@ -9,6 +9,10 @@ use dashmap::DashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
use crate::constants::CANONICAL_BASE;
const CMS_ADMIN_PERMISSION: &str = "cms:admin";
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct IamClientConfig { pub struct IamClientConfig {
pub base_url: String, pub base_url: String,
@@ -70,6 +74,42 @@ struct AuthorizationCheckResponse {
allowed: bool, allowed: bool,
} }
#[derive(Debug, Serialize)]
struct AuthorizationExprCheckRequest {
expr: PermissionExpr,
}
#[derive(Debug, Deserialize)]
struct AuthorizationExprCheckResponse {
allowed: bool,
}
#[derive(Debug, Serialize)]
#[serde(deny_unknown_fields)]
struct AnyExpr {
any: Vec<PermissionExprItem>,
}
#[derive(Debug, Serialize)]
#[serde(deny_unknown_fields)]
struct AllExpr {
all: Vec<PermissionExprItem>,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum PermissionExpr {
Any(AnyExpr),
All(AllExpr),
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum PermissionExprItem {
Permission(String),
Expr(Box<PermissionExpr>),
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ApiSuccessResponse<T> { struct ApiSuccessResponse<T> {
#[allow(dead_code)] #[allow(dead_code)]
@@ -111,6 +151,146 @@ impl IamClient {
} }
} }
pub async fn logout(&self, tenant_id: Uuid, access_token: &str) -> Result<(), AppError> {
let base = self.inner.cfg.base_url.trim_end_matches('/');
let api_base = if base.ends_with(CANONICAL_BASE) {
base.to_string()
} else {
format!("{}{}", base, CANONICAL_BASE)
};
let url = format!("{}/auth/logout", api_base);
let resp = self
.inner
.http
.post(url)
.bearer_auth(access_token)
.header("X-Tenant-ID", tenant_id.to_string())
.send()
.await
.map_err(|e| AppError::ExternalReqError(format!("iam:request_failed:{}", e)))?;
let status = resp.status();
if status == reqwest::StatusCode::UNAUTHORIZED {
return Err(AppError::AuthError("iam:unauthorized".into()));
}
if !status.is_success() {
return Err(AppError::ExternalReqError(format!(
"iam:unexpected_status:{}",
status.as_u16()
)));
}
Ok(())
}
async fn check_permission_expr(
&self,
tenant_id: Uuid,
access_token: &str,
expr: PermissionExpr,
) -> Result<bool, AppError> {
let expr = with_admin_override(expr);
let base = self.inner.cfg.base_url.trim_end_matches('/');
let api_base = if base.ends_with(CANONICAL_BASE) {
base.to_string()
} else {
format!("{}{}", base, CANONICAL_BASE)
};
let url = format!("{}/authorize/check-expr", api_base);
let resp = self
.inner
.http
.post(url)
.bearer_auth(access_token)
.header("X-Tenant-ID", tenant_id.to_string())
.json(&AuthorizationExprCheckRequest { expr })
.send()
.await
.map_err(|e| AppError::ExternalReqError(format!("iam:request_failed:{}", e)))?;
let status = resp.status();
if status == reqwest::StatusCode::UNAUTHORIZED {
return Err(AppError::AuthError("iam:unauthorized".into()));
}
if status == reqwest::StatusCode::FORBIDDEN {
return Err(AppError::PermissionDenied("iam:forbidden".into()));
}
if !status.is_success() {
return Err(AppError::ExternalReqError(format!(
"iam:unexpected_status:{}",
status.as_u16()
)));
}
let body: ApiSuccessResponse<AuthorizationExprCheckResponse> = resp
.json()
.await
.map_err(|e| AppError::ExternalReqError(format!("iam:decode_failed:{}", e)))?;
let allowed = body
.data
.map(|d| d.allowed)
.ok_or_else(|| AppError::ExternalReqError("iam:missing_data".into()))?;
Ok(allowed)
}
async fn require_permission_expr(
&self,
tenant_id: Uuid,
user_id: Uuid,
expr: PermissionExpr,
access_token: &str,
) -> Result<(), AppError> {
let allowed = self
.check_permission_expr(tenant_id, access_token, expr)
.await?;
if allowed {
Ok(())
} else {
Err(AppError::PermissionDenied(format!(
"iam:expr_denied:{}:{}",
tenant_id, user_id
)))
}
}
pub async fn require_any_permissions(
&self,
tenant_id: Uuid,
user_id: Uuid,
permissions: &[&str],
access_token: &str,
) -> Result<(), AppError> {
let expr = PermissionExpr::Any(AnyExpr {
any: permissions
.iter()
.map(|p| PermissionExprItem::Permission((*p).to_string()))
.collect(),
});
self.require_permission_expr(tenant_id, user_id, expr, access_token)
.await
}
pub async fn require_all_permissions(
&self,
tenant_id: Uuid,
user_id: Uuid,
permissions: &[&str],
access_token: &str,
) -> Result<(), AppError> {
let expr = PermissionExpr::All(AllExpr {
all: permissions
.iter()
.map(|p| PermissionExprItem::Permission((*p).to_string()))
.collect(),
});
self.require_permission_expr(tenant_id, user_id, expr, access_token)
.await
}
async fn check_permission( async fn check_permission(
&self, &self,
tenant_id: Uuid, tenant_id: Uuid,
@@ -177,48 +357,28 @@ impl IamClient {
permission: &str, permission: &str,
access_token: &str, access_token: &str,
) -> Result<bool, AppError> { ) -> Result<bool, AppError> {
let url = format!( let expr = PermissionExpr::Any(AnyExpr {
"{}/authorize/check", any: vec![PermissionExprItem::Permission(permission.to_string())],
self.inner.cfg.base_url.trim_end_matches('/') });
); self.check_permission_expr(tenant_id, access_token, expr)
let resp = self
.inner
.http
.post(url)
.bearer_auth(access_token)
.header("X-Tenant-ID", tenant_id.to_string())
.json(&AuthorizationCheckRequest {
permission: permission.to_string(),
})
.send()
.await .await
.map_err(|e| AppError::ExternalReqError(format!("iam:request_failed:{}", e)))?; }
}
let status = resp.status();
if status == reqwest::StatusCode::UNAUTHORIZED { fn with_admin_override(expr: PermissionExpr) -> PermissionExpr {
return Err(AppError::AuthError("iam:unauthorized".into())); match expr {
} PermissionExpr::Any(mut x) => {
if status == reqwest::StatusCode::FORBIDDEN { x.any.insert(
return Err(AppError::PermissionDenied("iam:forbidden".into())); 0,
} PermissionExprItem::Permission(CMS_ADMIN_PERMISSION.to_string()),
if !status.is_success() { );
return Err(AppError::ExternalReqError(format!( PermissionExpr::Any(x)
"iam:unexpected_status:{}", }
status.as_u16() PermissionExpr::All(x) => PermissionExpr::Any(AnyExpr {
))); any: vec![
} PermissionExprItem::Permission(CMS_ADMIN_PERMISSION.to_string()),
PermissionExprItem::Expr(Box::new(PermissionExpr::All(x))),
let body: ApiSuccessResponse<AuthorizationCheckResponse> = resp ],
.json() }),
.await
.map_err(|e| AppError::ExternalReqError(format!("iam:decode_failed:{}", e)))?;
let allowed = body
.data
.map(|d| d.allowed)
.ok_or_else(|| AppError::ExternalReqError("iam:missing_data".into()))?;
Ok(allowed)
} }
} }

View File

@@ -1,5 +1,6 @@
pub mod api; pub mod api;
pub mod application; pub mod application;
pub mod config; pub mod config;
pub mod constants;
pub mod domain; pub mod domain;
pub mod infrastructure; pub mod infrastructure;

View File

@@ -1,12 +1,16 @@
use auth_kit::middleware::{auth::AuthMiddlewareConfig, tenant::TenantMiddlewareConfig};
use axum::middleware::{from_fn, from_fn_with_state}; use axum::middleware::{from_fn, from_fn_with_state};
use cms_service::{ use cms_service::{
api::{self, AppState}, api::{self, AppState},
application::services::CmsServices, application::services::CmsServices,
config::AppConfig, config::AppConfig,
infrastructure::{db, iam_client::{IamClient, IamClientConfig}}, constants::CANONICAL_BASE,
infrastructure::{
db,
iam_client::{IamClient, IamClientConfig},
},
}; };
use common_telemetry::telemetry::{self, TelemetryConfig}; use common_telemetry::telemetry::{self, TelemetryConfig};
use auth_kit::middleware::{tenant::TenantMiddlewareConfig, auth::AuthMiddlewareConfig};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
@@ -23,7 +27,9 @@ async fn main() {
log_file: Some(config.log_file_name.clone()), log_file: Some(config.log_file_name.clone()),
}); });
let pool = db::init_pool(&config).await.expect("failed to init db pool"); let pool = db::init_pool(&config)
.await
.expect("failed to init db pool");
let run_migrations = std::env::var("RUN_MIGRATIONS") let run_migrations = std::env::var("RUN_MIGRATIONS")
.ok() .ok()
.map(|v| matches!(v.as_str(), "1" | "true" | "TRUE")) .map(|v| matches!(v.as_str(), "1" | "true" | "TRUE"))
@@ -46,17 +52,23 @@ async fn main() {
}; };
let auth_cfg = AuthMiddlewareConfig { let auth_cfg = AuthMiddlewareConfig {
skip_exact_paths: vec!["/healthz".to_string(), "/auth/callback".to_string()], skip_exact_paths: vec![
"/healthz".to_string(),
format!("{}/auth/callback", CANONICAL_BASE),
format!("{}/auth/refresh", CANONICAL_BASE),
],
skip_path_prefixes: vec!["/scalar".to_string()], skip_path_prefixes: vec!["/scalar".to_string()],
jwt: match &config.jwt_public_key_pem { jwt: match &config.jwt_public_key_pem {
Some(pem) => auth_kit::jwt::JwtVerifyConfig::rs256_from_pem("iam-service", pem) Some(pem) => auth_kit::jwt::JwtVerifyConfig::rs256_from_pem("iam-service", pem)
.expect("invalid JWT_PUBLIC_KEY_PEM"), .expect("invalid JWT_PUBLIC_KEY_PEM"),
None => { None => {
let jwks_url = config.iam_jwks_url.clone().unwrap_or_else(|| { let jwks_url = config.iam_jwks_url.clone().unwrap_or_else(|| {
format!( let base = config.iam_base_url.trim_end_matches('/');
"{}/.well-known/jwks.json", if base.ends_with(CANONICAL_BASE) {
config.iam_base_url.trim_end_matches('/') format!("{}/.well-known/jwks.json", base)
) } else {
format!("{}{}/.well-known/jwks.json", base, CANONICAL_BASE)
}
}); });
auth_kit::jwt::JwtVerifyConfig::rs256_from_jwks("iam-service", &jwks_url) auth_kit::jwt::JwtVerifyConfig::rs256_from_jwks("iam-service", &jwks_url)
.expect("invalid IAM_JWKS_URL") .expect("invalid IAM_JWKS_URL")
@@ -64,7 +76,11 @@ async fn main() {
}, },
}; };
let tenant_cfg = TenantMiddlewareConfig { let tenant_cfg = TenantMiddlewareConfig {
skip_exact_paths: vec!["/healthz".to_string(), "/auth/callback".to_string()], skip_exact_paths: vec![
"/healthz".to_string(),
format!("{}/auth/callback", CANONICAL_BASE),
format!("{}/auth/refresh", CANONICAL_BASE),
],
skip_path_prefixes: vec!["/scalar".to_string()], skip_path_prefixes: vec!["/scalar".to_string()],
}; };
@@ -77,7 +93,9 @@ async fn main() {
auth_cfg, auth_cfg,
auth_kit::middleware::auth::authenticate_with_config, auth_kit::middleware::auth::authenticate_with_config,
)) ))
.layer(from_fn(common_telemetry::axum_middleware::trace_http_request)) .layer(from_fn(
common_telemetry::axum_middleware::trace_http_request,
))
.layer(from_fn(cms_service::api::middleware::ensure_request_id)); .layer(from_fn(cms_service::api::middleware::ensure_request_id));
let addr = SocketAddr::from(([0, 0, 0, 0], config.port)); let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
@@ -85,7 +103,10 @@ async fn main() {
tracing::info!("Docs available at http://{}/scalar", addr); tracing::info!("Docs available at http://{}/scalar", addr);
let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()) axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await .await
.unwrap(); .unwrap();
} }

View File

@@ -4,15 +4,11 @@ use std::sync::{
}; };
use std::time::Duration; use std::time::Duration;
use axum::{Json, Router, routing::post};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::{Json, Router, routing::post};
use cms_service::infrastructure::iam_client::{IamClient, IamClientConfig}; use cms_service::infrastructure::iam_client::{IamClient, IamClientConfig};
use serde::{Deserialize, Serialize}; use serde::Serialize;
use serde_json::Value;
#[derive(Debug, Deserialize)]
struct AuthorizationCheckRequest {
permission: String,
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct AuthorizationCheckResponse { struct AuthorizationCheckResponse {
@@ -31,9 +27,7 @@ async fn start_mock_iam(
call_count: Arc<AtomicUsize>, call_count: Arc<AtomicUsize>,
fail: Arc<AtomicBool>, fail: Arc<AtomicBool>,
) -> (String, tokio::task::JoinHandle<()>) { ) -> (String, tokio::task::JoinHandle<()>) {
let app = Router::new().route( let handler = move |Json(_body): Json<Value>| {
"/authorize/check",
post(move |Json(body): Json<AuthorizationCheckRequest>| {
let call_count = call_count.clone(); let call_count = call_count.clone();
let fail = fail.clone(); let fail = fail.clone();
async move { async move {
@@ -42,17 +36,19 @@ async fn start_mock_iam(
return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "fail").into_response(); return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "fail").into_response();
} }
let allowed = body.permission == "cms:article:read";
let resp = ApiSuccessResponse { let resp = ApiSuccessResponse {
code: 0, code: 0,
message: "ok".to_string(), message: "ok".to_string(),
data: AuthorizationCheckResponse { allowed }, data: AuthorizationCheckResponse { allowed: true },
trace_id: None, trace_id: None,
}; };
(axum::http::StatusCode::OK, Json(resp)).into_response() (axum::http::StatusCode::OK, Json(resp)).into_response()
} }
}), };
);
let app = Router::new()
.route("/authorize/check-expr", post(handler.clone()))
.route("/api/v1/authorize/check-expr", post(handler));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap(); let addr = listener.local_addr().unwrap();
@@ -81,11 +77,11 @@ async fn iam_client_caches_decisions() {
let user_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4();
client client
.require_permission(tenant_id, user_id, "cms:article:read", "token") .require_permission(tenant_id, user_id, "cms:article:edit", "token")
.await .await
.unwrap(); .unwrap();
client client
.require_permission(tenant_id, user_id, "cms:article:read", "token") .require_permission(tenant_id, user_id, "cms:article:edit", "token")
.await .await
.unwrap(); .unwrap();
@@ -111,7 +107,7 @@ async fn iam_client_uses_stale_cache_on_error() {
let user_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4();
client client
.require_permission(tenant_id, user_id, "cms:article:read", "token") .require_permission(tenant_id, user_id, "cms:article:edit", "token")
.await .await
.unwrap(); .unwrap();
@@ -119,7 +115,7 @@ async fn iam_client_uses_stale_cache_on_error() {
fail.store(true, Ordering::SeqCst); fail.store(true, Ordering::SeqCst);
client client
.require_permission(tenant_id, user_id, "cms:article:read", "token") .require_permission(tenant_id, user_id, "cms:article:edit", "token")
.await .await
.unwrap(); .unwrap();

68
tests/iam_client_expr.rs Normal file
View File

@@ -0,0 +1,68 @@
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use std::time::Duration;
use axum::{Json, Router, routing::post};
use axum::response::IntoResponse;
use cms_service::infrastructure::iam_client::{IamClient, IamClientConfig};
use serde_json::Value;
async fn start_mock_iam(call_count: Arc<AtomicUsize>) -> (String, tokio::task::JoinHandle<()>) {
let app = Router::new().route(
"/api/v1/authorize/check-expr",
post(move |Json(body): Json<Value>| {
let call_count = call_count.clone();
async move {
call_count.fetch_add(1, Ordering::SeqCst);
let allowed = body.get("expr").is_some();
let resp = serde_json::json!({
"code": 0,
"message": "ok",
"data": { "allowed": allowed }
});
(axum::http::StatusCode::OK, Json(resp)).into_response()
}
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let base_url = format!("http://{}", addr);
let handle = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(base_url, handle)
}
#[tokio::test]
async fn iam_client_check_expr_hits_endpoint() {
let call_count = Arc::new(AtomicUsize::new(0));
let (base_url, handle) = start_mock_iam(call_count.clone()).await;
let client = IamClient::new(IamClientConfig {
base_url,
timeout: Duration::from_millis(500),
cache_ttl: Duration::from_secs(5),
cache_stale_if_error: Duration::from_secs(30),
cache_max_entries: 1000,
});
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
client
.require_any_permissions(
tenant_id,
user_id,
&["cms:article:edit", "cms:article:create"],
"token",
)
.await
.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 1);
handle.abort();
}

View File

@@ -0,0 +1,40 @@
use serde::Deserialize;
#[derive(Debug, Deserialize)]
struct AppResponse<T> {
code: i32,
message: String,
data: Option<T>,
}
#[derive(Debug, Deserialize)]
struct RefreshResponseData {
access_token: String,
refresh_token: String,
token_type: Option<String>,
expires_in: usize,
}
#[test]
fn can_decode_iam_refresh_response_snake_case() {
let json = r#"
{
"code": 0,
"message": "ok",
"data": {
"access_token": "a",
"refresh_token": "r",
"token_type": "Bearer",
"expires_in": 7200
}
}
"#;
let parsed: AppResponse<RefreshResponseData> = serde_json::from_str(json).unwrap();
let data = parsed.data.unwrap();
assert_eq!(data.access_token, "a");
assert_eq!(data.refresh_token, "r");
assert_eq!(data.token_type.as_deref(), Some("Bearer"));
assert_eq!(data.expires_in, 7200);
}