feat(callback): add callback
This commit is contained in:
213
src/api/handlers/auth.rs
Normal file
213
src/api/handlers/auth.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
use axum::{
|
||||
Router,
|
||||
extract::Query,
|
||||
http::{HeaderValue, header},
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::get,
|
||||
};
|
||||
use common_telemetry::AppError;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::api::AppState;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CallbackQuery {
|
||||
pub code: String,
|
||||
pub next: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, serde::Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct Code2TokenRequest {
|
||||
code: String,
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct Code2TokenData {
|
||||
access_token: String,
|
||||
refresh_token: String,
|
||||
expires_in: usize,
|
||||
token_type: String,
|
||||
tenant_id: String,
|
||||
user_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AppResponse<T> {
|
||||
code: i32,
|
||||
message: String,
|
||||
data: Option<T>,
|
||||
}
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new().route("/callback", get(sso_callback_handler))
|
||||
}
|
||||
|
||||
fn is_https(headers: &axum::http::HeaderMap) -> bool {
|
||||
headers
|
||||
.get("x-forwarded-proto")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.eq_ignore_ascii_case("https"))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn cookie_header(
|
||||
name: &str,
|
||||
value: &str,
|
||||
secure: bool,
|
||||
http_only: bool,
|
||||
max_age: Option<u64>,
|
||||
) -> String {
|
||||
let mut s = format!("{}={}; Path=/; SameSite=Strict", name, value);
|
||||
if secure {
|
||||
s.push_str("; Secure");
|
||||
}
|
||||
if http_only {
|
||||
s.push_str("; HttpOnly");
|
||||
}
|
||||
if let Some(v) = max_age {
|
||||
s.push_str(&format!("; Max-Age={}", v));
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
fn resolve_front_redirect(next: Option<String>) -> String {
|
||||
let base = std::env::var("CMS_FRONT_BASE_URL").ok();
|
||||
|
||||
let Some(raw) = next else {
|
||||
return base.unwrap_or_else(|| "/".to_string());
|
||||
};
|
||||
|
||||
if raw.starts_with('/') {
|
||||
return raw;
|
||||
}
|
||||
|
||||
if raw.starts_with("https://") {
|
||||
if let Some(base) = base {
|
||||
if raw.starts_with(&base) {
|
||||
return raw;
|
||||
}
|
||||
return base;
|
||||
}
|
||||
return raw;
|
||||
}
|
||||
|
||||
base.unwrap_or_else(|| "/".to_string())
|
||||
}
|
||||
|
||||
fn resolve_front_error_redirect(message: &str) -> String {
|
||||
let base = std::env::var("CMS_FRONT_BASE_URL").ok();
|
||||
let encoded = urlencoding::encode(message);
|
||||
if let Some(base) = base {
|
||||
format!(
|
||||
"{}/auth-error?message={}",
|
||||
base.trim_end_matches('/'),
|
||||
encoded
|
||||
)
|
||||
} else {
|
||||
format!("/auth-error?message={}", encoded)
|
||||
}
|
||||
}
|
||||
|
||||
async fn sso_callback_handler(
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(q): Query<CallbackQuery>,
|
||||
) -> Result<axum::response::Response, AppError> {
|
||||
if q.code.trim().is_empty() {
|
||||
let target = resolve_front_error_redirect("missing code");
|
||||
return Ok(Redirect::temporary(&target).into_response());
|
||||
}
|
||||
|
||||
let iam_base = std::env::var("IAM_BASE_URL")
|
||||
.or_else(|_| std::env::var("IAM_SERVICE_BASE_URL"))
|
||||
.map_err(|_| AppError::ConfigError("IAM_BASE_URL is required".into()))?;
|
||||
let client_id = std::env::var("IAM_CLIENT_ID")
|
||||
.map_err(|_| AppError::ConfigError("IAM_CLIENT_ID is required".into()))?;
|
||||
let client_secret = std::env::var("IAM_CLIENT_SECRET")
|
||||
.map_err(|_| AppError::ConfigError("IAM_CLIENT_SECRET is required".into()))?;
|
||||
|
||||
let http = reqwest::Client::new();
|
||||
let resp = http
|
||||
.post(format!(
|
||||
"{}/iam/api/v1/auth/code2token",
|
||||
iam_base.trim_end_matches('/')
|
||||
))
|
||||
.json(&Code2TokenRequest {
|
||||
code: q.code,
|
||||
client_id,
|
||||
client_secret,
|
||||
})
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?;
|
||||
|
||||
let status = resp.status();
|
||||
let body = resp
|
||||
.json::<AppResponse<Code2TokenData>>()
|
||||
.await
|
||||
.map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?;
|
||||
|
||||
if !status.is_success() || body.code != 0 {
|
||||
let target = resolve_front_error_redirect(&body.message);
|
||||
return Ok(Redirect::temporary(&target).into_response());
|
||||
}
|
||||
|
||||
let Some(data) = body.data else {
|
||||
let target = resolve_front_error_redirect("invalid code2token response");
|
||||
return Ok(Redirect::temporary(&target).into_response());
|
||||
};
|
||||
|
||||
let target = resolve_front_redirect(q.next);
|
||||
let secure = is_https(&headers);
|
||||
let mut res = Redirect::temporary(&target).into_response();
|
||||
let refresh_max_age = 30_u64 * 24 * 60 * 60;
|
||||
res.headers_mut().append(
|
||||
header::SET_COOKIE,
|
||||
HeaderValue::from_str(&cookie_header(
|
||||
"accessToken",
|
||||
&data.access_token,
|
||||
secure,
|
||||
true,
|
||||
Some(data.expires_in as u64),
|
||||
))
|
||||
.map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?,
|
||||
);
|
||||
res.headers_mut().append(
|
||||
header::SET_COOKIE,
|
||||
HeaderValue::from_str(&cookie_header(
|
||||
"refreshToken",
|
||||
&data.refresh_token,
|
||||
secure,
|
||||
true,
|
||||
Some(refresh_max_age),
|
||||
))
|
||||
.map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?,
|
||||
);
|
||||
res.headers_mut().append(
|
||||
header::SET_COOKIE,
|
||||
HeaderValue::from_str(&cookie_header(
|
||||
"tenantId",
|
||||
&data.tenant_id,
|
||||
secure,
|
||||
true,
|
||||
Some(refresh_max_age),
|
||||
))
|
||||
.map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?,
|
||||
);
|
||||
res.headers_mut().append(
|
||||
header::SET_COOKIE,
|
||||
HeaderValue::from_str(&cookie_header(
|
||||
"userId",
|
||||
&data.user_id,
|
||||
secure,
|
||||
true,
|
||||
Some(refresh_max_age),
|
||||
))
|
||||
.map_err(|e| AppError::AnyhowError(anyhow::anyhow!(e)))?,
|
||||
);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod article;
|
||||
pub mod auth;
|
||||
pub mod column;
|
||||
pub mod common;
|
||||
pub mod media;
|
||||
|
||||
@@ -21,6 +21,8 @@ pub struct AppState {
|
||||
pub fn build_router(state: AppState) -> Router {
|
||||
let health = Router::new().route("/healthz", get(|| async { axum::http::StatusCode::OK }));
|
||||
|
||||
let auth = Router::new().nest("/auth", handlers::auth::router());
|
||||
|
||||
let v1 = Router::new()
|
||||
.nest("/columns", handlers::column::router())
|
||||
.nest("/tags", handlers::tag::router())
|
||||
@@ -31,6 +33,7 @@ pub fn build_router(state: AppState) -> Router {
|
||||
.route("/favicon.ico", get(|| async { axum::http::StatusCode::NO_CONTENT }))
|
||||
.merge(Scalar::with_url("/scalar", ApiDoc::openapi()))
|
||||
.merge(health)
|
||||
.merge(auth)
|
||||
.nest("/v1", v1)
|
||||
.layer(axum::middleware::from_fn(catch_panic))
|
||||
.layer(axum::middleware::from_fn(request_logger))
|
||||
|
||||
16
src/main.rs
16
src/main.rs
@@ -24,9 +24,15 @@ async fn main() {
|
||||
});
|
||||
|
||||
let pool = db::init_pool(&config).await.expect("failed to init db pool");
|
||||
db::run_migrations(&pool)
|
||||
.await
|
||||
.expect("failed to run migrations");
|
||||
let run_migrations = std::env::var("RUN_MIGRATIONS")
|
||||
.ok()
|
||||
.map(|v| matches!(v.as_str(), "1" | "true" | "TRUE"))
|
||||
.unwrap_or(false);
|
||||
if run_migrations {
|
||||
db::run_migrations(&pool)
|
||||
.await
|
||||
.expect("failed to run migrations");
|
||||
}
|
||||
|
||||
let state = AppState {
|
||||
services: CmsServices::new(pool),
|
||||
@@ -40,7 +46,7 @@ async fn main() {
|
||||
};
|
||||
|
||||
let auth_cfg = AuthMiddlewareConfig {
|
||||
skip_exact_paths: vec!["/healthz".to_string()],
|
||||
skip_exact_paths: vec!["/healthz".to_string(), "/auth/callback".to_string()],
|
||||
skip_path_prefixes: vec!["/scalar".to_string()],
|
||||
jwt: match &config.jwt_public_key_pem {
|
||||
Some(pem) => auth_kit::jwt::JwtVerifyConfig::rs256_from_pem("iam-service", pem)
|
||||
@@ -58,7 +64,7 @@ async fn main() {
|
||||
},
|
||||
};
|
||||
let tenant_cfg = TenantMiddlewareConfig {
|
||||
skip_exact_paths: vec!["/healthz".to_string()],
|
||||
skip_exact_paths: vec!["/healthz".to_string(), "/auth/callback".to_string()],
|
||||
skip_path_prefixes: vec!["/scalar".to_string()],
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user