diff options
Diffstat (limited to 'crates/utils/tcp_connection/src')
| -rw-r--r-- | crates/utils/tcp_connection/src/error.rs | 15 | ||||
| -rw-r--r-- | crates/utils/tcp_connection/src/instance.rs | 146 | ||||
| -rw-r--r-- | crates/utils/tcp_connection/src/target_connection.rs | 9 |
3 files changed, 146 insertions, 24 deletions
diff --git a/crates/utils/tcp_connection/src/error.rs b/crates/utils/tcp_connection/src/error.rs index 171e23d..ffcce6f 100644 --- a/crates/utils/tcp_connection/src/error.rs +++ b/crates/utils/tcp_connection/src/error.rs @@ -32,6 +32,9 @@ pub enum TcpTargetError { #[error("Unsupported operation: {0}")] Unsupported(String), + + #[error("Pool already exists: {0}")] + PoolAlreadyExists(String), } impl From<io::Error> for TcpTargetError { @@ -87,3 +90,15 @@ impl From<pem::PemError> for TcpTargetError { TcpTargetError::Crypto(error.to_string()) } } + +impl From<rmp_serde::encode::Error> for TcpTargetError { + fn from(error: rmp_serde::encode::Error) -> Self { + TcpTargetError::Serialization(error.to_string()) + } +} + +impl From<rmp_serde::decode::Error> for TcpTargetError { + fn from(error: rmp_serde::decode::Error) -> Self { + TcpTargetError::Serialization(error.to_string()) + } +} diff --git a/crates/utils/tcp_connection/src/instance.rs b/crates/utils/tcp_connection/src/instance.rs index 217b10a..fd620e2 100644 --- a/crates/utils/tcp_connection/src/instance.rs +++ b/crates/utils/tcp_connection/src/instance.rs @@ -48,7 +48,7 @@ impl Default for ConnectionConfig { } pub struct ConnectionInstance { - stream: TcpStream, + pub(crate) stream: TcpStream, config: ConnectionConfig, } @@ -90,6 +90,35 @@ impl ConnectionInstance { Ok(()) } + /// Serialize data to MessagePack and write to the target machine + pub async fn write_msgpack<Data>(&mut self, data: Data) -> Result<(), TcpTargetError> + where + Data: Serialize, + { + let msgpack_data = rmp_serde::to_vec(&data)?; + let len = msgpack_data.len() as u32; + + self.stream.write_all(&len.to_be_bytes()).await?; + self.stream.write_all(&msgpack_data).await?; + Ok(()) + } + + /// Read data from target machine and deserialize from MessagePack + pub async fn read_msgpack<Data>(&mut self) -> Result<Data, TcpTargetError> + where + Data: serde::de::DeserializeOwned, + { + let mut len_buf = [0u8; 4]; + self.stream.read_exact(&mut len_buf).await?; + let len = u32::from_be_bytes(len_buf) as usize; + + let mut buffer = vec![0; len]; + self.stream.read_exact(&mut buffer).await?; + + let data = rmp_serde::from_slice(&buffer)?; + Ok(data) + } + /// Read data from target machine and deserialize pub async fn read<Data>(&mut self) -> Result<Data, TcpTargetError> where @@ -213,6 +242,73 @@ impl ConnectionInstance { Ok(String::from_utf8_lossy(&buffer).to_string()) } + /// Write large MessagePack data to the target machine (chunked) + pub async fn write_large_msgpack<Data>( + &mut self, + data: Data, + chunk_size: impl Into<u32>, + ) -> Result<(), TcpTargetError> + where + Data: Serialize, + { + let msgpack_data = rmp_serde::to_vec(&data)?; + let chunk_size = chunk_size.into() as usize; + let len = msgpack_data.len() as u32; + + // Write total length first + self.stream.write_all(&len.to_be_bytes()).await?; + + // Write data in chunks + let mut offset = 0; + while offset < msgpack_data.len() { + let end = std::cmp::min(offset + chunk_size, msgpack_data.len()); + let chunk = &msgpack_data[offset..end]; + match self.stream.write(chunk).await { + Ok(n) => offset += n, + Err(err) => return Err(TcpTargetError::Io(err.to_string())), + } + } + + Ok(()) + } + + /// Read large MessagePack data from the target machine (chunked) + pub async fn read_large_msgpack<Data>( + &mut self, + chunk_size: impl Into<u32>, + ) -> Result<Data, TcpTargetError> + where + Data: serde::de::DeserializeOwned, + { + let chunk_size = chunk_size.into() as usize; + + // Read total length first + let mut len_buf = [0u8; 4]; + self.stream.read_exact(&mut len_buf).await?; + let total_len = u32::from_be_bytes(len_buf) as usize; + + // Read data in chunks + let mut buffer = Vec::with_capacity(total_len); + let mut remaining = total_len; + let mut chunk_buf = vec![0; chunk_size]; + + while remaining > 0 { + let read_size = std::cmp::min(chunk_size, remaining); + let chunk = &mut chunk_buf[..read_size]; + + match self.stream.read_exact(chunk).await { + Ok(_) => { + buffer.extend_from_slice(chunk); + remaining -= read_size; + } + Err(err) => return Err(TcpTargetError::Io(err.to_string())), + } + } + + let data = rmp_serde::from_slice(&buffer)?; + Ok(data) + } + /// Write file to target machine. pub async fn write_file(&mut self, file_path: impl AsRef<Path>) -> Result<(), TcpTargetError> { let path = file_path.as_ref(); @@ -319,9 +415,10 @@ impl ConnectionInstance { // Make sure parent directory exists if let Some(parent) = path.parent() - && !parent.exists() { - tokio::fs::create_dir_all(parent).await?; - } + && !parent.exists() + { + tokio::fs::create_dir_all(parent).await?; + } // Read file header (version + size + crc) let mut version_buf = [0u8; 8]; @@ -398,15 +495,16 @@ impl ConnectionInstance { // Validate CRC if enabled if self.config.enable_crc_validation - && let Some(crc_calculator) = crc_calculator { - let actual_crc = crc_calculator.finalize(); - if actual_crc != expected_crc && expected_crc != 0 { - return Err(TcpTargetError::File(format!( - "CRC validation failed: expected {:08x}, got {:08x}", - expected_crc, actual_crc - ))); - } + && let Some(crc_calculator) = crc_calculator + { + let actual_crc = crc_calculator.finalize(); + if actual_crc != expected_crc && expected_crc != 0 { + return Err(TcpTargetError::File(format!( + "CRC validation failed: expected {:08x}, got {:08x}", + expected_crc, actual_crc + ))); } + } // Final flush and sync writer.flush().await?; @@ -576,22 +674,26 @@ fn parse_ed25519_public_key(pem: &str) -> [u8; 32] { let mut key_bytes = [0u8; 32]; if let Ok(pem_data) = pem::parse(pem) - && pem_data.tag() == "PUBLIC KEY" && pem_data.contents().len() >= 32 { - let contents = pem_data.contents(); - key_bytes.copy_from_slice(&contents[contents.len() - 32..]); - } + && pem_data.tag() == "PUBLIC KEY" + && pem_data.contents().len() >= 32 + { + let contents = pem_data.contents(); + key_bytes.copy_from_slice(&contents[contents.len() - 32..]); + } key_bytes } /// Parse Ed25519 private key from PEM format fn parse_ed25519_private_key(pem: &str) -> Result<SigningKey, TcpTargetError> { if let Ok(pem_data) = pem::parse(pem) - && pem_data.tag() == "PRIVATE KEY" && pem_data.contents().len() >= 32 { - let contents = pem_data.contents(); - let mut seed = [0u8; 32]; - seed.copy_from_slice(&contents[contents.len() - 32..]); - return Ok(SigningKey::from_bytes(&seed)); - } + && pem_data.tag() == "PRIVATE KEY" + && pem_data.contents().len() >= 32 + { + let contents = pem_data.contents(); + let mut seed = [0u8; 32]; + seed.copy_from_slice(&contents[contents.len() - 32..]); + return Ok(SigningKey::from_bytes(&seed)); + } Err(TcpTargetError::Crypto( "Invalid Ed25519 private key format".to_string(), )) diff --git a/crates/utils/tcp_connection/src/target_connection.rs b/crates/utils/tcp_connection/src/target_connection.rs index 87fd1ab..0462f7b 100644 --- a/crates/utils/tcp_connection/src/target_connection.rs +++ b/crates/utils/tcp_connection/src/target_connection.rs @@ -17,7 +17,10 @@ where Server: ServerHandle<Client>, { /// Attempts to establish a connection to the TCP server. - /// This function initiates a connection to the server address specified in the target configuration. + /// + /// This function initiates a connection to the server address + /// specified in the target configuration. + /// /// This is a Block operation. pub async fn connect(&self) -> Result<(), TcpTargetError> { let addr = self.get_addr(); @@ -37,7 +40,9 @@ where } /// Attempts to establish a connection to the TCP server. - /// This function initiates a connection to the server address specified in the target configuration. + /// + /// This function initiates a connection to the server address + /// specified in the target configuration. pub async fn listen(&self) -> Result<(), TcpTargetError> { let addr = self.get_addr(); let listener = match TcpListener::bind(addr).await { |
