Skip to content

セッション管理機構の実装

セッションストアを設定する

repository.rsに以下を追加しましょう。

rs
use async_sqlx_session::MySqlSessionStore; 
use sqlx::mysql::MySqlConnectOptions;
use sqlx::mysql::MySqlPool;
use std::env;

pub mod country;
pub mod users;

#[derive(Clone)]
pub struct Repository {
    pool: MySqlPool,
    session_store: MySqlSessionStore, 
}

impl Repository {
    pub async fn connect() -> anyhow::Result<Self> {
        let options = get_options()?;
        let pool = sqlx::MySqlPool::connect_with(options).await?;

        let session_store =
            MySqlSessionStore::from_client(pool.clone()).with_table_name("user_sessions"); 

        Ok(Self {
            pool,
            session_store, 
        })
    }

    pub async fn migrate(&self) -> anyhow::Result<()> {
        sqlx::migrate!("./migrations").run(&self.pool).await?;
        self.session_store.migrate().await?;
        Ok(())
    }
}
...(省略)

これらはセッションストアの設定です。 セッションの情報を記憶するための場所をデータベース上に設定して、session_store からアクセスできるようにしています。

login ハンドラの実装

続いて、login ハンドラを handler/auth.rs に実装していきましょう。

rs
pub async fn login( 
    State(state): State<Repository>, 
    Json(body): Json<Login>, 
) -> Result<impl IntoResponse, StatusCode> { 
} 

login ハンドラの外に以下の構造体を追加します。

rs
#[derive(Deserialize)] 
pub struct Login { 
    pub username: String, 
    pub password: String, 
} 

login ハンドラの中身を実装する前に、必要になるデータベース操作のメソッドを追加します。ここで必要になるのは以下の 2 つです。

  • username から id を取得するメソッド
  • idpassword の組が登録されているものと一致するかを確認するメソッド

この 2 つを repository/users.rs に追加します。

rs
use super::Repository;

impl Repository {
    pub async fn is_exist_username(&self, username: String) -> sqlx::Result<bool> {
        ...(省略)
    }

    pub async fn create_user(&self, username: String) -> sqlx::Result<u64> {
        ...(省略)
    }

    pub async fn get_user_id_by_name(&self, username: String) -> sqlx::Result<i32> { 
        let result = sqlx::query_scalar("SELECT id FROM users WHERE username = ?") 
            .bind(&username) 
            .fetch_one(&self.pool) 
            .await?; 
        Ok(result) 
    } 

    pub async fn save_user_password(&self, id: i32, password: String) -> anyhow::Result<()> {
        ...(省略)
    }

    pub async fn verify_user_password(&self, id: i32, password: String) -> anyhow::Result<bool> { 
        let hash =
            sqlx::query_scalar::<_, String>("SELECT hashed_pass FROM user_passwords WHERE id = ?") 
                .bind(id) 
                .fetch_one(&self.pool) 
                .await?; 

        Ok(bcrypt::verify(password, &hash)?) 
    } 
}

データベースに保存されているパスワードはハッシュ化されています。

ハッシュ化は不可逆な処理なので、ハッシュ化されたものから原文を調べることはできません。確認する際はもらったパスワードをハッシュ化することで行います。 bcrypt::verify によってパスワードの検証ができます。

handler/auth.rs に戻り、login ハンドラを実装していきます。

rs
pub async fn login( 
    State(state): State<Repository>, 
    Json(body): Json<Login>, 
) -> Result<impl IntoResponse, StatusCode> { 
    // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
    if body.username.is_empty() || body.password.is_empty() { 
        return Err(StatusCode::BAD_REQUEST); 
    } 

    // データベースからユーザーを取得する
    let id = state 
        .get_user_id_by_name(body.username.clone()) 
        .await
        .map_err(|e| match e { 
            sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED, 
            _ => StatusCode::INTERNAL_SERVER_ERROR, 
        })?; 
} 

