diff --git a/.env.example b/.env.example index dea9cfa..f04694f 100644 --- a/.env.example +++ b/.env.example @@ -10,6 +10,8 @@ DATABASE_URL=postgres://cms_service_user:cms_service_password@127.0.0.1:5432/cms DB_MAX_CONNECTIONS=20 DB_MIN_CONNECTIONS=5 +RUN_MIGRATIONS=0 + IAM_BASE_URL=http://127.0.0.1:3000 IAM_JWKS_URL= JWT_PUBLIC_KEY_PEM= @@ -17,3 +19,10 @@ IAM_TIMEOUT_MS=2000 IAM_CACHE_TTL_SECONDS=10 IAM_STALE_IF_ERROR_SECONDS=60 IAM_CACHE_MAX_ENTRIES=50000 + +# SSO:业务服务端回调(/auth/callback)用到的 IAM client 凭证 +IAM_CLIENT_ID=cms +IAM_CLIENT_SECRET=please_replace_with_client_secret + +# SSO:回调处理完成后跳回 CMS 前端的基准地址(用于限制 next 的 open redirect) +CMS_FRONT_BASE_URL=https://cms.example.com diff --git a/Cargo.lock b/Cargo.lock index f691e64..8ed6375 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -280,6 +280,7 @@ dependencies = [ "tower", "tracing", "tracing-subscriber", + "urlencoding", "utoipa", "utoipa-scalar", "uuid", @@ -2832,6 +2833,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index 299b141..a8c9fe6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,3 +50,4 @@ reqwest = { version = "0.12", default-features = false, features = [ "json", "rustls-tls", ] } +urlencoding = "2.1.3" diff --git a/README.md b/README.md index 32c34e3..d5b9e27 100644 --- a/README.md +++ b/README.md @@ -36,13 +36,17 @@ DDD 分层目录: 1. 复制并修改环境变量: - `cp .env.example .env` 2. 准备 PostgreSQL 并配置 `DATABASE_URL` -3. 启动服务(会自动运行 migrations): +3. 执行数据库迁移(推荐使用脚本体系): + - `./scripts/db/migrate.sh` + - `./scripts/db/verify.sh` +4. 启动服务: - `cargo run` ## 文档 - Scalar:`GET /scalar` - 健康检查:`GET /healthz` +- SSO 回调(code → token → Set-Cookie):`GET /auth/callback?code=...&next=...` ## API(v1) @@ -95,7 +99,8 @@ CMS 运行时依赖 IAM 提供以下能力: ## 数据库迁移 - 迁移文件目录: [migrations](file:///home/shay/project/backend/cms-service/migrations) -- 启动时自动执行:见 [db::run_migrations](file:///home/shay/project/backend/cms-service/src/infrastructure/db/mod.rs#L14-L16) +- 启动时默认不执行迁移;如需在本地启动时自动执行,设置 `RUN_MIGRATIONS=1` +- 服务内迁移入口:见 [db::run_migrations](file:///home/shay/project/backend/cms-service/src/infrastructure/db/mod.rs#L14-L16) - 运维脚本(migrate/verify/rollback):见 [scripts/db/README.md](file:///home/shay/project/backend/cms-service/scripts/db/README.md) ## 测试 diff --git a/src/api/handlers/auth.rs b/src/api/handlers/auth.rs new file mode 100644 index 0000000..0d0eb2c --- /dev/null +++ b/src/api/handlers/auth.rs @@ -0,0 +1,213 @@ +use axum::{ + Router, + extract::Query, + http::{HeaderValue, header}, + response::{IntoResponse, Redirect}, + routing::get, +}; +use common_telemetry::AppError; +use serde::Deserialize; + +use crate::api::AppState; + +#[derive(Debug, Deserialize)] +pub struct CallbackQuery { + pub code: String, + pub next: Option, +} + +#[derive(Debug, Deserialize, serde::Serialize)] +#[serde(rename_all = "camelCase")] +struct Code2TokenRequest { + code: String, + client_id: String, + client_secret: String, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct Code2TokenData { + access_token: String, + refresh_token: String, + expires_in: usize, + token_type: String, + tenant_id: String, + user_id: String, +} + +#[derive(Debug, Deserialize)] +struct AppResponse { + code: i32, + message: String, + data: Option, +} + +pub fn router() -> Router { + Router::new().route("/callback", get(sso_callback_handler)) +} + +fn is_https(headers: &axum::http::HeaderMap) -> bool { + headers + .get("x-forwarded-proto") + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("https")) + .unwrap_or(false) +} + +fn cookie_header( + name: &str, + value: &str, + secure: bool, + http_only: bool, + max_age: Option, +) -> String { + let mut s = format!("{}={}; Path=/; SameSite=Strict", name, value); + if secure { + s.push_str("; Secure"); + } + if http_only { + s.push_str("; HttpOnly"); + } + if let Some(v) = max_age { + s.push_str(&format!("; Max-Age={}", v)); + } + s +} + +fn resolve_front_redirect(next: Option) -> String { + let base = std::env::var("CMS_FRONT_BASE_URL").ok(); + + let Some(raw) = next else { + return base.unwrap_or_else(|| "/".to_string()); + }; + + if raw.starts_with('/') { + return raw; + } + + if raw.starts_with("https://") { + if let Some(base) = base { + if raw.starts_with(&base) { + return raw; + } + return base; + } + return raw; + } + + base.unwrap_or_else(|| "/".to_string()) +} + +fn resolve_front_error_redirect(message: &str) -> String { + let base = std::env::var("CMS_FRONT_BASE_URL").ok(); + let encoded = urlencoding::encode(message); + if let Some(base) = base { + format!( + "{}/auth-error?message={}", + base.trim_end_matches('/'), + encoded + ) + } else { + format!("/auth-error?message={}", encoded) + } +} + +async fn sso_callback_handler( + headers: axum::http::HeaderMap, + Query(q): Query, +) -> Result { + if q.code.trim().is_empty() { + let target = resolve_front_error_redirect("missing code"); + return Ok(Redirect::temporary(&target).into_response()); + } + + let iam_base = std::env::var("IAM_BASE_URL") + .or_else(|_| std::env::var("IAM_SERVICE_BASE_URL")) + .map_err(|_| AppError::ConfigError("IAM_BASE_URL is required".into()))?; + let client_id = std::env::var("IAM_CLIENT_ID") + .map_err(|_| AppError::ConfigError("IAM_CLIENT_ID is required".into()))?; + let client_secret = std::env::var("IAM_CLIENT_SECRET") + .map_err(|_| AppError::ConfigError("IAM_CLIENT_SECRET is required".into()))?; + + let http = reqwest::Client::new(); + let resp = http + .post(format!( + "{}/iam/api/v1/auth/code2token", + iam_base.trim_end_matches('/') + )) + .json(&Code2TokenRequest { + code: q.code, + client_id, + client_secret, + }) + .send() + .await + .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?; + + let status = resp.status(); + let body = resp + .json::>() + .await + .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?; + + if !status.is_success() || body.code != 0 { + let target = resolve_front_error_redirect(&body.message); + return Ok(Redirect::temporary(&target).into_response()); + } + + let Some(data) = body.data else { + let target = resolve_front_error_redirect("invalid code2token response"); + return Ok(Redirect::temporary(&target).into_response()); + }; + + let target = resolve_front_redirect(q.next); + let secure = is_https(&headers); + let mut res = Redirect::temporary(&target).into_response(); + let refresh_max_age = 30_u64 * 24 * 60 * 60; + res.headers_mut().append( + header::SET_COOKIE, + HeaderValue::from_str(&cookie_header( + "accessToken", + &data.access_token, + secure, + true, + Some(data.expires_in as u64), + )) + .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?, + ); + res.headers_mut().append( + header::SET_COOKIE, + HeaderValue::from_str(&cookie_header( + "refreshToken", + &data.refresh_token, + secure, + true, + Some(refresh_max_age), + )) + .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?, + ); + res.headers_mut().append( + header::SET_COOKIE, + HeaderValue::from_str(&cookie_header( + "tenantId", + &data.tenant_id, + secure, + true, + Some(refresh_max_age), + )) + .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?, + ); + res.headers_mut().append( + header::SET_COOKIE, + HeaderValue::from_str(&cookie_header( + "userId", + &data.user_id, + secure, + true, + Some(refresh_max_age), + )) + .map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?, + ); + + Ok(res) +} diff --git a/src/api/handlers/mod.rs b/src/api/handlers/mod.rs index 3f9d4bd..ce63085 100644 --- a/src/api/handlers/mod.rs +++ b/src/api/handlers/mod.rs @@ -1,4 +1,5 @@ pub mod article; +pub mod auth; pub mod column; pub mod common; pub mod media; diff --git a/src/api/mod.rs b/src/api/mod.rs index 1968205..9c61478 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -21,6 +21,8 @@ pub struct AppState { pub fn build_router(state: AppState) -> Router { let health = Router::new().route("/healthz", get(|| async { axum::http::StatusCode::OK })); + let auth = Router::new().nest("/auth", handlers::auth::router()); + let v1 = Router::new() .nest("/columns", handlers::column::router()) .nest("/tags", handlers::tag::router()) @@ -31,6 +33,7 @@ pub fn build_router(state: AppState) -> Router { .route("/favicon.ico", get(|| async { axum::http::StatusCode::NO_CONTENT })) .merge(Scalar::with_url("/scalar", ApiDoc::openapi())) .merge(health) + .merge(auth) .nest("/v1", v1) .layer(axum::middleware::from_fn(catch_panic)) .layer(axum::middleware::from_fn(request_logger)) diff --git a/src/main.rs b/src/main.rs index 9f0279e..de16f95 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,9 +24,15 @@ async fn main() { }); let pool = db::init_pool(&config).await.expect("failed to init db pool"); - db::run_migrations(&pool) - .await - .expect("failed to run migrations"); + let run_migrations = std::env::var("RUN_MIGRATIONS") + .ok() + .map(|v| matches!(v.as_str(), "1" | "true" | "TRUE")) + .unwrap_or(false); + if run_migrations { + db::run_migrations(&pool) + .await + .expect("failed to run migrations"); + } let state = AppState { services: CmsServices::new(pool), @@ -40,7 +46,7 @@ async fn main() { }; let auth_cfg = AuthMiddlewareConfig { - skip_exact_paths: vec!["/healthz".to_string()], + skip_exact_paths: vec!["/healthz".to_string(), "/auth/callback".to_string()], skip_path_prefixes: vec!["/scalar".to_string()], jwt: match &config.jwt_public_key_pem { Some(pem) => auth_kit::jwt::JwtVerifyConfig::rs256_from_pem("iam-service", pem) @@ -58,7 +64,7 @@ async fn main() { }, }; let tenant_cfg = TenantMiddlewareConfig { - skip_exact_paths: vec!["/healthz".to_string()], + skip_exact_paths: vec!["/healthz".to_string(), "/auth/callback".to_string()], skip_path_prefixes: vec!["/scalar".to_string()], };