feat(project): init
This commit is contained in:
34
src/config/mod.rs
Normal file
34
src/config/mod.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use std::env;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AppConfig {
|
||||
pub service_name: String,
|
||||
pub log_level: String,
|
||||
pub log_to_file: bool,
|
||||
pub log_dir: String,
|
||||
pub log_file_name: String,
|
||||
pub database_url: String,
|
||||
pub jwt_secret: String,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub fn from_env() -> Self {
|
||||
Self {
|
||||
service_name: env::var("SERVICE_NAME").unwrap_or_else(|_| "iam-service".into()),
|
||||
log_level: env::var("LOG_LEVEL").unwrap_or_else(|_| "info".into()),
|
||||
log_to_file: env::var("LOG_TO_FILE")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false),
|
||||
log_dir: env::var("LOG_DIR").unwrap_or_else(|_| "./log".into()),
|
||||
log_file_name: env::var("LOG_FILE_NAME").unwrap_or_else(|_| "iam.log".into()),
|
||||
database_url: env::var("DATABASE_URL").expect("DATABASE_URL required"),
|
||||
jwt_secret: env::var("JWT_SECRET")
|
||||
.expect("JWT_SECRET required, generate by run 'openssl rand -base64 32'"),
|
||||
port: env::var("PORT")
|
||||
.unwrap_or_else(|_| "3000".to_string())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
18
src/db/mod.rs
Normal file
18
src/db/mod.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use sqlx::postgres::{PgPool, PgPoolOptions};
|
||||
use std::time::Duration;
|
||||
|
||||
/// 初始化数据库连接池
|
||||
pub async fn init_pool(database_url: &str) -> Result<PgPool, sqlx::Error> {
|
||||
PgPoolOptions::new()
|
||||
.max_connections(20) // 根据服务器规格调整,IAM服务通常并发高
|
||||
.min_connections(5)
|
||||
.acquire_timeout(Duration::from_secs(3)) // 获取连接超时时间
|
||||
.connect(database_url)
|
||||
.await
|
||||
}
|
||||
|
||||
// (可选) 可以在应用启动时自动运行迁移
|
||||
// pub async fn run_migrations(pool: &PgPool) -> Result<(), sqlx::migrate::MigrateError> {
|
||||
// // 这要求你在项目根目录有 `migrations/` 文件夹
|
||||
// sqlx::migrate!("./migrations").run(pool).await
|
||||
// }
|
||||
48
src/handlers/mod.rs
Normal file
48
src/handlers/mod.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use crate::middleware::TenantId;
|
||||
use crate::models::{CreateUserRequest, UserResponse};
|
||||
use crate::services::AuthService;
|
||||
use axum::{Json, extract::State};
|
||||
use common_telemetry::AppError; // 引入刚刚写的中间件类型
|
||||
|
||||
// 状态对象,包含 Service
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub auth_service: AuthService,
|
||||
}
|
||||
|
||||
/// 注册接口
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/auth/register",
|
||||
request_body = CreateUserRequest,
|
||||
responses(
|
||||
(status = 201, description = "User created", body = UserResponse),
|
||||
(status = 400, description = "Bad request")
|
||||
),
|
||||
params(
|
||||
("X-Tenant-ID" = String, Header, description = "Tenant UUID")
|
||||
)
|
||||
)]
|
||||
pub async fn register_handler(
|
||||
// 1. 自动注入 TenantId (由中间件解析)
|
||||
TenantId(tenant_id): TenantId,
|
||||
// 2. 获取全局状态中的 Service
|
||||
State(state): State<AppState>,
|
||||
// 3. 获取 Body
|
||||
Json(payload): Json<CreateUserRequest>,
|
||||
) -> Result<Json<UserResponse>, AppError> {
|
||||
let user = state
|
||||
.auth_service
|
||||
.register(tenant_id, payload)
|
||||
.await
|
||||
.map_err(AppError::BadRequest)?;
|
||||
|
||||
// 转换为 Response DTO (隐藏密码等敏感信息)
|
||||
let response = UserResponse {
|
||||
id: user.id,
|
||||
email: user.email.clone(),
|
||||
// ...
|
||||
};
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
78
src/main.rs
Normal file
78
src/main.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
mod config;
|
||||
mod db; // 声明 db 模块
|
||||
mod handlers;
|
||||
mod middleware;
|
||||
mod models;
|
||||
mod services;
|
||||
mod utils;
|
||||
|
||||
use axum::{Router, middleware::from_fn, routing::post};
|
||||
use config::AppConfig;
|
||||
use handlers::{AppState, register_handler};
|
||||
use services::AuthService;
|
||||
use std::net::SocketAddr;
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_scalar::{Scalar, Servable};
|
||||
// 引入 models 下的所有结构体以生成文档
|
||||
use common_telemetry::telemetry::{self, TelemetryConfig};
|
||||
use models::*;
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(handlers::register_handler),
|
||||
components(schemas(CreateUserRequest, UserResponse)),
|
||||
tags((name = "auth", description = "Authentication API"))
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// 1. 加载配置
|
||||
dotenvy::dotenv().ok();
|
||||
let config = AppConfig::from_env();
|
||||
let telemetry_config = TelemetryConfig {
|
||||
service_name: config.service_name,
|
||||
log_level: config.log_level,
|
||||
log_to_file: config.log_to_file,
|
||||
log_dir: Some(config.log_dir),
|
||||
log_file: Some(config.log_file_name),
|
||||
};
|
||||
// 2. 初始化 Tracing
|
||||
let _guard = telemetry::init(telemetry_config);
|
||||
|
||||
// 3. 初始化数据库 (使用 db 模块)
|
||||
let pool = match db::init_pool(&config.database_url).await {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
// 记录到日志文件和控制台
|
||||
tracing::error!(%e, "Fatal error: Failed to connect to database!");
|
||||
// 退出程序 (或者 panic)
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
// (可选) 运行迁移
|
||||
// tracing::info!("🔄 Running migrations...");
|
||||
// db::run_migrations(&pool).await.expect("Failed to run migrations");
|
||||
|
||||
// 4. 初始化 Service 和 AppState
|
||||
let auth_service = AuthService::new(pool.clone(), config.jwt_secret.clone());
|
||||
let state = AppState { auth_service };
|
||||
|
||||
// 5. 构建路由
|
||||
let app = Router::new()
|
||||
.route("/auth/register", post(register_handler))
|
||||
// 挂载多租户中间件
|
||||
.layer(from_fn(middleware::resolve_tenant))
|
||||
// 挂载 Scalar 文档
|
||||
.merge(Scalar::with_url("/scalar", ApiDoc::openapi()))
|
||||
.with_state(state);
|
||||
|
||||
// 6. 启动服务器
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
|
||||
tracing::info!("🚀 Server started at http://{}", addr);
|
||||
tracing::info!("📄 Docs available at http://{}/scalar", addr);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
56
src/middleware/mod.rs
Normal file
56
src/middleware/mod.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use axum::extract::FromRequestParts;
|
||||
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
|
||||
use http::request::Parts;
|
||||
use uuid::Uuid;
|
||||
|
||||
// --- 1. 租户 ID 提取器 ---
|
||||
|
||||
#[derive(Clone, Debug)] // 这是一个类型安全的 Wrapper,用于在 Handler 中注入
|
||||
pub struct TenantId(pub Uuid);
|
||||
|
||||
pub async fn resolve_tenant(mut req: Request, next: Next) -> Result<Response, StatusCode> {
|
||||
// 尝试从 Header 获取 X-Tenant-ID
|
||||
let tenant_id_str = req
|
||||
.headers()
|
||||
.get("X-Tenant-ID")
|
||||
.and_then(|val| val.to_str().ok());
|
||||
|
||||
match tenant_id_str {
|
||||
Some(id_str) => {
|
||||
if let Ok(uuid) = Uuid::parse_str(id_str) {
|
||||
// 验证成功,注入到 Extension 中
|
||||
req.extensions_mut().insert(TenantId(uuid));
|
||||
Ok(next.run(req).await)
|
||||
} else {
|
||||
Err(StatusCode::BAD_REQUEST) // ID 格式错误
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// 如果是公开接口(如登录注册),可能不需要 TenantID,视业务而定
|
||||
// 这里假设严格模式,必须带 TenantID
|
||||
Err(StatusCode::BAD_REQUEST)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 实现 FromRequestParts 让 Handler 可以直接写 `tid: TenantId`
|
||||
impl<S> FromRequestParts<S> for TenantId
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = StatusCode;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let tenant_id_str = parts
|
||||
.headers
|
||||
.get("X-Tenant-ID")
|
||||
.and_then(|val| val.to_str().ok());
|
||||
|
||||
match tenant_id_str {
|
||||
Some(id_str) => uuid::Uuid::parse_str(id_str)
|
||||
.map(TenantId)
|
||||
.map_err(|_| StatusCode::BAD_REQUEST),
|
||||
None => Err(StatusCode::BAD_REQUEST),
|
||||
}
|
||||
}
|
||||
}
|
||||
30
src/models.rs
Normal file
30
src/models.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
use utoipa::ToSchema;
|
||||
use uuid::Uuid; // 关键引入
|
||||
|
||||
#[derive(Debug, Serialize, FromRow, ToSchema)]
|
||||
pub struct User {
|
||||
#[schema(example = "550e8400-e29b-41d4-a716-446655440000")]
|
||||
pub id: Uuid,
|
||||
#[schema(example = "11111111-1111-1111-1111-111111111111")]
|
||||
pub tenant_id: Uuid,
|
||||
#[schema(example = "user@example.com")]
|
||||
pub email: String,
|
||||
#[schema(ignore)] // 不在文档中显示密码哈希
|
||||
pub password_hash: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
pub struct CreateUserRequest {
|
||||
#[schema(example = "user@example.com")]
|
||||
pub email: String,
|
||||
#[schema(example = "securePassword123")]
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct UserResponse {
|
||||
pub id: Uuid,
|
||||
pub email: String,
|
||||
}
|
||||
69
src/services/mod.rs
Normal file
69
src/services/mod.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use crate::models::{CreateUserRequest, User}; // 假设你在 models 定义了这些
|
||||
use crate::utils::{create_jwt, hash_password, verify_password};
|
||||
use axum::Json;
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthService {
|
||||
pool: PgPool,
|
||||
jwt_secret: String,
|
||||
}
|
||||
|
||||
impl AuthService {
|
||||
pub fn new(pool: PgPool, jwt_secret: String) -> Self {
|
||||
Self { pool, jwt_secret }
|
||||
}
|
||||
|
||||
// 注册业务
|
||||
pub async fn register(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
req: CreateUserRequest,
|
||||
) -> Result<Json<User>, String> {
|
||||
// 1. 哈希密码
|
||||
let hashed = hash_password(&req.password)?;
|
||||
|
||||
// 2. 存入数据库 (带上 tenant_id)
|
||||
let query = r#"
|
||||
INSERT INTO users (tenant_id, email, password_hash)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, tenant_id, email, password_hash, created_at
|
||||
"#;
|
||||
let user = sqlx::query_as::<_, User>(query)
|
||||
.bind(tenant_id)
|
||||
.bind(&req.email)
|
||||
.bind(&hashed)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
Ok(Json(user))
|
||||
}
|
||||
|
||||
// 登录业务
|
||||
pub async fn login(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
email: &str,
|
||||
password: &str,
|
||||
) -> Result<String, String> {
|
||||
// 1. 查找用户 (带 tenant_id 防止跨租户登录)
|
||||
let query = "SELECT * FROM users WHERE tenant_id = $1 AND email = $2";
|
||||
let user = sqlx::query_as::<_, User>(query)
|
||||
.bind(tenant_id)
|
||||
.bind(email)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?
|
||||
.ok_or("User not found")?;
|
||||
|
||||
// 2. 验证密码
|
||||
if !verify_password(password, &user.password_hash) {
|
||||
return Err("Invalid password".to_string());
|
||||
}
|
||||
|
||||
// 3. 签发 Token
|
||||
create_jwt(user.id, user.tenant_id, &self.jwt_secret)
|
||||
}
|
||||
}
|
||||
60
src/utils/mod.rs
Normal file
60
src/utils/mod.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use argon2::{
|
||||
Argon2,
|
||||
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core::OsRng},
|
||||
};
|
||||
use jsonwebtoken::{EncodingKey, Header, encode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use uuid::Uuid;
|
||||
|
||||
// --- 密码部分 ---
|
||||
|
||||
pub fn hash_password(password: &str) -> Result<String, String> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
let password_hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map_err(|e| e.to_string())?
|
||||
.to_string();
|
||||
Ok(password_hash)
|
||||
}
|
||||
|
||||
pub fn verify_password(password: &str, password_hash: &str) -> bool {
|
||||
let parsed_hash = match PasswordHash::new(password_hash) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return false,
|
||||
};
|
||||
Argon2::default()
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
// --- JWT 部分 ---
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String, // 用户ID
|
||||
pub tenant_id: String, // 租户ID (关键!)
|
||||
pub exp: usize, // 过期时间
|
||||
}
|
||||
|
||||
pub fn create_jwt(user_id: Uuid, tenant_id: Uuid, secret: &str) -> Result<String, String> {
|
||||
let expiration = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as usize
|
||||
+ 24 * 3600; // 24小时过期
|
||||
|
||||
let claims = Claims {
|
||||
sub: user_id.to_string(),
|
||||
tenant_id: tenant_id.to_string(),
|
||||
exp: expiration,
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(secret.as_ref()),
|
||||
)
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
Reference in New Issue
Block a user