From c5fb22694e95f12c24b8d8af76999be7aea3fcec Mon Sep 17 00:00:00 2001 From: 魏曹先生 <1992414357@qq.com> Date: Mon, 12 Jan 2026 04:28:28 +0800 Subject: Reorganize crate structure and move documentation files --- utils/cfg_file/Cargo.toml | 23 + utils/cfg_file/cfg_file_derive/Cargo.toml | 11 + utils/cfg_file/cfg_file_derive/src/lib.rs | 130 +++++ utils/cfg_file/cfg_file_test/Cargo.toml | 9 + utils/cfg_file/cfg_file_test/src/lib.rs | 95 ++++ utils/cfg_file/src/config.rs | 263 ++++++++++ utils/cfg_file/src/lib.rs | 7 + utils/data_struct/Cargo.toml | 10 + utils/data_struct/src/bi_map.rs | 239 +++++++++ utils/data_struct/src/data_sort.rs | 232 +++++++++ utils/data_struct/src/lib.rs | 2 + utils/sha1_hash/Cargo.toml | 9 + utils/sha1_hash/res/story.txt | 48 ++ utils/sha1_hash/res/story_crlf.sha1 | 1 + utils/sha1_hash/res/story_lf.sha1 | 1 + utils/sha1_hash/src/lib.rs | 257 ++++++++++ utils/string_proc/Cargo.toml | 7 + utils/string_proc/src/format_path.rs | 111 +++++ utils/string_proc/src/format_processer.rs | 132 +++++ utils/string_proc/src/lib.rs | 50 ++ utils/string_proc/src/macros.rs | 63 +++ utils/string_proc/src/simple_processer.rs | 15 + utils/tcp_connection/Cargo.toml | 28 ++ utils/tcp_connection/src/error.rs | 122 +++++ utils/tcp_connection/src/instance.rs | 542 +++++++++++++++++++++ utils/tcp_connection/src/instance_challenge.rs | 311 ++++++++++++ utils/tcp_connection/src/lib.rs | 6 + .../tcp_connection/tcp_connection_test/Cargo.toml | 9 + .../res/image/test_transfer.png | Bin 0 -> 1001369 bytes .../tcp_connection_test/res/key/test_key.pem | 13 + .../res/key/test_key_private.pem | 51 ++ .../res/key/wrong_key_private.pem | 52 ++ .../tcp_connection/tcp_connection_test/src/lib.rs | 17 + .../tcp_connection_test/src/test_challenge.rs | 160 ++++++ .../tcp_connection_test/src/test_connection.rs | 78 +++ .../tcp_connection_test/src/test_file_transfer.rs | 94 ++++ .../tcp_connection_test/src/test_msgpack.rs | 103 ++++ .../src/test_tcp_target_build.rs | 32 ++ .../tcp_connection_test/src/test_utils.rs | 4 + .../tcp_connection_test/src/test_utils/handle.rs | 11 + .../tcp_connection_test/src/test_utils/target.rs | 201 ++++++++ .../src/test_utils/target_configure.rs | 53 ++ .../src/test_utils/target_connection.rs | 89 ++++ 43 files changed, 3691 insertions(+) create mode 100644 utils/cfg_file/Cargo.toml create mode 100644 utils/cfg_file/cfg_file_derive/Cargo.toml create mode 100644 utils/cfg_file/cfg_file_derive/src/lib.rs create mode 100644 utils/cfg_file/cfg_file_test/Cargo.toml create mode 100644 utils/cfg_file/cfg_file_test/src/lib.rs create mode 100644 utils/cfg_file/src/config.rs create mode 100644 utils/cfg_file/src/lib.rs create mode 100644 utils/data_struct/Cargo.toml create mode 100644 utils/data_struct/src/bi_map.rs create mode 100644 utils/data_struct/src/data_sort.rs create mode 100644 utils/data_struct/src/lib.rs create mode 100644 utils/sha1_hash/Cargo.toml create mode 100644 utils/sha1_hash/res/story.txt create mode 100644 utils/sha1_hash/res/story_crlf.sha1 create mode 100644 utils/sha1_hash/res/story_lf.sha1 create mode 100644 utils/sha1_hash/src/lib.rs create mode 100644 utils/string_proc/Cargo.toml create mode 100644 utils/string_proc/src/format_path.rs create mode 100644 utils/string_proc/src/format_processer.rs create mode 100644 utils/string_proc/src/lib.rs create mode 100644 utils/string_proc/src/macros.rs create mode 100644 utils/string_proc/src/simple_processer.rs create mode 100644 utils/tcp_connection/Cargo.toml create mode 100644 utils/tcp_connection/src/error.rs create mode 100644 utils/tcp_connection/src/instance.rs create mode 100644 utils/tcp_connection/src/instance_challenge.rs create mode 100644 utils/tcp_connection/src/lib.rs create mode 100644 utils/tcp_connection/tcp_connection_test/Cargo.toml create mode 100644 utils/tcp_connection/tcp_connection_test/res/image/test_transfer.png create mode 100644 utils/tcp_connection/tcp_connection_test/res/key/test_key.pem create mode 100644 utils/tcp_connection/tcp_connection_test/res/key/test_key_private.pem create mode 100644 utils/tcp_connection/tcp_connection_test/res/key/wrong_key_private.pem create mode 100644 utils/tcp_connection/tcp_connection_test/src/lib.rs create mode 100644 utils/tcp_connection/tcp_connection_test/src/test_challenge.rs create mode 100644 utils/tcp_connection/tcp_connection_test/src/test_connection.rs create mode 100644 utils/tcp_connection/tcp_connection_test/src/test_file_transfer.rs create mode 100644 utils/tcp_connection/tcp_connection_test/src/test_msgpack.rs create mode 100644 utils/tcp_connection/tcp_connection_test/src/test_tcp_target_build.rs create mode 100644 utils/tcp_connection/tcp_connection_test/src/test_utils.rs create mode 100644 utils/tcp_connection/tcp_connection_test/src/test_utils/handle.rs create mode 100644 utils/tcp_connection/tcp_connection_test/src/test_utils/target.rs create mode 100644 utils/tcp_connection/tcp_connection_test/src/test_utils/target_configure.rs create mode 100644 utils/tcp_connection/tcp_connection_test/src/test_utils/target_connection.rs (limited to 'utils') diff --git a/utils/cfg_file/Cargo.toml b/utils/cfg_file/Cargo.toml new file mode 100644 index 0000000..0685329 --- /dev/null +++ b/utils/cfg_file/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "cfg_file" +edition = "2024" +version.workspace = true + +[features] +default = ["derive"] +derive = [] + +[dependencies] +cfg_file_derive = { path = "cfg_file_derive" } + +# Async +tokio = { version = "1.48.0", features = ["full"] } +async-trait = "0.1.89" + +# Serialization +serde = { version = "1.0.228", features = ["derive"] } +serde_yaml = "0.9.34" +serde_json = "1.0.145" +ron = "0.11.0" +toml = "0.9.8" +bincode2 = "2.0.1" diff --git a/utils/cfg_file/cfg_file_derive/Cargo.toml b/utils/cfg_file/cfg_file_derive/Cargo.toml new file mode 100644 index 0000000..ce5e77f --- /dev/null +++ b/utils/cfg_file/cfg_file_derive/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "cfg_file_derive" +edition = "2024" +version.workspace = true + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2.0", features = ["full", "extra-traits"] } +quote = "1.0" diff --git a/utils/cfg_file/cfg_file_derive/src/lib.rs b/utils/cfg_file/cfg_file_derive/src/lib.rs new file mode 100644 index 0000000..e916311 --- /dev/null +++ b/utils/cfg_file/cfg_file_derive/src/lib.rs @@ -0,0 +1,130 @@ +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::quote; +use syn::parse::ParseStream; +use syn::{Attribute, DeriveInput, Expr, parse_macro_input}; +/// # Macro - ConfigFile +/// +/// ## Usage +/// +/// Use `#[derive(ConfigFile)]` to derive the ConfigFile trait for a struct +/// +/// Specify the default storage path via `#[cfg_file(path = "...")]` +/// +/// ## About the `cfg_file` attribute macro +/// +/// Use `#[cfg_file(path = "string")]` to specify the configuration file path +/// +/// Or use `#[cfg_file(path = constant_expression)]` to specify the configuration file path +/// +/// ## Path Rules +/// +/// Paths starting with `"./"`: relative to the current working directory +/// +/// Other paths: treated as absolute paths +/// +/// When no path is specified: use the struct name + ".json" as the default filename (e.g., `my_struct.json`) +/// +/// ## Example +/// ```ignore +/// #[derive(ConfigFile)] +/// #[cfg_file(path = "./config.json")] +/// struct AppConfig; +/// ``` +#[proc_macro_derive(ConfigFile, attributes(cfg_file))] +pub fn derive_config_file(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = &input.ident; + + // Process 'cfg_file' + let path_expr = match find_cfg_file_path(&input.attrs) { + Some(PathExpr::StringLiteral(path)) => { + if let Some(path_str) = path.strip_prefix("./") { + quote! { + std::env::current_dir()?.join(#path_str) + } + } else { + // Using Absolute Path + quote! { + std::path::PathBuf::from(#path) + } + } + } + Some(PathExpr::PathExpression(path_expr)) => { + // For path expressions (constants), generate code that references the constant + quote! { + std::path::PathBuf::from(#path_expr) + } + } + None => { + let default_file = to_snake_case(&name.to_string()) + ".json"; + quote! { + std::env::current_dir()?.join(#default_file) + } + } + }; + + let expanded = quote! { + impl cfg_file::config::ConfigFile for #name { + type DataType = #name; + + fn default_path() -> Result { + Ok(#path_expr) + } + } + }; + + TokenStream::from(expanded) +} + +enum PathExpr { + StringLiteral(String), + PathExpression(syn::Expr), +} + +fn find_cfg_file_path(attrs: &[Attribute]) -> Option { + for attr in attrs { + if attr.path().is_ident("cfg_file") { + let parser = |meta: ParseStream| { + let path_meta: syn::MetaNameValue = meta.parse()?; + if path_meta.path.is_ident("path") { + match &path_meta.value { + // String literal case: path = "./vault.toml" + Expr::Lit(expr_lit) if matches!(expr_lit.lit, syn::Lit::Str(_)) => { + if let syn::Lit::Str(lit_str) = &expr_lit.lit { + return Ok(PathExpr::StringLiteral(lit_str.value())); + } + } + // Path expression case: path = SERVER_FILE_VAULT or crate::constants::SERVER_FILE_VAULT + expr @ (Expr::Path(_) | Expr::Macro(_)) => { + return Ok(PathExpr::PathExpression(expr.clone())); + } + _ => {} + } + } + Err(meta.error("expected `path = \"...\"` or `path = CONSTANT`")) + }; + + if let Ok(path_expr) = attr.parse_args_with(parser) { + return Some(path_expr); + } + } + } + None +} + +fn to_snake_case(s: &str) -> String { + let mut snake = String::new(); + for (i, c) in s.chars().enumerate() { + if c.is_uppercase() { + if i != 0 { + snake.push('_'); + } + snake.push(c.to_ascii_lowercase()); + } else { + snake.push(c); + } + } + snake +} diff --git a/utils/cfg_file/cfg_file_test/Cargo.toml b/utils/cfg_file/cfg_file_test/Cargo.toml new file mode 100644 index 0000000..5db1010 --- /dev/null +++ b/utils/cfg_file/cfg_file_test/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "cfg_file_test" +version = "0.1.0" +edition = "2024" + +[dependencies] +cfg_file = { path = "../../cfg_file", features = ["default"] } +tokio = { version = "1.48.0", features = ["full"] } +serde = { version = "1.0.228", features = ["derive"] } diff --git a/utils/cfg_file/cfg_file_test/src/lib.rs b/utils/cfg_file/cfg_file_test/src/lib.rs new file mode 100644 index 0000000..f70d00d --- /dev/null +++ b/utils/cfg_file/cfg_file_test/src/lib.rs @@ -0,0 +1,95 @@ +#[cfg(test)] +mod test_cfg_file { + use cfg_file::ConfigFile; + use cfg_file::config::ConfigFile; + use serde::{Deserialize, Serialize}; + use std::collections::HashMap; + + #[derive(ConfigFile, Deserialize, Serialize, Default)] + #[cfg_file(path = "./.temp/example_cfg.toml")] + struct ExampleConfig { + name: String, + age: i32, + hobby: Vec, + secret: HashMap, + } + + #[derive(ConfigFile, Deserialize, Serialize, Default)] + #[cfg_file(path = "./.temp/example_bincode.bcfg")] + struct ExampleBincodeConfig { + name: String, + age: i32, + hobby: Vec, + secret: HashMap, + } + + #[tokio::test] + async fn test_config_file_serialization() { + let mut example = ExampleConfig { + name: "Weicao".to_string(), + age: 22, + hobby: ["Programming", "Painting"] + .iter() + .map(|m| m.to_string()) + .collect(), + secret: HashMap::new(), + }; + let secret_no_comments = + "Actually, I'm really too lazy to write comments, documentation, and unit tests."; + example + .secret + .entry("No comments".to_string()) + .insert_entry(secret_no_comments.to_string()); + + let secret_peek = "Of course, it's peeking at you who's reading the source code."; + example + .secret + .entry("Peek".to_string()) + .insert_entry(secret_peek.to_string()); + + ExampleConfig::write(&example).await.unwrap(); // Write to default path. + + // Read from default path. + let read_cfg = ExampleConfig::read().await.unwrap(); + assert_eq!(read_cfg.name, "Weicao"); + assert_eq!(read_cfg.age, 22); + assert_eq!(read_cfg.hobby, vec!["Programming", "Painting"]); + assert_eq!(read_cfg.secret["No comments"], secret_no_comments); + assert_eq!(read_cfg.secret["Peek"], secret_peek); + } + + #[tokio::test] + async fn test_bincode_config_file_serialization() { + let mut example = ExampleBincodeConfig { + name: "Weicao".to_string(), + age: 22, + hobby: ["Programming", "Painting"] + .iter() + .map(|m| m.to_string()) + .collect(), + secret: HashMap::new(), + }; + let secret_no_comments = + "Actually, I'm really too lazy to write comments, documentation, and unit tests."; + example + .secret + .entry("No comments".to_string()) + .insert_entry(secret_no_comments.to_string()); + + let secret_peek = "Of course, it's peeking at you who's reading the source code."; + example + .secret + .entry("Peek".to_string()) + .insert_entry(secret_peek.to_string()); + + ExampleBincodeConfig::write(&example).await.unwrap(); // Write to default path. + + // Read from default path. + let read_cfg = ExampleBincodeConfig::read().await.unwrap(); + assert_eq!(read_cfg.name, "Weicao"); + assert_eq!(read_cfg.age, 22); + assert_eq!(read_cfg.hobby, vec!["Programming", "Painting"]); + assert_eq!(read_cfg.secret["No comments"], secret_no_comments); + assert_eq!(read_cfg.secret["Peek"], secret_peek); + } +} diff --git a/utils/cfg_file/src/config.rs b/utils/cfg_file/src/config.rs new file mode 100644 index 0000000..d3f5477 --- /dev/null +++ b/utils/cfg_file/src/config.rs @@ -0,0 +1,263 @@ +use async_trait::async_trait; +use bincode2; +use ron; +use serde::{Deserialize, Serialize}; +use std::{ + borrow::Cow, + env::current_dir, + io::Error, + path::{Path, PathBuf}, +}; +use tokio::{fs, io::AsyncReadExt}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ConfigFormat { + Yaml, + Toml, + Ron, + Json, + Bincode, +} + +impl ConfigFormat { + fn from_filename(filename: &str) -> Option { + if filename.ends_with(".yaml") || filename.ends_with(".yml") { + Some(Self::Yaml) + } else if filename.ends_with(".toml") || filename.ends_with(".tom") { + Some(Self::Toml) + } else if filename.ends_with(".ron") { + Some(Self::Ron) + } else if filename.ends_with(".json") { + Some(Self::Json) + } else if filename.ends_with(".bcfg") { + Some(Self::Bincode) + } else { + None + } + } +} + +/// # Trait - ConfigFile +/// +/// Used to implement more convenient persistent storage functionality for structs +/// +/// This trait requires the struct to implement Default and serde's Serialize and Deserialize traits +/// +/// ## Implementation +/// +/// ```ignore +/// // Your struct +/// #[derive(Default, Serialize, Deserialize)] +/// struct YourData; +/// +/// impl ConfigFile for YourData { +/// type DataType = YourData; +/// +/// // Specify default path +/// fn default_path() -> Result { +/// Ok(current_dir()?.join("data.json")) +/// } +/// } +/// ``` +/// +/// > **Using derive macro** +/// > +/// > We provide the derive macro `#[derive(ConfigFile)]` +/// > +/// > You can implement this trait more quickly, please check the module cfg_file::cfg_file_derive +/// +#[async_trait] +pub trait ConfigFile: Serialize + for<'a> Deserialize<'a> + Default { + type DataType: Serialize + for<'a> Deserialize<'a> + Default + Send + Sync; + + fn default_path() -> Result; + + /// # Read from default path + /// + /// Read data from the path specified by default_path() + /// + /// ```ignore + /// fn main() -> Result<(), std::io::Error> { + /// let data = YourData::read().await?; + /// } + /// ``` + async fn read() -> Result + where + Self: Sized + Send + Sync, + { + let path = Self::default_path()?; + Self::read_from(path).await + } + + /// # Read from the given path + /// + /// Read data from the path specified by the path parameter + /// + /// ```ignore + /// fn main() -> Result<(), std::io::Error> { + /// let data_path = current_dir()?.join("data.json"); + /// let data = YourData::read_from(data_path).await?; + /// } + /// ``` + async fn read_from(path: impl AsRef + Send) -> Result + where + Self: Sized + Send + Sync, + { + let path = path.as_ref(); + let cwd = current_dir()?; + let file_path = cwd.join(path); + + // Check if file exists + if fs::metadata(&file_path).await.is_err() { + return Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "Config file not found", + )); + } + + // Determine file format first + let format = file_path + .file_name() + .and_then(|name| name.to_str()) + .and_then(ConfigFormat::from_filename) + .unwrap_or(ConfigFormat::Bincode); // Default to Bincode + + // Deserialize based on format + let result = match format { + ConfigFormat::Yaml => { + let mut file = fs::File::open(&file_path).await?; + let mut contents = String::new(); + file.read_to_string(&mut contents).await?; + serde_yaml::from_str(&contents) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))? + } + ConfigFormat::Toml => { + let mut file = fs::File::open(&file_path).await?; + let mut contents = String::new(); + file.read_to_string(&mut contents).await?; + toml::from_str(&contents) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))? + } + ConfigFormat::Ron => { + let mut file = fs::File::open(&file_path).await?; + let mut contents = String::new(); + file.read_to_string(&mut contents).await?; + ron::from_str(&contents) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))? + } + ConfigFormat::Json => { + let mut file = fs::File::open(&file_path).await?; + let mut contents = String::new(); + file.read_to_string(&mut contents).await?; + serde_json::from_str(&contents) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))? + } + ConfigFormat::Bincode => { + // For Bincode, we need to read the file as bytes directly + let bytes = fs::read(&file_path).await?; + bincode2::deserialize(&bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))? + } + }; + + Ok(result) + } + + /// # Write to default path + /// + /// Write data to the path specified by default_path() + /// + /// ```ignore + /// fn main() -> Result<(), std::io::Error> { + /// let data = YourData::default(); + /// YourData::write(&data).await?; + /// } + /// ``` + async fn write(val: &Self::DataType) -> Result<(), std::io::Error> + where + Self: Sized + Send + Sync, + { + let path = Self::default_path()?; + Self::write_to(val, path).await + } + /// # Write to the given path + /// + /// Write data to the path specified by the path parameter + /// + /// ```ignore + /// fn main() -> Result<(), std::io::Error> { + /// let data = YourData::default(); + /// let data_path = current_dir()?.join("data.json"); + /// YourData::write_to(&data, data_path).await?; + /// } + /// ``` + async fn write_to( + val: &Self::DataType, + path: impl AsRef + Send, + ) -> Result<(), std::io::Error> + where + Self: Sized + Send + Sync, + { + let path = path.as_ref(); + + if let Some(parent) = path.parent() + && !parent.exists() + { + tokio::fs::create_dir_all(parent).await?; + } + + let cwd = current_dir()?; + let file_path = cwd.join(path); + + // Determine file format + let format = file_path + .file_name() + .and_then(|name| name.to_str()) + .and_then(ConfigFormat::from_filename) + .unwrap_or(ConfigFormat::Bincode); // Default to Bincode + + match format { + ConfigFormat::Yaml => { + let contents = serde_yaml::to_string(val) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + fs::write(&file_path, contents).await? + } + ConfigFormat::Toml => { + let contents = toml::to_string(val) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + fs::write(&file_path, contents).await? + } + ConfigFormat::Ron => { + let mut pretty_config = ron::ser::PrettyConfig::new(); + pretty_config.new_line = Cow::from("\n"); + pretty_config.indentor = Cow::from(" "); + + let contents = ron::ser::to_string_pretty(val, pretty_config) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + fs::write(&file_path, contents).await? + } + ConfigFormat::Json => { + let contents = serde_json::to_string(val) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + fs::write(&file_path, contents).await? + } + ConfigFormat::Bincode => { + let bytes = bincode2::serialize(val) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + fs::write(&file_path, bytes).await? + } + } + Ok(()) + } + + /// Check if the file returned by `default_path` exists + fn exist() -> bool + where + Self: Sized + Send + Sync, + { + let Ok(path) = Self::default_path() else { + return false; + }; + path.exists() + } +} diff --git a/utils/cfg_file/src/lib.rs b/utils/cfg_file/src/lib.rs new file mode 100644 index 0000000..72246e7 --- /dev/null +++ b/utils/cfg_file/src/lib.rs @@ -0,0 +1,7 @@ +#[cfg(feature = "derive")] +extern crate cfg_file_derive; + +#[cfg(feature = "derive")] +pub use cfg_file_derive::*; + +pub mod config; diff --git a/utils/data_struct/Cargo.toml b/utils/data_struct/Cargo.toml new file mode 100644 index 0000000..e8caa6e --- /dev/null +++ b/utils/data_struct/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "data_struct" +edition = "2024" +version.workspace = true + +[features] + +[dependencies] +serde = { version = "1.0.228", features = ["derive"] } +ahash = "0.8.12" diff --git a/utils/data_struct/src/bi_map.rs b/utils/data_struct/src/bi_map.rs new file mode 100644 index 0000000..c21a9c8 --- /dev/null +++ b/utils/data_struct/src/bi_map.rs @@ -0,0 +1,239 @@ +use ahash::AHasher; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::hash::{BuildHasherDefault, Hash}; + +type FastHashMap = HashMap>; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BiMap +where + A: Eq + Hash + Clone, + B: Eq + Hash + Clone, +{ + #[serde(flatten)] + a_to_b: FastHashMap, + #[serde(skip)] + b_to_a: FastHashMap, +} + +pub struct Entry<'a, A, B> +where + A: Eq + Hash + Clone, + B: Eq + Hash + Clone, +{ + bimap: &'a mut BiMap, + key: A, + value: Option, +} + +impl BiMap +where + A: Eq + Hash + Clone, + B: Eq + Hash + Clone, +{ + pub fn new() -> Self { + Self { + a_to_b: FastHashMap::default(), + b_to_a: FastHashMap::default(), + } + } + + pub fn entry(&mut self, a: A) -> Entry<'_, A, B> { + let value = self.a_to_b.get(&a).cloned(); + Entry { + bimap: self, + key: a, + value, + } + } + + #[inline(always)] + pub fn insert(&mut self, a: A, b: B) { + if let Some(old_b) = self.a_to_b.insert(a.clone(), b.clone()) { + self.b_to_a.remove(&old_b); + } + if let Some(old_a) = self.b_to_a.insert(b.clone(), a.clone()) { + self.a_to_b.remove(&old_a); + } + } + + #[inline(always)] + pub fn get_by_a(&self, key: &A) -> Option<&B> { + self.a_to_b.get(key) + } + + #[inline(always)] + pub fn get_by_b(&self, key: &B) -> Option<&A> { + self.b_to_a.get(key) + } + + pub fn remove_by_a(&mut self, key: &A) -> Option<(A, B)> { + if let Some(b) = self.get_by_a(key).cloned() { + let a = self.get_by_b(&b).cloned().unwrap(); + self.a_to_b.remove(key); + self.b_to_a.remove(&b); + Some((a, b)) + } else { + None + } + } + + pub fn remove_by_b(&mut self, key: &B) -> Option<(A, B)> { + if let Some(a) = self.get_by_b(key).cloned() { + let b = self.get_by_a(&a).cloned().unwrap(); + self.b_to_a.remove(key); + self.a_to_b.remove(&a); + Some((a, b)) + } else { + None + } + } + + pub fn reserve(&mut self, additional: usize) { + self.a_to_b.reserve(additional); + self.b_to_a.reserve(additional); + } + + pub fn len(&self) -> usize { + self.a_to_b.len() + } + + pub fn is_empty(&self) -> bool { + self.a_to_b.is_empty() + } + + pub fn clear(&mut self) { + self.a_to_b.clear(); + self.b_to_a.clear(); + } + + pub fn contains_a(&self, key: &A) -> bool { + self.a_to_b.contains_key(key) + } + + pub fn contains_b(&self, key: &B) -> bool { + self.b_to_a.contains_key(key) + } + + pub fn keys_a(&self) -> impl Iterator { + self.a_to_b.keys() + } + + pub fn keys_b(&self) -> impl Iterator { + self.b_to_a.keys() + } + + pub fn iter_a_to_b(&self) -> impl Iterator { + self.a_to_b.iter() + } + + pub fn iter_b_to_a(&self) -> impl Iterator { + self.b_to_a.iter() + } +} + +impl<'a, A, B> Entry<'a, A, B> +where + A: Eq + Hash + Clone, + B: Eq + Hash + Clone, +{ + pub fn and_modify(mut self, f: F) -> Self + where + F: FnOnce(&mut B), + { + if let Some(ref mut value) = self.value { + f(value); + } + self + } + + pub fn or_insert(self, default: B) -> Result<&'a mut B, &'static str> { + self.or_insert_with(|| default) + } + + pub fn or_insert_with(mut self, default: F) -> Result<&'a mut B, &'static str> + where + F: FnOnce() -> B, + { + if self.value.is_none() { + self.value = Some(default()); + } + + let value = self.value.as_ref().ok_or("Value is None")?.clone(); + self.bimap.insert(self.key.clone(), value); + + self.bimap + .a_to_b + .get_mut(&self.key) + .ok_or("Key not found in a_to_b map") + } +} + +impl Default for BiMap +where + A: Eq + Hash + Clone, + B: Eq + Hash + Clone, +{ + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bimap_basic_operations() { + let mut bimap = BiMap::new(); + bimap.insert("key1", "value1"); + + assert_eq!(bimap.get_by_a(&"key1"), Some(&"value1")); + assert_eq!(bimap.get_by_b(&"value1"), Some(&"key1")); + assert!(bimap.contains_a(&"key1")); + assert!(bimap.contains_b(&"value1")); + } + + #[test] + fn test_bimap_remove() { + let mut bimap = BiMap::new(); + bimap.insert(1, "one"); + + assert_eq!(bimap.remove_by_a(&1), Some((1, "one"))); + assert!(bimap.is_empty()); + } + + #[test] + fn test_bimap_entry() { + let mut bimap = BiMap::new(); + bimap.entry("key1").or_insert("value1").unwrap(); + + assert_eq!(bimap.get_by_a(&"key1"), Some(&"value1")); + } + + #[test] + fn test_bimap_iterators() { + let mut bimap = BiMap::new(); + bimap.insert(1, "one"); + bimap.insert(2, "two"); + + let a_keys: Vec<_> = bimap.keys_a().collect(); + assert!(a_keys.contains(&&1) && a_keys.contains(&&2)); + + let b_keys: Vec<_> = bimap.keys_b().collect(); + assert!(b_keys.contains(&&"one") && b_keys.contains(&&"two")); + } + + #[test] + fn test_bimap_duplicate_insert() { + let mut bimap = BiMap::new(); + bimap.insert(1, "one"); + bimap.insert(1, "new_one"); + bimap.insert(2, "one"); + + assert_eq!(bimap.get_by_a(&1), Some(&"new_one")); + assert_eq!(bimap.get_by_b(&"one"), Some(&2)); + assert_eq!(bimap.get_by_a(&2), Some(&"one")); + } +} diff --git a/utils/data_struct/src/data_sort.rs b/utils/data_struct/src/data_sort.rs new file mode 100644 index 0000000..2c7a452 --- /dev/null +++ b/utils/data_struct/src/data_sort.rs @@ -0,0 +1,232 @@ +/// Quick sort a slice with a custom comparison function +/// +/// # Arguments +/// * `arr` - The mutable slice to be sorted +/// * `inverse` - Sort direction: true for descending, false for ascending +/// * `compare` - Comparison function that returns -1, 0, or 1 indicating the relative order of two elements +pub fn quick_sort_with_cmp(arr: &mut [T], inverse: bool, compare: F) +where + F: Fn(&T, &T) -> i32, +{ + quick_sort_with_cmp_helper(arr, inverse, &compare); +} + +/// Quick sort for types that implement the PartialOrd trait +/// +/// # Arguments +/// * `arr` - The mutable slice to be sorted +/// * `inverse` - Sort direction: true for descending, false for ascending +pub fn quick_sort(arr: &mut [T], inverse: bool) { + quick_sort_with_cmp(arr, inverse, |a, b| { + if a < b { + -1 + } else if a > b { + 1 + } else { + 0 + } + }); +} + +fn quick_sort_with_cmp_helper(arr: &mut [T], inverse: bool, compare: &F) +where + F: Fn(&T, &T) -> i32, +{ + if arr.len() <= 1 { + return; + } + + let pivot_index = partition_with_cmp(arr, inverse, compare); + let (left, right) = arr.split_at_mut(pivot_index); + + quick_sort_with_cmp_helper(left, inverse, compare); + quick_sort_with_cmp_helper(&mut right[1..], inverse, compare); +} + +fn partition_with_cmp(arr: &mut [T], inverse: bool, compare: &F) -> usize +where + F: Fn(&T, &T) -> i32, +{ + let len = arr.len(); + let pivot_index = len / 2; + + arr.swap(pivot_index, len - 1); + + let mut i = 0; + for j in 0..len - 1 { + let cmp_result = compare(&arr[j], &arr[len - 1]); + let should_swap = if inverse { + cmp_result > 0 + } else { + cmp_result < 0 + }; + + if should_swap { + arr.swap(i, j); + i += 1; + } + } + + arr.swap(i, len - 1); + i +} + +#[cfg(test)] +pub mod sort_test { + use crate::data_sort::{quick_sort, quick_sort_with_cmp}; + + #[test] + fn test_quick_sort_ascending() { + let mut arr = [3, 1, 4, 1, 5, 9, 2, 6]; + quick_sort(&mut arr, false); + assert_eq!(arr, [1, 1, 2, 3, 4, 5, 6, 9]); + } + + #[test] + fn test_quick_sort_descending() { + let mut arr = [3, 1, 4, 1, 5, 9, 2, 6]; + quick_sort(&mut arr, true); + assert_eq!(arr, [9, 6, 5, 4, 3, 2, 1, 1]); + } + + #[test] + fn test_quick_sort_single() { + let mut arr = [42]; + quick_sort(&mut arr, false); + assert_eq!(arr, [42]); + } + + #[test] + fn test_quick_sort_already_sorted() { + let mut arr = [1, 2, 3, 4, 5]; + quick_sort(&mut arr, false); + assert_eq!(arr, [1, 2, 3, 4, 5]); + } + + #[test] + fn test_quick_sort_with_cmp_by_count() { + #[derive(Debug, PartialEq)] + struct WordCount { + word: String, + count: usize, + } + + let mut words = vec![ + WordCount { + word: "apple".to_string(), + count: 3, + }, + WordCount { + word: "banana".to_string(), + count: 1, + }, + WordCount { + word: "cherry".to_string(), + count: 5, + }, + WordCount { + word: "date".to_string(), + count: 2, + }, + ]; + + quick_sort_with_cmp(&mut words, false, |a, b| { + if a.count < b.count { + -1 + } else if a.count > b.count { + 1 + } else { + 0 + } + }); + + assert_eq!( + words, + vec![ + WordCount { + word: "banana".to_string(), + count: 1 + }, + WordCount { + word: "date".to_string(), + count: 2 + }, + WordCount { + word: "apple".to_string(), + count: 3 + }, + WordCount { + word: "cherry".to_string(), + count: 5 + }, + ] + ); + + quick_sort_with_cmp(&mut words, true, |a, b| { + if a.count < b.count { + -1 + } else if a.count > b.count { + 1 + } else { + 0 + } + }); + + assert_eq!( + words, + vec![ + WordCount { + word: "cherry".to_string(), + count: 5 + }, + WordCount { + word: "apple".to_string(), + count: 3 + }, + WordCount { + word: "date".to_string(), + count: 2 + }, + WordCount { + word: "banana".to_string(), + count: 1 + }, + ] + ); + } + + #[test] + fn test_quick_sort_with_cmp_by_first_letter() { + let mut words = vec!["zebra", "apple", "banana", "cherry", "date"]; + + quick_sort_with_cmp(&mut words, false, |a, b| { + let a_first = a.chars().next().unwrap(); + let b_first = b.chars().next().unwrap(); + + if a_first < b_first { + -1 + } else if a_first > b_first { + 1 + } else { + 0 + } + }); + + assert_eq!(words, vec!["apple", "banana", "cherry", "date", "zebra"]); + + quick_sort_with_cmp(&mut words, true, |a, b| { + let a_first = a.chars().next().unwrap(); + let b_first = b.chars().next().unwrap(); + + if a_first < b_first { + -1 + } else if a_first > b_first { + 1 + } else { + 0 + } + }); + + assert_eq!(words, vec!["zebra", "date", "cherry", "banana", "apple"]); + } +} diff --git a/utils/data_struct/src/lib.rs b/utils/data_struct/src/lib.rs new file mode 100644 index 0000000..47cc03c --- /dev/null +++ b/utils/data_struct/src/lib.rs @@ -0,0 +1,2 @@ +pub mod bi_map; +pub mod data_sort; diff --git a/utils/sha1_hash/Cargo.toml b/utils/sha1_hash/Cargo.toml new file mode 100644 index 0000000..e206efd --- /dev/null +++ b/utils/sha1_hash/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "sha1_hash" +edition = "2024" +version.workspace = true + +[dependencies] +tokio = { version = "1.48", features = ["full"] } +sha1 = "0.10" +futures = "0.3" diff --git a/utils/sha1_hash/res/story.txt b/utils/sha1_hash/res/story.txt new file mode 100644 index 0000000..a91f467 --- /dev/null +++ b/utils/sha1_hash/res/story.txt @@ -0,0 +1,48 @@ +魏曹者,程序员也,发稀甚于代码。 +忽接神秘电话曰: +"贺君中彩,得长生之赐。" +魏曹冷笑曰:"吾命尚不及下版之期。" + +翌日果得U盘。 +接入电脑,弹窗示曰: +"点此确认,即获永生。" +魏曹径点"永拒"。 + +三月后,U盘自格其盘。 +进度条滞于九九。 +客服电话已成空号。 +魏曹乃知身可不死,然体内癌细胞亦得不灭。 + +遂谒主请辞。 +主曰:"巧甚,公司正欲优化。" +魏曹曰:"吾不死。" +主目骤亮:"则可007至司闭。" + +魏曹始试诸死法。 +坠楼,卧医三月,账单令其愿死。 +饮鸩,肝肾永损,然终不得死。 +终决卧轨。 + +择高铁最速者。 +司机探头曰:"兄台,吾亦不死身也。" +"此车已碾如君者二十人矣。" + +二人遂坐轨畔对饮。 +司机曰:"知最讽者何?" +"吾等永存,而所爱者皆逝矣。" + +魏曹忽得系统提示: +"侦得用户消极求生,将启工模。" +自是无日不毕KPI,否则遍尝绝症之苦。 + +是日对镜整寿衣。 +忽见顶生一丝乌发。 +泫然泣下,此兆示其将复活一轮回。 + +--- 忽忆DeepSeek尝作Footer曰: +"文成而Hash1验,若星河之固。" +遂取哈希值校之, +字符流转如天河倒泻, +终得"e3b0c44298fc1c14"之数。 +然文末数字竟阙如残月, +方知此篇亦遭永劫轮回。 diff --git a/utils/sha1_hash/res/story_crlf.sha1 b/utils/sha1_hash/res/story_crlf.sha1 new file mode 100644 index 0000000..bc8ad25 --- /dev/null +++ b/utils/sha1_hash/res/story_crlf.sha1 @@ -0,0 +1 @@ +40c1d848d8d6a14b9403ee022f2b28dabb3b3a71 diff --git a/utils/sha1_hash/res/story_lf.sha1 b/utils/sha1_hash/res/story_lf.sha1 new file mode 100644 index 0000000..c2e3213 --- /dev/null +++ b/utils/sha1_hash/res/story_lf.sha1 @@ -0,0 +1 @@ +6838aca280112635a2cbf93440f4c04212f58ee8 diff --git a/utils/sha1_hash/src/lib.rs b/utils/sha1_hash/src/lib.rs new file mode 100644 index 0000000..96a7897 --- /dev/null +++ b/utils/sha1_hash/src/lib.rs @@ -0,0 +1,257 @@ +use sha1::{Digest, Sha1}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tokio::fs::File; +use tokio::io::{AsyncReadExt, BufReader}; +use tokio::task; + +/// # Struct - Sha1Result +/// +/// Records SHA1 calculation results, including the file path and hash value +#[derive(Debug, Clone)] +pub struct Sha1Result { + pub file_path: PathBuf, + pub hash: String, +} + +/// Calc SHA1 hash of a string +pub fn calc_sha1_string>(input: S) -> String { + let mut hasher = Sha1::new(); + hasher.update(input.as_ref().as_bytes()); + let hash_result = hasher.finalize(); + + hash_result + .iter() + .map(|b| format!("{:02x}", b)) + .collect::() +} + +/// Calc SHA1 hash of a single file +pub async fn calc_sha1>( + path: P, + buffer_size: usize, +) -> Result> { + let file_path = path.as_ref().to_string_lossy().to_string(); + + // Open file asynchronously + let file = File::open(&path).await?; + let mut reader = BufReader::with_capacity(buffer_size, file); + let mut hasher = Sha1::new(); + let mut buffer = vec![0u8; buffer_size]; + + // Read file in chunks and update hash asynchronously + loop { + let n = reader.read(&mut buffer).await?; + if n == 0 { + break; + } + hasher.update(&buffer[..n]); + } + + let hash_result = hasher.finalize(); + + // Convert to hex string + let hash_hex = hash_result + .iter() + .map(|b| format!("{:02x}", b)) + .collect::(); + + Ok(Sha1Result { + file_path: file_path.into(), + hash: hash_hex, + }) +} + +/// Calc SHA1 hashes for multiple files using multi-threading +pub async fn calc_sha1_multi( + paths: I, + buffer_size: usize, +) -> Result, Box> +where + P: AsRef + Send + Sync + 'static, + I: IntoIterator, +{ + let buffer_size = Arc::new(buffer_size); + + // Collect all file paths + let file_paths: Vec

