fix(handlers): add handlers

This commit is contained in:
2026-01-30 16:31:53 +08:00
parent bb82c75834
commit ce12b997f4
38 changed files with 3746 additions and 317 deletions

View File

@@ -8,13 +8,15 @@ pub struct AppConfig {
pub log_dir: String,
pub log_file_name: String,
pub database_url: String,
pub db_max_connections: u32,
pub db_min_connections: u32,
pub jwt_secret: String,
pub port: u16,
}
impl AppConfig {
pub fn from_env() -> Self {
Self {
pub fn from_env() -> Result<Self, String> {
Ok(Self {
service_name: env::var("SERVICE_NAME").unwrap_or_else(|_| "iam-service".into()),
log_level: env::var("LOG_LEVEL").unwrap_or_else(|_| "info".into()),
log_to_file: env::var("LOG_TO_FILE")
@@ -22,13 +24,14 @@ impl AppConfig {
.unwrap_or(false),
log_dir: env::var("LOG_DIR").unwrap_or_else(|_| "./log".into()),
log_file_name: env::var("LOG_FILE_NAME").unwrap_or_else(|_| "iam.log".into()),
database_url: env::var("DATABASE_URL").expect("DATABASE_URL required"),
jwt_secret: env::var("JWT_SECRET")
.expect("JWT_SECRET required, generate by run 'openssl rand -base64 32'"),
database_url: env::var("DATABASE_URL").map_err(|_| "DATABASE_URL environment variable is required")?,
db_max_connections: env::var("DB_MAX_CONNECTIONS").unwrap_or("20".into()).parse().map_err(|_| "DB_MAX_CONNECTIONS must be a number")?,
db_min_connections: env::var("DB_MIN_CONNECTIONS").unwrap_or("5".into()).parse().map_err(|_| "DB_MIN_CONNECTIONS must be a number")?,
jwt_secret: env::var("JWT_SECRET").map_err(|_| "JWT_SECRET environment variable is required")?,
port: env::var("PORT")
.unwrap_or_else(|_| "3000".to_string())
.parse()
.unwrap(),
}
.map_err(|_| "PORT must be a valid number")?,
})
}
}

View File

@@ -1,13 +1,14 @@
use crate::config::AppConfig;
use sqlx::postgres::{PgPool, PgPoolOptions};
use std::time::Duration;
/// 初始化数据库连接池
pub async fn init_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
pub async fn init_pool(config: &AppConfig) -> Result<PgPool, sqlx::Error> {
PgPoolOptions::new()
.max_connections(20) // 根据服务器规格调整IAM服务通常并发高
.min_connections(5)
.max_connections(config.db_max_connections)
.min_connections(config.db_min_connections)
.acquire_timeout(Duration::from_secs(3)) // 获取连接超时时间
.connect(database_url)
.connect(&config.database_url)
.await
}

105
src/docs.rs Normal file
View File

@@ -0,0 +1,105 @@
use crate::handlers;
use crate::models::{
CreateRoleRequest, CreateTenantRequest, CreateUserRequest, LoginRequest, LoginResponse, Role,
RoleResponse, Tenant, TenantResponse, UpdateTenantRequest, UpdateTenantStatusRequest,
UpdateUserRequest, User, UserResponse,
};
use utoipa::openapi::security::{HttpAuthScheme, HttpBuilder, SecurityScheme};
use utoipa::{Modify, OpenApi};
struct SecurityAddon;
impl Modify for SecurityAddon {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
let components = openapi
.components
.get_or_insert_with(utoipa::openapi::Components::new);
components.add_security_scheme(
"bearer_auth",
SecurityScheme::Http(
HttpBuilder::new()
.scheme(HttpAuthScheme::Bearer)
.bearer_format("JWT")
.build(),
),
);
}
}
#[derive(OpenApi)]
#[openapi(
modifiers(&SecurityAddon),
info(
title = "IAM Service API",
version = "0.1.0",
description = include_str!("../docs/SCALAR_GUIDE.md")
),
paths(
handlers::auth::register_handler,
handlers::auth::login_handler,
handlers::authorization::my_permissions_handler,
handlers::tenant::create_tenant_handler,
handlers::tenant::get_tenant_handler,
handlers::tenant::update_tenant_handler,
handlers::tenant::update_tenant_status_handler,
handlers::tenant::delete_tenant_handler,
handlers::role::create_role_handler,
handlers::role::list_roles_handler,
handlers::user::list_users_handler,
handlers::user::get_user_handler,
handlers::user::update_user_handler,
handlers::user::delete_user_handler,
// Add other handlers here as you implement them
),
components(
schemas(
User,
UserResponse,
CreateUserRequest,
UpdateUserRequest,
LoginRequest,
LoginResponse,
Role,
CreateRoleRequest,
RoleResponse,
Tenant,
TenantResponse,
CreateTenantRequest,
UpdateTenantRequest,
UpdateTenantStatusRequest
)
),
tags(
(name = "Auth", description = "认证:注册/登录/令牌"),
(name = "Tenant", description = "租户:创建/查询/更新/状态/删除"),
(name = "User", description = "用户:查询/列表/更新/删除(需权限)"),
(name = "Role", description = "角色:创建/列表(需权限)"),
(name = "Me", description = "当前用户:权限自查等"),
(name = "Policy", description = "策略预留ABAC/策略引擎后续扩展)")
)
)]
pub struct ApiDoc;
#[cfg(test)]
mod tests {
use super::ApiDoc;
use utoipa::OpenApi;
#[test]
fn openapi_schema_contains_defaults() {
let doc = ApiDoc::openapi();
let json = serde_json::to_value(&doc).unwrap();
let token_type_default = json
.pointer("/components/schemas/LoginResponse/properties/token_type/default")
.and_then(|v| v.as_str())
.unwrap_or_default();
assert_eq!(token_type_default, "Bearer");
let tenant_status_default = json
.pointer("/components/schemas/Tenant/properties/status/default")
.and_then(|v| v.as_str())
.unwrap_or_default();
assert_eq!(tenant_status_default, "active");
}
}

67
src/handlers/auth.rs Normal file
View File

@@ -0,0 +1,67 @@
use crate::handlers::AppState;
use crate::middleware::TenantId;
use crate::models::{CreateUserRequest, LoginRequest, LoginResponse, UserResponse};
use axum::{Json, extract::State};
use common_telemetry::{AppError, AppResponse};
use tracing::instrument;
/// 注册接口
#[utoipa::path(
post,
path = "/auth/register",
tag = "Auth",
request_body = CreateUserRequest,
responses(
(status = 201, description = "User created", body = UserResponse),
(status = 400, description = "Bad request"),
(status = 429, description = "Too many requests")
),
params(
("X-Tenant-ID" = String, Header, description = "Tenant UUID")
)
)]
#[instrument(skip(state, payload))]
pub async fn register_handler(
// 1. 自动注入 TenantId (由中间件解析)
TenantId(tenant_id): TenantId,
// 2. 获取全局状态中的 Service
State(state): State<AppState>,
// 3. 获取 Body
Json(payload): Json<CreateUserRequest>,
) -> Result<AppResponse<UserResponse>, AppError> {
let user = state.auth_service.register(tenant_id, payload).await?;
// 转换为 Response DTO (隐藏密码等敏感信息)
let response = UserResponse {
id: user.id,
email: user.email.clone(),
};
Ok(AppResponse::created(response))
}
/// 登录接口
#[utoipa::path(
post,
path = "/auth/login",
tag = "Auth",
request_body = LoginRequest,
responses(
(status = 200, description = "Login successful", body = LoginResponse),
(status = 401, description = "Unauthorized"),
(status = 429, description = "Too many requests")
),
params(
("X-Tenant-ID" = String, Header, description = "Tenant UUID")
)
)]
#[instrument(skip(state, payload))]
pub async fn login_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
Json(payload): Json<LoginRequest>,
) -> Result<AppResponse<LoginResponse>, AppError> {
let response = state.auth_service.login(tenant_id, payload).await?;
Ok(AppResponse::ok(response))
}

View File

@@ -0,0 +1,59 @@
use crate::handlers::AppState;
use crate::middleware::TenantId;
use crate::middleware::auth::AuthContext;
use axum::extract::State;
use common_telemetry::{AppError, AppResponse};
use tracing::instrument;
#[utoipa::path(
get,
path = "/me/permissions",
tag = "Me",
security(
("bearer_auth" = [])
),
responses(
(status = 200, description = "当前用户权限列表", body = [String]),
(status = 401, description = "未认证"),
(status = 403, description = "无权限")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)")
)
)]
#[instrument(skip(state))]
/// 查询当前登录用户在当前租户下的权限编码列表。
///
/// 用途:
/// - 快速自查当前令牌是否携带期望的权限(便于联调与排障)。
///
/// 输入:
/// - Header `Authorization: Bearer <access_token>`(必填)
/// - Header `X-Tenant-ID`(可选;若提供需与 Token 中 tenant_id 一致,否则返回 403
///
/// 输出:
/// - `200`:权限字符串数组(如 `user:read`
///
/// 异常:
/// - `401`:未携带或无法解析访问令牌
/// - `403`:租户不匹配或无权访问
pub async fn my_permissions_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
) -> Result<AppResponse<Vec<String>>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
let permissions = state
.authorization_service
.list_permissions_for_user(tenant_id, user_id)
.await?;
Ok(AppResponse::ok(permissions))
}

View File

