From cdf7bb3f159b5fedcaf8c572096e88d92b0c421f7c16b052fcee4452475a0978 Mon Sep 17 00:00:00 2001 From: raven <7156279+RavenX8@users.noreply.github.com> Date: Wed, 12 Mar 2025 15:39:45 -0400 Subject: [PATCH] - update: tell consul to use docker dns to resolve CNAME addresses - add: load balancer for consul services - update: dns lookup to now return the service address - update: docker consul to the latest version --- auth-service/src/main.rs | 46 ++++---- docker-compose.yml | 2 +- scripts/consul.json | 3 + utils/Cargo.toml | 1 + utils/src/lib.rs | 1 + utils/src/multi_service_load_balancer.rs | 135 +++++++++++++++++++++++ utils/src/service_discovery.rs | 9 +- 7 files changed, 173 insertions(+), 24 deletions(-) create mode 100644 utils/src/multi_service_load_balancer.rs diff --git a/auth-service/src/main.rs b/auth-service/src/main.rs index 741bac4..16cbb5e 100644 --- a/auth-service/src/main.rs +++ b/auth-service/src/main.rs @@ -11,6 +11,7 @@ use std::sync::Arc; use tonic::transport::Server; use tracing::{debug, info, Level}; use utils::consul_registration; +use utils::multi_service_load_balancer::{LoadBalancingStrategy, MultiServiceLoadBalancer}; use utils::service_discovery::{get_service_address, get_service_endpoints_by_dns}; #[tokio::main] @@ -33,14 +34,35 @@ async fn main() -> Result<(), Box> { let consul_port = env::var("CONSUL_PORT").unwrap_or_else(|_| "8500".to_string()); let consul_dns_port = env::var("CONSUL_DNS_PORT").unwrap_or_else(|_| "8600".to_string()); let consul_url = format!("http://{}:{}", consul_address, consul_port); - // let consul_url = env::var("CONSUL_URL").unwrap_or_else(|_| "http://127.0.0.1:8500".to_string()); + let consul_dns_url = format!("{}:{}", consul_address, consul_dns_port); let service_name = env::var("SERVICE_NAME").unwrap_or_else(|_| "auth-service".to_string()); let service_address = env::var("AUTH_SERVICE_ADDR").unwrap_or_else(|_| "127.0.0.1".to_string()); let service_port = port.clone(); - let db_nodes = get_service_address(&consul_url, "database-service").await?; - let session_nodes = get_service_address(&consul_url, "session-service").await?; - let temp_session_nodes = get_service_endpoints_by_dns(format!("{}:{}", consul_address, consul_dns_port).as_str(), "grpc", "session-service").await?; - debug!("{:?}", temp_session_nodes); + + let lb = MultiServiceLoadBalancer::new(&consul_dns_url, LoadBalancingStrategy::RoundRobin); + + let mut db_url = "".to_string(); + match lb.get_endpoint("database-service", "grpc").await? { + Some(endpoint) => { + db_url = format!("http://{}", endpoint); + }, + None => { + println!("No endpoints available for database-service"); + } + } + + let mut session_service_address = "".to_string(); + match lb.get_endpoint("session-service", "grpc").await? { + Some(endpoint) => { + session_service_address = format!("http://{}", endpoint); + }, + None => { + println!("No endpoints available for session-service"); + } + } + + let db_client = Arc::new(DatabaseClient::connect(&db_url).await?); + let session_client = Arc::new(SessionServiceClient::connect(session_service_address).await?); // Register service with Consul let service_id = consul_registration::get_or_generate_service_id(env!("CARGO_PKG_NAME")); @@ -60,20 +82,6 @@ async fn main() -> Result<(), Box> { ) .await?; - let db_address = db_nodes.get(0).unwrap(); - let db_url = format!( - "http://{}:{}", - db_address.ServiceAddress, db_address.ServicePort - ); - let db_client = Arc::new(DatabaseClient::connect(&db_url).await?); - - let session_address = session_nodes.get(0).unwrap(); - let session_address = format!( - "http://{}:{}", - session_address.ServiceAddress, session_address.ServicePort - ); - let session_client = Arc::new(SessionServiceClient::connect(session_address).await?); - let full_addr = format!("{}:{}", &addr, port); let address = full_addr.parse().expect("Invalid address"); let auth_service = MyAuthService { diff --git a/docker-compose.yml b/docker-compose.yml index 242dbd5..2b395b4 100755 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -175,7 +175,7 @@ - ./sql/schema.sql:/docker-entrypoint-initdb.d/schema.sql:ro consul: - image: consul:1.15.4 + image: hashicorp/consul:latest command: [ "agent", "-dev", diff --git a/scripts/consul.json b/scripts/consul.json index c36b4d0..f01ca0a 100755 --- a/scripts/consul.json +++ b/scripts/consul.json @@ -1,4 +1,7 @@ { + "recursors": [ + "127.0.0.11" + ], "http_config": { "response_headers": { "Access-Control-Allow-Origin": "*" diff --git a/utils/Cargo.toml b/utils/Cargo.toml index 90d6284..72dbab9 100644 --- a/utils/Cargo.toml +++ b/utils/Cargo.toml @@ -16,3 +16,4 @@ deadpool-redis = "0.20.0" async-trait = "0.1.87" serde_json = "1.0.140" hickory-resolver = "0.24.4" +rand = "0.8.5" diff --git a/utils/src/lib.rs b/utils/src/lib.rs index ff75744..06e1115 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -3,3 +3,4 @@ pub mod null_string; pub mod redis_cache; pub mod service_discovery; pub mod signal_handler; +pub mod multi_service_load_balancer; diff --git a/utils/src/multi_service_load_balancer.rs b/utils/src/multi_service_load_balancer.rs new file mode 100644 index 0000000..81d49aa --- /dev/null +++ b/utils/src/multi_service_load_balancer.rs @@ -0,0 +1,135 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; +use rand::seq::SliceRandom; +use crate::service_discovery::get_service_endpoints_by_dns; + +pub enum LoadBalancingStrategy { + Random, + RoundRobin, +} + +// Service identifier +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct ServiceId { + pub name: String, + pub protocol: String, +} + +impl ServiceId { + pub fn new(name: &str, protocol: &str) -> Self { + ServiceId { + name: name.to_string(), + protocol: protocol.to_string(), + } + } +} + +// Per-service state +struct ServiceState { + endpoints: Vec, + current_index: usize, +} + +impl ServiceState { + fn new(endpoints: Vec) -> Self { + ServiceState { + endpoints, + current_index: 0, + } + } + + fn get_endpoint(&mut self, strategy: &LoadBalancingStrategy) -> Option { + if self.endpoints.is_empty() { + return None; + } + + match strategy { + LoadBalancingStrategy::Random => { + let mut rng = rand::thread_rng(); + self.endpoints.choose(&mut rng).copied() + } + LoadBalancingStrategy::RoundRobin => { + let endpoint = self.endpoints[self.current_index].clone(); + self.current_index = (self.current_index + 1) % self.endpoints.len(); + Some(endpoint) + } + } + } +} + +pub struct MultiServiceLoadBalancer { + consul_url: String, + strategy: LoadBalancingStrategy, + services: Arc>>, +} + +impl MultiServiceLoadBalancer { + pub fn new(consul_url: &str, strategy: LoadBalancingStrategy) -> Self { + MultiServiceLoadBalancer { + consul_url: consul_url.to_string(), + strategy, + services: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub async fn get_endpoint( + &self, + service_name: &str, + service_protocol: &str, + ) -> Result, Box> { + let service_id = ServiceId::new(service_name, service_protocol); + + // Try to get an endpoint from the cache first + { + let mut services = self.services.lock().unwrap(); + if let Some(service_state) = services.get_mut(&service_id) { + if let Some(endpoint) = service_state.get_endpoint(&self.strategy) { + return Ok(Some(endpoint)); + } + } + } + + // If we don't have endpoints or they're all unavailable, refresh them + self.refresh_service_endpoints(service_name, service_protocol).await?; + + // Try again after refresh + let mut services = self.services.lock().unwrap(); + if let Some(service_state) = services.get_mut(&service_id) { + return Ok(service_state.get_endpoint(&self.strategy)); + } + + Ok(None) + } + + pub async fn refresh_service_endpoints( + &self, + service_name: &str, + service_protocol: &str, + ) -> Result<(), Box> { + let endpoints = get_service_endpoints_by_dns( + &self.consul_url, + service_protocol, + service_name, + ).await?; + + let service_id = ServiceId::new(service_name, service_protocol); + let mut services = self.services.lock().unwrap(); + + services.insert(service_id, ServiceState::new(endpoints)); + Ok(()) + } + + pub async fn refresh_all_services(&self) -> Result<(), Box> { + let service_ids: Vec = { + let services = self.services.lock().unwrap(); + services.keys().cloned().collect() + }; + + for service_id in service_ids { + self.refresh_service_endpoints(&service_id.name, &service_id.protocol).await?; + } + + Ok(()) + } +} \ No newline at end of file diff --git a/utils/src/service_discovery.rs b/utils/src/service_discovery.rs index fe44d48..82471e2 100644 --- a/utils/src/service_discovery.rs +++ b/utils/src/service_discovery.rs @@ -3,6 +3,7 @@ use hickory_resolver::{Resolver, TokioAsyncResolver}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::SocketAddr; +use std::str::FromStr; use tokio::runtime::Runtime; use tracing::log::debug; @@ -17,12 +18,12 @@ pub async fn get_service_endpoints_by_dns(consul_url: &str, service_protocol: &s let srv_record = resolver.srv_lookup(&srv_name).await?; let mut endpoints = Vec::new(); - debug!("service records: {:?}", srv_record); for record in srv_record { let hostname = record.target(); - debug!("hostname: {:?}", hostname); - - // endpoints.push(SocketAddr::new(, record.port())); + let lookup_responses = resolver.lookup_ip(hostname.to_string()).await?; + for response in lookup_responses { + endpoints.push(SocketAddr::from_str(&format!("{}:{}", &response.to_string(), record.port()))?); + } } Ok(endpoints)