feat(project): init

This commit is contained in:
2026-02-02 14:30:53 +08:00
commit ccead7e9f6
10 changed files with 718 additions and 0 deletions

233
src/jwt.rs Normal file
View File

@@ -0,0 +1,233 @@
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use common_telemetry::AppError;
use dashmap::DashMap;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
use serde::{Deserialize, Serialize};
use std::{
sync::Arc,
time::{Duration, Instant},
};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: String,
pub tenant_id: String,
pub exp: usize,
pub iat: usize,
pub iss: String,
#[serde(default)]
pub roles: Vec<String>,
#[serde(default)]
pub permissions: Vec<String>,
#[serde(default)]
pub apps: Vec<String>,
#[serde(default)]
pub apps_version: i32,
}
#[derive(Clone)]
pub enum JwtVerifyConfig {
Static(StaticVerifyConfig),
Jwks(JwksVerifyConfig),
}
#[derive(Clone)]
pub struct StaticVerifyConfig {
decoding_key: DecodingKey,
validation: Validation,
}
#[derive(Clone)]
pub struct JwksVerifyConfig {
issuer: String,
jwks_url: String,
client: reqwest::Client,
cache: Arc<DashMap<String, CachedJwk>>,
cache_ttl: Duration,
stale_if_error: Duration,
}
impl JwtVerifyConfig {
pub fn rs256_from_pem(issuer: &str, public_key_pem: &str) -> Result<Self, AppError> {
let decoding_key = DecodingKey::from_rsa_pem(public_key_pem.as_bytes())
.map_err(|e| AppError::ConfigError(format!("Invalid JWT public key pem: {}", e)))?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[issuer]);
Ok(Self::Static(StaticVerifyConfig {
decoding_key,
validation,
}))
}
pub fn hs256_from_secret(issuer: &str, secret: &str) -> Self {
let decoding_key = DecodingKey::from_secret(secret.as_bytes());
let mut validation = Validation::new(Algorithm::HS256);
validation.set_issuer(&[issuer]);
Self::Static(StaticVerifyConfig {
decoding_key,
validation,
})
}
pub fn rs256_from_jwks(issuer: &str, jwks_url: &str) -> Result<Self, AppError> {
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(1500))
.build()
.map_err(|e| AppError::ConfigError(format!("Failed to build http client: {}", e)))?;
Ok(Self::Jwks(JwksVerifyConfig {
issuer: issuer.to_string(),
jwks_url: jwks_url.to_string(),
client,
cache: Arc::new(DashMap::new()),
cache_ttl: Duration::from_secs(300),
stale_if_error: Duration::from_secs(3600),
}))
}
}
#[derive(Clone)]
struct CachedJwk {
n: String,
e: String,
expires_at: Instant,
stale_until: Instant,
}
#[derive(Debug, Deserialize)]
struct Jwks {
keys: Vec<Jwk>,
}
#[derive(Debug, Deserialize)]
struct Jwk {
kty: String,
kid: Option<String>,
#[serde(rename = "use")]
use_field: Option<String>,
alg: Option<String>,
n: Option<String>,
e: Option<String>,
}
pub async fn verify(token: &str, cfg: &JwtVerifyConfig) -> Result<Claims, AppError> {
match cfg {
JwtVerifyConfig::Static(static_cfg) => {
let token_data =
decode::<Claims>(token, &static_cfg.decoding_key, &static_cfg.validation)
.map_err(|e| AppError::AuthError(e.to_string()))?;
Ok(token_data.claims)
}
JwtVerifyConfig::Jwks(jwks_cfg) => {
let header = decode_header(token).map_err(|e| AppError::AuthError(e.to_string()))?;
let kid = header
.kid
.ok_or_else(|| AppError::AuthError("Missing kid in JWT header".into()))?;
let now = Instant::now();
if let Some(entry) = jwks_cfg.cache.get(&kid).map(|e| e.clone()) {
if entry.expires_at > now {
return verify_with_rsa_components(token, &jwks_cfg.issuer, &entry.n, &entry.e);
}
}
let fetched = fetch_jwk_by_kid(&jwks_cfg.client, &jwks_cfg.jwks_url, &kid).await;
match fetched {
Ok((n, e)) => {
let entry = CachedJwk {
n: n.clone(),
e: e.clone(),
expires_at: now + jwks_cfg.cache_ttl,
stale_until: now + jwks_cfg.cache_ttl + jwks_cfg.stale_if_error,
};
jwks_cfg.cache.insert(kid.clone(), entry);
verify_with_rsa_components(token, &jwks_cfg.issuer, &n, &e)
}
Err(err) => {
if let Some(entry) = jwks_cfg.cache.get(&kid).map(|e| e.clone()) {
if entry.stale_until > now {
return verify_with_rsa_components(
token,
&jwks_cfg.issuer,
&entry.n,
&entry.e,
);
}
}
Err(err)
}
}
}
}
}
fn verify_with_rsa_components(
token: &str,
issuer: &str,
n: &str,
e: &str,
) -> Result<Claims, AppError> {
let decoding_key = DecodingKey::from_rsa_components(n, e)
.map_err(|e| AppError::AuthError(format!("Invalid jwk rsa components: {}", e)))?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[issuer]);
let token_data = decode::<Claims>(token, &decoding_key, &validation)
.map_err(|e| AppError::AuthError(e.to_string()))?;
Ok(token_data.claims)
}
async fn fetch_jwk_by_kid(
client: &reqwest::Client,
jwks_url: &str,
kid: &str,
) -> Result<(String, String), AppError> {
let resp = client
.get(jwks_url)
.send()
.await
.map_err(|e| AppError::ExternalReqError(format!("jwks:request_failed:{}", e)))?;
if !resp.status().is_success() {
return Err(AppError::ExternalReqError(format!(
"jwks:unexpected_status:{}",
resp.status().as_u16()
)));
}
let jwks: Jwks = resp
.json()
.await
.map_err(|e| AppError::ExternalReqError(format!("jwks:decode_failed:{}", e)))?;
let key = jwks
.keys
.into_iter()
.find(|k| k.kid.as_deref() == Some(kid))
.ok_or_else(|| AppError::ExternalReqError("jwks:kid_not_found".into()))?;
if key.kty != "RSA" {
return Err(AppError::ExternalReqError("jwks:unsupported_kty".into()));
}
if key.use_field.as_deref() != Some("sig") {
return Err(AppError::ExternalReqError("jwks:unsupported_use".into()));
}
if let Some(alg) = &key.alg {
if alg != "RS256" {
return Err(AppError::ExternalReqError("jwks:unsupported_alg".into()));
}
}
let n = key
.n
.ok_or_else(|| AppError::ExternalReqError("jwks:missing_n".into()))?;
let e = key
.e
.ok_or_else(|| AppError::ExternalReqError("jwks:missing_e".into()))?;
if URL_SAFE_NO_PAD.decode(n.as_bytes()).is_err()
|| URL_SAFE_NO_PAD.decode(e.as_bytes()).is_err()
{
return Err(AppError::ExternalReqError("jwks:invalid_base64url".into()));
}
Ok((n, e))
}