@@ -1,48 +1,28 @@
use crate::middleware::TenantId;
use crate::models::{CreateUserRequest, UserResponse};
use crate::services::AuthService;
use axum::{Json, extract::State};
use common_telemetry::AppError; // 引入刚刚写的中间件类型
pub mod authorization;
pub mod auth;
pub mod role;
pub mod tenant;
pub mod user;
use crate::services::{AuthService, AuthorizationService, RoleService, TenantService, UserService};
pub use auth::{login_handler, register_handler};
pub use authorization::my_permissions_handler;
pub use role::{create_role_handler, list_roles_handler};
pub use tenant::{
create_tenant_handler, delete_tenant_handler, get_tenant_handler, update_tenant_handler,
update_tenant_status_handler,
};
pub use user::{
delete_user_handler, get_user_handler, list_users_handler, update_user_handler,
};
// 状态对象,包含 Service
#[derive(Clone)]
pub struct AppState {
pub auth_service: AuthService,
}
/// 注册接口
#[utoipa::path(
post,
path = "/auth/register",
request_body = CreateUserRequest,
responses(
(status = 201, description = "User created", body = UserResponse),
(status = 400, description = "Bad request")
),
params(
("X-Tenant-ID" = String, Header, description = "Tenant UUID")
)
)]
pub async fn register_handler(
// 1. 自动注入 TenantId (由中间件解析)
TenantId(tenant_id): TenantId,
// 2. 获取全局状态中的 Service
State(state): State<AppState>,
// 3. 获取 Body
Json(payload): Json<CreateUserRequest>,
) -> Result<Json<UserResponse>, AppError> {
let user = state
.auth_service
.register(tenant_id, payload)
.await
.map_err(AppError::BadRequest)?;
// 转换为 Response DTO (隐藏密码等敏感信息)
let response = UserResponse {
id: user.id,
email: user.email.clone(),
// ...
};
Ok(Json(response))
pub user_service: UserService,
pub role_service: RoleService,
pub tenant_service: TenantService,
pub authorization_service: AuthorizationService,
}

132
src/handlers/role.rs Normal file
View File

@@ -0,0 +1,132 @@
use crate::handlers::AppState;
use crate::middleware::TenantId;
use crate::middleware::auth::AuthContext;
use crate::models::{CreateRoleRequest, RoleResponse};
use axum::{Json, extract::State};
use common_telemetry::{AppError, AppResponse};
use tracing::instrument;
#[utoipa::path(
post,
path = "/roles",
tag = "Role",
security(
("bearer_auth" = [])
),
request_body = CreateRoleRequest,
responses(
(status = 201, description = "角色创建成功", body = RoleResponse),
(status = 400, description = "请求参数错误"),
(status = 401, description = "未认证"),
(status = 403, description = "无权限")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)")
)
)]
#[instrument(skip(state, payload))]
/// 在当前租户下创建角色。
///
/// 业务规则:
/// - 角色归属到当前租户(由 `TenantId` 决定),禁止跨租户写入。
/// - 需要具备 `role:write` 权限,否则返回 403。
///
/// 输入:
/// - Header `Authorization: Bearer <access_token>`(必填)
/// - Body `CreateRoleRequest`(必填)
///
/// 输出:
/// - `201`:返回新建角色信息(含 `id`
///
/// 异常:
/// - `401`:未认证
/// - `403`:租户不匹配或无权限
/// - `400`:请求参数错误
pub async fn create_role_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
Json(payload): Json<CreateRoleRequest>,
) -> Result<AppResponse<RoleResponse>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
state
.authorization_service
.require_permission(tenant_id, user_id, "role:write")
.await?;
let role = state.role_service.create_role(tenant_id, payload).await?;
Ok(AppResponse::created(RoleResponse {
id: role.id,
name: role.name,
description: role.description,
}))
}
#[utoipa::path(
get,
path = "/roles",
tag = "Role",
security(
("bearer_auth" = [])
),
responses(
(status = 200, description = "角色列表", body = [RoleResponse]),
(status = 401, description = "未认证"),
(status = 403, description = "无权限")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)")
)
)]
#[instrument(skip(state))]
/// 查询当前租户下的角色列表。
///
/// 业务规则:
/// - 仅返回当前租户角色;若 `X-Tenant-ID` 与 Token 不一致则返回 403。
/// - 需要具备 `role:read` 权限。
///
/// 输入:
/// - Header `Authorization: Bearer <access_token>`(必填)
///
/// 输出:
/// - `200`:角色列表
///
/// 异常:
/// - `401`:未认证
/// - `403`:租户不匹配或无权限
pub async fn list_roles_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
) -> Result<AppResponse<Vec<RoleResponse>>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
state
.authorization_service
.require_permission(tenant_id, user_id, "role:read")
.await?;
let roles = state.role_service.list_roles(tenant_id).await?;
let response = roles
.into_iter()
.map(|r| RoleResponse {
id: r.id,
name: r.name,
description: r.description,
})
.collect();
Ok(AppResponse::ok(response))
}

294
src/handlers/tenant.rs Normal file
View File

@@ -0,0 +1,294 @@
use crate::handlers::AppState;
use crate::middleware::TenantId;
use crate::middleware::auth::AuthContext;
use crate::models::{
CreateTenantRequest, TenantResponse, UpdateTenantRequest, UpdateTenantStatusRequest,
};
use axum::{Json, extract::State};
use common_telemetry::{AppError, AppResponse};
use tracing::instrument;
#[utoipa::path(
post,
path = "/tenants/register",
tag = "Tenant",
request_body = CreateTenantRequest,
responses(
(status = 201, description = "租户创建成功", body = TenantResponse),
(status = 400, description = "请求参数错误")
)
)]
#[instrument(skip(state, payload))]
/// 创建租户(公开接口)。
///
/// 业务规则:
/// - 新租户默认 `status=active`。
/// - `config` 未提供时默认 `{}`。
///
/// 输入:
/// - Body `CreateTenantRequest`(必填)
///
/// 输出:
/// - `201`:返回新建租户信息(含 `id`
///
/// 异常:
/// - `400`:请求参数错误
pub async fn create_tenant_handler(
State(state): State<AppState>,
Json(payload): Json<CreateTenantRequest>,
) -> Result<AppResponse<TenantResponse>, AppError> {
let tenant = state.tenant_service.create_tenant(payload).await?;
let response = TenantResponse {
id: tenant.id,
name: tenant.name,
status: tenant.status,
config: tenant.config,
};
Ok(AppResponse::created(response))
}
#[utoipa::path(
get,
path = "/tenants/me",
tag = "Tenant",
security(
("bearer_auth" = [])
),
responses(
(status = 200, description = "获取当前租户信息", body = TenantResponse),
(status = 401, description = "未认证"),
(status = 403, description = "无权限")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)")
)
)]
#[instrument(skip(state))]
/// 获取当前登录用户所属租户的信息。
///
/// 业务规则:
/// - 若同时提供 `X-Tenant-ID` 与 Token 中租户不一致,返回 403tenant:mismatch
/// - 需要具备 `tenant:read` 权限。
///
/// 输入:
/// - Header `Authorization: Bearer <access_token>`(必填)
/// - Header `X-Tenant-ID`(可选;若提供需与 Token 一致)
///
/// 输出:
/// - `200`:租户信息
///
/// 异常:
/// - `401`:未认证
/// - `403`:租户不匹配或无权限
pub async fn get_tenant_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
) -> Result<AppResponse<TenantResponse>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
state
.authorization_service
.require_permission(tenant_id, user_id, "tenant:read")
.await?;
let tenant = state.tenant_service.get_tenant(tenant_id).await?;
let response = TenantResponse {
id: tenant.id,
name: tenant.name,
status: tenant.status,
config: tenant.config,
};
Ok(AppResponse::ok(response))
}
#[utoipa::path(
patch,
path = "/tenants/me",
tag = "Tenant",
security(
("bearer_auth" = [])
),
request_body = UpdateTenantRequest,
responses(
(status = 200, description = "租户更新成功", body = TenantResponse),
(status = 400, description = "请求参数错误"),
(status = 401, description = "未认证"),
(status = 403, description = "无权限")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)")
)
)]
#[instrument(skip(state, payload))]
/// 更新当前租户的基础信息(名称 / 配置)。
///
/// 业务规则:
/// - 只允许更新当前登录租户;租户不一致返回 403。
/// - 需要具备 `tenant:write` 权限。
///
/// 输入:
/// - Header `Authorization: Bearer <access_token>`(必填)
/// - Body `UpdateTenantRequest``name` / `config` 为可选字段,未提供则保持不变
///
/// 输出:
/// - `200`:更新后的租户信息
///
/// 异常:
/// - `400`:请求参数错误
/// - `401`:未认证
/// - `403`:租户不匹配或无权限
pub async fn update_tenant_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
Json(payload): Json<UpdateTenantRequest>,
) -> Result<AppResponse<TenantResponse>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
state
.authorization_service
.require_permission(tenant_id, user_id, "tenant:write")
.await?;
let tenant = state
.tenant_service
.update_tenant(tenant_id, payload)
.await?;
let response = TenantResponse {
id: tenant.id,
name: tenant.name,
status: tenant.status,
config: tenant.config,
};
Ok(AppResponse::ok(response))
}
#[utoipa::path(
post,
path = "/tenants/me/status",
tag = "Tenant",
security(
("bearer_auth" = [])
),
request_body = UpdateTenantStatusRequest,
responses(
(status = 200, description = "租户状态更新成功", body = TenantResponse),
(status = 400, description = "请求参数错误"),
(status = 401, description = "未认证"),
(status = 403, description = "无权限")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)")
)
)]
#[instrument(skip(state, payload))]
/// 更新当前租户状态(如 active / disabled
///
/// 业务规则:
/// - 只允许更新当前登录租户;租户不一致返回 403。
/// - 需要具备 `tenant:write` 权限。
///
/// 输入:
/// - Header `Authorization: Bearer <access_token>`(必填)
/// - Body `UpdateTenantStatusRequest.status`(必填)
///
/// 输出:
/// - `200`:更新后的租户信息
///
/// 异常:
/// - `400`:请求参数错误
/// - `401`:未认证
/// - `403`:租户不匹配或无权限
pub async fn update_tenant_status_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
Json(payload): Json<UpdateTenantStatusRequest>,
) -> Result<AppResponse<TenantResponse>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
state
.authorization_service
.require_permission(tenant_id, user_id, "tenant:write")
.await?;
let tenant = state
.tenant_service
.update_tenant_status(tenant_id, payload)
.await?;
let response = TenantResponse {
id: tenant.id,
name: tenant.name,
status: tenant.status,
config: tenant.config,
};
Ok(AppResponse::ok(response))
}
#[utoipa::path(
delete,
path = "/tenants/me",
tag = "Tenant",
security(
("bearer_auth" = [])
),
responses(
(status = 200, description = "租户删除成功"),
(status = 401, description = "未认证"),
(status = 403, description = "无权限"),
(status = 404, description = "未找到")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)")
)
)]
#[instrument(skip(state))]
/// 删除当前租户。
///
/// 业务规则:
/// - 只允许删除当前登录租户;租户不一致返回 403。
/// - 需要具备 `tenant:write` 权限。
///
/// 输出:
/// - `200`:删除成功(空响应)
///
/// 异常:
/// - `401`:未认证
/// - `403`:租户不匹配或无权限
/// - `404`:租户不存在
pub async fn delete_tenant_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
) -> Result<AppResponse<()>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
state
.authorization_service
.require_permission(tenant_id, user_id, "tenant:write")
.await?;
state.tenant_service.delete_tenant(tenant_id).await?;
Ok(AppResponse::ok_empty())
}

