- add: logout grpc function

- add: logout packet handler
- add: connection state and service for storing connection data
- add: session service calls to auth-service
- fix: compile error on database service due to moved redis cache
This commit is contained in:
2024-12-20 14:46:00 -05:00
parent 3c1f8c40d6
commit 18afa71d74
22 changed files with 265 additions and 46 deletions

View File

@@ -1,10 +1,13 @@
[workspace] [workspace]
members = [ members = [
"api-service",
"auth-service", "auth-service",
"character-service", "character-service",
"database-service", "database-service",
"packet-service", "packet-service",
"session-service",
"world-service", "world-service",
"utils", "launcher", "api-service", "utils",
"launcher"
] ]
resolver = "2" resolver = "2"

View File

@@ -17,7 +17,7 @@ tracing = "0.1.41"
tracing-subscriber = "0.3.19" tracing-subscriber = "0.3.19"
utils = { path = "../utils" } utils = { path = "../utils" }
dotenv = "0.15" dotenv = "0.15"
tower-http = { version = "0.6.2", features = ["cors"] } tower-http = { version = "0.6.2", features = ["cors", "trace"] }
[build-dependencies] [build-dependencies]
tonic-build = "0.12.3" tonic-build = "0.12.3"

View File

@@ -1,15 +1,19 @@
use axum::extract::State; use axum::extract::{ConnectInfo, State};
use axum::{routing::post, Json, Router}; use axum::{routing::post, Json, Router};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::env; use std::env;
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use axum::http::Method; use axum::http::Method;
use tonic::transport::Channel; use tonic::transport::Channel;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tower::ServiceBuilder;
use auth::auth_service_client::AuthServiceClient; use auth::auth_service_client::AuthServiceClient;
use auth::{LoginRequest, RegisterRequest}; use auth::{LoginRequest, RegisterRequest};
use log::error; use log::{error, info};
use tokio::sync::Mutex; use tokio::sync::Mutex;
pub mod auth { pub mod auth {
@@ -25,6 +29,7 @@ struct RestLoginRequest {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct RestLoginResponse { struct RestLoginResponse {
token: String, token: String,
session_id: String,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@@ -41,19 +46,27 @@ struct RestRegisterResponse {
} }
async fn login_handler( async fn login_handler(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(grpc_client): State<Arc<Mutex<AuthServiceClient<Channel>>>>, State(grpc_client): State<Arc<Mutex<AuthServiceClient<Channel>>>>,
Json(payload): Json<RestLoginRequest>, Json(payload): Json<RestLoginRequest>,
) -> Result<Json<RestLoginResponse>, axum::http::StatusCode> { ) -> Result<Json<RestLoginResponse>, axum::http::StatusCode> {
let ip_address = addr.ip().to_string();
info!("Client IP Address: {}", ip_address);
let request = tonic::Request::new(LoginRequest { let request = tonic::Request::new(LoginRequest {
username: payload.username.clone(), username: payload.username.clone(),
password: payload.password.clone(), password: payload.password.clone(),
ip_address
}); });
let mut client = grpc_client.lock().await; // Lock the mutex to get mutable access let mut client = grpc_client.lock().await; // Lock the mutex to get mutable access
match client.login(request).await { match client.login(request).await {
Ok(response) => { Ok(response) => {
let token = response.into_inner().token; let resp = response.into_inner();
Ok(Json(RestLoginResponse { token })) let token = resp.token;
let session_id = resp.session_id;
Ok(Json(RestLoginResponse { token, session_id }))
} }
Err(e) => { Err(e) => {
error!("gRPC Login call failed: {}", e); error!("gRPC Login call failed: {}", e);
@@ -63,6 +76,7 @@ async fn login_handler(
} }
async fn register_handler( async fn register_handler(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(grpc_client): State<Arc<Mutex<AuthServiceClient<Channel>>>>, State(grpc_client): State<Arc<Mutex<AuthServiceClient<Channel>>>>,
Json(payload): Json<RestRegisterRequest>, Json(payload): Json<RestRegisterRequest>,
) -> Result<Json<RestRegisterResponse>, axum::http::StatusCode> { ) -> Result<Json<RestRegisterResponse>, axum::http::StatusCode> {
@@ -106,7 +120,7 @@ pub async fn serve_rest_api(
let listener = tokio::net::TcpListener::bind(format!("{}:{}", addr, port)) let listener = tokio::net::TcpListener::bind(format!("{}:{}", addr, port))
.await .await
.unwrap(); .unwrap();
axum::serve(listener, app.into_make_service()) axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>())
.await .await
.unwrap(); .unwrap();

View File

@@ -25,6 +25,7 @@ rand = "0.8.5"
warp = "0.3.7" warp = "0.3.7"
reqwest = { version = "0.12.9", features = ["json"] } reqwest = { version = "0.12.9", features = ["json"] }
utils = { path = "../utils" } utils = { path = "../utils" }
uuid = "1.11.0"
[build-dependencies] [build-dependencies]
tonic-build = "0.12.3" tonic-build = "0.12.3"

View File

@@ -11,6 +11,6 @@ fn main() {
tonic_build::configure() tonic_build::configure()
.build_server(false) // Generate gRPC client code .build_server(false) // Generate gRPC client code
.compile_well_known_types(true) .compile_well_known_types(true)
.compile_protos(&["../proto/user_db_api.proto"], &["../proto"]) .compile_protos(&["../proto/user_db_api.proto", "../proto/session_service_api.proto"], &["../proto"])
.unwrap_or_else(|e| panic!("Failed to compile protos {:?}", e)); .unwrap_or_else(|e| panic!("Failed to compile protos {:?}", e));
} }

View File

@@ -1,6 +1,9 @@
use std::sync::Arc;
use crate::auth::auth_service_server::AuthService; use crate::auth::auth_service_server::AuthService;
use crate::auth::{LoginRequest, LoginResponse, PasswordResetRequest, PasswordResetResponse, RegisterRequest, RegisterResponse, ResetPasswordRequest, ResetPasswordResponse, ValidateTokenRequest, ValidateTokenResponse}; use crate::auth::{LoginRequest, LoginResponse, PasswordResetRequest, PasswordResetResponse, RegisterRequest, RegisterResponse, ResetPasswordRequest, ResetPasswordResponse, ValidateTokenRequest, ValidateTokenResponse, ValidateSessionRequest, ValidateSessionResponse, Empty, LogoutRequest};
use crate::database_client::{DatabaseClientTrait}; use crate::database_client::{DatabaseClient, DatabaseClientTrait};
use crate::session::session_service_client::SessionServiceClient;
use crate::session::{CreateSessionRequest, GetSessionRequest, DeleteSessionRequest};
use crate::jwt::{generate_token, validate_token}; use crate::jwt::{generate_token, validate_token};
use crate::users::{hash_password, verify_user}; use crate::users::{hash_password, verify_user};
use chrono::{Duration, Utc}; use chrono::{Duration, Utc};
@@ -8,12 +11,13 @@ use rand::Rng;
use tonic::{Request, Response, Status}; use tonic::{Request, Response, Status};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
pub struct MyAuthService<T: DatabaseClientTrait + Clone> { pub struct MyAuthService {
pub db_client: T, pub db_client: Arc<DatabaseClient>,
pub session_client: Arc<SessionServiceClient<tonic::transport::Channel>>,
} }
#[tonic::async_trait] #[tonic::async_trait]
impl<T: DatabaseClientTrait + Send + Sync + Clone + 'static> AuthService for MyAuthService<T> { impl AuthService for MyAuthService {
async fn login( async fn login(
&self, &self,
request: Request<LoginRequest>, request: Request<LoginRequest>,
@@ -22,18 +26,50 @@ impl<T: DatabaseClientTrait + Send + Sync + Clone + 'static> AuthService for MyA
info!("Login attempt for username: {}", req.username); info!("Login attempt for username: {}", req.username);
if let Some(user_id) = verify_user(self.db_client.clone(), &req.username, &req.password).await { if let Some(user_id) = verify_user(self.db_client.as_ref().clone(), &req.username, &req.password).await {
let token = generate_token(&user_id, vec!["user".to_string()]) let token = generate_token(&user_id, vec!["user".to_string()])
.map_err(|_| Status::internal("Token generation failed"))?; .map_err(|_| Status::internal("Token generation failed"))?;
let session_id = uuid::Uuid::new_v4().to_string();
let response = self
.session_client.as_ref().clone()
.create_session(CreateSessionRequest {
session_id: session_id.clone(),
user_id: user_id.parse().unwrap(),
username: req.username.to_string(),
character_id: 0,
ip_address: req.ip_address.to_string(),
})
.await;
let session = match response {
Ok(session) => session,
Err(_) => return Err(Status::internal("Session creation failed")),
};
let session_id = session.into_inner().session_id;
info!("Login successful for username: {}", req.username); info!("Login successful for username: {}", req.username);
Ok(Response::new(LoginResponse { token, user_id })) Ok(Response::new(LoginResponse { token, user_id, session_id }))
} else { } else {
warn!("Invalid login attempt for username: {}", req.username); warn!("Invalid login attempt for username: {}", req.username);
Err(Status::unauthenticated("Invalid credentials")) Err(Status::unauthenticated("Invalid credentials"))
} }
} }
async fn logout(
&self,
request: Request<LogoutRequest>,
) -> Result<Response<Empty>, Status> {
let req = request.into_inner();
self.session_client.as_ref().clone()
.delete_session(DeleteSessionRequest {
session_id: req.session_id.clone(),
})
.await?;
Ok(Response::new(Empty {}))
}
async fn validate_token( async fn validate_token(
&self, &self,
request: Request<ValidateTokenRequest>, request: Request<ValidateTokenRequest>,
@@ -52,6 +88,30 @@ impl<T: DatabaseClientTrait + Send + Sync + Clone + 'static> AuthService for MyA
} }
} }
async fn validate_session(
&self,
request: Request<ValidateSessionRequest>,
) -> Result<Response<ValidateSessionResponse>, Status> {
let req = request.into_inner();
let response = self
.session_client.as_ref().clone()
.get_session(GetSessionRequest {
session_id: req.session_id,
})
.await;
match response {
Ok(res) => {
println!("Session valid: {:?}", res.into_inner());
Ok(Response::new(ValidateSessionResponse { valid: true }))
}
Err(_) => {
println!("Session invalid or not found");
Ok(Response::new(ValidateSessionResponse { valid: false }))
}
}
}
async fn register( async fn register(
&self, &self,
request: Request<RegisterRequest>, request: Request<RegisterRequest>,
@@ -62,7 +122,7 @@ impl<T: DatabaseClientTrait + Send + Sync + Clone + 'static> AuthService for MyA
let hashed_password = hash_password(&req.password); let hashed_password = hash_password(&req.password);
// Create user in the database // Create user in the database
let result = self.db_client.clone().create_user(&req.username, &req.email, &hashed_password) let result = self.db_client.as_ref().clone().create_user(&req.username, &req.email, &hashed_password)
.await; .await;
match result { match result {
@@ -83,7 +143,7 @@ impl<T: DatabaseClientTrait + Send + Sync + Clone + 'static> AuthService for MyA
) -> Result<Response<PasswordResetResponse>, Status> { ) -> Result<Response<PasswordResetResponse>, Status> {
let email = request.into_inner().email; let email = request.into_inner().email;
let user = self.db_client.clone().get_user_by_email(&email).await; let user = self.db_client.as_ref().clone().get_user_by_email(&email).await;
// Check if the email exists // Check if the email exists
if user.ok().is_some() { if user.ok().is_some() {
@@ -98,14 +158,14 @@ impl<T: DatabaseClientTrait + Send + Sync + Clone + 'static> AuthService for MyA
let expires_at = Utc::now() + Duration::hours(1); let expires_at = Utc::now() + Duration::hours(1);
// Store the reset token in the database // Store the reset token in the database
self.db_client.clone() self.db_client.as_ref().clone()
.store_password_reset(&email, &reset_token, expires_at) .store_password_reset(&email, &reset_token, expires_at)
.await .await
.map_err(|e| Status::internal(format!("Database error: {}", e)))?; .map_err(|e| Status::internal(format!("Database error: {}", e)))?;
// Send the reset email // Send the reset email
// send_email(&email, "Password Reset Request", &format!( // send_email(&email, "Password Reset Request", &format!(
// "Click the link to reset your password: https://example.com/reset?token={}", // "Click the link to reset your password: https://azgstudio.com/reset?token={}",
// reset_token // reset_token
// )) // ))
// .map_err(|e| Status::internal(format!("Email error: {}", e)))?; // .map_err(|e| Status::internal(format!("Email error: {}", e)))?;

View File

@@ -4,10 +4,14 @@ pub mod database_client;
pub mod users; pub mod users;
pub mod auth { pub mod auth {
tonic::include_proto!("auth"); // Path matches the package name in auth.proto tonic::include_proto!("auth");
} }
pub mod database { pub mod database {
tonic::include_proto!("user_db_api"); // Matches package name in user_db_api.proto tonic::include_proto!("user_db_api");
}
pub mod session {
tonic::include_proto!("session_service_api");
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -1,19 +1,18 @@
use auth_service::auth::auth_service_server::AuthServiceServer; use auth_service::auth::auth_service_server::AuthServiceServer;
use auth_service::database_client::DatabaseClient; use auth_service::database_client::DatabaseClient;
use auth_service::database_client::DatabaseClientTrait; use auth_service::database_client::DatabaseClientTrait;
use auth_service::session::session_service_client::SessionServiceClient;
use auth_service::grpc::MyAuthService; use auth_service::grpc::MyAuthService;
use dotenv::dotenv; use dotenv::dotenv;
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
use std::net::ToSocketAddrs;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use tokio::{select, signal}; use tokio::{select, signal};
use tonic::transport::Server; use tonic::transport::Server;
use tracing::log::debug;
use tracing::{info, Level}; use tracing::{info, Level};
use utils::consul_registration; use utils::consul_registration;
use utils::service_discovery::get_service_address; use utils::service_discovery::get_service_address;
use warp::Filter;
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
@@ -36,6 +35,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let health_check_url = format!("http://{}:{}/health", service_address, health_port); let health_check_url = format!("http://{}:{}/health", service_address, health_port);
let health_check_endpoint_addr = format!("{}:{}", service_address, health_port); let health_check_endpoint_addr = format!("{}:{}", service_address, health_port);
let db_nodes = get_service_address(&consul_url, "database-service").await?; let db_nodes = get_service_address(&consul_url, "database-service").await?;
let session_nodes = get_service_address(&consul_url, "session-service").await?;
// Register service with Consul // Register service with Consul
let service_id = consul_registration::get_or_generate_service_id(); let service_id = consul_registration::get_or_generate_service_id();
@@ -58,12 +58,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let db_address = db_nodes.get(0).unwrap(); let db_address = db_nodes.get(0).unwrap();
let db_url = format!("http://{}:{}", db_address.ServiceAddress, db_address.ServicePort); let db_url = format!("http://{}:{}", db_address.ServiceAddress, db_address.ServicePort);
let database_client = DatabaseClient::connect(&db_url).await?; let db_client = Arc::new(DatabaseClient::connect(&db_url).await?);
let session_address = session_nodes.get(0).unwrap();
let session_address = format!("http://{}:{}", session_address.ServiceAddress, session_address.ServicePort);
let session_client = Arc::new(SessionServiceClient::connect(session_address).await?);
let full_addr = format!("{}:{}", &addr, port); let full_addr = format!("{}:{}", &addr, port);
let address = full_addr.parse().expect("Invalid address"); let address = full_addr.parse().expect("Invalid address");
let auth_service = MyAuthService { let auth_service = MyAuthService {
db_client: database_client, db_client,
session_client
}; };
println!("Authentication Service running on {}", addr); println!("Authentication Service running on {}", addr);

View File

@@ -1,6 +1,6 @@
use crate::users::UserRepository; use crate::users::UserRepository;
use crate::characters::CharacterRepository; use crate::characters::CharacterRepository;
use crate::redis_cache::RedisCache; use utils::redis_cache::RedisCache;
use sqlx::PgPool; use sqlx::PgPool;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;

View File

@@ -1,5 +1,4 @@
pub mod users; pub mod users;
pub mod characters; pub mod characters;
pub mod redis_cache;
pub mod db; pub mod db;
pub mod grpc; pub mod grpc;

View File

@@ -24,6 +24,7 @@
environment: environment:
- HEALTH_CHECK_PORT=8080 - HEALTH_CHECK_PORT=8080
depends_on: depends_on:
- session-service
- database-service - database-service
- consul - consul
@@ -93,6 +94,18 @@
- auth-service - auth-service
- consul - consul
session-service:
build:
context: ./
dockerfile: ./session-service/Dockerfile
ports:
- "50055:50055"
env_file:
- ./session-service/.env
- .env
depends_on:
- consul
db: db:
image: postgres:17 image: postgres:17
env_file: env_file:

View File

@@ -34,12 +34,12 @@ const LoginPage = () => {
const response = await axios.post(apiUrl, { username, password }); const response = await axios.post(apiUrl, { username, password });
// Extract token and server info from response // Extract token and server info from response
const { token, serverIp } = response.data; const { token, session_id } = response.data;
setMessage("Login successful! Launching game..."); setMessage("Login successful! Launching game...");
const { ServiceAddress, ServicePort } = await getServiceAddress("packet-service"); const { ServiceAddress, ServicePort } = await getServiceAddress("packet-service");
window.location.href = `osirose-launcher://launch?otp=${encodeURIComponent(token)}&ip=${encodeURIComponent(ServiceAddress)}&port=${encodeURIComponent(ServicePort)}&username=${encodeURIComponent(username)}`; window.location.href = `osirose-launcher://launch?otp=${encodeURIComponent(token)}&session=${encodeURIComponent(session_id)}&ip=${encodeURIComponent(ServiceAddress)}&port=${encodeURIComponent(ServicePort)}&username=${encodeURIComponent(username)}`;
} catch (error) { } catch (error) {
setMessage("Login failed: " + error.response?.data?.error || error.message); setMessage("Login failed: " + error.response?.data?.error || error.message);
} }

View File

@@ -50,6 +50,9 @@ pub(crate) fn launch_game(url: String) {
is_direct = true; is_direct = true;
command.arg("_direct").arg("_otp").arg(value.to_string()); command.arg("_direct").arg("_otp").arg(value.to_string());
} }
Cow::Borrowed("session") => {
command.arg("_session").arg(value.to_string());
}
Cow::Borrowed("username") => { Cow::Borrowed("username") => {
command.arg("_userid").arg(value.to_string()); command.arg("_userid").arg(value.to_string());
} }

View File

@@ -23,6 +23,8 @@ tonic = "0.12.3"
prost = "0.13.4" prost = "0.13.4"
utils = { path = "../utils" } utils = { path = "../utils" }
warp = "0.3.7" warp = "0.3.7"
dashmap = "6.1.0"
uuid = { version = "1.11.0", features = ["v4"] }
[build-dependencies] [build-dependencies]
tonic-build = "0.12.3" tonic-build = "0.12.3"

View File

@@ -1,5 +1,5 @@
use crate::auth::auth_service_client::AuthServiceClient; use crate::auth::auth_service_client::AuthServiceClient;
use crate::auth::{LoginRequest, LoginResponse, ValidateTokenRequest, ValidateTokenResponse}; use crate::auth::{Empty, LoginRequest, LoginResponse, LogoutRequest, ValidateTokenRequest, ValidateTokenResponse};
use tonic::transport::Channel; use tonic::transport::Channel;
pub struct AuthClient { pub struct AuthClient {
@@ -12,10 +12,11 @@ impl AuthClient {
Ok(AuthClient { client }) Ok(AuthClient { client })
} }
pub async fn login(&mut self, username: &str, password: &str) -> Result<LoginResponse, Box<dyn std::error::Error + Send + Sync>> { pub async fn login(&mut self, username: &str, password: &str, ip_address: &str) -> Result<LoginResponse, Box<dyn std::error::Error + Send + Sync>> {
let request = LoginRequest { let request = LoginRequest {
username: username.to_string(), username: username.to_string(),
password: password.to_string(), password: password.to_string(),
ip_address: ip_address.to_string(),
}; };
let response = self.client.login(request).await?; let response = self.client.login(request).await?;
@@ -24,10 +25,19 @@ impl AuthClient {
pub async fn login_token(&mut self, token: &str) -> Result<ValidateTokenResponse, Box<dyn std::error::Error + Send + Sync>> { pub async fn login_token(&mut self, token: &str) -> Result<ValidateTokenResponse, Box<dyn std::error::Error + Send + Sync>> {
let request = ValidateTokenRequest { let request = ValidateTokenRequest {
token: token.to_string(), token: token.to_string()
}; };
let response = self.client.validate_token(request).await?; let response = self.client.validate_token(request).await?;
Ok(response.into_inner()) Ok(response.into_inner())
} }
pub async fn logout(&mut self, session_id: &str) -> Result<Empty, Box<dyn std::error::Error + Send + Sync>> {
let request = LogoutRequest {
session_id: session_id.to_string(),
};
let response = self.client.logout(request).await?;
Ok(response.into_inner())
}
} }

View File

@@ -0,0 +1,30 @@
use dashmap::DashMap;
use std::sync::Arc;
use uuid::Uuid;
use crate::connection_state::ConnectionState;
pub struct ConnectionService {
pub connections: Arc<DashMap<String, ConnectionState>>, // Map connection ID to state
}
impl ConnectionService {
pub fn new() -> Self {
Self {
connections: Arc::new(DashMap::new()),
}
}
pub fn add_connection(&self) -> String {
let connection_id = Uuid::new_v4().to_string();
self.connections.insert(connection_id.clone(), ConnectionState::new());
connection_id
}
pub fn get_connection(&self, connection_id: &str) -> Option<ConnectionState> {
self.connections.get(connection_id).map(|entry| entry.clone())
}
pub fn remove_connection(&self, connection_id: &str) {
self.connections.remove(connection_id);
}
}

View File

@@ -0,0 +1,20 @@
use std::collections::HashMap;
#[derive(Clone, Debug)]
pub struct ConnectionState {
pub user_id: Option<i32>,
pub session_id: Option<String>,
pub character_id: Option<i32>,
pub additional_data: HashMap<String, String>, // Flexible data storage
}
impl ConnectionState {
pub fn new() -> Self {
Self {
user_id: None,
session_id: None,
character_id: None,
additional_data: HashMap::new(),
}
}
}

View File

@@ -15,17 +15,17 @@ use crate::packets::*;
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
use std::error::Error; use std::error::Error;
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tonic::{Code, Status}; use tonic::{Code, Status};
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use utils::service_discovery; use utils::service_discovery;
use crate::connection_service::ConnectionService;
use crate::packets::cli_logout_req::CliLogoutReq;
pub(crate) async fn handle_accept_req(stream: &mut TcpStream, packet: Packet) -> Result<(), Box<dyn Error + Send + Sync>> { pub(crate) async fn handle_accept_req(stream: &mut TcpStream, packet: Packet) -> Result<(), Box<dyn Error + Send + Sync>> {
// let request = CliAcceptReq::decode(packet.payload.as_slice());
// We need to do reply to this packet
let data = SrvAcceptReply { result: srv_accept_reply::Result::Accepted, rand_value: 0 }; let data = SrvAcceptReply { result: srv_accept_reply::Result::Accepted, rand_value: 0 };
let response_packet = Packet::new(PacketType::PakssAcceptReply, &data)?; let response_packet = Packet::new(PacketType::PakssAcceptReply, &data)?;
@@ -39,7 +39,19 @@ pub(crate) async fn handle_join_server_req(stream: &mut TcpStream, packet: Packe
Ok(()) Ok(())
} }
pub(crate) async fn handle_login_req(stream: &mut TcpStream, packet: Packet, auth_client: Arc<Mutex<AuthClient>>) -> Result<(), Box<dyn Error + Send + Sync>> { pub(crate) async fn handle_logout_req(stream: &mut TcpStream, packet: Packet, auth_client: Arc<Mutex<AuthClient>>, connection_service: Arc<ConnectionService>, connection_id: String) -> Result<(), Box<dyn Error + Send + Sync>> {
let request = CliLogoutReq::decode(packet.payload.as_slice());
let mut auth_client = auth_client.lock().await;
if let Some(mut state) = connection_service.get_connection(&connection_id) {
let session_id = state.session_id.clone().unwrap();
auth_client.logout(&session_id).await?;
Ok(())
} else {
Err("Unable to find connection state".into())
}
}
pub(crate) async fn handle_login_req(stream: &mut TcpStream, packet: Packet, auth_client: Arc<Mutex<AuthClient>>, connection_service: Arc<ConnectionService>, connection_id: String, addr: SocketAddr) -> Result<(), Box<dyn Error + Send + Sync>> {
debug!("decoding packet payload of size {}", packet.payload.as_slice().len()); debug!("decoding packet payload of size {}", packet.payload.as_slice().len());
let data = CliLoginTokenReq::decode(packet.payload.as_slice())?; let data = CliLoginTokenReq::decode(packet.payload.as_slice())?;
debug!("{:?}", data); debug!("{:?}", data);
@@ -56,6 +68,11 @@ pub(crate) async fn handle_login_req(stream: &mut TcpStream, packet: Packet, aut
} else { } else {
debug!("Successfully logged in"); debug!("Successfully logged in");
if let Some(mut state) = connection_service.get_connection(&connection_id) {
state.user_id = Some(response.user_id.parse().unwrap());
// auth_client.logout(&session_id).await?;
}
let consul_url = env::var("CONSUL_URL").unwrap_or_else(|_| "http://127.0.0.1:8500".to_string()); let consul_url = env::var("CONSUL_URL").unwrap_or_else(|_| "http://127.0.0.1:8500".to_string());
let servers = service_discovery::get_service_address(&consul_url, "character-service").await.unwrap_or_else(|err| { let servers = service_discovery::get_service_address(&consul_url, "character-service").await.unwrap_or_else(|err| {
warn!(err); warn!(err);
@@ -125,10 +142,10 @@ pub(crate) async fn handle_server_select_req(stream: &mut TcpStream, packet: Pac
let data = SrvSrvSelectReply { let data = SrvSrvSelectReply {
result: srv_srv_select_reply::Result::Failed, result: srv_srv_select_reply::Result::Failed,
session_id: 0, session_id: 0, // Client should already have this value
crypt_val: 0, crypt_val: 0, // This is only for the old encryption
ip: NullTerminatedString::new(""), ip: NullTerminatedString::new(""), // If this is empty, the client should stay connected (requires client change)
port: 0, port: 0, // See comment about ip above
}; };
let response_packet = Packet::new(PacketType::PaklcSrvSelectReply, &data)?; let response_packet = Packet::new(PacketType::PaklcSrvSelectReply, &data)?;

View File

@@ -18,6 +18,7 @@ use tracing::{debug, error, info, warn};
use utils::consul_registration; use utils::consul_registration;
use utils::service_discovery::get_service_address; use utils::service_discovery::get_service_address;
use warp::Filter; use warp::Filter;
use crate::connection_service::ConnectionService;
mod packet_type; mod packet_type;
mod packet; mod packet;
@@ -30,6 +31,9 @@ mod handlers;
mod bufferpool; mod bufferpool;
mod metrics; mod metrics;
mod auth_client; mod auth_client;
mod connection_state;
mod connection_service;
pub mod auth { pub mod auth {
tonic::include_proto!("auth"); // Path matches the package name in auth.proto tonic::include_proto!("auth"); // Path matches the package name in auth.proto
} }
@@ -38,7 +42,7 @@ const BUFFER_POOL_SIZE: usize = 1000;
const MAX_CONCURRENT_CONNECTIONS: usize = 100; const MAX_CONCURRENT_CONNECTIONS: usize = 100;
async fn handle_connection(stream: &mut TcpStream, pool: Arc<BufferPool>, auth_client: Arc<Mutex<AuthClient>>) -> Result<(), Box<dyn Error + Send + Sync>> { async fn handle_connection(stream: &mut TcpStream, pool: Arc<BufferPool>, auth_client: Arc<Mutex<AuthClient>>, connection_service: Arc<ConnectionService>, connection_id: String) -> Result<(), Box<dyn Error + Send + Sync>> {
ACTIVE_CONNECTIONS.inc(); ACTIVE_CONNECTIONS.inc();
while let Some(mut buffer) = pool.acquire().await { while let Some(mut buffer) = pool.acquire().await {
// Read data into the buffer // Read data into the buffer
@@ -53,13 +57,19 @@ async fn handle_connection(stream: &mut TcpStream, pool: Arc<BufferPool>, auth_c
Ok(packet) => { Ok(packet) => {
debug!("Parsed Packet: {:?}", packet); debug!("Parsed Packet: {:?}", packet);
// Handle the parsed packet (route it, process it, etc.) // Handle the parsed packet (route it, process it, etc.)
router::route_packet(stream, packet, auth_client.clone()).await?; router::route_packet(stream, packet, auth_client.clone(), connection_service.clone(), connection_id.clone()).await?;
} }
Err(e) => warn!("Failed to parse packet: {}", e), Err(e) => warn!("Failed to parse packet: {}", e),
} }
pool.release(buffer).await; pool.release(buffer).await;
} }
if let Some(state) = connection_service.get_connection(&connection_id) {
let session_id = state.session_id.unwrap();
let mut auth_client = auth_client.lock().await;
auth_client.logout(&session_id).await?;
}
ACTIVE_CONNECTIONS.dec(); ACTIVE_CONNECTIONS.dec();
Ok(()) Ok(())
} }
@@ -113,12 +123,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_CONNECTIONS)); let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_CONNECTIONS));
let listener = TcpListener::bind(full_addr.clone()).await.unwrap(); let listener = TcpListener::bind(full_addr.clone()).await.unwrap();
let buffer_pool = BufferPool::new(BUFFER_POOL_SIZE); let buffer_pool = BufferPool::new(BUFFER_POOL_SIZE);
let connection_service = Arc::new(ConnectionService::new());
info!("Packet service listening on {}", full_addr); info!("Packet service listening on {}", full_addr);
loop { loop {
let (mut socket, addr) = listener.accept().await.unwrap(); let (mut socket, addr) = listener.accept().await.unwrap();
let auth_client = auth_client.clone(); let auth_client = auth_client.clone();
let connection_service = connection_service.clone();
info!("New connection from {}", addr); info!("New connection from {}", addr);
let pool = buffer_pool.clone(); let pool = buffer_pool.clone();
@@ -127,9 +139,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Spawn a new task for each connection // Spawn a new task for each connection
tokio::spawn(async move { tokio::spawn(async move {
let _permit = permit; let _permit = permit;
if let Err(e) = handle_connection(&mut socket, pool, auth_client).await { let connection_id = connection_service.add_connection();
if let Err(e) = handle_connection(&mut socket, pool, auth_client, connection_service.clone(), connection_id.clone()).await {
error!("Error handling connection: {}", e); error!("Error handling connection: {}", e);
} }
connection_service.remove_connection(&connection_id);
}); });
} }
}); });

View File

@@ -7,15 +7,17 @@ use std::sync::Arc;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::{debug, warn}; use tracing::{debug, warn};
use crate::connection_service::ConnectionService;
pub async fn route_packet(stream: &mut TcpStream, packet: Packet, auth_client: Arc<Mutex<AuthClient>>) -> Result<(), Box<dyn Error + Send + Sync>> { pub async fn route_packet(stream: &mut TcpStream, packet: Packet, auth_client: Arc<Mutex<AuthClient>>, connection_service: Arc<ConnectionService>, connection_id: String) -> Result<(), Box<dyn Error + Send + Sync>> {
debug!("Routing packet: {:?}", packet); debug!("Routing packet: {:?}", packet);
match packet.packet_type { match packet.packet_type {
PacketType::PakcsAlive => Ok(()), PacketType::PakcsAlive => Ok(()),
PacketType::PakcsAcceptReq => auth::handle_accept_req(stream, packet).await, PacketType::PakcsAcceptReq => auth::handle_accept_req(stream, packet).await,
PacketType::PakcsJoinServerTokenReq => auth::handle_join_server_req(stream, packet).await, PacketType::PakcsJoinServerTokenReq => auth::handle_join_server_req(stream, packet).await,
// Login Stuff // Login Stuff
PacketType::PakcsLoginTokenReq => auth::handle_login_req(stream, packet, auth_client).await, PacketType::PakcsLoginTokenReq => auth::handle_login_req(stream, packet, auth_client, connection_service, connection_id, stream.peer_addr()?).await,
PacketType::PakcsLogoutReq => auth::handle_logout_req(stream, packet, auth_client, connection_service, connection_id).await,
PacketType::PakcsSrvSelectReq => auth::handle_server_select_req(stream, packet).await, PacketType::PakcsSrvSelectReq => auth::handle_server_select_req(stream, packet).await,
PacketType::PakcsChannelListReq => auth::handle_channel_list_req(stream, packet).await, PacketType::PakcsChannelListReq => auth::handle_channel_list_req(stream, packet).await,

View File

@@ -4,7 +4,9 @@ package auth;
service AuthService { service AuthService {
rpc Login(LoginRequest) returns (LoginResponse); rpc Login(LoginRequest) returns (LoginResponse);
rpc Logout(LogoutRequest) returns (Empty);
rpc ValidateToken(ValidateTokenRequest) returns (ValidateTokenResponse); rpc ValidateToken(ValidateTokenRequest) returns (ValidateTokenResponse);
rpc ValidateSession(ValidateSessionRequest) returns (ValidateSessionResponse);
rpc Register (RegisterRequest) returns (RegisterResponse); rpc Register (RegisterRequest) returns (RegisterResponse);
rpc RequestPasswordReset (PasswordResetRequest) returns (PasswordResetResponse); rpc RequestPasswordReset (PasswordResetRequest) returns (PasswordResetResponse);
rpc ResetPassword (ResetPasswordRequest) returns (ResetPasswordResponse); rpc ResetPassword (ResetPasswordRequest) returns (ResetPasswordResponse);
@@ -13,11 +15,17 @@ service AuthService {
message LoginRequest { message LoginRequest {
string username = 1; string username = 1;
string password = 2; string password = 2;
string ip_address = 3;
} }
message LoginResponse { message LoginResponse {
string token = 1; string token = 1;
string user_id = 2; string user_id = 2;
string session_id = 3;
}
message LogoutRequest {
string session_id = 1;
} }
message ValidateTokenRequest { message ValidateTokenRequest {
@@ -29,6 +37,14 @@ message ValidateTokenResponse {
string user_id = 2; string user_id = 2;
} }
message ValidateSessionRequest {
string session_id = 1;
}
message ValidateSessionResponse {
bool valid = 1;
}
message RegisterRequest { message RegisterRequest {
string username = 1; string username = 1;
string email = 2; string email = 2;
@@ -56,3 +72,5 @@ message ResetPasswordRequest {
message ResetPasswordResponse { message ResetPasswordResponse {
string message = 1; string message = 1;
} }
message Empty {}

View File

@@ -12,3 +12,7 @@ uuid = { version = "1.11.0", features = ["v4"] }
warp = "0.3.7" warp = "0.3.7"
tokio = "1.41.1" tokio = "1.41.1"
bincode = { version = "2.0.0-rc.3", features = ["derive", "serde"] } bincode = { version = "2.0.0-rc.3", features = ["derive", "serde"] }
redis = "0.27.5"
deadpool-redis = "0.18.0"
async-trait = "0.1.83"
serde_json = "1.0.133"