diff options
Diffstat (limited to 'mingling_macros/src')
| -rw-r--r-- | mingling_macros/src/chain.rs | 472 |
1 files changed, 287 insertions, 185 deletions
diff --git a/mingling_macros/src/chain.rs b/mingling_macros/src/chain.rs index fd1db65..111c584 100644 --- a/mingling_macros/src/chain.rs +++ b/mingling_macros/src/chain.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use proc_macro::TokenStream; use quote::{ToTokens, quote}; use syn::spanned::Spanned; @@ -35,7 +37,9 @@ fn extract_args_info(sig: &Signature) -> syn::Result<(Pat, TypePath, Vec<Resourc return Err(syn::Error::new( type_path.span(), format!( - "The type `{}` in #[chain] function must be a simple single-segment type, e.g. `Empty` instead of `other::Empty`. Qualified paths with `::` are not allowed here.", + "The type `{}` in #[chain] function must be a simple single-segment type, \ + e.g. `Empty` instead of `other::Empty`. \ + Qualified paths with `::` are not allowed here.", quote! { #type_path } ), )); @@ -45,7 +49,9 @@ fn extract_args_info(sig: &Signature) -> syn::Result<(Pat, TypePath, Vec<Resourc Type::Reference(_) => { return Err(syn::Error::new( ty.span(), - "The first parameter (previous type) must be taken by move, not by reference. Use `prev: SomeEntry` instead of `prev: &SomeEntry`.", + "The first parameter (previous type) must be taken by move, \ + not by reference. \ + Use `prev: SomeEntry` instead of `prev: &SomeEntry`.", )); } _ => { @@ -100,13 +106,15 @@ fn extract_args_info(sig: &Signature) -> syn::Result<(Pat, TypePath, Vec<Resourc Type::Path(_) => { return Err(syn::Error::new( ty.span(), - "Resource injection parameter must be a reference (`&T` or `&mut T`), not an owned value. Use `age: &Age` instead of `age: Age`.", + "Resource injection parameter must be a reference (`&T` or `&mut T`), \ + not an owned value. Use `age: &Age` instead of `age: Age`.", )); } _ => { return Err(syn::Error::new( ty.span(), - "Resource injection type must be a type path or reference to one (e.g., `age: Age` or `age: &Age`)", + "Resource injection type must be a type path or reference to one \ + (e.g., `age: Age` or `age: &Age`)", )); } }; @@ -131,129 +139,62 @@ fn extract_args_info(sig: &Signature) -> syn::Result<(Pat, TypePath, Vec<Resourc Ok((prev_param, previous_type, resources)) } -pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream { - // Parse the attribute arguments (e.g., MyProgram from #[chain(MyProgram)]) - // If no argument is provided, use ThisProgram - let (group_name, use_crate_prefix) = if attr.is_empty() { +/// Parses the `#[chain(...)]` attribute arguments. +/// +/// Returns: +/// - `program_path`: the token stream representing the program type +/// - `use_crate_prefix`: whether to use the default crate-defined program path +fn parse_chain_attr_args(attr: TokenStream) -> (proc_macro2::TokenStream, bool) { + if attr.is_empty() { (crate::default_program_path(), true) } else { - let path: syn::Path = parse_macro_input!(attr as syn::Path); + let path: syn::Path = syn::parse(attr).expect("#[chain(..)] argument must be a path"); (quote! { #path }, false) - }; - - // Parse the function item - let input_fn = parse_macro_input!(item as ItemFn); - - // In `async` mode, check if the function is an async function - #[cfg(feature = "async")] - let is_async_fn = input_fn.sig.asyncness.is_some(); - - // Validate the chain functions is a regular function - #[cfg(not(feature = "async"))] - { - if input_fn.sig.asyncness.is_some() { - return syn::Error::new( - input_fn.sig.span(), - "Chain function cannot be async when async feature is disabled", - ) - .to_compile_error() - .into(); - } } +} - // Check that return type is NextProcess - let return_type = &input_fn.sig.output; - match return_type { - ReturnType::Type(_, ty) => { - // Check if the return type is NextProcess - match &**ty { - Type::Path(type_path) => { - let last_segment = type_path.path.segments.last().unwrap(); - if last_segment.ident != "NextProcess" { - return syn::Error::new( - ty.span(), - "Chain function must return `NextProcess`", - ) - .to_compile_error() - .into(); - } - } - _ => { - return syn::Error::new(ty.span(), "Chain function must return `NextProcess`") - .to_compile_error() - .into(); +/// Validates that the return type of the function is `NextProcess`. +fn validate_return_type_is_next_process(sig: &Signature) -> Result<(), proc_macro2::TokenStream> { + match &sig.output { + ReturnType::Type(_, ty) => match &**ty { + Type::Path(type_path) => { + let last_segment = type_path.path.segments.last().unwrap(); + if last_segment.ident != "NextProcess" { + return Err(syn::Error::new( + ty.span(), + "Chain function must return `NextProcess`", + ) + .to_compile_error()); } } - } + _ => { + return Err( + syn::Error::new(ty.span(), "Chain function must return `NextProcess`") + .to_compile_error(), + ); + } + }, ReturnType::Default => { - return syn::Error::new( - input_fn.sig.span(), + return Err(syn::Error::new( + sig.span(), "Chain function must specify a return type (must be `NextProcess`)", ) - .to_compile_error() - .into(); - } - } - - // Extract the previous type, parameter name, and resource injection params - let (prev_param, previous_type, resources) = match extract_args_info(&input_fn.sig) { - Ok(info) => info, - Err(e) => return e.to_compile_error().into(), - }; - - // Get the function signature components for direct substitution - let sig = &input_fn.sig; - let inputs = &sig.inputs; - - // Get the function body - let fn_body = &input_fn.block; - - // Get function attributes (excluding the chain attribute) - let mut fn_attrs = input_fn.attrs.clone(); - - // Remove any #[chain(...)] attributes to avoid infinite recursion - fn_attrs.retain(|attr| !attr.path().is_ident("chain")); - - // Get function visibility - let vis = &input_fn.vis; - - // Get function name - let fn_name = &input_fn.sig.ident; - - // Generate struct name from function name using snake_case - let internal_name = format!( - "__internal_chain_{}", - just_fmt::snake_case!(fn_name.to_string()) - ); - let struct_name = Ident::new(&internal_name, fn_name.span()); - - // Determine the program type for the return type - let program_type = if use_crate_prefix { - crate::default_program_path() - } else { - group_name.clone() - }; - - // Check for async fn + &mut combination, which is not supported - #[cfg(feature = "async")] - if is_async_fn { - if let Some(mut_res) = resources.iter().find(|r| r.is_mut) { - return syn::Error::new( - mut_res.var_name.span(), - "Cannot use `&mut` resource injection in async chain function. ", - ) - .to_compile_error() - .into(); + .to_compile_error()); } } + Ok(()) +} - // Separate resources into immutable refs and mutable refs - let immut_resources: Vec<_> = resources.iter().filter(|r| !r.is_mut).collect(); - let mut_resources: Vec<_> = resources.iter().filter(|r| r.is_mut).collect(); - - // Build resource injection statements for immutable references (let ... = ...) - let immut_resource_stmts: Vec<_> = immut_resources - .iter() +/// Generates `let` binding statements for immutable resource injection parameters. +/// +/// Each immutable reference parameter gets a `_binding` variable that holds the +/// `res_or_default` result, then a shadowing `let` that borrows from it via `.as_ref()`. +fn generate_immut_resource_bindings<'a>( + resources: impl Iterator<Item = &'a ResourceInjection>, + program_type: &proc_macro2::TokenStream, +) -> Vec<proc_macro2::TokenStream> { + resources + .filter(|r| !r.is_mut) .map(|res| { let var_binding_name = syn::Ident::new( &format!("{}_binding", &res.var_name.to_string()), @@ -275,34 +216,57 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream { } } }) - .collect(); - - // Build nested __modify_res_and_return_any wrappers for mutable references. - // The innermost layer is the original function body, wrapping outward for each - // mutable resource. - let body_stmts = &fn_body.stmts; - let mut wrapped_body = quote! { - #(#body_stmts)* + .collect() +} + +/// Wraps the function body in nested `__modify_res_and_return_any` closures for +/// each mutable resource parameter. The innermost closure gets the original body, +/// and each mutable parameter wraps outward from last to first. +fn wrap_body_with_mut_resources( + fn_body_stmts: &[syn::Stmt], + mut_resources: &[&ResourceInjection], + program_type: &proc_macro2::TokenStream, +) -> proc_macro2::TokenStream { + let mut wrapped = quote! { + #(#fn_body_stmts)* }; - // Wrap from inside to outside: the first mutable parameter becomes the outermost wrapper, - // and the last mutable parameter becomes the innermost wrapper. for res in mut_resources.iter() { let var_name = &res.var_name; let inner_type = &res.inner_type; - wrapped_body = quote! { + wrapped = quote! { ::mingling::this::<#program_type>().__modify_res_and_return_any(|#var_name: &mut #inner_type| { - #wrapped_body + #wrapped }).into() }; } - let has_immut_resources = !immut_resources.is_empty(); - let has_mut_resources = !mut_resources.is_empty(); + wrapped +} + +/// Builds the `proc` function implementation that serves as the actual chain +/// entry point inside the generated `Chain` impl. +/// +/// * Without resources: delegates directly to the original function. +/// * With resources: inlines the body and prepends resource bindings. +#[allow(unused_variables)] +fn generate_proc_fn( + has_resources: bool, + resources: &[ResourceInjection], + program_type: &proc_macro2::TokenStream, + previous_type: &TypePath, + prev_param: &Pat, + fn_name: &Ident, + fn_body_stmts: &[syn::Stmt], + is_async_fn: bool, +) -> proc_macro2::TokenStream { + let immut_resource_stmts = generate_immut_resource_bindings(resources.iter(), program_type); + let mut_resources: Vec<_> = resources.iter().filter(|r| r.is_mut).collect(); + let wrapped_body = wrap_body_with_mut_resources(fn_body_stmts, &mut_resources, program_type); #[cfg(feature = "async")] - let proc_fn = if is_async_fn { - if has_immut_resources || has_mut_resources { + { + if has_resources { quote! { async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> { #(#immut_resource_stmts)* @@ -310,108 +274,246 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream { } } } else { + let call = if is_async_fn { + quote! { #fn_name(#prev_param).await.into() } + } else { + quote! { #fn_name(#prev_param).into() } + }; quote! { async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> { - #fn_name(#prev_param).await.into() + #call } } } - } else { - if has_immut_resources || has_mut_resources { + } + + #[cfg(not(feature = "async"))] + { + if has_resources { quote! { - async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> { + fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> { #(#immut_resource_stmts)* #wrapped_body } } } else { quote! { - async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> { + fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> { #fn_name(#prev_param).into() } } } - }; + } +} +/// Generates the original function signature (kept for backwards compatibility / +/// internal use), with its return type changed to `impl Into<ChainProcess<..>>`. +#[allow(unused_variables)] +fn generate_original_fn( + fn_attrs: &[syn::Attribute], + vis: &syn::Visibility, + fn_name: &Ident, + inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>, + fn_body: &syn::Block, + is_async_fn: bool, + program_type: &proc_macro2::TokenStream, +) -> proc_macro2::TokenStream { #[cfg(feature = "async")] - let origin_proc_fn = if is_async_fn { + { + let async_kw = if is_async_fn { + quote! { async } + } else { + quote! {} + }; quote! { #(#fn_attrs)* - #vis async fn #fn_name(#inputs) -> impl Into<::mingling::ChainProcess<#program_type>> { + #vis #async_kw fn #fn_name(#inputs) -> impl Into<::mingling::ChainProcess<#program_type>> { #fn_body } } - } else { + } + + #[cfg(not(feature = "async"))] + { quote! { #(#fn_attrs)* #vis fn #fn_name(#inputs) -> impl Into<::mingling::ChainProcess<#program_type>> { #fn_body } } - }; + } +} - #[cfg(not(feature = "async"))] - let proc_fn = if has_immut_resources || has_mut_resources { - quote! { - fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> { - #(#immut_resource_stmts)* - #wrapped_body - } - } +/// Assembles the final expanded output: hidden struct, `register_chain!` invocation, +/// `Chain` impl with the `proc` method, and the original function. +fn generate_struct_and_impl( + fn_attrs: &[syn::Attribute], + vis: &syn::Visibility, + struct_name: &Ident, + previous_type: &TypePath, + previous_type_str: &proc_macro2::TokenStream, + group_name: &proc_macro2::TokenStream, + program_type: &proc_macro2::TokenStream, + use_crate_prefix: bool, + proc_fn: proc_macro2::TokenStream, + origin_proc_fn: proc_macro2::TokenStream, +) -> proc_macro2::TokenStream { + let chain_type = if use_crate_prefix { + program_type } else { - quote! { - fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> { - #fn_name(#prev_param).into() - } - } + group_name }; - #[cfg(not(feature = "async"))] - let origin_proc_fn = quote! { + quote! { #(#fn_attrs)* - #vis fn #fn_name(#inputs) -> impl Into<::mingling::ChainProcess<#program_type>> { - #fn_body + #[doc(hidden)] + #[allow(non_camel_case_types)] + #vis struct #struct_name; + + ::mingling::macros::register_chain!(#previous_type_str, #struct_name); + + impl ::mingling::Chain<#chain_type> for #struct_name { + type Previous = #previous_type; + + #proc_fn } - }; - // Generate the struct and implementation - let expanded = if use_crate_prefix { - quote! { - #(#fn_attrs)* - #[doc(hidden)] - #[allow(non_camel_case_types)] - #vis struct #struct_name; + // Keep the original function for internal use + #origin_proc_fn + } +} - ::mingling::macros::register_chain!(#previous_type, #struct_name); +/// Ensures the function is not async when the `async` feature is disabled. +#[cfg(not(feature = "async"))] +fn reject_async(sig: &Signature) -> Result<(), proc_macro2::TokenStream> { + if sig.asyncness.is_some() { + return Err(syn::Error::new( + sig.span(), + "Chain function cannot be async when async feature is disabled", + ) + .to_compile_error()); + } + Ok(()) +} - impl ::mingling::Chain<#program_type> for #struct_name { - type Previous = #previous_type; +/// Ensures no `&mut` resource injection is used in async functions. +#[cfg(feature = "async")] +fn reject_mut_in_async(resources: &[ResourceInjection]) -> Result<(), proc_macro2::TokenStream> { + if let Some(mut_res) = resources.iter().find(|r| r.is_mut) { + return Err(syn::Error::new( + mut_res.var_name.span(), + "Cannot use `&mut` resource injection in async chain function.", + ) + .to_compile_error()); + } + Ok(()) +} - #proc_fn - } +pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream { + // Parse attribute arguments + let (group_name, use_crate_prefix) = parse_chain_attr_args(attr); - // Keep the original function for internal use - #origin_proc_fn - } - } else { - quote! { - #(#fn_attrs)* - #[allow(non_camel_case_types)] - #vis struct #struct_name; + // Parse the function item + let input_fn = parse_macro_input!(item as ItemFn); - ::mingling::macros::register_chain!(#previous_type, #struct_name); + // Handle async feature gate + #[cfg(feature = "async")] + let is_async_fn = input_fn.sig.asyncness.is_some(); - impl ::mingling::Chain<#group_name> for #struct_name { - type Previous = #previous_type; + #[cfg(not(feature = "async"))] + { + if let Err(err) = reject_async(&input_fn.sig) { + return err.into(); + } + } - #proc_fn - } + // Validate return type + if let Err(err) = validate_return_type_is_next_process(&input_fn.sig) { + return err.into(); + } + + // Extract the previous type, parameter name, and resource injection params + let (prev_param, previous_type, resources) = match extract_args_info(&input_fn.sig) { + Ok(info) => info, + Err(e) => return e.to_compile_error().into(), + }; - // Keep the original function for internal use - #origin_proc_fn + // Reject `&mut` in async chains + #[cfg(feature = "async")] + if is_async_fn { + if let Err(err) = reject_mut_in_async(&resources) { + return err.into(); } + } + + // Prepare building blocks + let sig = &input_fn.sig; + let inputs = &sig.inputs; + let fn_body = &input_fn.block; + let mut fn_attrs = input_fn.attrs.clone(); + fn_attrs.retain(|attr| !attr.path().is_ident("chain")); + let vis = &input_fn.vis; + let fn_name = &input_fn.sig.ident; + let has_resources = !resources.is_empty(); + + // Generate struct name + let internal_name = format!( + "__internal_chain_{}", + just_fmt::snake_case!(fn_name.to_string()) + ); + let struct_name = Ident::new(&internal_name, fn_name.span()); + + // Determine the program type for the return type + let program_type = if use_crate_prefix { + crate::default_program_path() + } else { + group_name.clone() }; + // Generate the `proc` function + let proc_fn = generate_proc_fn( + has_resources, + &resources, + &program_type, + &previous_type, + &prev_param, + fn_name, + &fn_body.stmts, + #[cfg(feature = "async")] + is_async_fn, + #[cfg(not(feature = "async"))] + false, + ); + + // Generate the original function + let origin_proc_fn = generate_original_fn( + &fn_attrs, + vis, + fn_name, + inputs, + fn_body, + #[cfg(feature = "async")] + is_async_fn, + #[cfg(not(feature = "async"))] + false, + &program_type, + ); + + // Assemble the final output + let previous_type_str = quote! { #previous_type }; + let expanded = generate_struct_and_impl( + &fn_attrs, + vis, + &struct_name, + &previous_type, + &previous_type_str, + &group_name, + &program_type, + use_crate_prefix, + proc_fn, + origin_proc_fn, + ); + expanded.into() } |
