summaryrefslogtreecommitdiff
path: root/mingling_macros/src/chain.rs
diff options
context:
space:
mode:
Diffstat (limited to 'mingling_macros/src/chain.rs')
-rw-r--r--mingling_macros/src/chain.rs97
1 files changed, 79 insertions, 18 deletions
diff --git a/mingling_macros/src/chain.rs b/mingling_macros/src/chain.rs
index 1fddd3b..f8b1e1c 100644
--- a/mingling_macros/src/chain.rs
+++ b/mingling_macros/src/chain.rs
@@ -1,10 +1,10 @@
//! Chain Attribute Macro Implementation
//!
-//! This module provides the `#[chain]` attribute macro for automatically
+//! This module provides the `#[chain(Group)]` attribute macro for automatically
//! generating structs that implement the `Chain` trait from async functions.
use proc_macro::TokenStream;
-use quote::quote;
+use quote::{ToTokens, quote};
use syn::spanned::Spanned;
use syn::{
FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Signature, Type, TypePath, parse_macro_input,
@@ -59,7 +59,18 @@ fn extract_return_type(sig: &Signature) -> syn::Result<TypePath> {
}
}
-pub fn chain_attr(item: TokenStream) -> TokenStream {
+pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
+ // Parse the attribute arguments (e.g., MyProgram from #[chain(MyProgram)])
+ // If no argument is provided, use DefaultProgram
+ let (group_name, use_crate_prefix) = if attr.is_empty() {
+ (
+ Ident::new("DefaultProgram", proc_macro2::Span::call_site()),
+ true,
+ )
+ } else {
+ (parse_macro_input!(attr as Ident), false)
+ };
+
// Parse the function item
let input_fn = parse_macro_input!(item as ItemFn);
@@ -82,11 +93,22 @@ pub fn chain_attr(item: TokenStream) -> TokenStream {
Err(e) => return e.to_compile_error().into(),
};
+ // Ensure the return type is named "GroupProcess"
+ if return_type.path.segments.last().unwrap().ident != "GroupProcess" {
+ return syn::Error::new(
+ return_type.span(),
+ "Return type must be 'mingling::marker::GroupProcess'",
+ )
+ .to_compile_error()
+ .into();
+ }
+
// Get the function body
let fn_body = &input_fn.block;
// Get function attributes (excluding the chain attribute)
let mut fn_attrs = input_fn.attrs.clone();
+
// Remove any #[chain(...)] attributes to avoid infinite recursion
fn_attrs.retain(|attr| !attr.path().is_ident("chain"));
@@ -101,23 +123,55 @@ pub fn chain_attr(item: TokenStream) -> TokenStream {
let struct_name = Ident::new(&pascal_case_name, fn_name.span());
// Generate the struct and implementation
- let expanded = quote! {
- #(#fn_attrs)*
- #vis struct #struct_name;
-
- impl ::mingling::Chain for #struct_name {
- type Previous = #previous_type;
+ let expanded = if use_crate_prefix {
+ quote! {
+ #(#fn_attrs)*
+ #vis struct #struct_name;
+
+ impl ::mingling::Chain<DefaultProgram> for #struct_name {
+ type Previous = #previous_type;
+
+ async fn proc(#prev_param: Self::Previous) ->
+ ::mingling::ChainProcess<DefaultProgram>
+ {
+ let _ = GroupProcess;
+ // Call the original function
+ #fn_name(#prev_param).await
+ }
+ }
- async fn proc(#prev_param: Self::Previous) -> #return_type {
- // Call the original function
- #fn_name(#prev_param).await
+ // Keep the original function for internal use
+ #(#fn_attrs)*
+ #vis async fn #fn_name(#prev_param: #previous_type)
+ -> ::mingling::ChainProcess<DefaultProgram>
+ {
+ #fn_body
}
}
+ } else {
+ quote! {
+ #(#fn_attrs)*
+ #vis struct #struct_name;
+
+ impl ::mingling::Chain<#group_name> for #struct_name {
+ type Previous = #previous_type;
+
+ async fn proc(#prev_param: Self::Previous) ->
+ ::mingling::ChainProcess<#group_name>
+ {
+ let _ = GroupProcess;
+ // Call the original function
+ #fn_name(#prev_param).await
+ }
+ }
- // Keep the original function for internal use
- #(#fn_attrs)*
- #vis async fn #fn_name(#prev_param: #previous_type) -> #return_type {
- #fn_body
+ // Keep the original function for internal use
+ #(#fn_attrs)*
+ #vis async fn #fn_name(#prev_param: #previous_type)
+ -> ::mingling::ChainProcess<#group_name>
+ {
+ #fn_body
+ }
}
};
@@ -127,21 +181,28 @@ pub fn chain_attr(item: TokenStream) -> TokenStream {
};
let chain_exist_entry = quote! {
- id if id == std::any::TypeId::of::<#previous_type>() => true,
+ Self::#previous_type => true,
};
let mut chains = crate::CHAINS.lock().unwrap();
let mut chain_exist = crate::CHAINS_EXIST.lock().unwrap();
+ let mut packed_types = crate::PACKED_TYPES.lock().unwrap();
let chain_entry = chain_entry.to_string();
let chain_exist_entry = chain_exist_entry.to_string();
+ let previous_type_str = previous_type.to_token_stream().to_string();
if !chains.contains(&chain_entry) {
chains.push(chain_entry);
}
- if !chains.contains(&chain_exist_entry) {
+
+ if !chain_exist.contains(&chain_exist_entry) {
chain_exist.push(chain_exist_entry);
}
+ if !packed_types.contains(&previous_type_str) {
+ packed_types.push(previous_type_str);
+ }
+
expanded.into()
}