use crate::service_discovery::get_service_endpoints_by_dns; use rand::seq::SliceRandom; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; pub enum LoadBalancingStrategy { Random, RoundRobin, } // Service identifier #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct ServiceId { pub name: String, pub protocol: String, } impl ServiceId { pub fn new(name: &str, protocol: &str) -> Self { ServiceId { name: name.to_string(), protocol: protocol.to_string(), } } } // Per-service state struct ServiceState { endpoints: Vec, current_index: usize, } impl ServiceState { fn new(endpoints: Vec) -> Self { ServiceState { endpoints, current_index: 0, } } fn get_endpoint(&mut self, strategy: &LoadBalancingStrategy) -> Option { if self.endpoints.is_empty() { return None; } match strategy { LoadBalancingStrategy::Random => { let mut rng = rand::thread_rng(); self.endpoints.choose(&mut rng).copied() } LoadBalancingStrategy::RoundRobin => { let endpoint = self.endpoints[self.current_index].clone(); self.current_index = (self.current_index + 1) % self.endpoints.len(); Some(endpoint) } } } } pub struct MultiServiceLoadBalancer { consul_url: String, strategy: LoadBalancingStrategy, services: Arc>>, } impl MultiServiceLoadBalancer { pub fn new(consul_url: &str, strategy: LoadBalancingStrategy) -> Self { MultiServiceLoadBalancer { consul_url: consul_url.to_string(), strategy, services: Arc::new(Mutex::new(HashMap::new())), } } pub async fn get_endpoint( &self, service_name: &str, service_protocol: &str, ) -> Result, Box> { let service_id = ServiceId::new(service_name, service_protocol); // Try to get an endpoint from the cache first { let mut services = self.services.lock().unwrap(); if let Some(service_state) = services.get_mut(&service_id) { if let Some(endpoint) = service_state.get_endpoint(&self.strategy) { return Ok(Some(endpoint)); } } } // If we don't have endpoints or they're all unavailable, refresh them self.refresh_service_endpoints(service_name, service_protocol).await?; // Try again after refresh let mut services = self.services.lock().unwrap(); if let Some(service_state) = services.get_mut(&service_id) { return Ok(service_state.get_endpoint(&self.strategy)); } Ok(None) } pub async fn refresh_service_endpoints( &self, service_name: &str, service_protocol: &str, ) -> Result<(), Box> { let endpoints = get_service_endpoints_by_dns(&self.consul_url, service_protocol, service_name).await?; let service_id = ServiceId::new(service_name, service_protocol); let mut services = self.services.lock().unwrap(); services.insert(service_id, ServiceState::new(endpoints)); Ok(()) } pub async fn refresh_all_services(&self) -> Result<(), Box> { let service_ids: Vec = { let services = self.services.lock().unwrap(); services.keys().cloned().collect() }; for service_id in service_ids { self.refresh_service_endpoints(&service_id.name, &service_id.protocol) .await?; } Ok(()) } }