ユーザーが存在しなかった場合は sqlx::Error::RowNotFound というエラーが返ってきます。 もしそのエラーなら 401 (Unauthorized)、そうでなければ 500 (Internal Server Error) です。 もし 404 (Not Found) とすると、「このユーザーはパスワードが違うのではなく存在しないんだ」という事がわかってしまい(このユーザーは存在していてパスワードは違う事も分かります)、セキュリティ上のリスクに繋がります。

rs
pub async fn login(
    State(state): State<Repository>,
    Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
    // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
    if body.username.is_empty() || body.password.is_empty() {
        return Err(StatusCode::BAD_REQUEST);
    }

    // データベースからユーザーを取得する
    let id = state
        .get_user_id_by_name(body.username.clone())
        .await
        .map_err(|e| match e {
            sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED,
            _ => StatusCode::INTERNAL_SERVER_ERROR,
        })?;

    // パスワードが一致しているかを確かめる
    if !state 
        .verify_user_password(id, body.password.clone()) 
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
    { 
        return Err(StatusCode::UNAUTHORIZED); 
    } 
}

データベースでエラーが起きた場合や、検証の操作に失敗した場合は 500 (Internal Server Error), パスワードが間違っていた場合 401 (Unauthorized) を返却しています。

rs
pub async fn login( 
    State(state): State<Repository>, 
    Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
    ...(省略)
    
    // パスワードが一致しているかを確かめる
    if !state
        .verify_user_password(id, body.password.clone())
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
    {
        return Err(StatusCode::UNAUTHORIZED);
    } 

    // セッションストアに登録する
    let session_id = state 
        .create_user_session(id.to_string()) 
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 
    
}

id をセッションストアに登録して、セッション id を取得します。

ここで用いる、セッションストアに登録するメソッド create_user_session を実装していきます。

ファイル repository/users_session.rs を作成し、以下を記述してください。

rs
use anyhow::Context; 
use async_session::{Session, SessionStore}; 

use super::Repository; 

impl Repository { 
    pub async fn create_user_session(&self, user_id: String) -> anyhow::Result<String> { 
        let mut session = Session::new(); 

        session 
            .insert("user_id", user_id) 
            .with_context(|| "Failed to insert user_id")?; 

        let session_id = self
            .session_store 
            .store_session(session) 
            .await
            .with_context(|| "Failed to store session")?
            .with_context(|| "Failed to create session")?; 

        Ok(session_id) 
    } 
} 

セッションに user_id を登録し、セッションストアに保存します。 セッション id を返却します。

handler/auth.rs に戻り、ヘッダーにセッション id を設定する処理を追加します。

rs
pub async fn login(
    State(state): State<Repository>,
    Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
    ...(省略)

    // セッションストアに登録する
    let session_id = state
        .create_user_session(id.to_string())
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // クッキーをセットする
    let mut headers = header::HeaderMap::new(); 

    headers.insert( 
        header::SET_COOKIE, 
        format!("session_id={}; HttpOnly; SameSite=Strict", session_id) 
            .parse() 
            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, 
    ); 

    Ok((StatusCode::OK, headers)) 
}

ここまで書いたら、 login ハンドラを使えるようにしましょう。 handler.rs に以下を追加してください。

rs
pub fn make_router(app_state: Repository) -> Router {
    let city_router = Router::new()
        .route("/cities/:city_name", get(country::get_city_handler))
        .route("/cities", post(country::post_city_handler));

    let auth_router = Router::new()
        .route("/signup", post(auth::sign_up))
        .route("/login", post(auth::login)); 

    Router::new()
        .nest("/", city_router)
        .nest("/", auth_router)
        .with_state(app_state)
}
ここまでの全体像
rs
use axum::{
    extract::State,
    http::{header, StatusCode},
    response::IntoResponse,
    Json,
};
use serde::Deserialize;

use crate::repository::Repository;

#[derive(Deserialize)]
pub struct SignUp {
    pub username: String,
    pub password: String,
}

