summaryrefslogtreecommitdiff
path: root/utils/src/socket_addr_helper.rs
diff options
context:
space:
mode:
author魏曹先生 <1992414357@qq.com>2026-01-27 06:02:59 +0800
committer魏曹先生 <1992414357@qq.com>2026-01-27 06:02:59 +0800
commit4eef9ce364bb660421a96052a3fb126a33b22c63 (patch)
treea36947411d83205dc743881cd2a30d8c907d4b57 /utils/src/socket_addr_helper.rs
parent243d521fd19af169910506529e737a797e9bc583 (diff)
Extract CLI utilities into a separate crate
Diffstat (limited to 'utils/src/socket_addr_helper.rs')
-rw-r--r--utils/src/socket_addr_helper.rs194
1 files changed, 194 insertions, 0 deletions
diff --git a/utils/src/socket_addr_helper.rs b/utils/src/socket_addr_helper.rs
new file mode 100644
index 0000000..29ccd9f
--- /dev/null
+++ b/utils/src/socket_addr_helper.rs
@@ -0,0 +1,194 @@
+use std::net::SocketAddr;
+use tokio::net::lookup_host;
+
+/// Helper function to parse a string into a SocketAddr with optional default port
+pub async fn get_socket_addr(
+ address_str: impl AsRef<str>,
+ default_port: u16,
+) -> Result<SocketAddr, std::io::Error> {
+ let address = address_str.as_ref().trim();
+
+ // Return error if input is empty after trimming
+ if address.is_empty() {
+ return Err(std::io::Error::new(
+ std::io::ErrorKind::InvalidInput,
+ "Empty address string",
+ ));
+ }
+
+ // Check if the address contains a port
+ if let Some((host, port_str)) = parse_host_and_port(address) {
+ let port = port_str.parse::<u16>().map_err(|e| {
+ std::io::Error::new(
+ std::io::ErrorKind::InvalidInput,
+ format!("Invalid port number '{}': {}", port_str, e),
+ )
+ })?;
+
+ return resolve_to_socket_addr(host, port).await;
+ }
+
+ // No port specified, use default port
+ resolve_to_socket_addr(address, default_port).await
+}
+
+/// Parse host and port from address string
+fn parse_host_and_port(address: &str) -> Option<(&str, &str)> {
+ if address.starts_with('[')
+ && let Some(close_bracket) = address.find(']')
+ && close_bracket + 1 < address.len()
+ && address.as_bytes()[close_bracket + 1] == b':'
+ {
+ let host = &address[1..close_bracket];
+ let port = &address[close_bracket + 2..];
+ return Some((host, port));
+ }
+
+ // Handle IPv4 addresses and hostnames with ports
+ if let Some(colon_pos) = address.rfind(':') {
+ // Check if this is not part of an IPv6 address without brackets
+ if !address.contains('[') && !address.contains(']') {
+ let host = &address[..colon_pos];
+ let port = &address[colon_pos + 1..];
+
+ // Basic validation to avoid false positives
+ if !host.is_empty() && !port.is_empty() && port.chars().all(|c| c.is_ascii_digit()) {
+ return Some((host, port));
+ }
+ }
+ }
+
+ None
+}
+
+/// Resolve host to SocketAddr, handling both IP addresses and domain names
+async fn resolve_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, std::io::Error> {
+ // First try to parse as IP address (IPv4 or IPv6)
+ if let Ok(ip_addr) = host.parse() {
+ return Ok(SocketAddr::new(ip_addr, port));
+ }
+
+ // If it's not a valid IP address, treat it as a domain name and perform DNS lookup
+ let lookup_addr = format!("{}:{}", host, port);
+ let mut addrs = lookup_host(&lookup_addr).await?;
+
+ if let Some(addr) = addrs.next() {
+ Ok(addr)
+ } else {
+ Err(std::io::Error::new(
+ std::io::ErrorKind::NotFound,
+ format!("Could not resolve host '{}'", host),
+ ))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[tokio::test]
+ async fn test_ipv4_with_port() {
+ let result = get_socket_addr("127.0.0.1:8080", 80).await;
+ assert!(result.is_ok());
+ let addr = result.unwrap();
+ assert_eq!(addr.port(), 8080);
+ assert_eq!(addr.ip().to_string(), "127.0.0.1");
+ }
+
+ #[tokio::test]
+ async fn test_ipv4_without_port() {
+ let result = get_socket_addr("192.168.1.1", 443).await;
+ assert!(result.is_ok());
+ let addr = result.unwrap();
+ assert_eq!(addr.port(), 443);
+ assert_eq!(addr.ip().to_string(), "192.168.1.1");
+ }
+
+ #[tokio::test]
+ async fn test_ipv6_with_port() {
+ let result = get_socket_addr("[::1]:8080", 80).await;
+ assert!(result.is_ok());
+ let addr = result.unwrap();
+ assert_eq!(addr.port(), 8080);
+ assert_eq!(addr.ip().to_string(), "::1");
+ }
+
+ #[tokio::test]
+ async fn test_ipv6_without_port() {
+ let result = get_socket_addr("[::1]", 443).await;
+ assert!(result.is_ok());
+ let addr = result.unwrap();
+ assert_eq!(addr.port(), 443);
+ assert_eq!(addr.ip().to_string(), "::1");
+ }
+
+ #[tokio::test]
+ async fn test_invalid_port() {
+ let result = get_socket_addr("127.0.0.1:99999", 80).await;
+ assert!(result.is_err());
+ }
+
+ #[tokio::test]
+ async fn test_empty_string() {
+ let result = get_socket_addr("", 80).await;
+ assert!(result.is_err());
+ }
+
+ #[tokio::test]
+ async fn test_whitespace_trimming() {
+ let result = get_socket_addr(" 127.0.0.1:8080 ", 80).await;
+ assert!(result.is_ok());
+ let addr = result.unwrap();
+ assert_eq!(addr.port(), 8080);
+ }
+
+ #[tokio::test]
+ async fn test_domain_name_with_port() {
+ // This test will only pass if localhost resolves
+ let result = get_socket_addr("localhost:8080", 80).await;
+ if result.is_ok() {
+ let addr = result.unwrap();
+ assert_eq!(addr.port(), 8080);
+ // localhost should resolve to 127.0.0.1 or ::1
+ assert!(addr.ip().is_loopback());
+ }
+ }
+
+ #[tokio::test]
+ async fn test_domain_name_without_port() {
+ // This test will only pass if localhost resolves
+ let result = get_socket_addr("localhost", 443).await;
+ if result.is_ok() {
+ let addr = result.unwrap();
+ assert_eq!(addr.port(), 443);
+ // localhost should resolve to 127.0.0.1 or ::1
+ assert!(addr.ip().is_loopback());
+ }
+ }
+
+ #[tokio::test]
+ async fn test_parse_host_and_port() {
+ // IPv4 with port
+ assert_eq!(
+ parse_host_and_port("192.168.1.1:8080"),
+ Some(("192.168.1.1", "8080"))
+ );
+
+ // IPv6 with port
+ assert_eq!(parse_host_and_port("[::1]:8080"), Some(("::1", "8080")));
+
+ // Hostname with port
+ assert_eq!(
+ parse_host_and_port("example.com:443"),
+ Some(("example.com", "443"))
+ );
+
+ // No port
+ assert_eq!(parse_host_and_port("192.168.1.1"), None);
+ assert_eq!(parse_host_and_port("example.com"), None);
+
+ // Invalid cases
+ assert_eq!(parse_host_and_port(":"), None);
+ assert_eq!(parse_host_and_port("192.168.1.1:"), None);
+ }
+}