summaryrefslogtreecommitdiff
path: root/systems/action/action_macros/src
diff options
context:
space:
mode:
Diffstat (limited to 'systems/action/action_macros/src')
-rw-r--r--systems/action/action_macros/src/lib.rs248
1 files changed, 248 insertions, 0 deletions
diff --git a/systems/action/action_macros/src/lib.rs b/systems/action/action_macros/src/lib.rs
new file mode 100644
index 0000000..e6616b4
--- /dev/null
+++ b/systems/action/action_macros/src/lib.rs
@@ -0,0 +1,248 @@
+use proc_macro::TokenStream;
+use quote::quote;
+use syn::{ItemFn, parse_macro_input};
+
+/// # Macro - Generate Action
+///
+/// When annotating a function with the `#[action_gen]` macro in the following format, it generates boilerplate code for client-server interaction
+///
+/// ```ignore
+/// #[action_gen]
+/// async fn action_name(ctx: ActionContext, argument: YourArgument) -> Result<YourResult, TcpTargetError> {
+/// // Write your client and server logic here
+/// if ctx.is_proc_on_remote() {
+/// // Server logic
+/// }
+/// if ctx.is_proc_on_local() {
+/// // Client logic
+/// }
+/// }
+/// ```
+///
+/// > WARNING:
+/// > For Argument and Result types, the `action_gen` macro only supports types that derive serde's Serialize and Deserialize
+///
+/// ## Generated Code
+///
+/// `action_gen` will generate the following:
+///
+/// 1. Complete implementation of Action<Args, Return>
+/// 2. Process / Register method
+///
+/// ## How to use
+///
+/// You can use your generated method as follows
+///
+/// ```ignore
+/// async fn main() -> Result<(), TcpTargetError> {
+///
+/// // Prepare your argument
+/// let args = YourArgument::default();
+///
+/// // Create a pool and context
+/// let mut pool = ActionPool::new();
+/// let ctx = ActionContext::local();
+///
+/// // Register your action
+/// register_your_action(&mut pool);
+///
+/// // Process your action
+/// proc_your_action(&pool, ctx, args).await?;
+///
+/// Ok(())
+/// }
+/// ```
+#[proc_macro_attribute]
+pub fn action_gen(attr: TokenStream, item: TokenStream) -> TokenStream {
+ let input_fn = parse_macro_input!(item as ItemFn);
+ let is_local = if attr.is_empty() {
+ false
+ } else {
+ let attr_str = attr.to_string();
+ attr_str == "local" || attr_str.contains("local")
+ };
+
+ generate_action_struct(input_fn, is_local).into()
+}
+
+fn generate_action_struct(input_fn: ItemFn, _is_local: bool) -> proc_macro2::TokenStream {
+ let fn_vis = &input_fn.vis;
+ let fn_sig = &input_fn.sig;
+ let fn_name = &fn_sig.ident;
+ let fn_block = &input_fn.block;
+
+ validate_function_signature(fn_sig);
+
+ let (context_param_name, arg_param_name, arg_type, return_type) =
+ extract_parameters_and_types(fn_sig);
+
+ let struct_name = quote::format_ident!("{}", convert_to_pascal_case(&fn_name.to_string()));
+
+ let action_name_ident = &fn_name;
+
+ let register_this_action = quote::format_ident!("register_{}", action_name_ident);
+ let proc_this_action = quote::format_ident!("proc_{}", action_name_ident);
+
+ quote! {
+ #[derive(Debug, Clone, Default)]
+ #fn_vis struct #struct_name;
+
+ impl action_system::action::Action<#arg_type, #return_type> for #struct_name {
+ fn action_name() -> &'static str {
+ Box::leak(string_proc::snake_case!(stringify!(#action_name_ident)).into_boxed_str())
+ }
+
+ fn is_remote_action() -> bool {
+ !#_is_local
+ }
+
+ async fn process(#context_param_name: action_system::action::ActionContext, #arg_param_name: #arg_type) -> Result<#return_type, tcp_connection::error::TcpTargetError> {
+ #fn_block
+ }
+ }
+
+ #fn_vis fn #register_this_action(pool: &mut action_system::action_pool::ActionPool) {
+ pool.register::<#struct_name, #arg_type, #return_type>();
+ }
+
+ #fn_vis async fn #proc_this_action(
+ pool: &action_system::action_pool::ActionPool,
+ mut ctx: action_system::action::ActionContext,
+ #arg_param_name: #arg_type
+ ) -> Result<#return_type, tcp_connection::error::TcpTargetError> {
+ ctx.set_is_remote_action(!#_is_local);
+ let args_json = serde_json::to_string(&#arg_param_name)
+ .map_err(|e| {
+ tcp_connection::error::TcpTargetError::Serialization(e.to_string())
+ })?;
+ let result_json = pool.process_json(
+ Box::leak(string_proc::snake_case!(stringify!(#action_name_ident)).into_boxed_str()),
+ ctx,
+ args_json,
+ ).await?;
+ serde_json::from_str(&result_json)
+ .map_err(|e| {
+ tcp_connection::error::TcpTargetError::Serialization(e.to_string())
+ })
+ }
+
+ #[allow(dead_code)]
+ #[deprecated = "This function is used by #[action_gen] as a template."]
+ #[doc = "Template function for #[[action_gen]] - do not call directly."]
+ #[doc = "Use the generated struct instead."]
+ #[doc = ""]
+ #[doc = "Register the action to the pool."]
+ #[doc = "```ignore"]
+ #[doc = "register_your_func(&mut pool);"]
+ #[doc = "```"]
+ #[doc = ""]
+ #[doc = "Process the action at the pool."]
+ #[doc = "```ignore"]
+ #[doc = "let result = proc_your_func(&pool, ctx, arg).await?;"]
+ #[doc = "```"]
+ #fn_vis #fn_sig #fn_block
+ }
+}
+
+fn validate_function_signature(fn_sig: &syn::Signature) {
+ if fn_sig.asyncness.is_none() {
+ panic!("Expected async function for Action, but found synchronous function");
+ }
+
+ if fn_sig.inputs.len() != 2 {
+ panic!(
+ "Expected exactly 2 arguments for Action function: ctx: ActionContext and arg: T, but found {} arguments",
+ fn_sig.inputs.len()
+ );
+ }
+
+ let return_type = match &fn_sig.output {
+ syn::ReturnType::Type(_, ty) => ty,
+ _ => panic!(
+ "Expected Action function to return Result<T, TcpTargetError>, but found no return type"
+ ),
+ };
+
+ if let syn::Type::Path(type_path) = return_type.as_ref() {
+ if let Some(segment) = type_path.path.segments.last()
+ && segment.ident != "Result"
+ {
+ panic!(
+ "Expected Action function to return Result<T, TcpTargetError>, but found different return type"
+ );
+ }
+ } else {
+ panic!(
+ "Expected Action function to return Result<T, TcpTargetError>, but found no return type"
+ );
+ }
+}
+
+fn convert_to_pascal_case(s: &str) -> String {
+ s.split('_')
+ .map(|word| {
+ let mut chars = word.chars();
+ match chars.next() {
+ None => String::new(),
+ Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
+ }
+ })
+ .collect()
+}
+
+fn extract_parameters_and_types(
+ fn_sig: &syn::Signature,
+) -> (
+ proc_macro2::TokenStream,
+ proc_macro2::TokenStream,
+ proc_macro2::TokenStream,
+ proc_macro2::TokenStream,
+) {
+ let mut inputs = fn_sig.inputs.iter();
+
+ let context_param = match inputs.next() {
+ Some(syn::FnArg::Typed(pat_type)) => {
+ let pat = &pat_type.pat;
+ quote::quote!(#pat)
+ }
+ _ => {
+ panic!("Expected the first argument to be a typed parameter, but found something else")
+ }
+ };
+
+ let arg_param = match inputs.next() {
+ Some(syn::FnArg::Typed(pat_type)) => {
+ let pat = &pat_type.pat;
+ let ty = &pat_type.ty;
+ (quote::quote!(#pat), quote::quote!(#ty))
+ }
+ _ => {
+ panic!("Expected the second argument to be a typed parameter, but found something else")
+ }
+ };
+
+ let (arg_param_name, arg_type) = arg_param;
+
+ let return_type = match &fn_sig.output {
+ syn::ReturnType::Type(_, ty) => {
+ if let syn::Type::Path(type_path) = ty.as_ref() {
+ if let syn::PathArguments::AngleBracketed(args) =
+ &type_path.path.segments.last().unwrap().arguments
+ {
+ if let Some(syn::GenericArgument::Type(ty)) = args.args.first() {
+ quote::quote!(#ty)
+ } else {
+ panic!("Expected to extract the success type of Result, but failed");
+ }
+ } else {
+ panic!("Expected Result type to have generic parameters, but found none");
+ }
+ } else {
+ panic!("Expected return type to be Result, but found different type");
+ }
+ }
+ _ => panic!("Expected function to have return type, but found none"),
+ };
+
+ (context_param, arg_param_name, arg_type, return_type)
+}