summaryrefslogtreecommitdiff
path: root/crates/utils/tcp_connection
diff options
context:
space:
mode:
Diffstat (limited to 'crates/utils/tcp_connection')
-rw-r--r--crates/utils/tcp_connection/Cargo.toml1
-rw-r--r--crates/utils/tcp_connection/src/error.rs15
-rw-r--r--crates/utils/tcp_connection/src/instance.rs146
-rw-r--r--crates/utils/tcp_connection/src/target_connection.rs9
-rw-r--r--crates/utils/tcp_connection/tcp_connection_test/Cargo.toml1
-rw-r--r--crates/utils/tcp_connection/tcp_connection_test/src/lib.rs3
-rw-r--r--crates/utils/tcp_connection/tcp_connection_test/src/test_msgpack.rs111
7 files changed, 262 insertions, 24 deletions
diff --git a/crates/utils/tcp_connection/Cargo.toml b/crates/utils/tcp_connection/Cargo.toml
index e70baf0..22466c8 100644
--- a/crates/utils/tcp_connection/Cargo.toml
+++ b/crates/utils/tcp_connection/Cargo.toml
@@ -9,6 +9,7 @@ tokio = { version = "1.46.1", features = ["full"] }
# Serialization
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
+rmp-serde = "1.3.0"
# Error handling
thiserror = "1.0.69"
diff --git a/crates/utils/tcp_connection/src/error.rs b/crates/utils/tcp_connection/src/error.rs
index 171e23d..ffcce6f 100644
--- a/crates/utils/tcp_connection/src/error.rs
+++ b/crates/utils/tcp_connection/src/error.rs
@@ -32,6 +32,9 @@ pub enum TcpTargetError {
#[error("Unsupported operation: {0}")]
Unsupported(String),
+
+ #[error("Pool already exists: {0}")]
+ PoolAlreadyExists(String),
}
impl From<io::Error> for TcpTargetError {
@@ -87,3 +90,15 @@ impl From<pem::PemError> for TcpTargetError {
TcpTargetError::Crypto(error.to_string())
}
}
+
+impl From<rmp_serde::encode::Error> for TcpTargetError {
+ fn from(error: rmp_serde::encode::Error) -> Self {
+ TcpTargetError::Serialization(error.to_string())
+ }
+}
+
+impl From<rmp_serde::decode::Error> for TcpTargetError {
+ fn from(error: rmp_serde::decode::Error) -> Self {
+ TcpTargetError::Serialization(error.to_string())
+ }
+}
diff --git a/crates/utils/tcp_connection/src/instance.rs b/crates/utils/tcp_connection/src/instance.rs
index 217b10a..fd620e2 100644
--- a/crates/utils/tcp_connection/src/instance.rs
+++ b/crates/utils/tcp_connection/src/instance.rs
@@ -48,7 +48,7 @@ impl Default for ConnectionConfig {
}
pub struct ConnectionInstance {
- stream: TcpStream,
+ pub(crate) stream: TcpStream,
config: ConnectionConfig,
}
@@ -90,6 +90,35 @@ impl ConnectionInstance {
Ok(())
}
+ /// Serialize data to MessagePack and write to the target machine
+ pub async fn write_msgpack<Data>(&mut self, data: Data) -> Result<(), TcpTargetError>
+ where
+ Data: Serialize,
+ {
+ let msgpack_data = rmp_serde::to_vec(&data)?;
+ let len = msgpack_data.len() as u32;
+
+ self.stream.write_all(&len.to_be_bytes()).await?;
+ self.stream.write_all(&msgpack_data).await?;
+ Ok(())
+ }
+
+ /// Read data from target machine and deserialize from MessagePack
+ pub async fn read_msgpack<Data>(&mut self) -> Result<Data, TcpTargetError>
+ where
+ Data: serde::de::DeserializeOwned,
+ {
+ let mut len_buf = [0u8; 4];
+ self.stream.read_exact(&mut len_buf).await?;
+ let len = u32::from_be_bytes(len_buf) as usize;
+
+ let mut buffer = vec![0; len];
+ self.stream.read_exact(&mut buffer).await?;
+
+ let data = rmp_serde::from_slice(&buffer)?;
+ Ok(data)
+ }
+
/// Read data from target machine and deserialize
pub async fn read<Data>(&mut self) -> Result<Data, TcpTargetError>
where
@@ -213,6 +242,73 @@ impl ConnectionInstance {
Ok(String::from_utf8_lossy(&buffer).to_string())
}
+ /// Write large MessagePack data to the target machine (chunked)
+ pub async fn write_large_msgpack<Data>(
+ &mut self,
+ data: Data,
+ chunk_size: impl Into<u32>,
+ ) -> Result<(), TcpTargetError>
+ where
+ Data: Serialize,
+ {
+ let msgpack_data = rmp_serde::to_vec(&data)?;
+ let chunk_size = chunk_size.into() as usize;
+ let len = msgpack_data.len() as u32;
+
+ // Write total length first
+ self.stream.write_all(&len.to_be_bytes()).await?;
+
+ // Write data in chunks
+ let mut offset = 0;
+ while offset < msgpack_data.len() {
+ let end = std::cmp::min(offset + chunk_size, msgpack_data.len());
+ let chunk = &msgpack_data[offset..end];
+ match self.stream.write(chunk).await {
+ Ok(n) => offset += n,
+ Err(err) => return Err(TcpTargetError::Io(err.to_string())),
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Read large MessagePack data from the target machine (chunked)
+ pub async fn read_large_msgpack<Data>(
+ &mut self,
+ chunk_size: impl Into<u32>,
+ ) -> Result<Data, TcpTargetError>
+ where
+ Data: serde::de::DeserializeOwned,
+ {
+ let chunk_size = chunk_size.into() as usize;
+
+ // Read total length first
+ let mut len_buf = [0u8; 4];
+ self.stream.read_exact(&mut len_buf).await?;
+ let total_len = u32::from_be_bytes(len_buf) as usize;
+
+ // Read data in chunks
+ let mut buffer = Vec::with_capacity(total_len);
+ let mut remaining = total_len;
+ let mut chunk_buf = vec![0; chunk_size];
+
+ while remaining > 0 {
+ let read_size = std::cmp::min(chunk_size, remaining);
+ let chunk = &mut chunk_buf[..read_size];
+
+ match self.stream.read_exact(chunk).await {
+ Ok(_) => {
+ buffer.extend_from_slice(chunk);
+ remaining -= read_size;
+ }
+ Err(err) => return Err(TcpTargetError::Io(err.to_string())),
+ }
+ }
+
+ let data = rmp_serde::from_slice(&buffer)?;
+ Ok(data)
+ }
+
/// Write file to target machine.
pub async fn write_file(&mut self, file_path: impl AsRef<Path>) -> Result<(), TcpTargetError> {
let path = file_path.as_ref();
@@ -319,9 +415,10 @@ impl ConnectionInstance {
// Make sure parent directory exists
if let Some(parent) = path.parent()
- && !parent.exists() {
- tokio::fs::create_dir_all(parent).await?;
- }
+ && !parent.exists()
+ {
+ tokio::fs::create_dir_all(parent).await?;
+ }
// Read file header (version + size + crc)
let mut version_buf = [0u8; 8];
@@ -398,15 +495,16 @@ impl ConnectionInstance {
// Validate CRC if enabled
if self.config.enable_crc_validation
- && let Some(crc_calculator) = crc_calculator {
- let actual_crc = crc_calculator.finalize();
- if actual_crc != expected_crc && expected_crc != 0 {
- return Err(TcpTargetError::File(format!(
- "CRC validation failed: expected {:08x}, got {:08x}",
- expected_crc, actual_crc
- )));
- }
+ && let Some(crc_calculator) = crc_calculator
+ {
+ let actual_crc = crc_calculator.finalize();
+ if actual_crc != expected_crc && expected_crc != 0 {
+ return Err(TcpTargetError::File(format!(
+ "CRC validation failed: expected {:08x}, got {:08x}",
+ expected_crc, actual_crc
+ )));
}
+ }
// Final flush and sync
writer.flush().await?;
@@ -576,22 +674,26 @@ fn parse_ed25519_public_key(pem: &str) -> [u8; 32] {
let mut key_bytes = [0u8; 32];
if let Ok(pem_data) = pem::parse(pem)
- && pem_data.tag() == "PUBLIC KEY" && pem_data.contents().len() >= 32 {
- let contents = pem_data.contents();
- key_bytes.copy_from_slice(&contents[contents.len() - 32..]);
- }
+ && pem_data.tag() == "PUBLIC KEY"
+ && pem_data.contents().len() >= 32
+ {
+ let contents = pem_data.contents();
+ key_bytes.copy_from_slice(&contents[contents.len() - 32..]);
+ }
key_bytes
}
/// Parse Ed25519 private key from PEM format
fn parse_ed25519_private_key(pem: &str) -> Result<SigningKey, TcpTargetError> {
if let Ok(pem_data) = pem::parse(pem)
- && pem_data.tag() == "PRIVATE KEY" && pem_data.contents().len() >= 32 {
- let contents = pem_data.contents();
- let mut seed = [0u8; 32];
- seed.copy_from_slice(&contents[contents.len() - 32..]);
- return Ok(SigningKey::from_bytes(&seed));
- }
+ && pem_data.tag() == "PRIVATE KEY"
+ && pem_data.contents().len() >= 32
+ {
+ let contents = pem_data.contents();
+ let mut seed = [0u8; 32];
+ seed.copy_from_slice(&contents[contents.len() - 32..]);
+ return Ok(SigningKey::from_bytes(&seed));
+ }
Err(TcpTargetError::Crypto(
"Invalid Ed25519 private key format".to_string(),
))
diff --git a/crates/utils/tcp_connection/src/target_connection.rs b/crates/utils/tcp_connection/src/target_connection.rs
index 87fd1ab..0462f7b 100644
--- a/crates/utils/tcp_connection/src/target_connection.rs
+++ b/crates/utils/tcp_connection/src/target_connection.rs
@@ -17,7 +17,10 @@ where
Server: ServerHandle<Client>,
{
/// Attempts to establish a connection to the TCP server.
- /// This function initiates a connection to the server address specified in the target configuration.
+ ///
+ /// This function initiates a connection to the server address
+ /// specified in the target configuration.
+ ///
/// This is a Block operation.
pub async fn connect(&self) -> Result<(), TcpTargetError> {
let addr = self.get_addr();
@@ -37,7 +40,9 @@ where
}
/// Attempts to establish a connection to the TCP server.
- /// This function initiates a connection to the server address specified in the target configuration.
+ ///
+ /// This function initiates a connection to the server address
+ /// specified in the target configuration.
pub async fn listen(&self) -> Result<(), TcpTargetError> {
let addr = self.get_addr();
let listener = match TcpListener::bind(addr).await {
diff --git a/crates/utils/tcp_connection/tcp_connection_test/Cargo.toml b/crates/utils/tcp_connection/tcp_connection_test/Cargo.toml
index e4cba71..397f13a 100644
--- a/crates/utils/tcp_connection/tcp_connection_test/Cargo.toml
+++ b/crates/utils/tcp_connection/tcp_connection_test/Cargo.toml
@@ -6,3 +6,4 @@ version.workspace = true
[dependencies]
tcp_connection = { path = "../../tcp_connection" }
tokio = { version = "1.46.1", features = ["full"] }
+serde = { version = "1.0.219", features = ["derive"] }
diff --git a/crates/utils/tcp_connection/tcp_connection_test/src/lib.rs b/crates/utils/tcp_connection/tcp_connection_test/src/lib.rs
index f0eb66e..beba25b 100644
--- a/crates/utils/tcp_connection/tcp_connection_test/src/lib.rs
+++ b/crates/utils/tcp_connection/tcp_connection_test/src/lib.rs
@@ -9,3 +9,6 @@ pub mod test_challenge;
#[cfg(test)]
pub mod test_file_transfer;
+
+#[cfg(test)]
+pub mod test_msgpack;
diff --git a/crates/utils/tcp_connection/tcp_connection_test/src/test_msgpack.rs b/crates/utils/tcp_connection/tcp_connection_test/src/test_msgpack.rs
new file mode 100644
index 0000000..7344d64
--- /dev/null
+++ b/crates/utils/tcp_connection/tcp_connection_test/src/test_msgpack.rs
@@ -0,0 +1,111 @@
+use serde::{Deserialize, Serialize};
+use std::time::Duration;
+use tcp_connection::{
+ handle::{ClientHandle, ServerHandle},
+ instance::ConnectionInstance,
+ target::TcpServerTarget,
+ target_configure::ServerTargetConfig,
+};
+use tokio::{join, time::sleep};
+
+#[derive(Debug, PartialEq, Serialize, Deserialize)]
+struct TestData {
+ id: u32,
+ name: String,
+}
+
+impl Default for TestData {
+ fn default() -> Self {
+ Self {
+ id: 0,
+ name: String::new(),
+ }
+ }
+}
+
+pub(crate) struct MsgPackClientHandle;
+
+impl ClientHandle<MsgPackServerHandle> for MsgPackClientHandle {
+ async fn process(mut instance: ConnectionInstance) {
+ // Test basic MessagePack serialization
+ let test_data = TestData {
+ id: 42,
+ name: "Test MessagePack".to_string(),
+ };
+
+ // Write MessagePack data
+ if let Err(e) = instance.write_msgpack(&test_data).await {
+ panic!("Write MessagePack failed: {}", e);
+ }
+
+ // Read response
+ let response: TestData = match instance.read_msgpack().await {
+ Ok(data) => data,
+ Err(e) => panic!("Read MessagePack response failed: {}", e),
+ };
+
+ // Verify response
+ assert_eq!(response.id, test_data.id * 2);
+ assert_eq!(response.name, format!("Processed: {}", test_data.name));
+ }
+}
+
+pub(crate) struct MsgPackServerHandle;
+
+impl ServerHandle<MsgPackClientHandle> for MsgPackServerHandle {
+ async fn process(mut instance: ConnectionInstance) {
+ // Read MessagePack data
+ let received_data: TestData = match instance.read_msgpack().await {
+ Ok(data) => data,
+ Err(_) => return,
+ };
+
+ // Process data
+ let response = TestData {
+ id: received_data.id * 2,
+ name: format!("Processed: {}", received_data.name),
+ };
+
+ // Write response as MessagePack
+ if let Err(e) = instance.write_msgpack(&response).await {
+ panic!("Write MessagePack response failed: {}", e);
+ }
+ }
+}
+
+#[tokio::test]
+async fn test_msgpack_basic() {
+ let host = "localhost:5013";
+
+ // Server setup
+ let Ok(server_target) =
+ TcpServerTarget::<MsgPackClientHandle, MsgPackServerHandle>::from_domain(host).await
+ else {
+ panic!("Test target built failed from a domain named `{}`", host);
+ };
+
+ // Client setup
+ let Ok(client_target) =
+ TcpServerTarget::<MsgPackClientHandle, MsgPackServerHandle>::from_domain(host).await
+ else {
+ panic!("Test target built failed from a domain named `{}`", host);
+ };
+
+ let future_server = async move {
+ // Only process once
+ let configured_server = server_target.server_cfg(ServerTargetConfig::default().once());
+
+ // Listen here
+ let _ = configured_server.listen().await;
+ };
+
+ let future_client = async move {
+ // Wait for server start
+ let _ = sleep(Duration::from_secs_f32(1.5)).await;
+
+ // Connect here
+ let _ = client_target.connect().await;
+ };
+
+ let _ = async { join!(future_client, future_server) }.await;
+}