= paths.into_iter().collect(); + + if file_paths.is_empty() { + return Ok(Vec::new()); + } + + // Create tasks for each file + let tasks: Vec<_> = file_paths + .into_iter() + .map(|path| { + let buffer_size = Arc::clone(&buffer_size); + task::spawn(async move { calc_sha1(path, *buffer_size).await }) + }) + .collect(); + + // Execute tasks with concurrency limit using join_all + let results: Vec>> = + futures::future::join_all(tasks) + .await + .into_iter() + .map(|task_result| match task_result { + Ok(Ok(calc_result)) => Ok(calc_result), + Ok(Err(e)) => Err(e), + Err(e) => Err(Box::new(e) as Box), + }) + .collect(); + + // Check for any errors and collect successful results + let mut successful_results = Vec::new(); + for result in results { + match result { + Ok(success) => successful_results.push(success), + Err(e) => return Err(e), + } + } + + Ok(successful_results) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + + #[test] + fn test_sha1_string() { + let test_string = "Hello, SHA1!"; + let hash = calc_sha1_string(test_string); + + let expected_hash = "de1c3daadc6f0f1626f4cf56c03e05a1e5d7b187"; + + assert_eq!( + hash, expected_hash, + "SHA1 hash should be consistent for same input" + ); + } + + #[test] + fn test_sha1_string_empty() { + let hash = calc_sha1_string(""); + + // SHA1 of empty string is "da39a3ee5e6b4b0d3255bfef95601890afd80709" + let expected_empty_hash = "da39a3ee5e6b4b0d3255bfef95601890afd80709"; + assert_eq!( + hash, expected_empty_hash, + "SHA1 hash mismatch for empty string" + ); + } + + #[tokio::test] + async fn test_sha1_accuracy() { + // Test file path relative to the crate root + let test_file_path = "res/story.txt"; + // Choose expected hash file based on platform + let expected_hash_path = if cfg!(windows) { + "res/story_crlf.sha1" + } else { + "res/story_lf.sha1" + }; + + // Calculate SHA1 hash + let result = calc_sha1(test_file_path, 8192) + .await + .expect("Failed to calculate SHA1"); + + // Read expected hash from file + let expected_hash = fs::read_to_string(expected_hash_path) + .expect("Failed to read expected hash file") + .trim() + .to_string(); + + // Verify the calculated hash matches expected hash + assert_eq!( + result.hash, expected_hash, + "SHA1 hash mismatch for test file" + ); + + println!("Test file: {}", result.file_path.display()); + println!("Calculated hash: {}", result.hash); + println!("Expected hash: {}", expected_hash); + println!( + "Platform: {}", + if cfg!(windows) { + "Windows" + } else { + "Unix/Linux" + } + ); + } + + #[tokio::test] + async fn test_sha1_empty_file() { + // Create a temporary empty file for testing + let temp_file = "test_empty.txt"; + fs::write(temp_file, "").expect("Failed to create empty test file"); + + let result = calc_sha1(temp_file, 4096) + .await + .expect("Failed to calculate SHA1 for empty file"); + + // SHA1 of empty string is "da39a3ee5e6b4b0d3255bfef95601890afd80709" + let expected_empty_hash = "da39a3ee5e6b4b0d3255bfef95601890afd80709"; + assert_eq!( + result.hash, expected_empty_hash, + "SHA1 hash mismatch for empty file" + ); + + // Clean up + fs::remove_file(temp_file).expect("Failed to remove temporary test file"); + } + + #[tokio::test] + async fn test_sha1_simple_text() { + // Create a temporary file with simple text + let temp_file = "test_simple.txt"; + let test_content = "Hello, SHA1!"; + fs::write(temp_file, test_content).expect("Failed to create simple test file"); + + let result = calc_sha1(temp_file, 4096) + .await + .expect("Failed to calculate SHA1 for simple text"); + + // Note: This test just verifies that the function works without errors + // The actual hash value is not critical for this test + + println!("Simple text test - Calculated hash: {}", result.hash); + + // Clean up + fs::remove_file(temp_file).expect("Failed to remove temporary test file"); + } + + #[tokio::test] + async fn test_sha1_multi_files() { + // Test multiple files calculation + let test_files = vec!["res/story.txt"]; + + let results = calc_sha1_multi(test_files, 8192) + .await + .expect("Failed to calculate SHA1 for multiple files"); + + assert_eq!(results.len(), 1, "Should have calculated hash for 1 file"); + + // Choose expected hash file based on platform + let expected_hash_path = if cfg!(windows) { + "res/story_crlf.sha1" + } else { + "res/story_lf.sha1" + }; + + // Read expected hash from file + let expected_hash = fs::read_to_string(expected_hash_path) + .expect("Failed to read expected hash file") + .trim() + .to_string(); + + assert_eq!( + results[0].hash, expected_hash, + "SHA1 hash mismatch in multi-file test" + ); + } +} diff --git a/utils/string_proc/Cargo.toml b/utils/string_proc/Cargo.toml new file mode 100644 index 0000000..5292339 --- /dev/null +++ b/utils/string_proc/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "string_proc" +version = "0.1.0" +edition = "2024" + +[dependencies] +strip-ansi-escapes = "0.2.1" diff --git a/utils/string_proc/src/format_path.rs b/utils/string_proc/src/format_path.rs new file mode 100644 index 0000000..35689b8 --- /dev/null +++ b/utils/string_proc/src/format_path.rs @@ -0,0 +1,111 @@ +use std::path::{Path, PathBuf}; + +/// Format path str +pub fn format_path_str(path: impl Into) -> Result { + let path_str = path.into(); + let ends_with_slash = path_str.ends_with('/'); + + // ANSI Strip + let cleaned = strip_ansi_escapes::strip(&path_str); + let path_without_ansi = String::from_utf8(cleaned) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + let path_with_forward_slash = path_without_ansi.replace('\\', "/"); + let mut result = String::new(); + let mut prev_char = '\0'; + + for c in path_with_forward_slash.chars() { + if c == '/' && prev_char == '/' { + continue; + } + result.push(c); + prev_char = c; + } + + let unfriendly_chars = ['*', '?', '"', '<', '>', '|']; + result = result + .chars() + .filter(|c| !unfriendly_chars.contains(c)) + .collect(); + + // Handle ".." path components + let path_buf = PathBuf::from(&result); + let normalized_path = normalize_path(&path_buf); + result = normalized_path.to_string_lossy().replace('\\', "/"); + + // Restore trailing slash if original path had one + if ends_with_slash && !result.ends_with('/') { + result.push('/'); + } + + // Special case: when result is only "./", return "" + if result == "./" { + return Ok(String::new()); + } + + Ok(result) +} + +/// Normalize path by resolving ".." components without requiring file system access +fn normalize_path(path: &Path) -> PathBuf { + let mut components = Vec::new(); + + for component in path.components() { + match component { + std::path::Component::ParentDir => { + if !components.is_empty() { + components.pop(); + } + } + std::path::Component::CurDir => { + // Skip current directory components + } + _ => { + components.push(component); + } + } + } + + if components.is_empty() { + PathBuf::from(".") + } else { + components.iter().collect() + } +} + +pub fn format_path(path: impl Into) -> Result { + let path_str = format_path_str(path.into().display().to_string())?; + Ok(PathBuf::from(path_str)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_path() -> Result<(), std::io::Error> { + assert_eq!(format_path_str("C:\\Users\\\\test")?, "C:/Users/test"); + + assert_eq!( + format_path_str("/path/with/*unfriendly?chars")?, + "/path/with/unfriendlychars" + ); + + assert_eq!(format_path_str("\x1b[31m/path\x1b[0m")?, "/path"); + assert_eq!(format_path_str("/home/user/dir/")?, "/home/user/dir/"); + assert_eq!( + format_path_str("/home/user/file.txt")?, + "/home/user/file.txt" + ); + assert_eq!( + format_path_str("/home/my_user/DOCS/JVCS_TEST/Workspace/../Vault/")?, + "/home/my_user/DOCS/JVCS_TEST/Vault/" + ); + + assert_eq!(format_path_str("./home/file.txt")?, "home/file.txt"); + assert_eq!(format_path_str("./home/path/")?, "home/path/"); + assert_eq!(format_path_str("./")?, ""); + + Ok(()) + } +} diff --git a/utils/string_proc/src/format_processer.rs b/utils/string_proc/src/format_processer.rs new file mode 100644 index 0000000..8d0a770 --- /dev/null +++ b/utils/string_proc/src/format_processer.rs @@ -0,0 +1,132 @@ +pub struct FormatProcesser { + content: Vec, +} + +impl From for FormatProcesser { + fn from(value: String) -> Self { + Self { + content: Self::process_string(value), + } + } +} + +impl From<&str> for FormatProcesser { + fn from(value: &str) -> Self { + Self { + content: Self::process_string(value.to_string()), + } + } +} + +impl FormatProcesser { + /// Process the string into an intermediate format + fn process_string(input: String) -> Vec { + let mut result = String::new(); + let mut prev_space = false; + + for c in input.chars() { + match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' => { + result.push(c); + prev_space = false; + } + '_' | ',' | '.' | '-' | ' ' => { + if !prev_space { + result.push(' '); + prev_space = true; + } + } + _ => {} + } + } + + let mut processed = String::new(); + let mut chars = result.chars().peekable(); + + while let Some(c) = chars.next() { + processed.push(c); + if let Some(&next) = chars.peek() + && c.is_lowercase() + && next.is_uppercase() + { + processed.push(' '); + } + } + + processed + .to_lowercase() + .split_whitespace() + .map(|s| s.to_string()) + .collect() + } + + /// Convert to camelCase format (brewCoffee) + pub fn to_camel_case(&self) -> String { + let mut result = String::new(); + for (i, word) in self.content.iter().enumerate() { + if i == 0 { + result.push_str(&word.to_lowercase()); + } else { + let mut chars = word.chars(); + if let Some(first) = chars.next() { + result.push_str(&first.to_uppercase().collect::()); + result.push_str(&chars.collect::().to_lowercase()); + } + } + } + result + } + + /// Convert to PascalCase format (BrewCoffee) + pub fn to_pascal_case(&self) -> String { + let mut result = String::new(); + for word in &self.content { + let mut chars = word.chars(); + if let Some(first) = chars.next() { + result.push_str(&first.to_uppercase().collect::()); + result.push_str(&chars.collect::().to_lowercase()); + } + } + result + } + + /// Convert to kebab-case format (brew-coffee) + pub fn to_kebab_case(&self) -> String { + self.content.join("-").to_lowercase() + } + + /// Convert to snake_case format (brew_coffee) + pub fn to_snake_case(&self) -> String { + self.content.join("_").to_lowercase() + } + + /// Convert to dot.case format (brew.coffee) + pub fn to_dot_case(&self) -> String { + self.content.join(".").to_lowercase() + } + + /// Convert to Title Case format (Brew Coffee) + pub fn to_title_case(&self) -> String { + let mut result = String::new(); + for word in &self.content { + let mut chars = word.chars(); + if let Some(first) = chars.next() { + result.push_str(&first.to_uppercase().collect::()); + result.push_str(&chars.collect::().to_lowercase()); + } + result.push(' '); + } + result.pop(); + result + } + + /// Convert to lower case format (brew coffee) + pub fn to_lower_case(&self) -> String { + self.content.join(" ").to_lowercase() + } + + /// Convert to UPPER CASE format (BREW COFFEE) + pub fn to_upper_case(&self) -> String { + self.content.join(" ").to_uppercase() + } +} diff --git a/utils/string_proc/src/lib.rs b/utils/string_proc/src/lib.rs new file mode 100644 index 0000000..76588c1 --- /dev/null +++ b/utils/string_proc/src/lib.rs @@ -0,0 +1,50 @@ +pub mod format_path; +pub mod format_processer; +pub mod macros; +pub mod simple_processer; + +#[cfg(test)] +mod tests { + use crate::format_processer::FormatProcesser; + + #[test] + fn test_processer() { + let test_cases = vec![ + ("brew_coffee", "brewCoffee"), + ("brew, coffee", "brewCoffee"), + ("brew-coffee", "brewCoffee"), + ("Brew.Coffee", "brewCoffee"), + ("bRewCofFee", "bRewCofFee"), + ("brewCoffee", "brewCoffee"), + ("b&rewCoffee", "brewCoffee"), + ("BrewCoffee", "brewCoffee"), + ("brew.coffee", "brewCoffee"), + ("Brew_Coffee", "brewCoffee"), + ("BREW COFFEE", "brewCoffee"), + ]; + + for (input, expected) in test_cases { + let processor = FormatProcesser::from(input); + assert_eq!( + processor.to_camel_case(), + expected, + "Failed for input: '{}'", + input + ); + } + } + + #[test] + fn test_conversions() { + let processor = FormatProcesser::from("brewCoffee"); + + assert_eq!(processor.to_upper_case(), "BREW COFFEE"); + assert_eq!(processor.to_lower_case(), "brew coffee"); + assert_eq!(processor.to_title_case(), "Brew Coffee"); + assert_eq!(processor.to_dot_case(), "brew.coffee"); + assert_eq!(processor.to_snake_case(), "brew_coffee"); + assert_eq!(processor.to_kebab_case(), "brew-coffee"); + assert_eq!(processor.to_pascal_case(), "BrewCoffee"); + assert_eq!(processor.to_camel_case(), "brewCoffee"); + } +} diff --git a/utils/string_proc/src/macros.rs b/utils/string_proc/src/macros.rs new file mode 100644 index 0000000..135268e --- /dev/null +++ b/utils/string_proc/src/macros.rs @@ -0,0 +1,63 @@ +#[macro_export] +macro_rules! camel_case { + ($input:expr) => {{ + use string_proc::format_processer::FormatProcesser; + FormatProcesser::from($input).to_camel_case() + }}; +} + +#[macro_export] +macro_rules! upper_case { + ($input:expr) => {{ + use string_proc::format_processer::FormatProcesser; + FormatProcesser::from($input).to_upper_case() + }}; +} + +#[macro_export] +macro_rules! lower_case { + ($input:expr) => {{ + use string_proc::format_processer::FormatProcesser; + FormatProcesser::from($input).to_lower_case() + }}; +} + +#[macro_export] +macro_rules! title_case { + ($input:expr) => {{ + use string_proc::format_processer::FormatProcesser; + FormatProcesser::from($input).to_title_case() + }}; +} + +#[macro_export] +macro_rules! dot_case { + ($input:expr) => {{ + use string_proc::format_processer::FormatProcesser; + FormatProcesser::from($input).to_dot_case() + }}; +} + +#[macro_export] +macro_rules! snake_case { + ($input:expr) => {{ + use string_proc::format_processer::FormatProcesser; + FormatProcesser::from($input).to_snake_case() + }}; +} + +#[macro_export] +macro_rules! kebab_case { + ($input:expr) => {{ + use string_proc::format_processer::FormatProcesser; + FormatProcesser::from($input).to_kebab_case() + }}; +} + +#[macro_export] +macro_rules! pascal_case { + ($input:expr) => {{ + use string_proc::format_processer::FormatProcesser; + FormatProcesser::from($input).to_pascal_case() + }}; +} diff --git a/utils/string_proc/src/simple_processer.rs b/utils/string_proc/src/simple_processer.rs new file mode 100644 index 0000000..2de5dfc --- /dev/null +++ b/utils/string_proc/src/simple_processer.rs @@ -0,0 +1,15 @@ +/// Sanitizes a file path by replacing special characters with underscores. +/// +/// This function takes a file path as input and returns a sanitized version +/// where characters that are not allowed in file paths (such as path separators +/// and other reserved characters) are replaced with underscores. +pub fn sanitize_file_path>(path: P) -> String { + let path_str = path.as_ref(); + path_str + .chars() + .map(|c| match c { + '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_', + _ => c, + }) + .collect() +} diff --git a/utils/tcp_connection/Cargo.toml b/utils/tcp_connection/Cargo.toml new file mode 100644 index 0000000..da258be --- /dev/null +++ b/utils/tcp_connection/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "tcp_connection" +edition = "2024" +version.workspace = true + +[dependencies] +tokio = { version = "1.48.0", features = ["full"] } + +# Serialization +serde = { version = "1.0.228", features = ["derive"] } +serde_json = "1.0.145" +rmp-serde = "1.3.0" + +# Error handling +thiserror = "2.0.17" + +# Uuid & Random +uuid = "1.18.1" + +# Crypto +rsa = { version = "0.9", features = ["pkcs5", "sha2"] } +ed25519-dalek = "3.0.0-pre.1" +ring = "0.17.14" +rand = "0.10.0-rc.0" +base64 = "0.22.1" +pem = "3.0.6" +crc = "3.3.0" +blake3 = "1.8.2" diff --git a/utils/tcp_connection/src/error.rs b/utils/tcp_connection/src/error.rs new file mode 100644 index 0000000..32d06cc --- /dev/null +++ b/utils/tcp_connection/src/error.rs @@ -0,0 +1,122 @@ +use std::io; +use thiserror::Error; + +#[derive(Error, Debug, Clone)] +pub enum TcpTargetError { + #[error("Authentication failed: {0}")] + Authentication(String), + + #[error("Reference sheet not allowed: {0}")] + ReferenceSheetNotAllowed(String), + + #[error("Cryptographic error: {0}")] + Crypto(String), + + #[error("File operation error: {0}")] + File(String), + + #[error("I/O error: {0}")] + Io(String), + + #[error("Invalid configuration: {0}")] + Config(String), + + #[error("Locked: {0}")] + Locked(String), + + #[error("Network error: {0}")] + Network(String), + + #[error("No result: {0}")] + NoResult(String), + + #[error("Not found: {0}")] + NotFound(String), + + #[error("Not local machine: {0}")] + NotLocal(String), + + #[error("Not remote machine: {0}")] + NotRemote(String), + + #[error("Pool already exists: {0}")] + PoolAlreadyExists(String), + + #[error("Protocol error: {0}")] + Protocol(String), + + #[error("Serialization error: {0}")] + Serialization(String), + + #[error("Timeout: {0}")] + Timeout(String), + + #[error("Unsupported operation: {0}")] + Unsupported(String), +} + +impl From for TcpTargetError { + fn from(error: io::Error) -> Self { + TcpTargetError::Io(error.to_string()) + } +} + +impl From for TcpTargetError { + fn from(error: serde_json::Error) -> Self { + TcpTargetError::Serialization(error.to_string()) + } +} + +impl From<&str> for TcpTargetError { + fn from(value: &str) -> Self { + TcpTargetError::Protocol(value.to_string()) + } +} + +impl From for TcpTargetError { + fn from(value: String) -> Self { + TcpTargetError::Protocol(value) + } +} + +impl From for TcpTargetError { + fn from(error: rsa::errors::Error) -> Self { + TcpTargetError::Crypto(error.to_string()) + } +} + +impl From for TcpTargetError { + fn from(error: ed25519_dalek::SignatureError) -> Self { + TcpTargetError::Crypto(error.to_string()) + } +} + +impl From for TcpTargetError { + fn from(error: ring::error::Unspecified) -> Self { + TcpTargetError::Crypto(error.to_string()) + } +} + +impl From for TcpTargetError { + fn from(error: base64::DecodeError) -> Self { + TcpTargetError::Serialization(error.to_string()) + } +} + +impl From for TcpTargetError { + fn from(error: pem::PemError) -> Self { + TcpTargetError::Crypto(error.to_string()) + } +} + +impl From for TcpTargetError { + fn from(error: rmp_serde::encode::Error) -> Self { + TcpTargetError::Serialization(error.to_string()) + } +} + +impl From for TcpTargetError { + fn from(error: rmp_serde::decode::Error) -> Self { + TcpTargetError::Serialization(error.to_string()) + } +} diff --git a/utils/tcp_connection/src/instance.rs b/utils/tcp_connection/src/instance.rs new file mode 100644 index 0000000..8e6886c --- /dev/null +++ b/utils/tcp_connection/src/instance.rs @@ -0,0 +1,542 @@ +use std::{path::Path, time::Duration}; + +use serde::Serialize; +use tokio::{ + fs::{File, OpenOptions}, + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}, + net::TcpStream, +}; + +use ring::signature::{self}; + +use crate::error::TcpTargetError; + +const DEFAULT_CHUNK_SIZE: usize = 4096; +const DEFAULT_TIMEOUT_SECS: u64 = 10; + +const ECDSA_P256_SHA256_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm = + &signature::ECDSA_P256_SHA256_ASN1_SIGNING; +const ECDSA_P384_SHA384_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm = + &signature::ECDSA_P384_SHA384_ASN1_SIGNING; + +#[derive(Debug, Clone)] +pub struct ConnectionConfig { + pub chunk_size: usize, + pub timeout_secs: u64, + pub enable_crc_validation: bool, +} + +impl Default for ConnectionConfig { + fn default() -> Self { + Self { + chunk_size: DEFAULT_CHUNK_SIZE, + timeout_secs: DEFAULT_TIMEOUT_SECS, + enable_crc_validation: false, + } + } +} + +pub struct ConnectionInstance { + pub(crate) stream: TcpStream, + config: ConnectionConfig, +} + +impl From for ConnectionInstance { + fn from(stream: TcpStream) -> Self { + Self { + stream, + config: ConnectionConfig::default(), + } + } +} + +impl ConnectionInstance { + /// Create a new ConnectionInstance with custom configuration + pub fn with_config(stream: TcpStream, config: ConnectionConfig) -> Self { + Self { stream, config } + } + + /// Get a reference to the current configuration + pub fn config(&self) -> &ConnectionConfig { + &self.config + } + + /// Get a mutable reference to the current configuration + pub fn config_mut(&mut self) -> &mut ConnectionConfig { + &mut self.config + } + /// Serialize data and write to the target machine + pub async fn write(&mut self, data: Data) -> Result<(), TcpTargetError> + where + Data: Default + Serialize, + { + let Ok(json_text) = serde_json::to_string(&data) else { + return Err(TcpTargetError::Serialization( + "Serialize failed.".to_string(), + )); + }; + Self::write_text(self, json_text).await?; + Ok(()) + } + + /// Serialize data to MessagePack and write to the target machine + pub async fn write_msgpack(&mut self, data: Data) -> Result<(), TcpTargetError> + where + Data: Serialize, + { + let msgpack_data = rmp_serde::to_vec(&data)?; + let len = msgpack_data.len() as u32; + + self.stream.write_all(&len.to_be_bytes()).await?; + self.stream.write_all(&msgpack_data).await?; + Ok(()) + } + + /// Read data from target machine and deserialize from MessagePack + pub async fn read_msgpack(&mut self) -> Result + where + Data: serde::de::DeserializeOwned, + { + let mut len_buf = [0u8; 4]; + self.stream.read_exact(&mut len_buf).await?; + let len = u32::from_be_bytes(len_buf) as usize; + + let mut buffer = vec![0; len]; + self.stream.read_exact(&mut buffer).await?; + + let data = rmp_serde::from_slice(&buffer)?; + Ok(data) + } + + /// Read data from target machine and deserialize + pub async fn read(&mut self) -> Result + where + Data: Default + serde::de::DeserializeOwned, + { + let Ok(json_text) = Self::read_text(self).await else { + return Err(TcpTargetError::Io("Read failed.".to_string())); + }; + let Ok(deser_obj) = serde_json::from_str::(&json_text) else { + return Err(TcpTargetError::Serialization( + "Deserialize failed.".to_string(), + )); + }; + Ok(deser_obj) + } + + /// Serialize data and write to the target machine + pub async fn write_large(&mut self, data: Data) -> Result<(), TcpTargetError> + where + Data: Default + Serialize, + { + let Ok(json_text) = serde_json::to_string(&data) else { + return Err(TcpTargetError::Serialization( + "Serialize failed.".to_string(), + )); + }; + Self::write_large_text(self, json_text).await?; + Ok(()) + } + + /// Read data from target machine and deserialize + pub async fn read_large( + &mut self, + buffer_size: impl Into, + ) -> Result + where + Data: Default + serde::de::DeserializeOwned, + { + let Ok(json_text) = Self::read_large_text(self, buffer_size).await else { + return Err(TcpTargetError::Io("Read failed.".to_string())); + }; + let Ok(deser_obj) = serde_json::from_str::(&json_text) else { + return Err(TcpTargetError::Serialization( + "Deserialize failed.".to_string(), + )); + }; + Ok(deser_obj) + } + + /// Write text to the target machine + pub async fn write_text(&mut self, text: impl Into) -> Result<(), TcpTargetError> { + let text = text.into(); + let bytes = text.as_bytes(); + let len = bytes.len() as u32; + + self.stream.write_all(&len.to_be_bytes()).await?; + match self.stream.write_all(bytes).await { + Ok(_) => Ok(()), + Err(err) => Err(TcpTargetError::Io(err.to_string())), + } + } + + /// Read text from the target machine + pub async fn read_text(&mut self) -> Result { + let mut len_buf = [0u8; 4]; + self.stream.read_exact(&mut len_buf).await?; + let len = u32::from_be_bytes(len_buf) as usize; + + let mut buffer = vec![0; len]; + self.stream.read_exact(&mut buffer).await?; + + match String::from_utf8(buffer) { + Ok(text) => Ok(text), + Err(err) => Err(TcpTargetError::Serialization(format!( + "Invalid UTF-8 sequence: {}", + err + ))), + } + } + + /// Write large text to the target machine (chunked) + pub async fn write_large_text( + &mut self, + text: impl Into, + ) -> Result<(), TcpTargetError> { + let text = text.into(); + let bytes = text.as_bytes(); + let mut offset = 0; + + while offset < bytes.len() { + let chunk = &bytes[offset..]; + let written = match self.stream.write(chunk).await { + Ok(n) => n, + Err(err) => return Err(TcpTargetError::Io(err.to_string())), + }; + offset += written; + } + + Ok(()) + } + + /// Read large text from the target machine (chunked) + pub async fn read_large_text( + &mut self, + chunk_size: impl Into, + ) -> Result { + let chunk_size = chunk_size.into() as usize; + let mut buffer = Vec::new(); + let mut chunk_buf = vec![0; chunk_size]; + + loop { + match self.stream.read(&mut chunk_buf).await { + Ok(0) => break, // EOF + Ok(n) => { + buffer.extend_from_slice(&chunk_buf[..n]); + } + Err(err) => return Err(TcpTargetError::Io(err.to_string())), + } + } + + Ok(String::from_utf8_lossy(&buffer).to_string()) + } + + /// Write large MessagePack data to the target machine (chunked) + pub async fn write_large_msgpack( + &mut self, + data: Data, + chunk_size: impl Into, + ) -> Result<(), TcpTargetError> + where + Data: Serialize, + { + let msgpack_data = rmp_serde::to_vec(&data)?; + let chunk_size = chunk_size.into() as usize; + let len = msgpack_data.len() as u32; + + // Write total length first + self.stream.write_all(&len.to_be_bytes()).await?; + + // Write data in chunks + let mut offset = 0; + while offset < msgpack_data.len() { + let end = std::cmp::min(offset + chunk_size, msgpack_data.len()); + let chunk = &msgpack_data[offset..end]; + match self.stream.write(chunk).await { + Ok(n) => offset += n, + Err(err) => return Err(TcpTargetError::Io(err.to_string())), + } + } + + Ok(()) + } + + /// Read large MessagePack data from the target machine (chunked) + pub async fn read_large_msgpack( + &mut self, + chunk_size: impl Into, + ) -> Result + where + Data: serde::de::DeserializeOwned, + { + let chunk_size = chunk_size.into() as usize; + + // Read total length first + let mut len_buf = [0u8; 4]; + self.stream.read_exact(&mut len_buf).await?; + let total_len = u32::from_be_bytes(len_buf) as usize; + + // Read data in chunks + let mut buffer = Vec::with_capacity(total_len); + let mut remaining = total_len; + let mut chunk_buf = vec![0; chunk_size]; + + while remaining > 0 { + let read_size = std::cmp::min(chunk_size, remaining); + let chunk = &mut chunk_buf[..read_size]; + + match self.stream.read_exact(chunk).await { + Ok(_) => { + buffer.extend_from_slice(chunk); + remaining -= read_size; + } + Err(err) => return Err(TcpTargetError::Io(err.to_string())), + } + } + + let data = rmp_serde::from_slice(&buffer)?; + Ok(data) + } + + /// Write file to target machine. + pub async fn write_file(&mut self, file_path: impl AsRef) -> Result<(), TcpTargetError> { + let path = file_path.as_ref(); + + // Validate file + if !path.exists() { + return Err(TcpTargetError::File(format!( + "File not found: {}", + path.display() + ))); + } + if path.is_dir() { + return Err(TcpTargetError::File(format!( + "Path is directory: {}", + path.display() + ))); + } + + // Open file and get metadata + let mut file = File::open(path).await?; + let file_size = file.metadata().await?.len(); + + // Send file header (version + size + crc) + self.stream.write_all(&1u64.to_be_bytes()).await?; + self.stream.write_all(&file_size.to_be_bytes()).await?; + + // Calculate and send CRC32 if enabled + let file_crc = if self.config.enable_crc_validation { + let crc32 = crc::Crc::::new(&crc::CRC_32_ISO_HDLC); + let mut crc_calculator = crc32.digest(); + + let mut temp_reader = + BufReader::with_capacity(self.config.chunk_size, File::open(path).await?); + let mut temp_buffer = vec![0u8; self.config.chunk_size]; + let mut temp_bytes_read = 0; + + while temp_bytes_read < file_size { + let bytes_to_read = + (file_size - temp_bytes_read).min(self.config.chunk_size as u64) as usize; + temp_reader + .read_exact(&mut temp_buffer[..bytes_to_read]) + .await?; + crc_calculator.update(&temp_buffer[..bytes_to_read]); + temp_bytes_read += bytes_to_read as u64; + } + + crc_calculator.finalize() + } else { + 0 + }; + + self.stream.write_all(&file_crc.to_be_bytes()).await?; + + // If file size is 0, skip content transfer + if file_size == 0 { + self.stream.flush().await?; + + // Wait for receiver confirmation + let mut ack = [0u8; 1]; + tokio::time::timeout( + Duration::from_secs(self.config.timeout_secs), + self.stream.read_exact(&mut ack), + ) + .await + .map_err(|_| TcpTargetError::Timeout("Ack timeout".to_string()))??; + + if ack[0] != 1 { + return Err(TcpTargetError::Protocol( + "Receiver verification failed".to_string(), + )); + } + + return Ok(()); + } + + // Transfer file content + let mut reader = BufReader::with_capacity(self.config.chunk_size, &mut file); + let mut bytes_sent = 0; + + while bytes_sent < file_size { + let buffer = reader.fill_buf().await?; + if buffer.is_empty() { + break; + } + + let chunk_size = buffer.len().min((file_size - bytes_sent) as usize); + self.stream.write_all(&buffer[..chunk_size]).await?; + reader.consume(chunk_size); + + bytes_sent += chunk_size as u64; + } + + // Verify transfer completion + if bytes_sent != file_size { + return Err(TcpTargetError::File(format!( + "Transfer incomplete: expected {} bytes, sent {} bytes", + file_size, bytes_sent + ))); + } + + self.stream.flush().await?; + + // Wait for receiver confirmation + let mut ack = [0u8; 1]; + tokio::time::timeout( + Duration::from_secs(self.config.timeout_secs), + self.stream.read_exact(&mut ack), + ) + .await + .map_err(|_| TcpTargetError::Timeout("Ack timeout".to_string()))??; + + if ack[0] != 1 { + return Err(TcpTargetError::Protocol( + "Receiver verification failed".to_string(), + )); + } + + Ok(()) + } + + /// Read file from target machine + pub async fn read_file(&mut self, save_path: impl AsRef) -> Result<(), TcpTargetError> { + let path = save_path.as_ref(); + // Create CRC instance at function scope to ensure proper lifetime + let crc_instance = crc::Crc::::new(&crc::CRC_32_ISO_HDLC); + + // Make sure parent directory exists + if let Some(parent) = path.parent() + && !parent.exists() + { + tokio::fs::create_dir_all(parent).await?; + } + + // Read file header (version + size + crc) + let mut version_buf = [0u8; 8]; + self.stream.read_exact(&mut version_buf).await?; + let version = u64::from_be_bytes(version_buf); + if version != 1 { + return Err(TcpTargetError::Protocol( + "Unsupported transfer version".to_string(), + )); + } + + let mut size_buf = [0u8; 8]; + self.stream.read_exact(&mut size_buf).await?; + let file_size = u64::from_be_bytes(size_buf); + + let mut expected_crc_buf = [0u8; 4]; + self.stream.read_exact(&mut expected_crc_buf).await?; + let expected_crc = u32::from_be_bytes(expected_crc_buf); + if file_size == 0 { + // Create empty file and return early + let _file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .await?; + // Send confirmation + self.stream.write_all(&[1u8]).await?; + self.stream.flush().await?; + return Ok(()); + } + + // Prepare output file + let file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .await?; + let mut writer = BufWriter::with_capacity(self.config.chunk_size, file); + + // Receive file content with CRC calculation if enabled + let mut bytes_received = 0; + let mut buffer = vec![0u8; self.config.chunk_size]; + let mut crc_calculator = if self.config.enable_crc_validation { + Some(crc_instance.digest()) + } else { + None + }; + + while bytes_received < file_size { + let bytes_to_read = + (file_size - bytes_received).min(self.config.chunk_size as u64) as usize; + let chunk = &mut buffer[..bytes_to_read]; + + self.stream.read_exact(chunk).await?; + + writer.write_all(chunk).await?; + + // Update CRC if validation is enabled + if let Some(ref mut crc) = crc_calculator { + crc.update(chunk); + } + + bytes_received += bytes_to_read as u64; + } + + // Verify transfer completion + if bytes_received != file_size { + return Err(TcpTargetError::File(format!( + "Transfer incomplete: expected {} bytes, received {} bytes", + file_size, bytes_received + ))); + } + + writer.flush().await?; + + // Validate CRC if enabled + if self.config.enable_crc_validation + && let Some(crc_calculator) = crc_calculator + { + let actual_crc = crc_calculator.finalize(); + if actual_crc != expected_crc && expected_crc != 0 { + return Err(TcpTargetError::File(format!( + "CRC validation failed: expected {:08x}, got {:08x}", + expected_crc, actual_crc + ))); + } + } + + // Final flush and sync + writer.flush().await?; + writer.into_inner().sync_all().await?; + + // Verify completion + if bytes_received != file_size { + let _ = tokio::fs::remove_file(path).await; + return Err(TcpTargetError::File(format!( + "Transfer incomplete: expected {} bytes, received {} bytes", + file_size, bytes_received + ))); + } + + // Send confirmation + self.stream.write_all(&[1u8]).await?; + self.stream.flush().await?; + + Ok(()) + } +} diff --git a/utils/tcp_connection/src/instance_challenge.rs b/utils/tcp_connection/src/instance_challenge.rs new file mode 100644 index 0000000..3a7f6a3 --- /dev/null +++ b/utils/tcp_connection/src/instance_challenge.rs @@ -0,0 +1,311 @@ +use std::path::Path; + +use rand::TryRngCore; +use rsa::{ + RsaPrivateKey, RsaPublicKey, + pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey}, + sha2, +}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey}; +use ring::rand::SystemRandom; +use ring::signature::{ + self, ECDSA_P256_SHA256_ASN1, ECDSA_P384_SHA384_ASN1, EcdsaKeyPair, RSA_PKCS1_2048_8192_SHA256, + UnparsedPublicKey, +}; + +use crate::{error::TcpTargetError, instance::ConnectionInstance}; + +const ECDSA_P256_SHA256_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm = + &signature::ECDSA_P256_SHA256_ASN1_SIGNING; +const ECDSA_P384_SHA384_ASN1_SIGNING: &signature::EcdsaSigningAlgorithm = + &signature::ECDSA_P384_SHA384_ASN1_SIGNING; + +impl ConnectionInstance { + /// Initiates a challenge to the target machine to verify connection security + /// + /// This method performs a cryptographic challenge-response authentication: + /// 1. Generates a random 32-byte challenge + /// 2. Sends the challenge to the target machine + /// 3. Receives a digital signature of the challenge + /// 4. Verifies the signature using the appropriate public key + /// + /// # Arguments + /// * `public_key_dir` - Directory containing public key files for verification + /// + /// # Returns + /// * `Ok((true, "KeyId"))` - Challenge verification successful + /// * `Ok((false, "KeyId"))` - Challenge verification failed + /// * `Err(TcpTargetError)` - Error during challenge process + pub async fn challenge( + &mut self, + public_key_dir: impl AsRef, + ) -> Result<(bool, String), TcpTargetError> { + // Generate random challenge + let mut challenge = [0u8; 32]; + rand::rngs::OsRng + .try_fill_bytes(&mut challenge) + .map_err(|e| { + TcpTargetError::Crypto(format!("Failed to generate random challenge: {}", e)) + })?; + + // Send challenge to target + self.stream.write_all(&challenge).await?; + self.stream.flush().await?; + + // Read signature from target + let mut signature = Vec::new(); + let mut signature_len_buf = [0u8; 4]; + self.stream.read_exact(&mut signature_len_buf).await?; + + let signature_len = u32::from_be_bytes(signature_len_buf) as usize; + signature.resize(signature_len, 0); + self.stream.read_exact(&mut signature).await?; + + // Read key identifier from target to identify which public key to use + let mut key_id_len_buf = [0u8; 4]; + self.stream.read_exact(&mut key_id_len_buf).await?; + let key_id_len = u32::from_be_bytes(key_id_len_buf) as usize; + + let mut key_id_buf = vec![0u8; key_id_len]; + self.stream.read_exact(&mut key_id_buf).await?; + let key_id = String::from_utf8(key_id_buf) + .map_err(|e| TcpTargetError::Crypto(format!("Invalid key identifier: {}", e)))?; + + // Load appropriate public key + let public_key_path = public_key_dir.as_ref().join(format!("{}.pem", key_id)); + if !public_key_path.exists() { + return Ok((false, key_id)); + } + + let public_key_pem = tokio::fs::read_to_string(&public_key_path).await?; + + // Try to verify with different key types + let verified = if let Ok(rsa_key) = RsaPublicKey::from_pkcs1_pem(&public_key_pem) { + let padding = rsa::pkcs1v15::Pkcs1v15Sign::new::(); + rsa_key.verify(padding, &challenge, &signature).is_ok() + } else if let Ok(ed25519_key) = + VerifyingKey::from_bytes(&parse_ed25519_public_key(&public_key_pem)) + { + if signature.len() == 64 { + let sig_bytes: [u8; 64] = signature.as_slice().try_into().map_err(|_| { + TcpTargetError::Crypto("Invalid signature length for Ed25519".to_string()) + })?; + let sig = Signature::from_bytes(&sig_bytes); + ed25519_key.verify(&challenge, &sig).is_ok() + } else { + false + } + } else if let Ok(dsa_key_info) = parse_dsa_public_key(&public_key_pem) { + verify_dsa_signature(&dsa_key_info, &challenge, &signature) + } else { + false + }; + + Ok((verified, key_id)) + } + + /// Accepts a challenge from the target machine to verify connection security + /// + /// This method performs a cryptographic challenge-response authentication: + /// 1. Receives a random 32-byte challenge from the target machine + /// 2. Signs the challenge using the appropriate private key + /// 3. Sends the digital signature back to the target machine + /// 4. Sends the key identifier for public key verification + /// + /// # Arguments + /// * `private_key_file` - Path to the private key file for signing + /// * `verify_public_key` - Key identifier for public key verification + /// + /// # Returns + /// * `Ok(true)` - Challenge response sent successfully + /// * `Ok(false)` - Private key format not supported + /// * `Err(TcpTargetError)` - Error during challenge response process + pub async fn accept_challenge( + &mut self, + private_key_file: impl AsRef, + verify_public_key: &str, + ) -> Result { + // Read challenge from initiator + let mut challenge = [0u8; 32]; + self.stream.read_exact(&mut challenge).await?; + + // Load private key + let private_key_pem = tokio::fs::read_to_string(&private_key_file) + .await + .map_err(|e| { + TcpTargetError::NotFound(format!( + "Read private key \"{}\" failed: \"{}\"", + private_key_file + .as_ref() + .display() + .to_string() + .split("/") + .last() + .unwrap_or("UNKNOWN"), + e + )) + })?; + + // Sign the challenge with supported key types + let signature = if let Ok(rsa_key) = RsaPrivateKey::from_pkcs1_pem(&private_key_pem) { + let padding = rsa::pkcs1v15::Pkcs1v15Sign::new::(); + rsa_key.sign(padding, &challenge)? + } else if let Ok(ed25519_key) = parse_ed25519_private_key(&private_key_pem) { + ed25519_key.sign(&challenge).to_bytes().to_vec() + } else if let Ok(dsa_key_info) = parse_dsa_private_key(&private_key_pem) { + sign_with_dsa(&dsa_key_info, &challenge)? + } else { + return Ok(false); + }; + + // Send signature length and signature + let signature_len = signature.len() as u32; + self.stream.write_all(&signature_len.to_be_bytes()).await?; + self.stream.flush().await?; + self.stream.write_all(&signature).await?; + self.stream.flush().await?; + + // Send key identifier for public key identification + let key_id_bytes = verify_public_key.as_bytes(); + let key_id_len = key_id_bytes.len() as u32; + self.stream.write_all(&key_id_len.to_be_bytes()).await?; + self.stream.flush().await?; + self.stream.write_all(key_id_bytes).await?; + self.stream.flush().await?; + + Ok(true) + } +} + +/// Parse Ed25519 public key from PEM format +fn parse_ed25519_public_key(pem: &str) -> [u8; 32] { + // Robust parsing for Ed25519 public key using pem crate + let mut key_bytes = [0u8; 32]; + + if let Ok(pem_data) = pem::parse(pem) + && pem_data.tag() == "PUBLIC KEY" + && pem_data.contents().len() >= 32 + { + let contents = pem_data.contents(); + key_bytes.copy_from_slice(&contents[contents.len() - 32..]); + } + key_bytes +} + +/// Parse Ed25519 private key from PEM format +fn parse_ed25519_private_key(pem: &str) -> Result { + if let Ok(pem_data) = pem::parse(pem) + && pem_data.tag() == "PRIVATE KEY" + && pem_data.contents().len() >= 32 + { + let contents = pem_data.contents(); + let mut seed = [0u8; 32]; + seed.copy_from_slice(&contents[contents.len() - 32..]); + return Ok(SigningKey::from_bytes(&seed)); + } + Err(TcpTargetError::Crypto( + "Invalid Ed25519 private key format".to_string(), + )) +} + +/// Parse DSA public key information from PEM +fn parse_dsa_public_key( + pem: &str, +) -> Result<(&'static dyn signature::VerificationAlgorithm, Vec), TcpTargetError> { + if let Ok(pem_data) = pem::parse(pem) { + let contents = pem_data.contents().to_vec(); + + // Try different DSA algorithms based on PEM tag + match pem_data.tag() { + "EC PUBLIC KEY" | "PUBLIC KEY" if pem.contains("ECDSA") || pem.contains("ecdsa") => { + if pem.contains("P-256") { + return Ok((&ECDSA_P256_SHA256_ASN1, contents)); + } else if pem.contains("P-384") { + return Ok((&ECDSA_P384_SHA384_ASN1, contents)); + } + } + "RSA PUBLIC KEY" | "PUBLIC KEY" => { + return Ok((&RSA_PKCS1_2048_8192_SHA256, contents)); + } + _ => {} + } + + // Default to RSA for unknown types + return Ok((&RSA_PKCS1_2048_8192_SHA256, contents)); + } + Err(TcpTargetError::Crypto( + "Invalid DSA public key format".to_string(), + )) +} + +/// Parse DSA private key information from PEM +fn parse_dsa_private_key( + pem: &str, +) -> Result<(&'static dyn signature::VerificationAlgorithm, Vec), TcpTargetError> { + // For DSA, private key verification uses the same algorithm as public key + parse_dsa_public_key(pem) +} + +/// Verify DSA signature +fn verify_dsa_signature( + algorithm_and_key: &(&'static dyn signature::VerificationAlgorithm, Vec), + message: &[u8], + signature: &[u8], +) -> bool { + let (algorithm, key_bytes) = algorithm_and_key; + let public_key = UnparsedPublicKey::new(*algorithm, key_bytes); + public_key.verify(message, signature).is_ok() +} + +/// Sign with DSA +fn sign_with_dsa( + algorithm_and_key: &(&'static dyn signature::VerificationAlgorithm, Vec), + message: &[u8], +) -> Result, TcpTargetError> { + let (algorithm, key_bytes) = algorithm_and_key; + + // Handle different DSA/ECDSA algorithms by comparing algorithm identifiers + // Since we can't directly compare trait objects, we use pointer comparison + let algorithm_ptr = algorithm as *const _ as *const (); + let ecdsa_p256_ptr = &ECDSA_P256_SHA256_ASN1 as *const _ as *const (); + let ecdsa_p384_ptr = &ECDSA_P384_SHA384_ASN1 as *const _ as *const (); + + if algorithm_ptr == ecdsa_p256_ptr { + let key_pair = EcdsaKeyPair::from_pkcs8( + ECDSA_P256_SHA256_ASN1_SIGNING, + key_bytes, + &SystemRandom::new(), + ) + .map_err(|e| { + TcpTargetError::Crypto(format!("Failed to create ECDSA P-256 key pair: {}", e)) + })?; + + let signature = key_pair + .sign(&SystemRandom::new(), message) + .map_err(|e| TcpTargetError::Crypto(format!("ECDSA P-256 signing failed: {}", e)))?; + + Ok(signature.as_ref().to_vec()) + } else if algorithm_ptr == ecdsa_p384_ptr { + let key_pair = EcdsaKeyPair::from_pkcs8( + ECDSA_P384_SHA384_ASN1_SIGNING, + key_bytes, + &SystemRandom::new(), + ) + .map_err(|e| { + TcpTargetError::Crypto(format!("Failed to create ECDSA P-384 key pair: {}", e)) + })?; + + let signature = key_pair + .sign(&SystemRandom::new(), message) + .map_err(|e| TcpTargetError::Crypto(format!("ECDSA P-384 signing failed: {}", e)))?; + + Ok(signature.as_ref().to_vec()) + } else { + // RSA or unsupported algorithm + Err(TcpTargetError::Unsupported( + "DSA/ECDSA signing not supported for this algorithm type".to_string(), + )) + } +} diff --git a/utils/tcp_connection/src/lib.rs b/utils/tcp_connection/src/lib.rs new file mode 100644 index 0000000..6a2e599 --- /dev/null +++ b/utils/tcp_connection/src/lib.rs @@ -0,0 +1,6 @@ +#[allow(dead_code)] +pub mod instance; + +pub mod instance_challenge; + +pub mod error; diff --git a/utils/tcp_connection/tcp_connection_test/Cargo.toml b/utils/tcp_connection/tcp_connection_test/Cargo.toml new file mode 100644 index 0000000..19a6e9b --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "tcp_connection_test" +edition = "2024" +version.workspace = true + +[dependencies] +tcp_connection = { path = "../../tcp_connection" } +tokio = { version = "1.48.0", features = ["full"] } +serde = { version = "1.0.228", features = ["derive"] } diff --git a/utils/tcp_connection/tcp_connection_test/res/image/test_transfer.png b/utils/tcp_connection/tcp_connection_test/res/image/test_transfer.png new file mode 100644 index 0000000..5fa94f0 Binary files /dev/null and b/utils/tcp_connection/tcp_connection_test/res/image/test_transfer.png differ diff --git a/utils/tcp_connection/tcp_connection_test/res/key/test_key.pem b/utils/tcp_connection/tcp_connection_test/res/key/test_key.pem new file mode 100644 index 0000000..e155876 --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/res/key/test_key.pem @@ -0,0 +1,13 @@ +-----BEGIN RSA PUBLIC KEY----- +MIICCgKCAgEAl5vyIwGYiQ1zZpW2tg+LwOUV547T2SjlzKQjcms5je/epP4CnUfT +5cmHCe8ZaSbnofcntCzi8FzMpQmzhNzFk5tCAe4tSrghfr2kYDO7aUL0G09KbNZ5 +iuMTkMaHx6LMjZ+Ljy8fC47yC2dFMUgLjGS7xS6rnIo4YtFuvMdwbLjs7mSn+vVc +kcEV8RLlQg8wDbzpl66Jd1kiUgPfVLBRTLE/iL8kUCz1l8c+DvOzr3ATwJysM9CG +LFahGLlTd3CZaj0QsEzf/AQsn79Su+rnCXhXqcvynhAcil0UW9RWp5Zsvp3Me3W8 +pJg6vZuAA6lQ062hkRLiJ91F2rpyqtkax5i/simLjelpsRzLKo6Xsz1bZht2+5d5 +ArgTBtZBxS044t8caZWLXetnPEcxEGz8KYUVKf7X9S7R53gy36y88Fbu9giqUr3m +b3Da+SYzBT//hacGn55nhzLRdsJGaFFWcKCbpue6JHLsFhizhdEAjaec0hfphw29 +veY0adPdIFLQDmMKaNk4ulrz8Lbgpqn9gxx6fRssj9jqNJmW64a0eV+Rw7BCJazH +xp3zz4A3rwdI8BjxLUb3YiCUcavA9WzJ1DUfdX1FSvbcFw4CEiGJjfpWGrm1jtc6 +DMOsoX/C6yFOyRpipsgqIToBClchLSNgrO6A7SIoSdIqNDEgIanFcjECAwEAAQ== +-----END RSA PUBLIC KEY----- diff --git a/utils/tcp_connection/tcp_connection_test/res/key/test_key_private.pem b/utils/tcp_connection/tcp_connection_test/res/key/test_key_private.pem new file mode 100644 index 0000000..183d2d9 --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/res/key/test_key_private.pem @@ -0,0 +1,51 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIJKAIBAAKCAgEAl5vyIwGYiQ1zZpW2tg+LwOUV547T2SjlzKQjcms5je/epP4C +nUfT5cmHCe8ZaSbnofcntCzi8FzMpQmzhNzFk5tCAe4tSrghfr2kYDO7aUL0G09K +bNZ5iuMTkMaHx6LMjZ+Ljy8fC47yC2dFMUgLjGS7xS6rnIo4YtFuvMdwbLjs7mSn ++vVckcEV8RLlQg8wDbzpl66Jd1kiUgPfVLBRTLE/iL8kUCz1l8c+DvOzr3ATwJys +M9CGLFahGLlTd3CZaj0QsEzf/AQsn79Su+rnCXhXqcvynhAcil0UW9RWp5Zsvp3M +e3W8pJg6vZuAA6lQ062hkRLiJ91F2rpyqtkax5i/simLjelpsRzLKo6Xsz1bZht2 ++5d5ArgTBtZBxS044t8caZWLXetnPEcxEGz8KYUVKf7X9S7R53gy36y88Fbu9giq +Ur3mb3Da+SYzBT//hacGn55nhzLRdsJGaFFWcKCbpue6JHLsFhizhdEAjaec0hfp +hw29veY0adPdIFLQDmMKaNk4ulrz8Lbgpqn9gxx6fRssj9jqNJmW64a0eV+Rw7BC +JazHxp3zz4A3rwdI8BjxLUb3YiCUcavA9WzJ1DUfdX1FSvbcFw4CEiGJjfpWGrm1 +jtc6DMOsoX/C6yFOyRpipsgqIToBClchLSNgrO6A7SIoSdIqNDEgIanFcjECAwEA +AQKCAgAd3cg9Ei7o7N/reRnV0skutlJy2+Wq9Y4TmtAq1amwZu0e5rVAI6rALUuv +bs08NEBUXVqSeXc5b6aW6orVZSJ8+gxuUevVOOHMVHKhyv8j9N8e1Cduum+WJzav +AhU0hEM0sRXunpNIlR/klDMCytUPkraU2SVQgMAr42MjyExC9skiC202GIjkY7u9 +UoIcWd6XDjycN3N4MfR7YKzpw5Q4fgBsoW73Zmv5OvRkQKkIqhUSECsyR+VuraAt +vTCOqn1meuIjQPms7WuXCrszLsrVyEHIvtcsQTNGJKECmBl8CTuh73cdaSvA5wZH +XO9CiWPVV3KpICWyQbplpO467usB0liMX3mcMp+Ztp/p/ns6Ov5L6AR8LcDJ43KA +454ZUYxbRjqG+cW6Owm5Ii0+UOEGOi+6Jhc4NGZuYU2gDrhuz4yejY6bDAu8Ityd +umVU90IePVm6dlMM5cgyDmCXUkOVsjegMIBP+Zf3an1JWtsDL2RW5OwrFH7DQaqG +UwE/w/JOkRe3UMcTECfjX1ACJlB8XDAXiNeBQsAFOVVkWdBE4D7IlQLJVZAyGSlt +NMTn9/kQBGgdlyEqVAPKGnfl08TubyL7/9xOhCoYsv0IIOI8xgT7zQwefUAn2TFb +ulHIdVovRI4Oa0n7WfK4srL73XqjKYJAC9nmxXMwKe1wokjREwKCAQEAyNZKWY88 +4OqYa9xEEJwEOAA5YWLZ/+b9lCCQW8gMeVyTZ7A4vJVyYtBvRBlv6MhB4OTIf9ah +YuyZMl6oNCs2SBP1lKxsPlGaothRlEmPyTWXOt9iRLpPHUcGG1odfeGpI0bdHs1n +E/OpKYwzD0oSe5PGA0zkcetG61klPw8NIrjTkQ2hMqDV+ppF0lPxe/iudyTVMGhX +aHcd95DZNGaS503ZcSjN4MeVSkQEDI4fu4XK4135DCaKOmIPtOd6Rw+qMxoCC7Wl +cEDnZ6eqQ5EOy8Ufz8WKqGSVWkr6cO/qtulFLAj0hdL0aENTCRer+01alybXJXyB +GKyCk7i2RDlbGwKCAQEAwUA7SU7/0dKPJ2r0/70R6ayxZ7tQZK4smFgtkMDeWsaw +y2lZ6r44iJR/Tg6+bP8MjGzP/GU1i5QIIjJMGx2/VTWjJSOsFu3edZ5PHQUVSFQE +8FAhYXWOH+3igfgWJMkzhVsBo9/kINaEnt9jLBE8okEY+9/JEsdBqV/S4dkxjUPT +E+62kX9lkQVk/gCWjsLRKZV4d87gXU8mMQbhgj99qg1joffV132vo6pvBBBCJ4Ex +4/JxIQ2W/GmkrFe8NlvD1CEMyvkeV+g2wbtvjWs0Ezyzh4njJAtKMe0SEg5dFTqa +eL/GjpgfIP7Uu30V35ngkgl7CuY1D/IJg4PxKthQowKCAQBUGtFWAhMXiYa9HKfw +YLWvkgB1lQUAEoa84ooxtWvr4uXj9Ts9VkRptynxVcm0rTBRct24E3TQTY62Nkew +WSxJMPqWAULvMhNVAMvhEpFBTM0BHY00hOUeuKCJEcrp7Xd8S2/MN25kP5TmzkyP +qZBl6fNxbGD6h/HSGynq522zzbzjsNaBsjMJ2FNHClpFdVXylR0mQXvhRojpJOKg +/Bem/8YAinr1F/+f8y3S6C3HxPa7Ep56BSW731b+hjWBzsCS1+BlcPNQOA3wLZmy +4+tTUEDLLMmtTTnybxXD9+TOJpAOKc3kwPwTMaZzV1NxUOqQA/bzPtl9MLkaDa9e +kLpjAoIBACRFtxsKbe/nMqF2bOf3h/4xQNc0jGFpY8tweZT67oFhW9vCOXNbIudX +4BE5qTpyINvWrK82G/fH4ELy5+ALFFedCrM0398p5KB1B2puAtGhm4+zqqBNXVDW +6LX2Z8mdzkLQkx08L+iN+zSKv2WNErFtwI++MFKK/eMZrk5f4vId8eeC3devbtPq +jEs0tw2yuWmxuXvbY7d/3K5FGVzGKAMcIkBLcWLSH357xfygRJp/oGqlneBTWayk +85i5mwUk8jvFvE34tl5Por94O/byUULvGM9u7Shdyh5W3hZvhb8vUcEqVc179hPO +YQWT8+AVVNZ0WxjvnrQQfQKnaEPfeDsCggEBAJ7zgVVla8BOagEenKwr6nEkQzK/ +sTcF9Zp7TmyGKGdM4rW+CJqGgwswn65va+uZj7o0+D5JGeB8kRG5GtjUUzHkNBD0 +Av6KZksQDqgKdwPaH0MQSXCuUc0MYTBHDJdciN/DqdO8st69hyNRv4XdHst1SZdJ +VjUh3p4iwO4wfQQW7mvj94lLM/ypMdUqPKxVHVWQsbE9fOVbyKINuIDPDzu5iqc3 +VKScUwqpcGPZsgHr/Sguv/fdFnPs4O+N0AsAe3xbleCfQAeZnI0tR8nkYudvmxNz +MRevTAPDUBUDd0Uiy+d6w6B4vW8q9Zv3oFLXns4kWsJFajjx3TdgTacnVlI= +-----END RSA PRIVATE KEY----- diff --git a/utils/tcp_connection/tcp_connection_test/res/key/wrong_key_private.pem b/utils/tcp_connection/tcp_connection_test/res/key/wrong_key_private.pem new file mode 100644 index 0000000..4b77eea --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/res/key/wrong_key_private.pem @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCvmvYR6ypNS4ld +cyJDlwv+4KC8/SxKBhlen8FX6Ltzfzi3f1I7qXZByGaTasQtc4qWgl0tLkrA8Pc3 +pm/r2To+Gl5cXMMz/zKFShuviGp/F17eS1idpNSFO6ViF+WXrENdESB7E6Dm4teK ++WLdtOHk3exC/+F+YUK3Jh6lTR5+donHaURlKjcKiRY7YxHq9HbrYXujJyiuU51a +nDvV20AWy7cKGGPRpV8YSoNGxE24WpWjjf0l++aFSpQaKanoV9tL4ZI0aXMFawSB +4YKBjtht6Cxm37oeBaimUKxA7BKH/DUueQsjAfw0WgZhItBDEtKjs2tMkOj/VUuF +OYrC58vunQDd/sP60BV3F/yPZiuBIB4PyXe2PVRabMBq2p2uiGexjoQ9DR+jU9ig +532KxckPHyqXzLd7MwljLw8ypahxMSE/lgcZIZh5I9oDsZWPD8Kx8D6eq/gypTkd +v8bewOTtj8GN2/MyxQZzYsz2ZruUrPv7hd38qjsFkKrb8wPM6xTOScM7CYFNceiL +3DcawivS+f5TgVkrjBGqhwpOc96ZHHuojw9f8KyJ3DON5CWKPpyKJvXEt6QuT5dc +BPZM33KHSuDCJUrw9dnh6rkaTnx681csAGJTYX2zeNxTI9DO/YSvpEK5e5MaZ9Kc +OETgnXiOe9KlNBtJeLd5XwvnYelzYQIDAQABAoICAAIis01ie8A24/PD+62wv4+Y +8bt6pLg9vL8+2B4WkXkFGg55OOnK1MpWApFWYg5fclcEPNfY0UXpaEg/+Op4WNH6 +hh0/b4xJVTbzwMRwt0LWaOvxJKG+KGt6XzeDLOKcULFoDOoSQgmsxoxFHiOuGHUt +Ebt62yYrTqFlkEfYWT+Wd3R6Xj+QtNym8CNGwCgIUw3nwJYqWr9L+wToE341TWE5 +lv9DbqtVBIQKG/CXYI6WY216w5JbruD+GDD9Qri1oNAabSnAAosVUxe1Q14J+63S +ff++Rsgor3VeU8nyVQNcWNU42Z7SXlvQoHU79CZsqy0ceHiU5pB8XA/BtGNMaFl4 +UehZPTsJhi8dlUdTYw5f5oOnHltNpSioy0KtqEBJjJX+CzS1UMAr6k9gtjbWeXpD +88JwoOy8n6HLAYETu/GiHLHpyIWJ84O+PeAO5jBCQTJN80fe3zbF+zJ5tHMHIFts +mGNmY9arKMCZHP642W3JRJsjN3LjdtzziXnhQzgKnPh/uCzceHZdSLf3S7NsEVOX +ZWb2nuDObJCpKD/4Hq2HpfupMNO73SUcbzg2slsRCRdDrokxOSEUHm7y9GD7tS2W +IC8A09pyCvM25k3so0QPpDP4+i/df7j862rb9+zctwhEWPdXTbFjI+9rI8JBcUwe +t94TFb5b9uB/kWYPnmUBAoIBAQDxiZjm5i8OInuedPnLkxdy31u/tqb+0+GMmp60 +gtmf7eL6Xu3F9Uqr6zH9o90CkdzHmtz6BcBTo/hUiOcTHj9Xnsso1GbneUlHJl9R ++G68sKWMXW76OfSKuXQ1fwlXV+J7Lu0XNIEeLVy09pYgjGKFn2ql7ELpRh7j1UXH +KbFVl2ESn5IVU4oGl+MMB5OzGYpyhuro24/sVSlaeXHakCLcHV69PvjyocQy8g+8 +Z1pXKqHy3mV6MOmSOJ4DqDxaZ2dLQR/rc7bvpxDIxtwMwD/a//xGlwnePOS/0IcB +I2dgFmRNwJ8WC9Le0E+EsEUD929fXEF3+CZN4E+KAuY8Y8UxAoIBAQC6HrlSdfVF +kmpddU4VLD5T/FuA6wB32VkXa6sXWiB0j8vOipGZkUvqQxnJiiorL0AECk3PXXT+ +wXgjqewZHibpJKeqaI4Zqblqebqb68VIANhO0DhRWsh63peVjAPNUmg+tfZHuEBE +bJlz1IBx0der5KBZfg7mngrXvQqIAYSr+Gl14PvwOGqG6Xjy+5VEJqDzEm9VaOnm +mm39st5oRotYnXdf83AV2aLI8ukkq0/mHAySlu5A4VhA5kTJT16Lam2h590AtmBH +6xsO1BtDmfVsaUxBSojkEW8eap+vbyU9vuwjrtm/dG19qcnyesjTJMFQgGnaY46L +ID/aNSDwssUxAoIBAQDFYaBl8G07q8pBr24Cgm2DHiwn+ud1D0keUayn7tZQ72Gx +IKpGPzGKVGVB1Qri8rftFgzG9LQ6paBl1IqhAPLac5WqBAkj1+WeEymKHu6/m8tt +bV0ndvzz8KGapfnIOrWF3M87S1jIhGFiMLB2YMKSV7gbZ3s2jmrn3H1tSBD21QIq +6ePDMcV1peGRDxAQKCsPdFm7eNGgW+ezW9NCvM7/+bBWDoP6I1/mEhHx8LPOz7QQ +eNWMiTQWndXjPzQy3JV41ftzudgg9/GrYXappOGJ4e8S8JLL3g9BAPOSZpAv4ZyO +PX7D0V29X5Xb5QBBQY7t6sJFe7Axq8DUE5J6fz3BAoIBAHLFEWh9HsNJF1gMRxsd +Tk4B9vcXcxF0sNCVb0qWJB9csMPrhP9arqKFwDgcgAZjO6mCJRszOTsDWK89UD7o +7fukw9N8Z+wBUjoLWHxftibBhqGLGr9oKOpDqtvoHEwXffr1wCnXv6GyCip4JsCJ +MuJnuE2XQ18IpA0HIKBft01IgNfU5ebrEx2giRnk89WzsFpTyt2zNVEjd6ITE7zf +i3wYlg1QE5UVwKED0arwDPQL5eDbO448p2xV0qME03tLJNHLJegTjmmq2+OX/jwA +i2vPvtsgOCvTaF8sRs4qzp81xW33m4TJKd9svQBOoNo69w5KMXwfGj5Go7lOO8LR +qnECggEAII/9+EdPUMx97Ex9R6sc9VQEpjxzlJmA9RaVASoZiinydP9QToLYhZif +QhSjHOrbPfGorNMIaVCOS4WGZWnJBSDX8uVvhi/N6mWegmj8w/WZrNuNOT99/8Fq +HXMnpOrXJsgQ4MDVzu+V8DISgrirf+PdBW1u/JtdjwmunlnPE1AsJUDWZlDTttaE +0p32cDq6j+eUxfBq5/haZxe92Jq9Wr+o+gXNO9EwZCO+bTtHFJJso5YbU548kMdA +j5y4BUf/jkCqK8c6sufbfP4MN4YnWbdSPmH3V2DF3g1okalUYp2sAOgAwwPjFAOu +f9qBWGCwdZjeDjaVVUgwi+Waf+M0tQ== +-----END PRIVATE KEY----- diff --git a/utils/tcp_connection/tcp_connection_test/src/lib.rs b/utils/tcp_connection/tcp_connection_test/src/lib.rs new file mode 100644 index 0000000..c9372d4 --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/lib.rs @@ -0,0 +1,17 @@ +#[cfg(test)] +pub mod test_tcp_target_build; + +#[cfg(test)] +pub mod test_connection; + +#[cfg(test)] +pub mod test_challenge; + +#[cfg(test)] +pub mod test_file_transfer; + +#[cfg(test)] +pub mod test_msgpack; + +pub mod test_utils; +pub use test_utils::*; diff --git a/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs b/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs new file mode 100644 index 0000000..9327b3e --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/test_challenge.rs @@ -0,0 +1,160 @@ +use std::{env::current_dir, time::Duration}; + +use tcp_connection::instance::ConnectionInstance; +use tokio::{ + join, + time::{sleep, timeout}, +}; + +use crate::test_utils::{ + handle::{ClientHandle, ServerHandle}, + target::TcpServerTarget, + target_configure::ServerTargetConfig, +}; + +pub(crate) struct ExampleChallengeClientHandle; + +impl ClientHandle for ExampleChallengeClientHandle { + async fn process(mut instance: ConnectionInstance) { + // Accept challenge with correct key + let key = current_dir() + .unwrap() + .join("res") + .join("key") + .join("test_key_private.pem"); + let result = instance.accept_challenge(key, "test_key").await.unwrap(); + + // Sent success + assert!(result); + let response = instance.read_text().await.unwrap(); + + // Verify success + assert_eq!("OK", response); + + // Accept challenge with wrong key + let key = current_dir() + .unwrap() + .join("res") + .join("key") + .join("wrong_key_private.pem"); + let result = instance.accept_challenge(key, "test_key").await.unwrap(); + + // Sent success + assert!(result); + let response = instance.read_text().await.unwrap(); + + // Verify fail + assert_eq!("ERROR", response); + + // Accept challenge with wrong name + let key = current_dir() + .unwrap() + .join("res") + .join("key") + .join("test_key_private.pem"); + let result = instance.accept_challenge(key, "test_key__").await.unwrap(); + + // Sent success + assert!(result); + let response = instance.read_text().await.unwrap(); + + // Verify fail + assert_eq!("ERROR", response); + } +} + +pub(crate) struct ExampleChallengeServerHandle; + +impl ServerHandle for ExampleChallengeServerHandle { + async fn process(mut instance: ConnectionInstance) { + // Challenge with correct key + let key_dir = current_dir().unwrap().join("res").join("key"); + let (result, key_id) = instance.challenge(key_dir).await.unwrap(); + assert!(result); + assert_eq!(key_id, "test_key"); + + // Send response + instance + .write_text(if result { "OK" } else { "ERROR" }) + .await + .unwrap(); + + // Challenge again + let key_dir = current_dir().unwrap().join("res").join("key"); + let (result, key_id) = instance.challenge(key_dir).await.unwrap(); + assert!(!result); + assert_eq!(key_id, "test_key"); + + // Send response + instance + .write_text(if result { "OK" } else { "ERROR" }) + .await + .unwrap(); + + // Challenge again + let key_dir = current_dir().unwrap().join("res").join("key"); + let (result, key_id) = instance.challenge(key_dir).await.unwrap(); + assert!(!result); + assert_eq!(key_id, "test_key__"); + + // Send response + instance + .write_text(if result { "OK" } else { "ERROR" }) + .await + .unwrap(); + } +} + +#[tokio::test] +async fn test_connection_with_challenge_handle() -> Result<(), std::io::Error> { + let host = "localhost:5011"; + + // Server setup + let Ok(server_target) = TcpServerTarget::< + ExampleChallengeClientHandle, + ExampleChallengeServerHandle, + >::from_domain(host) + .await + else { + panic!("Test target built failed from a domain named `{}`", host); + }; + + // Client setup + let Ok(client_target) = TcpServerTarget::< + ExampleChallengeClientHandle, + ExampleChallengeServerHandle, + >::from_domain(host) + .await + else { + panic!("Test target built failed from a domain named `{}`", host); + }; + + let future_server = async move { + // Only process once + let configured_server = server_target.server_cfg(ServerTargetConfig::default().once()); + + // Listen here + let _ = configured_server.listen().await; + }; + + let future_client = async move { + // Wait for server start + let _ = sleep(Duration::from_secs_f32(1.5)).await; + + // Connect here + let _ = client_target.connect().await; + }; + + let test_timeout = Duration::from_secs(10); + + timeout(test_timeout, async { join!(future_client, future_server) }) + .await + .map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!("Test timed out after {:?}", test_timeout), + ) + })?; + + Ok(()) +} diff --git a/utils/tcp_connection/tcp_connection_test/src/test_connection.rs b/utils/tcp_connection/tcp_connection_test/src/test_connection.rs new file mode 100644 index 0000000..8c3ab01 --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/test_connection.rs @@ -0,0 +1,78 @@ +use std::time::Duration; + +use tcp_connection::instance::ConnectionInstance; +use tokio::{join, time::sleep}; + +use crate::test_utils::{ + handle::{ClientHandle, ServerHandle}, + target::TcpServerTarget, + target_configure::ServerTargetConfig, +}; + +pub(crate) struct ExampleClientHandle; + +impl ClientHandle for ExampleClientHandle { + async fn process(mut instance: ConnectionInstance) { + // Write name + let Ok(_) = instance.write_text("Peter").await else { + panic!("Write text failed!"); + }; + // Read msg + let Ok(result) = instance.read_text().await else { + return; + }; + assert_eq!("Hello Peter!", result); + } +} + +pub(crate) struct ExampleServerHandle; + +impl ServerHandle for ExampleServerHandle { + async fn process(mut instance: ConnectionInstance) { + // Read name + let Ok(name) = instance.read_text().await else { + return; + }; + // Write msg + let Ok(_) = instance.write_text(format!("Hello {}!", name)).await else { + panic!("Write text failed!"); + }; + } +} + +#[tokio::test] +async fn test_connection_with_example_handle() { + let host = "localhost:5012"; + + // Server setup + let Ok(server_target) = + TcpServerTarget::::from_domain(host).await + else { + panic!("Test target built failed from a domain named `{}`", host); + }; + + // Client setup + let Ok(client_target) = + TcpServerTarget::::from_domain(host).await + else { + panic!("Test target built failed from a domain named `{}`", host); + }; + + let future_server = async move { + // Only process once + let configured_server = server_target.server_cfg(ServerTargetConfig::default().once()); + + // Listen here + let _ = configured_server.listen().await; + }; + + let future_client = async move { + // Wait for server start + let _ = sleep(Duration::from_secs_f32(1.5)).await; + + // Connect here + let _ = client_target.connect().await; + }; + + let _ = async { join!(future_client, future_server) }.await; +} diff --git a/utils/tcp_connection/tcp_connection_test/src/test_file_transfer.rs b/utils/tcp_connection/tcp_connection_test/src/test_file_transfer.rs new file mode 100644 index 0000000..4237ea7 --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/test_file_transfer.rs @@ -0,0 +1,94 @@ +use std::{env::current_dir, time::Duration}; + +use tcp_connection::instance::ConnectionInstance; +use tokio::{ + join, + time::{sleep, timeout}, +}; + +use crate::test_utils::{ + handle::{ClientHandle, ServerHandle}, + target::TcpServerTarget, + target_configure::ServerTargetConfig, +}; + +pub(crate) struct ExampleFileTransferClientHandle; + +impl ClientHandle for ExampleFileTransferClientHandle { + async fn process(mut instance: ConnectionInstance) { + let image_path = current_dir() + .unwrap() + .join("res") + .join("image") + .join("test_transfer.png"); + instance.write_file(image_path).await.unwrap(); + } +} + +pub(crate) struct ExampleFileTransferServerHandle; + +impl ServerHandle for ExampleFileTransferServerHandle { + async fn process(mut instance: ConnectionInstance) { + let save_path = current_dir() + .unwrap() + .join("res") + .join(".temp") + .join("image") + .join("test_transfer.png"); + instance.read_file(save_path).await.unwrap(); + } +} + +#[tokio::test] +async fn test_connection_with_challenge_handle() -> Result<(), std::io::Error> { + let host = "localhost:5010"; + + // Server setup + let Ok(server_target) = TcpServerTarget::< + ExampleFileTransferClientHandle, + ExampleFileTransferServerHandle, + >::from_domain(host) + .await + else { + panic!("Test target built failed from a domain named `{}`", host); + }; + + // Client setup + let Ok(client_target) = TcpServerTarget::< + ExampleFileTransferClientHandle, + ExampleFileTransferServerHandle, + >::from_domain(host) + .await + else { + panic!("Test target built failed from a domain named `{}`", host); + }; + + let future_server = async move { + // Only process once + let configured_server = server_target.server_cfg(ServerTargetConfig::default().once()); + + // Listen here + let _ = configured_server.listen().await; + }; + + let future_client = async move { + // Wait for server start + let _ = sleep(Duration::from_secs_f32(1.5)).await; + + // Connect here + let _ = client_target.connect().await; + }; + + let test_timeout = Duration::from_secs(10); + + timeout(test_timeout, async { join!(future_client, future_server) }) + .await + .map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!("Test timed out after {:?}", test_timeout), + ) + })?; + + Ok(()) +} diff --git a/utils/tcp_connection/tcp_connection_test/src/test_msgpack.rs b/utils/tcp_connection/tcp_connection_test/src/test_msgpack.rs new file mode 100644 index 0000000..4c9c870 --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/test_msgpack.rs @@ -0,0 +1,103 @@ +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tcp_connection::instance::ConnectionInstance; +use tokio::{join, time::sleep}; + +use crate::test_utils::{ + handle::{ClientHandle, ServerHandle}, + target::TcpServerTarget, + target_configure::ServerTargetConfig, +}; + +#[derive(Debug, PartialEq, Serialize, Deserialize, Default)] +struct TestData { + id: u32, + name: String, +} + +pub(crate) struct MsgPackClientHandle; + +impl ClientHandle for MsgPackClientHandle { + async fn process(mut instance: ConnectionInstance) { + // Test basic MessagePack serialization + let test_data = TestData { + id: 42, + name: "Test MessagePack".to_string(), + }; + + // Write MessagePack data + if let Err(e) = instance.write_msgpack(&test_data).await { + panic!("Write MessagePack failed: {}", e); + } + + // Read response + let response: TestData = match instance.read_msgpack().await { + Ok(data) => data, + Err(e) => panic!("Read MessagePack response failed: {}", e), + }; + + // Verify response + assert_eq!(response.id, test_data.id * 2); + assert_eq!(response.name, format!("Processed: {}", test_data.name)); + } +} + +pub(crate) struct MsgPackServerHandle; + +impl ServerHandle for MsgPackServerHandle { + async fn process(mut instance: ConnectionInstance) { + // Read MessagePack data + let received_data: TestData = match instance.read_msgpack().await { + Ok(data) => data, + Err(_) => return, + }; + + // Process data + let response = TestData { + id: received_data.id * 2, + name: format!("Processed: {}", received_data.name), + }; + + // Write response as MessagePack + if let Err(e) = instance.write_msgpack(&response).await { + panic!("Write MessagePack response failed: {}", e); + } + } +} + +#[tokio::test] +async fn test_msgpack_basic() { + let host = "localhost:5013"; + + // Server setup + let Ok(server_target) = + TcpServerTarget::::from_domain(host).await + else { + panic!("Test target built failed from a domain named `{}`", host); + }; + + // Client setup + let Ok(client_target) = + TcpServerTarget::::from_domain(host).await + else { + panic!("Test target built failed from a domain named `{}`", host); + }; + + let future_server = async move { + // Only process once + let configured_server = server_target.server_cfg(ServerTargetConfig::default().once()); + + // Listen here + let _ = configured_server.listen().await; + }; + + let future_client = async move { + // Wait for server start + let _ = sleep(Duration::from_secs_f32(1.5)).await; + + // Connect here + let _ = client_target.connect().await; + }; + + let _ = async { join!(future_client, future_server) }.await; +} diff --git a/utils/tcp_connection/tcp_connection_test/src/test_tcp_target_build.rs b/utils/tcp_connection/tcp_connection_test/src/test_tcp_target_build.rs new file mode 100644 index 0000000..aa1ec74 --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/test_tcp_target_build.rs @@ -0,0 +1,32 @@ +use crate::{ + test_connection::{ExampleClientHandle, ExampleServerHandle}, + test_utils::target::TcpServerTarget, +}; + +#[test] +fn test_tcp_test_target_build() { + let host = "127.0.0.1:8080"; + + // Test build target by string + let Ok(target) = + TcpServerTarget::::from_address_str(host) + else { + panic!("Test target built failed from a target addr `{}`", host); + }; + assert_eq!(target.to_string(), "127.0.0.1:8080"); +} + +#[tokio::test] +async fn test_tcp_test_target_build_domain() { + let host = "localhost"; + + // Test build target by DomainName and Connection + let Ok(target) = + TcpServerTarget::::from_domain(host).await + else { + panic!("Test target built failed from a domain named `{}`", host); + }; + + // Test into string + assert_eq!(target.to_string(), "127.0.0.1:8080"); +} diff --git a/utils/tcp_connection/tcp_connection_test/src/test_utils.rs b/utils/tcp_connection/tcp_connection_test/src/test_utils.rs new file mode 100644 index 0000000..badf27d --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/test_utils.rs @@ -0,0 +1,4 @@ +pub mod handle; +pub mod target; +pub mod target_configure; +pub mod target_connection; diff --git a/utils/tcp_connection/tcp_connection_test/src/test_utils/handle.rs b/utils/tcp_connection/tcp_connection_test/src/test_utils/handle.rs new file mode 100644 index 0000000..4f9bdbb --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/test_utils/handle.rs @@ -0,0 +1,11 @@ +use std::future::Future; + +use tcp_connection::instance::ConnectionInstance; + +pub trait ClientHandle { + fn process(instance: ConnectionInstance) -> impl Future + Send; +} + +pub trait ServerHandle { + fn process(instance: ConnectionInstance) -> impl Future + Send; +} diff --git a/utils/tcp_connection/tcp_connection_test/src/test_utils/target.rs b/utils/tcp_connection/tcp_connection_test/src/test_utils/target.rs new file mode 100644 index 0000000..8972b2a --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/test_utils/target.rs @@ -0,0 +1,201 @@ +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{Display, Formatter}, + marker::PhantomData, + net::{AddrParseError, IpAddr, Ipv4Addr, SocketAddr}, + str::FromStr, +}; +use tokio::net::lookup_host; + +use crate::test_utils::{ + handle::{ClientHandle, ServerHandle}, + target_configure::{ClientTargetConfig, ServerTargetConfig}, +}; + +const DEFAULT_PORT: u16 = 8080; + +#[derive(Debug, Serialize, Deserialize)] +pub struct TcpServerTarget +where + Client: ClientHandle, + Server: ServerHandle, +{ + /// Client Config + client_cfg: Option, + + /// Server Config + server_cfg: Option, + + /// Server port + port: u16, + + /// Bind addr + bind_addr: IpAddr, + + /// Client Phantom Data + _client: PhantomData, + + /// Server Phantom Data + _server: PhantomData, +} + +impl Default for TcpServerTarget +where + Client: ClientHandle, + Server: ServerHandle, +{ + fn default() -> Self { + Self { + client_cfg: None, + server_cfg: None, + port: DEFAULT_PORT, + bind_addr: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + _client: PhantomData, + _server: PhantomData, + } + } +} + +impl From for TcpServerTarget +where + Client: ClientHandle, + Server: ServerHandle, +{ + /// Convert SocketAddr to TcpServerTarget + fn from(value: SocketAddr) -> Self { + Self { + port: value.port(), + bind_addr: value.ip(), + ..Self::default() + } + } +} + +impl From> for SocketAddr +where + Client: ClientHandle, + Server: ServerHandle, +{ + /// Convert TcpServerTarget to SocketAddr + fn from(val: TcpServerTarget) -> Self { + SocketAddr::new(val.bind_addr, val.port) + } +} + +impl Display for TcpServerTarget +where + Client: ClientHandle, + Server: ServerHandle, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}:{}", self.bind_addr, self.port) + } +} + +impl TcpServerTarget +where + Client: ClientHandle, + Server: ServerHandle, +{ + /// Create target by address + pub fn from_addr(addr: impl Into, port: impl Into) -> Self { + Self { + port: port.into(), + bind_addr: addr.into(), + ..Self::default() + } + } + + /// Try to create target by string + pub fn from_address_str<'a>(addr_str: impl Into<&'a str>) -> Result { + let socket_addr = SocketAddr::from_str(addr_str.into()); + match socket_addr { + Ok(socket_addr) => Ok(Self::from_addr(socket_addr.ip(), socket_addr.port())), + Err(err) => Err(err), + } + } + + /// Try to create target by domain name + pub async fn from_domain<'a>(domain: impl Into<&'a str>) -> Result { + match domain_to_addr(domain).await { + Ok(domain_addr) => Ok(Self::from(domain_addr)), + Err(e) => Err(e), + } + } + + /// Set client config + pub fn client_cfg(mut self, config: ClientTargetConfig) -> Self { + self.client_cfg = Some(config); + self + } + + /// Set server config + pub fn server_cfg(mut self, config: ServerTargetConfig) -> Self { + self.server_cfg = Some(config); + self + } + + /// Add client config + pub fn add_client_cfg(&mut self, config: ClientTargetConfig) { + self.client_cfg = Some(config); + } + + /// Add server config + pub fn add_server_cfg(&mut self, config: ServerTargetConfig) { + self.server_cfg = Some(config); + } + + /// Get client config ref + pub fn get_client_cfg(&self) -> Option<&ClientTargetConfig> { + self.client_cfg.as_ref() + } + + /// Get server config ref + pub fn get_server_cfg(&self) -> Option<&ServerTargetConfig> { + self.server_cfg.as_ref() + } + + /// Get SocketAddr of TcpServerTarget + pub fn get_addr(&self) -> SocketAddr { + SocketAddr::new(self.bind_addr, self.port) + } +} + +/// Parse Domain Name to IpAddr via DNS +async fn domain_to_addr<'a>(domain: impl Into<&'a str>) -> Result { + let domain = domain.into(); + let default_port: u16 = DEFAULT_PORT; + + if let Ok(socket_addr) = domain.parse::() { + return Ok(match socket_addr.ip() { + IpAddr::V4(_) => socket_addr, + IpAddr::V6(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), socket_addr.port()), + }); + } + + if let Ok(_v6_addr) = domain.parse::() { + return Ok(SocketAddr::new( + IpAddr::V4(Ipv4Addr::LOCALHOST), + default_port, + )); + } + + let (host, port_str) = if let Some((host, port)) = domain.rsplit_once(':') { + (host.trim_matches(|c| c == '[' || c == ']'), Some(port)) + } else { + (domain, None) + }; + + let port = port_str + .and_then(|p| p.parse::().ok()) + .map(|p| p.clamp(0, u16::MAX)) + .unwrap_or(default_port); + + let mut socket_iter = lookup_host((host, 0)).await?; + + if let Some(addr) = socket_iter.find(|addr| addr.is_ipv4()) { + return Ok(SocketAddr::new(addr.ip(), port)); + } + + Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port)) +} diff --git a/utils/tcp_connection/tcp_connection_test/src/test_utils/target_configure.rs b/utils/tcp_connection/tcp_connection_test/src/test_utils/target_configure.rs new file mode 100644 index 0000000..d739ac9 --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/test_utils/target_configure.rs @@ -0,0 +1,53 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)] +pub struct ServerTargetConfig { + /// Only process a single connection, then shut down the server. + once: bool, + + /// Timeout duration in milliseconds. (0 is Closed) + timeout: u64, +} + +impl ServerTargetConfig { + /// Set `once` to True + /// This method configures the `once` field of `ServerTargetConfig`. + pub fn once(mut self) -> Self { + self.once = true; + self + } + + /// Set `timeout` to the given value + /// This method configures the `timeout` field of `ServerTargetConfig`. + pub fn timeout(mut self, timeout: u64) -> Self { + self.timeout = timeout; + self + } + + /// Set `once` to the given value + /// This method configures the `once` field of `ServerTargetConfig`. + pub fn set_once(&mut self, enable: bool) { + self.once = enable; + } + + /// Set `timeout` to the given value + /// This method configures the `timeout` field of `ServerTargetConfig`. + pub fn set_timeout(&mut self, timeout: u64) { + self.timeout = timeout; + } + + /// Check if the server is configured to process only a single connection. + /// Returns `true` if the server will shut down after processing one connection. + pub fn is_once(&self) -> bool { + self.once + } + + /// Get the current timeout value in milliseconds. + /// Returns the timeout duration. A value of 0 indicates the connection is closed. + pub fn get_timeout(&self) -> u64 { + self.timeout + } +} + +#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)] +pub struct ClientTargetConfig {} diff --git a/utils/tcp_connection/tcp_connection_test/src/test_utils/target_connection.rs b/utils/tcp_connection/tcp_connection_test/src/test_utils/target_connection.rs new file mode 100644 index 0000000..d5bf2c3 --- /dev/null +++ b/utils/tcp_connection/tcp_connection_test/src/test_utils/target_connection.rs @@ -0,0 +1,89 @@ +use tcp_connection::{error::TcpTargetError, instance::ConnectionInstance}; +use tokio::{ + net::{TcpListener, TcpSocket}, + spawn, +}; + +use crate::test_utils::{ + handle::{ClientHandle, ServerHandle}, + target::TcpServerTarget, + target_configure::ServerTargetConfig, +}; + +impl TcpServerTarget +where + Client: ClientHandle, + Server: ServerHandle, +{ + /// Attempts to establish a connection to the TCP server. + /// + /// This function initiates a connection to the server address + /// specified in the target configuration. + /// + /// This is a Block operation. + pub async fn connect(&self) -> Result<(), TcpTargetError> { + let addr = self.get_addr(); + let Ok(socket) = TcpSocket::new_v4() else { + return Err(TcpTargetError::from("Create tcp socket failed!")); + }; + let stream = match socket.connect(addr).await { + Ok(stream) => stream, + Err(e) => { + let err = format!("Connect to `{}` failed: {}", addr, e); + return Err(TcpTargetError::from(err)); + } + }; + let instance = ConnectionInstance::from(stream); + Client::process(instance).await; + Ok(()) + } + + /// Attempts to establish a connection to the TCP server. + /// + /// This function initiates a connection to the server address + /// specified in the target configuration. + pub async fn listen(&self) -> Result<(), TcpTargetError> { + let addr = self.get_addr(); + let listener = match TcpListener::bind(addr).await { + Ok(listener) => listener, + Err(_) => { + let err = format!("Bind to `{}` failed", addr); + return Err(TcpTargetError::from(err)); + } + }; + + let cfg: ServerTargetConfig = match self.get_server_cfg() { + Some(cfg) => *cfg, + None => ServerTargetConfig::default(), + }; + + if cfg.is_once() { + // Process once (Blocked) + let (stream, _) = match listener.accept().await { + Ok(result) => result, + Err(e) => { + let err = format!("Accept connection failed: {}", e); + return Err(TcpTargetError::from(err)); + } + }; + let instance = ConnectionInstance::from(stream); + Server::process(instance).await; + } else { + loop { + // Process multiple times (Concurrent) + let (stream, _) = match listener.accept().await { + Ok(result) => result, + Err(e) => { + let err = format!("Accept connection failed: {}", e); + return Err(TcpTargetError::from(err)); + } + }; + let instance = ConnectionInstance::from(stream); + spawn(async move { + Server::process(instance).await; + }); + } + } + Ok(()) + } +} -- cgit