Compare commits

..

4 Commits

Author SHA256 Message Date
0777bd4605 - fix: user_id and session id were not being saved in the state correctly
- add: server and channel id to connection state
2024-12-20 17:48:02 -05:00
9d9e2bef05 - add: session_id to the validate token response
- add: session_id to the jwt generated token
2024-12-20 17:46:54 -05:00
e0114fd832 - add: get_connection_mut function to allow modifying the connection state 2024-12-20 17:44:22 -05:00
e3fb186a44 - fix: when shutting down a docker container, the services would not deregister from consul correctly 2024-12-20 17:42:50 -05:00
14 changed files with 87 additions and 29 deletions

View File

@@ -67,8 +67,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
info!("Starting REST API on {}:{}", addr, port); info!("Starting REST API on {}:{}", addr, port);
tokio::spawn(axum_gateway::serve_rest_api(grpc_client)); tokio::spawn(axum_gateway::serve_rest_api(grpc_client));
let mut sigterm_stream = signal::unix::signal(signal::unix::SignalKind::terminate())?;
select! { select! {
_ = signal::ctrl_c() => {}, _ = signal::ctrl_c() => {
info!("Received SIGINT (Ctrl+C), shutting down...");
},
_ = sigterm_stream.recv() => {
info!("Received SIGTERM, shutting down...");
},
} }
consul_registration::deregister_service(&consul_url, service_id.as_str()) consul_registration::deregister_service(&consul_url, service_id.as_str())

View File

