diff --git a/auth-service/src/main.rs b/auth-service/src/main.rs index c4425c5..dc32c8b 100644 --- a/auth-service/src/main.rs +++ b/auth-service/src/main.rs @@ -9,6 +9,7 @@ use std::str::FromStr; use tokio::{select, signal}; use tonic::transport::Server; use tracing::{info, Level}; +use tracing::log::debug; use warp::Filter; use utils::consul_registration; use utils::service_discovery::get_service_address; @@ -33,7 +34,7 @@ async fn main() -> Result<(), Box> { let service_port = port.clone(); let health_check_url = format!("http://{}:{}/health", service_address, health_port); let health_check_endpoint_addr = format!("{}:{}", service_address, health_port); - let db_address = get_service_address(&consul_url, "database-service").await?; + let db_nodes = get_service_address(&consul_url, "database-service").await?; // Register service with Consul let service_id = consul_registration::generate_service_id(); @@ -53,7 +54,8 @@ async fn main() -> Result<(), Box> { tokio::spawn(warp::serve(health_route).run(health_check_endpoint_addr.to_socket_addrs()?.next().unwrap())); - let db_url = format!("http://{}:{}", db_address.Address, db_address.Port); + let db_address = db_nodes.get(0).unwrap(); + let db_url = format!("http://{}:{}", db_address.0, db_address.1); let database_client = DatabaseClient::connect(&db_url).await?; let full_addr = format!("{}:{}", &addr, port); diff --git a/packet-service/src/main.rs b/packet-service/src/main.rs index 86bffc6..0c353ff 100644 --- a/packet-service/src/main.rs +++ b/packet-service/src/main.rs @@ -81,7 +81,7 @@ async fn main() -> Result<(), Box> { let service_port = port.clone(); let health_check_url = format!("http://{}:{}/health", service_address, health_port); let health_check_endpoint_addr = format!("{}:{}", service_address, health_port); - let auth_address = get_service_address(&consul_url, "auth-service").await?; + let auth_node = get_service_address(&consul_url, "auth-service").await?; // Register service with Consul let service_id = consul_registration::generate_service_id(); @@ -101,14 +101,13 @@ async fn main() -> Result<(), Box> { tokio::spawn(warp::serve(health_route).run(health_check_endpoint_addr.to_socket_addrs()?.next().unwrap())); - let auth_url = format!("http://{}:{}", auth_address.Address, auth_address.Port); + let auth_address = auth_node.get(0).unwrap(); + let auth_url = format!("http://{}:{}", auth_address.0, auth_address.1); let auth_client = Arc::new(Mutex::new(AuthClient::connect(&auth_url).await?)); let full_addr = format!("{}:{}", &addr, port); // let address = full_addr.parse().expect("Invalid address"); - - tokio::spawn(async move { let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_CONNECTIONS)); let listener = TcpListener::bind(full_addr.clone()).await.unwrap(); diff --git a/utils/src/service_discovery.rs b/utils/src/service_discovery.rs index 09ab7f2..edea219 100644 --- a/utils/src/service_discovery.rs +++ b/utils/src/service_discovery.rs @@ -1,48 +1,40 @@ use serde::Deserialize; -#[derive(Deserialize)] -struct Address { - Address: String, - Port: u16, +#[derive(Debug, Deserialize)] +struct ServiceNode { + ServiceAddress: String, + ServicePort: u16, } -#[derive(Deserialize)] -struct TaggedAddresses { - lan_ipv4: Address, - wan_ipv4: Address, -} - -#[derive(Deserialize)] -struct Weights { - Passing: u8, - Warning: u8, -} - -#[derive(Deserialize)] -pub struct Service { - pub ID: String, - pub Service: String, - pub Tags: Vec, - pub Port: u16, - pub Address: String, - pub TaggedAddresses: TaggedAddresses, - pub Weights: Weights, - pub EnableTagOverride: bool, - pub ContentHash: String, - pub Datacenter: String, -} - -pub async fn get_service_address(consul_url: &str, service_name: &str) -> Result<(Service), Box> { +pub async fn get_service_address(consul_url: &str, service_name: &str) -> Result, Box> { let client = reqwest::Client::new(); - let consul_service_url = format!("{}/v1/agent/service/{}", consul_url, service_name); + let consul_service_url = format!("{}/v1/catalog/service/{}", consul_url, service_name); let response = client .get(&consul_service_url) .send() - .await? - .error_for_status()? - .json::() - .await?; // Ensure response is successful + .await?; - Ok(response) + if !response.status().is_success() { + return Err(format!( + "Failed to fetch service nodes for '{}': {}", + service_name, response.status() + ) + .into()); + } + + // Deserialize the response into a Vec + let nodes: Vec = response.json().await?; + + // Map the nodes to (address, port) tuples + let addresses: Vec<(String, u16)> = nodes + .into_iter() + .map(|node| (node.ServiceAddress, node.ServicePort)) + .collect(); + + if addresses.is_empty() { + Err(format!("No nodes found for service '{}'", service_name).into()) + } else { + Ok(addresses) + } }