検証
完成形
まずは完成形です。
完成形
rs
use tower_http::trace::TraceLayer;
use tracing_subscriber::EnvFilter;
mod handler;
mod repository;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::try_from_default_env().unwrap_or("info".into()))
.init();
let app_state = repository::Repository::connect().await?;
app_state.migrate().await?;
let app = handler::make_router(app_state).layer(TraceLayer::new_for_http());
let listener = tokio::net::TcpListener::bind("127.0.0.1:8080").await?;
tracing::info!("listening on {}", listener.local_addr()?);
axum::serve(listener, app).await.unwrap();
Ok(())
}
rs
use axum::{
middleware::from_fn_with_state,
routing::{get, post},
Router,
};
use crate::repository::Repository;
mod auth;
mod country;
pub fn make_router(app_state: Repository) -> Router {
let with_auth_router = Router::new()
.route("/cities/:city_name", get(country::get_city_handler))
.route("/cities", post(country::post_city_handler))
.route("/me", get(auth::me))
.route_layer(from_fn_with_state(app_state.clone(), auth::auth_middleware));
let auth_router = Router::new()
.route("/signup", post(auth::sign_up))
.route("/login", post(auth::login))
.route("/logout", post(auth::logout));
let ping_router = Router::new().route("/ping", get(|| async { "pong" }));
Router::new()
.nest("/", with_auth_router)
.nest("/", auth_router)
.nest("/", ping_router)
.with_state(app_state)
}
rs
use async_sqlx_session::MySqlSessionStore;
use sqlx::mysql::MySqlConnectOptions;
use sqlx::mysql::MySqlPool;
use std::env;
pub mod country;
pub mod users;
pub mod users_session;
#[derive(Clone)]
pub struct Repository {
pool: MySqlPool,
session_store: MySqlSessionStore,
}
impl Repository {
pub async fn connect() -> anyhow::Result<Self> {
let options = get_options()?;
let pool = sqlx::MySqlPool::connect_with(options).await?;
let session_store =
MySqlSessionStore::from_client(pool.clone()).with_table_name("user_sessions");
Ok(Self {
pool,
session_store,
})
}
pub async fn migrate(&self) -> anyhow::Result<()> {
sqlx::migrate!("./migrations").run(&self.pool).await?;
self.session_store.migrate().await?;
Ok(())
}
}
fn get_options() -> anyhow::Result<MySqlConnectOptions> {
let host = env::var("DB_HOSTNAME")?;
let port = env::var("DB_PORT")?.parse()?;
let username = env::var("DB_USERNAME")?;
let password = env::var("DB_PASSWORD")?;
let database = env::var("DB_DATABASE")?;
let timezone = Some(String::from("Asia/Tokyo"));
let collation = String::from("utf8mb4_unicode_ci");
Ok(MySqlConnectOptions::new()
.host(&host)
.port(port)
.username(&username)
.password(&password)
.database(&database)
.timezone(timezone)
.collation(&collation))
}
rs
use crate::repository::{country::City, Repository};
use axum::{
extract::rejection::JsonRejection,
extract::{Path, State},
http::StatusCode,
Json,
};
pub async fn get_city_handler(
State(state): State<Repository>,
Path(city_name): Path<String>,
) -> Result<Json<City>, StatusCode> {
let city = Repository::get_city_by_name(&state, city_name).await;
match city {
Ok(city) => Ok(Json(city)),
Err(sqlx::Error::RowNotFound) => Err(StatusCode::NOT_FOUND),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
pub async fn post_city_handler(
State(state): State<Repository>,
query: Result<Json<City>, JsonRejection>,
) -> Result<Json<City>, StatusCode> {
match query {
Ok(Json(city)) => {
let result = Repository::create_city(&state, city).await;
match result {
Ok(city) => Ok(Json(city)),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
Err(_) => Err(StatusCode::BAD_REQUEST),
}
}
rs
use super::Repository;
#[derive(sqlx::FromRow, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct City {
#[sqlx(rename = "ID")]
pub id: Option<i32>,
#[sqlx(rename = "Name")]
pub name: String,
#[sqlx(rename = "CountryCode")]
pub country_code: String,
#[sqlx(rename = "District")]
pub district: String,
#[sqlx(rename = "Population")]
pub population: i32,
}
impl Repository {
pub async fn get_city_by_name(&self, city_name: String) -> sqlx::Result<City> {
sqlx::query_as::<_, City>("SELECT * FROM city WHERE Name = ?")
.bind(&city_name)
.fetch_one(&self.pool)
.await
}
pub async fn create_city(&self, city: City) -> sqlx::Result<City> {
let result = sqlx::query(
"INSERT INTO city (Name, CountryCode, District, Population) VALUES (?, ?, ?, ?)",
)
.bind(&city.name)
.bind(&city.country_code)
.bind(&city.district)
.bind(city.population)
.execute(&self.pool)
.await?;
let id = result.last_insert_id() as i32;
Ok(City {
id: Some(id),
..city
})
}
}
rs
use axum::{
extract::{Request, State},
http::{header, StatusCode},
middleware::Next,
response::IntoResponse,
Json,
};
use axum_extra::{headers::Cookie, TypedHeader};
use serde::{Deserialize, Serialize};
use crate::repository::Repository;
pub async fn auth_middleware(
State(state): State<Repository>,
TypedHeader(cookie): TypedHeader<Cookie>,
mut req: Request,
next: Next,
) -> Result<impl IntoResponse, StatusCode> {
// セッションIDを取得する
let session_id = cookie
.get("session_id")
.ok_or(StatusCode::UNAUTHORIZED)?
.to_string();
// セッションストアからユーザーIDを取得する
let user_id = state
.get_user_id_by_session_id(&session_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::UNAUTHORIZED)?;
// リクエストにユーザーIDを追加する
req.extensions_mut().insert(user_id);
// 次のミドルウェアを呼び出す
Ok(next.run(req).await)
}
#[derive(Deserialize)]
pub struct SignUp {
pub username: String,
pub password: String,
}
pub async fn sign_up(
State(state): State<Repository>,
Json(body): Json<SignUp>,
) -> Result<StatusCode, StatusCode> {
// バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
if body.username.is_empty() || body.password.is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
// 登録しようとしているユーザーが既にデータベース内に存在したら409 Conflictを返す
if let Ok(true) = state.is_exist_username(body.username.clone()).await {
return Err(StatusCode::CONFLICT);
}
// ユーザーを作成する
let id = state
.create_user(body.username.clone())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// パスワードを保存する
state
.save_user_password(id as i32, body.password.clone())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(StatusCode::CREATED)
}
#[derive(Deserialize)]
pub struct Login {
pub username: String,
pub password: String,
}
pub async fn login(
State(state): State<Repository>,
Json(body): Json<Login>,
) -> Result<impl IntoResponse, StatusCode> {
// バリデーションする(PasswordかUsernameが空文字列の場合は400 BadRequestを返す)
if body.username.is_empty() || body.password.is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
// データベースからユーザーを取得する
let id = state
.get_user_id_by_name(body.username.clone())
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
})?;
// パスワードが一致しているかを確かめる
if !state
.verify_user_password(id, body.password.clone())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
return Err(StatusCode::UNAUTHORIZED);
}
// セッションストアに登録する
let session_id = state
.create_user_session(id.to_string())
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// クッキーをセットする
let mut headers = header::HeaderMap::new();
headers.insert(
header::SET_COOKIE,
format!("session_id={}; HttpOnly; SameSite=Strict", session_id)
.parse()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
);
Ok((StatusCode::OK, headers))
}
pub async fn logout(
State(state): State<Repository>,
TypedHeader(cookie): TypedHeader<Cookie>,
) -> Result<impl IntoResponse, StatusCode> {
// セッションIDを取得する
let session_id = cookie
.get("session_id")
.ok_or(StatusCode::UNAUTHORIZED)?
.to_string();
// セッションストアから削除する
state
.delete_user_session(session_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// クッキーを削除する
let mut headers = header::HeaderMap::new();
headers.insert(
header::SET_COOKIE,
"session_id=; HttpOnly; SameSite=Strict; Max-Age=0"
.parse()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
);
Ok((StatusCode::OK, headers))
}
#[derive(Serialize)]
pub struct Me {
pub username: String,
}
pub async fn me(State(state): State<Repository>, req: Request) -> Result<Json<Me>, StatusCode> {
// リクエストからユーザーIDを取得する
let user_id = req
.extensions()
.get::<String>()
.ok_or(StatusCode::UNAUTHORIZED)?
.to_string();
// データベースからユーザー名を取得する
let username = state
.get_user_name_by_id(
user_id
.parse()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(Me { username }))
}
rs
use super::Repository;
impl Repository {
pub async fn is_exist_username(&self, username: String) -> sqlx::Result<bool> {
let result = sqlx::query("SELECT * FROM users WHERE username = ?")
.bind(&username)
.fetch_optional(&self.pool)
.await?;
Ok(result.is_some())
}
pub async fn create_user(&self, username: String) -> sqlx::Result<u64> {
let result = sqlx::query("INSERT INTO users (username) VALUES (?)")
.bind(&username)
.execute(&self.pool)
.await?;
Ok(result.last_insert_id())
}
pub async fn get_user_id_by_name(&self, username: String) -> sqlx::Result<i32> {
let result = sqlx::query_scalar("SELECT id FROM users WHERE username = ?")
.bind(&username)
.fetch_one(&self.pool)
.await?;
Ok(result)
}
pub async fn get_user_name_by_id(&self, id: i32) -> sqlx::Result<String> {
let result = sqlx::query_scalar("SELECT username FROM users WHERE id = ?")
.bind(id)
.fetch_one(&self.pool)
.await?;
Ok(result)
}
pub async fn save_user_password(&self, id: i32, password: String) -> anyhow::Result<()> {
let hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?;
sqlx::query("INSERT INTO user_passwords (id, hashed_pass) VALUES (?, ?)")
.bind(id)
.bind(hash)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn verify_user_password(&self, id: i32, password: String) -> anyhow::Result<bool> {
let hash =
sqlx::query_scalar::<_, String>("SELECT hashed_pass FROM user_passwords WHERE id = ?")
.bind(id)
.fetch_one(&self.pool)
.await?;
Ok(bcrypt::verify(password, &hash)?)
}
}
rs
use anyhow::Context;
use async_session::{Session, SessionStore};
use super::Repository;
impl Repository {
pub async fn create_user_session(&self, user_id: String) -> anyhow::Result<String> {
let mut session = Session::new();
session
.insert("user_id", user_id)
.with_context(|| "Failed to insert user_id")?;
let session_id = self
.session_store
.store_session(session)
.await
.with_context(|| "Failed to store session")?
.with_context(|| "Failed to create session")?;
Ok(session_id)
}
pub async fn delete_user_session(&self, session_id: String) -> anyhow::Result<()> {
let session = self
.session_store
.load_session(session_id.clone())
.await
.with_context(|| "Failed to load session")?
.with_context(|| "Failed to find session")?;
self.session_store
.destroy_session(session)
.await
.with_context(|| "Failed to destroy session")?;
Ok(())
}
pub async fn get_user_id_by_session_id(
&self,
session_id: &String,
) -> anyhow::Result<Option<String>> {
let session = self
.session_store
.load_session(session_id.clone())
.await
.with_context(|| "Failed to load session")?;
Ok(session.and_then(|s| s.get::<String>("user_id")))
}
}
検証
自分の実装が正しく動くか検証しましょう。
WARNING
全て Postman での検証です。cargo run
でサーバーを起動した状態で行ってください。
また、GET
とPOST
を間違えないようにして下さい。
localhost:8080/cities/Tokyoにアクセスすると、ログインしていないため401 Unauthorized
が返ってきます。そのため、情報を入手できません。
ユーザーを作成します。 上手く作成できれば Status 201 が返ってくるはずです。
(注意:POST
です)
そのままパスを変えてログインリクエストを送ります。
ログインに成功したら、レスポンスの方の Cookies を開いて value の中身をコピーします
リクエストの方の Headers で Cookie をセットします。
Key にCookie
を Value にsession_id={コピーした値};
をセットします(既に自動で入っている場合もあります、その場合は追加しなくて大丈夫です)。
もう一度 localhost:8080/cities/Tokyo にアクセスすると正常に API が取れるようになりました。
(注意:GET
です)
ここで、作成されたユーザーがデータベースに保存されていることを確認してみましょう。
bash
SELECT * FROM users;
SELECT * FROM user_passwords;
SELECT * FROM user_sessions;
ユーザー名とハッシュ化されたパスワードが確認できますね。
ちょっと分かりにくい表示ですが、セッションもしっかり確認できます。