Skip to content

検証

完成形

まずは完成形です。

完成形
rs
use tower_http::trace::TraceLayer;
use tracing_subscriber::EnvFilter;

mod handler;
mod repository;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    tracing_subscriber::fmt()
        .with_env_filter(EnvFilter::try_from_default_env().unwrap_or("info".into()))
        .init();

    let app_state = repository::Repository::connect().await?;
    app_state.migrate().await?;
    let app = handler::make_router(app_state).layer(TraceLayer::new_for_http());
    let listener = tokio::net::TcpListener::bind("127.0.0.1:8080").await?;

    tracing::info!("listening on {}", listener.local_addr()?);
    axum::serve(listener, app).await.unwrap();
    Ok(())
}
rs
use axum::{
    middleware::from_fn_with_state,
    routing::{get, post},
    Router,
};

use crate::repository::Repository;

mod auth;
mod country;

pub fn make_router(app_state: Repository) -> Router {
    let with_auth_router = Router::new()
        .route("/cities/:city_name", get(country::get_city_handler))
        .route("/cities", post(country::post_city_handler))
        .route("/me", get(auth::me))
        .route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware));

    let auth_router = Router::new()
        .route("/signup", post(auth::sign_up))
        .route("/login", post(auth::login))
        .route("/logout", post(auth::logout));

    let ping_router = Router::new().route("/ping", get(|| async { "pong" }));

    Router::new()
        .nest("/", with_auth_router)
        .nest("/", auth_router)
        .nest("/", ping_router)
        .with_state(app_state)
}
rs
use async_sqlx_session::MySqlSessionStore;
use sqlx::mysql::MySqlConnectOptions;
use sqlx::mysql::MySqlPool;
use std::env;

pub mod country;
pub mod users;
pub mod users_session;

#[derive(Clone)]
pub struct Repository {
    pool: MySqlPool,
    session_store: MySqlSessionStore,
}

impl Repository {
    pub async fn connect() -> anyhow::Result<Self> {
        let options = get_options()?;
        let pool = sqlx::MySqlPool::connect_with(options).await?;

        let session_store =
            MySqlSessionStore::from_client(pool.clone()).with_table_name("user_sessions");

        Ok(Self {
            pool,
            session_store,
        })
    }

    pub async fn migrate(&self) -> anyhow::Result<()> {
        sqlx::migrate!("./migrations").run(&self.pool).await?;
        self.session_store.migrate().await?;
        Ok(())
    }
}

fn get_options() -> anyhow::Result<MySqlConnectOptions> {
    let host = env::var("DB_HOSTNAME")?;
    let port = env::var("DB_PORT")?.parse()?;
    let username = env::var("DB_USERNAME")?;
    let password = env::var("DB_PASSWORD")?;
    let database = env::var("DB_DATABASE")?;
    let timezone = Some(String::from("Asia/Tokyo"));
    let collation = String::from("utf8mb4_unicode_ci");

    Ok(MySqlConnectOptions::new()
        .host(&host)
        .port(port)
        .username(&username)
        .password(&password)
        .database(&database)
        .timezone(timezone)
        .collation(&collation))
}
rs
use crate::repository::{country::City, Repository};
use axum::{
    extract::rejection::JsonRejection,
    extract::{Path, State},
    http::StatusCode,
    Json,
};

pub async fn get_city_handler(
    State(state): State<Repository>,
    Path(city_name): Path<String>,
) -> Result<Json<City>, StatusCode> {
    let city = Repository::get_city_by_name(&state, city_name).await;
    match city {
        Ok(city) => Ok(Json(city)),
        Err(sqlx::Error::RowNotFound) => Err(StatusCode::NOT_FOUND),
        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
    }
}

pub async fn post_city_handler(
    State(state): State<Repository>,
    query: Result<Json<City>, JsonRejection>,
) -> Result<Json<City>, StatusCode> {
    match query {
        Ok(Json(city)) => {
            let result = Repository::create_city(&state, city).await;
            match result {
                Ok(city) => Ok(Json(city)),
                Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
            }
        }
        Err(_) => Err(StatusCode::BAD_REQUEST),
    }
}
rs
use super::Repository;

#[derive(sqlx::FromRow, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct City {
    #[sqlx(rename = "ID")]
    pub id: Option<i32>,
    #[sqlx(rename = "Name")]
    pub name: String,
    #[sqlx(rename = "CountryCode")]
    pub country_code: String,
    #[sqlx(rename = "District")]
    pub district: String,
    #[sqlx(rename = "Population")]
    pub population: i32,
}