pub async fn sign_up(
    State(state): State<Repository>,
    Json(body): Json<SignUp>,
) -> Result<StatusCode, StatusCode> {
    // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
    if body.username.is_empty() || body.password.is_empty() {
        return Err(StatusCode::BAD_REQUEST);
    }

    // 登録しようとしているユーザーが既にデータベース内に存在したら409 Conflictを返す
    if let Ok(true) = state.is_exist_username(body.username.clone()).await {
        return Err(StatusCode::CONFLICT);
    }

    // ユーザーを作成する
    let id = state
        .create_user(body.username.clone())
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // パスワードを保存する
    state
        .save_user_password(id as i32, body.password.clone())
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    Ok(StatusCode::CREATED)
}

#[derive(Deserialize)]
pub struct Login {
    pub username: String,
    pub password: String,
}

pub async fn login(
    State(state): State<Repository>,
    Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
    // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
    if body.username.is_empty() || body.password.is_empty() {
        return Err(StatusCode::BAD_REQUEST);
    }

    // データベースからユーザーを取得する
    let id = state
        .get_user_id_by_name(body.username.clone())
        .await
        .map_err(|e| match e {
            sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED,
            _ => StatusCode::INTERNAL_SERVER_ERROR,
        })?;

    // パスワードが一致しているかを確かめる
    if !state
        .verify_user_password(id, body.password.clone())
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
    {
        return Err(StatusCode::UNAUTHORIZED);
    }

    // セッションストアに登録する
    let session_id = state
        .create_user_session(id.to_string())
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // クッキーをセットする
    let mut headers = header::HeaderMap::new();

    headers.insert(
        header::SET_COOKIE,
        format!("session_id={}; HttpOnly; SameSite=Strict", session_id)
            .parse()
            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
    );

    Ok((StatusCode::OK, headers))
}
rs
use super::Repository;

impl Repository {
    pub async fn is_exist_username(&self, username: String) -> sqlx::Result<bool> {
        let result = sqlx::query("SELECT * FROM users WHERE username = ?")
            .bind(&username)
            .fetch_optional(&self.pool)
            .await?;
        Ok(result.is_some())
    }

    pub async fn create_user(&self, username: String) -> sqlx::Result<u64> {
        let result = sqlx::query("INSERT INTO users (username) VALUES (?)")
            .bind(&username)
            .execute(&self.pool)
            .await?;
        Ok(result.last_insert_id())
    }

    pub async fn get_user_id_by_name(&self, username: String) -> sqlx::Result<i32> {
        let result = sqlx::query_scalar("SELECT id FROM users WHERE username = ?")
            .bind(&username)
            .fetch_one(&self.pool)
            .await?;
        Ok(result)
    }

    pub async fn save_user_password(&self, id: i32, password: String) -> anyhow::Result<()> {
        let hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?;

        sqlx::query("INSERT INTO user_passwords (id, hashed_pass) VALUES (?, ?)")
            .bind(id)
            .bind(hash)
            .execute(&self.pool)
            .await?;

        Ok(())
    }

    pub async fn verify_user_password(&self, id: i32, password: String) -> anyhow::Result<bool> {
        let hash =
            sqlx::query_scalar::<_, String>("SELECT hashed_pass FROM user_passwords WHERE id = ?")
                .bind(id)
                .fetch_one(&self.pool)
                .await?;

        Ok(bcrypt::verify(password, &hash)?)
    }
}
rs
use anyhow::Context;
use async_session::{Session, SessionStore};

use super::Repository;

impl Repository {
    pub async fn create_user_session(&self, user_id: String) -> anyhow::Result<String> {
        let mut session = Session::new();

        session
            .insert("user_id", user_id)
            .with_context(|| "Failed to insert user_id")?;

        let session_id = self
            .session_store
            .store_session(session)
            .await
            .with_context(|| "Failed to store session")?
            .with_context(|| "Failed to create session")?;

        Ok(session_id)
    }
}
rs
use async_sqlx_session::MySqlSessionStore;
use sqlx::mysql::MySqlConnectOptions;
use sqlx::mysql::MySqlPool;
use std::env;