293
src/handlers/user.rs Normal file
View File

@@ -0,0 +1,293 @@
use crate::handlers::AppState;
use crate::middleware::TenantId;
use crate::middleware::auth::AuthContext;
use crate::models::{UpdateUserRequest, UserResponse};
use axum::{
Json,
extract::{Path, Query, State},
};
use common_telemetry::{AppError, AppResponse};
use serde::Deserialize;
use tracing::instrument;
use uuid::Uuid;
#[derive(Debug, Deserialize)]
pub struct ListUsersQuery {
pub page: Option<u32>,
pub page_size: Option<u32>,
}
#[utoipa::path(
get,
path = "/users",
tag = "User",
security(
("bearer_auth" = [])
),
responses(
(status = 200, description = "用户列表", body = [UserResponse]),
(status = 400, description = "请求参数错误"),
(status = 401, description = "未认证"),
(status = 403, description = "无权限")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)"),
("page" = Option<u32>, Query, description = "页码,默认 1"),
("page_size" = Option<u32>, Query, description = "每页数量,默认 20最大 200")
)
)]
#[instrument(skip(state))]
/// 分页查询当前租户下的用户列表。
///
/// 业务规则:
/// - 仅返回当前租户用户;租户不一致返回 403。
/// - 需要具备 `user:read` 权限。
/// - 分页参数约束:`page>=1``page_size` 范围 `1..=200`。
///
/// 输入:
/// - Header `Authorization: Bearer <access_token>`(必填)
/// - Query `page` / `page_size`(可选)
///
/// 输出:
/// - `200`:用户列表
///
/// 异常:
/// - `400`:分页参数非法
/// - `401`:未认证
/// - `403`:租户不匹配或无权限
pub async fn list_users_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
Query(query): Query<ListUsersQuery>,
) -> Result<AppResponse<Vec<UserResponse>>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
state
.authorization_service
.require_permission(tenant_id, user_id, "user:read")
.await?;
let page = query.page.unwrap_or(1);
let page_size = query.page_size.unwrap_or(20);
if page == 0 || page_size == 0 || page_size > 200 {
return Err(AppError::BadRequest("Invalid pagination parameters".into()));
}
let users = state
.user_service
.list_users(tenant_id, page, page_size)
.await?;
let response = users
.into_iter()
.map(|u| UserResponse {
id: u.id,
email: u.email,
})
.collect();
Ok(AppResponse::ok(response))
}
#[utoipa::path(
get,
path = "/users/{id}",
tag = "User",
security(
("bearer_auth" = [])
),
responses(
(status = 200, description = "用户详情", body = UserResponse),
(status = 401, description = "未认证"),
(status = 403, description = "无权限"),
(status = 404, description = "未找到")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)"),
("id" = String, Path, description = "用户 UUID")
)
)]
#[instrument(skip(state))]
/// 根据用户 ID 查询用户详情。
///
/// 业务规则:
/// - 仅允许查询当前租户用户;租户不一致返回 403。
/// - 需要具备 `user:read` 权限。
///
/// 输入:
/// - Path `id`:用户 UUID
/// - Header `Authorization: Bearer <access_token>`(必填)
///
/// 输出:
/// - `200`:用户信息
///
/// 异常:
/// - `401`:未认证
/// - `403`:租户不匹配或无权限
/// - `404`:用户不存在
pub async fn get_user_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
Path(target_user_id): Path<Uuid>,
) -> Result<AppResponse<UserResponse>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
state
.authorization_service
.require_permission(tenant_id, user_id, "user:read")
.await?;
let user = state
.user_service
.get_user_by_id(tenant_id, target_user_id)
.await?;
Ok(AppResponse::ok(UserResponse {
id: user.id,
email: user.email,
}))
}
#[utoipa::path(
patch,
path = "/users/{id}",
tag = "User",
security(
("bearer_auth" = [])
),
request_body = UpdateUserRequest,
responses(
(status = 200, description = "用户更新成功", body = UserResponse),
(status = 400, description = "请求参数错误"),
(status = 401, description = "未认证"),
(status = 403, description = "无权限"),
(status = 404, description = "未找到")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)"),
("id" = String, Path, description = "用户 UUID")
)
)]
#[instrument(skip(state, payload))]
/// 更新指定用户信息(目前支持更新邮箱)。
///
/// 业务规则:
/// - 仅允许更新当前租户用户;租户不一致返回 403。
/// - 需要具备 `user:write` 权限。
/// - `UpdateUserRequest` 中未提供的字段保持不变。
///
/// 输入:
/// - Path `id`:用户 UUID
/// - Header `Authorization: Bearer <access_token>`(必填)
/// - Body `UpdateUserRequest`(必填)
///
/// 输出:
/// - `200`:更新后的用户信息
///
/// 异常:
/// - `400`:请求参数错误
/// - `401`:未认证
/// - `403`:租户不匹配或无权限
/// - `404`:用户不存在
pub async fn update_user_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
Path(target_user_id): Path<Uuid>,
Json(payload): Json<UpdateUserRequest>,
) -> Result<AppResponse<UserResponse>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
state
.authorization_service
.require_permission(tenant_id, user_id, "user:write")
.await?;
let user = state
.user_service
.update_user(tenant_id, target_user_id, payload)
.await?;
Ok(AppResponse::ok(UserResponse {
id: user.id,
email: user.email,
}))
}
#[utoipa::path(
delete,
path = "/users/{id}",
tag = "User",
security(
("bearer_auth" = [])
),
responses(
(status = 200, description = "用户删除成功"),
(status = 401, description = "未认证"),
(status = 403, description = "无权限"),
(status = 404, description = "未找到")
),
params(
("Authorization" = String, Header, description = "Bearer <access_token>(访问令牌)"),
("X-Tenant-ID" = String, Header, description = "租户 UUID可选若提供需与 Token 中 tenant_id 一致)"),
("id" = String, Path, description = "用户 UUID")
)
)]
#[instrument(skip(state))]
/// 删除指定用户。
///
/// 业务规则:
/// - 仅允许删除当前租户用户;租户不一致返回 403。
/// - 需要具备 `user:write` 权限。
///
/// 输入:
/// - Path `id`:用户 UUID
/// - Header `Authorization: Bearer <access_token>`(必填)
///
/// 输出:
/// - `200`:删除成功(空响应)
///
/// 异常:
/// - `401`:未认证
/// - `403`:租户不匹配或无权限
/// - `404`:用户不存在
pub async fn delete_user_handler(
TenantId(tenant_id): TenantId,
State(state): State<AppState>,
AuthContext {
tenant_id: auth_tenant_id,
user_id,
..
}: AuthContext,
Path(target_user_id): Path<Uuid>,
) -> Result<AppResponse<()>, AppError> {
if auth_tenant_id != tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
state
.authorization_service
.require_permission(tenant_id, user_id, "user:write")
.await?;
state
.user_service
.delete_user(tenant_id, target_user_id)
.await?;
Ok(AppResponse::ok_empty())
}

View File