@@ -26,16 +26,14 @@ impl AuthService for MyAuthService {
info!("Login attempt for username: {}", req.username); info!("Login attempt for username: {}", req.username);
if let Some(user_id) = verify_user(self.db_client.as_ref().clone(), &req.username, &req.password).await { if let Some(user) = verify_user(self.db_client.as_ref().clone(), &req.username, &req.password).await {
let token = generate_token(&user_id, vec!["user".to_string()]) let user_id = user.user_id.to_string();
.map_err(|_| Status::internal("Token generation failed"))?;
let session_id = uuid::Uuid::new_v4().to_string(); let session_id = uuid::Uuid::new_v4().to_string();
let response = self let response = self
.session_client.as_ref().clone() .session_client.as_ref().clone()
.create_session(CreateSessionRequest { .create_session(CreateSessionRequest {
session_id: session_id.clone(), session_id: session_id.clone(),
user_id: user_id.parse().unwrap(), user_id: user.user_id,
username: req.username.to_string(), username: req.username.to_string(),
character_id: 0, character_id: 0,
ip_address: req.ip_address.to_string(), ip_address: req.ip_address.to_string(),
@@ -48,6 +46,9 @@ impl AuthService for MyAuthService {
}; };
let session_id = session.into_inner().session_id; let session_id = session.into_inner().session_id;
let token = generate_token(&user_id, &&session_id.clone(), user.roles)
.map_err(|_| Status::internal("Token generation failed"))?;
info!("Login successful for username: {}", req.username); info!("Login successful for username: {}", req.username);
Ok(Response::new(LoginResponse { token, user_id, session_id })) Ok(Response::new(LoginResponse { token, user_id, session_id }))
} else { } else {
@@ -77,13 +78,15 @@ impl AuthService for MyAuthService {
let req = request.into_inner(); let req = request.into_inner();
match validate_token(&req.token) { match validate_token(&req.token) {
Ok(user_id) => Ok(Response::new(ValidateTokenResponse { Ok(user_data) => Ok(Response::new(ValidateTokenResponse {
valid: true, valid: true,
user_id, user_id: user_data.0,
session_id: user_data.1,
})), })),
Err(_) => Ok(Response::new(ValidateTokenResponse { Err(_) => Ok(Response::new(ValidateTokenResponse {
valid: false, valid: false,
user_id: "".to_string(), user_id: "".to_string(),
session_id: "".to_string(),
})), })),
} }
} }

View File

@@ -5,11 +5,12 @@ use std::env;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct Claims { struct Claims {
sub: String, // Subject (user ID) sub: String, // Subject (user ID)
session_id: String, // Session ID
roles: Vec<String>, // Roles/permissions roles: Vec<String>, // Roles/permissions
exp: usize, // Expiration time exp: usize, // Expiration time
} }
pub fn generate_token(user_id: &str, roles: Vec<String>) -> Result<String, jsonwebtoken::errors::Error> { pub fn generate_token(user_id: &str, session_id: &str, roles: Vec<String>) -> Result<String, jsonwebtoken::errors::Error> {
let secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set"); let secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
let expiration = chrono::Utc::now() let expiration = chrono::Utc::now()
.checked_add_signed(chrono::Duration::days(1)) .checked_add_signed(chrono::Duration::days(1))
@@ -18,6 +19,7 @@ pub fn generate_token(user_id: &str, roles: Vec<String>) -> Result<String, jsonw
let claims = Claims { let claims = Claims {
sub: user_id.to_owned(), sub: user_id.to_owned(),
session_id: session_id.to_owned(),
roles, roles,
exp: expiration, exp: expiration,
}; };
@@ -25,12 +27,12 @@ pub fn generate_token(user_id: &str, roles: Vec<String>) -> Result<String, jsonw
encode(&Header::default(), &claims, &EncodingKey::from_secret(secret.as_ref())) encode(&Header::default(), &claims, &EncodingKey::from_secret(secret.as_ref()))
} }
pub fn validate_token(token: &str) -> Result<String, jsonwebtoken::errors::Error> { pub fn validate_token(token: &str) -> Result<(String, String), jsonwebtoken::errors::Error> {
let secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set"); let secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
let token_data = decode::<Claims>( let token_data = decode::<Claims>(
token, token,
&DecodingKey::from_secret(secret.as_ref()), &DecodingKey::from_secret(secret.as_ref()),
&Validation::default(), &Validation::default(),
)?; )?;
Ok(token_data.claims.sub) Ok((token_data.claims.sub, token_data.claims.session_id))
} }

View File

@@ -77,9 +77,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tokio::spawn(Server::builder() tokio::spawn(Server::builder()
.add_service(AuthServiceServer::new(auth_service)) .add_service(AuthServiceServer::new(auth_service))
.serve(address)); .serve(address));
let mut sigterm_stream = signal::unix::signal(signal::unix::SignalKind::terminate())?;
select! { select! {
_ = signal::ctrl_c() => {}, _ = signal::ctrl_c() => {
info!("Received SIGINT (Ctrl+C), shutting down...");
},
_ = sigterm_stream.recv() => {
info!("Received SIGTERM, shutting down...");
},
} }
consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect(""); consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect("");

View File

@@ -1,4 +1,5 @@
use crate::database_client::DatabaseClientTrait; use crate::database_client::DatabaseClientTrait;
use crate::database::GetUserResponse;
use argon2::{ use argon2::{
password_hash::{ password_hash::{
@@ -20,11 +21,11 @@ pub fn verify_password(password: &str, hash: &str) -> bool {
} }
pub async fn verify_user<T: DatabaseClientTrait>(mut db_client: T, pub async fn verify_user<T: DatabaseClientTrait>(mut db_client: T,
username: &str, password: &str) -> Option<String> { username: &str, password: &str) -> Option<GetUserResponse> {
let user = db_client.get_user_by_username(username).await.ok()?; let user = db_client.get_user_by_username(username).await.ok()?;
if verify_password(password, &user.hashed_password) { if verify_password(password, &user.hashed_password) {
Some(user.user_id.to_string()) Some(user)
} else { } else {
None None
} }

View File

@@ -3,7 +3,7 @@ use std::collections::HashMap;
use std::env; use std::env;
use std::str::FromStr; use std::str::FromStr;
use tokio::{select, signal}; use tokio::{select, signal};
use tracing::Level; use tracing::{info, Level};
use utils::consul_registration; use utils::consul_registration;
use utils::service_discovery::get_service_address; use utils::service_discovery::get_service_address;
@@ -50,8 +50,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let db_address = db_nodes.get(0).unwrap(); let db_address = db_nodes.get(0).unwrap();
let db_url = format!("http://{}:{}", db_address.ServiceAddress, db_address.ServicePort); let db_url = format!("http://{}:{}", db_address.ServiceAddress, db_address.ServicePort);
let mut sigterm_stream = signal::unix::signal(signal::unix::SignalKind::terminate())?;
select! { select! {
_ = signal::ctrl_c() => {}, _ = signal::ctrl_c() => {
info!("Received SIGINT (Ctrl+C), shutting down...");
},
_ = sigterm_stream.recv() => {
info!("Received SIGTERM, shutting down...");
},
} }
consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect(""); consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect("");

View File

@@ -73,8 +73,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.add_service(CharacterServiceServer::new(my_service)) .add_service(CharacterServiceServer::new(my_service))
.serve(address)); .serve(address));
let mut sigterm_stream = signal::unix::signal(signal::unix::SignalKind::terminate())?;
select! { select! {
_ = signal::ctrl_c() => {}, _ = signal::ctrl_c() => {
info!("Received SIGINT (Ctrl+C), shutting down...");
},
_ = sigterm_stream.recv() => {
info!("Received SIGTERM, shutting down...");
},
} }
consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect(""); consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect("");

View File

@@ -24,6 +24,10 @@ impl ConnectionService {
self.connections.get(connection_id).map(|entry| entry.clone()) self.connections.get(connection_id).map(|entry| entry.clone())
} }
pub fn get_connection_mut(&self, connection_id: &str) -> Option<dashmap::mapref::one::RefMut<'_, String, ConnectionState>> {
self.connections.get_mut(connection_id)
}
pub fn remove_connection(&self, connection_id: &str) { pub fn remove_connection(&self, connection_id: &str) {
self.connections.remove(connection_id); self.connections.remove(connection_id);
} }

View File

@@ -68,9 +68,9 @@ pub(crate) async fn handle_login_req(stream: &mut TcpStream, packet: Packet, aut
} else { } else {
debug!("Successfully logged in"); debug!("Successfully logged in");
if let Some(mut state) = connection_service.get_connection(&connection_id) { if let Some(mut state) = connection_service.get_connection_mut(&connection_id) {
state.user_id = Some(response.user_id.parse().unwrap()); state.user_id = Some(response.user_id.parse().unwrap());
// auth_client.logout(&session_id).await?; state.session_id = Some(response.session_id);
} }
let consul_url = env::var("CONSUL_URL").unwrap_or_else(|_| "http://127.0.0.1:8500".to_string()); let consul_url = env::var("CONSUL_URL").unwrap_or_else(|_| "http://127.0.0.1:8500".to_string());
@@ -136,10 +136,15 @@ pub(crate) async fn handle_login_req(stream: &mut TcpStream, packet: Packet, aut
Ok(()) Ok(())
} }
pub(crate) async fn handle_server_select_req(stream: &mut TcpStream, packet: Packet) -> Result<(), Box<dyn Error + Send + Sync>> { pub(crate) async fn handle_server_select_req(stream: &mut TcpStream, packet: Packet, connection_service: Arc<ConnectionService>, connection_id: String) -> Result<(), Box<dyn Error + Send + Sync>> {
let request = CliSrvSelectReq::decode(packet.payload.as_slice()); let request = CliSrvSelectReq::decode(packet.payload.as_slice())?;
debug!("{:?}", request); debug!("{:?}", request);
if let Some(mut state) = connection_service.get_connection_mut(&connection_id) {
state.additional_data.insert("server".to_string(), request.server_id.to_string());
state.additional_data.insert("channel".to_string(), request.channel_id.to_string());
}
let data = SrvSrvSelectReply { let data = SrvSrvSelectReply {
result: srv_srv_select_reply::Result::Failed, result: srv_srv_select_reply::Result::Failed,
session_id: 0, // Client should already have this value session_id: 0, // Client should already have this value

View File

@@ -64,7 +64,7 @@ async fn handle_connection(stream: &mut TcpStream, pool: Arc<BufferPool>, auth_c
pool.release(buffer).await; pool.release(buffer).await;
} }
if let Some(state) = connection_service.get_connection(&connection_id) { if let Some(state) = connection_service.get_connection(&connection_id) {
let session_id = state.session_id.unwrap(); let session_id = state.session_id.unwrap();
let mut auth_client = auth_client.lock().await; let mut auth_client = auth_client.lock().await;
@@ -148,8 +148,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
}); });
let mut sigterm_stream = signal::unix::signal(signal::unix::SignalKind::terminate())?;
select! { select! {
_ = signal::ctrl_c() => {}, _ = signal::ctrl_c() => {
info!("Received SIGINT (Ctrl+C), shutting down...");
},
_ = sigterm_stream.recv() => {
info!("Received SIGTERM, shutting down...");
},
} }
consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect(""); consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect("");

View File

@@ -18,7 +18,7 @@ pub async fn route_packet(stream: &mut TcpStream, packet: Packet, auth_client: A
// Login Stuff // Login Stuff
PacketType::PakcsLoginTokenReq => auth::handle_login_req(stream, packet, auth_client, connection_service, connection_id, stream.peer_addr()?).await, PacketType::PakcsLoginTokenReq => auth::handle_login_req(stream, packet, auth_client, connection_service, connection_id, stream.peer_addr()?).await,
PacketType::PakcsLogoutReq => auth::handle_logout_req(stream, packet, auth_client, connection_service, connection_id).await, PacketType::PakcsLogoutReq => auth::handle_logout_req(stream, packet, auth_client, connection_service, connection_id).await,
PacketType::PakcsSrvSelectReq => auth::handle_server_select_req(stream, packet).await, PacketType::PakcsSrvSelectReq => auth::handle_server_select_req(stream, packet, connection_service, connection_id).await,
PacketType::PakcsChannelListReq => auth::handle_channel_list_req(stream, packet).await, PacketType::PakcsChannelListReq => auth::handle_channel_list_req(stream, packet).await,
// Character Stuff // Character Stuff

View File

@@ -35,6 +35,7 @@ message ValidateTokenRequest {
message ValidateTokenResponse { message ValidateTokenResponse {
bool valid = 1; bool valid = 1;
string user_id = 2; string user_id = 2;
string session_id = 3;
} }
message ValidateSessionRequest { message ValidateSessionRequest {

View File

@@ -8,7 +8,7 @@ use std::sync::Arc;
use tokio::{select, signal}; use tokio::{select, signal};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tonic::transport::Server; use tonic::transport::Server;
use tracing::Level; use tracing::{info, Level};
use utils::consul_registration; use utils::consul_registration;
use utils::redis_cache::RedisCache; use utils::redis_cache::RedisCache;
use utils::service_discovery::get_service_address; use utils::service_discovery::get_service_address;
@@ -68,8 +68,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.add_service(SessionServiceServer::new(session_service)) .add_service(SessionServiceServer::new(session_service))
.serve(address)); .serve(address));
let mut sigterm_stream = signal::unix::signal(signal::unix::SignalKind::terminate())?;
select! { select! {
_ = signal::ctrl_c() => {}, _ = signal::ctrl_c() => {
info!("Received SIGINT (Ctrl+C), shutting down...");
},
_ = sigterm_stream.recv() => {
info!("Received SIGTERM, shutting down...");
},
} }
consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect(""); consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect("");

View File

@@ -3,7 +3,7 @@ use std::collections::HashMap;
use std::env; use std::env;
use std::str::FromStr; use std::str::FromStr;
use tokio::{select, signal}; use tokio::{select, signal};
use tracing::Level; use tracing::{info, Level};
use utils::consul_registration; use utils::consul_registration;
use utils::service_discovery::get_service_address; use utils::service_discovery::get_service_address;
@@ -50,8 +50,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let db_address = db_nodes.get(0).unwrap(); let db_address = db_nodes.get(0).unwrap();
let db_url = format!("http://{}:{}", db_address.ServiceAddress, db_address.ServicePort); let db_url = format!("http://{}:{}", db_address.ServiceAddress, db_address.ServicePort);
let mut sigterm_stream = signal::unix::signal(signal::unix::SignalKind::terminate())?;
select! { select! {
_ = signal::ctrl_c() => {}, _ = signal::ctrl_c() => {
info!("Received SIGINT (Ctrl+C), shutting down...");
},
_ = sigterm_stream.recv() => {
info!("Received SIGTERM, shutting down...");
},
} }
consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect(""); consul_registration::deregister_service(&consul_url, service_id.as_str()).await.expect("");