2
src/lib.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod jwt;
pub mod middleware;

104
src/middleware/auth.rs Normal file
View File

@@ -0,0 +1,104 @@
use axum::{
extract::{FromRequestParts, Request, State},
http::request::Parts,
middleware::Next,
response::Response,
};
use common_telemetry::AppError;
use uuid::Uuid;
use crate::jwt::{Claims, JwtVerifyConfig};
#[derive(Clone)]
pub struct AuthMiddlewareConfig {
pub skip_exact_paths: Vec<String>,
pub skip_path_prefixes: Vec<String>,
pub jwt: JwtVerifyConfig,
}
impl AuthMiddlewareConfig {
pub fn should_skip(&self, path: &str) -> bool {
self.skip_exact_paths.iter().any(|p| p == path)
|| self
.skip_path_prefixes
.iter()
.any(|prefix| path.starts_with(prefix))
}
}
#[derive(Clone, Debug)]
pub struct AuthContext {
pub tenant_id: Uuid,
pub user_id: Uuid,
pub roles: Vec<String>,
pub permissions: Vec<String>,
pub apps: Vec<String>,
pub apps_version: i32,
}
fn claims_to_context(claims: Claims) -> Result<AuthContext, AppError> {
let tenant_id = Uuid::parse_str(&claims.tenant_id)
.map_err(|_| AppError::AuthError("Invalid tenant_id claim".into()))?;
let user_id = Uuid::parse_str(&claims.sub)
.map_err(|_| AppError::AuthError("Invalid sub claim".into()))?;
Ok(AuthContext {
tenant_id,
user_id,
roles: claims.roles,
permissions: claims.permissions,
apps: claims.apps,
apps_version: claims.apps_version,
})
}
async fn authenticate_inner(
cfg: &AuthMiddlewareConfig,
mut req: Request,
next: Next,
) -> Result<Response, AppError> {
let path = req.uri().path();
if cfg.should_skip(path) {
return Ok(next.run(req).await);
}
let token = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.ok_or(AppError::MissingAuthHeader)?;
let claims = crate::jwt::verify(token, &cfg.jwt).await?;
let ctx = claims_to_context(claims)?;
tracing::Span::current().record("tenant_id", tracing::field::display(ctx.tenant_id));
tracing::Span::current().record("user_id", tracing::field::display(ctx.user_id));
req.extensions_mut().insert(ctx);
Ok(next.run(req).await)
}
pub async fn authenticate_with_config(
State(cfg): State<AuthMiddlewareConfig>,
req: Request,
next: Next,
) -> Result<Response, AppError> {
authenticate_inner(&cfg, req, next).await
}
impl<S> FromRequestParts<S> for AuthContext
where
S: Send + Sync,
{
type Rejection = AppError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<AuthContext>()
.cloned()
.ok_or(AppError::MissingAuthHeader)
}
}

