aboutsummaryrefslogtreecommitdiff
path: root/mingling_macros/src/chain.rs
diff options
context:
space:
mode:
Diffstat (limited to 'mingling_macros/src/chain.rs')
-rw-r--r--mingling_macros/src/chain.rs76
1 files changed, 17 insertions, 59 deletions
diff --git a/mingling_macros/src/chain.rs b/mingling_macros/src/chain.rs
index 7894362..daa6b1c 100644
--- a/mingling_macros/src/chain.rs
+++ b/mingling_macros/src/chain.rs
@@ -9,9 +9,7 @@
use proc_macro::TokenStream;
use quote::{ToTokens, quote};
use syn::spanned::Spanned;
-use syn::{
- FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Signature, Type, TypePath, parse_macro_input,
-};
+use syn::{FnArg, Ident, ItemFn, Pat, PatType, Signature, Type, TypePath, parse_macro_input};
/// Extracts the previous type and parameter name from function arguments
fn extract_previous_info(sig: &Signature) -> syn::Result<(Pat, TypePath)> {
@@ -45,23 +43,6 @@ fn extract_previous_info(sig: &Signature) -> syn::Result<(Pat, TypePath)> {
}
}
-/// Extracts the return type from the function signature
-fn extract_return_type(sig: &Signature) -> syn::Result<TypePath> {
- match &sig.output {
- ReturnType::Type(_, ty) => match &**ty {
- Type::Path(type_path) => Ok(type_path.clone()),
- _ => Err(syn::Error::new(
- ty.span(),
- "Return type must be a type path",
- )),
- },
- ReturnType::Default => Err(syn::Error::new(
- sig.span(),
- "Chain function must have a return type",
- )),
- }
-}
-
pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
// Parse the attribute arguments (e.g., MyProgram from #[chain(MyProgram)])
// If no argument is provided, use ThisProgram
@@ -100,21 +81,9 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
Err(e) => return e.to_compile_error().into(),
};
- // Extract the return type
- let return_type = match extract_return_type(&input_fn.sig) {
- Ok(ty) => ty,
- Err(e) => return e.to_compile_error().into(),
- };
-
- // Ensure the return type is named "NextProcess"
- if return_type.path.segments.last().unwrap().ident != "NextProcess" {
- return syn::Error::new(
- return_type.span(),
- "Return type must be 'mingling::marker::NextProcess'",
- )
- .to_compile_error()
- .into();
- }
+ // Get the function signature components for direct substitution
+ let sig = &input_fn.sig;
+ let inputs = &sig.inputs;
// Get the function body
let fn_body = &input_fn.block;
@@ -135,24 +104,23 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
let pascal_case_name = just_fmt::pascal_case!(fn_name.to_string());
let struct_name = Ident::new(&pascal_case_name, fn_name.span());
+ // Determine the program type for the return type
+ let program_type = if use_crate_prefix {
+ quote! { ThisProgram }
+ } else {
+ quote! { #group_name }
+ };
+
#[cfg(feature = "async")]
let proc_fn = if is_async_fn {
quote! {
- async fn proc(#prev_param: Self::Previous) ->
- ::mingling::ChainProcess<ThisProgram>
- {
- let _ = NextProcess;
- // Call the original function
+ async fn proc(#inputs) -> ::mingling::ChainProcess<#program_type> {
#fn_name(#prev_param).await.into()
}
}
} else {
quote! {
- async fn proc(#prev_param: Self::Previous) ->
- ::mingling::ChainProcess<ThisProgram>
- {
- let _ = NextProcess;
- // Call the original function
+ async fn proc(#inputs) -> ::mingling::ChainProcess<#program_type> {
#fn_name(#prev_param).into()
}
}
@@ -162,18 +130,14 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
let origin_proc_fn = if is_async_fn {
quote! {
#(#fn_attrs)*
- #vis async fn #fn_name(#prev_param: #previous_type)
- -> impl Into<::mingling::ChainProcess<#group_name>>
- {
+ #vis async fn #fn_name(#inputs) -> impl Into<::mingling::ChainProcess<#program_type>> {
#fn_body
}
}
} else {
quote! {
#(#fn_attrs)*
- #vis fn #fn_name(#prev_param: #previous_type)
- -> impl Into<::mingling::ChainProcess<#group_name>>
- {
+ #vis fn #fn_name(#inputs) -> impl Into<::mingling::ChainProcess<#program_type>> {
#fn_body
}
}
@@ -181,11 +145,7 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
#[cfg(not(feature = "async"))]
let proc_fn = quote! {
- fn proc(#prev_param: Self::Previous) ->
- ::mingling::ChainProcess<ThisProgram>
- {
- let _ = NextProcess;
- // Call the original function
+ fn proc(#inputs) -> ::mingling::ChainProcess<#program_type> {
#fn_name(#prev_param).into()
}
};
@@ -193,9 +153,7 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
#[cfg(not(feature = "async"))]
let origin_proc_fn = quote! {
#(#fn_attrs)*
- #vis fn #fn_name(#prev_param: #previous_type)
- -> impl Into<::mingling::ChainProcess<#group_name>>
- {
+ #vis fn #fn_name(#inputs) -> impl Into<::mingling::ChainProcess<#program_type>> {
#fn_body
}
};