use async_trait::async_trait; use redis::Script; use uuid::Uuid; use crate::domain::DomainError; use crate::models::Code2TokenRequest; use crate::application::services::AuthService; #[derive(Clone)] pub struct ExchangeCodeUseCase { pub auth_service: AuthService, pub redis: redis::aio::ConnectionManager, pub auth_code_jwt_secret: String, } pub struct ExchangeCodeResult { pub tenant_id: Uuid, pub user_id: Uuid, pub access_token: String, pub refresh_token: String, pub token_type: String, pub expires_in: usize, } #[async_trait] pub trait Execute { async fn execute(&self, req: Code2TokenRequest) -> Result; } #[derive(serde::Deserialize)] struct AuthCodeClaims { sub: String, tenant_id: String, client_id: Option, #[allow(dead_code)] exp: usize, #[allow(dead_code)] iat: usize, #[allow(dead_code)] iss: String, jti: String, } #[derive(serde::Deserialize)] struct AuthCodeRedisValue { user_id: String, tenant_id: String, client_id: Option, } fn redis_key(jti: &str) -> String { format!("iam:auth_code:{}", jti) } #[async_trait] impl Execute for ExchangeCodeUseCase { async fn execute(&self, req: Code2TokenRequest) -> Result { if req.code.trim().is_empty() { return Err(DomainError::InvalidArgument("code".into())); } let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256); validation.set_issuer(&["iam-front", "iam-service"]); let token_data = jsonwebtoken::decode::( req.code.trim(), &jsonwebtoken::DecodingKey::from_secret(self.auth_code_jwt_secret.as_bytes()), &validation, ) .map_err(|_| DomainError::Unauthorized)?; let claims = token_data.claims; if let Some(cid) = &claims.client_id && cid != req.client_id.trim() { return Err(DomainError::Unauthorized); } let jti = claims.jti.trim(); if jti.is_empty() { return Err(DomainError::Unauthorized); } let script = Script::new( r#" local v = redis.call('GET', KEYS[1]) if v then redis.call('DEL', KEYS[1]) end return v "#, ); let key = redis_key(jti); let mut conn = self.redis.clone(); let val: Option = script .key(key) .invoke_async(&mut conn) .await .map_err(|_| DomainError::Unexpected)?; let Some(val) = val else { return Err(DomainError::Unauthorized); }; let stored: AuthCodeRedisValue = serde_json::from_str(&val).map_err(|_| DomainError::Unauthorized)?; if let Some(cid) = stored.client_id.as_deref() && cid != req.client_id.trim() { return Err(DomainError::Unauthorized); } if stored.user_id != claims.sub || stored.tenant_id != claims.tenant_id { return Err(DomainError::Unauthorized); } let user_id = Uuid::parse_str(&stored.user_id).map_err(|_| DomainError::Unauthorized)?; let tenant_id = Uuid::parse_str(&stored.tenant_id).map_err(|_| DomainError::Unauthorized)?; let tokens = self .auth_service .issue_tokens_for_user(tenant_id, user_id, 7200) .await .map_err(|_| DomainError::Unexpected)?; Ok(ExchangeCodeResult { tenant_id, user_id, access_token: tokens.access_token, refresh_token: tokens.refresh_token, token_type: tokens.token_type, expires_in: tokens.expires_in, }) } }