feat(project): init

This commit is contained in:
2026-01-29 18:14:47 +08:00
commit bb82c75834
15 changed files with 3715 additions and 0 deletions

34
src/config/mod.rs Normal file
View File

@@ -0,0 +1,34 @@
use std::env;
#[derive(Clone, Debug)]
pub struct AppConfig {
pub service_name: String,
pub log_level: String,
pub log_to_file: bool,
pub log_dir: String,
pub log_file_name: String,
pub database_url: String,
pub jwt_secret: String,
pub port: u16,
}
impl AppConfig {
pub fn from_env() -> Self {
Self {
service_name: env::var("SERVICE_NAME").unwrap_or_else(|_| "iam-service".into()),
log_level: env::var("LOG_LEVEL").unwrap_or_else(|_| "info".into()),
log_to_file: env::var("LOG_TO_FILE")
.map(|v| v == "true" || v == "1")
.unwrap_or(false),
log_dir: env::var("LOG_DIR").unwrap_or_else(|_| "./log".into()),
log_file_name: env::var("LOG_FILE_NAME").unwrap_or_else(|_| "iam.log".into()),
database_url: env::var("DATABASE_URL").expect("DATABASE_URL required"),
jwt_secret: env::var("JWT_SECRET")
.expect("JWT_SECRET required, generate by run 'openssl rand -base64 32'"),
port: env::var("PORT")
.unwrap_or_else(|_| "3000".to_string())
.parse()
.unwrap(),
}
}
}

18
src/db/mod.rs Normal file
View File

@@ -0,0 +1,18 @@
use sqlx::postgres::{PgPool, PgPoolOptions};
use std::time::Duration;
/// 初始化数据库连接池
pub async fn init_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
PgPoolOptions::new()
.max_connections(20) // 根据服务器规格调整IAM服务通常并发高
.min_connections(5)
.acquire_timeout(Duration::from_secs(3)) // 获取连接超时时间
.connect(database_url)
.await
}
// (可选) 可以在应用启动时自动运行迁移
// pub async fn run_migrations(pool: &PgPool) -> Result<(), sqlx::migrate::MigrateError> {
// // 这要求你在项目根目录有 `migrations/` 文件夹
// sqlx::migrate!("./migrations").run(pool).await
// }

48
src/handlers/mod.rs Normal file
View File

@@ -0,0 +1,48 @@
use crate::middleware::TenantId;
use crate::models::{CreateUserRequest, UserResponse};
use crate::services::AuthService;
use axum::{Json, extract::State};
use common_telemetry::AppError; // 引入刚刚写的中间件类型
// 状态对象,包含 Service
#[derive(Clone)]
pub struct AppState {
pub auth_service: AuthService,
}
/// 注册接口
#[utoipa::path(
post,
path = "/auth/register",
request_body = CreateUserRequest,
responses(
(status = 201, description = "User created", body = UserResponse),
(status = 400, description = "Bad request")
),
params(
("X-Tenant-ID" = String, Header, description = "Tenant UUID")
)
)]
pub async fn register_handler(
// 1. 自动注入 TenantId (由中间件解析)
TenantId(tenant_id): TenantId,
// 2. 获取全局状态中的 Service
State(state): State<AppState>,
// 3. 获取 Body
Json(payload): Json<CreateUserRequest>,
) -> Result<Json<UserResponse>, AppError> {
let user = state
.auth_service
.register(tenant_id, payload)
.await
.map_err(AppError::BadRequest)?;
// 转换为 Response DTO (隐藏密码等敏感信息)
let response = UserResponse {
id: user.id,
email: user.email.clone(),
// ...
};
Ok(Json(response))
}

78
src/main.rs Normal file
View File