impl Repository {
    pub async fn get_city_by_name(&self, city_name: String) -> sqlx::Result<City> {
        sqlx::query_as::<_, City>("SELECT * FROM city WHERE Name = ?")
            .bind(&city_name)
            .fetch_one(&self.pool)
            .await
    }

    pub async fn create_city(&self, city: City) -> sqlx::Result<City> {
        let result = sqlx::query(
            "INSERT INTO city (Name, CountryCode, District, Population) VALUES (?, ?, ?, ?)",
        )
        .bind(&city.name)
        .bind(&city.country_code)
        .bind(&city.district)
        .bind(city.population)
        .execute(&self.pool)
        .await?;

        let id = result.last_insert_id() as i32;
        Ok(City {
            id: Some(id),
            ..city
        })
    }
}
rs
use axum::{
    extract::{Request, State},
    http::{header, StatusCode},
    middleware::Next,
    response::IntoResponse,
    Json,
};
use axum_extra::{headers::Cookie, TypedHeader};
use serde::{Deserialize, Serialize};

use crate::repository::Repository;

pub async fn auth_middleware(
    State(state): State<Repository>,
    TypedHeader(cookie): TypedHeader<Cookie>,
    mut req: Request,
    next: Next,
) -> Result<impl IntoResponse, StatusCode> {
    // セッションIDを取得する
    let session_id = cookie
        .get("session_id")
        .ok_or(StatusCode::UNAUTHORIZED)?
        .to_string();

    // セッションストアからユーザーIDを取得する
    let user_id = state
        .get_user_id_by_session_id(&session_id)
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
        .ok_or(StatusCode::UNAUTHORIZED)?;

    // リクエストにユーザーIDを追加する
    req.extensions_mut().insert(user_id);

    // 次のミドルウェアを呼び出す
    Ok(next.run(req).await)
}

#[derive(Deserialize)]
pub struct SignUp {
    pub username: String,
    pub password: String,
}

pub async fn sign_up(
    State(state): State<Repository>,
    Json(body): Json<SignUp>,
) -> Result<StatusCode, StatusCode> {
    // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
    if body.username.is_empty() || body.password.is_empty() {
        return Err(StatusCode::BAD_REQUEST);
    }

    // 登録しようとしているユーザーが既にデータベース内に存在したら409 Conflictを返す
    if let Ok(true) = state.is_exist_username(body.username.clone()).await {
        return Err(StatusCode::CONFLICT);
    }

    // ユーザーを作成する
    let id = state
        .create_user(body.username.clone())
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // パスワードを保存する
    state
        .save_user_password(id as i32, body.password.clone())
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    Ok(StatusCode::CREATED)
}

#[derive(Deserialize)]
pub struct Login {
    pub username: String,
    pub password: String,
}

pub async fn login(
    State(state): State<Repository>,
    Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
    // バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
    if body.username.is_empty() || body.password.is_empty() {
        return Err(StatusCode::BAD_REQUEST);
    }

    // データベースからユーザーを取得する
    let id = state
        .get_user_id_by_name(body.username.clone())
        .await
        .map_err(|e| match e {
            sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED,
            _ => StatusCode::INTERNAL_SERVER_ERROR,
        })?;

    // パスワードが一致しているかを確かめる
    if !state
        .verify_user_password(id, body.password.clone())
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
    {
        return Err(StatusCode::UNAUTHORIZED);
    }

    // セッションストアに登録する
    let session_id = state
        .create_user_session(id.to_string())
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // クッキーをセットする
    let mut headers = header::HeaderMap::new();

    headers.insert(
        header::SET_COOKIE,
        format!("session_id={}; HttpOnly; SameSite=Strict", session_id)
            .parse()
            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
    );

    Ok((StatusCode::OK, headers))
}

pub async fn logout(
    State(state): State<Repository>,
    TypedHeader(cookie): TypedHeader<Cookie>,
) -> Result<impl IntoResponse, StatusCode> {
    // セッションIDを取得する
    let session_id = cookie
        .get("session_id")
        .ok_or(StatusCode::UNAUTHORIZED)?
        .to_string();

    // セッションストアから削除する
    state
        .delete_user_session(session_id)
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // クッキーを削除する
    let mut headers = header::HeaderMap::new();
    headers.insert(
        header::SET_COOKIE,
        "session_id=; HttpOnly; SameSite=Strict; Max-Age=0"
            .parse()
            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
    );

    Ok((StatusCode::OK, headers))
}

#[derive(Serialize)]
pub struct Me {
    pub username: String,
}

