From 113ab5a4ac3320cdc4329bb393bb817b83e7a409f556f4a523e56ad021c2e3ab Mon Sep 17 00:00:00 2001 From: raven <7156279+RavenX8@users.noreply.github.com> Date: Tue, 26 Nov 2024 01:58:26 -0500 Subject: [PATCH] - add: roles to user - add: register calls for auth server - add: user lookup by email - add: start of password reset - add: Cache trait to allow redis cache mocking --- auth-service/src/database_client.rs | 55 +++++++++- auth-service/src/grpc.rs | 103 +++++++++++++++++- .../src/mocks/database_client_mock.rs | 9 +- database-service/src/grpc.rs | 37 +++++-- database-service/src/redis_cache.rs | 29 +++-- database-service/src/users.rs | 60 ++++++++-- proto/auth.proto | 30 +++++ proto/database.proto | 6 + 8 files changed, 303 insertions(+), 26 deletions(-) diff --git a/auth-service/src/database_client.rs b/auth-service/src/database_client.rs index 20a4437..86cd768 100644 --- a/auth-service/src/database_client.rs +++ b/auth-service/src/database_client.rs @@ -1,20 +1,33 @@ use std::error::Error; use tonic::transport::Channel; -use crate::database::{database_service_client::DatabaseServiceClient, CreateUserRequest, CreateUserResponse, GetUserByUsernameRequest, GetUserRequest, GetUserResponse}; +use crate::database::{database_service_client::DatabaseServiceClient, CreateUserRequest, CreateUserResponse, GetUserByUsernameRequest, GetUserByEmailRequest, GetUserRequest, GetUserResponse}; use async_trait::async_trait; +use chrono::{DateTime, Utc}; #[async_trait] pub trait DatabaseClientTrait: Sized { async fn connect(endpoint: &str) -> Result>; async fn get_user_by_userid(&mut self, user_id: i32) -> Result>; async fn get_user_by_username(&mut self, user_id: &str) -> Result>; + async fn get_user_by_email(&mut self, email: &str) -> Result>; async fn create_user(&mut self, username: &str, email: &str, password: &str) -> Result>; + async fn store_password_reset(&mut self, email: &str, reset_token: &str, expires_at: DateTime) -> Result<(), Box>; + async fn get_password_reset(&self, reset_token: &str) -> Result, Box>; + async fn delete_password_reset(&self, reset_token: &str) -> Result<(), Box>; + async fn update_user_password(&self, email: &str, hashed_password: &str) -> Result<(), Box>; } #[derive(Clone)] pub struct DatabaseClient { client: DatabaseServiceClient, } +#[derive(Debug)] +pub struct PasswordReset { + pub email: String, + pub reset_token: String, + pub expires_at: DateTime, +} + #[async_trait] impl DatabaseClientTrait for DatabaseClient { async fn connect(endpoint: &str) -> Result> { @@ -44,6 +57,14 @@ impl DatabaseClientTrait for DatabaseClient { Ok(response.into_inner()) } + async fn get_user_by_email(&mut self, email: &str) -> Result> { + let request = tonic::Request::new(GetUserByEmailRequest { + email: email.to_string(), + }); + let response = self.client.get_user_by_email(request).await?; + Ok(response.into_inner()) + } + async fn create_user(&mut self, username: &str, email: &str, password: &str) -> Result> { let request = tonic::Request::new(CreateUserRequest { username: username.to_string(), @@ -53,4 +74,36 @@ impl DatabaseClientTrait for DatabaseClient { let response = self.client.create_user(request).await?; Ok(response.into_inner()) } + + async fn store_password_reset( + &mut self, + email: &str, + reset_token: &str, + expires_at: DateTime, + ) -> Result<(), Box> { + Ok(()) + } + + async fn get_password_reset( + &self, + reset_token: &str, + ) -> Result, Box> { + todo!() + } + + async fn delete_password_reset( + &self, + reset_token: &str, + ) -> Result<(), Box> { + Ok(()) + } + + async fn update_user_password( + &self, + email: &str, + hashed_password: &str, + ) -> Result<(), Box> { + Ok(()) + } + } diff --git a/auth-service/src/grpc.rs b/auth-service/src/grpc.rs index a6301b8..56aa1e7 100644 --- a/auth-service/src/grpc.rs +++ b/auth-service/src/grpc.rs @@ -1,9 +1,11 @@ +use chrono::{Duration, Utc}; +use rand::Rng; use tonic::{Request, Response, Status}; use crate::jwt::{generate_token, validate_token}; -use crate::users::verify_user; +use crate::users::{verify_user, hash_password}; use crate::database_client::{DatabaseClient, DatabaseClientTrait}; use crate::auth::auth_service_server::{AuthService}; -use crate::auth::{LoginRequest, LoginResponse, ValidateTokenRequest, ValidateTokenResponse}; +use crate::auth::{LoginRequest, LoginResponse, ValidateTokenRequest, ValidateTokenResponse, RegisterRequest, RegisterResponse, PasswordResetRequest, PasswordResetResponse, ResetPasswordRequest, ResetPasswordResponse}; use tracing::{info, warn}; pub struct MyAuthService { @@ -49,4 +51,101 @@ impl AuthService for MyA })), } } + + async fn register( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + + // Hash the password + let hashed_password = hash_password(&req.password); + + // Create user in the database + let user = self.db_client.clone().create_user(&req.username, &req.email, &hashed_password) + .await + .map_err(|e| Status::internal(format!("Database error: {}", e)))?; + + Ok(Response::new(RegisterResponse { user_id: user.user_id })) + } + + async fn request_password_reset( + &self, + request: Request, + ) -> Result, Status> { + let email = request.into_inner().email; + + let user = self.db_client.clone().get_user_by_email(&email).await; + + // Check if the email exists + if user.ok().is_some() { + // Generate a reset token + let reset_token: String = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(32) + .map(char::from) + .collect(); + + // Set token expiration (e.g., 1 hour) + let expires_at = Utc::now() + Duration::hours(1); + + // Store the reset token in the database + self.db_client.clone() + .store_password_reset(&email, &reset_token, expires_at) + .await + .map_err(|e| Status::internal(format!("Database error: {}", e)))?; + + // Send the reset email + // send_email(&email, "Password Reset Request", &format!( + // "Click the link to reset your password: https://example.com/reset?token={}", + // reset_token + // )) + // .map_err(|e| Status::internal(format!("Email error: {}", e)))?; + + Ok(Response::new(PasswordResetResponse { + message: "Password reset email sent".to_string(), + })) + } else { + // Respond with a generic message to avoid information leaks + Ok(Response::new(PasswordResetResponse { + message: "If the email exists, a reset link has been sent.".to_string(), + })) + } + } + + async fn reset_password( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + + // Validate the reset token + if let Some(password_reset) = self.db_client.clone().get_password_reset(&req.reset_token).await + .map_err(|e| Status::internal(format!("Database error: {}", e)))? { + if password_reset.expires_at < Utc::now() { + return Err(Status::unauthenticated("Token expired")); + } + + // Hash the new password + let hashed_password = hash_password(&req.new_password); + + // Update the user's password + self.db_client + .update_user_password(&password_reset.email, &hashed_password) + .await + .map_err(|e| Status::internal(format!("Database error: {}", e)))?; + + // Delete the reset token + self.db_client + .delete_password_reset(&req.reset_token) + .await + .map_err(|e| Status::internal(format!("Database error: {}", e)))?; + + Ok(Response::new(ResetPasswordResponse { + message: "Password successfully reset".to_string(), + })) + } else { + Err(Status::unauthenticated("Invalid reset token")) + } + } } diff --git a/auth-service/src/mocks/database_client_mock.rs b/auth-service/src/mocks/database_client_mock.rs index 29ac147..b53b84b 100644 --- a/auth-service/src/mocks/database_client_mock.rs +++ b/auth-service/src/mocks/database_client_mock.rs @@ -1,7 +1,9 @@ +use std::error::Error; use mockall::{mock, predicate::*}; use async_trait::async_trait; +use chrono::{DateTime, Utc}; use crate::database::{CreateUserResponse, GetUserResponse}; -use crate::database_client::{DatabaseClientTrait}; +use crate::database_client::{DatabaseClientTrait, PasswordReset}; #[cfg(test)] mock! { @@ -12,7 +14,12 @@ mock! { async fn connect(endpoint: &str) -> Result>; async fn get_user_by_userid(&mut self, user_id: i32) -> Result>; async fn get_user_by_username(&mut self, user_id: &str) -> Result>; + async fn get_user_by_email(&mut self, email: &str) -> Result>; async fn create_user(&mut self, username: &str, email: &str, password: &str) -> Result>; + async fn store_password_reset(&mut self, email: &str, reset_token: &str, expires_at: DateTime) -> Result<(), Box>; + async fn get_password_reset(&self, reset_token: &str) -> Result, Box>; + async fn delete_password_reset(&self, reset_token: &str) -> Result<(), Box>; + async fn update_user_password(&self, email: &str, hashed_password: &str) -> Result<(), Box>; } } diff --git a/database-service/src/grpc.rs b/database-service/src/grpc.rs index 0675036..b91c619 100644 --- a/database-service/src/grpc.rs +++ b/database-service/src/grpc.rs @@ -1,5 +1,5 @@ use crate::db::Database; -use crate::database::{CreateUserRequest, CreateUserResponse, GetUserRequest, GetUserByUsernameRequest, GetUserResponse}; +use crate::database::{CreateUserRequest, CreateUserResponse, GetUserRequest, GetUserByUsernameRequest, GetUserByEmailRequest, GetUserResponse}; use tonic::{Request, Response, Status}; use crate::database::database_service_server::{DatabaseService}; @@ -26,9 +26,24 @@ impl DatabaseService for MyDatabaseService { username: user.username, email: user.email, hashed_password: user.hashed_password, + roles: user.roles, })) } + async fn create_user( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + + let user_id = self.db.users_service.create_user(&req.username, &req.email, &req.hashed_password) + .await + .map_err(|_| Status::internal("Failed to create user"))?; + + // Return the newly created user ID + Ok(Response::new(CreateUserResponse { user_id })) + } + async fn get_user_by_username( &self, request: Request, @@ -44,20 +59,26 @@ impl DatabaseService for MyDatabaseService { username: user.username, email: user.email, hashed_password: user.hashed_password, + roles: user.roles, })) } - async fn create_user( + async fn get_user_by_email( &self, - request: Request, - ) -> Result, Status> { + request: Request, + ) -> Result, Status> { let req = request.into_inner(); - let user_id = self.db.users_service.create_user(&req.username, &req.email, &req.hashed_password) + let user = self.db.users_service.get_user_by_email(&req.email) .await - .map_err(|_| Status::internal("Failed to create user"))?; + .map_err(|_| Status::not_found("User not found"))?; - // Return the newly created user ID - Ok(Response::new(CreateUserResponse { user_id: user_id })) + Ok(Response::new(GetUserResponse { + user_id: user.id, + username: user.username, + email: user.email, + hashed_password: user.hashed_password, + roles: user.roles, + })) } } \ No newline at end of file diff --git a/database-service/src/redis_cache.rs b/database-service/src/redis_cache.rs index 828f8d2..6489776 100644 --- a/database-service/src/redis_cache.rs +++ b/database-service/src/redis_cache.rs @@ -1,24 +1,39 @@ use deadpool_redis::{Config, Pool, Runtime}; use redis::{AsyncCommands, RedisError}; -use serde::{de::DeserializeOwned, Serialize}; +use serde::{Deserialize, Serialize}; +use async_trait::async_trait; -#[async_trait::async_trait] +#[async_trait] pub trait Cache { - fn new(redis_url: &str) -> Self; + async fn set( + &self, + key: &String, + value: &T, + ttl: u64, + ) -> Result<(), redis::RedisError>; + + async fn get serde::Deserialize<'de> + Send + Sync>( + &self, + key: &String, + ) -> Result, redis::RedisError>; + } pub struct RedisCache { pub pool: Pool, } -impl Cache for RedisCache { - fn new(redis_url: &str) -> Self { +impl RedisCache { + pub fn new(redis_url: &str) -> Self { let cfg = Config::from_url(redis_url); let pool = cfg.create_pool(Some(Runtime::Tokio1)).expect("Failed to create Redis pool"); RedisCache { pool } } +} - async fn set( +#[async_trait] +impl Cache for RedisCache { + async fn set( &self, key: &String, value: &T, @@ -41,7 +56,7 @@ impl Cache for RedisCache { conn.set_ex(key, serialized_value, ttl).await } - async fn get(&self, key: &String) -> Result, redis::RedisError> { + async fn get Deserialize<'de> + Send + Sync>(&self, key: &String) -> Result, redis::RedisError> { let mut conn = self.pool.get().await .map_err(|err| { redis::RedisError::from(( diff --git a/database-service/src/users.rs b/database-service/src/users.rs index 0e910ef..600556f 100644 --- a/database-service/src/users.rs +++ b/database-service/src/users.rs @@ -2,7 +2,7 @@ use sqlx::Error; use sqlx::PgPool; use serde::{Serialize, Deserialize}; use std::sync::Arc; -use crate::redis_cache::RedisCache; +use crate::redis_cache::{Cache, RedisCache}; use tracing::{debug, error}; @@ -12,6 +12,7 @@ pub struct User { pub username: String, pub email: String, pub hashed_password: String, + pub roles: Vec, } pub struct UsersService { @@ -50,13 +51,20 @@ impl UsersService { } // Fetch from PostgreSQL - let user = sqlx::query_as!( - User, - "SELECT id, username, email, hashed_password FROM users WHERE id = $1", + let row = sqlx::query!( + "SELECT id, username, email, hashed_password, roles FROM users WHERE id = $1", user_id ) .fetch_one(&self.pool) .await?; + + let user = User { + id: row.id, + username: row.username, + email: row.email, + hashed_password: row.hashed_password, + roles: row.roles.unwrap_or_default(), + }; // Store result in Redis self.cache @@ -74,14 +82,21 @@ impl UsersService { } // Fetch from PostgreSQL - let user = sqlx::query_as!( - User, - "SELECT id, username, email, hashed_password FROM users WHERE username = $1", + let row = sqlx::query!( + "SELECT id, username, email, hashed_password, roles FROM users WHERE username = $1", username ) .fetch_one(&self.pool) .await?; + let user = User { + id: row.id, + username: row.username, + email: row.email, + hashed_password: row.hashed_password, + roles: row.roles.unwrap_or_default(), + }; + // Store result in Redis self.cache .set(&format!("user_by_username:{}", username), &user, 3600) @@ -91,6 +106,37 @@ impl UsersService { Ok(user) } + pub async fn get_user_by_email(&self, email: &str) -> Result { + // Check Redis cache first + if let Ok(Some(cached_user)) = self.cache.get::(&format!("user_by_email:{}", email)).await { + return Ok(cached_user); + } + + // Fetch from PostgreSQL + let row = sqlx::query!( + "SELECT id, username, email, hashed_password, roles FROM users WHERE email = $1", + email + ) + .fetch_one(&self.pool) + .await?; + + let user = User { + id: row.id, + username: row.username, + email: row.email, + hashed_password: row.hashed_password, + roles: row.roles.unwrap_or_default(), + }; + + // Store result in Redis + self.cache + .set(&format!("user_by_email:{}", email), &user, 3600) + .await + .unwrap_or_else(|err| eprintln!("Failed to cache user: {:?}", err)); + + Ok(user) + } + pub async fn update_user_email(&self, user_id: i32, new_email: &str) -> Result<(), Error> { sqlx::query!( r#" diff --git a/proto/auth.proto b/proto/auth.proto index 4e7126c..70ed14c 100644 --- a/proto/auth.proto +++ b/proto/auth.proto @@ -5,6 +5,9 @@ package auth; service AuthService { rpc Login(LoginRequest) returns (LoginResponse); rpc ValidateToken(ValidateTokenRequest) returns (ValidateTokenResponse); + rpc Register (RegisterRequest) returns (RegisterResponse); + rpc RequestPasswordReset (PasswordResetRequest) returns (PasswordResetResponse); + rpc ResetPassword (ResetPasswordRequest) returns (ResetPasswordResponse); } message LoginRequest { @@ -25,3 +28,30 @@ message ValidateTokenResponse { bool valid = 1; string user_id = 2; } + +message RegisterRequest { + string username = 1; + string email = 2; + string password = 3; +} + +message RegisterResponse { + int32 user_id = 1; +} + +message PasswordResetRequest { + string email = 1; +} + +message PasswordResetResponse { + string message = 1; +} + +message ResetPasswordRequest { + string reset_token = 1; + string new_password = 2; +} + +message ResetPasswordResponse { + string message = 1; +} diff --git a/proto/database.proto b/proto/database.proto index 0a43f55..081b5b2 100644 --- a/proto/database.proto +++ b/proto/database.proto @@ -6,6 +6,7 @@ service DatabaseService { rpc GetUser(GetUserRequest) returns (GetUserResponse); rpc CreateUser(CreateUserRequest) returns (CreateUserResponse); rpc GetUserByUsername(GetUserByUsernameRequest) returns (GetUserResponse); + rpc GetUserByEmail(GetUserByEmailRequest) returns (GetUserResponse); } message GetUserRequest { @@ -16,11 +17,16 @@ message GetUserByUsernameRequest { string username = 1; } +message GetUserByEmailRequest { + string email = 1; +} + message GetUserResponse { int32 user_id = 1; string username = 2; string email = 3; string hashed_password = 4; + repeated string roles = 5; } message CreateUserRequest {