pub mod country;
pub mod users;
pub mod users_session;

#[derive(Clone)]
pub struct Repository {
    pool: MySqlPool,
    session_store: MySqlSessionStore,
}

impl Repository {
    pub async fn connect() -> anyhow::Result<Self> {
        let options = get_options()?;
        let pool = sqlx::MySqlPool::connect_with(options).await?;

        let session_store =
            MySqlSessionStore::from_client(pool.clone()).with_table_name("user_sessions");

        Ok(Self {
            pool,
            session_store,
        })
    }

    pub async fn migrate(&self) -> anyhow::Result<()> {
        sqlx::migrate!("./migrations").run(&self.pool).await?;
        self.session_store.migrate().await?;
        Ok(())
    }
}

fn get_options() -> anyhow::Result<MySqlConnectOptions> {
    let host = env::var("DB_HOSTNAME")?;
    let port = env::var("DB_PORT")?.parse()?;
    let username = env::var("DB_USERNAME")?;
    let password = env::var("DB_PASSWORD")?;
    let database = env::var("DB_DATABASE")?;
    let timezone = Some(String::from("Asia/Tokyo"));
    let collation = String::from("utf8mb4_unicode_ci");

    Ok(MySqlConnectOptions::new()
        .host(&host)
        .port(port)
        .username(&username)
        .password(&password)
        .database(&database)
        .timezone(timezone)
        .collation(&collation))
}

Middleware の実装

続いて、auth_middleware を実装します。 まず、これは Handler ではなく Middleware と呼ばれます。

送られてくるリクエストは、Middleware を経由して、 Handler に流れていきます。

Middleware から次の Middleware/Handler を呼び出す際は next.run(req) と記述します。

以下をhandler/auth.rsに追加してください。

rs
pub async fn auth_middleware( 
    State(state): State<Repository>, 
    TypedHeader(cookie): TypedHeader<Cookie>, 
    mut req: Request, 
    next: Next, 
) -> Result<impl IntoResponse, StatusCode> { 

    // セッションIDを取得する
    let session_id = cookie 
        .get("session_id") 
        .ok_or(StatusCode::UNAUTHORIZED)?
        .to_string(); 

    // セッションストアからユーザーIDを取得する
    let user_id = state 
        .get_user_id_by_session_id(&session_id) 
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
        .ok_or(StatusCode::UNAUTHORIZED)?; 

    // リクエストにユーザーIDを追加する
    req.extensions_mut().insert(user_id); 

    // 次のミドルウェアを呼び出す
    Ok(next.run(req).await) 
} 

この Middleware はリクエストを送ったユーザーがログインしているのかをチェックし、 ログインしているならリクエスト(req) に user_id を追加します。

Cookie からセッション id を取得し、セッションストアからユーザー id を取得します。 ここで、セッション id がなかった場合や、不正なセッション id だった場合は 401 (Unauthorized) を返却します。 正しくログインされていれば、次の Middleware/Handler を呼び出します。

ここで使用した、 get_user_id_by_session_id メソッドを repository/users_session.rs に追加します。

rs
pub async fn get_user_id_by_session_id( 
    &self, 
    session_id: &String, 
) -> anyhow::Result<Option<String>> { 
    let session = self
        .session_store 
        .load_session(session_id.clone()) 
        .await
        .with_context(|| "Failed to load session")?; 

    Ok(session.and_then(|s| s.get::<String>("user_id"))) 
} 

最後に、Middleware を設定しましょう。 ログインが必要なエンドポイントを with_auth_router でまとめ、Middleware を適用します。

handler.rs に以下を追加してください。

rs
use axum::{
    middleware::from_fn_with_state, 
    routing::{get, post},
    Router,
};

use crate::repository::Repository;

mod auth;
mod country;

