commit ccead7e9f61fe78110c36240fac38df73d2e4e8b Author: shay7sev Date: Mon Feb 2 14:30:53 2026 +0800 feat(project): init diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..b7c4f6d --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,5 @@ +[registries.kellnr] +index = "sparse+https://kellnr.shay7sev.site/api/v1/crates/" + +[net] +git-fetch-with-cli = true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..869df07 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..7d9cf48 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "auth-kit" +version = "0.1.0" +edition = "2024" + +[dependencies] +common-telemetry = { version = "0.1.5", registry = "kellnr", default-features = false, features = [ + "response", + "telemetry", + "with-anyhow", + "with-sqlx", +] } + +axum = "0.8.8" +http = "1.4.0" +jsonwebtoken = { version = "10.3.0", features = ["aws_lc_rs"] } +dashmap = "6.1.0" +reqwest = { version = "0.12.28", default-features = false, features = [ + "json", + "rustls-tls", +] } +base64 = "0.22.1" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tracing = "0.1" +uuid = { version = "1", features = ["serde", "v4"] } + +[dev-dependencies] +rsa = "0.9.10" +tokio = { version = "1", features = ["full"] } diff --git a/README.md b/README.md new file mode 100644 index 0000000..b18aee2 --- /dev/null +++ b/README.md @@ -0,0 +1,124 @@ +# auth-kit + +`auth-kit` 是一个可复用的认证/多租户中间件库(Rust + Axum),用于在多个微服务中统一实现: + +- JWT 认证:从 `Authorization: Bearer ` 解析并验证 token,注入 `AuthContext` +- 租户隔离:从 `X-Tenant-ID` 解析租户并注入 `TenantId`,并在同时存在 token 与 header 时强制一致性 + +该 crate 不包含任何业务逻辑、数据库访问或 RBAC 聚合,仅提供“请求上下文注入 + 类型提取器 + JWT 验证能力”。 + +## 依赖与约定 + +- Web 框架:Axum +- 统一错误类型:`common_telemetry::AppError` +- 头部约定: + - `Authorization: Bearer ` + - `X-Tenant-ID: ` + +## 快速开始(在 Axum 中挂载) + +推荐链路顺序: + +1. `authenticate_with_config`(注入 `AuthContext`) +2. `resolve_tenant_with_config`(注入 `TenantId`,并校验 header tenant 与 token tenant 一致) + +示例: + +```rust +use axum::{Router, middleware::from_fn_with_state}; +use auth_kit::middleware::{ + auth::{self, AuthMiddlewareConfig}, + tenant::{self, TenantMiddlewareConfig}, +}; + +let auth_cfg = AuthMiddlewareConfig { + skip_exact_paths: vec!["/healthz".to_string()], + skip_path_prefixes: vec!["/scalar".to_string()], + jwt: auth_kit::jwt::JwtVerifyConfig::rs256_from_jwks( + "iam-service", + "http://127.0.0.1:3000/.well-known/jwks.json", + )?, +}; + +let tenant_cfg = TenantMiddlewareConfig { + skip_exact_paths: vec!["/healthz".to_string()], + skip_path_prefixes: vec!["/scalar".to_string()], +}; + +let app = Router::new() + .layer(from_fn_with_state(tenant_cfg, tenant::resolve_tenant_with_config)) + .layer(from_fn_with_state(auth_cfg, auth::authenticate_with_config)); +``` + +## Handler 中如何使用(类型注入) + +### AuthContext + +`AuthContext` 会从 request extensions 中提取,如果中间件未运行或缺失认证信息,会返回 `AppError::MissingAuthHeader`。 + +```rust +use auth_kit::middleware::auth::AuthContext; + +pub async fn handler(AuthContext { user_id, tenant_id, .. }: AuthContext) { + // ... +} +``` + +### TenantId + +`TenantId` 可从 request extensions 中提取;若不存在,会尝试从 `X-Tenant-ID` 解析;缺失则返回 `AppError::BadRequest("Missing X-Tenant-ID header")`。 + +```rust +use auth_kit::middleware::tenant::TenantId; + +pub async fn handler(TenantId(tenant_id): TenantId) { + // ... +} +``` + +## JWT 验证配置(JwtVerifyConfig) + +`JwtVerifyConfig` 定义在 [jwt.rs](file:///home/shay/project/backend/auth-kit/src/jwt.rs)。 + +- **静态公钥**:`rs256_from_pem(issuer, public_key_pem)` +- **JWKS 拉取**:`rs256_from_jwks(issuer, jwks_url)`(带缓存与降级) +- **对称密钥(调试/兼容)**:`hs256_from_secret(issuer, secret)` + +注意点: + +- RS256 模式要求 token header 中包含 `kid`(用于从 JWKS 中选 key) +- issuer 必须与 token 中 `iss` 一致 +- JWKS 目前只支持 RSA + `use=sig` + `alg=RS256` + +### JWKS 缓存策略(当前实现) + +- HTTP 超时:1500ms +- 缓存 TTL:300s +- stale-if-error:3600s(JWKS 拉取失败时可在窗口内使用缓存 key 继续验签) + +## Skip 规则(免鉴权路径) + +`AuthMiddlewareConfig` / `TenantMiddlewareConfig` 都支持: + +- `skip_exact_paths`:精确路径匹配 +- `skip_path_prefixes`:前缀匹配 + +用于跳过如 `/healthz`、`/scalar`、`/.well-known/jwks.json` 等公开端点。 + +## 端到端验证(推荐) + +本库自带一个集成测试,用于验证 RS256 + JWKS 的验签路径: + +- 测试文件: [jwks_verify.rs](file:///home/shay/project/backend/auth-kit/tests/jwks_verify.rs) +- 运行: + +```bash +cd /home/shay/project/backend/auth-kit +cargo test +``` + +## 已知限制 + +- 暂未提供“从环境变量自动组装配置”的辅助函数(建议业务服务自行装配 `AuthMiddlewareConfig`) +- 暂未实现 ECDSA(ES256)/EdDSA(Ed25519)JWKS 解析 + diff --git a/src/jwt.rs b/src/jwt.rs new file mode 100644 index 0000000..f460311 --- /dev/null +++ b/src/jwt.rs @@ -0,0 +1,233 @@ +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; +use common_telemetry::AppError; +use dashmap::DashMap; +use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header}; +use serde::{Deserialize, Serialize}; +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Claims { + pub sub: String, + pub tenant_id: String, + pub exp: usize, + pub iat: usize, + pub iss: String, + #[serde(default)] + pub roles: Vec, + #[serde(default)] + pub permissions: Vec, + #[serde(default)] + pub apps: Vec, + #[serde(default)] + pub apps_version: i32, +} + +#[derive(Clone)] +pub enum JwtVerifyConfig { + Static(StaticVerifyConfig), + Jwks(JwksVerifyConfig), +} + +#[derive(Clone)] +pub struct StaticVerifyConfig { + decoding_key: DecodingKey, + validation: Validation, +} + +#[derive(Clone)] +pub struct JwksVerifyConfig { + issuer: String, + jwks_url: String, + client: reqwest::Client, + cache: Arc>, + cache_ttl: Duration, + stale_if_error: Duration, +} + +impl JwtVerifyConfig { + pub fn rs256_from_pem(issuer: &str, public_key_pem: &str) -> Result { + let decoding_key = DecodingKey::from_rsa_pem(public_key_pem.as_bytes()) + .map_err(|e| AppError::ConfigError(format!("Invalid JWT public key pem: {}", e)))?; + let mut validation = Validation::new(Algorithm::RS256); + validation.set_issuer(&[issuer]); + Ok(Self::Static(StaticVerifyConfig { + decoding_key, + validation, + })) + } + + pub fn hs256_from_secret(issuer: &str, secret: &str) -> Self { + let decoding_key = DecodingKey::from_secret(secret.as_bytes()); + let mut validation = Validation::new(Algorithm::HS256); + validation.set_issuer(&[issuer]); + Self::Static(StaticVerifyConfig { + decoding_key, + validation, + }) + } + + pub fn rs256_from_jwks(issuer: &str, jwks_url: &str) -> Result { + let client = reqwest::Client::builder() + .timeout(Duration::from_millis(1500)) + .build() + .map_err(|e| AppError::ConfigError(format!("Failed to build http client: {}", e)))?; + Ok(Self::Jwks(JwksVerifyConfig { + issuer: issuer.to_string(), + jwks_url: jwks_url.to_string(), + client, + cache: Arc::new(DashMap::new()), + cache_ttl: Duration::from_secs(300), + stale_if_error: Duration::from_secs(3600), + })) + } +} + +#[derive(Clone)] +struct CachedJwk { + n: String, + e: String, + expires_at: Instant, + stale_until: Instant, +} + +#[derive(Debug, Deserialize)] +struct Jwks { + keys: Vec, +} + +#[derive(Debug, Deserialize)] +struct Jwk { + kty: String, + kid: Option, + #[serde(rename = "use")] + use_field: Option, + alg: Option, + n: Option, + e: Option, +} + +pub async fn verify(token: &str, cfg: &JwtVerifyConfig) -> Result { + match cfg { + JwtVerifyConfig::Static(static_cfg) => { + let token_data = + decode::(token, &static_cfg.decoding_key, &static_cfg.validation) + .map_err(|e| AppError::AuthError(e.to_string()))?; + Ok(token_data.claims) + } + JwtVerifyConfig::Jwks(jwks_cfg) => { + let header = decode_header(token).map_err(|e| AppError::AuthError(e.to_string()))?; + let kid = header + .kid + .ok_or_else(|| AppError::AuthError("Missing kid in JWT header".into()))?; + + let now = Instant::now(); + if let Some(entry) = jwks_cfg.cache.get(&kid).map(|e| e.clone()) { + if entry.expires_at > now { + return verify_with_rsa_components(token, &jwks_cfg.issuer, &entry.n, &entry.e); + } + } + + let fetched = fetch_jwk_by_kid(&jwks_cfg.client, &jwks_cfg.jwks_url, &kid).await; + match fetched { + Ok((n, e)) => { + let entry = CachedJwk { + n: n.clone(), + e: e.clone(), + expires_at: now + jwks_cfg.cache_ttl, + stale_until: now + jwks_cfg.cache_ttl + jwks_cfg.stale_if_error, + }; + jwks_cfg.cache.insert(kid.clone(), entry); + verify_with_rsa_components(token, &jwks_cfg.issuer, &n, &e) + } + Err(err) => { + if let Some(entry) = jwks_cfg.cache.get(&kid).map(|e| e.clone()) { + if entry.stale_until > now { + return verify_with_rsa_components( + token, + &jwks_cfg.issuer, + &entry.n, + &entry.e, + ); + } + } + Err(err) + } + } + } + } +} + +fn verify_with_rsa_components( + token: &str, + issuer: &str, + n: &str, + e: &str, +) -> Result { + let decoding_key = DecodingKey::from_rsa_components(n, e) + .map_err(|e| AppError::AuthError(format!("Invalid jwk rsa components: {}", e)))?; + let mut validation = Validation::new(Algorithm::RS256); + validation.set_issuer(&[issuer]); + let token_data = decode::(token, &decoding_key, &validation) + .map_err(|e| AppError::AuthError(e.to_string()))?; + Ok(token_data.claims) +} + +async fn fetch_jwk_by_kid( + client: &reqwest::Client, + jwks_url: &str, + kid: &str, +) -> Result<(String, String), AppError> { + let resp = client + .get(jwks_url) + .send() + .await + .map_err(|e| AppError::ExternalReqError(format!("jwks:request_failed:{}", e)))?; + + if !resp.status().is_success() { + return Err(AppError::ExternalReqError(format!( + "jwks:unexpected_status:{}", + resp.status().as_u16() + ))); + } + + let jwks: Jwks = resp + .json() + .await + .map_err(|e| AppError::ExternalReqError(format!("jwks:decode_failed:{}", e)))?; + + let key = jwks + .keys + .into_iter() + .find(|k| k.kid.as_deref() == Some(kid)) + .ok_or_else(|| AppError::ExternalReqError("jwks:kid_not_found".into()))?; + + if key.kty != "RSA" { + return Err(AppError::ExternalReqError("jwks:unsupported_kty".into())); + } + if key.use_field.as_deref() != Some("sig") { + return Err(AppError::ExternalReqError("jwks:unsupported_use".into())); + } + if let Some(alg) = &key.alg { + if alg != "RS256" { + return Err(AppError::ExternalReqError("jwks:unsupported_alg".into())); + } + } + + let n = key + .n + .ok_or_else(|| AppError::ExternalReqError("jwks:missing_n".into()))?; + let e = key + .e + .ok_or_else(|| AppError::ExternalReqError("jwks:missing_e".into()))?; + + if URL_SAFE_NO_PAD.decode(n.as_bytes()).is_err() + || URL_SAFE_NO_PAD.decode(e.as_bytes()).is_err() + { + return Err(AppError::ExternalReqError("jwks:invalid_base64url".into())); + } + + Ok((n, e)) +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..0f2afcb --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,2 @@ +pub mod jwt; +pub mod middleware; diff --git a/src/middleware/auth.rs b/src/middleware/auth.rs new file mode 100644 index 0000000..fbe2167 --- /dev/null +++ b/src/middleware/auth.rs @@ -0,0 +1,104 @@ +use axum::{ + extract::{FromRequestParts, Request, State}, + http::request::Parts, + middleware::Next, + response::Response, +}; +use common_telemetry::AppError; +use uuid::Uuid; + +use crate::jwt::{Claims, JwtVerifyConfig}; + +#[derive(Clone)] +pub struct AuthMiddlewareConfig { + pub skip_exact_paths: Vec, + pub skip_path_prefixes: Vec, + pub jwt: JwtVerifyConfig, +} + +impl AuthMiddlewareConfig { + pub fn should_skip(&self, path: &str) -> bool { + self.skip_exact_paths.iter().any(|p| p == path) + || self + .skip_path_prefixes + .iter() + .any(|prefix| path.starts_with(prefix)) + } +} + +#[derive(Clone, Debug)] +pub struct AuthContext { + pub tenant_id: Uuid, + pub user_id: Uuid, + pub roles: Vec, + pub permissions: Vec, + pub apps: Vec, + pub apps_version: i32, +} + +fn claims_to_context(claims: Claims) -> Result { + let tenant_id = Uuid::parse_str(&claims.tenant_id) + .map_err(|_| AppError::AuthError("Invalid tenant_id claim".into()))?; + let user_id = Uuid::parse_str(&claims.sub) + .map_err(|_| AppError::AuthError("Invalid sub claim".into()))?; + + Ok(AuthContext { + tenant_id, + user_id, + roles: claims.roles, + permissions: claims.permissions, + apps: claims.apps, + apps_version: claims.apps_version, + }) +} + +async fn authenticate_inner( + cfg: &AuthMiddlewareConfig, + mut req: Request, + next: Next, +) -> Result { + let path = req.uri().path(); + if cfg.should_skip(path) { + return Ok(next.run(req).await); + } + + let token = req + .headers() + .get(axum::http::header::AUTHORIZATION) + .and_then(|h| h.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) + .ok_or(AppError::MissingAuthHeader)?; + + let claims = crate::jwt::verify(token, &cfg.jwt).await?; + let ctx = claims_to_context(claims)?; + + tracing::Span::current().record("tenant_id", tracing::field::display(ctx.tenant_id)); + tracing::Span::current().record("user_id", tracing::field::display(ctx.user_id)); + + req.extensions_mut().insert(ctx); + + Ok(next.run(req).await) +} + +pub async fn authenticate_with_config( + State(cfg): State, + req: Request, + next: Next, +) -> Result { + authenticate_inner(&cfg, req, next).await +} + +impl FromRequestParts for AuthContext +where + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(AppError::MissingAuthHeader) + } +} diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs new file mode 100644 index 0000000..c10527d --- /dev/null +++ b/src/middleware/mod.rs @@ -0,0 +1,2 @@ +pub mod auth; +pub mod tenant; diff --git a/src/middleware/tenant.rs b/src/middleware/tenant.rs new file mode 100644 index 0000000..83f5fe2 --- /dev/null +++ b/src/middleware/tenant.rs @@ -0,0 +1,105 @@ +use axum::{ + extract::{FromRequestParts, Request, State}, + middleware::Next, + response::Response, +}; +use common_telemetry::AppError; +use http::request::Parts; +use uuid::Uuid; + +use crate::middleware::auth::AuthContext; + +#[derive(Clone, Debug)] +pub struct TenantId(pub Uuid); + +#[derive(Clone)] +pub struct TenantMiddlewareConfig { + pub skip_exact_paths: Vec, + pub skip_path_prefixes: Vec, +} + +impl TenantMiddlewareConfig { + pub fn should_skip(&self, path: &str) -> bool { + self.skip_exact_paths.iter().any(|p| p == path) + || self + .skip_path_prefixes + .iter() + .any(|prefix| path.starts_with(prefix)) + } +} + +async fn resolve_tenant_inner( + cfg: &TenantMiddlewareConfig, + mut req: Request, + next: Next, +) -> Result { + let path = req.uri().path(); + if cfg.should_skip(path) { + return Ok(next.run(req).await); + } + + if let Some(auth_tenant_id) = req.extensions().get::().map(|c| c.tenant_id) { + if let Some(header_value) = req + .headers() + .get("X-Tenant-ID") + .and_then(|val| val.to_str().ok()) + { + let header_tenant_id = Uuid::parse_str(header_value) + .map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into()))?; + if header_tenant_id != auth_tenant_id { + return Err(AppError::PermissionDenied("tenant:mismatch".into())); + } + } + tracing::Span::current().record("tenant_id", tracing::field::display(auth_tenant_id)); + req.extensions_mut().insert(TenantId(auth_tenant_id)); + return Ok(next.run(req).await); + } + + let tenant_id_str = req + .headers() + .get("X-Tenant-ID") + .and_then(|val| val.to_str().ok()); + + match tenant_id_str { + Some(id_str) => { + let uuid = Uuid::parse_str(id_str) + .map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into()))?; + tracing::Span::current().record("tenant_id", tracing::field::display(uuid)); + req.extensions_mut().insert(TenantId(uuid)); + Ok(next.run(req).await) + } + None => Err(AppError::BadRequest("Missing X-Tenant-ID header".into())), + } +} + +pub async fn resolve_tenant_with_config( + State(cfg): State, + req: Request, + next: Next, +) -> Result { + resolve_tenant_inner(&cfg, req, next).await +} + +impl FromRequestParts for TenantId +where + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + if let Some(tid) = parts.extensions.get::().cloned() { + return Ok(tid); + } + let tenant_id_str = parts + .headers + .get("X-Tenant-ID") + .and_then(|val| val.to_str().ok()); + + match tenant_id_str { + Some(id_str) => uuid::Uuid::parse_str(id_str) + .map(TenantId) + .map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into())), + None => Err(AppError::BadRequest("Missing X-Tenant-ID header".into())), + } + } +} diff --git a/tests/jwks_verify.rs b/tests/jwks_verify.rs new file mode 100644 index 0000000..235d02d --- /dev/null +++ b/tests/jwks_verify.rs @@ -0,0 +1,111 @@ +use auth_kit::jwt::{Claims, JwtVerifyConfig}; +use axum::response::IntoResponse; +use axum::{Json, Router, routing::get}; +use base64::Engine as _; +use jsonwebtoken::{Algorithm, EncodingKey, Header, encode}; +use rsa::pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey}; +use rsa::rand_core::OsRng; +use rsa::traits::PublicKeyParts; +use rsa::{RsaPrivateKey, pkcs1::LineEnding}; +use serde::Serialize; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +#[derive(Serialize, Clone)] +struct Jwks { + keys: Vec, +} + +#[derive(Serialize, Clone)] +struct Jwk { + kty: &'static str, + kid: &'static str, + #[serde(rename = "use")] + use_field: &'static str, + alg: &'static str, + n: String, + e: String, +} + +fn base64url_no_pad(data: &[u8]) -> String { + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data) +} + +#[tokio::test] +async fn verify_rs256_via_jwks() { + let private = RsaPrivateKey::new(&mut OsRng, 2048).unwrap(); + let public = private.to_public_key(); + + let private_pem = private.to_pkcs1_pem(LineEnding::LF).unwrap().to_string(); + let public_pem = public.to_pkcs1_pem(LineEnding::LF).unwrap().to_string(); + + let n = base64url_no_pad(&public.n().to_bytes_be()); + let e = base64url_no_pad(&public.e().to_bytes_be()); + + let kid = "test-kid"; + let jwks = Jwks { + keys: vec![Jwk { + kty: "RSA", + kid, + use_field: "sig", + alg: "RS256", + n, + e, + }], + }; + + let app = Router::new().route( + "/.well-known/jwks.json", + get(move || { + let jwks = jwks; + async move { (axum::http::StatusCode::OK, Json(jwks)).into_response() } + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let base_url = format!("http://{}", addr); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as usize; + + let claims = Claims { + sub: uuid::Uuid::new_v4().to_string(), + tenant_id: uuid::Uuid::new_v4().to_string(), + exp: now + 60, + iat: now, + iss: "iam-service".to_string(), + roles: vec![], + permissions: vec![], + apps: vec![], + apps_version: 0, + }; + + let mut header = Header::new(Algorithm::RS256); + header.kid = Some(kid.to_string()); + let token = encode( + &header, + &claims, + &EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap(), + ) + .unwrap(); + + let cfg = JwtVerifyConfig::rs256_from_jwks( + "iam-service", + &format!("{}/.well-known/jwks.json", base_url), + ) + .unwrap(); + let verified = auth_kit::jwt::verify(&token, &cfg).await.unwrap(); + assert_eq!(verified.tenant_id, claims.tenant_id); + + let cfg2 = JwtVerifyConfig::rs256_from_pem("iam-service", &public_pem).unwrap(); + let verified2 = auth_kit::jwt::verify(&token, &cfg2).await.unwrap(); + assert_eq!(verified2.sub, claims.sub); + + tokio::time::sleep(Duration::from_millis(10)).await; + handle.abort(); +}