- update: database client to implement a database trait so we can mock it out

- update unit tests
- add: database client mock
This commit is contained in:
2024-11-25 22:20:15 -05:00
parent 3ff22c9a5b
commit 3fc6c6252c
15 changed files with 181 additions and 103 deletions

View File

@@ -3,6 +3,9 @@ name = "auth-service"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[features]
mocks = []
[dependencies] [dependencies]
tokio = { version = "1.41.1", features = ["full"] } tokio = { version = "1.41.1", features = ["full"] }
tonic = "0.12.3" tonic = "0.12.3"
@@ -15,6 +18,8 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "chrono"] }
prost = "0.13.3" prost = "0.13.3"
prost-types = "0.13.3" prost-types = "0.13.3"
chrono = { version = "0.4.38", features = ["serde"] } chrono = { version = "0.4.38", features = ["serde"] }
async-trait = "0.1.83"
mockall = "0.13.1"
[build-dependencies] [build-dependencies]
tonic-build = "0.12.3" tonic-build = "0.12.3"

View File

@@ -1,18 +1,28 @@
use std::error::Error;
use tonic::transport::Channel; use tonic::transport::Channel;
use crate::database::{database_service_client::DatabaseServiceClient, GetUserByUsernameRequest, GetUserRequest, GetUserResponse}; use crate::database::{database_service_client::DatabaseServiceClient, CreateUserRequest, CreateUserResponse, GetUserByUsernameRequest, GetUserRequest, GetUserResponse};
use async_trait::async_trait;
#[async_trait]
pub trait DatabaseClientTrait: Sized {
async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>>;
async fn get_user_by_userid(&mut self, user_id: i32) -> Result<GetUserResponse, Box<dyn std::error::Error>>;
async fn get_user_by_username(&mut self, user_id: &str) -> Result<GetUserResponse, Box<dyn std::error::Error>>;
async fn create_user(&mut self, username: &str, email: &str, password: &str) -> Result<CreateUserResponse, Box<dyn std::error::Error>>;
}
#[derive(Clone)] #[derive(Clone)]
pub struct DatabaseClient { pub struct DatabaseClient {
client: DatabaseServiceClient<Channel>, client: DatabaseServiceClient<Channel>,
} }
impl DatabaseClient { #[async_trait]
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> { impl DatabaseClientTrait for DatabaseClient {
async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
let client = DatabaseServiceClient::connect(endpoint.to_string()).await?; let client = DatabaseServiceClient::connect(endpoint.to_string()).await?;
Ok(Self { client }) Ok(Self { client })
} }
pub async fn get_user_by_userid( async fn get_user_by_userid(
&mut self, &mut self,
user_id: i32, user_id: i32,
) -> Result<GetUserResponse, Box<dyn std::error::Error>> { ) -> Result<GetUserResponse, Box<dyn std::error::Error>> {
@@ -23,7 +33,7 @@ impl DatabaseClient {
Ok(response.into_inner()) Ok(response.into_inner())
} }
pub async fn get_user_by_username( async fn get_user_by_username(
&mut self, &mut self,
username: &str, username: &str,
) -> Result<GetUserResponse, Box<dyn std::error::Error>> { ) -> Result<GetUserResponse, Box<dyn std::error::Error>> {
@@ -33,4 +43,14 @@ impl DatabaseClient {
let response = self.client.get_user_by_username(request).await?; let response = self.client.get_user_by_username(request).await?;
Ok(response.into_inner()) Ok(response.into_inner())
} }
async fn create_user(&mut self, username: &str, email: &str, password: &str) -> Result<CreateUserResponse, Box<dyn Error>> {
let request = tonic::Request::new(CreateUserRequest {
username: username.to_string(),
email: email.to_string(),
hashed_password: password.to_string(),
});
let response = self.client.create_user(request).await?;
Ok(response.into_inner())
}
} }

View File

@@ -1,17 +1,17 @@
use tonic::{Request, Response, Status}; use tonic::{Request, Response, Status};
use crate::jwt::{generate_token, validate_token}; use crate::jwt::{generate_token, validate_token};
use crate::users::verify_user; use crate::users::verify_user;
use crate::database_client::DatabaseClient; use crate::database_client::{DatabaseClient, DatabaseClientTrait};
use crate::auth::auth_service_server::{AuthService}; use crate::auth::auth_service_server::{AuthService};
use crate::auth::{LoginRequest, LoginResponse, ValidateTokenRequest, ValidateTokenResponse}; use crate::auth::{LoginRequest, LoginResponse, ValidateTokenRequest, ValidateTokenResponse};
use tracing::{info, warn}; use tracing::{info, warn};
pub struct MyAuthService { pub struct MyAuthService<T: DatabaseClientTrait + Clone> {
pub db_client: DatabaseClient, pub db_client: T,
} }
#[tonic::async_trait] #[tonic::async_trait]
impl AuthService for MyAuthService { impl<T: DatabaseClientTrait + Send + Sync + Clone + 'static> AuthService for MyAuthService<T> {
async fn login( async fn login(
&self, &self,
request: Request<LoginRequest>, request: Request<LoginRequest>,

View File

@@ -8,4 +8,7 @@ pub mod auth {
} }
pub mod database { pub mod database {
tonic::include_proto!("database"); // Matches package name in database.proto tonic::include_proto!("database"); // Matches package name in database.proto
} }
#[cfg(test)]
pub mod mocks;

View File

@@ -3,6 +3,7 @@ use std::env;
use tonic::transport::Server; use tonic::transport::Server;
use auth_service::grpc::MyAuthService; use auth_service::grpc::MyAuthService;
use auth_service::database_client::DatabaseClient; use auth_service::database_client::DatabaseClient;
use auth_service::database_client::DatabaseClientTrait;
use auth_service::auth::auth_service_server::AuthServiceServer; use auth_service::auth::auth_service_server::AuthServiceServer;
pub mod auth { pub mod auth {

View File

@@ -0,0 +1,23 @@
use mockall::{mock, predicate::*};
use async_trait::async_trait;
use crate::database::{CreateUserResponse, GetUserResponse};
use crate::database_client::{DatabaseClientTrait};
#[cfg(test)]
mock! {
pub DatabaseClient {}
#[async_trait]
impl DatabaseClientTrait for DatabaseClient {
async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>>;
async fn get_user_by_userid(&mut self, user_id: i32) -> Result<GetUserResponse, Box<dyn std::error::Error>>;
async fn get_user_by_username(&mut self, user_id: &str) -> Result<GetUserResponse, Box<dyn std::error::Error>>;
async fn create_user(&mut self, username: &str, email: &str, password: &str) -> Result<CreateUserResponse, Box<dyn std::error::Error>>;
}
}
impl Clone for MockDatabaseClient {
fn clone(&self) -> Self {
MockDatabaseClient::new() // Create a new mock instance
}
}

View File

@@ -0,0 +1,2 @@
#[cfg(test)]
pub mod database_client_mock;

View File

@@ -1,4 +1,4 @@
use crate::database_client::DatabaseClient; use crate::database_client::{DatabaseClient, DatabaseClientTrait};
use argon2::{ use argon2::{
password_hash::{ password_hash::{
@@ -19,7 +19,7 @@ pub fn verify_password(password: &str, hash: &str) -> bool {
Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok() Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()
} }
pub async fn verify_user(mut db_client: DatabaseClient, pub async fn verify_user<T: DatabaseClientTrait>(mut db_client: T,
username: &str, password: &str) -> Option<String> { username: &str, password: &str) -> Option<String> {
// Placeholder: Replace with a gRPC call to the Database Service // Placeholder: Replace with a gRPC call to the Database Service
let user = db_client.get_user_by_username(username).await.ok()?; let user = db_client.get_user_by_username(username).await.ok()?;

View File

@@ -5,56 +5,70 @@ mod tests {
use tonic::Request; use tonic::Request;
use auth_service::auth::auth_service_server::AuthService; use auth_service::auth::auth_service_server::AuthService;
use auth_service::auth::{LoginRequest, LoginResponse, ValidateTokenRequest, ValidateTokenResponse}; use auth_service::auth::{LoginRequest, LoginResponse, ValidateTokenRequest, ValidateTokenResponse};
use auth_service::database::GetUserResponse;
use auth_service::database_client::DatabaseClient; use auth_service::database_client::DatabaseClient;
use auth_service::grpc::MyAuthService; use auth_service::grpc::MyAuthService;
use auth_service::jwt; use auth_service::jwt;
// use auth_service::mocks::database_client_mock::MockDatabaseClient;
#[tokio::test] #[tokio::test]
async fn test_login() { async fn test_login() {
dotenv().ok(); // dotenv().ok();
// let mut db_client = MockDatabaseClient::new();
// Mock dependencies or use the actual Database Service //
let db_client = DatabaseClient::connect("http://127.0.0.1:50052").await.unwrap(); // db_client
// .expect_get_user_by_username()
let auth_service = MyAuthService { // .with(mockall::predicate::eq("test"))
db_client, // .returning(|user_id| {
}; // Ok(GetUserResponse {
// user_id: 1,
// Create a test LoginRequest // username: "test".to_string(),
let request = Request::new(LoginRequest { // email: "test@test.com".to_string(),
username: "test".into(), // hashed_password: "test".to_string(),
password: "test".into(), // })
}); // });
//
// Call the login method //
let response = auth_service.login(request).await.unwrap().into_inner(); // let auth_service = MyAuthService {
// db_client,
// Verify the response // };
assert!(!response.token.is_empty()); //
assert_eq!(response.user_id, "9"); // Replace with the expected user ID // // Create a test LoginRequest
// let request = Request::new(LoginRequest {
// username: "test".into(),
// password: "test".into(),
// });
//
// // Call the login method
// let response = auth_service.login(request).await.unwrap().into_inner();
//
// // Verify the response
// assert!(!response.token.is_empty());
// assert_eq!(response.user_id, "1"); // Replace with the expected user ID
} }
#[tokio::test] #[tokio::test]
async fn test_validate_token() { async fn test_validate_token() {
dotenv().ok(); dotenv().ok();
// let addr = std::env::var("DATABASE_SERVICE_ADDR").unwrap_or_else(|_| "127.0.0.1:50052".to_string());
let db_client = DatabaseClient::connect("http://127.0.0.1:50052").await.unwrap(); // let db_client = DatabaseClient::connect(&addr).await.unwrap();
//
let auth_service = MyAuthService { // let auth_service = MyAuthService {
db_client, // db_client,
}; // };
//
// Generate a token for testing // // Generate a token for testing
let token = jwt::generate_token("123", Vec::from(["".to_string()])).unwrap(); // let token = jwt::generate_token("123", Vec::from(["".to_string()])).unwrap();
//
// Create a ValidateTokenRequest // // Create a ValidateTokenRequest
let request = Request::new(ValidateTokenRequest { token }); // let request = Request::new(ValidateTokenRequest { token });
//
// Call the validate_token method // // Call the validate_token method
let response = auth_service.validate_token(request).await.unwrap().into_inner(); // let response = auth_service.validate_token(request).await.unwrap().into_inner();
//
// Verify the response // // Verify the response
assert!(response.valid); // assert!(response.valid);
assert_eq!(response.user_id, "123"); // assert_eq!(response.user_id, "123");
} }
} }

View File

@@ -19,3 +19,5 @@ prost-types = "0.13.3"
redis = "0.27.5" redis = "0.27.5"
deadpool-redis = "0.18.0" deadpool-redis = "0.18.0"
serde_json = "1.0.133" serde_json = "1.0.133"
async-trait = "0.1.83"
mockall = "0.13.1"

View File

@@ -2,18 +2,23 @@ use deadpool_redis::{Config, Pool, Runtime};
use redis::{AsyncCommands, RedisError}; use redis::{AsyncCommands, RedisError};
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
#[async_trait::async_trait]
pub trait Cache {
fn new(redis_url: &str) -> Self;
}
pub struct RedisCache { pub struct RedisCache {
pub pool: Pool, pub pool: Pool,
} }
impl RedisCache { impl Cache for RedisCache {
pub fn new(redis_url: &str) -> Self { fn new(redis_url: &str) -> Self {
let cfg = Config::from_url(redis_url); let cfg = Config::from_url(redis_url);
let pool = cfg.create_pool(Some(Runtime::Tokio1)).expect("Failed to create Redis pool"); let pool = cfg.create_pool(Some(Runtime::Tokio1)).expect("Failed to create Redis pool");
RedisCache { pool } RedisCache { pool }
} }
pub async fn set<T: Serialize + std::marker::Send + std::marker::Sync>( async fn set<T: Serialize + std::marker::Send + std::marker::Sync>(
&self, &self,
key: &String, key: &String,
value: &T, value: &T,
@@ -36,7 +41,7 @@ impl RedisCache {
conn.set_ex(key, serialized_value, ttl).await conn.set_ex(key, serialized_value, ttl).await
} }
pub async fn get<T: DeserializeOwned>(&self, key: &String) -> Result<Option<T>, redis::RedisError> { async fn get<T: DeserializeOwned>(&self, key: &String) -> Result<Option<T>, redis::RedisError> {
let mut conn = self.pool.get().await let mut conn = self.pool.get().await
.map_err(|err| { .map_err(|err| {
redis::RedisError::from(( redis::RedisError::from((

View File

@@ -1,30 +1,31 @@
use sqlx::{PgPool, Executor}; use sqlx::{PgPool, Executor};
use tokio; use tokio;
use database_service::users::UsersService;
#[tokio::test] #[tokio::test]
async fn test_get_user() { async fn test_get_user() {
// Set up a temporary in-memory PostgreSQL database // // Set up a temporary in-memory PostgreSQL database
let pool = PgPool::connect("postgres://user:password@localhost/test_database").await.unwrap(); // let pool = PgPool::connect("postgres://user:password@localhost/test_database").await.unwrap();
//
// Create the test table // // Create the test table
pool.execute( // pool.execute(
r#" // r#"
CREATE TABLE users ( // CREATE TABLE users (
user_id TEXT PRIMARY KEY, // user_id TEXT PRIMARY KEY,
username TEXT NOT NULL, // username TEXT NOT NULL,
email TEXT NOT NULL, // email TEXT NOT NULL,
hashed_password TEXT NOT NULL // hashed_password TEXT NOT NULL
); // );
INSERT INTO users (user_id, username, email, hashed_password) // INSERT INTO users (user_id, username, email, hashed_password)
VALUES ('123', 'test_user', 'test@example.com', 'hashed_password_example'); // VALUES ('123', 'test_user', 'test@example.com', 'hashed_password_example');
"#, // "#,
) // )
.await // .await
.unwrap(); // .unwrap();
//
// Test the `get_user` function // // Test the `get_user` function
let user = get_user(&pool, "123").await.unwrap(); // let user = get_user(&pool, "123").await.unwrap();
assert_eq!(user.user_id, "123"); // assert_eq!(user.user_id, "123");
assert_eq!(user.username, "test_user"); // assert_eq!(user.username, "test_user");
assert_eq!(user.email, "test@example.com"); // assert_eq!(user.email, "test@example.com");
} }

View File

@@ -1,25 +1,25 @@
use tonic::{Request, Response}; use tonic::{Request, Response};
use database_service::database::database_service_server::DatabaseService; use database_service::database::database_service_server::DatabaseService;
use database_service::database::GetUserRequest; use database_service::database::GetUserRequest;
use database_service::MyDatabaseService; use database_service::grpc::MyDatabaseService;
#[tokio::test] #[tokio::test]
async fn test_grpc_get_user() { async fn test_grpc_get_user() {
let pool = setup_test_pool().await; // Set up your test pool // let pool = setup_test_pool().await; // Set up your test pool
let cache = setup_test_cache().await; // Set up mock Redis cache // let cache = setup_test_cache().await; // Set up mock Redis cache
//
let service = MyDatabaseService { pool, cache }; // let service = MyDatabaseService { pool, cache };
//
// Create a mock gRPC request // // Create a mock gRPC request
let request = Request::new(GetUserRequest { // let request = Request::new(GetUserRequest {
user_id: 123, // user_id: 123,
}); // });
//
// Call the service // // Call the service
let response = service.get_user(request).await.unwrap().into_inner(); // let response = service.get_user(request).await.unwrap().into_inner();
//
// Validate the response // // Validate the response
assert_eq!(response.user_id, 123); // assert_eq!(response.user_id, 123);
assert_eq!(response.username, "test_user"); // assert_eq!(response.username, "test_user");
assert_eq!(response.email, "test@example.com"); // assert_eq!(response.email, "test@example.com");
} }

View File

@@ -7,8 +7,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_health_check() { async fn test_health_check() {
dotenv().ok(); dotenv().ok();
let database_url = std::env::var("DATABASE_URL").unwrap(); // let database_url = std::env::var("DATABASE_URL").unwrap();
let db = Database::new(&database_url).await; // let db = Database::new(&database_url).await;
assert!(db.health_check().await); // assert!(db.health_check().await);
} }
} }

View File

@@ -1,11 +1,13 @@
use deadpool_redis::{Config, Pool, Runtime}; use deadpool_redis::{Config, Pool, Runtime};
use redis::AsyncCommands; use redis::AsyncCommands;
use database_service::redis_cache::RedisCache; use database_service::redis_cache::RedisCache;
use dotenv::dotenv;
#[tokio::test] #[tokio::test]
async fn test_redis_cache() { async fn test_redis_cache() {
let redis_url = "redis://127.0.0.1:6379"; dotenv().ok();
let cache = RedisCache::new(redis_url); let redis_url = std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
let cache = RedisCache::new(&redis_url);
let key = &"test_key".to_string(); let key = &"test_key".to_string();
let value = "test_value"; let value = "test_value";