feat(project): init
This commit is contained in:
5
.cargo/config.toml
Normal file
5
.cargo/config.toml
Normal file
@@ -0,0 +1,5 @@
|
||||
[registries.kellnr]
|
||||
index = "sparse+https://kellnr.shay7sev.site/api/v1/crates/"
|
||||
|
||||
[net]
|
||||
git-fetch-with-cli = true
|
||||
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
/target
|
||||
Cargo.lock
|
||||
30
Cargo.toml
Normal file
30
Cargo.toml
Normal file
@@ -0,0 +1,30 @@
|
||||
[package]
|
||||
name = "auth-kit"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
common-telemetry = { version = "0.1.5", registry = "kellnr", default-features = false, features = [
|
||||
"response",
|
||||
"telemetry",
|
||||
"with-anyhow",
|
||||
"with-sqlx",
|
||||
] }
|
||||
|
||||
axum = "0.8.8"
|
||||
http = "1.4.0"
|
||||
jsonwebtoken = { version = "10.3.0", features = ["aws_lc_rs"] }
|
||||
dashmap = "6.1.0"
|
||||
reqwest = { version = "0.12.28", default-features = false, features = [
|
||||
"json",
|
||||
"rustls-tls",
|
||||
] }
|
||||
base64 = "0.22.1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tracing = "0.1"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
rsa = "0.9.10"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
124
README.md
Normal file
124
README.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# auth-kit
|
||||
|
||||
`auth-kit` 是一个可复用的认证/多租户中间件库(Rust + Axum),用于在多个微服务中统一实现:
|
||||
|
||||
- JWT 认证:从 `Authorization: Bearer <token>` 解析并验证 token,注入 `AuthContext`
|
||||
- 租户隔离:从 `X-Tenant-ID` 解析租户并注入 `TenantId`,并在同时存在 token 与 header 时强制一致性
|
||||
|
||||
该 crate 不包含任何业务逻辑、数据库访问或 RBAC 聚合,仅提供“请求上下文注入 + 类型提取器 + JWT 验证能力”。
|
||||
|
||||
## 依赖与约定
|
||||
|
||||
- Web 框架:Axum
|
||||
- 统一错误类型:`common_telemetry::AppError`
|
||||
- 头部约定:
|
||||
- `Authorization: Bearer <access_token>`
|
||||
- `X-Tenant-ID: <tenant_uuid>`
|
||||
|
||||
## 快速开始(在 Axum 中挂载)
|
||||
|
||||
推荐链路顺序:
|
||||
|
||||
1. `authenticate_with_config`(注入 `AuthContext`)
|
||||
2. `resolve_tenant_with_config`(注入 `TenantId`,并校验 header tenant 与 token tenant 一致)
|
||||
|
||||
示例:
|
||||
|
||||
```rust
|
||||
use axum::{Router, middleware::from_fn_with_state};
|
||||
use auth_kit::middleware::{
|
||||
auth::{self, AuthMiddlewareConfig},
|
||||
tenant::{self, TenantMiddlewareConfig},
|
||||
};
|
||||
|
||||
let auth_cfg = AuthMiddlewareConfig {
|
||||
skip_exact_paths: vec!["/healthz".to_string()],
|
||||
skip_path_prefixes: vec!["/scalar".to_string()],
|
||||
jwt: auth_kit::jwt::JwtVerifyConfig::rs256_from_jwks(
|
||||
"iam-service",
|
||||
"http://127.0.0.1:3000/.well-known/jwks.json",
|
||||
)?,
|
||||
};
|
||||
|
||||
let tenant_cfg = TenantMiddlewareConfig {
|
||||
skip_exact_paths: vec!["/healthz".to_string()],
|
||||
skip_path_prefixes: vec!["/scalar".to_string()],
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.layer(from_fn_with_state(tenant_cfg, tenant::resolve_tenant_with_config))
|
||||
.layer(from_fn_with_state(auth_cfg, auth::authenticate_with_config));
|
||||
```
|
||||
|
||||
## Handler 中如何使用(类型注入)
|
||||
|
||||
### AuthContext
|
||||
|
||||
`AuthContext` 会从 request extensions 中提取,如果中间件未运行或缺失认证信息,会返回 `AppError::MissingAuthHeader`。
|
||||
|
||||
```rust
|
||||
use auth_kit::middleware::auth::AuthContext;
|
||||
|
||||
pub async fn handler(AuthContext { user_id, tenant_id, .. }: AuthContext) {
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
### TenantId
|
||||
|
||||
`TenantId` 可从 request extensions 中提取;若不存在,会尝试从 `X-Tenant-ID` 解析;缺失则返回 `AppError::BadRequest("Missing X-Tenant-ID header")`。
|
||||
|
||||
```rust
|
||||
use auth_kit::middleware::tenant::TenantId;
|
||||
|
||||
pub async fn handler(TenantId(tenant_id): TenantId) {
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
## JWT 验证配置(JwtVerifyConfig)
|
||||
|
||||
`JwtVerifyConfig` 定义在 [jwt.rs](file:///home/shay/project/backend/auth-kit/src/jwt.rs)。
|
||||
|
||||
- **静态公钥**:`rs256_from_pem(issuer, public_key_pem)`
|
||||
- **JWKS 拉取**:`rs256_from_jwks(issuer, jwks_url)`(带缓存与降级)
|
||||
- **对称密钥(调试/兼容)**:`hs256_from_secret(issuer, secret)`
|
||||
|
||||
注意点:
|
||||
|
||||
- RS256 模式要求 token header 中包含 `kid`(用于从 JWKS 中选 key)
|
||||
- issuer 必须与 token 中 `iss` 一致
|
||||
- JWKS 目前只支持 RSA + `use=sig` + `alg=RS256`
|
||||
|
||||
### JWKS 缓存策略(当前实现)
|
||||
|
||||
- HTTP 超时:1500ms
|
||||
- 缓存 TTL:300s
|
||||
- stale-if-error:3600s(JWKS 拉取失败时可在窗口内使用缓存 key 继续验签)
|
||||
|
||||
## Skip 规则(免鉴权路径)
|
||||
|
||||
`AuthMiddlewareConfig` / `TenantMiddlewareConfig` 都支持:
|
||||
|
||||
- `skip_exact_paths`:精确路径匹配
|
||||
- `skip_path_prefixes`:前缀匹配
|
||||
|
||||
用于跳过如 `/healthz`、`/scalar`、`/.well-known/jwks.json` 等公开端点。
|
||||
|
||||
## 端到端验证(推荐)
|
||||
|
||||
本库自带一个集成测试,用于验证 RS256 + JWKS 的验签路径:
|
||||
|
||||
- 测试文件: [jwks_verify.rs](file:///home/shay/project/backend/auth-kit/tests/jwks_verify.rs)
|
||||
- 运行:
|
||||
|
||||
```bash
|
||||
cd /home/shay/project/backend/auth-kit
|
||||
cargo test
|
||||
```
|
||||
|
||||
## 已知限制
|
||||
|
||||
- 暂未提供“从环境变量自动组装配置”的辅助函数(建议业务服务自行装配 `AuthMiddlewareConfig`)
|
||||
- 暂未实现 ECDSA(ES256)/EdDSA(Ed25519)JWKS 解析
|
||||
|
||||
233
src/jwt.rs
Normal file
233
src/jwt.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
use common_telemetry::AppError;
|
||||
use dashmap::DashMap;
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
pub tenant_id: String,
|
||||
pub exp: usize,
|
||||
pub iat: usize,
|
||||
pub iss: String,
|
||||
#[serde(default)]
|
||||
pub roles: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub permissions: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub apps: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub apps_version: i32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum JwtVerifyConfig {
|
||||
Static(StaticVerifyConfig),
|
||||
Jwks(JwksVerifyConfig),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StaticVerifyConfig {
|
||||
decoding_key: DecodingKey,
|
||||
validation: Validation,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JwksVerifyConfig {
|
||||
issuer: String,
|
||||
jwks_url: String,
|
||||
client: reqwest::Client,
|
||||
cache: Arc<DashMap<String, CachedJwk>>,
|
||||
cache_ttl: Duration,
|
||||
stale_if_error: Duration,
|
||||
}
|
||||
|
||||
impl JwtVerifyConfig {
|
||||
pub fn rs256_from_pem(issuer: &str, public_key_pem: &str) -> Result<Self, AppError> {
|
||||
let decoding_key = DecodingKey::from_rsa_pem(public_key_pem.as_bytes())
|
||||
.map_err(|e| AppError::ConfigError(format!("Invalid JWT public key pem: {}", e)))?;
|
||||
let mut validation = Validation::new(Algorithm::RS256);
|
||||
validation.set_issuer(&[issuer]);
|
||||
Ok(Self::Static(StaticVerifyConfig {
|
||||
decoding_key,
|
||||
validation,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn hs256_from_secret(issuer: &str, secret: &str) -> Self {
|
||||
let decoding_key = DecodingKey::from_secret(secret.as_bytes());
|
||||
let mut validation = Validation::new(Algorithm::HS256);
|
||||
validation.set_issuer(&[issuer]);
|
||||
Self::Static(StaticVerifyConfig {
|
||||
decoding_key,
|
||||
validation,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn rs256_from_jwks(issuer: &str, jwks_url: &str) -> Result<Self, AppError> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_millis(1500))
|
||||
.build()
|
||||
.map_err(|e| AppError::ConfigError(format!("Failed to build http client: {}", e)))?;
|
||||
Ok(Self::Jwks(JwksVerifyConfig {
|
||||
issuer: issuer.to_string(),
|
||||
jwks_url: jwks_url.to_string(),
|
||||
client,
|
||||
cache: Arc::new(DashMap::new()),
|
||||
cache_ttl: Duration::from_secs(300),
|
||||
stale_if_error: Duration::from_secs(3600),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CachedJwk {
|
||||
n: String,
|
||||
e: String,
|
||||
expires_at: Instant,
|
||||
stale_until: Instant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Jwks {
|
||||
keys: Vec<Jwk>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Jwk {
|
||||
kty: String,
|
||||
kid: Option<String>,
|
||||
#[serde(rename = "use")]
|
||||
use_field: Option<String>,
|
||||
alg: Option<String>,
|
||||
n: Option<String>,
|
||||
e: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn verify(token: &str, cfg: &JwtVerifyConfig) -> Result<Claims, AppError> {
|
||||
match cfg {
|
||||
JwtVerifyConfig::Static(static_cfg) => {
|
||||
let token_data =
|
||||
decode::<Claims>(token, &static_cfg.decoding_key, &static_cfg.validation)
|
||||
.map_err(|e| AppError::AuthError(e.to_string()))?;
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
JwtVerifyConfig::Jwks(jwks_cfg) => {
|
||||
let header = decode_header(token).map_err(|e| AppError::AuthError(e.to_string()))?;
|
||||
let kid = header
|
||||
.kid
|
||||
.ok_or_else(|| AppError::AuthError("Missing kid in JWT header".into()))?;
|
||||
|
||||
let now = Instant::now();
|
||||
if let Some(entry) = jwks_cfg.cache.get(&kid).map(|e| e.clone()) {
|
||||
if entry.expires_at > now {
|
||||
return verify_with_rsa_components(token, &jwks_cfg.issuer, &entry.n, &entry.e);
|
||||
}
|
||||
}
|
||||
|
||||
let fetched = fetch_jwk_by_kid(&jwks_cfg.client, &jwks_cfg.jwks_url, &kid).await;
|
||||
match fetched {
|
||||
Ok((n, e)) => {
|
||||
let entry = CachedJwk {
|
||||
n: n.clone(),
|
||||
e: e.clone(),
|
||||
expires_at: now + jwks_cfg.cache_ttl,
|
||||
stale_until: now + jwks_cfg.cache_ttl + jwks_cfg.stale_if_error,
|
||||
};
|
||||
jwks_cfg.cache.insert(kid.clone(), entry);
|
||||
verify_with_rsa_components(token, &jwks_cfg.issuer, &n, &e)
|
||||
}
|
||||
Err(err) => {
|
||||
if let Some(entry) = jwks_cfg.cache.get(&kid).map(|e| e.clone()) {
|
||||
if entry.stale_until > now {
|
||||
return verify_with_rsa_components(
|
||||
token,
|
||||
&jwks_cfg.issuer,
|
||||
&entry.n,
|
||||
&entry.e,
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_with_rsa_components(
|
||||
token: &str,
|
||||
issuer: &str,
|
||||
n: &str,
|
||||
e: &str,
|
||||
) -> Result<Claims, AppError> {
|
||||
let decoding_key = DecodingKey::from_rsa_components(n, e)
|
||||
.map_err(|e| AppError::AuthError(format!("Invalid jwk rsa components: {}", e)))?;
|
||||
let mut validation = Validation::new(Algorithm::RS256);
|
||||
validation.set_issuer(&[issuer]);
|
||||
let token_data = decode::<Claims>(token, &decoding_key, &validation)
|
||||
.map_err(|e| AppError::AuthError(e.to_string()))?;
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
|
||||
async fn fetch_jwk_by_kid(
|
||||
client: &reqwest::Client,
|
||||
jwks_url: &str,
|
||||
kid: &str,
|
||||
) -> Result<(String, String), AppError> {
|
||||
let resp = client
|
||||
.get(jwks_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ExternalReqError(format!("jwks:request_failed:{}", e)))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(AppError::ExternalReqError(format!(
|
||||
"jwks:unexpected_status:{}",
|
||||
resp.status().as_u16()
|
||||
)));
|
||||
}
|
||||
|
||||
let jwks: Jwks = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ExternalReqError(format!("jwks:decode_failed:{}", e)))?;
|
||||
|
||||
let key = jwks
|
||||
.keys
|
||||
.into_iter()
|
||||
.find(|k| k.kid.as_deref() == Some(kid))
|
||||
.ok_or_else(|| AppError::ExternalReqError("jwks:kid_not_found".into()))?;
|
||||
|
||||
if key.kty != "RSA" {
|
||||
return Err(AppError::ExternalReqError("jwks:unsupported_kty".into()));
|
||||
}
|
||||
if key.use_field.as_deref() != Some("sig") {
|
||||
return Err(AppError::ExternalReqError("jwks:unsupported_use".into()));
|
||||
}
|
||||
if let Some(alg) = &key.alg {
|
||||
if alg != "RS256" {
|
||||
return Err(AppError::ExternalReqError("jwks:unsupported_alg".into()));
|
||||
}
|
||||
}
|
||||
|
||||
let n = key
|
||||
.n
|
||||
.ok_or_else(|| AppError::ExternalReqError("jwks:missing_n".into()))?;
|
||||
let e = key
|
||||
.e
|
||||
.ok_or_else(|| AppError::ExternalReqError("jwks:missing_e".into()))?;
|
||||
|
||||
if URL_SAFE_NO_PAD.decode(n.as_bytes()).is_err()
|
||||
|| URL_SAFE_NO_PAD.decode(e.as_bytes()).is_err()
|
||||
{
|
||||
return Err(AppError::ExternalReqError("jwks:invalid_base64url".into()));
|
||||
}
|
||||
|
||||
Ok((n, e))
|
||||
}
|
||||
2
src/lib.rs
Normal file
2
src/lib.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod jwt;
|
||||
pub mod middleware;
|
||||
104
src/middleware/auth.rs
Normal file
104
src/middleware/auth.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use axum::{
|
||||
extract::{FromRequestParts, Request, State},
|
||||
http::request::Parts,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use common_telemetry::AppError;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::jwt::{Claims, JwtVerifyConfig};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthMiddlewareConfig {
|
||||
pub skip_exact_paths: Vec<String>,
|
||||
pub skip_path_prefixes: Vec<String>,
|
||||
pub jwt: JwtVerifyConfig,
|
||||
}
|
||||
|
||||
impl AuthMiddlewareConfig {
|
||||
pub fn should_skip(&self, path: &str) -> bool {
|
||||
self.skip_exact_paths.iter().any(|p| p == path)
|
||||
|| self
|
||||
.skip_path_prefixes
|
||||
.iter()
|
||||
.any(|prefix| path.starts_with(prefix))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AuthContext {
|
||||
pub tenant_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub roles: Vec<String>,
|
||||
pub permissions: Vec<String>,
|
||||
pub apps: Vec<String>,
|
||||
pub apps_version: i32,
|
||||
}
|
||||
|
||||
fn claims_to_context(claims: Claims) -> Result<AuthContext, AppError> {
|
||||
let tenant_id = Uuid::parse_str(&claims.tenant_id)
|
||||
.map_err(|_| AppError::AuthError("Invalid tenant_id claim".into()))?;
|
||||
let user_id = Uuid::parse_str(&claims.sub)
|
||||
.map_err(|_| AppError::AuthError("Invalid sub claim".into()))?;
|
||||
|
||||
Ok(AuthContext {
|
||||
tenant_id,
|
||||
user_id,
|
||||
roles: claims.roles,
|
||||
permissions: claims.permissions,
|
||||
apps: claims.apps,
|
||||
apps_version: claims.apps_version,
|
||||
})
|
||||
}
|
||||
|
||||
async fn authenticate_inner(
|
||||
cfg: &AuthMiddlewareConfig,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AppError> {
|
||||
let path = req.uri().path();
|
||||
if cfg.should_skip(path) {
|
||||
return Ok(next.run(req).await);
|
||||
}
|
||||
|
||||
let token = req
|
||||
.headers()
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "))
|
||||
.ok_or(AppError::MissingAuthHeader)?;
|
||||
|
||||
let claims = crate::jwt::verify(token, &cfg.jwt).await?;
|
||||
let ctx = claims_to_context(claims)?;
|
||||
|
||||
tracing::Span::current().record("tenant_id", tracing::field::display(ctx.tenant_id));
|
||||
tracing::Span::current().record("user_id", tracing::field::display(ctx.user_id));
|
||||
|
||||
req.extensions_mut().insert(ctx);
|
||||
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
|
||||
pub async fn authenticate_with_config(
|
||||
State(cfg): State<AuthMiddlewareConfig>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AppError> {
|
||||
authenticate_inner(&cfg, req, next).await
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for AuthContext
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = AppError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
parts
|
||||
.extensions
|
||||
.get::<AuthContext>()
|
||||
.cloned()
|
||||
.ok_or(AppError::MissingAuthHeader)
|
||||
}
|
||||
}
|
||||
2
src/middleware/mod.rs
Normal file
2
src/middleware/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod auth;
|
||||
pub mod tenant;
|
||||
105
src/middleware/tenant.rs
Normal file
105
src/middleware/tenant.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use axum::{
|
||||
extract::{FromRequestParts, Request, State},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use common_telemetry::AppError;
|
||||
use http::request::Parts;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::middleware::auth::AuthContext;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TenantId(pub Uuid);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TenantMiddlewareConfig {
|
||||
pub skip_exact_paths: Vec<String>,
|
||||
pub skip_path_prefixes: Vec<String>,
|
||||
}
|
||||
|
||||
impl TenantMiddlewareConfig {
|
||||
pub fn should_skip(&self, path: &str) -> bool {
|
||||
self.skip_exact_paths.iter().any(|p| p == path)
|
||||
|| self
|
||||
.skip_path_prefixes
|
||||
.iter()
|
||||
.any(|prefix| path.starts_with(prefix))
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve_tenant_inner(
|
||||
cfg: &TenantMiddlewareConfig,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AppError> {
|
||||
let path = req.uri().path();
|
||||
if cfg.should_skip(path) {
|
||||
return Ok(next.run(req).await);
|
||||
}
|
||||
|
||||
if let Some(auth_tenant_id) = req.extensions().get::<AuthContext>().map(|c| c.tenant_id) {
|
||||
if let Some(header_value) = req
|
||||
.headers()
|
||||
.get("X-Tenant-ID")
|
||||
.and_then(|val| val.to_str().ok())
|
||||
{
|
||||
let header_tenant_id = Uuid::parse_str(header_value)
|
||||
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into()))?;
|
||||
if header_tenant_id != auth_tenant_id {
|
||||
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
|
||||
}
|
||||
}
|
||||
tracing::Span::current().record("tenant_id", tracing::field::display(auth_tenant_id));
|
||||
req.extensions_mut().insert(TenantId(auth_tenant_id));
|
||||
return Ok(next.run(req).await);
|
||||
}
|
||||
|
||||
let tenant_id_str = req
|
||||
.headers()
|
||||
.get("X-Tenant-ID")
|
||||
.and_then(|val| val.to_str().ok());
|
||||
|
||||
match tenant_id_str {
|
||||
Some(id_str) => {
|
||||
let uuid = Uuid::parse_str(id_str)
|
||||
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into()))?;
|
||||
tracing::Span::current().record("tenant_id", tracing::field::display(uuid));
|
||||
req.extensions_mut().insert(TenantId(uuid));
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
None => Err(AppError::BadRequest("Missing X-Tenant-ID header".into())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn resolve_tenant_with_config(
|
||||
State(cfg): State<TenantMiddlewareConfig>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AppError> {
|
||||
resolve_tenant_inner(&cfg, req, next).await
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for TenantId
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = AppError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
if let Some(tid) = parts.extensions.get::<TenantId>().cloned() {
|
||||
return Ok(tid);
|
||||
}
|
||||
let tenant_id_str = parts
|
||||
.headers
|
||||
.get("X-Tenant-ID")
|
||||
.and_then(|val| val.to_str().ok());
|
||||
|
||||
match tenant_id_str {
|
||||
Some(id_str) => uuid::Uuid::parse_str(id_str)
|
||||
.map(TenantId)
|
||||
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into())),
|
||||
None => Err(AppError::BadRequest("Missing X-Tenant-ID header".into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
111
tests/jwks_verify.rs
Normal file
111
tests/jwks_verify.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use auth_kit::jwt::{Claims, JwtVerifyConfig};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::{Json, Router, routing::get};
|
||||
use base64::Engine as _;
|
||||
use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
|
||||
use rsa::pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey};
|
||||
use rsa::rand_core::OsRng;
|
||||
use rsa::traits::PublicKeyParts;
|
||||
use rsa::{RsaPrivateKey, pkcs1::LineEnding};
|
||||
use serde::Serialize;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
struct Jwks {
|
||||
keys: Vec<Jwk>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
struct Jwk {
|
||||
kty: &'static str,
|
||||
kid: &'static str,
|
||||
#[serde(rename = "use")]
|
||||
use_field: &'static str,
|
||||
alg: &'static str,
|
||||
n: String,
|
||||
e: String,
|
||||
}
|
||||
|
||||
fn base64url_no_pad(data: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_rs256_via_jwks() {
|
||||
let private = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
|
||||
let public = private.to_public_key();
|
||||
|
||||
let private_pem = private.to_pkcs1_pem(LineEnding::LF).unwrap().to_string();
|
||||
let public_pem = public.to_pkcs1_pem(LineEnding::LF).unwrap().to_string();
|
||||
|
||||
let n = base64url_no_pad(&public.n().to_bytes_be());
|
||||
let e = base64url_no_pad(&public.e().to_bytes_be());
|
||||
|
||||
let kid = "test-kid";
|
||||
let jwks = Jwks {
|
||||
keys: vec![Jwk {
|
||||
kty: "RSA",
|
||||
kid,
|
||||
use_field: "sig",
|
||||
alg: "RS256",
|
||||
n,
|
||||
e,
|
||||
}],
|
||||
};
|
||||
|
||||
let app = Router::new().route(
|
||||
"/.well-known/jwks.json",
|
||||
get(move || {
|
||||
let jwks = jwks;
|
||||
async move { (axum::http::StatusCode::OK, Json(jwks)).into_response() }
|
||||
}),
|
||||
);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let base_url = format!("http://{}", addr);
|
||||
let handle = tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as usize;
|
||||
|
||||
let claims = Claims {
|
||||
sub: uuid::Uuid::new_v4().to_string(),
|
||||
tenant_id: uuid::Uuid::new_v4().to_string(),
|
||||
exp: now + 60,
|
||||
iat: now,
|
||||
iss: "iam-service".to_string(),
|
||||
roles: vec![],
|
||||
permissions: vec![],
|
||||
apps: vec![],
|
||||
apps_version: 0,
|
||||
};
|
||||
|
||||
let mut header = Header::new(Algorithm::RS256);
|
||||
header.kid = Some(kid.to_string());
|
||||
let token = encode(
|
||||
&header,
|
||||
&claims,
|
||||
&EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let cfg = JwtVerifyConfig::rs256_from_jwks(
|
||||
"iam-service",
|
||||
&format!("{}/.well-known/jwks.json", base_url),
|
||||
)
|
||||
.unwrap();
|
||||
let verified = auth_kit::jwt::verify(&token, &cfg).await.unwrap();
|
||||
assert_eq!(verified.tenant_id, claims.tenant_id);
|
||||
|
||||
let cfg2 = JwtVerifyConfig::rs256_from_pem("iam-service", &public_pem).unwrap();
|
||||
let verified2 = auth_kit::jwt::verify(&token, &cfg2).await.unwrap();
|
||||
assert_eq!(verified2.sub, claims.sub);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
handle.abort();
|
||||
}
|
||||
Reference in New Issue
Block a user