summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author魏曹先生 <1992414357@qq.com>2025-09-21 18:10:26 +0800
committer魏曹先生 <1992414357@qq.com>2025-09-21 18:10:26 +0800
commitf271ca59abf8b3eba7db1788e38ea9efe739262a (patch)
treedc9bf1f01ab6e6c889ce4eae07a10cd6e2d404b0
parentc730a220232d6343e50aadc6d6c37b308215e401 (diff)
Complete Challenge
-rw-r--r--crates/utils/tcp_connection/src/instance.rs488
-rw-r--r--crates/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs16
2 files changed, 296 insertions, 208 deletions
diff --git a/crates/utils/tcp_connection/src/instance.rs b/crates/utils/tcp_connection/src/instance.rs
index db36b54..be7c956 100644
--- a/crates/utils/tcp_connection/src/instance.rs
+++ b/crates/utils/tcp_connection/src/instance.rs
@@ -1,7 +1,6 @@
use std::{path::Path, time::Duration};
-use base64::{engine::general_purpose::STANDARD, prelude::*};
-use rand::Rng;
+use rand::{Rng, TryRngCore};
use rsa::{
RsaPrivateKey, RsaPublicKey,
pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey},
@@ -16,46 +15,77 @@ use tokio::{
use uuid::Uuid;
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
+use ring::rand::SystemRandom;
use ring::signature::{
- self, ECDSA_P256_SHA256_ASN1, ECDSA_P384_SHA384_ASN1, RSA_PKCS1_2048_8192_SHA256,
+ self, ECDSA_P256_SHA256_ASN1, ECDSA_P384_SHA384_ASN1, EcdsaKeyPair, RSA_PKCS1_2048_8192_SHA256,
UnparsedPublicKey,
};
use crate::error::TcpTargetError;
-const CHUNK_SIZE: usize = 8 * 1024;
+const DEFAULT_CHUNK_SIZE: usize = 4096;
+const DEFAULT_TIMEOUT_SECS: u64 = 10;
-pub struct ConnectionInstance {
- stream: TcpStream,
+const ECDSA_P256_SHA256_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm =
+ &signature::ECDSA_P256_SHA256_ASN1_SIGNING;
+const ECDSA_P384_SHA384_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm =
+ &signature::ECDSA_P384_SHA384_ASN1_SIGNING;
+
+#[derive(Debug, Clone)]
+pub struct ConnectionConfig {
+ pub chunk_size: usize,
+ pub timeout_secs: u64,
+ pub enable_crc_validation: bool,
}
-impl From<TcpStream> for ConnectionInstance {
- fn from(value: TcpStream) -> Self {
- Self { stream: value }
+impl Default for ConnectionConfig {
+ fn default() -> Self {
+ Self {
+ chunk_size: DEFAULT_CHUNK_SIZE,
+ timeout_secs: DEFAULT_TIMEOUT_SECS,
+ enable_crc_validation: false,
+ }
}
}
-// Helper trait for reading u64 from TcpStream
-trait ReadU64Ext {
- async fn read_u64(&mut self) -> Result<u64, std::io::Error>;
+pub struct ConnectionInstance {
+ stream: TcpStream,
+ config: ConnectionConfig,
}
-impl ReadU64Ext for TcpStream {
- async fn read_u64(&mut self) -> Result<u64, std::io::Error> {
- let mut buf = [0u8; 8];
- self.read_exact(&mut buf).await?;
- Ok(u64::from_be_bytes(buf))
+impl From<TcpStream> for ConnectionInstance {
+ fn from(stream: TcpStream) -> Self {
+ Self {
+ stream,
+ config: ConnectionConfig::default(),
+ }
}
}
impl ConnectionInstance {
+ /// Create a new ConnectionInstance with custom configuration
+ pub fn with_config(stream: TcpStream, config: ConnectionConfig) -> Self {
+ Self { stream, config }
+ }
+
+ /// Get a reference to the current configuration
+ pub fn config(&self) -> &ConnectionConfig {
+ &self.config
+ }
+
+ /// Get a mutable reference to the current configuration
+ pub fn config_mut(&mut self) -> &mut ConnectionConfig {
+ &mut self.config
+ }
/// Serialize data and write to the target machine
pub async fn write<Data>(&mut self, data: Data) -> Result<(), TcpTargetError>
where
Data: Default + Serialize,
{
let Ok(json_text) = serde_json::to_string(&data) else {
- return Err(TcpTargetError::from("Serialize failed."));
+ return Err(TcpTargetError::Serialization(
+ "Serialize failed.".to_string(),
+ ));
};
Self::write_text(self, json_text).await?;
Ok(())
@@ -67,10 +97,12 @@ impl ConnectionInstance {
Data: Default + serde::de::DeserializeOwned,
{
let Ok(json_text) = Self::read_text(self, buffer_size).await else {
- return Err(TcpTargetError::from("Read failed."));
+ return Err(TcpTargetError::Io("Read failed.".to_string()));
};
let Ok(deser_obj) = serde_json::from_str::<Data>(&json_text) else {
- return Err(TcpTargetError::from("Deserialize failed."));
+ return Err(TcpTargetError::Serialization(
+ "Deserialize failed.".to_string(),
+ ));
};
Ok(deser_obj)
}
@@ -81,7 +113,9 @@ impl ConnectionInstance {
Data: Default + Serialize,
{
let Ok(json_text) = serde_json::to_string(&data) else {
- return Err(TcpTargetError::from("Serialize failed."));
+ return Err(TcpTargetError::Serialization(
+ "Serialize failed.".to_string(),
+ ));
};
Self::write_large_text(self, json_text).await?;
Ok(())
@@ -96,10 +130,12 @@ impl ConnectionInstance {
Data: Default + serde::de::DeserializeOwned,
{
let Ok(json_text) = Self::read_large_text(self, buffer_size).await else {
- return Err(TcpTargetError::from("Read failed."));
+ return Err(TcpTargetError::Io("Read failed.".to_string()));
};
let Ok(deser_obj) = serde_json::from_str::<Data>(&json_text) else {
- return Err(TcpTargetError::from("Deserialize failed."));
+ return Err(TcpTargetError::Serialization(
+ "Deserialize failed.".to_string(),
+ ));
};
Ok(deser_obj)
}
@@ -111,7 +147,7 @@ impl ConnectionInstance {
// Write
match self.stream.write_all(text.as_bytes()).await {
Ok(_) => Ok(()),
- Err(err) => Err(TcpTargetError::from(err.to_string())),
+ Err(err) => Err(TcpTargetError::Io(err.to_string())),
}
}
@@ -128,7 +164,7 @@ impl ConnectionInstance {
let text = String::from_utf8_lossy(&buffer[..n]).to_string();
Ok(text)
}
- Err(err) => Err(TcpTargetError::from(err.to_string())),
+ Err(err) => Err(TcpTargetError::Io(err.to_string())),
}
}
@@ -145,7 +181,7 @@ impl ConnectionInstance {
let chunk = &bytes[offset..];
let written = match self.stream.write(chunk).await {
Ok(n) => n,
- Err(err) => return Err(TcpTargetError::from(err.to_string())),
+ Err(err) => return Err(TcpTargetError::Io(err.to_string())),
};
offset += written;
}
@@ -168,7 +204,7 @@ impl ConnectionInstance {
Ok(n) => {
buffer.extend_from_slice(&chunk_buf[..n]);
}
- Err(err) => return Err(TcpTargetError::from(err.to_string())),
+ Err(err) => return Err(TcpTargetError::Io(err.to_string())),
}
}
@@ -181,59 +217,66 @@ impl ConnectionInstance {
// Validate file
if !path.exists() {
- return Err(TcpTargetError::from(format!(
+ return Err(TcpTargetError::File(format!(
"File not found: {}",
path.display()
)));
}
if path.is_dir() {
- return Err(TcpTargetError::from(format!(
+ return Err(TcpTargetError::File(format!(
"Path is directory: {}",
path.display()
)));
}
// Open file and get metadata
- let mut file = File::open(path)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
- let file_size = file
- .metadata()
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?
- .len();
- if file_size == 0 {
- return Err(TcpTargetError::from("Cannot send empty file"));
- }
+ let mut file = File::open(path).await?;
+ let file_size = file.metadata().await?.len();
+ // Allow empty files - just send the header with size 0
+
+ // Send file header (version + size + crc)
+ self.stream.write_all(&1u64.to_be_bytes()).await?;
+ self.stream.write_all(&file_size.to_be_bytes()).await?;
+
+ // Calculate and send CRC32 if enabled
+ let file_crc = if self.config.enable_crc_validation {
+ let crc32 = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
+ let mut crc_calculator = crc32.digest();
+
+ let mut temp_reader =
+ BufReader::with_capacity(self.config.chunk_size, File::open(path).await?);
+ let mut temp_buffer = vec![0u8; self.config.chunk_size];
+ let mut temp_bytes_read = 0;
+
+ while temp_bytes_read < file_size {
+ let bytes_to_read =
+ (file_size - temp_bytes_read).min(self.config.chunk_size as u64) as usize;
+ temp_reader
+ .read_exact(&mut temp_buffer[..bytes_to_read])
+ .await?;
+ crc_calculator.update(&temp_buffer[..bytes_to_read]);
+ temp_bytes_read += bytes_to_read as u64;
+ }
+
+ crc_calculator.finalize()
+ } else {
+ 0
+ };
- // Send file header (version + size)
- self.stream
- .write_all(&1u64.to_be_bytes())
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
- self.stream
- .write_all(&file_size.to_be_bytes())
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.write_all(&file_crc.to_be_bytes()).await?;
// Transfer file content
- let mut reader = BufReader::with_capacity(CHUNK_SIZE, &mut file);
+ let mut reader = BufReader::with_capacity(self.config.chunk_size, &mut file);
let mut bytes_sent = 0;
while bytes_sent < file_size {
- let buffer = reader
- .fill_buf()
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ let buffer = reader.fill_buf().await?;
if buffer.is_empty() {
break;
}
let chunk_size = buffer.len().min((file_size - bytes_sent) as usize);
- self.stream
- .write_all(&buffer[..chunk_size])
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.write_all(&buffer[..chunk_size]).await?;
reader.consume(chunk_size);
bytes_sent += chunk_size as u64;
@@ -241,26 +284,27 @@ impl ConnectionInstance {
// Verify transfer completion
if bytes_sent != file_size {
- return Err(TcpTargetError::from(format!(
+ return Err(TcpTargetError::File(format!(
"Transfer incomplete: expected {} bytes, sent {} bytes",
file_size, bytes_sent
)));
}
- self.stream
- .flush()
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.flush().await?;
// Wait for receiver confirmation
let mut ack = [0u8; 1];
- tokio::time::timeout(Duration::from_secs(10), self.stream.read_exact(&mut ack))
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ tokio::time::timeout(
+ Duration::from_secs(self.config.timeout_secs),
+ self.stream.read_exact(&mut ack),
+ )
+ .await
+ .map_err(|_| TcpTargetError::Timeout("Ack timeout".to_string()))??;
if ack[0] != 1 {
- return Err(TcpTargetError::from("Receiver verification failed"));
+ return Err(TcpTargetError::Protocol(
+ "Receiver verification failed".to_string(),
+ ));
}
Ok(())
@@ -269,35 +313,42 @@ impl ConnectionInstance {
/// Read file from target machine
pub async fn read_file(&mut self, save_path: impl AsRef<Path>) -> Result<(), TcpTargetError> {
let path = save_path.as_ref();
+ // Create CRC instance at function scope to ensure proper lifetime
+ let crc_instance = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
// Make sure parent directory exists
if let Some(parent) = path.parent() {
if !parent.exists() {
- tokio::fs::create_dir_all(parent)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ tokio::fs::create_dir_all(parent).await?;
}
}
- // Read file header (version + size)
+ // Read file header (version + size + crc)
let mut version_buf = [0u8; 8];
- self.stream
- .read_exact(&mut version_buf)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.read_exact(&mut version_buf).await?;
let version = u64::from_be_bytes(version_buf);
if version != 1 {
- return Err(TcpTargetError::from("Unsupported transfer version"));
+ return Err(TcpTargetError::Protocol(
+ "Unsupported transfer version".to_string(),
+ ));
}
let mut size_buf = [0u8; 8];
- self.stream
- .read_exact(&mut size_buf)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.read_exact(&mut size_buf).await?;
let file_size = u64::from_be_bytes(size_buf);
+
+ let mut expected_crc_buf = [0u8; 4];
+ self.stream.read_exact(&mut expected_crc_buf).await?;
+ let expected_crc = u32::from_be_bytes(expected_crc_buf);
if file_size == 0 {
- return Err(TcpTargetError::from("Cannot receive zero-length file"));
+ // Create empty file and return early
+ let _file = OpenOptions::new()
+ .write(true)
+ .create(true)
+ .truncate(true)
+ .open(path)
+ .await?;
+ return Ok(());
}
// Prepare output file
@@ -306,57 +357,74 @@ impl ConnectionInstance {
.create(true)
.truncate(true)
.open(path)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
- let mut writer = BufWriter::with_capacity(CHUNK_SIZE, file);
+ .await?;
+ let mut writer = BufWriter::with_capacity(self.config.chunk_size, file);
- // Receive file content
- let mut buffer = vec![0u8; CHUNK_SIZE];
+ // Receive file content with CRC calculation if enabled
let mut bytes_received = 0;
+ let mut buffer = vec![0u8; self.config.chunk_size];
+ let mut crc_calculator = if self.config.enable_crc_validation {
+ Some(crc_instance.digest())
+ } else {
+ None
+ };
while bytes_received < file_size {
- let read_size = buffer.len().min((file_size - bytes_received) as usize);
- self.stream
- .read_exact(&mut buffer[..read_size])
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
-
- writer
- .write_all(&buffer[..read_size])
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
- bytes_received += read_size as u64;
+ let bytes_to_read =
+ (file_size - bytes_received).min(self.config.chunk_size as u64) as usize;
+ let chunk = &mut buffer[..bytes_to_read];
+
+ self.stream.read_exact(chunk).await?;
+
+ writer.write_all(chunk).await?;
+
+ // Update CRC if validation is enabled
+ if let Some(ref mut crc) = crc_calculator {
+ crc.update(chunk);
+ }
+
+ bytes_received += bytes_to_read as u64;
+ }
+
+ // Verify transfer completion
+ if bytes_received != file_size {
+ return Err(TcpTargetError::File(format!(
+ "Transfer incomplete: expected {} bytes, received {} bytes",
+ file_size, bytes_received
+ )));
+ }
+
+ writer.flush().await?;
+
+ // Validate CRC if enabled
+ if self.config.enable_crc_validation {
+ if let Some(crc_calculator) = crc_calculator {
+ let actual_crc = crc_calculator.finalize();
+ if actual_crc != expected_crc && expected_crc != 0 {
+ return Err(TcpTargetError::File(format!(
+ "CRC validation failed: expected {:08x}, got {:08x}",
+ expected_crc, actual_crc
+ )));
+ }
+ }
}
// Final flush and sync
- writer
- .flush()
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
- writer
- .into_inner()
- .sync_all()
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ writer.flush().await?;
+ writer.into_inner().sync_all().await?;
// Verify completion
if bytes_received != file_size {
let _ = tokio::fs::remove_file(path).await;
- return Err(TcpTargetError::from(format!(
+ return Err(TcpTargetError::File(format!(
"Transfer incomplete: expected {} bytes, received {} bytes",
file_size, bytes_received
)));
}
// Send confirmation
- self.stream
- .write_all(&[1])
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
- self.stream
- .flush()
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.write_all(&[1u8]).await?;
+ self.stream.flush().await?;
Ok(())
}
@@ -366,36 +434,28 @@ impl ConnectionInstance {
public_key_dir: impl AsRef<Path>,
) -> Result<bool, TcpTargetError> {
// Generate random challenge
- let mut rng = rand::rng();
- let challenge: [u8; 32] = rng.random();
+ let mut challenge = [0u8; 32];
+ rand::rngs::OsRng
+ .try_fill_bytes(&mut challenge)
+ .map_err(|e| {
+ TcpTargetError::Crypto(format!("Failed to generate random challenge: {}", e))
+ })?;
// Send challenge to target
- self.stream
- .write_all(&challenge)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.write_all(&challenge).await?;
// Read signature from target
let mut signature = Vec::new();
let mut signature_len_buf = [0u8; 4];
- self.stream
- .read_exact(&mut signature_len_buf)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.read_exact(&mut signature_len_buf).await?;
let signature_len = u32::from_be_bytes(signature_len_buf) as usize;
signature.resize(signature_len, 0);
- self.stream
- .read_exact(&mut signature)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.read_exact(&mut signature).await?;
// Read UUID from target to identify which public key to use
let mut uuid_buf = [0u8; 16];
- self.stream
- .read_exact(&mut uuid_buf)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.read_exact(&mut uuid_buf).await?;
let user_uuid = Uuid::from_bytes(uuid_buf);
// Load appropriate public key
@@ -404,9 +464,7 @@ impl ConnectionInstance {
return Ok(false);
}
- let public_key_pem = tokio::fs::read_to_string(&public_key_path)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ let public_key_pem = tokio::fs::read_to_string(&public_key_path).await?;
// Try to verify with different key types
let verified = if let Ok(rsa_key) = RsaPublicKey::from_pkcs1_pem(&public_key_pem) {
@@ -415,9 +473,15 @@ impl ConnectionInstance {
} else if let Ok(ed25519_key) =
VerifyingKey::from_bytes(&parse_ed25519_public_key(&public_key_pem))
{
- let sig_bytes: [u8; 64] = signature.as_slice().try_into().unwrap_or([0u8; 64]);
- let sig = Signature::from_bytes(&sig_bytes);
- ed25519_key.verify(&challenge, &sig).is_ok()
+ if signature.len() == 64 {
+ let sig_bytes: [u8; 64] = signature.as_slice().try_into().map_err(|_| {
+ TcpTargetError::Crypto("Invalid signature length for Ed25519".to_string())
+ })?;
+ let sig = Signature::from_bytes(&sig_bytes);
+ ed25519_key.verify(&challenge, &sig).is_ok()
+ } else {
+ false
+ }
} else if let Ok(dsa_key_info) = parse_dsa_public_key(&public_key_pem) {
verify_dsa_signature(&dsa_key_info, &challenge, &signature)
} else {
@@ -430,50 +494,34 @@ impl ConnectionInstance {
pub async fn accept_challenge(
&mut self,
private_key_file: impl AsRef<Path>,
- verify_user_uuid: Uuid,
+ verify_public_key: &str,
) -> Result<bool, TcpTargetError> {
// Read challenge from initiator
let mut challenge = [0u8; 32];
- self.stream
- .read_exact(&mut challenge)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.read_exact(&mut challenge).await?;
// Load private key
- let private_key_pem = tokio::fs::read_to_string(&private_key_file)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ let private_key_pem = tokio::fs::read_to_string(&private_key_file).await?;
// Sign the challenge with supported key types
let signature = if let Ok(rsa_key) = RsaPrivateKey::from_pkcs1_pem(&private_key_pem) {
let padding = rsa::pkcs1v15::Pkcs1v15Sign::new::<sha2::Sha256>();
- rsa_key
- .sign(padding, &challenge)
- .map_err(|e| TcpTargetError::from(e.to_string()))?
+ rsa_key.sign(padding, &challenge)?
} else if let Ok(ed25519_key) = parse_ed25519_private_key(&private_key_pem) {
ed25519_key.sign(&challenge).to_bytes().to_vec()
} else if let Ok(dsa_key_info) = parse_dsa_private_key(&private_key_pem) {
- sign_with_dsa(&dsa_key_info, &challenge)
+ sign_with_dsa(&dsa_key_info, &challenge)?
} else {
return Ok(false);
};
// Send signature length and signature
let signature_len = signature.len() as u32;
- self.stream
- .write_all(&signature_len.to_be_bytes())
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
- self.stream
- .write_all(&signature)
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.write_all(&signature_len.to_be_bytes()).await?;
+ self.stream.write_all(&signature).await?;
// Send UUID for public key identification
- self.stream
- .write_all(verify_user_uuid.as_bytes())
- .await
- .map_err(|e| TcpTargetError::from(e.to_string()))?;
+ self.stream.write_all(verify_public_key.as_bytes()).await?;
Ok(true)
}
@@ -483,15 +531,13 @@ impl ConnectionInstance {
/// Parse Ed25519 public key from PEM format
fn parse_ed25519_public_key(pem: &str) -> [u8; 32] {
- // Simple parsing for Ed25519 public key (assuming raw 32-byte key)
- let lines: Vec<&str> = pem.lines().collect();
+ // Robust parsing for Ed25519 public key using pem crate
let mut key_bytes = [0u8; 32];
- if lines.len() >= 2 && lines[0].contains("PUBLIC KEY") {
- if let Ok(decoded) = STANDARD.decode(lines[1].trim()) {
- if decoded.len() >= 32 {
- key_bytes.copy_from_slice(&decoded[decoded.len() - 32..]);
- }
+ if let Ok(pem_data) = pem::parse(pem) {
+ if pem_data.tag() == "PUBLIC KEY" && pem_data.contents().len() >= 32 {
+ let contents = pem_data.contents();
+ key_bytes.copy_from_slice(&contents[contents.len() - 32..]);
}
}
key_bytes
@@ -499,18 +545,17 @@ fn parse_ed25519_public_key(pem: &str) -> [u8; 32] {
/// Parse Ed25519 private key from PEM format
fn parse_ed25519_private_key(pem: &str) -> Result<SigningKey, TcpTargetError> {
- let lines: Vec<&str> = pem.lines().collect();
-
- if lines.len() >= 2 && lines[0].contains("PRIVATE KEY") {
- if let Ok(decoded) = STANDARD.decode(lines[1].trim()) {
- if decoded.len() >= 32 {
- let mut seed = [0u8; 32];
- seed.copy_from_slice(&decoded[decoded.len() - 32..]);
- return Ok(SigningKey::from_bytes(&seed));
- }
+ if let Ok(pem_data) = pem::parse(pem) {
+ if pem_data.tag() == "PRIVATE KEY" && pem_data.contents().len() >= 32 {
+ let contents = pem_data.contents();
+ let mut seed = [0u8; 32];
+ seed.copy_from_slice(&contents[contents.len() - 32..]);
+ return Ok(SigningKey::from_bytes(&seed));
}
}
- Err(TcpTargetError::from("Invalid Ed25519 private key format"))
+ Err(TcpTargetError::Crypto(
+ "Invalid Ed25519 private key format".to_string(),
+ ))
}
// Helper functions for DSA support
@@ -519,23 +564,30 @@ fn parse_ed25519_private_key(pem: &str) -> Result<SigningKey, TcpTargetError> {
fn parse_dsa_public_key(
pem: &str,
) -> Result<(&'static dyn signature::VerificationAlgorithm, Vec<u8>), TcpTargetError> {
- let lines: Vec<&str> = pem.lines().collect();
+ if let Ok(pem_data) = pem::parse(pem) {
+ let contents = pem_data.contents().to_vec();
- if lines.len() >= 2 {
- if let Ok(decoded) = STANDARD.decode(lines[1].trim()) {
- // Try different DSA algorithms
- if pem.contains("ECDSA") || pem.contains("ecdsa") {
+ // Try different DSA algorithms based on PEM tag
+ match pem_data.tag() {
+ "EC PUBLIC KEY" | "PUBLIC KEY" if pem.contains("ECDSA") || pem.contains("ecdsa") => {
if pem.contains("P-256") {
- return Ok((&ECDSA_P256_SHA256_ASN1, decoded));
+ return Ok((&ECDSA_P256_SHA256_ASN1, contents));
} else if pem.contains("P-384") {
- return Ok((&ECDSA_P384_SHA384_ASN1, decoded));
+ return Ok((&ECDSA_P384_SHA384_ASN1, contents));
}
}
- // Default to RSA if no specific algorithm detected
- return Ok((&RSA_PKCS1_2048_8192_SHA256, decoded));
+ "RSA PUBLIC KEY" | "PUBLIC KEY" => {
+ return Ok((&RSA_PKCS1_2048_8192_SHA256, contents));
+ }
+ _ => {}
}
+
+ // Default to RSA for unknown types
+ return Ok((&RSA_PKCS1_2048_8192_SHA256, contents));
}
- Err(TcpTargetError::from("Invalid DSA public key format"))
+ Err(TcpTargetError::Crypto(
+ "Invalid DSA public key format".to_string(),
+ ))
}
/// Parse DSA private key information from PEM
@@ -559,13 +611,51 @@ fn verify_dsa_signature(
/// Sign with DSA (simplified - in practice this would use proper private key operations)
fn sign_with_dsa(
- _algorithm_and_key: &(&'static dyn signature::VerificationAlgorithm, Vec<u8>),
+ algorithm_and_key: &(&'static dyn signature::VerificationAlgorithm, Vec<u8>),
message: &[u8],
-) -> Vec<u8> {
- // Note: This is a simplified implementation. In a real scenario,
- // you would use proper private key signing operations with ring or other crypto library.
- // For now, we'll return a dummy signature for demonstration.
- let mut signature = vec![0u8; 64];
- signature[..32].copy_from_slice(&message[..32.min(message.len())]);
- signature
+) -> Result<Vec<u8>, TcpTargetError> {
+ let (algorithm, key_bytes) = algorithm_and_key;
+
+ // Handle different DSA/ECDSA algorithms by comparing algorithm identifiers
+ // Since we can't directly compare trait objects, we use pointer comparison
+ let algorithm_ptr = algorithm as *const _ as *const ();
+ let ecdsa_p256_ptr = &ECDSA_P256_SHA256_ASN1 as *const _ as *const ();
+ let ecdsa_p384_ptr = &ECDSA_P384_SHA384_ASN1 as *const _ as *const ();
+
+ if algorithm_ptr == ecdsa_p256_ptr {
+ let key_pair = EcdsaKeyPair::from_pkcs8(
+ ECDSA_P256_SHA256_ASN1_SIGNING,
+ &key_bytes,
+ &SystemRandom::new(),
+ )
+ .map_err(|e| {
+ TcpTargetError::Crypto(format!("Failed to create ECDSA P-256 key pair: {}", e))
+ })?;
+
+ let signature = key_pair
+ .sign(&SystemRandom::new(), message)
+ .map_err(|e| TcpTargetError::Crypto(format!("ECDSA P-256 signing failed: {}", e)))?;
+
+ Ok(signature.as_ref().to_vec())
+ } else if algorithm_ptr == ecdsa_p384_ptr {
+ let key_pair = EcdsaKeyPair::from_pkcs8(
+ ECDSA_P384_SHA384_ASN1_SIGNING,
+ &key_bytes,
+ &SystemRandom::new(),
+ )
+ .map_err(|e| {
+ TcpTargetError::Crypto(format!("Failed to create ECDSA P-384 key pair: {}", e))
+ })?;
+
+ let signature = key_pair
+ .sign(&SystemRandom::new(), message)
+ .map_err(|e| TcpTargetError::Crypto(format!("ECDSA P-384 signing failed: {}", e)))?;
+
+ Ok(signature.as_ref().to_vec())
+ } else {
+ // RSA or unsupported algorithm
+ Err(TcpTargetError::Unsupported(
+ "DSA/ECDSA signing not supported for this algorithm type".to_string(),
+ ))
+ }
}
diff --git a/crates/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs b/crates/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs
index ac57451..57d3819 100644
--- a/crates/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs
+++ b/crates/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs
@@ -1,7 +1,4 @@
-use std::{
- env::{current_dir, set_current_dir},
- time::Duration,
-};
+use std::{env::current_dir, time::Duration};
use tcp_connection::{
handle::{ClientHandle, ServerHandle},
@@ -18,7 +15,9 @@ impl ClientHandle<ExampleChallengeServerHandle> for ExampleChallengeClientHandle
mut instance: ConnectionInstance,
) -> impl std::future::Future<Output = ()> + Send + Sync {
async move {
- // TODO :: Complete the implementation
+ let key = current_dir().unwrap().join("res").join("test_key");
+ let result = instance.accept_challenge(key, "test_key").await.unwrap();
+ assert_eq!(true, result);
}
}
}
@@ -30,7 +29,9 @@ impl ServerHandle<ExampleChallengeClientHandle> for ExampleChallengeServerHandle
mut instance: ConnectionInstance,
) -> impl std::future::Future<Output = ()> + Send + Sync {
async move {
- // TODO :: Complete the implementation
+ let key_dir = current_dir().unwrap().join("res");
+ let result = instance.challenge(key_dir).await.unwrap();
+ assert_eq!(true, result);
}
}
}
@@ -39,9 +40,6 @@ impl ServerHandle<ExampleChallengeClientHandle> for ExampleChallengeServerHandle
async fn test_connection_with_challenge_handle() -> Result<(), std::io::Error> {
let host = "localhost";
- // Enter temp directory
- set_current_dir(current_dir().unwrap().join(".temp/"))?;
-
// Server setup
let Ok(server_target) = TcpServerTarget::<
ExampleChallengeClientHandle,