feat(project): init
This commit is contained in:
5
.cargo/config.toml
Normal file
5
.cargo/config.toml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
[registries.kellnr]
|
||||||
|
index = "sparse+https://kellnr.shay7sev.site/api/v1/crates/"
|
||||||
|
|
||||||
|
[net]
|
||||||
|
git-fetch-with-cli = true
|
||||||
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
/target
|
||||||
|
Cargo.lock
|
||||||
30
Cargo.toml
Normal file
30
Cargo.toml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
[package]
|
||||||
|
name = "auth-kit"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
common-telemetry = { version = "0.1.5", registry = "kellnr", default-features = false, features = [
|
||||||
|
"response",
|
||||||
|
"telemetry",
|
||||||
|
"with-anyhow",
|
||||||
|
"with-sqlx",
|
||||||
|
] }
|
||||||
|
|
||||||
|
axum = "0.8.8"
|
||||||
|
http = "1.4.0"
|
||||||
|
jsonwebtoken = { version = "10.3.0", features = ["aws_lc_rs"] }
|
||||||
|
dashmap = "6.1.0"
|
||||||
|
reqwest = { version = "0.12.28", default-features = false, features = [
|
||||||
|
"json",
|
||||||
|
"rustls-tls",
|
||||||
|
] }
|
||||||
|
base64 = "0.22.1"
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
tracing = "0.1"
|
||||||
|
uuid = { version = "1", features = ["serde", "v4"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rsa = "0.9.10"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
124
README.md
Normal file
124
README.md
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# auth-kit
|
||||||
|
|
||||||
|
`auth-kit` 是一个可复用的认证/多租户中间件库(Rust + Axum),用于在多个微服务中统一实现:
|
||||||
|
|
||||||
|
- JWT 认证:从 `Authorization: Bearer <token>` 解析并验证 token,注入 `AuthContext`
|
||||||
|
- 租户隔离:从 `X-Tenant-ID` 解析租户并注入 `TenantId`,并在同时存在 token 与 header 时强制一致性
|
||||||
|
|
||||||
|
该 crate 不包含任何业务逻辑、数据库访问或 RBAC 聚合,仅提供“请求上下文注入 + 类型提取器 + JWT 验证能力”。
|
||||||
|
|
||||||
|
## 依赖与约定
|
||||||
|
|
||||||
|
- Web 框架:Axum
|
||||||
|
- 统一错误类型:`common_telemetry::AppError`
|
||||||
|
- 头部约定:
|
||||||
|
- `Authorization: Bearer <access_token>`
|
||||||
|
- `X-Tenant-ID: <tenant_uuid>`
|
||||||
|
|
||||||
|
## 快速开始(在 Axum 中挂载)
|
||||||
|
|
||||||
|
推荐链路顺序:
|
||||||
|
|
||||||
|
1. `authenticate_with_config`(注入 `AuthContext`)
|
||||||
|
2. `resolve_tenant_with_config`(注入 `TenantId`,并校验 header tenant 与 token tenant 一致)
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use axum::{Router, middleware::from_fn_with_state};
|
||||||
|
use auth_kit::middleware::{
|
||||||
|
auth::{self, AuthMiddlewareConfig},
|
||||||
|
tenant::{self, TenantMiddlewareConfig},
|
||||||
|
};
|
||||||
|
|
||||||
|
let auth_cfg = AuthMiddlewareConfig {
|
||||||
|
skip_exact_paths: vec!["/healthz".to_string()],
|
||||||
|
skip_path_prefixes: vec!["/scalar".to_string()],
|
||||||
|
jwt: auth_kit::jwt::JwtVerifyConfig::rs256_from_jwks(
|
||||||
|
"iam-service",
|
||||||
|
"http://127.0.0.1:3000/.well-known/jwks.json",
|
||||||
|
)?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tenant_cfg = TenantMiddlewareConfig {
|
||||||
|
skip_exact_paths: vec!["/healthz".to_string()],
|
||||||
|
skip_path_prefixes: vec!["/scalar".to_string()],
|
||||||
|
};
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.layer(from_fn_with_state(tenant_cfg, tenant::resolve_tenant_with_config))
|
||||||
|
.layer(from_fn_with_state(auth_cfg, auth::authenticate_with_config));
|
||||||
|
```
|
||||||
|
|
||||||
|
## Handler 中如何使用(类型注入)
|
||||||
|
|
||||||
|
### AuthContext
|
||||||
|
|
||||||
|
`AuthContext` 会从 request extensions 中提取,如果中间件未运行或缺失认证信息,会返回 `AppError::MissingAuthHeader`。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use auth_kit::middleware::auth::AuthContext;
|
||||||
|
|
||||||
|
pub async fn handler(AuthContext { user_id, tenant_id, .. }: AuthContext) {
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### TenantId
|
||||||
|
|
||||||
|
`TenantId` 可从 request extensions 中提取;若不存在,会尝试从 `X-Tenant-ID` 解析;缺失则返回 `AppError::BadRequest("Missing X-Tenant-ID header")`。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use auth_kit::middleware::tenant::TenantId;
|
||||||
|
|
||||||
|
pub async fn handler(TenantId(tenant_id): TenantId) {
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## JWT 验证配置(JwtVerifyConfig)
|
||||||
|
|
||||||
|
`JwtVerifyConfig` 定义在 [jwt.rs](file:///home/shay/project/backend/auth-kit/src/jwt.rs)。
|
||||||
|
|
||||||
|
- **静态公钥**:`rs256_from_pem(issuer, public_key_pem)`
|
||||||
|
- **JWKS 拉取**:`rs256_from_jwks(issuer, jwks_url)`(带缓存与降级)
|
||||||
|
- **对称密钥(调试/兼容)**:`hs256_from_secret(issuer, secret)`
|
||||||
|
|
||||||
|
注意点:
|
||||||
|
|
||||||
|
- RS256 模式要求 token header 中包含 `kid`(用于从 JWKS 中选 key)
|
||||||
|
- issuer 必须与 token 中 `iss` 一致
|
||||||
|
- JWKS 目前只支持 RSA + `use=sig` + `alg=RS256`
|
||||||
|
|
||||||
|
### JWKS 缓存策略(当前实现)
|
||||||
|
|
||||||
|
- HTTP 超时:1500ms
|
||||||
|
- 缓存 TTL:300s
|
||||||
|
- stale-if-error:3600s(JWKS 拉取失败时可在窗口内使用缓存 key 继续验签)
|
||||||
|
|
||||||
|
## Skip 规则(免鉴权路径)
|
||||||
|
|
||||||
|
`AuthMiddlewareConfig` / `TenantMiddlewareConfig` 都支持:
|
||||||
|
|
||||||
|
- `skip_exact_paths`:精确路径匹配
|
||||||
|
- `skip_path_prefixes`:前缀匹配
|
||||||
|
|
||||||
|
用于跳过如 `/healthz`、`/scalar`、`/.well-known/jwks.json` 等公开端点。
|
||||||
|
|
||||||
|
## 端到端验证(推荐)
|
||||||
|
|
||||||
|
本库自带一个集成测试,用于验证 RS256 + JWKS 的验签路径:
|
||||||
|
|
||||||
|
- 测试文件: [jwks_verify.rs](file:///home/shay/project/backend/auth-kit/tests/jwks_verify.rs)
|
||||||
|
- 运行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /home/shay/project/backend/auth-kit
|
||||||
|
cargo test
|
||||||
|
```
|
||||||
|
|
||||||
|
## 已知限制
|
||||||
|
|
||||||
|
- 暂未提供“从环境变量自动组装配置”的辅助函数(建议业务服务自行装配 `AuthMiddlewareConfig`)
|
||||||
|
- 暂未实现 ECDSA(ES256)/EdDSA(Ed25519)JWKS 解析
|
||||||
|
|
||||||
233
src/jwt.rs
Normal file
233
src/jwt.rs
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||||
|
use common_telemetry::AppError;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{
|
||||||
|
sync::Arc,
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct Claims {
|
||||||
|
pub sub: String,
|
||||||
|
pub tenant_id: String,
|
||||||
|
pub exp: usize,
|
||||||
|
pub iat: usize,
|
||||||
|
pub iss: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub roles: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub permissions: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub apps: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub apps_version: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum JwtVerifyConfig {
|
||||||
|
Static(StaticVerifyConfig),
|
||||||
|
Jwks(JwksVerifyConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StaticVerifyConfig {
|
||||||
|
decoding_key: DecodingKey,
|
||||||
|
validation: Validation,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct JwksVerifyConfig {
|
||||||
|
issuer: String,
|
||||||
|
jwks_url: String,
|
||||||
|
client: reqwest::Client,
|
||||||
|
cache: Arc<DashMap<String, CachedJwk>>,
|
||||||
|
cache_ttl: Duration,
|
||||||
|
stale_if_error: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JwtVerifyConfig {
|
||||||
|
pub fn rs256_from_pem(issuer: &str, public_key_pem: &str) -> Result<Self, AppError> {
|
||||||
|
let decoding_key = DecodingKey::from_rsa_pem(public_key_pem.as_bytes())
|
||||||
|
.map_err(|e| AppError::ConfigError(format!("Invalid JWT public key pem: {}", e)))?;
|
||||||
|
let mut validation = Validation::new(Algorithm::RS256);
|
||||||
|
validation.set_issuer(&[issuer]);
|
||||||
|
Ok(Self::Static(StaticVerifyConfig {
|
||||||
|
decoding_key,
|
||||||
|
validation,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn hs256_from_secret(issuer: &str, secret: &str) -> Self {
|
||||||
|
let decoding_key = DecodingKey::from_secret(secret.as_bytes());
|
||||||
|
let mut validation = Validation::new(Algorithm::HS256);
|
||||||
|
validation.set_issuer(&[issuer]);
|
||||||
|
Self::Static(StaticVerifyConfig {
|
||||||
|
decoding_key,
|
||||||
|
validation,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rs256_from_jwks(issuer: &str, jwks_url: &str) -> Result<Self, AppError> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_millis(1500))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| AppError::ConfigError(format!("Failed to build http client: {}", e)))?;
|
||||||
|
Ok(Self::Jwks(JwksVerifyConfig {
|
||||||
|
issuer: issuer.to_string(),
|
||||||
|
jwks_url: jwks_url.to_string(),
|
||||||
|
client,
|
||||||
|
cache: Arc::new(DashMap::new()),
|
||||||
|
cache_ttl: Duration::from_secs(300),
|
||||||
|
stale_if_error: Duration::from_secs(3600),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CachedJwk {
|
||||||
|
n: String,
|
||||||
|
e: String,
|
||||||
|
expires_at: Instant,
|
||||||
|
stale_until: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct Jwks {
|
||||||
|
keys: Vec<Jwk>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct Jwk {
|
||||||
|
kty: String,
|
||||||
|
kid: Option<String>,
|
||||||
|
#[serde(rename = "use")]
|
||||||
|
use_field: Option<String>,
|
||||||
|
alg: Option<String>,
|
||||||
|
n: Option<String>,
|
||||||
|
e: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn verify(token: &str, cfg: &JwtVerifyConfig) -> Result<Claims, AppError> {
|
||||||
|
match cfg {
|
||||||
|
JwtVerifyConfig::Static(static_cfg) => {
|
||||||
|
let token_data =
|
||||||
|
decode::<Claims>(token, &static_cfg.decoding_key, &static_cfg.validation)
|
||||||
|
.map_err(|e| AppError::AuthError(e.to_string()))?;
|
||||||
|
Ok(token_data.claims)
|
||||||
|
}
|
||||||
|
JwtVerifyConfig::Jwks(jwks_cfg) => {
|
||||||
|
let header = decode_header(token).map_err(|e| AppError::AuthError(e.to_string()))?;
|
||||||
|
let kid = header
|
||||||
|
.kid
|
||||||
|
.ok_or_else(|| AppError::AuthError("Missing kid in JWT header".into()))?;
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
if let Some(entry) = jwks_cfg.cache.get(&kid).map(|e| e.clone()) {
|
||||||
|
if entry.expires_at > now {
|
||||||
|
return verify_with_rsa_components(token, &jwks_cfg.issuer, &entry.n, &entry.e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let fetched = fetch_jwk_by_kid(&jwks_cfg.client, &jwks_cfg.jwks_url, &kid).await;
|
||||||
|
match fetched {
|
||||||
|
Ok((n, e)) => {
|
||||||
|
let entry = CachedJwk {
|
||||||
|
n: n.clone(),
|
||||||
|
e: e.clone(),
|
||||||
|
expires_at: now + jwks_cfg.cache_ttl,
|
||||||
|
stale_until: now + jwks_cfg.cache_ttl + jwks_cfg.stale_if_error,
|
||||||
|
};
|
||||||
|
jwks_cfg.cache.insert(kid.clone(), entry);
|
||||||
|
verify_with_rsa_components(token, &jwks_cfg.issuer, &n, &e)
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
if let Some(entry) = jwks_cfg.cache.get(&kid).map(|e| e.clone()) {
|
||||||
|
if entry.stale_until > now {
|
||||||
|
return verify_with_rsa_components(
|
||||||
|
token,
|
||||||
|
&jwks_cfg.issuer,
|
||||||
|
&entry.n,
|
||||||
|
&entry.e,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify_with_rsa_components(
|
||||||
|
token: &str,
|
||||||
|
issuer: &str,
|
||||||
|
n: &str,
|
||||||
|
e: &str,
|
||||||
|
) -> Result<Claims, AppError> {
|
||||||
|
let decoding_key = DecodingKey::from_rsa_components(n, e)
|
||||||
|
.map_err(|e| AppError::AuthError(format!("Invalid jwk rsa components: {}", e)))?;
|
||||||
|
let mut validation = Validation::new(Algorithm::RS256);
|
||||||
|
validation.set_issuer(&[issuer]);
|
||||||
|
let token_data = decode::<Claims>(token, &decoding_key, &validation)
|
||||||
|
.map_err(|e| AppError::AuthError(e.to_string()))?;
|
||||||
|
Ok(token_data.claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_jwk_by_kid(
|
||||||
|
client: &reqwest::Client,
|
||||||
|
jwks_url: &str,
|
||||||
|
kid: &str,
|
||||||
|
) -> Result<(String, String), AppError> {
|
||||||
|
let resp = client
|
||||||
|
.get(jwks_url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::ExternalReqError(format!("jwks:request_failed:{}", e)))?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(AppError::ExternalReqError(format!(
|
||||||
|
"jwks:unexpected_status:{}",
|
||||||
|
resp.status().as_u16()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let jwks: Jwks = resp
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::ExternalReqError(format!("jwks:decode_failed:{}", e)))?;
|
||||||
|
|
||||||
|
let key = jwks
|
||||||
|
.keys
|
||||||
|
.into_iter()
|
||||||
|
.find(|k| k.kid.as_deref() == Some(kid))
|
||||||
|
.ok_or_else(|| AppError::ExternalReqError("jwks:kid_not_found".into()))?;
|
||||||
|
|
||||||
|
if key.kty != "RSA" {
|
||||||
|
return Err(AppError::ExternalReqError("jwks:unsupported_kty".into()));
|
||||||
|
}
|
||||||
|
if key.use_field.as_deref() != Some("sig") {
|
||||||
|
return Err(AppError::ExternalReqError("jwks:unsupported_use".into()));
|
||||||
|
}
|
||||||
|
if let Some(alg) = &key.alg {
|
||||||
|
if alg != "RS256" {
|
||||||
|
return Err(AppError::ExternalReqError("jwks:unsupported_alg".into()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let n = key
|
||||||
|
.n
|
||||||
|
.ok_or_else(|| AppError::ExternalReqError("jwks:missing_n".into()))?;
|
||||||
|
let e = key
|
||||||
|
.e
|
||||||
|
.ok_or_else(|| AppError::ExternalReqError("jwks:missing_e".into()))?;
|
||||||
|
|
||||||
|
if URL_SAFE_NO_PAD.decode(n.as_bytes()).is_err()
|
||||||
|
|| URL_SAFE_NO_PAD.decode(e.as_bytes()).is_err()
|
||||||
|
{
|
||||||
|
return Err(AppError::ExternalReqError("jwks:invalid_base64url".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((n, e))
|
||||||
|
}
|
||||||
2
src/lib.rs
Normal file
2
src/lib.rs
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
pub mod jwt;
|
||||||
|
pub mod middleware;
|
||||||
104
src/middleware/auth.rs
Normal file
104
src/middleware/auth.rs
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{FromRequestParts, Request, State},
|
||||||
|
http::request::Parts,
|
||||||
|
middleware::Next,
|
||||||
|
response::Response,
|
||||||
|
};
|
||||||
|
use common_telemetry::AppError;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::jwt::{Claims, JwtVerifyConfig};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AuthMiddlewareConfig {
|
||||||
|
pub skip_exact_paths: Vec<String>,
|
||||||
|
pub skip_path_prefixes: Vec<String>,
|
||||||
|
pub jwt: JwtVerifyConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthMiddlewareConfig {
|
||||||
|
pub fn should_skip(&self, path: &str) -> bool {
|
||||||
|
self.skip_exact_paths.iter().any(|p| p == path)
|
||||||
|
|| self
|
||||||
|
.skip_path_prefixes
|
||||||
|
.iter()
|
||||||
|
.any(|prefix| path.starts_with(prefix))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct AuthContext {
|
||||||
|
pub tenant_id: Uuid,
|
||||||
|
pub user_id: Uuid,
|
||||||
|
pub roles: Vec<String>,
|
||||||
|
pub permissions: Vec<String>,
|
||||||
|
pub apps: Vec<String>,
|
||||||
|
pub apps_version: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn claims_to_context(claims: Claims) -> Result<AuthContext, AppError> {
|
||||||
|
let tenant_id = Uuid::parse_str(&claims.tenant_id)
|
||||||
|
.map_err(|_| AppError::AuthError("Invalid tenant_id claim".into()))?;
|
||||||
|
let user_id = Uuid::parse_str(&claims.sub)
|
||||||
|
.map_err(|_| AppError::AuthError("Invalid sub claim".into()))?;
|
||||||
|
|
||||||
|
Ok(AuthContext {
|
||||||
|
tenant_id,
|
||||||
|
user_id,
|
||||||
|
roles: claims.roles,
|
||||||
|
permissions: claims.permissions,
|
||||||
|
apps: claims.apps,
|
||||||
|
apps_version: claims.apps_version,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn authenticate_inner(
|
||||||
|
cfg: &AuthMiddlewareConfig,
|
||||||
|
mut req: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, AppError> {
|
||||||
|
let path = req.uri().path();
|
||||||
|
if cfg.should_skip(path) {
|
||||||
|
return Ok(next.run(req).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let token = req
|
||||||
|
.headers()
|
||||||
|
.get(axum::http::header::AUTHORIZATION)
|
||||||
|
.and_then(|h| h.to_str().ok())
|
||||||
|
.and_then(|v| v.strip_prefix("Bearer "))
|
||||||
|
.ok_or(AppError::MissingAuthHeader)?;
|
||||||
|
|
||||||
|
let claims = crate::jwt::verify(token, &cfg.jwt).await?;
|
||||||
|
let ctx = claims_to_context(claims)?;
|
||||||
|
|
||||||
|
tracing::Span::current().record("tenant_id", tracing::field::display(ctx.tenant_id));
|
||||||
|
tracing::Span::current().record("user_id", tracing::field::display(ctx.user_id));
|
||||||
|
|
||||||
|
req.extensions_mut().insert(ctx);
|
||||||
|
|
||||||
|
Ok(next.run(req).await)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn authenticate_with_config(
|
||||||
|
State(cfg): State<AuthMiddlewareConfig>,
|
||||||
|
req: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, AppError> {
|
||||||
|
authenticate_inner(&cfg, req, next).await
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> FromRequestParts<S> for AuthContext
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
{
|
||||||
|
type Rejection = AppError;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
parts
|
||||||
|
.extensions
|
||||||
|
.get::<AuthContext>()
|
||||||
|
.cloned()
|
||||||
|
.ok_or(AppError::MissingAuthHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
2
src/middleware/mod.rs
Normal file
2
src/middleware/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
pub mod auth;
|
||||||
|
pub mod tenant;
|
||||||
105
src/middleware/tenant.rs
Normal file
105
src/middleware/tenant.rs
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{FromRequestParts, Request, State},
|
||||||
|
middleware::Next,
|
||||||
|
response::Response,
|
||||||
|
};
|
||||||
|
use common_telemetry::AppError;
|
||||||
|
use http::request::Parts;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::middleware::auth::AuthContext;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct TenantId(pub Uuid);
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct TenantMiddlewareConfig {
|
||||||
|
pub skip_exact_paths: Vec<String>,
|
||||||
|
pub skip_path_prefixes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TenantMiddlewareConfig {
|
||||||
|
pub fn should_skip(&self, path: &str) -> bool {
|
||||||
|
self.skip_exact_paths.iter().any(|p| p == path)
|
||||||
|
|| self
|
||||||
|
.skip_path_prefixes
|
||||||
|
.iter()
|
||||||
|
.any(|prefix| path.starts_with(prefix))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn resolve_tenant_inner(
|
||||||
|
cfg: &TenantMiddlewareConfig,
|
||||||
|
mut req: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, AppError> {
|
||||||
|
let path = req.uri().path();
|
||||||
|
if cfg.should_skip(path) {
|
||||||
|
return Ok(next.run(req).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(auth_tenant_id) = req.extensions().get::<AuthContext>().map(|c| c.tenant_id) {
|
||||||
|
if let Some(header_value) = req
|
||||||
|
.headers()
|
||||||
|
.get("X-Tenant-ID")
|
||||||
|
.and_then(|val| val.to_str().ok())
|
||||||
|
{
|
||||||
|
let header_tenant_id = Uuid::parse_str(header_value)
|
||||||
|
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into()))?;
|
||||||
|
if header_tenant_id != auth_tenant_id {
|
||||||
|
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::Span::current().record("tenant_id", tracing::field::display(auth_tenant_id));
|
||||||
|
req.extensions_mut().insert(TenantId(auth_tenant_id));
|
||||||
|
return Ok(next.run(req).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let tenant_id_str = req
|
||||||
|
.headers()
|
||||||
|
.get("X-Tenant-ID")
|
||||||
|
.and_then(|val| val.to_str().ok());
|
||||||
|
|
||||||
|
match tenant_id_str {
|
||||||
|
Some(id_str) => {
|
||||||
|
let uuid = Uuid::parse_str(id_str)
|
||||||
|
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into()))?;
|
||||||
|
tracing::Span::current().record("tenant_id", tracing::field::display(uuid));
|
||||||
|
req.extensions_mut().insert(TenantId(uuid));
|
||||||
|
Ok(next.run(req).await)
|
||||||
|
}
|
||||||
|
None => Err(AppError::BadRequest("Missing X-Tenant-ID header".into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn resolve_tenant_with_config(
|
||||||
|
State(cfg): State<TenantMiddlewareConfig>,
|
||||||
|
req: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, AppError> {
|
||||||
|
resolve_tenant_inner(&cfg, req, next).await
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> FromRequestParts<S> for TenantId
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
{
|
||||||
|
type Rejection = AppError;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
if let Some(tid) = parts.extensions.get::<TenantId>().cloned() {
|
||||||
|
return Ok(tid);
|
||||||
|
}
|
||||||
|
let tenant_id_str = parts
|
||||||
|
.headers
|
||||||
|
.get("X-Tenant-ID")
|
||||||
|
.and_then(|val| val.to_str().ok());
|
||||||
|
|
||||||
|
match tenant_id_str {
|
||||||
|
Some(id_str) => uuid::Uuid::parse_str(id_str)
|
||||||
|
.map(TenantId)
|
||||||
|
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into())),
|
||||||
|
None => Err(AppError::BadRequest("Missing X-Tenant-ID header".into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
111
tests/jwks_verify.rs
Normal file
111
tests/jwks_verify.rs
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
use auth_kit::jwt::{Claims, JwtVerifyConfig};
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use axum::{Json, Router, routing::get};
|
||||||
|
use base64::Engine as _;
|
||||||
|
use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
|
||||||
|
use rsa::pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey};
|
||||||
|
use rsa::rand_core::OsRng;
|
||||||
|
use rsa::traits::PublicKeyParts;
|
||||||
|
use rsa::{RsaPrivateKey, pkcs1::LineEnding};
|
||||||
|
use serde::Serialize;
|
||||||
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
struct Jwks {
|
||||||
|
keys: Vec<Jwk>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Clone)]
|
||||||
|
struct Jwk {
|
||||||
|
kty: &'static str,
|
||||||
|
kid: &'static str,
|
||||||
|
#[serde(rename = "use")]
|
||||||
|
use_field: &'static str,
|
||||||
|
alg: &'static str,
|
||||||
|
n: String,
|
||||||
|
e: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base64url_no_pad(data: &[u8]) -> String {
|
||||||
|
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn verify_rs256_via_jwks() {
|
||||||
|
let private = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
|
||||||
|
let public = private.to_public_key();
|
||||||
|
|
||||||
|
let private_pem = private.to_pkcs1_pem(LineEnding::LF).unwrap().to_string();
|
||||||
|
let public_pem = public.to_pkcs1_pem(LineEnding::LF).unwrap().to_string();
|
||||||
|
|
||||||
|
let n = base64url_no_pad(&public.n().to_bytes_be());
|
||||||
|
let e = base64url_no_pad(&public.e().to_bytes_be());
|
||||||
|
|
||||||
|
let kid = "test-kid";
|
||||||
|
let jwks = Jwks {
|
||||||
|
keys: vec![Jwk {
|
||||||
|
kty: "RSA",
|
||||||
|
kid,
|
||||||
|
use_field: "sig",
|
||||||
|
alg: "RS256",
|
||||||
|
n,
|
||||||
|
e,
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
|
||||||
|
let app = Router::new().route(
|
||||||
|
"/.well-known/jwks.json",
|
||||||
|
get(move || {
|
||||||
|
let jwks = jwks;
|
||||||
|
async move { (axum::http::StatusCode::OK, Json(jwks)).into_response() }
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let base_url = format!("http://{}", addr);
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs() as usize;
|
||||||
|
|
||||||
|
let claims = Claims {
|
||||||
|
sub: uuid::Uuid::new_v4().to_string(),
|
||||||
|
tenant_id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
exp: now + 60,
|
||||||
|
iat: now,
|
||||||
|
iss: "iam-service".to_string(),
|
||||||
|
roles: vec![],
|
||||||
|
permissions: vec![],
|
||||||
|
apps: vec![],
|
||||||
|
apps_version: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut header = Header::new(Algorithm::RS256);
|
||||||
|
header.kid = Some(kid.to_string());
|
||||||
|
let token = encode(
|
||||||
|
&header,
|
||||||
|
&claims,
|
||||||
|
&EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cfg = JwtVerifyConfig::rs256_from_jwks(
|
||||||
|
"iam-service",
|
||||||
|
&format!("{}/.well-known/jwks.json", base_url),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let verified = auth_kit::jwt::verify(&token, &cfg).await.unwrap();
|
||||||
|
assert_eq!(verified.tenant_id, claims.tenant_id);
|
||||||
|
|
||||||
|
let cfg2 = JwtVerifyConfig::rs256_from_pem("iam-service", &public_pem).unwrap();
|
||||||
|
let verified2 = auth_kit::jwt::verify(&token, &cfg2).await.unwrap();
|
||||||
|
assert_eq!(verified2.sub, claims.sub);
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user