diff options
| -rw-r--r-- | crates/utils/tcp_connection/src/instance.rs | 488 | ||||
| -rw-r--r-- | crates/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs | 16 |
2 files changed, 296 insertions, 208 deletions
diff --git a/crates/utils/tcp_connection/src/instance.rs b/crates/utils/tcp_connection/src/instance.rs index db36b54..be7c956 100644 --- a/crates/utils/tcp_connection/src/instance.rs +++ b/crates/utils/tcp_connection/src/instance.rs @@ -1,7 +1,6 @@ use std::{path::Path, time::Duration}; -use base64::{engine::general_purpose::STANDARD, prelude::*}; -use rand::Rng; +use rand::{Rng, TryRngCore}; use rsa::{ RsaPrivateKey, RsaPublicKey, pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey}, @@ -16,46 +15,77 @@ use tokio::{ use uuid::Uuid; use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey}; +use ring::rand::SystemRandom; use ring::signature::{ - self, ECDSA_P256_SHA256_ASN1, ECDSA_P384_SHA384_ASN1, RSA_PKCS1_2048_8192_SHA256, + self, ECDSA_P256_SHA256_ASN1, ECDSA_P384_SHA384_ASN1, EcdsaKeyPair, RSA_PKCS1_2048_8192_SHA256, UnparsedPublicKey, }; use crate::error::TcpTargetError; -const CHUNK_SIZE: usize = 8 * 1024; +const DEFAULT_CHUNK_SIZE: usize = 4096; +const DEFAULT_TIMEOUT_SECS: u64 = 10; -pub struct ConnectionInstance { - stream: TcpStream, +const ECDSA_P256_SHA256_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm = + &signature::ECDSA_P256_SHA256_ASN1_SIGNING; +const ECDSA_P384_SHA384_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm = + &signature::ECDSA_P384_SHA384_ASN1_SIGNING; + +#[derive(Debug, Clone)] +pub struct ConnectionConfig { + pub chunk_size: usize, + pub timeout_secs: u64, + pub enable_crc_validation: bool, } -impl From<TcpStream> for ConnectionInstance { - fn from(value: TcpStream) -> Self { - Self { stream: value } +impl Default for ConnectionConfig { + fn default() -> Self { + Self { + chunk_size: DEFAULT_CHUNK_SIZE, + timeout_secs: DEFAULT_TIMEOUT_SECS, + enable_crc_validation: false, + } } } -// Helper trait for reading u64 from TcpStream -trait ReadU64Ext { - async fn read_u64(&mut self) -> Result<u64, std::io::Error>; +pub struct ConnectionInstance { + stream: TcpStream, + config: ConnectionConfig, } -impl ReadU64Ext for TcpStream { - async fn read_u64(&mut self) -> Result<u64, std::io::Error> { - let mut buf = [0u8; 8]; - self.read_exact(&mut buf).await?; - Ok(u64::from_be_bytes(buf)) +impl From<TcpStream> for ConnectionInstance { + fn from(stream: TcpStream) -> Self { + Self { + stream, + config: ConnectionConfig::default(), + } } } impl ConnectionInstance { + /// Create a new ConnectionInstance with custom configuration + pub fn with_config(stream: TcpStream, config: ConnectionConfig) -> Self { + Self { stream, config } + } + + /// Get a reference to the current configuration + pub fn config(&self) -> &ConnectionConfig { + &self.config + } + + /// Get a mutable reference to the current configuration + pub fn config_mut(&mut self) -> &mut ConnectionConfig { + &mut self.config + } /// Serialize data and write to the target machine pub async fn write<Data>(&mut self, data: Data) -> Result<(), TcpTargetError> where Data: Default + Serialize, { let Ok(json_text) = serde_json::to_string(&data) else { - return Err(TcpTargetError::from("Serialize failed.")); + return Err(TcpTargetError::Serialization( + "Serialize failed.".to_string(), + )); }; Self::write_text(self, json_text).await?; Ok(()) @@ -67,10 +97,12 @@ impl ConnectionInstance { Data: Default + serde::de::DeserializeOwned, { let Ok(json_text) = Self::read_text(self, buffer_size).await else { - return Err(TcpTargetError::from("Read failed.")); + return Err(TcpTargetError::Io("Read failed.".to_string())); }; let Ok(deser_obj) = serde_json::from_str::<Data>(&json_text) else { - return Err(TcpTargetError::from("Deserialize failed.")); + return Err(TcpTargetError::Serialization( + "Deserialize failed.".to_string(), + )); }; Ok(deser_obj) } @@ -81,7 +113,9 @@ impl ConnectionInstance { Data: Default + Serialize, { let Ok(json_text) = serde_json::to_string(&data) else { - return Err(TcpTargetError::from("Serialize failed.")); + return Err(TcpTargetError::Serialization( + "Serialize failed.".to_string(), + )); }; Self::write_large_text(self, json_text).await?; Ok(()) @@ -96,10 +130,12 @@ impl ConnectionInstance { Data: Default + serde::de::DeserializeOwned, { let Ok(json_text) = Self::read_large_text(self, buffer_size).await else { - return Err(TcpTargetError::from("Read failed.")); + return Err(TcpTargetError::Io("Read failed.".to_string())); }; let Ok(deser_obj) = serde_json::from_str::<Data>(&json_text) else { - return Err(TcpTargetError::from("Deserialize failed.")); + return Err(TcpTargetError::Serialization( + "Deserialize failed.".to_string(), + )); }; Ok(deser_obj) } @@ -111,7 +147,7 @@ impl ConnectionInstance { // Write match self.stream.write_all(text.as_bytes()).await { Ok(_) => Ok(()), - Err(err) => Err(TcpTargetError::from(err.to_string())), + Err(err) => Err(TcpTargetError::Io(err.to_string())), } } @@ -128,7 +164,7 @@ impl ConnectionInstance { let text = String::from_utf8_lossy(&buffer[..n]).to_string(); Ok(text) } - Err(err) => Err(TcpTargetError::from(err.to_string())), + Err(err) => Err(TcpTargetError::Io(err.to_string())), } } @@ -145,7 +181,7 @@ impl ConnectionInstance { let chunk = &bytes[offset..]; let written = match self.stream.write(chunk).await { Ok(n) => n, - Err(err) => return Err(TcpTargetError::from(err.to_string())), + Err(err) => return Err(TcpTargetError::Io(err.to_string())), }; offset += written; } @@ -168,7 +204,7 @@ impl ConnectionInstance { Ok(n) => { buffer.extend_from_slice(&chunk_buf[..n]); } - Err(err) => return Err(TcpTargetError::from(err.to_string())), + Err(err) => return Err(TcpTargetError::Io(err.to_string())), } } @@ -181,59 +217,66 @@ impl ConnectionInstance { // Validate file if !path.exists() { - return Err(TcpTargetError::from(format!( + return Err(TcpTargetError::File(format!( "File not found: {}", path.display() ))); } if path.is_dir() { - return Err(TcpTargetError::from(format!( + return Err(TcpTargetError::File(format!( "Path is directory: {}", path.display() ))); } // Open file and get metadata - let mut file = File::open(path) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; - let file_size = file - .metadata() - .await - .map_err(|e| TcpTargetError::from(e.to_string()))? - .len(); - if file_size == 0 { - return Err(TcpTargetError::from("Cannot send empty file")); - } + let mut file = File::open(path).await?; + let file_size = file.metadata().await?.len(); + // Allow empty files - just send the header with size 0 + + // Send file header (version + size + crc) + self.stream.write_all(&1u64.to_be_bytes()).await?; + self.stream.write_all(&file_size.to_be_bytes()).await?; + + // Calculate and send CRC32 if enabled + let file_crc = if self.config.enable_crc_validation { + let crc32 = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC); + let mut crc_calculator = crc32.digest(); + + let mut temp_reader = + BufReader::with_capacity(self.config.chunk_size, File::open(path).await?); + let mut temp_buffer = vec![0u8; self.config.chunk_size]; + let mut temp_bytes_read = 0; + + while temp_bytes_read < file_size { + let bytes_to_read = + (file_size - temp_bytes_read).min(self.config.chunk_size as u64) as usize; + temp_reader + .read_exact(&mut temp_buffer[..bytes_to_read]) + .await?; + crc_calculator.update(&temp_buffer[..bytes_to_read]); + temp_bytes_read += bytes_to_read as u64; + } + + crc_calculator.finalize() + } else { + 0 + }; - // Send file header (version + size) - self.stream - .write_all(&1u64.to_be_bytes()) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; - self.stream - .write_all(&file_size.to_be_bytes()) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.write_all(&file_crc.to_be_bytes()).await?; // Transfer file content - let mut reader = BufReader::with_capacity(CHUNK_SIZE, &mut file); + let mut reader = BufReader::with_capacity(self.config.chunk_size, &mut file); let mut bytes_sent = 0; while bytes_sent < file_size { - let buffer = reader - .fill_buf() - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + let buffer = reader.fill_buf().await?; if buffer.is_empty() { break; } let chunk_size = buffer.len().min((file_size - bytes_sent) as usize); - self.stream - .write_all(&buffer[..chunk_size]) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.write_all(&buffer[..chunk_size]).await?; reader.consume(chunk_size); bytes_sent += chunk_size as u64; @@ -241,26 +284,27 @@ impl ConnectionInstance { // Verify transfer completion if bytes_sent != file_size { - return Err(TcpTargetError::from(format!( + return Err(TcpTargetError::File(format!( "Transfer incomplete: expected {} bytes, sent {} bytes", file_size, bytes_sent ))); } - self.stream - .flush() - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.flush().await?; // Wait for receiver confirmation let mut ack = [0u8; 1]; - tokio::time::timeout(Duration::from_secs(10), self.stream.read_exact(&mut ack)) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))? - .map_err(|e| TcpTargetError::from(e.to_string()))?; + tokio::time::timeout( + Duration::from_secs(self.config.timeout_secs), + self.stream.read_exact(&mut ack), + ) + .await + .map_err(|_| TcpTargetError::Timeout("Ack timeout".to_string()))??; if ack[0] != 1 { - return Err(TcpTargetError::from("Receiver verification failed")); + return Err(TcpTargetError::Protocol( + "Receiver verification failed".to_string(), + )); } Ok(()) @@ -269,35 +313,42 @@ impl ConnectionInstance { /// Read file from target machine pub async fn read_file(&mut self, save_path: impl AsRef<Path>) -> Result<(), TcpTargetError> { let path = save_path.as_ref(); + // Create CRC instance at function scope to ensure proper lifetime + let crc_instance = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC); // Make sure parent directory exists if let Some(parent) = path.parent() { if !parent.exists() { - tokio::fs::create_dir_all(parent) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + tokio::fs::create_dir_all(parent).await?; } } - // Read file header (version + size) + // Read file header (version + size + crc) let mut version_buf = [0u8; 8]; - self.stream - .read_exact(&mut version_buf) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.read_exact(&mut version_buf).await?; let version = u64::from_be_bytes(version_buf); if version != 1 { - return Err(TcpTargetError::from("Unsupported transfer version")); + return Err(TcpTargetError::Protocol( + "Unsupported transfer version".to_string(), + )); } let mut size_buf = [0u8; 8]; - self.stream - .read_exact(&mut size_buf) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.read_exact(&mut size_buf).await?; let file_size = u64::from_be_bytes(size_buf); + + let mut expected_crc_buf = [0u8; 4]; + self.stream.read_exact(&mut expected_crc_buf).await?; + let expected_crc = u32::from_be_bytes(expected_crc_buf); if file_size == 0 { - return Err(TcpTargetError::from("Cannot receive zero-length file")); + // Create empty file and return early + let _file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .await?; + return Ok(()); } // Prepare output file @@ -306,57 +357,74 @@ impl ConnectionInstance { .create(true) .truncate(true) .open(path) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; - let mut writer = BufWriter::with_capacity(CHUNK_SIZE, file); + .await?; + let mut writer = BufWriter::with_capacity(self.config.chunk_size, file); - // Receive file content - let mut buffer = vec![0u8; CHUNK_SIZE]; + // Receive file content with CRC calculation if enabled let mut bytes_received = 0; + let mut buffer = vec![0u8; self.config.chunk_size]; + let mut crc_calculator = if self.config.enable_crc_validation { + Some(crc_instance.digest()) + } else { + None + }; while bytes_received < file_size { - let read_size = buffer.len().min((file_size - bytes_received) as usize); - self.stream - .read_exact(&mut buffer[..read_size]) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; - - writer - .write_all(&buffer[..read_size]) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; - bytes_received += read_size as u64; + let bytes_to_read = + (file_size - bytes_received).min(self.config.chunk_size as u64) as usize; + let chunk = &mut buffer[..bytes_to_read]; + + self.stream.read_exact(chunk).await?; + + writer.write_all(chunk).await?; + + // Update CRC if validation is enabled + if let Some(ref mut crc) = crc_calculator { + crc.update(chunk); + } + + bytes_received += bytes_to_read as u64; + } + + // Verify transfer completion + if bytes_received != file_size { + return Err(TcpTargetError::File(format!( + "Transfer incomplete: expected {} bytes, received {} bytes", + file_size, bytes_received + ))); + } + + writer.flush().await?; + + // Validate CRC if enabled + if self.config.enable_crc_validation { + if let Some(crc_calculator) = crc_calculator { + let actual_crc = crc_calculator.finalize(); + if actual_crc != expected_crc && expected_crc != 0 { + return Err(TcpTargetError::File(format!( + "CRC validation failed: expected {:08x}, got {:08x}", + expected_crc, actual_crc + ))); + } + } } // Final flush and sync - writer - .flush() - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; - writer - .into_inner() - .sync_all() - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + writer.flush().await?; + writer.into_inner().sync_all().await?; // Verify completion if bytes_received != file_size { let _ = tokio::fs::remove_file(path).await; - return Err(TcpTargetError::from(format!( + return Err(TcpTargetError::File(format!( "Transfer incomplete: expected {} bytes, received {} bytes", file_size, bytes_received ))); } // Send confirmation - self.stream - .write_all(&[1]) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; - self.stream - .flush() - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.write_all(&[1u8]).await?; + self.stream.flush().await?; Ok(()) } @@ -366,36 +434,28 @@ impl ConnectionInstance { public_key_dir: impl AsRef<Path>, ) -> Result<bool, TcpTargetError> { // Generate random challenge - let mut rng = rand::rng(); - let challenge: [u8; 32] = rng.random(); + let mut challenge = [0u8; 32]; + rand::rngs::OsRng + .try_fill_bytes(&mut challenge) + .map_err(|e| { + TcpTargetError::Crypto(format!("Failed to generate random challenge: {}", e)) + })?; // Send challenge to target - self.stream - .write_all(&challenge) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.write_all(&challenge).await?; // Read signature from target let mut signature = Vec::new(); let mut signature_len_buf = [0u8; 4]; - self.stream - .read_exact(&mut signature_len_buf) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.read_exact(&mut signature_len_buf).await?; let signature_len = u32::from_be_bytes(signature_len_buf) as usize; signature.resize(signature_len, 0); - self.stream - .read_exact(&mut signature) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.read_exact(&mut signature).await?; // Read UUID from target to identify which public key to use let mut uuid_buf = [0u8; 16]; - self.stream - .read_exact(&mut uuid_buf) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.read_exact(&mut uuid_buf).await?; let user_uuid = Uuid::from_bytes(uuid_buf); // Load appropriate public key @@ -404,9 +464,7 @@ impl ConnectionInstance { return Ok(false); } - let public_key_pem = tokio::fs::read_to_string(&public_key_path) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + let public_key_pem = tokio::fs::read_to_string(&public_key_path).await?; // Try to verify with different key types let verified = if let Ok(rsa_key) = RsaPublicKey::from_pkcs1_pem(&public_key_pem) { @@ -415,9 +473,15 @@ impl ConnectionInstance { } else if let Ok(ed25519_key) = VerifyingKey::from_bytes(&parse_ed25519_public_key(&public_key_pem)) { - let sig_bytes: [u8; 64] = signature.as_slice().try_into().unwrap_or([0u8; 64]); - let sig = Signature::from_bytes(&sig_bytes); - ed25519_key.verify(&challenge, &sig).is_ok() + if signature.len() == 64 { + let sig_bytes: [u8; 64] = signature.as_slice().try_into().map_err(|_| { + TcpTargetError::Crypto("Invalid signature length for Ed25519".to_string()) + })?; + let sig = Signature::from_bytes(&sig_bytes); + ed25519_key.verify(&challenge, &sig).is_ok() + } else { + false + } } else if let Ok(dsa_key_info) = parse_dsa_public_key(&public_key_pem) { verify_dsa_signature(&dsa_key_info, &challenge, &signature) } else { @@ -430,50 +494,34 @@ impl ConnectionInstance { pub async fn accept_challenge( &mut self, private_key_file: impl AsRef<Path>, - verify_user_uuid: Uuid, + verify_public_key: &str, ) -> Result<bool, TcpTargetError> { // Read challenge from initiator let mut challenge = [0u8; 32]; - self.stream - .read_exact(&mut challenge) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.read_exact(&mut challenge).await?; // Load private key - let private_key_pem = tokio::fs::read_to_string(&private_key_file) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + let private_key_pem = tokio::fs::read_to_string(&private_key_file).await?; // Sign the challenge with supported key types let signature = if let Ok(rsa_key) = RsaPrivateKey::from_pkcs1_pem(&private_key_pem) { let padding = rsa::pkcs1v15::Pkcs1v15Sign::new::<sha2::Sha256>(); - rsa_key - .sign(padding, &challenge) - .map_err(|e| TcpTargetError::from(e.to_string()))? + rsa_key.sign(padding, &challenge)? } else if let Ok(ed25519_key) = parse_ed25519_private_key(&private_key_pem) { ed25519_key.sign(&challenge).to_bytes().to_vec() } else if let Ok(dsa_key_info) = parse_dsa_private_key(&private_key_pem) { - sign_with_dsa(&dsa_key_info, &challenge) + sign_with_dsa(&dsa_key_info, &challenge)? } else { return Ok(false); }; // Send signature length and signature let signature_len = signature.len() as u32; - self.stream - .write_all(&signature_len.to_be_bytes()) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; - self.stream - .write_all(&signature) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.write_all(&signature_len.to_be_bytes()).await?; + self.stream.write_all(&signature).await?; // Send UUID for public key identification - self.stream - .write_all(verify_user_uuid.as_bytes()) - .await - .map_err(|e| TcpTargetError::from(e.to_string()))?; + self.stream.write_all(verify_public_key.as_bytes()).await?; Ok(true) } @@ -483,15 +531,13 @@ impl ConnectionInstance { /// Parse Ed25519 public key from PEM format fn parse_ed25519_public_key(pem: &str) -> [u8; 32] { - // Simple parsing for Ed25519 public key (assuming raw 32-byte key) - let lines: Vec<&str> = pem.lines().collect(); + // Robust parsing for Ed25519 public key using pem crate let mut key_bytes = [0u8; 32]; - if lines.len() >= 2 && lines[0].contains("PUBLIC KEY") { - if let Ok(decoded) = STANDARD.decode(lines[1].trim()) { - if decoded.len() >= 32 { - key_bytes.copy_from_slice(&decoded[decoded.len() - 32..]); - } + if let Ok(pem_data) = pem::parse(pem) { + if pem_data.tag() == "PUBLIC KEY" && pem_data.contents().len() >= 32 { + let contents = pem_data.contents(); + key_bytes.copy_from_slice(&contents[contents.len() - 32..]); } } key_bytes @@ -499,18 +545,17 @@ fn parse_ed25519_public_key(pem: &str) -> [u8; 32] { /// Parse Ed25519 private key from PEM format fn parse_ed25519_private_key(pem: &str) -> Result<SigningKey, TcpTargetError> { - let lines: Vec<&str> = pem.lines().collect(); - - if lines.len() >= 2 && lines[0].contains("PRIVATE KEY") { - if let Ok(decoded) = STANDARD.decode(lines[1].trim()) { - if decoded.len() >= 32 { - let mut seed = [0u8; 32]; - seed.copy_from_slice(&decoded[decoded.len() - 32..]); - return Ok(SigningKey::from_bytes(&seed)); - } + if let Ok(pem_data) = pem::parse(pem) { + if pem_data.tag() == "PRIVATE KEY" && pem_data.contents().len() >= 32 { + let contents = pem_data.contents(); + let mut seed = [0u8; 32]; + seed.copy_from_slice(&contents[contents.len() - 32..]); + return Ok(SigningKey::from_bytes(&seed)); } } - Err(TcpTargetError::from("Invalid Ed25519 private key format")) + Err(TcpTargetError::Crypto( + "Invalid Ed25519 private key format".to_string(), + )) } // Helper functions for DSA support @@ -519,23 +564,30 @@ fn parse_ed25519_private_key(pem: &str) -> Result<SigningKey, TcpTargetError> { fn parse_dsa_public_key( pem: &str, ) -> Result<(&'static dyn signature::VerificationAlgorithm, Vec<u8>), TcpTargetError> { - let lines: Vec<&str> = pem.lines().collect(); + if let Ok(pem_data) = pem::parse(pem) { + let contents = pem_data.contents().to_vec(); - if lines.len() >= 2 { - if let Ok(decoded) = STANDARD.decode(lines[1].trim()) { - // Try different DSA algorithms - if pem.contains("ECDSA") || pem.contains("ecdsa") { + // Try different DSA algorithms based on PEM tag + match pem_data.tag() { + "EC PUBLIC KEY" | "PUBLIC KEY" if pem.contains("ECDSA") || pem.contains("ecdsa") => { if pem.contains("P-256") { - return Ok((&ECDSA_P256_SHA256_ASN1, decoded)); + return Ok((&ECDSA_P256_SHA256_ASN1, contents)); } else if pem.contains("P-384") { - return Ok((&ECDSA_P384_SHA384_ASN1, decoded)); + return Ok((&ECDSA_P384_SHA384_ASN1, contents)); } } - // Default to RSA if no specific algorithm detected - return Ok((&RSA_PKCS1_2048_8192_SHA256, decoded)); + "RSA PUBLIC KEY" | "PUBLIC KEY" => { + return Ok((&RSA_PKCS1_2048_8192_SHA256, contents)); + } + _ => {} } + + // Default to RSA for unknown types + return Ok((&RSA_PKCS1_2048_8192_SHA256, contents)); } - Err(TcpTargetError::from("Invalid DSA public key format")) + Err(TcpTargetError::Crypto( + "Invalid DSA public key format".to_string(), + )) } /// Parse DSA private key information from PEM @@ -559,13 +611,51 @@ fn verify_dsa_signature( /// Sign with DSA (simplified - in practice this would use proper private key operations) fn sign_with_dsa( - _algorithm_and_key: &(&'static dyn signature::VerificationAlgorithm, Vec<u8>), + algorithm_and_key: &(&'static dyn signature::VerificationAlgorithm, Vec<u8>), message: &[u8], -) -> Vec<u8> { - // Note: This is a simplified implementation. In a real scenario, - // you would use proper private key signing operations with ring or other crypto library. - // For now, we'll return a dummy signature for demonstration. - let mut signature = vec![0u8; 64]; - signature[..32].copy_from_slice(&message[..32.min(message.len())]); - signature +) -> Result<Vec<u8>, TcpTargetError> { + let (algorithm, key_bytes) = algorithm_and_key; + + // Handle different DSA/ECDSA algorithms by comparing algorithm identifiers + // Since we can't directly compare trait objects, we use pointer comparison + let algorithm_ptr = algorithm as *const _ as *const (); + let ecdsa_p256_ptr = &ECDSA_P256_SHA256_ASN1 as *const _ as *const (); + let ecdsa_p384_ptr = &ECDSA_P384_SHA384_ASN1 as *const _ as *const (); + + if algorithm_ptr == ecdsa_p256_ptr { + let key_pair = EcdsaKeyPair::from_pkcs8( + ECDSA_P256_SHA256_ASN1_SIGNING, + &key_bytes, + &SystemRandom::new(), + ) + .map_err(|e| { + TcpTargetError::Crypto(format!("Failed to create ECDSA P-256 key pair: {}", e)) + })?; + + let signature = key_pair + .sign(&SystemRandom::new(), message) + .map_err(|e| TcpTargetError::Crypto(format!("ECDSA P-256 signing failed: {}", e)))?; + + Ok(signature.as_ref().to_vec()) + } else if algorithm_ptr == ecdsa_p384_ptr { + let key_pair = EcdsaKeyPair::from_pkcs8( + ECDSA_P384_SHA384_ASN1_SIGNING, + &key_bytes, + &SystemRandom::new(), + ) + .map_err(|e| { + TcpTargetError::Crypto(format!("Failed to create ECDSA P-384 key pair: {}", e)) + })?; + + let signature = key_pair + .sign(&SystemRandom::new(), message) + .map_err(|e| TcpTargetError::Crypto(format!("ECDSA P-384 signing failed: {}", e)))?; + + Ok(signature.as_ref().to_vec()) + } else { + // RSA or unsupported algorithm + Err(TcpTargetError::Unsupported( + "DSA/ECDSA signing not supported for this algorithm type".to_string(), + )) + } } diff --git a/crates/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs b/crates/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs index ac57451..57d3819 100644 --- a/crates/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs +++ b/crates/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs @@ -1,7 +1,4 @@ -use std::{ - env::{current_dir, set_current_dir}, - time::Duration, -}; +use std::{env::current_dir, time::Duration}; use tcp_connection::{ handle::{ClientHandle, ServerHandle}, @@ -18,7 +15,9 @@ impl ClientHandle<ExampleChallengeServerHandle> for ExampleChallengeClientHandle mut instance: ConnectionInstance, ) -> impl std::future::Future<Output = ()> + Send + Sync { async move { - // TODO :: Complete the implementation + let key = current_dir().unwrap().join("res").join("test_key"); + let result = instance.accept_challenge(key, "test_key").await.unwrap(); + assert_eq!(true, result); } } } @@ -30,7 +29,9 @@ impl ServerHandle<ExampleChallengeClientHandle> for ExampleChallengeServerHandle mut instance: ConnectionInstance, ) -> impl std::future::Future<Output = ()> + Send + Sync { async move { - // TODO :: Complete the implementation + let key_dir = current_dir().unwrap().join("res"); + let result = instance.challenge(key_dir).await.unwrap(); + assert_eq!(true, result); } } } @@ -39,9 +40,6 @@ impl ServerHandle<ExampleChallengeClientHandle> for ExampleChallengeServerHandle async fn test_connection_with_challenge_handle() -> Result<(), std::io::Error> { let host = "localhost"; - // Enter temp directory - set_current_dir(current_dir().unwrap().join(".temp/"))?; - // Server setup let Ok(server_target) = TcpServerTarget::< ExampleChallengeClientHandle, |