@@ -0,0 +1,78 @@
mod config;
mod db; // 声明 db 模块
mod handlers;
mod middleware;
mod models;
mod services;
mod utils;
use axum::{Router, middleware::from_fn, routing::post};
use config::AppConfig;
use handlers::{AppState, register_handler};
use services::AuthService;
use std::net::SocketAddr;
use utoipa::OpenApi;
use utoipa_scalar::{Scalar, Servable};
// 引入 models 下的所有结构体以生成文档
use common_telemetry::telemetry::{self, TelemetryConfig};
use models::*;
#[derive(OpenApi)]
#[openapi(
paths(handlers::register_handler),
components(schemas(CreateUserRequest, UserResponse)),
tags((name = "auth", description = "Authentication API"))
)]
struct ApiDoc;
#[tokio::main]
async fn main() {
// 1. 加载配置
dotenvy::dotenv().ok();
let config = AppConfig::from_env();
let telemetry_config = TelemetryConfig {
service_name: config.service_name,
log_level: config.log_level,
log_to_file: config.log_to_file,
log_dir: Some(config.log_dir),
log_file: Some(config.log_file_name),
};
// 2. 初始化 Tracing
let _guard = telemetry::init(telemetry_config);
// 3. 初始化数据库 (使用 db 模块)
let pool = match db::init_pool(&config.database_url).await {
Ok(p) => p,
Err(e) => {
// 记录到日志文件和控制台
tracing::error!(%e, "Fatal error: Failed to connect to database!");
// 退出程序 (或者 panic)
std::process::exit(1);
}
};
// (可选) 运行迁移
// tracing::info!("🔄 Running migrations...");
// db::run_migrations(&pool).await.expect("Failed to run migrations");
// 4. 初始化 Service 和 AppState
let auth_service = AuthService::new(pool.clone(), config.jwt_secret.clone());
let state = AppState { auth_service };
// 5. 构建路由
let app = Router::new()
.route("/auth/register", post(register_handler))
// 挂载多租户中间件
.layer(from_fn(middleware::resolve_tenant))
// 挂载 Scalar 文档
.merge(Scalar::with_url("/scalar", ApiDoc::openapi()))
.with_state(state);
// 6. 启动服务器
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
tracing::info!("🚀 Server started at http://{}", addr);
tracing::info!("📄 Docs available at http://{}/scalar", addr);
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}

56
src/middleware/mod.rs Normal file
View File

@@ -0,0 +1,56 @@
use axum::extract::FromRequestParts;
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
use http::request::Parts;
use uuid::Uuid;
// --- 1. 租户 ID 提取器 ---
#[derive(Clone, Debug)] // 这是一个类型安全的 Wrapper用于在 Handler 中注入
pub struct TenantId(pub Uuid);
pub async fn resolve_tenant(mut req: Request, next: Next) -> Result<Response, StatusCode> {
// 尝试从 Header 获取 X-Tenant-ID
let tenant_id_str = req
.headers()
.get("X-Tenant-ID")
.and_then(|val| val.to_str().ok());
match tenant_id_str {
Some(id_str) => {
if let Ok(uuid) = Uuid::parse_str(id_str) {
// 验证成功,注入到 Extension 中
req.extensions_mut().insert(TenantId(uuid));
Ok(next.run(req).await)
} else {
Err(StatusCode::BAD_REQUEST) // ID 格式错误
}
}
None => {
// 如果是公开接口(如登录注册),可能不需要 TenantID视业务而定
// 这里假设严格模式,必须带 TenantID
Err(StatusCode::BAD_REQUEST)
}
}
}
// 实现 FromRequestParts 让 Handler 可以直接写 `tid: TenantId`
impl<S> FromRequestParts<S> for TenantId
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let tenant_id_str = parts
.headers
.get("X-Tenant-ID")
.and_then(|val| val.to_str().ok());
match tenant_id_str {
Some(id_str) => uuid::Uuid::parse_str(id_str)
.map(TenantId)
.map_err(|_| StatusCode::BAD_REQUEST),
None => Err(StatusCode::BAD_REQUEST),
}
}
}

30
src/models.rs Normal file
View File

