diff options
Diffstat (limited to 'crates/system_action')
| -rw-r--r-- | crates/system_action/action_macros/src/lib.rs | 4 | ||||
| -rw-r--r-- | crates/system_action/src/action.rs | 70 | ||||
| -rw-r--r-- | crates/system_action/src/action_pool.rs | 19 |
3 files changed, 75 insertions, 18 deletions
diff --git a/crates/system_action/action_macros/src/lib.rs b/crates/system_action/action_macros/src/lib.rs index aa1c696..d1a47ee 100644 --- a/crates/system_action/action_macros/src/lib.rs +++ b/crates/system_action/action_macros/src/lib.rs @@ -89,12 +89,12 @@ fn generate_action_struct(input_fn: ItemFn, _is_local: bool) -> proc_macro2::Tok #[doc = ""] #[doc = "Register the action to the pool."] #[doc = "```ignore"] - #[doc = "YourAction::register_to_pool(&mut pool);"] + #[doc = "register_your_func(&mut pool);"] #[doc = "```"] #[doc = ""] #[doc = "Process the action at the pool."] #[doc = "```ignore"] - #[doc = "let result = YourAction::process_at_pool(&pool, ctx, arg).await?;"] + #[doc = "let result = proc_your_func(&pool, ctx, arg).await?;"] #[doc = "```"] #fn_vis #fn_sig #fn_block } diff --git a/crates/system_action/src/action.rs b/crates/system_action/src/action.rs index e7d2d8c..3ae5711 100644 --- a/crates/system_action/src/action.rs +++ b/crates/system_action/src/action.rs @@ -1,6 +1,9 @@ use serde::{Serialize, de::DeserializeOwned}; +use std::any::{Any, TypeId}; +use std::collections::HashMap; +use std::sync::Arc; use tcp_connection::{error::TcpTargetError, instance::ConnectionInstance}; -use tokio::net::TcpStream; +use tokio::{net::TcpStream, sync::Mutex}; pub trait Action<Args, Return> where @@ -29,8 +32,10 @@ pub struct ActionContext { action_args_json: String, /// The connection instance in the current context, - /// used to interact with the machine on the other end - instance: Option<ConnectionInstance>, + instance: Option<Arc<Mutex<ConnectionInstance>>>, + + /// Generic data storage for arbitrary types + data: HashMap<TypeId, Arc<dyn Any + Send + Sync>>, } impl ActionContext { @@ -50,18 +55,18 @@ impl ActionContext { /// Build connection instance from TcpStream pub fn build_instance(mut self, stream: TcpStream) -> Self { - self.instance = Some(ConnectionInstance::from(stream)); + self.instance = Some(Arc::new(Mutex::new(ConnectionInstance::from(stream)))); self } /// Insert connection instance into context pub fn insert_instance(mut self, instance: ConnectionInstance) -> Self { - self.instance = Some(instance); + self.instance = Some(Arc::new(Mutex::new(instance))); self } /// Pop connection instance from context - pub fn pop_instance(&mut self) -> Option<ConnectionInstance> { + pub fn pop_instance(&mut self) -> Option<Arc<Mutex<ConnectionInstance>>> { self.instance.take() } } @@ -78,12 +83,12 @@ impl ActionContext { } /// Get the connection instance in the current context - pub fn instance(&self) -> &Option<ConnectionInstance> { + pub fn instance(&self) -> &Option<Arc<Mutex<ConnectionInstance>>> { &self.instance } /// Get a mutable reference to the connection instance in the current context - pub fn instance_mut(&mut self) -> &mut Option<ConnectionInstance> { + pub fn instance_mut(&mut self) -> &mut Option<Arc<Mutex<ConnectionInstance>>> { &mut self.instance } @@ -104,8 +109,55 @@ impl ActionContext { } /// Set the action arguments in the context - pub fn set_action_args_json(mut self, action_args: String) -> Self { + pub fn set_action_args(mut self, action_args: String) -> Self { self.action_args_json = action_args; self } + + /// Insert arbitrary data in the context + pub fn insert<T: Any + Send + Sync>(mut self, value: T) -> Self { + self.data.insert(TypeId::of::<T>(), Arc::new(value)); + self + } + + /// Insert arbitrary data as Arc in the context + pub fn insert_arc<T: Any + Send + Sync>(mut self, value: Arc<T>) -> Self { + self.data.insert(TypeId::of::<T>(), value); + self + } + + /// Get arbitrary data from the context + pub fn get<T: Any + Send + Sync>(&self) -> Option<&T> { + self.data + .get(&TypeId::of::<T>()) + .and_then(|arc| arc.downcast_ref::<T>()) + } + + /// Get arbitrary data as Arc from the context + pub fn get_arc<T: Any + Send + Sync>(&self) -> Option<Arc<T>> { + self.data + .get(&TypeId::of::<T>()) + .and_then(|arc| Arc::clone(arc).downcast::<T>().ok()) + } + + /// Remove and return arbitrary data from the context + pub fn remove<T: Any + Send + Sync>(&mut self) -> Option<Arc<T>> { + self.data + .remove(&TypeId::of::<T>()) + .and_then(|arc| arc.downcast::<T>().ok()) + } + + /// Check if the context contains data of a specific type + pub fn contains<T: Any + Send + Sync>(&self) -> bool { + self.data.contains_key(&TypeId::of::<T>()) + } + + /// Take ownership of the context and extract data of a specific type + pub fn take<T: Any + Send + Sync>(mut self) -> (Self, Option<Arc<T>>) { + let value = self + .data + .remove(&TypeId::of::<T>()) + .and_then(|arc| arc.downcast::<T>().ok()); + (self, value) + } } diff --git a/crates/system_action/src/action_pool.rs b/crates/system_action/src/action_pool.rs index 7e93fc4..f3e178a 100644 --- a/crates/system_action/src/action_pool.rs +++ b/crates/system_action/src/action_pool.rs @@ -8,7 +8,8 @@ use crate::action::{Action, ActionContext}; type ProcBeginCallback = for<'a> fn( - &'a mut ActionContext, + &'a ActionContext, + args: &'a (dyn std::any::Any + Send + Sync), ) -> Pin<Box<dyn Future<Output = Result<(), TcpTargetError>> + Send + 'a>>; type ProcEndCallback = fn() -> Pin<Box<dyn Future<Output = Result<(), TcpTargetError>> + Send>>; @@ -85,9 +86,9 @@ impl ActionPool { if let Some(action) = self.actions.get(action_name) { // Set action name and args in context for callbacks let context = context.set_action_name(action_name.to_string()); - let mut context = context.set_action_args_json(args_json.clone()); + let context = context.set_action_args(args_json.clone()); - let _ = self.exec_on_proc_begin(&mut context).await?; + let _ = self.exec_on_proc_begin(&context, &args_json).await?; let result = action.process_json_erased(context, args_json).await?; let _ = self.exec_on_proc_end().await?; Ok(result) @@ -109,11 +110,11 @@ impl ActionPool { args: Args, ) -> Result<Return, TcpTargetError> where - Args: serde::de::DeserializeOwned + Send + 'static, + Args: serde::de::DeserializeOwned + Send + Sync + 'static, Return: serde::Serialize + Send + 'static, { if let Some(action) = self.actions.get(action_name) { - let _ = self.exec_on_proc_begin(&mut context).await?; + let _ = self.exec_on_proc_begin(&context, &args).await?; let result = action.process_erased(context, Box::new(args)).await?; let result = *result .downcast::<Return>() @@ -126,9 +127,13 @@ impl ActionPool { } /// Executes the process begin callback if set - async fn exec_on_proc_begin(&self, context: &mut ActionContext) -> Result<(), TcpTargetError> { + async fn exec_on_proc_begin( + &self, + context: &ActionContext, + args: &(dyn std::any::Any + Send + Sync), + ) -> Result<(), TcpTargetError> { if let Some(callback) = &self.on_proc_begin { - callback(context).await + callback(context, args).await } else { Ok(()) } |
