use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use utils::multi_service_load_balancer::{LoadBalancingStrategy, MultiServiceLoadBalancer, ServiceId}; // Mock implementation for testing without actual service discovery mod mock { use super::*; use std::collections::HashMap; use std::sync::{Arc, Mutex}; // Mock version of the load balancer for testing pub struct MockMultiServiceLoadBalancer { strategy: LoadBalancingStrategy, services: Arc>>>, } impl MockMultiServiceLoadBalancer { pub fn new(strategy: LoadBalancingStrategy) -> Self { MockMultiServiceLoadBalancer { strategy, services: Arc::new(Mutex::new(HashMap::new())), } } pub fn add_service(&self, service_name: &str, service_protocol: &str, endpoints: Vec) { let service_id = ServiceId::new(service_name, service_protocol); let mut services = self.services.lock().unwrap(); services.insert(service_id, endpoints); } pub fn get_endpoint(&self, service_name: &str, service_protocol: &str) -> Option { let service_id = ServiceId::new(service_name, service_protocol); let services = self.services.lock().unwrap(); if let Some(endpoints) = services.get(&service_id) { if endpoints.is_empty() { return None; } match self.strategy { LoadBalancingStrategy::Random => { let index = rand::random::() % endpoints.len(); Some(endpoints[index]) }, LoadBalancingStrategy::RoundRobin => { // For simplicity in tests, just return the first endpoint Some(endpoints[0]) } } } else { None } } } } #[test] fn test_service_id() { let service_id1 = ServiceId::new("service1", "http"); let service_id2 = ServiceId::new("service1", "http"); let service_id3 = ServiceId::new("service2", "http"); let service_id4 = ServiceId::new("service1", "https"); // Test equality assert_eq!(service_id1, service_id2); assert_ne!(service_id1, service_id3); assert_ne!(service_id1, service_id4); // Test hash implementation let mut set = HashSet::new(); set.insert(service_id1); assert!(set.contains(&service_id2)); assert!(!set.contains(&service_id3)); assert!(!set.contains(&service_id4)); } #[test] fn test_mock_load_balancer_random() { let lb = mock::MockMultiServiceLoadBalancer::new(LoadBalancingStrategy::Random); // Add a service with multiple endpoints let endpoints = vec![ SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2)), 8080), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 3)), 8080), ]; lb.add_service("test-service", "http", endpoints.clone()); // Get an endpoint let endpoint = lb.get_endpoint("test-service", "http"); assert!(endpoint.is_some()); assert!(endpoints.contains(&endpoint.unwrap())); // Test non-existent service let endpoint = lb.get_endpoint("non-existent", "http"); assert!(endpoint.is_none()); } #[test] fn test_mock_load_balancer_round_robin() { let lb = mock::MockMultiServiceLoadBalancer::new(LoadBalancingStrategy::RoundRobin); // Add a service with multiple endpoints let endpoints = vec![ SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2)), 8080), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 3)), 8080), ]; lb.add_service("test-service", "http", endpoints); // Get an endpoint let endpoint = lb.get_endpoint("test-service", "http"); assert!(endpoint.is_some()); // Test empty service lb.add_service("empty-service", "http", vec![]); let endpoint = lb.get_endpoint("empty-service", "http"); assert!(endpoint.is_none()); } // Integration test with the actual MultiServiceLoadBalancer // This test is disabled by default as it requires a Consul server #[tokio::test] async fn test_multi_service_load_balancer() { use std::env; // Skip test if CONSUL_TEST_ENABLED is not set to true if env::var("CONSUL_TEST_ENABLED").unwrap_or_else(|_| "false".to_string()) != "true" { println!("Skipping MultiServiceLoadBalancer test. Set CONSUL_TEST_ENABLED=true to run."); return; } let consul_url = env::var("TEST_CONSUL_URL").unwrap_or_else(|_| "http://localhost:8500".to_string()); let service_name = env::var("TEST_CONSUL_SERVICE_NAME").unwrap_or_else(|_| "database-service".to_string()); let protocol = "tcp"; let lb = MultiServiceLoadBalancer::new(&consul_url, LoadBalancingStrategy::Random); // Refresh service endpoints let result = lb.refresh_service_endpoints(&service_name, protocol).await; assert!(result.is_ok(), "Failed to refresh service endpoints: {:?}", result.err()); // Get an endpoint let result = lb.get_endpoint(&service_name, protocol).await; assert!(result.is_ok(), "Failed to get endpoint: {:?}", result.err()); let endpoint = result.unwrap(); assert!(endpoint.is_some(), "No endpoint found for service {}", service_name); println!("Found endpoint for service {}: {:?}", service_name, endpoint); }