pub async fn me(State(state): State<Repository>, req: Request) -> Result<Json<Me>, StatusCode> {
    // リクエストからユーザーIDを取得する
    let user_id = req
        .extensions()
        .get::<String>()
        .ok_or(StatusCode::UNAUTHORIZED)?
        .to_string();

    // データベースからユーザー名を取得する
    let username = state
        .get_user_name_by_id(
            user_id
                .parse()
                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
        )
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    Ok(Json(Me { username }))
}
rs
use super::Repository;

impl Repository {
    pub async fn is_exist_username(&self, username: String) -> sqlx::Result<bool> {
        let result = sqlx::query("SELECT * FROM users WHERE username = ?")
            .bind(&username)
            .fetch_optional(&self.pool)
            .await?;
        Ok(result.is_some())
    }

    pub async fn create_user(&self, username: String) -> sqlx::Result<u64> {
        let result = sqlx::query("INSERT INTO users (username) VALUES (?)")
            .bind(&username)
            .execute(&self.pool)
            .await?;
        Ok(result.last_insert_id())
    }

    pub async fn get_user_id_by_name(&self, username: String) -> sqlx::Result<i32> {
        let result = sqlx::query_scalar("SELECT id FROM users WHERE username = ?")
            .bind(&username)
            .fetch_one(&self.pool)
            .await?;
        Ok(result)
    }

    pub async fn get_user_name_by_id(&self, id: i32) -> sqlx::Result<String> {
        let result = sqlx::query_scalar("SELECT username FROM users WHERE id = ?")
            .bind(id)
            .fetch_one(&self.pool)
            .await?;
        Ok(result)
    }

    pub async fn save_user_password(&self, id: i32, password: String) -> anyhow::Result<()> {
        let hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?;

        sqlx::query("INSERT INTO user_passwords (id, hashed_pass) VALUES (?, ?)")
            .bind(id)
            .bind(hash)
            .execute(&self.pool)
            .await?;

        Ok(())
    }

    pub async fn verify_user_password(&self, id: i32, password: String) -> anyhow::Result<bool> {
        let hash =
            sqlx::query_scalar::<_, String>("SELECT hashed_pass FROM user_passwords WHERE id = ?")
                .bind(id)
                .fetch_one(&self.pool)
                .await?;

        Ok(bcrypt::verify(password, &hash)?)
    }
}
rs
use anyhow::Context;
use async_session::{Session, SessionStore};

use super::Repository;

impl Repository {
    pub async fn create_user_session(&self, user_id: String) -> anyhow::Result<String> {
        let mut session = Session::new();

        session
            .insert("user_id", user_id)
            .with_context(|| "Failed to insert user_id")?;

        let session_id = self
            .session_store
            .store_session(session)
            .await
            .with_context(|| "Failed to store session")?
            .with_context(|| "Failed to create session")?;

        Ok(session_id)
    }

    pub async fn delete_user_session(&self, session_id: String) -> anyhow::Result<()> {
        let session = self
            .session_store
            .load_session(session_id.clone())
            .await
            .with_context(|| "Failed to load session")?
            .with_context(|| "Failed to find session")?;

        self.session_store
            .destroy_session(session)
            .await
            .with_context(|| "Failed to destroy session")?;

        Ok(())
    }

    pub async fn get_user_id_by_session_id(
        &self,
        session_id: &String,
    ) -> anyhow::Result<Option<String>> {
        let session = self
            .session_store
            .load_session(session_id.clone())
            .await
            .with_context(|| "Failed to load session")?;

        Ok(session.and_then(|s| s.get::<String>("user_id")))
    }
}

検証

自分の実装が正しく動くか検証しましょう。

WARNING

全て Postman での検証です。
cargo runでサーバーを起動した状態で行ってください。

また、GETPOSTを間違えないようにして下さい。

localhost:8080/cities/Tokyoにアクセスすると、ログインしていないため401 Unauthorizedが返ってきます。そのため、情報を入手できません。

ユーザーを作成します。 上手く作成できれば Status 201 が返ってくるはずです。
(注意:POSTです)

そのままパスを変えてログインリクエストを送ります。

ログインに成功したら、レスポンスの方の Cookies を開いて value の中身をコピーします

リクエストの方の Headers で Cookie をセットします。

Key にCookieを Value にsession_id={コピーした値};をセットします(既に自動で入っている場合もあります、その場合は追加しなくて大丈夫です)。

もう一度 localhost:8080/cities/Tokyo にアクセスすると正常に API が取れるようになりました。
(注意:GETです)

ここで、作成されたユーザーがデータベースに保存されていることを確認してみましょう。

bash
SELECT * FROM users;
SELECT * FROM user_passwords;
SELECT * FROM user_sessions;

ユーザー名とハッシュ化されたパスワードが確認できますね。

ちょっと分かりにくい表示ですが、セッションもしっかり確認できます。