2
src/middleware/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod auth;
pub mod tenant;

105
src/middleware/tenant.rs Normal file
View File

@@ -0,0 +1,105 @@
use axum::{
extract::{FromRequestParts, Request, State},
middleware::Next,
response::Response,
};
use common_telemetry::AppError;
use http::request::Parts;
use uuid::Uuid;
use crate::middleware::auth::AuthContext;
#[derive(Clone, Debug)]
pub struct TenantId(pub Uuid);
#[derive(Clone)]
pub struct TenantMiddlewareConfig {
pub skip_exact_paths: Vec<String>,
pub skip_path_prefixes: Vec<String>,
}
impl TenantMiddlewareConfig {
pub fn should_skip(&self, path: &str) -> bool {
self.skip_exact_paths.iter().any(|p| p == path)
|| self
.skip_path_prefixes
.iter()
.any(|prefix| path.starts_with(prefix))
}
}
async fn resolve_tenant_inner(
cfg: &TenantMiddlewareConfig,
mut req: Request,
next: Next,
) -> Result<Response, AppError> {
let path = req.uri().path();
if cfg.should_skip(path) {
return Ok(next.run(req).await);
}
if let Some(auth_tenant_id) = req.extensions().get::<AuthContext>().map(|c| c.tenant_id) {
if let Some(header_value) = req
.headers()
.get("X-Tenant-ID")
.and_then(|val| val.to_str().ok())
{
let header_tenant_id = Uuid::parse_str(header_value)
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into()))?;
if header_tenant_id != auth_tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
}
tracing::Span::current().record("tenant_id", tracing::field::display(auth_tenant_id));
req.extensions_mut().insert(TenantId(auth_tenant_id));
return Ok(next.run(req).await);
}
let tenant_id_str = req
.headers()
.get("X-Tenant-ID")
.and_then(|val| val.to_str().ok());
match tenant_id_str {
Some(id_str) => {
let uuid = Uuid::parse_str(id_str)
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into()))?;
tracing::Span::current().record("tenant_id", tracing::field::display(uuid));
req.extensions_mut().insert(TenantId(uuid));
Ok(next.run(req).await)
}
None => Err(AppError::BadRequest("Missing X-Tenant-ID header".into())),
}
}
pub async fn resolve_tenant_with_config(
State(cfg): State<TenantMiddlewareConfig>,
req: Request,
next: Next,
) -> Result<Response, AppError> {
resolve_tenant_inner(&cfg, req, next).await
}
impl<S> FromRequestParts<S> for TenantId
where
S: Send + Sync,
{
type Rejection = AppError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(tid) = parts.extensions.get::<TenantId>().cloned() {
return Ok(tid);
}
let tenant_id_str = parts
.headers
.get("X-Tenant-ID")
.and_then(|val| val.to_str().ok());
match tenant_id_str {
Some(id_str) => uuid::Uuid::parse_str(id_str)
.map(TenantId)
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into())),
None => Err(AppError::BadRequest("Missing X-Tenant-ID header".into())),
}
}
}