@@ -1,47 +1,57 @@
mod config;
mod db; // 声明 db 模块
mod docs;
mod handlers;
mod middleware;
mod models;
mod services;
mod utils;
use axum::{Router, middleware::from_fn, routing::post};
use axum::{
Router,
http::StatusCode,
middleware::from_fn,
routing::{get, post},
};
use config::AppConfig;
use handlers::{AppState, register_handler};
use services::AuthService;
use handlers::{
AppState, create_role_handler, create_tenant_handler, delete_tenant_handler,
delete_user_handler, get_tenant_handler, get_user_handler, list_roles_handler,
list_users_handler, login_handler, my_permissions_handler, register_handler,
update_tenant_handler, update_tenant_status_handler, update_user_handler,
};
use services::{AuthService, AuthorizationService, RoleService, TenantService, UserService};
use std::net::SocketAddr;
use utoipa::OpenApi;
use utoipa_scalar::{Scalar, Servable};
// 引入 models 下的所有结构体以生成文档
use common_telemetry::telemetry::{self, TelemetryConfig};
use models::*;
#[derive(OpenApi)]
#[openapi(
paths(handlers::register_handler),
components(schemas(CreateUserRequest, UserResponse)),
tags((name = "auth", description = "Authentication API"))
)]
struct ApiDoc;
use docs::ApiDoc;
#[tokio::main]
async fn main() {
// 1. 加载配置
dotenvy::dotenv().ok();
let config = AppConfig::from_env();
let config = match AppConfig::from_env() {
Ok(c) => c,
Err(e) => {
eprintln!("Failed to load configuration: {}", e);
std::process::exit(1);
}
};
let telemetry_config = TelemetryConfig {
service_name: config.service_name,
log_level: config.log_level,
service_name: config.service_name.clone(),
log_level: config.log_level.clone(),
log_to_file: config.log_to_file,
log_dir: Some(config.log_dir),
log_file: Some(config.log_file_name),
log_dir: Some(config.log_dir.clone()),
log_file: Some(config.log_file_name.clone()),
};
// 2. 初始化 Tracing
let _guard = telemetry::init(telemetry_config);
// 3. 初始化数据库 (使用 db 模块)
let pool = match db::init_pool(&config.database_url).await {
let pool = match db::init_pool(&config).await {
Ok(p) => p,
Err(e) => {
// 记录到日志文件和控制台
@@ -57,15 +67,60 @@ async fn main() {
// 4. 初始化 Service 和 AppState
let auth_service = AuthService::new(pool.clone(), config.jwt_secret.clone());
let state = AppState { auth_service };
let user_service = UserService::new(pool.clone());
let role_service = RoleService::new(pool.clone());
let tenant_service = TenantService::new(pool.clone());
let authorization_service = AuthorizationService::new(pool.clone());
let state = AppState {
auth_service,
user_service,
role_service,
tenant_service,
authorization_service,
};
// 5. 构建路由
let app = Router::new()
.route("/auth/register", post(register_handler))
// 挂载多租户中间件
let api = Router::new()
.route("/tenants/register", post(create_tenant_handler))
.route(
"/tenants/me",
get(get_tenant_handler)
.patch(update_tenant_handler)
.delete(delete_tenant_handler),
)
.route("/tenants/me/status", post(update_tenant_status_handler))
.route(
"/auth/register",
post(register_handler)
.layer(middleware::rate_limit::register_rate_limiter())
.layer(from_fn(middleware::rate_limit::log_rate_limit_register)),
)
.route(
"/auth/login",
post(login_handler)
.layer(middleware::rate_limit::login_rate_limiter())
.layer(from_fn(middleware::rate_limit::log_rate_limit_login)),
)
.route("/me/permissions", get(my_permissions_handler))
.route("/users", get(list_users_handler))
.route(
"/users/{id}",
get(get_user_handler)
.patch(update_user_handler)
.delete(delete_user_handler),
)
.route("/roles", get(list_roles_handler).post(create_role_handler))
.layer(from_fn(middleware::resolve_tenant))
// 挂载 Scalar 文档
.layer(from_fn(middleware::auth::authenticate))
.layer(from_fn(
common_telemetry::axum_middleware::trace_http_request,
));
let app = Router::new()
.route("/favicon.ico", get(|| async { StatusCode::NO_CONTENT }))
.merge(Scalar::with_url("/scalar", ApiDoc::openapi()))
.merge(api)
.with_state(state);
// 6. 启动服务器
@@ -74,5 +129,10 @@ async fn main() {
tracing::info!("📄 Docs available at http://{}/scalar", addr);
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.unwrap();
}

67
src/middleware/auth.rs Normal file
View File

@@ -0,0 +1,67 @@
use axum::{
extract::{FromRequestParts, Request},
http::request::Parts,
middleware::Next,
response::Response,
};
use common_telemetry::AppError;
use uuid::Uuid;
#[derive(Clone, Debug)]
pub struct AuthContext {
pub tenant_id: Uuid,
pub user_id: Uuid,
pub roles: Vec<String>,
pub permissions: Vec<String>,
}
pub async fn authenticate(mut req: Request, next: Next) -> Result<Response, AppError> {
let path = req.uri().path();
if path.starts_with("/scalar")
|| path == "/tenants/register"
|| path == "/auth/register"
|| path == "/auth/login"
{
return Ok(next.run(req).await);
}
let token = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.ok_or(AppError::MissingAuthHeader)?;
let claims = crate::utils::verify(token)?;
let tenant_id = Uuid::parse_str(&claims.tenant_id)
.map_err(|_| AppError::AuthError("Invalid tenant_id claim".into()))?;
let user_id = Uuid::parse_str(&claims.sub)
.map_err(|_| AppError::AuthError("Invalid sub claim".into()))?;
tracing::Span::current().record("tenant_id", tracing::field::display(tenant_id));
tracing::Span::current().record("user_id", tracing::field::display(user_id));
req.extensions_mut().insert(AuthContext {
tenant_id,
user_id,
roles: claims.roles,
permissions: claims.permissions,
});
Ok(next.run(req).await)
}
impl<S> FromRequestParts<S> for AuthContext
where
S: Send + Sync,
{
type Rejection = AppError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<AuthContext>()
.cloned()
.ok_or(AppError::MissingAuthHeader)
}
}

View File

@@ -1,5 +1,9 @@
pub mod auth;
pub mod rate_limit;
use axum::extract::FromRequestParts;
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
use axum::{extract::Request, middleware::Next, response::Response};
use common_telemetry::AppError;
use http::request::Parts;
use uuid::Uuid;
@@ -8,7 +12,32 @@ use uuid::Uuid;
#[derive(Clone, Debug)] // 这是一个类型安全的 Wrapper用于在 Handler 中注入
pub struct TenantId(pub Uuid);
pub async fn resolve_tenant(mut req: Request, next: Next) -> Result<Response, StatusCode> {
pub async fn resolve_tenant(mut req: Request, next: Next) -> Result<Response, AppError> {
let path = req.uri().path();
if path.starts_with("/scalar") || path == "/tenants/register" {
return Ok(next.run(req).await);
}
if let Some(auth_tenant_id) = req
.extensions()
.get::<auth::AuthContext>()
.map(|ctx| ctx.tenant_id)
{
if let Some(header_value) = req
.headers()
.get("X-Tenant-ID")
.and_then(|val| val.to_str().ok())
{
let header_tenant_id = Uuid::parse_str(header_value)
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into()))?;
if header_tenant_id != auth_tenant_id {
return Err(AppError::PermissionDenied("tenant:mismatch".into()));
}
}
tracing::Span::current().record("tenant_id", tracing::field::display(auth_tenant_id));
req.extensions_mut().insert(TenantId(auth_tenant_id));
return Ok(next.run(req).await);
}
// 尝试从 Header 获取 X-Tenant-ID
let tenant_id_str = req
.headers()
@@ -18,17 +47,18 @@ pub async fn resolve_tenant(mut req: Request, next: Next) -> Result<Response, St
match tenant_id_str {
Some(id_str) => {
if let Ok(uuid) = Uuid::parse_str(id_str) {
tracing::Span::current().record("tenant_id", tracing::field::display(uuid));
// 验证成功,注入到 Extension 中
req.extensions_mut().insert(TenantId(uuid));
Ok(next.run(req).await)
} else {
Err(StatusCode::BAD_REQUEST) // ID 格式错误
Err(AppError::BadRequest("Invalid X-Tenant-ID format".into()))
}
}
None => {
// 如果是公开接口(如登录注册),可能不需要 TenantID视业务而定
// 这里假设严格模式,必须带 TenantID
Err(StatusCode::BAD_REQUEST)
Err(AppError::BadRequest("Missing X-Tenant-ID header".into()))
}
}
}
@@ -38,9 +68,12 @@ impl<S> FromRequestParts<S> for TenantId
where
S: Send + Sync,
{
type Rejection = StatusCode;
type Rejection = AppError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(tid) = parts.extensions.get::<TenantId>().cloned() {
return Ok(tid);
}
let tenant_id_str = parts
.headers
.get("X-Tenant-ID")
@@ -49,8 +82,8 @@ where
match tenant_id_str {
Some(id_str) => uuid::Uuid::parse_str(id_str)
.map(TenantId)
.map_err(|_| StatusCode::BAD_REQUEST),
None => Err(StatusCode::BAD_REQUEST),
.map_err(|_| AppError::BadRequest("Invalid X-Tenant-ID format".into())),
None => Err(AppError::BadRequest("Missing X-Tenant-ID header".into())),
}
}
}

View File

@@ -0,0 +1,369 @@
use axum::{
extract::Request,
http::{HeaderValue, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use common_telemetry::AppError;
use governor::middleware::StateInformationMiddleware;
use ipnet::IpNet;
use std::net::IpAddr;
use std::sync::OnceLock;
use std::time::Duration;
use tower_governor::GovernorLayer;
use tower_governor::errors::GovernorError;
use tower_governor::governor::GovernorConfigBuilder;
use tower_governor::key_extractor::{KeyExtractor, PeerIpKeyExtractor, SmartIpKeyExtractor};
use tracing::Instrument;
#[derive(Clone, Debug)]
pub(crate) struct TrustedProxySmartIpKeyExtractor {
trusted_proxies: Vec<IpNet>,
}
impl TrustedProxySmartIpKeyExtractor {
fn from_env() -> Self {
static TRUSTED: OnceLock<Vec<IpNet>> = OnceLock::new();
let trusted_proxies = TRUSTED
.get_or_init(|| {
let raw = std::env::var("TRUSTED_PROXY_CIDRS").unwrap_or_default();
if raw.trim().is_empty() {
return Vec::new();
}
raw.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| {
s.parse::<IpNet>()
.unwrap_or_else(|_| panic!("Invalid TRUSTED_PROXY_CIDRS entry: {s}"))
})
.collect::<Vec<_>>()
})
.clone();
Self { trusted_proxies }
}
fn is_trusted_proxy(&self, peer_ip: IpAddr) -> bool {
self.trusted_proxies
.iter()
.any(|cidr| cidr.contains(&peer_ip))
}
}
impl KeyExtractor for TrustedProxySmartIpKeyExtractor {
type Key = IpAddr;
fn extract<T>(&self, req: &http::Request<T>) -> Result<Self::Key, GovernorError> {
let peer_ip = PeerIpKeyExtractor.extract(req)?;
if self.is_trusted_proxy(peer_ip) {
SmartIpKeyExtractor.extract(req)
} else {
Ok(peer_ip)
}
}
}
fn login_policy() -> (&'static str, Duration, u32) {
("auth.login", Duration::from_millis(500), 10)
}
fn register_policy() -> (&'static str, Duration, u32) {
("auth.register", Duration::from_secs(1), 5)
}
fn governor_headers(err: &GovernorError) -> Option<http::HeaderMap> {
match err {
GovernorError::TooManyRequests { headers, .. } => headers.clone(),
GovernorError::Other { headers, .. } => headers.clone(),
GovernorError::UnableToExtractKey => None,
}
}
fn governor_wait_time_seconds(err: &GovernorError) -> Option<u64> {
match err {
GovernorError::TooManyRequests { wait_time, .. } => Some(*wait_time),
_ => None,
}
}
fn rate_limit_error_response(err: GovernorError) -> Response {
let mut resp = AppError::RateLimitExceeded.into_response();
if let Some(headers) = governor_headers(&err) {
resp.headers_mut().extend(headers);
}
if let Some(wait_time) = governor_wait_time_seconds(&err) {
resp.headers_mut().insert(
http::header::RETRY_AFTER,
HeaderValue::from_str(&wait_time.to_string()).unwrap_or(HeaderValue::from_static("1")),
);
}
resp
}
fn header_u64(resp: &Response, name: &'static str) -> Option<u64> {
resp.headers()
.get(name)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
}
fn ip_for_log(req: &Request) -> Option<IpAddr> {
TrustedProxySmartIpKeyExtractor::from_env()
.extract(req)
.ok()
}
async fn log_rate_limited(
policy: (&'static str, Duration, u32),
req: Request,
next: Next,
) -> Response {
let method = req.method().clone();
let path = req.uri().path().to_string();
let client_ip = ip_for_log(&req);
let auth = req
.extensions()
.get::<crate::middleware::auth::AuthContext>()
.cloned();
let resp = next.run(req).await;
if resp.status() != StatusCode::TOO_MANY_REQUESTS {
return resp;
}
let (policy_name, period, burst_size) = policy;
let retry_after = header_u64(&resp, "retry-after");
let limit = header_u64(&resp, "x-ratelimit-limit");
let remaining = header_u64(&resp, "x-ratelimit-remaining");
let wait = header_u64(&resp, "x-ratelimit-after").or(retry_after);
let used = match (limit, remaining) {
(Some(l), Some(r)) if l >= r => Some(l - r),
_ => None,
};
let span = tracing::Span::current();
async move {
tracing::error!(
event = "rate_limit_triggered",
policy = policy_name,
method = %method,
path = %path,
client_ip = %client_ip.map(|ip| ip.to_string()).unwrap_or_else(|| "unknown".into()),
tenant_id = %auth.as_ref().map(|a| a.tenant_id.to_string()).unwrap_or_else(|| "unknown".into()),
user_id = %auth.as_ref().map(|a| a.user_id.to_string()).unwrap_or_else(|| "anonymous".into()),
burst_size = burst_size,
period_ms = period.as_millis() as u64,
retry_after_s = retry_after,
limit = limit,
remaining = remaining,
used = used,
wait_s = wait
);
}
.instrument(span)
.await;
resp
}
pub async fn log_rate_limit_login(req: Request, next: Next) -> Response {
log_rate_limited(login_policy(), req, next).await
}
pub async fn log_rate_limit_register(req: Request, next: Next) -> Response {
log_rate_limited(register_policy(), req, next).await
}
pub fn login_rate_limiter()
-> GovernorLayer<TrustedProxySmartIpKeyExtractor, StateInformationMiddleware, axum::body::Body> {
let (policy, period, burst) = login_policy();
let mut config = GovernorConfigBuilder::default();
config.period(period);
config.burst_size(burst);
let mut config = config.use_headers();
let config = config
.key_extractor(TrustedProxySmartIpKeyExtractor::from_env())
.finish()
.unwrap_or_else(|| panic!("failed to build rate limiter config: {policy}"));
GovernorLayer::new(config).error_handler(rate_limit_error_response)
}
pub fn register_rate_limiter()
-> GovernorLayer<TrustedProxySmartIpKeyExtractor, StateInformationMiddleware, axum::body::Body> {
let (policy, period, burst) = register_policy();
let mut config = GovernorConfigBuilder::default();
config.period(period);
config.burst_size(burst);
let mut config = config.use_headers();
let config = config
.key_extractor(TrustedProxySmartIpKeyExtractor::from_env())
.finish()
.unwrap_or_else(|| panic!("failed to build rate limiter config: {policy}"));
GovernorLayer::new(config).error_handler(rate_limit_error_response)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
Router,
extract::ConnectInfo,
http::{Request, StatusCode},
middleware::from_fn,
routing::post,
};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use tower::ServiceExt;
async fn ok_handler() -> StatusCode {
StatusCode::OK
}
#[tokio::test]
async fn login_rate_limiter_eventually_returns_429() {
let app = Router::new().route("/auth/login", post(ok_handler).layer(login_rate_limiter()));
let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
let mut saw_429 = false;
for _ in 0..32 {
let mut req = Request::builder()
.method("POST")
.uri("/auth/login")
.body(axum::body::Body::empty())
.unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
let resp = app.clone().oneshot(req).await.unwrap();
if resp.status() == StatusCode::TOO_MANY_REQUESTS {
saw_429 = true;
break;
}
}
assert!(saw_429);
}
#[tokio::test]
async fn register_rate_limiter_allows_burst_then_limits() {
let app = Router::new().route(
"/auth/register",
post(ok_handler).layer(register_rate_limiter()),
);
let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
for _ in 0..5 {
let mut req = Request::builder()
.method("POST")
.uri("/auth/register")
.body(axum::body::Body::empty())
.unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
let mut req = Request::builder()
.method("POST")
.uri("/auth/register")
.body(axum::body::Body::empty())
.unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn register_rate_limiter_recovers_after_wait() {
let app = Router::new().route(
"/auth/register",
post(ok_handler).layer(register_rate_limiter()),
);
let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
for _ in 0..6 {
let mut req = Request::builder()
.method("POST")
.uri("/auth/register")
.body(axum::body::Body::empty())
.unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
let _ = app.clone().oneshot(req).await.unwrap();
}
tokio::time::sleep(Duration::from_millis(1100)).await;
let mut req = Request::builder()
.method("POST")
.uri("/auth/register")
.body(axum::body::Body::empty())
.unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn rate_limit_log_contains_request_context() {
struct BufferWriter(Arc<Mutex<Vec<u8>>>);
impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for BufferWriter {
type Writer = BufferGuard;
fn make_writer(&'a self) -> Self::Writer {
BufferGuard(self.0.clone())
}
}
struct BufferGuard(Arc<Mutex<Vec<u8>>>);
impl std::io::Write for BufferGuard {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
let buf = Arc::new(Mutex::new(Vec::<u8>::new()));
let subscriber = tracing_subscriber::fmt()
.with_writer(BufferWriter(buf.clone()))
.with_ansi(false)
.json()
.finish();
let app = Router::new().route(
"/auth/register",
post(ok_handler)
.layer(register_rate_limiter())
.layer(from_fn(log_rate_limit_register)),
);
let _guard = tracing::subscriber::set_default(subscriber);
{
let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
for _ in 0..6 {
let mut req = Request::builder()
.method("POST")
.uri("/auth/register")
.body(axum::body::Body::empty())
.unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
let _ = app.clone().oneshot(req).await.unwrap();
}
}
let s = String::from_utf8(buf.lock().unwrap().clone()).unwrap();
assert!(s.contains("\"event\":\"rate_limit_triggered\""));
assert!(s.contains("\"path\":\"/auth/register\""));
assert!(s.contains("\"method\":\"POST\""));
assert!(s.contains("\"policy\":\"auth.register\""));
assert!(s.contains("\"client_ip\":\"127.0.0.1\""));
}
}

View File

@@ -1,30 +1,194 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::FromRow;
use utoipa::ToSchema;
use utoipa::{IntoParams, ToSchema};
use uuid::Uuid; // 关键引入
#[derive(Debug, Serialize, FromRow, ToSchema)]
pub struct User {
#[schema(example = "550e8400-e29b-41d4-a716-446655440000")]
pub id: Uuid,
#[schema(example = "11111111-1111-1111-1111-111111111111")]
pub tenant_id: Uuid,
#[schema(example = "user@example.com")]
pub email: String,
#[schema(ignore)] // 不在文档中显示密码哈希
pub password_hash: String,
#[allow(dead_code)]
fn default_uuid() -> Uuid {
Uuid::nil()
}
#[derive(Debug, Deserialize, ToSchema)]
#[allow(dead_code)]
fn default_json_object() -> Value {
Value::Object(Default::default())
}
#[allow(dead_code)]
fn default_token_type() -> String {
"Bearer".to_string()
}
#[derive(Debug, Serialize, FromRow, ToSchema, IntoParams)]
pub struct User {
#[schema(example = "550e8400-e29b-41d4-a716-446655440000")]
#[schema(default = "00000000-0000-0000-0000-000000000000")]
#[serde(default = "default_uuid")]
pub id: Uuid,
#[schema(example = "11111111-1111-1111-1111-111111111111")]
#[schema(default = "00000000-0000-0000-0000-000000000000")]
#[serde(default = "default_uuid")]
pub tenant_id: Uuid,
#[schema(example = "user@example.com")]
#[schema(default = "")]
#[serde(default)]
pub email: String,
#[schema(ignore)] // 不在文档中显示密码哈希
#[serde(default)]
pub password_hash: String,
// created_at, updated_at, status etc. could be added later
}
#[derive(Debug, Deserialize, ToSchema, IntoParams)]
pub struct CreateUserRequest {
#[schema(example = "user@example.com")]
#[schema(default = "")]
#[serde(default)]
pub email: String,
#[schema(example = "securePassword123")]
#[schema(default = "")]
#[serde(default)]
pub password: String,
}
#[derive(Debug, Serialize, ToSchema)]
#[derive(Debug, Deserialize, ToSchema, IntoParams)]
pub struct UpdateUserRequest {
#[schema(example = "new_email@example.com")]
#[serde(default)]
pub email: Option<String>,
// Add other fields like name, phone, etc.
}
#[derive(Debug, Serialize, ToSchema, IntoParams)]
pub struct UserResponse {
#[schema(default = "00000000-0000-0000-0000-000000000000")]
#[serde(default = "default_uuid")]
pub id: Uuid,
#[schema(default = "")]
#[serde(default)]
pub email: String,
}
// --- Auth Related Models ---
#[derive(Debug, Deserialize, ToSchema, IntoParams)]
pub struct LoginRequest {
#[schema(example = "user@example.com")]
#[schema(default = "")]
#[serde(default)]
pub email: String,
#[schema(example = "securePassword123")]
#[schema(default = "")]
#[serde(default)]
pub password: String,
}
#[derive(Debug, Serialize, ToSchema, IntoParams)]
pub struct LoginResponse {
#[schema(default = "")]
#[serde(default)]
pub access_token: String,
#[schema(default = "")]
#[serde(default)]
pub refresh_token: String,
#[schema(default = "Bearer", example = "Bearer")]
#[serde(default = "default_token_type")]
pub token_type: String,
#[schema(default = 0)]
#[serde(default)]
pub expires_in: usize,
}
// --- Role Related Models ---
#[derive(Debug, Serialize, FromRow, ToSchema, IntoParams)]
pub struct Role {
#[schema(default = "00000000-0000-0000-0000-000000000000")]
#[serde(default = "default_uuid")]
pub id: Uuid,
#[schema(default = "00000000-0000-0000-0000-000000000000")]
#[serde(default = "default_uuid")]
pub tenant_id: Uuid,
#[schema(default = "")]
#[serde(default)]
pub name: String,
#[serde(default)]
pub description: Option<String>,
}
#[derive(Debug, Deserialize, ToSchema, IntoParams)]
pub struct CreateRoleRequest {
#[serde(default)]
pub name: String,
#[serde(default)]
pub description: Option<String>,
}
#[derive(Debug, Serialize, ToSchema, IntoParams)]
pub struct RoleResponse {
#[schema(default = "00000000-0000-0000-0000-000000000000")]
#[serde(default = "default_uuid")]
pub id: Uuid,
#[schema(default = "")]
#[serde(default)]
pub name: String,
#[serde(default)]
pub description: Option<String>,
}
// --- Tenant Related Models ---
#[derive(Debug, Serialize, FromRow, ToSchema, IntoParams)]
pub struct Tenant {
#[schema(default = "00000000-0000-0000-0000-000000000000")]
#[serde(default = "default_uuid")]
pub id: Uuid,
#[schema(default = "")]
#[serde(default)]
pub name: String,
#[schema(default = "active", example = "active")]
#[serde(default)]
pub status: String,
#[schema(default = default_json_object, example = default_json_object)]
#[serde(default = "default_json_object")]
pub config: Value,
}
#[derive(Debug, Deserialize, ToSchema, IntoParams)]
pub struct CreateTenantRequest {
#[serde(default)]
pub name: String,
#[serde(default)]
pub config: Option<Value>,
}
#[derive(Debug, Deserialize, ToSchema, IntoParams)]
pub struct UpdateTenantRequest {
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub config: Option<Value>,
}
#[derive(Debug, Deserialize, ToSchema, IntoParams)]
pub struct UpdateTenantStatusRequest {
#[schema(default = "active", example = "active")]
#[serde(default)]
pub status: String,
}
#[derive(Debug, Serialize, ToSchema, IntoParams)]
pub struct TenantResponse {
#[schema(default = "00000000-0000-0000-0000-000000000000")]
#[serde(default = "default_uuid")]
pub id: Uuid,
#[schema(default = "")]
#[serde(default)]
pub name: String,
#[schema(default = "active", example = "active")]
#[serde(default)]
pub status: String,
#[schema(default = default_json_object, example = default_json_object)]
#[serde(default = "default_json_object")]
pub config: Value,
}

222
src/services/auth.rs Normal file
View File

@@ -0,0 +1,222 @@
use crate::models::{CreateUserRequest, LoginRequest, LoginResponse, User};
use crate::utils::{hash_password, sign, verify_password};
use common_telemetry::AppError;
use rand::RngCore;
use sqlx::PgPool;
use tracing::instrument;
use uuid::Uuid;
#[derive(Clone)]
pub struct AuthService {
pool: PgPool,
// jwt_secret removed, using RS256 keys
}
impl AuthService {
/// 创建认证服务实例。
///
/// 说明:
/// - 当前实现使用 RS256 密钥对进行 JWT 签发与校验,因此 `_jwt_secret` 参数仅为兼容保留。
pub fn new(pool: PgPool, _jwt_secret: String) -> Self {
Self { pool }
}
// 注册业务
#[instrument(skip(self, req))]
/// 在指定租户下注册新用户,并在首次注册时自动引导初始化租户管理员权限。
///
/// 业务规则:
/// - 用户必须绑定到 `tenant_id`,禁止跨租户注册写入。
/// - 密码以安全哈希形式存储,不回传明文。
/// - 若该租户用户数为 1首个用户自动创建/获取 `Admin` 系统角色并授予全量权限,同时绑定到该用户。
///
/// 输入:
/// - `tenant_id`:目标租户
/// - `req.email` / `req.password`:注册信息
///
/// 输出:
/// - 返回创建后的 `User` 记录(包含 `id/tenant_id/email` 等字段)
///
/// 异常:
/// - 数据库写入失败(如唯一约束冲突、连接错误等)
/// - 密码哈希失败
pub async fn register(
&self,
tenant_id: Uuid,
req: CreateUserRequest,
) -> Result<User, AppError> {
let mut tx = self.pool.begin().await?;
// 1. 哈希密码
let hashed =
hash_password(&req.password).map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?;
// 2. 存入数据库 (带上 tenant_id)
let query = r#"
INSERT INTO users (tenant_id, email, password_hash)
VALUES ($1, $2, $3)
RETURNING id, tenant_id, email, password_hash
"#;
let user = sqlx::query_as::<_, User>(query)
.bind(tenant_id)
.bind(&req.email)
.bind(&hashed)
.fetch_one(&mut *tx)
.await?;
let user_count: i64 = sqlx::query_scalar("SELECT COUNT(1) FROM users WHERE tenant_id = $1")
.bind(tenant_id)
.fetch_one(&mut *tx)
.await?;
if user_count == 1 {
self.bootstrap_tenant_admin(&mut tx, tenant_id, user.id)
.await?;
}
tx.commit().await?;
Ok(user)
}
// 登录业务
#[instrument(skip(self, req))]
/// 在指定租户内完成用户认证并签发访问令牌与刷新令牌。
///
/// 业务规则:
/// - 仅在当前租户内按 `email` 查找用户,防止跨租户登录。
/// - 密码校验失败返回 `InvalidCredentials`。
/// - 登录成功后生成:
/// - `access_token`JWT包含租户、用户、角色与权限
/// - `refresh_token`:随机生成并哈希后入库(默认 30 天过期)
///
/// 输出:
/// - `LoginResponse`token_type 固定为 `Bearer``expires_in` 当前为 15 分钟)
///
/// 异常:
/// - 用户不存在404
/// - 密码错误401
/// - Token 签发失败或数据库写入失败
pub async fn login(
&self,
tenant_id: Uuid,
req: LoginRequest,
) -> Result<LoginResponse, AppError> {
// 1. 查找用户 (带 tenant_id 防止跨租户登录)
let query = "SELECT * FROM users WHERE tenant_id = $1 AND email = $2";
let user = sqlx::query_as::<_, User>(query)
.bind(tenant_id)
.bind(&req.email)
.fetch_optional(&self.pool)
.await?
.ok_or(AppError::NotFound("User not found".into()))?;
// 2. 验证密码
if !verify_password(&req.password, &user.password_hash) {
return Err(AppError::InvalidCredentials);
}
let roles = sqlx::query_scalar::<_, String>(
r#"
SELECT r.name
FROM roles r
JOIN user_roles ur ON ur.role_id = r.id
WHERE r.tenant_id = $1 AND ur.user_id = $2
"#,
)
.bind(user.tenant_id)
.bind(user.id)
.fetch_all(&self.pool)
.await?;
let permissions = sqlx::query_scalar::<_, String>(
r#"
SELECT DISTINCT p.code
FROM permissions p
JOIN role_permissions rp ON rp.permission_id = p.id
JOIN user_roles ur ON ur.role_id = rp.role_id
JOIN roles r ON r.id = ur.role_id
WHERE r.tenant_id = $1 AND ur.user_id = $2
"#,
)
.bind(user.tenant_id)
.bind(user.id)
.fetch_all(&self.pool)
.await?;
// 3. 签发 Access Token
let access_token = sign(user.id, user.tenant_id, roles, permissions)?;
// 4. 生成 Refresh Token
let mut refresh_bytes = [0u8; 32];
rand::rng().fill_bytes(&mut refresh_bytes);
let refresh_token = hex::encode(refresh_bytes);
// Hash refresh token for storage
let refresh_token_hash =
hash_password(&refresh_token).map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?;
// 5. 存储 Refresh Token (30天过期)
let expires_at = chrono::Utc::now() + chrono::Duration::days(30);
sqlx::query(
"INSERT INTO refresh_tokens (user_id, token_hash, expires_at) VALUES ($1, $2, $3)",
)
.bind(user.id)
.bind(refresh_token_hash)
.bind(expires_at)
.execute(&self.pool)
.await?;
Ok(LoginResponse {
access_token,
refresh_token,
token_type: "Bearer".to_string(),
expires_in: 15 * 60, // 15 mins
})
}
async fn bootstrap_tenant_admin(
&self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
tenant_id: Uuid,
user_id: Uuid,
) -> Result<(), AppError> {
let role_id: Uuid = sqlx::query_scalar(
r#"
INSERT INTO roles (tenant_id, name, description, is_system)
VALUES ($1, 'Admin', 'Tenant administrator', TRUE)
ON CONFLICT (tenant_id, name)
DO UPDATE SET name = EXCLUDED.name
RETURNING id
"#,
)
.bind(tenant_id)
.fetch_one(&mut **tx)
.await?;
sqlx::query(
r#"
INSERT INTO role_permissions (role_id, permission_id)
SELECT $1, p.id FROM permissions p
ON CONFLICT DO NOTHING
"#,
)
.bind(role_id)
.execute(&mut **tx)
.await?;
sqlx::query(
r#"
INSERT INTO user_roles (user_id, role_id)
VALUES ($1, $2)
ON CONFLICT DO NOTHING
"#,
)
.bind(user_id)
.bind(role_id)
.execute(&mut **tx)
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,79 @@
use common_telemetry::AppError;
use sqlx::PgPool;
use tracing::instrument;
use uuid::Uuid;
#[derive(Clone)]
pub struct AuthorizationService {
pool: PgPool,
}
impl AuthorizationService {
/// 创建权限服务实例。
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
#[instrument(skip(self))]
/// 获取用户在指定租户下的权限编码集合(去重)。
///
/// 说明:
/// - 权限来源于用户所属角色user_roles → roles及角色绑定权限role_permissions → permissions
///
/// 输入:
/// - `tenant_id`:租户 ID
/// - `user_id`:用户 ID
///
/// 输出:
/// - 权限编码数组(如 `tenant:read` / `user:write`
///
/// 异常:
/// - 数据库查询失败
pub async fn list_permissions_for_user(
&self,
tenant_id: Uuid,
user_id: Uuid,
) -> Result<Vec<String>, AppError> {
let query = r#"
SELECT DISTINCT p.code
FROM permissions p
JOIN role_permissions rp ON rp.permission_id = p.id
JOIN user_roles ur ON ur.role_id = rp.role_id
JOIN roles r ON r.id = ur.role_id
WHERE r.tenant_id = $1 AND ur.user_id = $2
"#;
let rows = sqlx::query_scalar::<_, String>(query)
.bind(tenant_id)
.bind(user_id)
.fetch_all(&self.pool)
.await?;
Ok(rows)
}
#[instrument(skip(self))]
/// 校验用户是否具备指定权限,不满足则直接返回权限拒绝错误。
///
/// 业务规则:
/// - 若用户权限集合中不包含 `permission_code`,返回 `PermissionDenied(permission_code)`。
///
/// 输入:
/// - `tenant_id`:租户 ID
/// - `user_id`:用户 ID
/// - `permission_code`:权限编码
///
/// 输出:
/// - 成功返回 `()`;失败返回权限拒绝错误
pub async fn require_permission(
&self,
tenant_id: Uuid,
user_id: Uuid,
permission_code: &str,
) -> Result<(), AppError> {
let permissions = self.list_permissions_for_user(tenant_id, user_id).await?;
if permissions.iter().any(|p| p == permission_code) {
Ok(())
} else {
Err(AppError::PermissionDenied(permission_code.to_string()))
}
}
}

View File

@@ -1,69 +1,11 @@
use crate::models::{CreateUserRequest, User}; // 假设你在 models 定义了这些
use crate::utils::{create_jwt, hash_password, verify_password};
use axum::Json;
use sqlx::PgPool;
use uuid::Uuid;
pub mod auth;
pub mod authorization;
pub mod role;
pub mod tenant;
pub mod user;
#[derive(Clone)]
pub struct AuthService {
pool: PgPool,
jwt_secret: String,
}
impl AuthService {
pub fn new(pool: PgPool, jwt_secret: String) -> Self {
Self { pool, jwt_secret }
}
// 注册业务
pub async fn register(
&self,
tenant_id: Uuid,
req: CreateUserRequest,
) -> Result<Json<User>, String> {
// 1. 哈希密码
let hashed = hash_password(&req.password)?;
// 2. 存入数据库 (带上 tenant_id)
let query = r#"
INSERT INTO users (tenant_id, email, password_hash)
VALUES ($1, $2, $3)
RETURNING id, tenant_id, email, password_hash, created_at
"#;
let user = sqlx::query_as::<_, User>(query)
.bind(tenant_id)
.bind(&req.email)
.bind(&hashed)
.fetch_one(&self.pool)
.await
.map_err(|e| e.to_string())?;
Ok(Json(user))
}
// 登录业务
pub async fn login(
&self,
tenant_id: Uuid,
email: &str,
password: &str,
) -> Result<String, String> {
// 1. 查找用户 (带 tenant_id 防止跨租户登录)
let query = "SELECT * FROM users WHERE tenant_id = $1 AND email = $2";
let user = sqlx::query_as::<_, User>(query)
.bind(tenant_id)
.bind(email)
.fetch_optional(&self.pool)
.await
.map_err(|e| e.to_string())?
.ok_or("User not found")?;
// 2. 验证密码
if !verify_password(password, &user.password_hash) {
return Err("Invalid password".to_string());
}
// 3. 签发 Token
create_jwt(user.id, user.tenant_id, &self.jwt_secret)
}
}
pub use auth::AuthService;
pub use authorization::AuthorizationService;
pub use role::RoleService;
pub use tenant::TenantService;
pub use user::UserService;

59
src/services/role.rs Normal file
View File

@@ -0,0 +1,59 @@
use crate::models::{CreateRoleRequest, Role};
use common_telemetry::AppError;
use sqlx::PgPool;
use tracing::instrument;
use uuid::Uuid;
#[derive(Clone)]
pub struct RoleService {
pool: PgPool,
}
impl RoleService {
/// 创建角色服务实例。
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
#[instrument(skip(self))]
/// 在指定租户下创建角色记录。
///
/// 业务规则:
/// - 角色与租户强绑定(写入时携带 `tenant_id`)。
///
/// 异常:
/// - 数据库写入失败(如约束冲突、连接错误等)
pub async fn create_role(
&self,
tenant_id: Uuid,
req: CreateRoleRequest,
) -> Result<Role, AppError> {
let query = r#"
INSERT INTO roles (tenant_id, name, description)
VALUES ($1, $2, $3)
RETURNING id, tenant_id, name, description
"#;
// Note: 'roles' table needs to be created in DB
sqlx::query_as::<_, Role>(query)
.bind(tenant_id)
.bind(req.name)
.bind(req.description)
.fetch_one(&self.pool)
.await
.map_err(|e| AppError::DbError(e))
}
#[instrument(skip(self))]
/// 查询指定租户下的角色列表。
///
/// 异常:
/// - 数据库查询失败
pub async fn list_roles(&self, tenant_id: Uuid) -> Result<Vec<Role>, AppError> {
let query = "SELECT * FROM roles WHERE tenant_id = $1";
sqlx::query_as::<_, Role>(query)
.bind(tenant_id)
.fetch_all(&self.pool)
.await
.map_err(|e| AppError::DbError(e))
}
}

134
src/services/tenant.rs Normal file
View File

@@ -0,0 +1,134 @@
use crate::models::{
CreateTenantRequest, Tenant, UpdateTenantRequest, UpdateTenantStatusRequest,
};
use common_telemetry::AppError;
use serde_json::Value;
use sqlx::PgPool;
use tracing::instrument;
use uuid::Uuid;
#[derive(Clone)]
pub struct TenantService {
pool: PgPool,
}
impl TenantService {
/// 创建租户服务实例。
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
#[instrument(skip(self, req))]
/// 创建新租户并初始化默认状态与配置。
///
/// 业务规则:
/// - 默认 `status=active`。
/// - `config` 未提供时默认 `{}`。
///
/// 输出:
/// - 返回新建租户记录(含 `id`
///
/// 异常:
/// - 数据库写入失败(如连接异常、约束失败等)
pub async fn create_tenant(&self, req: CreateTenantRequest) -> Result<Tenant, AppError> {
let config = req.config.unwrap_or_else(|| Value::Object(Default::default()));
let query = r#"
INSERT INTO tenants (name, status, config)
VALUES ($1, 'active', $2)
RETURNING id, name, status, config
"#;
let tenant = sqlx::query_as::<_, Tenant>(query)
.bind(req.name)
.bind(config)
.fetch_one(&self.pool)
.await?;
Ok(tenant)
}
#[instrument(skip(self))]
/// 根据租户 ID 查询租户信息。
///
/// 异常:
/// - 若租户不存在,返回 `NotFound("Tenant not found")`。
pub async fn get_tenant(&self, tenant_id: Uuid) -> Result<Tenant, AppError> {
let query = "SELECT id, name, status, config FROM tenants WHERE id = $1";
sqlx::query_as::<_, Tenant>(query)
.bind(tenant_id)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| AppError::NotFound("Tenant not found".into()))
}
#[instrument(skip(self, req))]
/// 更新租户基础信息(名称 / 配置)。
///
/// 说明:
/// - 仅更新 `UpdateTenantRequest` 中提供的字段,未提供字段保持不变。
///
/// 异常:
/// - 若租户不存在,返回 `NotFound("Tenant not found")`。
pub async fn update_tenant(
&self,
tenant_id: Uuid,
req: UpdateTenantRequest,
) -> Result<Tenant, AppError> {
let query = r#"
UPDATE tenants
SET
name = COALESCE($1, name),
config = COALESCE($2, config),
updated_at = NOW()
WHERE id = $3
RETURNING id, name, status, config
"#;
sqlx::query_as::<_, Tenant>(query)
.bind(req.name)
.bind(req.config)
.bind(tenant_id)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| AppError::NotFound("Tenant not found".into()))
}
#[instrument(skip(self, req))]
/// 更新租户状态字段(如 active / disabled
///
/// 异常:
/// - 若租户不存在,返回 `NotFound("Tenant not found")`。
pub async fn update_tenant_status(
&self,
tenant_id: Uuid,
req: UpdateTenantStatusRequest,
) -> Result<Tenant, AppError> {
let query = r#"
UPDATE tenants
SET
status = $1,
updated_at = NOW()
WHERE id = $2
RETURNING id, name, status, config
"#;
sqlx::query_as::<_, Tenant>(query)
.bind(req.status)
.bind(tenant_id)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| AppError::NotFound("Tenant not found".into()))
}
#[instrument(skip(self))]
/// 删除指定租户。
///
/// 异常:
/// - 若租户不存在,返回 `NotFound("Tenant not found")`。
pub async fn delete_tenant(&self, tenant_id: Uuid) -> Result<(), AppError> {
let result = sqlx::query("DELETE FROM tenants WHERE id = $1")
.bind(tenant_id)
.execute(&self.pool)
.await?;
if result.rows_affected() == 0 {
return Err(AppError::NotFound("Tenant not found".into()));
}
Ok(())
}
}

109
src/services/user.rs Normal file
View File

@@ -0,0 +1,109 @@
use crate::models::{UpdateUserRequest, User};
use common_telemetry::AppError;
use sqlx::PgPool;
use tracing::instrument;
use uuid::Uuid;
#[derive(Clone)]
pub struct UserService {
pool: PgPool,
}
impl UserService {
/// 创建用户服务实例。
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
#[instrument(skip(self))]
/// 根据用户 ID 查询用户记录(限定在指定租户内)。
///
/// 业务规则:
/// - 查询条件同时包含 `tenant_id` 与 `user_id`,避免跨租户读取。
///
/// 异常:
/// - 用户不存在返回 `NotFound("User not found")`
pub async fn get_user_by_id(&self, tenant_id: Uuid, user_id: Uuid) -> Result<User, AppError> {
let query = "SELECT * FROM users WHERE tenant_id = $1 AND id = $2";
sqlx::query_as::<_, User>(query)
.bind(tenant_id)
.bind(user_id)
.fetch_optional(&self.pool)
.await
.map_err(|e| AppError::DbError(e))?
.ok_or_else(|| AppError::NotFound("User not found".into()))
}
#[instrument(skip(self))]
/// 分页查询租户下用户列表。
///
/// 说明:
/// - `offset = (page - 1) * page_size`,由上层负责保证 `page>=1`。
///
/// 异常:
/// - 数据库查询失败
pub async fn list_users(
&self,
tenant_id: Uuid,
page: u32,
page_size: u32,
) -> Result<Vec<User>, AppError> {
let offset = (page - 1) * page_size;
let query = "SELECT * FROM users WHERE tenant_id = $1 LIMIT $2 OFFSET $3";
sqlx::query_as::<_, User>(query)
.bind(tenant_id)
.bind(page_size as i64)
.bind(offset as i64)
.fetch_all(&self.pool)
.await
.map_err(|e| AppError::DbError(e))
}
#[instrument(skip(self))]
/// 更新指定用户信息(目前仅支持邮箱字段)。
///
/// 业务规则:
/// - 查询条件同时包含 `tenant_id` 与 `user_id`,避免跨租户更新。
/// - `UpdateUserRequest` 中未提供字段保持不变。
///
/// 异常:
/// - 用户不存在返回 `NotFound("User not found")`
pub async fn update_user(
&self,
tenant_id: Uuid,
user_id: Uuid,
req: UpdateUserRequest,
) -> Result<User, AppError> {
// Simple update implementation
// In a real app, you'd build the query dynamically based on Option fields
let query = "UPDATE users SET email = COALESCE($1, email) WHERE tenant_id = $2 AND id = $3 RETURNING *";
sqlx::query_as::<_, User>(query)
.bind(req.email)
.bind(tenant_id)
.bind(user_id)
.fetch_optional(&self.pool)
.await
.map_err(|e| AppError::DbError(e))?
.ok_or_else(|| AppError::NotFound("User not found".into()))
}
#[instrument(skip(self))]
/// 删除指定用户(限定在指定租户内)。
///
/// 异常:
/// - 用户不存在返回 `NotFound("User not found")`
pub async fn delete_user(&self, tenant_id: Uuid, user_id: Uuid) -> Result<(), AppError> {
let query = "DELETE FROM users WHERE tenant_id = $1 AND id = $2";
let result = sqlx::query(query)
.bind(tenant_id)
.bind(user_id)
.execute(&self.pool)
.await
.map_err(|e| AppError::DbError(e))?;
if result.rows_affected() == 0 {
return Err(AppError::NotFound("User not found".into()));
}
Ok(())
}
}

58
src/utils/jwt.rs Normal file
View File

@@ -0,0 +1,58 @@
use crate::utils::keys::get_keys;
use common_telemetry::AppError;
use jsonwebtoken::{Algorithm, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String, // User ID
pub tenant_id: String, // Tenant ID
pub exp: usize, // Expiration
pub iat: usize, // Issued At
pub iss: String, // Issuer
#[serde(default)]
pub roles: Vec<String>,
#[serde(default)]
pub permissions: Vec<String>,
}
pub fn sign(
user_id: Uuid,
tenant_id: Uuid,
roles: Vec<String>,
permissions: Vec<String>,
) -> Result<String, AppError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as usize;
let expiration = now + 15 * 60; // 15 minutes access token
let claims = Claims {
sub: user_id.to_string(),
tenant_id: tenant_id.to_string(),
exp: expiration,
iat: now,
iss: "iam-service".to_string(),
roles,
permissions,
};
let keys = get_keys();
encode(&Header::new(Algorithm::RS256), &claims, &keys.encoding_key)
.map_err(|e| AppError::AuthError(e.to_string()))
}
pub fn verify(token: &str) -> Result<Claims, AppError> {
let keys = get_keys();
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&["iam-service"]);
let token_data = decode::<Claims>(token, &keys.decoding_key, &validation)
.map_err(|e| AppError::AuthError(e.to_string()))?;
Ok(token_data.claims)
}

40
src/utils/keys.rs Normal file
View File

@@ -0,0 +1,40 @@
use rsa::pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey};
use rsa::rand_core::OsRng;
use rsa::{RsaPrivateKey, RsaPublicKey, pkcs1::LineEnding};
use std::sync::OnceLock;
pub struct KeyPair {
pub encoding_key: jsonwebtoken::EncodingKey,
pub decoding_key: jsonwebtoken::DecodingKey,
}
static KEYS: OnceLock<KeyPair> = OnceLock::new();
pub fn get_keys() -> &'static KeyPair {
KEYS.get_or_init(|| {
// In a real production app, you would load these from files or ENV variables
// defined in your AppConfig.
// For now, we generate a fresh key pair on startup.
let bits = 2048;
let private_key = RsaPrivateKey::new(&mut OsRng, bits).expect("failed to generate a key");
let public_key = RsaPublicKey::from(&private_key);
let private_pem = private_key
.to_pkcs1_pem(LineEnding::LF)
.expect("failed to encode private key");
let public_pem = public_key
.to_pkcs1_pem(LineEnding::LF)
.expect("failed to encode public key");
let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
.expect("failed to create encoding key");
let decoding_key = jsonwebtoken::DecodingKey::from_rsa_pem(public_pem.as_bytes())
.expect("failed to create decoding key");
KeyPair {
encoding_key,
decoding_key,
}
})
}

View File

@@ -1,60 +1,6 @@
use argon2::{
Argon2,
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core::OsRng},
};
use jsonwebtoken::{EncodingKey, Header, encode};
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
pub mod keys;
pub mod jwt;
pub mod password;
// --- 密码部分 ---
pub fn hash_password(password: &str) -> Result<String, String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| e.to_string())?
.to_string();
Ok(password_hash)
}
pub fn verify_password(password: &str, password_hash: &str) -> bool {
let parsed_hash = match PasswordHash::new(password_hash) {
Ok(h) => h,
Err(_) => return false,
};
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok()
}
// --- JWT 部分 ---
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String, // 用户ID
pub tenant_id: String, // 租户ID (关键!)
pub exp: usize, // 过期时间
}
pub fn create_jwt(user_id: Uuid, tenant_id: Uuid, secret: &str) -> Result<String, String> {
let expiration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as usize
+ 24 * 3600; // 24小时过期
let claims = Claims {
sub: user_id.to_string(),
tenant_id: tenant_id.to_string(),
exp: expiration,
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_ref()),
)
.map_err(|e| e.to_string())
}
pub use password::{hash_password, verify_password};
pub use jwt::{sign, verify};

24
src/utils/password.rs Normal file
View File

@@ -0,0 +1,24 @@
use argon2::{
Argon2,
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core::OsRng},
};
pub fn hash_password(password: &str) -> Result<String, String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| e.to_string())?
.to_string();
Ok(password_hash)
}
pub fn verify_password(password: &str, password_hash: &str) -> bool {
let parsed_hash = match PasswordHash::new(password_hash) {
Ok(h) => h,
Err(_) => return false,
};
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok()
}