@@ -0,0 +1,30 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use utoipa::ToSchema;
use uuid::Uuid; // 关键引入
#[derive(Debug, Serialize, FromRow, ToSchema)]
pub struct User {
#[schema(example = "550e8400-e29b-41d4-a716-446655440000")]
pub id: Uuid,
#[schema(example = "11111111-1111-1111-1111-111111111111")]
pub tenant_id: Uuid,
#[schema(example = "user@example.com")]
pub email: String,
#[schema(ignore)] // 不在文档中显示密码哈希
pub password_hash: String,
}
#[derive(Debug, Deserialize, ToSchema)]
pub struct CreateUserRequest {
#[schema(example = "user@example.com")]
pub email: String,
#[schema(example = "securePassword123")]
pub password: String,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct UserResponse {
pub id: Uuid,
pub email: String,
}

69
src/services/mod.rs Normal file
View File

@@ -0,0 +1,69 @@
use crate::models::{CreateUserRequest, User}; // 假设你在 models 定义了这些
use crate::utils::{create_jwt, hash_password, verify_password};
use axum::Json;
use sqlx::PgPool;
use uuid::Uuid;
#[derive(Clone)]
pub struct AuthService {
pool: PgPool,
jwt_secret: String,
}
impl AuthService {
pub fn new(pool: PgPool, jwt_secret: String) -> Self {
Self { pool, jwt_secret }
}
// 注册业务
pub async fn register(
&self,
tenant_id: Uuid,
req: CreateUserRequest,
) -> Result<Json<User>, String> {
// 1. 哈希密码
let hashed = hash_password(&req.password)?;
// 2. 存入数据库 (带上 tenant_id)
let query = r#"
INSERT INTO users (tenant_id, email, password_hash)
VALUES ($1, $2, $3)
RETURNING id, tenant_id, email, password_hash, created_at
"#;
let user = sqlx::query_as::<_, User>(query)
.bind(tenant_id)
.bind(&req.email)
.bind(&hashed)
.fetch_one(&self.pool)
.await
.map_err(|e| e.to_string())?;
Ok(Json(user))
}
// 登录业务
pub async fn login(
&self,
tenant_id: Uuid,
email: &str,
password: &str,
) -> Result<String, String> {
// 1. 查找用户 (带 tenant_id 防止跨租户登录)
let query = "SELECT * FROM users WHERE tenant_id = $1 AND email = $2";
let user = sqlx::query_as::<_, User>(query)
.bind(tenant_id)
.bind(email)
.fetch_optional(&self.pool)
.await
.map_err(|e| e.to_string())?
.ok_or("User not found")?;
// 2. 验证密码
if !verify_password(password, &user.password_hash) {
return Err("Invalid password".to_string());
}
// 3. 签发 Token
create_jwt(user.id, user.tenant_id, &self.jwt_secret)
}
}

60
src/utils/mod.rs Normal file
View File

@@ -0,0 +1,60 @@
use argon2::{
Argon2,
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core::OsRng},
};
use jsonwebtoken::{EncodingKey, Header, encode};
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
// --- 密码部分 ---
pub fn hash_password(password: &str) -> Result<String, String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| e.to_string())?
.to_string();
Ok(password_hash)
}
pub fn verify_password(password: &str, password_hash: &str) -> bool {
let parsed_hash = match PasswordHash::new(password_hash) {
Ok(h) => h,
Err(_) => return false,
};
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok()
}
// --- JWT 部分 ---
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String, // 用户ID
pub tenant_id: String, // 租户ID (关键!)
pub exp: usize, // 过期时间
}
pub fn create_jwt(user_id: Uuid, tenant_id: Uuid, secret: &str) -> Result<String, String> {
let expiration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as usize
+ 24 * 3600; // 24小时过期
let claims = Claims {
sub: user_id.to_string(),
tenant_id: tenant_id.to_string(),
exp: expiration,
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_ref()),
)
.map_err(|e| e.to_string())
}