pub fn make_router(app_state: Repository) -> Router {
    let city_router = Router::new() 
    let with_auth_router = Router::new() 
        .route("/cities/:city_name", get(country::get_city_handler))
        .route("/cities", post(country::post_city_handler));
        .route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware)); 

    let auth_router = Router::new()
        .route("/signup", post(auth::sign_up))
        .route("/login", post(auth::login)); 

    Router::new()
        .nest("/", city_router) 
        .nest("/", with_auth_router) 
        .nest("/", auth_router)
        .nest("/", ping_router)
        .with_state(app_state)
}

これで、この章の目標である「ログインしないと利用できないようにする」が達成されました。

logout ハンドラの実装

ログアウト機能をまだ実装していなかったので、 logout ハンドラを実装していきます。

まず、handler/auth.rs に以下を追加してください。

rs
pub async fn logout( 
    State(state): State<Repository>, 
    TypedHeader(cookie): TypedHeader<Cookie>, 
) -> Result<impl IntoResponse, StatusCode> { 
    // セッションIDを取得する
    let session_id = cookie 
        .get("session_id") 
        .ok_or(StatusCode::UNAUTHORIZED)?
        .to_string(); 

    // セッションストアから削除する
    state 
        .delete_user_session(session_id) 
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 

    // クッキーを削除する
    let mut headers = header::HeaderMap::new(); 
    headers.insert( 
        header::SET_COOKIE, 
        "session_id=; HttpOnly; SameSite=Strict; Max-Age=0"
            .parse() 
            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, 
    ); 

    Ok((StatusCode::OK, headers)) 
} 

ログアウトするときは、ログインするときとは逆にセッションと Cookie を削除します。

ここで呼び出す delete_user_session メソッドを repository/users_session.rs に追加します。

rs
pub async fn delete_user_session(&self, session_id: String) -> anyhow::Result<()> { 
    let session = self
        .session_store 
        .load_session(session_id.clone()) 
        .await
        .with_context(|| "Failed to load session")?
        .with_context(|| "Failed to find session")?; 

    self.session_store 
        .destroy_session(session) 
        .await
        .with_context(|| "Failed to destroy session")?; 

    Ok(()) 
} 

セッション ID からセッションを取得し、セッションストアから削除します。

最後に、handler.rslogout ハンドラを追加します。

rs
let auth_router = Router::new()
    .route("/signup", post(auth::sign_up))
    .route("/login", post(auth::login))
    .route("/logout", post(auth::logout)); 

me ハンドラの実装

最後に、 me ハンドラを実装します。叩いたときに自分の情報が返ってくるエンドポイントです。

以下を handler/auth.rs に追加してください。

rs
#[derive(Serialize)] 
pub struct Me { 
    pub username: String, 
}
rs
pub async fn me(State(state): State<Repository>, req: Request) -> Result<Json<Me>, StatusCode> { 
    // リクエストからユーザーIDを取得する
    let user_id = req 
        .extensions() 
        .get::<String>() 
        .ok_or(StatusCode::UNAUTHORIZED)?
        .to_string(); 

    // データベースからユーザー名を取得する
    let username = state 
        .get_user_name_by_id( 
            user_id 
                .parse() 
                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?, 
        ) 
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 

    Ok(Json(Me { username })) 
} 

リクエストからユーザー ID を取得し、データベースからユーザー名を取得します。

ここで呼び出す get_user_name_by_id メソッドを repository/users.rs に追加します。

rs
impl Repository {
    ...(省略)

    pub async fn delete_user_session(&self, session_id: String) -> anyhow::Result<()> { 
        let session = self
            .session_store
            .load_session(session_id.clone()) 
            .await
            .with_context(|| "Failed to load session")?
            .with_context(|| "Failed to find session")?; 

        self.session_store 
            .destroy_session(session) 
            .await
            .with_context(|| "Failed to destroy session")?; 

        Ok(()) 
    } 

    ...(省略)
}

最後に、handler.rsme ハンドラを追加します。

rs
let with_auth_router = Router::new()
        .route("/cities/:city_name", get(country::get_city_handler))
        .route("/cities", post(country::post_city_handler))
        .route("/me", get(auth::me)) 
        .route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware));