summaryrefslogtreecommitdiff
path: root/crates/system_action
diff options
context:
space:
mode:
Diffstat (limited to 'crates/system_action')
-rw-r--r--crates/system_action/Cargo.toml4
-rw-r--r--crates/system_action/action_macros/Cargo.toml4
-rw-r--r--crates/system_action/action_macros/src/lib.rs54
-rw-r--r--crates/system_action/src/action.rs132
-rw-r--r--crates/system_action/src/action_pool.rs145
5 files changed, 293 insertions, 46 deletions
diff --git a/crates/system_action/Cargo.toml b/crates/system_action/Cargo.toml
index ee4f774..54ae454 100644
--- a/crates/system_action/Cargo.toml
+++ b/crates/system_action/Cargo.toml
@@ -9,3 +9,7 @@ action_system_macros = { path = "action_macros" }
# Serialization
serde = { version = "1.0.219", features = ["derive"] }
+serde_json = "1.0.140"
+
+# Async & Networking
+tokio = { version = "1.46.1", features = ["full"] }
diff --git a/crates/system_action/action_macros/Cargo.toml b/crates/system_action/action_macros/Cargo.toml
index 5ae14fa..869dcde 100644
--- a/crates/system_action/action_macros/Cargo.toml
+++ b/crates/system_action/action_macros/Cargo.toml
@@ -13,3 +13,7 @@ string_proc = { path = "../../utils/string_proc" }
syn = { version = "2.0", features = ["full", "extra-traits"] }
quote = "1.0"
proc-macro2 = "1.0"
+
+# Serialization
+serde = { version = "1.0.219", features = ["derive"] }
+serde_json = "1.0.140"
diff --git a/crates/system_action/action_macros/src/lib.rs b/crates/system_action/action_macros/src/lib.rs
index a7de9b6..ce50073 100644
--- a/crates/system_action/action_macros/src/lib.rs
+++ b/crates/system_action/action_macros/src/lib.rs
@@ -37,6 +37,9 @@ fn generate_action_struct(input_fn: ItemFn, _is_local: bool) -> proc_macro2::Tok
let action_name_ident = &fn_name;
+ let register_this_action = quote::format_ident!("register_{}", action_name_ident);
+ let proc_this_action = quote::format_ident!("proc_{}", action_name_ident);
+
quote! {
#[derive(Debug, Clone, Default)]
#fn_vis struct #struct_name;
@@ -55,22 +58,28 @@ fn generate_action_struct(input_fn: ItemFn, _is_local: bool) -> proc_macro2::Tok
}
}
- impl #struct_name {
- #fn_vis fn register_to_pool(pool: &mut action_system::action_pool::ActionPool) {
- pool.register::<#struct_name, #arg_type, #return_type>();
- }
+ #fn_vis fn #register_this_action(pool: &mut action_system::action_pool::ActionPool) {
+ pool.register::<#struct_name, #arg_type, #return_type>();
+ }
- #fn_vis async fn process_at_pool<'a>(
- pool: &'a action_system::action_pool::ActionPool,
- ctx: action_system::action::ActionContext,
- #arg_param_name: #arg_type
- ) -> Result<#return_type, tcp_connection::error::TcpTargetError> {
- pool.process::<#arg_type, #return_type>(
- Box::leak(string_proc::snake_case!(stringify!(#action_name_ident)).into_boxed_str()),
- ctx,
- #arg_param_name
- ).await
- }
+ #fn_vis async fn #proc_this_action(
+ pool: &action_system::action_pool::ActionPool,
+ ctx: action_system::action::ActionContext,
+ #arg_param_name: #arg_type
+ ) -> Result<#return_type, tcp_connection::error::TcpTargetError> {
+ let args_json = serde_json::to_string(&#arg_param_name)
+ .map_err(|e| {
+ tcp_connection::error::TcpTargetError::Serialization(e.to_string())
+ })?;
+ let result_json = pool.process_json(
+ Box::leak(string_proc::snake_case!(stringify!(#action_name_ident)).into_boxed_str()),
+ ctx,
+ args_json,
+ ).await?;
+ serde_json::from_str(&result_json)
+ .map_err(|e| {
+ tcp_connection::error::TcpTargetError::Serialization(e.to_string())
+ })
}
#[allow(dead_code)]
@@ -79,20 +88,20 @@ fn generate_action_struct(input_fn: ItemFn, _is_local: bool) -> proc_macro2::Tok
#[doc = "Use the generated struct instead."]
#[doc = ""]
#[doc = "Register the action to the pool."]
- #[doc = "```rust"]
- #[doc = "YourActionPascalName::register_to_pool(&mut pool);"]
+ #[doc = "```ignore"]
+ #[doc = "register_your_func(&mut pool);"]
#[doc = "```"]
#[doc = ""]
#[doc = "Process the action at the pool."]
- #[doc = "```rust"]
- #[doc = "let result = YourActionPascalName::process_at_pool(&pool, ctx, arg).await?;"]
+ #[doc = "```ignore"]
+ #[doc = "let result = proc_your_func(&pool, ctx, arg).await?;"]
#[doc = "```"]
#fn_vis #fn_sig #fn_block
}
}
fn validate_function_signature(fn_sig: &syn::Signature) {
- if !fn_sig.asyncness.is_some() {
+ if fn_sig.asyncness.is_none() {
panic!("Expected async function for Action, but found synchronous function");
}
@@ -111,13 +120,12 @@ fn validate_function_signature(fn_sig: &syn::Signature) {
};
if let syn::Type::Path(type_path) = return_type.as_ref() {
- if let Some(segment) = type_path.path.segments.last() {
- if segment.ident != "Result" {
+ if let Some(segment) = type_path.path.segments.last()
+ && segment.ident != "Result" {
panic!(
"Expected Action function to return Result<T, TcpTargetError>, but found different return type"
);
}
- }
} else {
panic!(
"Expected Action function to return Result<T, TcpTargetError>, but found no return type"
diff --git a/crates/system_action/src/action.rs b/crates/system_action/src/action.rs
index 562a142..8a6180a 100644
--- a/crates/system_action/src/action.rs
+++ b/crates/system_action/src/action.rs
@@ -1,6 +1,15 @@
+use serde::{Serialize, de::DeserializeOwned};
+use std::any::{Any, TypeId};
+use std::collections::HashMap;
+use std::sync::Arc;
use tcp_connection::{error::TcpTargetError, instance::ConnectionInstance};
+use tokio::{net::TcpStream, sync::Mutex};
-pub trait Action<Args, Return> {
+pub trait Action<Args, Return>
+where
+ Args: Serialize + DeserializeOwned + Send,
+ Return: Serialize + DeserializeOwned + Send,
+{
fn action_name() -> &'static str;
fn is_remote_action() -> bool;
@@ -13,27 +22,54 @@ pub trait Action<Args, Return> {
#[derive(Default)]
pub struct ActionContext {
- // Whether the action is executed locally or remotely
+ /// Whether the action is executed locally or remotely
local: bool,
+ /// The name of the action being executed
+ action_name: String,
+
+ /// The JSON-serialized arguments for the action
+ action_args_json: String,
+
/// The connection instance in the current context,
- /// used to interact with the machine on the other end
- instance: Option<ConnectionInstance>,
+ instance: Option<Arc<Mutex<ConnectionInstance>>>,
+
+ /// Generic data storage for arbitrary types
+ data: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl ActionContext {
/// Generate local context
pub fn local() -> Self {
- let mut ctx = ActionContext::default();
- ctx.local = true;
- ctx
+ ActionContext {
+ local: true,
+ ..Default::default()
+ }
}
/// Generate remote context
pub fn remote() -> Self {
- let mut ctx = ActionContext::default();
- ctx.local = false;
- ctx
+ ActionContext {
+ local: false,
+ ..Default::default()
+ }
+ }
+
+ /// Build connection instance from TcpStream
+ pub fn build_instance(mut self, stream: TcpStream) -> Self {
+ self.instance = Some(Arc::new(Mutex::new(ConnectionInstance::from(stream))));
+ self
+ }
+
+ /// Insert connection instance into context
+ pub fn insert_instance(mut self, instance: ConnectionInstance) -> Self {
+ self.instance = Some(Arc::new(Mutex::new(instance)));
+ self
+ }
+
+ /// Pop connection instance from context
+ pub fn pop_instance(&mut self) -> Option<Arc<Mutex<ConnectionInstance>>> {
+ self.instance.take()
}
}
@@ -49,7 +85,81 @@ impl ActionContext {
}
/// Get the connection instance in the current context
- pub fn instance(&self) -> &Option<ConnectionInstance> {
+ pub fn instance(&self) -> &Option<Arc<Mutex<ConnectionInstance>>> {
&self.instance
}
+
+ /// Get a mutable reference to the connection instance in the current context
+ pub fn instance_mut(&mut self) -> &mut Option<Arc<Mutex<ConnectionInstance>>> {
+ &mut self.instance
+ }
+
+ /// Get the action name from the context
+ pub fn action_name(&self) -> &str {
+ &self.action_name
+ }
+
+ /// Get the action arguments from the context
+ pub fn action_args_json(&self) -> &String {
+ &self.action_args_json
+ }
+
+ /// Set the action name in the context
+ pub fn set_action_name(mut self, action_name: String) -> Self {
+ self.action_name = action_name;
+ self
+ }
+
+ /// Set the action arguments in the context
+ pub fn set_action_args(mut self, action_args: String) -> Self {
+ self.action_args_json = action_args;
+ self
+ }
+
+ /// Insert arbitrary data in the context
+ pub fn insert<T: Any + Send + Sync>(mut self, value: T) -> Self {
+ self.data.insert(TypeId::of::<T>(), Arc::new(value));
+ self
+ }
+
+ /// Insert arbitrary data as Arc in the context
+ pub fn insert_arc<T: Any + Send + Sync>(mut self, value: Arc<T>) -> Self {
+ self.data.insert(TypeId::of::<T>(), value);
+ self
+ }
+
+ /// Get arbitrary data from the context
+ pub fn get<T: Any + Send + Sync>(&self) -> Option<&T> {
+ self.data
+ .get(&TypeId::of::<T>())
+ .and_then(|arc| arc.downcast_ref::<T>())
+ }
+
+ /// Get arbitrary data as Arc from the context
+ pub fn get_arc<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
+ self.data
+ .get(&TypeId::of::<T>())
+ .and_then(|arc| Arc::clone(arc).downcast::<T>().ok())
+ }
+
+ /// Remove and return arbitrary data from the context
+ pub fn remove<T: Any + Send + Sync>(&mut self) -> Option<Arc<T>> {
+ self.data
+ .remove(&TypeId::of::<T>())
+ .and_then(|arc| arc.downcast::<T>().ok())
+ }
+
+ /// Check if the context contains data of a specific type
+ pub fn contains<T: Any + Send + Sync>(&self) -> bool {
+ self.data.contains_key(&TypeId::of::<T>())
+ }
+
+ /// Take ownership of the context and extract data of a specific type
+ pub fn take<T: Any + Send + Sync>(mut self) -> (Self, Option<Arc<T>>) {
+ let value = self
+ .data
+ .remove(&TypeId::of::<T>())
+ .and_then(|arc| arc.downcast::<T>().ok());
+ (self, value)
+ }
}
diff --git a/crates/system_action/src/action_pool.rs b/crates/system_action/src/action_pool.rs
index 0a1a6c7..c28de1e 100644
--- a/crates/system_action/src/action_pool.rs
+++ b/crates/system_action/src/action_pool.rs
@@ -1,11 +1,36 @@
+use std::pin::Pin;
+
+use serde::{Serialize, de::DeserializeOwned};
+use serde_json;
use tcp_connection::error::TcpTargetError;
use crate::action::{Action, ActionContext};
+type ProcBeginCallback = for<'a> fn(
+ &'a ActionContext,
+ args: &'a (dyn std::any::Any + Send + Sync),
+) -> ProcBeginFuture<'a>;
+type ProcEndCallback = fn() -> ProcEndFuture;
+
+type ProcBeginFuture<'a> = Pin<Box<dyn Future<Output = Result<(), TcpTargetError>> + Send + 'a>>;
+type ProcEndFuture = Pin<Box<dyn Future<Output = Result<(), TcpTargetError>> + Send>>;
+
/// A pool of registered actions that can be processed by name
pub struct ActionPool {
/// HashMap storing action name to action implementation mapping
actions: std::collections::HashMap<&'static str, Box<dyn ActionErased>>,
+
+ /// Callback to execute when process begins
+ on_proc_begin: Option<ProcBeginCallback>,
+
+ /// Callback to execute when process ends
+ on_proc_end: Option<ProcEndCallback>,
+}
+
+impl Default for ActionPool {
+ fn default() -> Self {
+ Self::new()
+ }
}
impl ActionPool {
@@ -13,20 +38,32 @@ impl ActionPool {
pub fn new() -> Self {
Self {
actions: std::collections::HashMap::new(),
+ on_proc_begin: None,
+ on_proc_end: None,
}
}
+ /// Sets a callback to be executed when process begins
+ pub fn set_on_proc_begin(&mut self, callback: ProcBeginCallback) {
+ self.on_proc_begin = Some(callback);
+ }
+
+ /// Sets a callback to be executed when process ends
+ pub fn set_on_proc_end(&mut self, callback: ProcEndCallback) {
+ self.on_proc_end = Some(callback);
+ }
+
/// Registers an action type with the pool
///
/// Usage:
- /// ```
+ /// ```ignore
/// action_pool.register::<MyAction, MyArgs, MyReturn>();
/// ```
pub fn register<A, Args, Return>(&mut self)
where
A: Action<Args, Return> + Send + Sync + 'static,
- Args: serde::de::DeserializeOwned + Send + Sync + 'static,
- Return: serde::Serialize + Send + Sync + 'static,
+ Args: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
+ Return: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
{
let action_name = A::action_name();
self.actions.insert(
@@ -38,7 +75,40 @@ impl ActionPool {
/// Processes an action by name with given context and arguments
///
/// Usage:
+ /// ```ignore
+ /// let result = action_pool.process::<MyArgs, MyReturn>("my_action", context, args).await?;
+ /// ```
+ /// Processes an action by name with JSON-serialized arguments
+ ///
+ /// Usage:
+ /// ```ignore
+ /// let result_json = action_pool.process_json("my_action", context, args_json).await?;
+ /// let result: MyReturn = serde_json::from_str(&result_json)?;
/// ```
+ pub async fn process_json<'a>(
+ &'a self,
+ action_name: &'a str,
+ context: ActionContext,
+ args_json: String,
+ ) -> Result<String, TcpTargetError> {
+ if let Some(action) = self.actions.get(action_name) {
+ // Set action name and args in context for callbacks
+ let context = context.set_action_name(action_name.to_string());
+ let context = context.set_action_args(args_json.clone());
+
+ self.exec_on_proc_begin(&context, &args_json).await?;
+ let result = action.process_json_erased(context, args_json).await?;
+ self.exec_on_proc_end().await?;
+ Ok(result)
+ } else {
+ Err(TcpTargetError::Unsupported("InvalidAction".to_string()))
+ }
+ }
+
+ /// Processes an action by name with given context and arguments
+ ///
+ /// Usage:
+ /// ```ignore
/// let result = action_pool.process::<MyArgs, MyReturn>("my_action", context, args).await?;
/// ```
pub async fn process<'a, Args, Return>(
@@ -48,34 +118,69 @@ impl ActionPool {
args: Args,
) -> Result<Return, TcpTargetError>
where
- Args: serde::de::DeserializeOwned + Send + 'static,
+ Args: serde::de::DeserializeOwned + Send + Sync + 'static,
Return: serde::Serialize + Send + 'static,
{
if let Some(action) = self.actions.get(action_name) {
+ self.exec_on_proc_begin(&context, &args).await?;
let result = action.process_erased(context, Box::new(args)).await?;
let result = *result
.downcast::<Return>()
.map_err(|_| TcpTargetError::Unsupported("InvalidArguments".to_string()))?;
+ self.exec_on_proc_end().await?;
Ok(result)
} else {
Err(TcpTargetError::Unsupported("InvalidAction".to_string()))
}
}
+
+ /// Executes the process begin callback if set
+ async fn exec_on_proc_begin(
+ &self,
+ context: &ActionContext,
+ args: &(dyn std::any::Any + Send + Sync),
+ ) -> Result<(), TcpTargetError> {
+ if let Some(callback) = &self.on_proc_begin {
+ callback(context, args).await
+ } else {
+ Ok(())
+ }
+ }
+
+ /// Executes the process end callback if set
+ async fn exec_on_proc_end(&self) -> Result<(), TcpTargetError> {
+ if let Some(callback) = &self.on_proc_end {
+ callback().await
+ } else {
+ Ok(())
+ }
+ }
}
/// Trait for type-erased actions that can be stored in ActionPool
+type ProcessErasedFuture = std::pin::Pin<
+ Box<
+ dyn std::future::Future<Output = Result<Box<dyn std::any::Any + Send>, TcpTargetError>>
+ + Send,
+ >,
+>;
+type ProcessJsonErasedFuture =
+ std::pin::Pin<Box<dyn std::future::Future<Output = Result<String, TcpTargetError>> + Send>>;
+
trait ActionErased: Send + Sync {
/// Processes the action with type-erased arguments and returns type-erased result
fn process_erased(
&self,
context: ActionContext,
args: Box<dyn std::any::Any + Send>,
- ) -> std::pin::Pin<
- Box<
- dyn std::future::Future<Output = Result<Box<dyn std::any::Any + Send>, TcpTargetError>>
- + Send,
- >,
- >;
+ ) -> ProcessErasedFuture;
+
+ /// Processes the action with JSON-serialized arguments and returns JSON-serialized result
+ fn process_json_erased(
+ &self,
+ context: ActionContext,
+ args_json: String,
+ ) -> ProcessJsonErasedFuture;
}
/// Wrapper struct that implements ActionErased for concrete Action types
@@ -84,8 +189,8 @@ struct ActionWrapper<A, Args, Return>(std::marker::PhantomData<(A, Args, Return)
impl<A, Args, Return> ActionErased for ActionWrapper<A, Args, Return>
where
A: Action<Args, Return> + Send + Sync,
- Args: serde::de::DeserializeOwned + Send + Sync + 'static,
- Return: serde::Serialize + Send + Sync + 'static,
+ Args: Serialize + DeserializeOwned + Send + Sync + 'static,
+ Return: Serialize + DeserializeOwned + Send + Sync + 'static,
{
fn process_erased(
&self,
@@ -105,4 +210,20 @@ where
Ok(Box::new(result) as Box<dyn std::any::Any + Send>)
})
}
+
+ fn process_json_erased(
+ &self,
+ context: ActionContext,
+ args_json: String,
+ ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String, TcpTargetError>> + Send>>
+ {
+ Box::pin(async move {
+ let args: Args = serde_json::from_str(&args_json)
+ .map_err(|e| TcpTargetError::Serialization(format!("Deserialize failed: {}", e)))?;
+ let result = A::process(context, args).await?;
+ let result_json = serde_json::to_string(&result)
+ .map_err(|e| TcpTargetError::Serialization(format!("Serialize failed: {}", e)))?;
+ Ok(result_json)
+ })
+ }
}