diff options
Diffstat (limited to 'utils/tcp_connection/src')
| -rw-r--r-- | utils/tcp_connection/src/error.rs | 122 | ||||
| -rw-r--r-- | utils/tcp_connection/src/instance.rs | 542 | ||||
| -rw-r--r-- | utils/tcp_connection/src/instance_challenge.rs | 311 | ||||
| -rw-r--r-- | utils/tcp_connection/src/lib.rs | 6 |
4 files changed, 981 insertions, 0 deletions
diff --git a/utils/tcp_connection/src/error.rs b/utils/tcp_connection/src/error.rs new file mode 100644 index 0000000..32d06cc --- /dev/null +++ b/utils/tcp_connection/src/error.rs @@ -0,0 +1,122 @@ +use std::io; +use thiserror::Error; + +#[derive(Error, Debug, Clone)] +pub enum TcpTargetError { + #[error("Authentication failed: {0}")] + Authentication(String), + + #[error("Reference sheet not allowed: {0}")] + ReferenceSheetNotAllowed(String), + + #[error("Cryptographic error: {0}")] + Crypto(String), + + #[error("File operation error: {0}")] + File(String), + + #[error("I/O error: {0}")] + Io(String), + + #[error("Invalid configuration: {0}")] + Config(String), + + #[error("Locked: {0}")] + Locked(String), + + #[error("Network error: {0}")] + Network(String), + + #[error("No result: {0}")] + NoResult(String), + + #[error("Not found: {0}")] + NotFound(String), + + #[error("Not local machine: {0}")] + NotLocal(String), + + #[error("Not remote machine: {0}")] + NotRemote(String), + + #[error("Pool already exists: {0}")] + PoolAlreadyExists(String), + + #[error("Protocol error: {0}")] + Protocol(String), + + #[error("Serialization error: {0}")] + Serialization(String), + + #[error("Timeout: {0}")] + Timeout(String), + + #[error("Unsupported operation: {0}")] + Unsupported(String), +} + +impl From<io::Error> for TcpTargetError { + fn from(error: io::Error) -> Self { + TcpTargetError::Io(error.to_string()) + } +} + +impl From<serde_json::Error> for TcpTargetError { + fn from(error: serde_json::Error) -> Self { + TcpTargetError::Serialization(error.to_string()) + } +} + +impl From<&str> for TcpTargetError { + fn from(value: &str) -> Self { + TcpTargetError::Protocol(value.to_string()) + } +} + +impl From<String> for TcpTargetError { + fn from(value: String) -> Self { + TcpTargetError::Protocol(value) + } +} + +impl From<rsa::errors::Error> for TcpTargetError { + fn from(error: rsa::errors::Error) -> Self { + TcpTargetError::Crypto(error.to_string()) + } +} + +impl From<ed25519_dalek::SignatureError> for TcpTargetError { + fn from(error: ed25519_dalek::SignatureError) -> Self { + TcpTargetError::Crypto(error.to_string()) + } +} + +impl From<ring::error::Unspecified> for TcpTargetError { + fn from(error: ring::error::Unspecified) -> Self { + TcpTargetError::Crypto(error.to_string()) + } +} + +impl From<base64::DecodeError> for TcpTargetError { + fn from(error: base64::DecodeError) -> Self { + TcpTargetError::Serialization(error.to_string()) + } +} + +impl From<pem::PemError> for TcpTargetError { + fn from(error: pem::PemError) -> Self { + TcpTargetError::Crypto(error.to_string()) + } +} + +impl From<rmp_serde::encode::Error> for TcpTargetError { + fn from(error: rmp_serde::encode::Error) -> Self { + TcpTargetError::Serialization(error.to_string()) + } +} + +impl From<rmp_serde::decode::Error> for TcpTargetError { + fn from(error: rmp_serde::decode::Error) -> Self { + TcpTargetError::Serialization(error.to_string()) + } +} diff --git a/utils/tcp_connection/src/instance.rs b/utils/tcp_connection/src/instance.rs new file mode 100644 index 0000000..8e6886c --- /dev/null +++ b/utils/tcp_connection/src/instance.rs @@ -0,0 +1,542 @@ +use std::{path::Path, time::Duration}; + +use serde::Serialize; +use tokio::{ + fs::{File, OpenOptions}, + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}, + net::TcpStream, +}; + +use ring::signature::{self}; + +use crate::error::TcpTargetError; + +const DEFAULT_CHUNK_SIZE: usize = 4096; +const DEFAULT_TIMEOUT_SECS: u64 = 10; + +const ECDSA_P256_SHA256_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm = + &signature::ECDSA_P256_SHA256_ASN1_SIGNING; +const ECDSA_P384_SHA384_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm = + &signature::ECDSA_P384_SHA384_ASN1_SIGNING; + +#[derive(Debug, Clone)] +pub struct ConnectionConfig { + pub chunk_size: usize, + pub timeout_secs: u64, + pub enable_crc_validation: bool, +} + +impl Default for ConnectionConfig { + fn default() -> Self { + Self { + chunk_size: DEFAULT_CHUNK_SIZE, + timeout_secs: DEFAULT_TIMEOUT_SECS, + enable_crc_validation: false, + } + } +} + +pub struct ConnectionInstance { + pub(crate) stream: TcpStream, + config: ConnectionConfig, +} + +impl From<TcpStream> for ConnectionInstance { + fn from(stream: TcpStream) -> Self { + Self { + stream, + config: ConnectionConfig::default(), + } + } +} + +impl ConnectionInstance { + /// Create a new ConnectionInstance with custom configuration + pub fn with_config(stream: TcpStream, config: ConnectionConfig) -> Self { + Self { stream, config } + } + + /// Get a reference to the current configuration + pub fn config(&self) -> &ConnectionConfig { + &self.config + } + + /// Get a mutable reference to the current configuration + pub fn config_mut(&mut self) -> &mut ConnectionConfig { + &mut self.config + } + /// Serialize data and write to the target machine + pub async fn write<Data>(&mut self, data: Data) -> Result<(), TcpTargetError> + where + Data: Default + Serialize, + { + let Ok(json_text) = serde_json::to_string(&data) else { + return Err(TcpTargetError::Serialization( + "Serialize failed.".to_string(), + )); + }; + Self::write_text(self, json_text).await?; + Ok(()) + } + + /// Serialize data to MessagePack and write to the target machine + pub async fn write_msgpack<Data>(&mut self, data: Data) -> Result<(), TcpTargetError> + where + Data: Serialize, + { + let msgpack_data = rmp_serde::to_vec(&data)?; + let len = msgpack_data.len() as u32; + + self.stream.write_all(&len.to_be_bytes()).await?; + self.stream.write_all(&msgpack_data).await?; + Ok(()) + } + + /// Read data from target machine and deserialize from MessagePack + pub async fn read_msgpack<Data>(&mut self) -> Result<Data, TcpTargetError> + where + Data: serde::de::DeserializeOwned, + { + let mut len_buf = [0u8; 4]; + self.stream.read_exact(&mut len_buf).await?; + let len = u32::from_be_bytes(len_buf) as usize; + + let mut buffer = vec![0; len]; + self.stream.read_exact(&mut buffer).await?; + + let data = rmp_serde::from_slice(&buffer)?; + Ok(data) + } + + /// Read data from target machine and deserialize + pub async fn read<Data>(&mut self) -> Result<Data, TcpTargetError> + where + Data: Default + serde::de::DeserializeOwned, + { + let Ok(json_text) = Self::read_text(self).await else { + return Err(TcpTargetError::Io("Read failed.".to_string())); + }; + let Ok(deser_obj) = serde_json::from_str::<Data>(&json_text) else { + return Err(TcpTargetError::Serialization( + "Deserialize failed.".to_string(), + )); + }; + Ok(deser_obj) + } + + /// Serialize data and write to the target machine + pub async fn write_large<Data>(&mut self, data: Data) -> Result<(), TcpTargetError> + where + Data: Default + Serialize, + { + let Ok(json_text) = serde_json::to_string(&data) else { + return Err(TcpTargetError::Serialization( + "Serialize failed.".to_string(), + )); + }; + Self::write_large_text(self, json_text).await?; + Ok(()) + } + + /// Read data from target machine and deserialize + pub async fn read_large<Data>( + &mut self, + buffer_size: impl Into<u32>, + ) -> Result<Data, TcpTargetError> + where + Data: Default + serde::de::DeserializeOwned, + { + let Ok(json_text) = Self::read_large_text(self, buffer_size).await else { + return Err(TcpTargetError::Io("Read failed.".to_string())); + }; + let Ok(deser_obj) = serde_json::from_str::<Data>(&json_text) else { + return Err(TcpTargetError::Serialization( + "Deserialize failed.".to_string(), + )); + }; + Ok(deser_obj) + } + + /// Write text to the target machine + pub async fn write_text(&mut self, text: impl Into<String>) -> Result<(), TcpTargetError> { + let text = text.into(); + let bytes = text.as_bytes(); + let len = bytes.len() as u32; + + self.stream.write_all(&len.to_be_bytes()).await?; + match self.stream.write_all(bytes).await { + Ok(_) => Ok(()), + Err(err) => Err(TcpTargetError::Io(err.to_string())), + } + } + + /// Read text from the target machine + pub async fn read_text(&mut self) -> Result<String, TcpTargetError> { + let mut len_buf = [0u8; 4]; + self.stream.read_exact(&mut len_buf).await?; + let len = u32::from_be_bytes(len_buf) as usize; + + let mut buffer = vec![0; len]; + self.stream.read_exact(&mut buffer).await?; + + match String::from_utf8(buffer) { + Ok(text) => Ok(text), + Err(err) => Err(TcpTargetError::Serialization(format!( + "Invalid UTF-8 sequence: {}", + err + ))), + } + } + + /// Write large text to the target machine (chunked) + pub async fn write_large_text( + &mut self, + text: impl Into<String>, + ) -> Result<(), TcpTargetError> { + let text = text.into(); + let bytes = text.as_bytes(); + let mut offset = 0; + + while offset < bytes.len() { + let chunk = &bytes[offset..]; + let written = match self.stream.write(chunk).await { + Ok(n) => n, + Err(err) => return Err(TcpTargetError::Io(err.to_string())), + }; + offset += written; + } + + Ok(()) + } + + /// Read large text from the target machine (chunked) + pub async fn read_large_text( + &mut self, + chunk_size: impl Into<u32>, + ) -> Result<String, TcpTargetError> { + let chunk_size = chunk_size.into() as usize; + let mut buffer = Vec::new(); + let mut chunk_buf = vec![0; chunk_size]; + + loop { + match self.stream.read(&mut chunk_buf).await { + Ok(0) => break, // EOF + Ok(n) => { + buffer.extend_from_slice(&chunk_buf[..n]); + } + Err(err) => return Err(TcpTargetError::Io(err.to_string())), + } + } + + Ok(String::from_utf8_lossy(&buffer).to_string()) + } + + /// Write large MessagePack data to the target machine (chunked) + pub async fn write_large_msgpack<Data>( + &mut self, + data: Data, + chunk_size: impl Into<u32>, + ) -> Result<(), TcpTargetError> + where + Data: Serialize, + { + let msgpack_data = rmp_serde::to_vec(&data)?; + let chunk_size = chunk_size.into() as usize; + let len = msgpack_data.len() as u32; + + // Write total length first + self.stream.write_all(&len.to_be_bytes()).await?; + + // Write data in chunks + let mut offset = 0; + while offset < msgpack_data.len() { + let end = std::cmp::min(offset + chunk_size, msgpack_data.len()); + let chunk = &msgpack_data[offset..end]; + match self.stream.write(chunk).await { + Ok(n) => offset += n, + Err(err) => return Err(TcpTargetError::Io(err.to_string())), + } + } + + Ok(()) + } + + /// Read large MessagePack data from the target machine (chunked) + pub async fn read_large_msgpack<Data>( + &mut self, + chunk_size: impl Into<u32>, + ) -> Result<Data, TcpTargetError> + where + Data: serde::de::DeserializeOwned, + { + let chunk_size = chunk_size.into() as usize; + + // Read total length first + let mut len_buf = [0u8; 4]; + self.stream.read_exact(&mut len_buf).await?; + let total_len = u32::from_be_bytes(len_buf) as usize; + + // Read data in chunks + let mut buffer = Vec::with_capacity(total_len); + let mut remaining = total_len; + let mut chunk_buf = vec![0; chunk_size]; + + while remaining > 0 { + let read_size = std::cmp::min(chunk_size, remaining); + let chunk = &mut chunk_buf[..read_size]; + + match self.stream.read_exact(chunk).await { + Ok(_) => { + buffer.extend_from_slice(chunk); + remaining -= read_size; + } + Err(err) => return Err(TcpTargetError::Io(err.to_string())), + } + } + + let data = rmp_serde::from_slice(&buffer)?; + Ok(data) + } + + /// Write file to target machine. + pub async fn write_file(&mut self, file_path: impl AsRef<Path>) -> Result<(), TcpTargetError> { + let path = file_path.as_ref(); + + // Validate file + if !path.exists() { + return Err(TcpTargetError::File(format!( + "File not found: {}", + path.display() + ))); + } + if path.is_dir() { + return Err(TcpTargetError::File(format!( + "Path is directory: {}", + path.display() + ))); + } + + // Open file and get metadata + let mut file = File::open(path).await?; + let file_size = file.metadata().await?.len(); + + // Send file header (version + size + crc) + self.stream.write_all(&1u64.to_be_bytes()).await?; + self.stream.write_all(&file_size.to_be_bytes()).await?; + + // Calculate and send CRC32 if enabled + let file_crc = if self.config.enable_crc_validation { + let crc32 = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC); + let mut crc_calculator = crc32.digest(); + + let mut temp_reader = + BufReader::with_capacity(self.config.chunk_size, File::open(path).await?); + let mut temp_buffer = vec![0u8; self.config.chunk_size]; + let mut temp_bytes_read = 0; + + while temp_bytes_read < file_size { + let bytes_to_read = + (file_size - temp_bytes_read).min(self.config.chunk_size as u64) as usize; + temp_reader + .read_exact(&mut temp_buffer[..bytes_to_read]) + .await?; + crc_calculator.update(&temp_buffer[..bytes_to_read]); + temp_bytes_read += bytes_to_read as u64; + } + + crc_calculator.finalize() + } else { + 0 + }; + + self.stream.write_all(&file_crc.to_be_bytes()).await?; + + // If file size is 0, skip content transfer + if file_size == 0 { + self.stream.flush().await?; + + // Wait for receiver confirmation + let mut ack = [0u8; 1]; + tokio::time::timeout( + Duration::from_secs(self.config.timeout_secs), + self.stream.read_exact(&mut ack), + ) + .await + .map_err(|_| TcpTargetError::Timeout("Ack timeout".to_string()))??; + + if ack[0] != 1 { + return Err(TcpTargetError::Protocol( + "Receiver verification failed".to_string(), + )); + } + + return Ok(()); + } + + // Transfer file content + let mut reader = BufReader::with_capacity(self.config.chunk_size, &mut file); + let mut bytes_sent = 0; + + while bytes_sent < file_size { + let buffer = reader.fill_buf().await?; + if buffer.is_empty() { + break; + } + + let chunk_size = buffer.len().min((file_size - bytes_sent) as usize); + self.stream.write_all(&buffer[..chunk_size]).await?; + reader.consume(chunk_size); + + bytes_sent += chunk_size as u64; + } + + // Verify transfer completion + if bytes_sent != file_size { + return Err(TcpTargetError::File(format!( + "Transfer incomplete: expected {} bytes, sent {} bytes", + file_size, bytes_sent + ))); + } + + self.stream.flush().await?; + + // Wait for receiver confirmation + let mut ack = [0u8; 1]; + tokio::time::timeout( + Duration::from_secs(self.config.timeout_secs), + self.stream.read_exact(&mut ack), + ) + .await + .map_err(|_| TcpTargetError::Timeout("Ack timeout".to_string()))??; + + if ack[0] != 1 { + return Err(TcpTargetError::Protocol( + "Receiver verification failed".to_string(), + )); + } + + Ok(()) + } + + /// Read file from target machine + pub async fn read_file(&mut self, save_path: impl AsRef<Path>) -> Result<(), TcpTargetError> { + let path = save_path.as_ref(); + // Create CRC instance at function scope to ensure proper lifetime + let crc_instance = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC); + + // Make sure parent directory exists + if let Some(parent) = path.parent() + && !parent.exists() + { + tokio::fs::create_dir_all(parent).await?; + } + + // Read file header (version + size + crc) + let mut version_buf = [0u8; 8]; + self.stream.read_exact(&mut version_buf).await?; + let version = u64::from_be_bytes(version_buf); + if version != 1 { + return Err(TcpTargetError::Protocol( + "Unsupported transfer version".to_string(), + )); + } + + let mut size_buf = [0u8; 8]; + self.stream.read_exact(&mut size_buf).await?; + let file_size = u64::from_be_bytes(size_buf); + + let mut expected_crc_buf = [0u8; 4]; + self.stream.read_exact(&mut expected_crc_buf).await?; + let expected_crc = u32::from_be_bytes(expected_crc_buf); + if file_size == 0 { + // Create empty file and return early + let _file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .await?; + // Send confirmation + self.stream.write_all(&[1u8]).await?; + self.stream.flush().await?; + return Ok(()); + } + + // Prepare output file + let file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .await?; + let mut writer = BufWriter::with_capacity(self.config.chunk_size, file); + + // Receive file content with CRC calculation if enabled + let mut bytes_received = 0; + let mut buffer = vec![0u8; self.config.chunk_size]; + let mut crc_calculator = if self.config.enable_crc_validation { + Some(crc_instance.digest()) + } else { + None + }; + + while bytes_received < file_size { + let bytes_to_read = + (file_size - bytes_received).min(self.config.chunk_size as u64) as usize; + let chunk = &mut buffer[..bytes_to_read]; + + self.stream.read_exact(chunk).await?; + + writer.write_all(chunk).await?; + + // Update CRC if validation is enabled + if let Some(ref mut crc) = crc_calculator { + crc.update(chunk); + } + + bytes_received += bytes_to_read as u64; + } + + // Verify transfer completion + if bytes_received != file_size { + return Err(TcpTargetError::File(format!( + "Transfer incomplete: expected {} bytes, received {} bytes", + file_size, bytes_received + ))); + } + + writer.flush().await?; + + // Validate CRC if enabled + if self.config.enable_crc_validation + && let Some(crc_calculator) = crc_calculator + { + let actual_crc = crc_calculator.finalize(); + if actual_crc != expected_crc && expected_crc != 0 { + return Err(TcpTargetError::File(format!( + "CRC validation failed: expected {:08x}, got {:08x}", + expected_crc, actual_crc + ))); + } + } + + // Final flush and sync + writer.flush().await?; + writer.into_inner().sync_all().await?; + + // Verify completion + if bytes_received != file_size { + let _ = tokio::fs::remove_file(path).await; + return Err(TcpTargetError::File(format!( + "Transfer incomplete: expected {} bytes, received {} bytes", + file_size, bytes_received + ))); + } + + // Send confirmation + self.stream.write_all(&[1u8]).await?; + self.stream.flush().await?; + + Ok(()) + } +} diff --git a/utils/tcp_connection/src/instance_challenge.rs b/utils/tcp_connection/src/instance_challenge.rs new file mode 100644 index 0000000..3a7f6a3 --- /dev/null +++ b/utils/tcp_connection/src/instance_challenge.rs @@ -0,0 +1,311 @@ +use std::path::Path; + +use rand::TryRngCore; +use rsa::{ + RsaPrivateKey, RsaPublicKey, + pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey}, + sha2, +}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey}; +use ring::rand::SystemRandom; +use ring::signature::{ + self, ECDSA_P256_SHA256_ASN1, ECDSA_P384_SHA384_ASN1, EcdsaKeyPair, RSA_PKCS1_2048_8192_SHA256, + UnparsedPublicKey, +}; + +use crate::{error::TcpTargetError, instance::ConnectionInstance}; + +const ECDSA_P256_SHA256_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm = + &signature::ECDSA_P256_SHA256_ASN1_SIGNING; +const ECDSA_P384_SHA384_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm = + &signature::ECDSA_P384_SHA384_ASN1_SIGNING; + +impl ConnectionInstance { + /// Initiates a challenge to the target machine to verify connection security + /// + /// This method performs a cryptographic challenge-response authentication: + /// 1. Generates a random 32-byte challenge + /// 2. Sends the challenge to the target machine + /// 3. Receives a digital signature of the challenge + /// 4. Verifies the signature using the appropriate public key + /// + /// # Arguments + /// * `public_key_dir` - Directory containing public key files for verification + /// + /// # Returns + /// * `Ok((true, "KeyId"))` - Challenge verification successful + /// * `Ok((false, "KeyId"))` - Challenge verification failed + /// * `Err(TcpTargetError)` - Error during challenge process + pub async fn challenge( + &mut self, + public_key_dir: impl AsRef<Path>, + ) -> Result<(bool, String), TcpTargetError> { + // Generate random challenge + let mut challenge = [0u8; 32]; + rand::rngs::OsRng + .try_fill_bytes(&mut challenge) + .map_err(|e| { + TcpTargetError::Crypto(format!("Failed to generate random challenge: {}", e)) + })?; + + // Send challenge to target + self.stream.write_all(&challenge).await?; + self.stream.flush().await?; + + // Read signature from target + let mut signature = Vec::new(); + let mut signature_len_buf = [0u8; 4]; + self.stream.read_exact(&mut signature_len_buf).await?; + + let signature_len = u32::from_be_bytes(signature_len_buf) as usize; + signature.resize(signature_len, 0); + self.stream.read_exact(&mut signature).await?; + + // Read key identifier from target to identify which public key to use + let mut key_id_len_buf = [0u8; 4]; + self.stream.read_exact(&mut key_id_len_buf).await?; + let key_id_len = u32::from_be_bytes(key_id_len_buf) as usize; + + let mut key_id_buf = vec![0u8; key_id_len]; + self.stream.read_exact(&mut key_id_buf).await?; + let key_id = String::from_utf8(key_id_buf) + .map_err(|e| TcpTargetError::Crypto(format!("Invalid key identifier: {}", e)))?; + + // Load appropriate public key + let public_key_path = public_key_dir.as_ref().join(format!("{}.pem", key_id)); + if !public_key_path.exists() { + return Ok((false, key_id)); + } + + let public_key_pem = tokio::fs::read_to_string(&public_key_path).await?; + + // Try to verify with different key types + let verified = if let Ok(rsa_key) = RsaPublicKey::from_pkcs1_pem(&public_key_pem) { + let padding = rsa::pkcs1v15::Pkcs1v15Sign::new::<sha2::Sha256>(); + rsa_key.verify(padding, &challenge, &signature).is_ok() + } else if let Ok(ed25519_key) = + VerifyingKey::from_bytes(&parse_ed25519_public_key(&public_key_pem)) + { + if signature.len() == 64 { + let sig_bytes: [u8; 64] = signature.as_slice().try_into().map_err(|_| { + TcpTargetError::Crypto("Invalid signature length for Ed25519".to_string()) + })?; + let sig = Signature::from_bytes(&sig_bytes); + ed25519_key.verify(&challenge, &sig).is_ok() + } else { + false + } + } else if let Ok(dsa_key_info) = parse_dsa_public_key(&public_key_pem) { + verify_dsa_signature(&dsa_key_info, &challenge, &signature) + } else { + false + }; + + Ok((verified, key_id)) + } + + /// Accepts a challenge from the target machine to verify connection security + /// + /// This method performs a cryptographic challenge-response authentication: + /// 1. Receives a random 32-byte challenge from the target machine + /// 2. Signs the challenge using the appropriate private key + /// 3. Sends the digital signature back to the target machine + /// 4. Sends the key identifier for public key verification + /// + /// # Arguments + /// * `private_key_file` - Path to the private key file for signing + /// * `verify_public_key` - Key identifier for public key verification + /// + /// # Returns + /// * `Ok(true)` - Challenge response sent successfully + /// * `Ok(false)` - Private key format not supported + /// * `Err(TcpTargetError)` - Error during challenge response process + pub async fn accept_challenge( + &mut self, + private_key_file: impl AsRef<Path>, + verify_public_key: &str, + ) -> Result<bool, TcpTargetError> { + // Read challenge from initiator + let mut challenge = [0u8; 32]; + self.stream.read_exact(&mut challenge).await?; + + // Load private key + let private_key_pem = tokio::fs::read_to_string(&private_key_file) + .await + .map_err(|e| { + TcpTargetError::NotFound(format!( + "Read private key \"{}\" failed: \"{}\"", + private_key_file + .as_ref() + .display() + .to_string() + .split("/") + .last() + .unwrap_or("UNKNOWN"), + e + )) + })?; + + // Sign the challenge with supported key types + let signature = if let Ok(rsa_key) = RsaPrivateKey::from_pkcs1_pem(&private_key_pem) { + let padding = rsa::pkcs1v15::Pkcs1v15Sign::new::<sha2::Sha256>(); + rsa_key.sign(padding, &challenge)? + } else if let Ok(ed25519_key) = parse_ed25519_private_key(&private_key_pem) { + ed25519_key.sign(&challenge).to_bytes().to_vec() + } else if let Ok(dsa_key_info) = parse_dsa_private_key(&private_key_pem) { + sign_with_dsa(&dsa_key_info, &challenge)? + } else { + return Ok(false); + }; + + // Send signature length and signature + let signature_len = signature.len() as u32; + self.stream.write_all(&signature_len.to_be_bytes()).await?; + self.stream.flush().await?; + self.stream.write_all(&signature).await?; + self.stream.flush().await?; + + // Send key identifier for public key identification + let key_id_bytes = verify_public_key.as_bytes(); + let key_id_len = key_id_bytes.len() as u32; + self.stream.write_all(&key_id_len.to_be_bytes()).await?; + self.stream.flush().await?; + self.stream.write_all(key_id_bytes).await?; + self.stream.flush().await?; + + Ok(true) + } +} + +/// Parse Ed25519 public key from PEM format +fn parse_ed25519_public_key(pem: &str) -> [u8; 32] { + // Robust parsing for Ed25519 public key using pem crate + let mut key_bytes = [0u8; 32]; + + if let Ok(pem_data) = pem::parse(pem) + && pem_data.tag() == "PUBLIC KEY" + && pem_data.contents().len() >= 32 + { + let contents = pem_data.contents(); + key_bytes.copy_from_slice(&contents[contents.len() - 32..]); + } + key_bytes +} + +/// Parse Ed25519 private key from PEM format +fn parse_ed25519_private_key(pem: &str) -> Result<SigningKey, TcpTargetError> { + if let Ok(pem_data) = pem::parse(pem) + && pem_data.tag() == "PRIVATE KEY" + && pem_data.contents().len() >= 32 + { + let contents = pem_data.contents(); + let mut seed = [0u8; 32]; + seed.copy_from_slice(&contents[contents.len() - 32..]); + return Ok(SigningKey::from_bytes(&seed)); + } + Err(TcpTargetError::Crypto( + "Invalid Ed25519 private key format".to_string(), + )) +} + +/// Parse DSA public key information from PEM +fn parse_dsa_public_key( + pem: &str, +) -> Result<(&'static dyn signature::VerificationAlgorithm, Vec<u8>), TcpTargetError> { + if let Ok(pem_data) = pem::parse(pem) { + let contents = pem_data.contents().to_vec(); + + // Try different DSA algorithms based on PEM tag + match pem_data.tag() { + "EC PUBLIC KEY" | "PUBLIC KEY" if pem.contains("ECDSA") || pem.contains("ecdsa") => { + if pem.contains("P-256") { + return Ok((&ECDSA_P256_SHA256_ASN1, contents)); + } else if pem.contains("P-384") { + return Ok((&ECDSA_P384_SHA384_ASN1, contents)); + } + } + "RSA PUBLIC KEY" | "PUBLIC KEY" => { + return Ok((&RSA_PKCS1_2048_8192_SHA256, contents)); + } + _ => {} + } + + // Default to RSA for unknown types + return Ok((&RSA_PKCS1_2048_8192_SHA256, contents)); + } + Err(TcpTargetError::Crypto( + "Invalid DSA public key format".to_string(), + )) +} + +/// Parse DSA private key information from PEM +fn parse_dsa_private_key( + pem: &str, +) -> Result<(&'static dyn signature::VerificationAlgorithm, Vec<u8>), TcpTargetError> { + // For DSA, private key verification uses the same algorithm as public key + parse_dsa_public_key(pem) +} + +/// Verify DSA signature +fn verify_dsa_signature( + algorithm_and_key: &(&'static dyn signature::VerificationAlgorithm, Vec<u8>), + message: &[u8], + signature: &[u8], +) -> bool { + let (algorithm, key_bytes) = algorithm_and_key; + let public_key = UnparsedPublicKey::new(*algorithm, key_bytes); + public_key.verify(message, signature).is_ok() +} + +/// Sign with DSA +fn sign_with_dsa( + algorithm_and_key: &(&'static dyn signature::VerificationAlgorithm, Vec<u8>), + message: &[u8], +) -> Result<Vec<u8>, TcpTargetError> { + let (algorithm, key_bytes) = algorithm_and_key; + + // Handle different DSA/ECDSA algorithms by comparing algorithm identifiers + // Since we can't directly compare trait objects, we use pointer comparison + let algorithm_ptr = algorithm as *const _ as *const (); + let ecdsa_p256_ptr = &ECDSA_P256_SHA256_ASN1 as *const _ as *const (); + let ecdsa_p384_ptr = &ECDSA_P384_SHA384_ASN1 as *const _ as *const (); + + if algorithm_ptr == ecdsa_p256_ptr { + let key_pair = EcdsaKeyPair::from_pkcs8( + ECDSA_P256_SHA256_ASN1_SIGNING, + key_bytes, + &SystemRandom::new(), + ) + .map_err(|e| { + TcpTargetError::Crypto(format!("Failed to create ECDSA P-256 key pair: {}", e)) + })?; + + let signature = key_pair + .sign(&SystemRandom::new(), message) + .map_err(|e| TcpTargetError::Crypto(format!("ECDSA P-256 signing failed: {}", e)))?; + + Ok(signature.as_ref().to_vec()) + } else if algorithm_ptr == ecdsa_p384_ptr { + let key_pair = EcdsaKeyPair::from_pkcs8( + ECDSA_P384_SHA384_ASN1_SIGNING, + key_bytes, + &SystemRandom::new(), + ) + .map_err(|e| { + TcpTargetError::Crypto(format!("Failed to create ECDSA P-384 key pair: {}", e)) + })?; + + let signature = key_pair + .sign(&SystemRandom::new(), message) + .map_err(|e| TcpTargetError::Crypto(format!("ECDSA P-384 signing failed: {}", e)))?; + + Ok(signature.as_ref().to_vec()) + } else { + // RSA or unsupported algorithm + Err(TcpTargetError::Unsupported( + "DSA/ECDSA signing not supported for this algorithm type".to_string(), + )) + } +} diff --git a/utils/tcp_connection/src/lib.rs b/utils/tcp_connection/src/lib.rs new file mode 100644 index 0000000..6a2e599 --- /dev/null +++ b/utils/tcp_connection/src/lib.rs @@ -0,0 +1,6 @@ +#[allow(dead_code)] +pub mod instance; + +pub mod instance_challenge; + +pub mod error; |
