summaryrefslogtreecommitdiff
path: root/crates/utils/tcp_connection/src
diff options
context:
space:
mode:
Diffstat (limited to 'crates/utils/tcp_connection/src')
-rw-r--r--crates/utils/tcp_connection/src/error.rs15
-rw-r--r--crates/utils/tcp_connection/src/handle.rs10
-rw-r--r--crates/utils/tcp_connection/src/instance.rs146
-rw-r--r--crates/utils/tcp_connection/src/lib.rs6
-rw-r--r--crates/utils/tcp_connection/src/target.rs198
-rw-r--r--crates/utils/tcp_connection/src/target_configure.rs53
-rw-r--r--crates/utils/tcp_connection/src/target_connection.rs85
7 files changed, 139 insertions, 374 deletions
diff --git a/crates/utils/tcp_connection/src/error.rs b/crates/utils/tcp_connection/src/error.rs
index 171e23d..ffcce6f 100644
--- a/crates/utils/tcp_connection/src/error.rs
+++ b/crates/utils/tcp_connection/src/error.rs
@@ -32,6 +32,9 @@ pub enum TcpTargetError {
#[error("Unsupported operation: {0}")]
Unsupported(String),
+
+ #[error("Pool already exists: {0}")]
+ PoolAlreadyExists(String),
}
impl From<io::Error> for TcpTargetError {
@@ -87,3 +90,15 @@ impl From<pem::PemError> for TcpTargetError {
TcpTargetError::Crypto(error.to_string())
}
}
+
+impl From<rmp_serde::encode::Error> for TcpTargetError {
+ fn from(error: rmp_serde::encode::Error) -> Self {
+ TcpTargetError::Serialization(error.to_string())
+ }
+}
+
+impl From<rmp_serde::decode::Error> for TcpTargetError {
+ fn from(error: rmp_serde::decode::Error) -> Self {
+ TcpTargetError::Serialization(error.to_string())
+ }
+}
diff --git a/crates/utils/tcp_connection/src/handle.rs b/crates/utils/tcp_connection/src/handle.rs
deleted file mode 100644
index ee77b43..0000000
--- a/crates/utils/tcp_connection/src/handle.rs
+++ /dev/null
@@ -1,10 +0,0 @@
-use crate::instance::ConnectionInstance;
-use std::future::Future;
-
-pub trait ClientHandle<RequestServer> {
- fn process(instance: ConnectionInstance) -> impl Future<Output = ()> + Send;
-}
-
-pub trait ServerHandle<RequestClient> {
- fn process(instance: ConnectionInstance) -> impl Future<Output = ()> + Send;
-}
diff --git a/crates/utils/tcp_connection/src/instance.rs b/crates/utils/tcp_connection/src/instance.rs
index 217b10a..fd620e2 100644
--- a/crates/utils/tcp_connection/src/instance.rs
+++ b/crates/utils/tcp_connection/src/instance.rs
@@ -48,7 +48,7 @@ impl Default for ConnectionConfig {
}
pub struct ConnectionInstance {
- stream: TcpStream,
+ pub(crate) stream: TcpStream,
config: ConnectionConfig,
}
@@ -90,6 +90,35 @@ impl ConnectionInstance {
Ok(())
}
+ /// Serialize data to MessagePack and write to the target machine
+ pub async fn write_msgpack<Data>(&mut self, data: Data) -> Result<(), TcpTargetError>
+ where
+ Data: Serialize,
+ {
+ let msgpack_data = rmp_serde::to_vec(&data)?;
+ let len = msgpack_data.len() as u32;
+
+ self.stream.write_all(&len.to_be_bytes()).await?;
+ self.stream.write_all(&msgpack_data).await?;
+ Ok(())
+ }
+
+ /// Read data from target machine and deserialize from MessagePack
+ pub async fn read_msgpack<Data>(&mut self) -> Result<Data, TcpTargetError>
+ where
+ Data: serde::de::DeserializeOwned,
+ {
+ let mut len_buf = [0u8; 4];
+ self.stream.read_exact(&mut len_buf).await?;
+ let len = u32::from_be_bytes(len_buf) as usize;
+
+ let mut buffer = vec![0; len];
+ self.stream.read_exact(&mut buffer).await?;
+
+ let data = rmp_serde::from_slice(&buffer)?;
+ Ok(data)
+ }
+
/// Read data from target machine and deserialize
pub async fn read<Data>(&mut self) -> Result<Data, TcpTargetError>
where
@@ -213,6 +242,73 @@ impl ConnectionInstance {
Ok(String::from_utf8_lossy(&buffer).to_string())
}
+ /// Write large MessagePack data to the target machine (chunked)
+ pub async fn write_large_msgpack<Data>(
+ &mut self,
+ data: Data,
+ chunk_size: impl Into<u32>,
+ ) -> Result<(), TcpTargetError>
+ where
+ Data: Serialize,
+ {
+ let msgpack_data = rmp_serde::to_vec(&data)?;
+ let chunk_size = chunk_size.into() as usize;
+ let len = msgpack_data.len() as u32;
+
+ // Write total length first
+ self.stream.write_all(&len.to_be_bytes()).await?;
+
+ // Write data in chunks
+ let mut offset = 0;
+ while offset < msgpack_data.len() {
+ let end = std::cmp::min(offset + chunk_size, msgpack_data.len());
+ let chunk = &msgpack_data[offset..end];
+ match self.stream.write(chunk).await {
+ Ok(n) => offset += n,
+ Err(err) => return Err(TcpTargetError::Io(err.to_string())),
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Read large MessagePack data from the target machine (chunked)
+ pub async fn read_large_msgpack<Data>(
+ &mut self,
+ chunk_size: impl Into<u32>,
+ ) -> Result<Data, TcpTargetError>
+ where
+ Data: serde::de::DeserializeOwned,
+ {
+ let chunk_size = chunk_size.into() as usize;
+
+ // Read total length first
+ let mut len_buf = [0u8; 4];
+ self.stream.read_exact(&mut len_buf).await?;
+ let total_len = u32::from_be_bytes(len_buf) as usize;
+
+ // Read data in chunks
+ let mut buffer = Vec::with_capacity(total_len);
+ let mut remaining = total_len;
+ let mut chunk_buf = vec![0; chunk_size];
+
+ while remaining > 0 {
+ let read_size = std::cmp::min(chunk_size, remaining);
+ let chunk = &mut chunk_buf[..read_size];
+
+ match self.stream.read_exact(chunk).await {
+ Ok(_) => {
+ buffer.extend_from_slice(chunk);
+ remaining -= read_size;
+ }
+ Err(err) => return Err(TcpTargetError::Io(err.to_string())),
+ }
+ }
+
+ let data = rmp_serde::from_slice(&buffer)?;
+ Ok(data)
+ }
+
/// Write file to target machine.
pub async fn write_file(&mut self, file_path: impl AsRef<Path>) -> Result<(), TcpTargetError> {
let path = file_path.as_ref();
@@ -319,9 +415,10 @@ impl ConnectionInstance {
// Make sure parent directory exists
if let Some(parent) = path.parent()
- && !parent.exists() {
- tokio::fs::create_dir_all(parent).await?;
- }
+ && !parent.exists()
+ {
+ tokio::fs::create_dir_all(parent).await?;
+ }
// Read file header (version + size + crc)
let mut version_buf = [0u8; 8];
@@ -398,15 +495,16 @@ impl ConnectionInstance {
// Validate CRC if enabled
if self.config.enable_crc_validation
- && let Some(crc_calculator) = crc_calculator {
- let actual_crc = crc_calculator.finalize();
- if actual_crc != expected_crc && expected_crc != 0 {
- return Err(TcpTargetError::File(format!(
- "CRC validation failed: expected {:08x}, got {:08x}",
- expected_crc, actual_crc
- )));
- }
+ && let Some(crc_calculator) = crc_calculator
+ {
+ let actual_crc = crc_calculator.finalize();
+ if actual_crc != expected_crc && expected_crc != 0 {
+ return Err(TcpTargetError::File(format!(
+ "CRC validation failed: expected {:08x}, got {:08x}",
+ expected_crc, actual_crc
+ )));
}
+ }
// Final flush and sync
writer.flush().await?;
@@ -576,22 +674,26 @@ fn parse_ed25519_public_key(pem: &str) -> [u8; 32] {
let mut key_bytes = [0u8; 32];
if let Ok(pem_data) = pem::parse(pem)
- && pem_data.tag() == "PUBLIC KEY" && pem_data.contents().len() >= 32 {
- let contents = pem_data.contents();
- key_bytes.copy_from_slice(&contents[contents.len() - 32..]);
- }
+ && pem_data.tag() == "PUBLIC KEY"
+ && pem_data.contents().len() >= 32
+ {
+ let contents = pem_data.contents();
+ key_bytes.copy_from_slice(&contents[contents.len() - 32..]);
+ }
key_bytes
}
/// Parse Ed25519 private key from PEM format
fn parse_ed25519_private_key(pem: &str) -> Result<SigningKey, TcpTargetError> {
if let Ok(pem_data) = pem::parse(pem)
- && pem_data.tag() == "PRIVATE KEY" && pem_data.contents().len() >= 32 {
- let contents = pem_data.contents();
- let mut seed = [0u8; 32];
- seed.copy_from_slice(&contents[contents.len() - 32..]);
- return Ok(SigningKey::from_bytes(&seed));
- }
+ && pem_data.tag() == "PRIVATE KEY"
+ && pem_data.contents().len() >= 32
+ {
+ let contents = pem_data.contents();
+ let mut seed = [0u8; 32];
+ seed.copy_from_slice(&contents[contents.len() - 32..]);
+ return Ok(SigningKey::from_bytes(&seed));
+ }
Err(TcpTargetError::Crypto(
"Invalid Ed25519 private key format".to_string(),
))
diff --git a/crates/utils/tcp_connection/src/lib.rs b/crates/utils/tcp_connection/src/lib.rs
index 928457b..a5b5c20 100644
--- a/crates/utils/tcp_connection/src/lib.rs
+++ b/crates/utils/tcp_connection/src/lib.rs
@@ -1,10 +1,4 @@
-pub mod target;
-pub mod target_configure;
-pub mod target_connection;
-
#[allow(dead_code)]
pub mod instance;
-pub mod handle;
-
pub mod error;
diff --git a/crates/utils/tcp_connection/src/target.rs b/crates/utils/tcp_connection/src/target.rs
deleted file mode 100644
index 88b931a..0000000
--- a/crates/utils/tcp_connection/src/target.rs
+++ /dev/null
@@ -1,198 +0,0 @@
-use crate::handle::{ClientHandle, ServerHandle};
-use crate::target_configure::{ClientTargetConfig, ServerTargetConfig};
-use serde::{Deserialize, Serialize};
-use std::{
- fmt::{Display, Formatter},
- marker::PhantomData,
- net::{AddrParseError, IpAddr, Ipv4Addr, SocketAddr},
- str::FromStr,
-};
-use tokio::net::lookup_host;
-
-const DEFAULT_PORT: u16 = 8080;
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct TcpServerTarget<Client, Server>
-where
- Client: ClientHandle<Server>,
- Server: ServerHandle<Client>,
-{
- /// Client Config
- client_cfg: Option<ClientTargetConfig>,
-
- /// Server Config
- server_cfg: Option<ServerTargetConfig>,
-
- /// Server port
- port: u16,
-
- /// Bind addr
- bind_addr: IpAddr,
-
- /// Client Phantom Data
- _client: PhantomData<Client>,
-
- /// Server Phantom Data
- _server: PhantomData<Server>,
-}
-
-impl<Client, Server> Default for TcpServerTarget<Client, Server>
-where
- Client: ClientHandle<Server>,
- Server: ServerHandle<Client>,
-{
- fn default() -> Self {
- Self {
- client_cfg: None,
- server_cfg: None,
- port: DEFAULT_PORT,
- bind_addr: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
- _client: PhantomData,
- _server: PhantomData,
- }
- }
-}
-
-impl<Client, Server> From<SocketAddr> for TcpServerTarget<Client, Server>
-where
- Client: ClientHandle<Server>,
- Server: ServerHandle<Client>,
-{
- /// Convert SocketAddr to TcpServerTarget
- fn from(value: SocketAddr) -> Self {
- Self {
- port: value.port(),
- bind_addr: value.ip(),
- ..Self::default()
- }
- }
-}
-
-impl<Client, Server> From<TcpServerTarget<Client, Server>> for SocketAddr
-where
- Client: ClientHandle<Server>,
- Server: ServerHandle<Client>,
-{
- /// Convert TcpServerTarget to SocketAddr
- fn from(val: TcpServerTarget<Client, Server>) -> Self {
- SocketAddr::new(val.bind_addr, val.port)
- }
-}
-
-impl<Client, Server> Display for TcpServerTarget<Client, Server>
-where
- Client: ClientHandle<Server>,
- Server: ServerHandle<Client>,
-{
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}:{}", self.bind_addr, self.port)
- }
-}
-
-impl<Client, Server> TcpServerTarget<Client, Server>
-where
- Client: ClientHandle<Server>,
- Server: ServerHandle<Client>,
-{
- /// Create target by address
- pub fn from_addr(addr: impl Into<IpAddr>, port: impl Into<u16>) -> Self {
- Self {
- port: port.into(),
- bind_addr: addr.into(),
- ..Self::default()
- }
- }
-
- /// Try to create target by string
- pub fn from_address_str<'a>(addr_str: impl Into<&'a str>) -> Result<Self, AddrParseError> {
- let socket_addr = SocketAddr::from_str(addr_str.into());
- match socket_addr {
- Ok(socket_addr) => Ok(Self::from_addr(socket_addr.ip(), socket_addr.port())),
- Err(err) => Err(err),
- }
- }
-
- /// Try to create target by domain name
- pub async fn from_domain<'a>(domain: impl Into<&'a str>) -> Result<Self, std::io::Error> {
- match domain_to_addr(domain).await {
- Ok(domain_addr) => Ok(Self::from(domain_addr)),
- Err(e) => Err(e),
- }
- }
-
- /// Set client config
- pub fn client_cfg(mut self, config: ClientTargetConfig) -> Self {
- self.client_cfg = Some(config);
- self
- }
-
- /// Set server config
- pub fn server_cfg(mut self, config: ServerTargetConfig) -> Self {
- self.server_cfg = Some(config);
- self
- }
-
- /// Add client config
- pub fn add_client_cfg(&mut self, config: ClientTargetConfig) {
- self.client_cfg = Some(config);
- }
-
- /// Add server config
- pub fn add_server_cfg(&mut self, config: ServerTargetConfig) {
- self.server_cfg = Some(config);
- }
-
- /// Get client config ref
- pub fn get_client_cfg(&self) -> Option<&ClientTargetConfig> {
- self.client_cfg.as_ref()
- }
-
- /// Get server config ref
- pub fn get_server_cfg(&self) -> Option<&ServerTargetConfig> {
- self.server_cfg.as_ref()
- }
-
- /// Get SocketAddr of TcpServerTarget
- pub fn get_addr(&self) -> SocketAddr {
- SocketAddr::new(self.bind_addr, self.port)
- }
-}
-
-/// Parse Domain Name to IpAddr via DNS
-async fn domain_to_addr<'a>(domain: impl Into<&'a str>) -> Result<SocketAddr, std::io::Error> {
- let domain = domain.into();
- let default_port: u16 = DEFAULT_PORT;
-
- if let Ok(socket_addr) = domain.parse::<SocketAddr>() {
- return Ok(match socket_addr.ip() {
- IpAddr::V4(_) => socket_addr,
- IpAddr::V6(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), socket_addr.port()),
- });
- }
-
- if let Ok(_v6_addr) = domain.parse::<std::net::Ipv6Addr>() {
- return Ok(SocketAddr::new(
- IpAddr::V4(Ipv4Addr::LOCALHOST),
- default_port,
- ));
- }
-
- let (host, port_str) = if let Some((host, port)) = domain.rsplit_once(':') {
- (host.trim_matches(|c| c == '[' || c == ']'), Some(port))
- } else {
- (domain, None)
- };
-
- let port = port_str
- .and_then(|p| p.parse::<u16>().ok())
- .map(|p| p.clamp(0, u16::MAX))
- .unwrap_or(default_port);
-
- let mut socket_iter = lookup_host((host, 0)).await?;
-
- if let Some(addr) = socket_iter.find(|addr| addr.is_ipv4()) {
- return Ok(SocketAddr::new(addr.ip(), port));
- }
-
- Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port))
-}
diff --git a/crates/utils/tcp_connection/src/target_configure.rs b/crates/utils/tcp_connection/src/target_configure.rs
deleted file mode 100644
index d739ac9..0000000
--- a/crates/utils/tcp_connection/src/target_configure.rs
+++ /dev/null
@@ -1,53 +0,0 @@
-use serde::{Deserialize, Serialize};
-
-#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
-pub struct ServerTargetConfig {
- /// Only process a single connection, then shut down the server.
- once: bool,
-
- /// Timeout duration in milliseconds. (0 is Closed)
- timeout: u64,
-}
-
-impl ServerTargetConfig {
- /// Set `once` to True
- /// This method configures the `once` field of `ServerTargetConfig`.
- pub fn once(mut self) -> Self {
- self.once = true;
- self
- }
-
- /// Set `timeout` to the given value
- /// This method configures the `timeout` field of `ServerTargetConfig`.
- pub fn timeout(mut self, timeout: u64) -> Self {
- self.timeout = timeout;
- self
- }
-
- /// Set `once` to the given value
- /// This method configures the `once` field of `ServerTargetConfig`.
- pub fn set_once(&mut self, enable: bool) {
- self.once = enable;
- }
-
- /// Set `timeout` to the given value
- /// This method configures the `timeout` field of `ServerTargetConfig`.
- pub fn set_timeout(&mut self, timeout: u64) {
- self.timeout = timeout;
- }
-
- /// Check if the server is configured to process only a single connection.
- /// Returns `true` if the server will shut down after processing one connection.
- pub fn is_once(&self) -> bool {
- self.once
- }
-
- /// Get the current timeout value in milliseconds.
- /// Returns the timeout duration. A value of 0 indicates the connection is closed.
- pub fn get_timeout(&self) -> u64 {
- self.timeout
- }
-}
-
-#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
-pub struct ClientTargetConfig {}
diff --git a/crates/utils/tcp_connection/src/target_connection.rs b/crates/utils/tcp_connection/src/target_connection.rs
deleted file mode 100644
index 87fd1ab..0000000
--- a/crates/utils/tcp_connection/src/target_connection.rs
+++ /dev/null
@@ -1,85 +0,0 @@
-use tokio::{
- net::{TcpListener, TcpSocket},
- spawn,
-};
-
-use crate::{
- error::TcpTargetError,
- handle::{ClientHandle, ServerHandle},
- instance::ConnectionInstance,
- target::TcpServerTarget,
- target_configure::ServerTargetConfig,
-};
-
-impl<Client, Server> TcpServerTarget<Client, Server>
-where
- Client: ClientHandle<Server>,
- Server: ServerHandle<Client>,
-{
- /// Attempts to establish a connection to the TCP server.
- /// This function initiates a connection to the server address specified in the target configuration.
- /// This is a Block operation.
- pub async fn connect(&self) -> Result<(), TcpTargetError> {
- let addr = self.get_addr();
- let Ok(socket) = TcpSocket::new_v4() else {
- return Err(TcpTargetError::from("Create tcp socket failed!"));
- };
- let stream = match socket.connect(addr).await {
- Ok(stream) => stream,
- Err(e) => {
- let err = format!("Connect to `{}` failed: {}", addr, e);
- return Err(TcpTargetError::from(err));
- }
- };
- let instance = ConnectionInstance::from(stream);
- Client::process(instance).await;
- Ok(())
- }
-
- /// Attempts to establish a connection to the TCP server.
- /// This function initiates a connection to the server address specified in the target configuration.
- pub async fn listen(&self) -> Result<(), TcpTargetError> {
- let addr = self.get_addr();
- let listener = match TcpListener::bind(addr).await {
- Ok(listener) => listener,
- Err(_) => {
- let err = format!("Bind to `{}` failed", addr);
- return Err(TcpTargetError::from(err));
- }
- };
-
- let cfg: ServerTargetConfig = match self.get_server_cfg() {
- Some(cfg) => *cfg,
- None => ServerTargetConfig::default(),
- };
-
- if cfg.is_once() {
- // Process once (Blocked)
- let (stream, _) = match listener.accept().await {
- Ok(result) => result,
- Err(e) => {
- let err = format!("Accept connection failed: {}", e);
- return Err(TcpTargetError::from(err));
- }
- };
- let instance = ConnectionInstance::from(stream);
- Server::process(instance).await;
- } else {
- loop {
- // Process multiple times (Concurrent)
- let (stream, _) = match listener.accept().await {
- Ok(result) => result,
- Err(e) => {
- let err = format!("Accept connection failed: {}", e);
- return Err(TcpTargetError::from(err));
- }
- };
- let instance = ConnectionInstance::from(stream);
- spawn(async move {
- Server::process(instance).await;
- });
- }
- }
- Ok(())
- }
-}