セッション管理機構の実装
セッションストアを設定する
repository.rs
に以下を追加しましょう。
use async_sqlx_session::MySqlSessionStore;
use sqlx::mysql::MySqlConnectOptions;
use sqlx::mysql::MySqlPool;
use std::env;
pub mod country;
pub mod users;
#[derive(Clone)]
pub struct Repository {
pool: MySqlPool,
session_store: MySqlSessionStore,
}
impl Repository {
pub async fn connect() -> anyhow::Result<Self> {
let options = get_options()?;
let pool = sqlx::MySqlPool::connect_with(options).await?;
let session_store =
MySqlSessionStore::from_client(pool.clone()).with_table_name("user_sessions");
Ok(Self {
pool,
session_store,
})
}
pub async fn migrate(&self) -> anyhow::Result<()> {
sqlx::migrate!("./migrations").run(&self.pool).await?;
self.session_store.migrate().await?;
Ok(())
}
}
...(省略)
これらはセッションストアの設定です。 セッションの情報を記憶するための場所をデータベース上に設定して、session_store
からアクセスできるようにしています。
login
ハンドラの実装
続いて、login
ハンドラを handler/auth.rs
に実装していきましょう。
pub async fn login(
State(state): State<Repository>,
Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
}
login
ハンドラの外に以下の構造体を追加します。
#[derive(Deserialize)]
pub struct Login {
pub username: String,
pub password: String,
}
login
ハンドラの中身を実装する前に、必要になるデータベース操作のメソッドを追加します。ここで必要になるのは以下の 2 つです。
username
からid
を取得するメソッドid
とpassword
の組が登録されているものと一致するかを確認するメソッド
この 2 つを repository/users.rs
に追加します。
use super::Repository;
impl Repository {
pub async fn is_exist_username(&self, username: String) -> sqlx::Result<bool> {
...(省略)
}
pub async fn create_user(&self, username: String) -> sqlx::Result<u64> {
...(省略)
}
pub async fn get_user_id_by_name(&self, username: String) -> sqlx::Result<i32> {
let result = sqlx::query_scalar("SELECT id FROM users WHERE username = ?")
.bind(&username)
.fetch_one(&self.pool)
.await?;
Ok(result)
}
pub async fn save_user_password(&self, id: i32, password: String) -> anyhow::Result<()> {
...(省略)
}
pub async fn verify_user_password(&self, id: i32, password: String) -> anyhow::Result<bool> {
let hash =
sqlx::query_scalar::<_, String>("SELECT hashed_pass FROM user_passwords WHERE id = ?")
.bind(id)
.fetch_one(&self.pool)
.await?;
Ok(bcrypt::verify(password, &hash)?)
}
}
データベースに保存されているパスワードはハッシュ化されています。
ハッシュ化は不可逆な処理なので、ハッシュ化されたものから原文を調べることはできません。確認する際はもらったパスワードをハッシュ化することで行います。 bcrypt::verify
によってパスワードの検証ができます。
handler/auth.rs
に戻り、login
ハンドラを実装していきます。
pub async fn login(
State(state): State<Repository>,
Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
// バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
if body.username.is_empty() || body.password.is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
// データベースからユーザーを取得する
let id = state
.get_user_id_by_name(body.username.clone())
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
})?;
}
ユーザーが存在しなかった場合は sqlx::Error::RowNotFound
というエラーが返ってきます。 もしそのエラーなら 401 (Unauthorized)、そうでなければ 500 (Internal Server Error) です。 もし 404 (Not Found) とすると、「このユーザーはパスワードが違うのではなく存在しないんだ」という事がわかってしまい(このユーザーは存在していてパスワードは違う事も分かります)、セキュリティ上のリスクに繋がります。
pub async fn login(
State(state): State<Repository>,
Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
// バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
if body.username.is_empty() || body.password.is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
// データベースからユーザーを取得する
let id = state
.get_user_id_by_name(body.username.clone())
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
})?;
// パスワードが一致しているかを確かめる
if !state
.verify_user_password(id, body.password.clone())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
return Err(StatusCode::UNAUTHORIZED);
}
}
データベースでエラーが起きた場合や、検証の操作に失敗した場合は 500 (Internal Server Error), パスワードが間違っていた場合 401 (Unauthorized) を返却しています。
pub async fn login(
State(state): State<Repository>,
Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
...(省略)
// パスワードが一致しているかを確かめる
if !state
.verify_user_password(id, body.password.clone())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
return Err(StatusCode::UNAUTHORIZED);
}
// セッションストアに登録する
let session_id = state
.create_user_session(id.to_string())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
id
をセッションストアに登録して、セッション id を取得します。
ここで用いる、セッションストアに登録するメソッド create_user_session
を実装していきます。
ファイル repository/users_session.rs
を作成し、以下を記述してください。
use anyhow::Context;
use async_session::{Session, SessionStore};
use super::Repository;
impl Repository {
pub async fn create_user_session(&self, user_id: String) -> anyhow::Result<String> {
let mut session = Session::new();
session
.insert("user_id", user_id)
.with_context(|| "Failed to insert user_id")?;
let session_id = self
.session_store
.store_session(session)
.await
.with_context(|| "Failed to store session")?
.with_context(|| "Failed to create session")?;
Ok(session_id)
}
}
セッションに user_id
を登録し、セッションストアに保存します。 セッション id を返却します。
handler/auth.rs
に戻り、ヘッダーにセッション id を設定する処理を追加します。
pub async fn login(
State(state): State<Repository>,
Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
...(省略)
// セッションストアに登録する
let session_id = state
.create_user_session(id.to_string())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// クッキーをセットする
let mut headers = header::HeaderMap::new();
headers.insert(
header::SET_COOKIE,
format!("session_id={}; HttpOnly; SameSite=Strict", session_id)
.parse()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
);
Ok((StatusCode::OK, headers))
}
ここまで書いたら、 login
ハンドラを使えるようにしましょう。 handler.rs
に以下を追加してください。
pub fn make_router(app_state: Repository) -> Router {
let city_router = Router::new()
.route("/cities/:city_name", get(country::get_city_handler))
.route("/cities", post(country::post_city_handler));
let auth_router = Router::new()
.route("/signup", post(auth::sign_up))
.route("/login", post(auth::login));
Router::new()
.nest("/", city_router)
.nest("/", auth_router)
.with_state(app_state)
}
ここまでの全体像
use axum::{
extract::State,
http::{header, StatusCode},
response::IntoResponse,
Json,
};
use serde::Deserialize;
use crate::repository::Repository;
#[derive(Deserialize)]
pub struct SignUp {
pub username: String,
pub password: String,
}
pub async fn sign_up(
State(state): State<Repository>,
Json(body): Json<SignUp>,
) -> Result<StatusCode, StatusCode> {
// バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
if body.username.is_empty() || body.password.is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
// 登録しようとしているユーザーが既にデータベース内に存在したら409 Conflictを返す
if let Ok(true) = state.is_exist_username(body.username.clone()).await {
return Err(StatusCode::CONFLICT);
}
// ユーザーを作成する
let id = state
.create_user(body.username.clone())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// パスワードを保存する
state
.save_user_password(id as i32, body.password.clone())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(StatusCode::CREATED)
}
#[derive(Deserialize)]
pub struct Login {
pub username: String,
pub password: String,
}
pub async fn login(
State(state): State<Repository>,
Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
// バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
if body.username.is_empty() || body.password.is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
// データベースからユーザーを取得する
let id = state
.get_user_id_by_name(body.username.clone())
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
})?;
// パスワードが一致しているかを確かめる
if !state
.verify_user_password(id, body.password.clone())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
return Err(StatusCode::UNAUTHORIZED);
}
// セッションストアに登録する
let session_id = state
.create_user_session(id.to_string())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// クッキーをセットする
let mut headers = header::HeaderMap::new();
headers.insert(
header::SET_COOKIE,
format!("session_id={}; HttpOnly; SameSite=Strict", session_id)
.parse()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
);
Ok((StatusCode::OK, headers))
}
use super::Repository;
impl Repository {
pub async fn is_exist_username(&self, username: String) -> sqlx::Result<bool> {
let result = sqlx::query("SELECT * FROM users WHERE username = ?")
.bind(&username)
.fetch_optional(&self.pool)
.await?;
Ok(result.is_some())
}
pub async fn create_user(&self, username: String) -> sqlx::Result<u64> {
let result = sqlx::query("INSERT INTO users (username) VALUES (?)")
.bind(&username)
.execute(&self.pool)
.await?;
Ok(result.last_insert_id())
}
pub async fn get_user_id_by_name(&self, username: String) -> sqlx::Result<i32> {
let result = sqlx::query_scalar("SELECT id FROM users WHERE username = ?")
.bind(&username)
.fetch_one(&self.pool)
.await?;
Ok(result)
}
pub async fn save_user_password(&self, id: i32, password: String) -> anyhow::Result<()> {
let hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?;
sqlx::query("INSERT INTO user_passwords (id, hashed_pass) VALUES (?, ?)")
.bind(id)
.bind(hash)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn verify_user_password(&self, id: i32, password: String) -> anyhow::Result<bool> {
let hash =
sqlx::query_scalar::<_, String>("SELECT hashed_pass FROM user_passwords WHERE id = ?")
.bind(id)
.fetch_one(&self.pool)
.await?;
Ok(bcrypt::verify(password, &hash)?)
}
}
use anyhow::Context;
use async_session::{Session, SessionStore};
use super::Repository;
impl Repository {
pub async fn create_user_session(&self, user_id: String) -> anyhow::Result<String> {
let mut session = Session::new();
session
.insert("user_id", user_id)
.with_context(|| "Failed to insert user_id")?;
let session_id = self
.session_store
.store_session(session)
.await
.with_context(|| "Failed to store session")?
.with_context(|| "Failed to create session")?;
Ok(session_id)
}
}
use async_sqlx_session::MySqlSessionStore;
use sqlx::mysql::MySqlConnectOptions;
use sqlx::mysql::MySqlPool;
use std::env;
pub mod country;
pub mod users;
pub mod users_session;
#[derive(Clone)]
pub struct Repository {
pool: MySqlPool,
session_store: MySqlSessionStore,
}
impl Repository {
pub async fn connect() -> anyhow::Result<Self> {
let options = get_options()?;
let pool = sqlx::MySqlPool::connect_with(options).await?;
let session_store =
MySqlSessionStore::from_client(pool.clone()).with_table_name("user_sessions");
Ok(Self {
pool,
session_store,
})
}
pub async fn migrate(&self) -> anyhow::Result<()> {
sqlx::migrate!("./migrations").run(&self.pool).await?;
self.session_store.migrate().await?;
Ok(())
}
}
fn get_options() -> anyhow::Result<MySqlConnectOptions> {
let host = env::var("DB_HOSTNAME")?;
let port = env::var("DB_PORT")?.parse()?;
let username = env::var("DB_USERNAME")?;
let password = env::var("DB_PASSWORD")?;
let database = env::var("DB_DATABASE")?;
let timezone = Some(String::from("Asia/Tokyo"));
let collation = String::from("utf8mb4_unicode_ci");
Ok(MySqlConnectOptions::new()
.host(&host)
.port(port)
.username(&username)
.password(&password)
.database(&database)
.timezone(timezone)
.collation(&collation))
}
Middleware の実装
続いて、auth_middleware
を実装します。 まず、これは Handler ではなく Middleware と呼ばれます。
送られてくるリクエストは、Middleware を経由して、 Handler に流れていきます。
Middleware から次の Middleware/Handler を呼び出す際は next.run(req)
と記述します。
以下をhandler/auth.rs
に追加してください。
pub async fn auth_middleware(
State(state): State<Repository>,
TypedHeader(cookie): TypedHeader<Cookie>,
mut req: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
// セッションIDを取得する
let session_id = cookie
.get("session_id")
.ok_or(StatusCode::UNAUTHORIZED)?
.to_string();
// セッションストアからユーザーIDを取得する
let user_id = state
.get_user_id_by_session_id(&session_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::UNAUTHORIZED)?;
// リクエストにユーザーIDを追加する
req.extensions_mut().insert(user_id);
// 次のミドルウェアを呼び出す
Ok(next.run(req).await)
}
この Middleware はリクエストを送ったユーザーがログインしているのかをチェックし、 ログインしているならリクエスト(req
) に user_id
を追加します。
Cookie からセッション id を取得し、セッションストアからユーザー id を取得します。 ここで、セッション id がなかった場合や、不正なセッション id だった場合は 401 (Unauthorized) を返却します。 正しくログインされていれば、次の Middleware/Handler を呼び出します。
ここで使用した、 get_user_id_by_session_id
メソッドを repository/users_session.rs
に追加します。
pub async fn get_user_id_by_session_id(
&self,
session_id: &String,
) -> anyhow::Result<Option<String>> {
let session = self
.session_store
.load_session(session_id.clone())
.await
.with_context(|| "Failed to load session")?;
Ok(session.and_then(|s| s.get::<String>("user_id")))
}
最後に、Middleware を設定しましょう。 ログインが必要なエンドポイントを with_auth_router
でまとめ、Middleware を適用します。
handler.rs
に以下を追加してください。
use axum::{
middleware::from_fn_with_state,
routing::{get, post},
Router,
};
use crate::repository::Repository;
mod auth;
mod country;
pub fn make_router(app_state: Repository) -> Router {
let city_router = Router::new()
let with_auth_router = Router::new()
.route("/cities/:city_name", get(country::get_city_handler))
.route("/cities", post(country::post_city_handler));
.route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware));
let auth_router = Router::new()
.route("/signup", post(auth::sign_up))
.route("/login", post(auth::login));
Router::new()
.nest("/", city_router)
.nest("/", with_auth_router)
.nest("/", auth_router)
.nest("/", ping_router)
.with_state(app_state)
}
これで、この章の目標である「ログインしないと利用できないようにする」が達成されました。
logout ハンドラの実装
ログアウト機能をまだ実装していなかったので、 logout
ハンドラを実装していきます。
まず、handler/auth.rs
に以下を追加してください。
pub async fn logout(
State(state): State<Repository>,
TypedHeader(cookie): TypedHeader<Cookie>,
) -> Result<impl IntoResponse, StatusCode> {
// セッションIDを取得する
let session_id = cookie
.get("session_id")
.ok_or(StatusCode::UNAUTHORIZED)?
.to_string();
// セッションストアから削除する
state
.delete_user_session(session_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// クッキーを削除する
let mut headers = header::HeaderMap::new();
headers.insert(
header::SET_COOKIE,
"session_id=; HttpOnly; SameSite=Strict; Max-Age=0"
.parse()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
);
Ok((StatusCode::OK, headers))
}
ログアウトするときは、ログインするときとは逆にセッションと Cookie を削除します。
ここで呼び出す delete_user_session
メソッドを repository/users_session.rs
に追加します。
pub async fn delete_user_session(&self, session_id: String) -> anyhow::Result<()> {
let session = self
.session_store
.load_session(session_id.clone())
.await
.with_context(|| "Failed to load session")?
.with_context(|| "Failed to find session")?;
self.session_store
.destroy_session(session)
.await
.with_context(|| "Failed to destroy session")?;
Ok(())
}
セッション ID からセッションを取得し、セッションストアから削除します。
最後に、handler.rs
に logout
ハンドラを追加します。
let auth_router = Router::new()
.route("/signup", post(auth::sign_up))
.route("/login", post(auth::login))
.route("/logout", post(auth::logout));
me ハンドラの実装
最後に、 me
ハンドラを実装します。叩いたときに自分の情報が返ってくるエンドポイントです。
以下を handler/auth.rs
に追加してください。
#[derive(Serialize)]
pub struct Me {
pub username: String,
}
pub async fn me(State(state): State<Repository>, req: Request) -> Result<Json<Me>, StatusCode> {
// リクエストからユーザーIDを取得する
let user_id = req
.extensions()
.get::<String>()
.ok_or(StatusCode::UNAUTHORIZED)?
.to_string();
// データベースからユーザー名を取得する
let username = state
.get_user_name_by_id(
user_id
.parse()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(Me { username }))
}
リクエストからユーザー ID を取得し、データベースからユーザー名を取得します。
ここで呼び出す get_user_name_by_id
メソッドを repository/users.rs
に追加します。
impl Repository {
...(省略)
pub async fn delete_user_session(&self, session_id: String) -> anyhow::Result<()> {
let session = self
.session_store
.load_session(session_id.clone())
.await
.with_context(|| "Failed to load session")?
.with_context(|| "Failed to find session")?;
self.session_store
.destroy_session(session)
.await
.with_context(|| "Failed to destroy session")?;
Ok(())
}
...(省略)
}
最後に、handler.rs
に me
ハンドラを追加します。
let with_auth_router = Router::new()
.route("/cities/:city_name", get(country::get_city_handler))
.route("/cities", post(country::post_city_handler))
.route("/me", get(auth::me))
.route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware));