aboutsummaryrefslogtreecommitdiff
path: root/mingling_macros/src/chain.rs
diff options
context:
space:
mode:
Diffstat (limited to 'mingling_macros/src/chain.rs')
-rw-r--r--mingling_macros/src/chain.rs65
1 files changed, 47 insertions, 18 deletions
diff --git a/mingling_macros/src/chain.rs b/mingling_macros/src/chain.rs
index f9291e0..496a0a4 100644
--- a/mingling_macros/src/chain.rs
+++ b/mingling_macros/src/chain.rs
@@ -11,6 +11,7 @@ struct ResourceInjection {
full_type: Type,
inner_type: TypePath,
is_ref: bool,
+ is_mut: bool,
}
/// Extracts the previous type and parameter name from function arguments,
@@ -64,9 +65,13 @@ fn extract_args_info(sig: &Signature) -> syn::Result<(Pat, TypePath, Vec<Resourc
let full_type = *(*ty).clone();
// Try to extract inner type for reference patterns like `&Age` -> `Age`
- let (inner_type, is_ref) = match &full_type {
+ // and `&mut Age` -> `Age`
+ let (inner_type, is_ref, is_mut) = match &full_type {
Type::Reference(ref_type) => match &*ref_type.elem {
- Type::Path(type_path) => (type_path.clone(), true),
+ Type::Path(type_path) => {
+ let is_mut = ref_type.mutability.is_some();
+ (type_path.clone(), true, is_mut)
+ }
_ => {
return Err(syn::Error::new(
ty.span(),
@@ -74,7 +79,7 @@ fn extract_args_info(sig: &Signature) -> syn::Result<(Pat, TypePath, Vec<Resourc
));
}
},
- Type::Path(type_path) => (type_path.clone(), false),
+ Type::Path(type_path) => (type_path.clone(), false, false),
_ => {
return Err(syn::Error::new(
ty.span(),
@@ -88,6 +93,7 @@ fn extract_args_info(sig: &Signature) -> syn::Result<(Pat, TypePath, Vec<Resourc
full_type,
inner_type,
is_ref,
+ is_mut,
});
}
FnArg::Receiver(_) => {
@@ -207,8 +213,12 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
quote! { #group_name }
};
- // Build resource injection statements (let ... = ...)
- let resource_stmts: Vec<_> = resources
+ // Separate resources into immutable refs and mutable refs
+ let immut_resources: Vec<_> = resources.iter().filter(|r| !r.is_mut).collect();
+ let mut_resources: Vec<_> = resources.iter().filter(|r| r.is_mut).collect();
+
+ // Build resource injection statements for immutable references (let ... = ...)
+ let immut_resource_stmts: Vec<_> = immut_resources
.iter()
.map(|res| {
let var_binding_name = syn::Ident::new(
@@ -233,16 +243,37 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
})
.collect();
- let has_resources = !resources.is_empty();
+ // Build nested __modify_res_and_return_any wrappers for mutable references
+ // The innermost layer is the original function body, wrapping outward for each mutable resource.
+ // The function returns a value, so the return values need to be properly chained.
+ let body_stmts = &fn_body.stmts;
+ let mut wrapped_body = quote! {
+ #(#body_stmts)*
+ };
+
+ // Wrap from inside to outside: the first mutable parameter becomes the outermost wrapper,
+ // and the last mutable parameter becomes the innermost wrapper.
+ // Therefore iterate mut_resources and wrap outward.
+ for res in mut_resources.iter() {
+ let var_name = &res.var_name;
+ let inner_type = &res.inner_type;
+ wrapped_body = quote! {
+ ::mingling::this::<#program_type>().__modify_res_and_return_any(|#var_name: &mut #inner_type| {
+ #wrapped_body
+ }).into()
+ };
+ }
+
+ let has_immut_resources = !immut_resources.is_empty();
+ let has_mut_resources = !mut_resources.is_empty();
#[cfg(feature = "async")]
let proc_fn = if is_async_fn {
- if has_resources {
- let body_stmts = &fn_body.stmts;
+ if has_immut_resources || has_mut_resources {
quote! {
async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
- #(#resource_stmts)*
- #(#body_stmts)*
+ #(#immut_resource_stmts)*
+ #wrapped_body
}
}
} else {
@@ -253,12 +284,11 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
}
}
} else {
- if has_resources {
- let body_stmts = &fn_body.stmts;
+ if has_immut_resources || has_mut_resources {
quote! {
async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
- #(#resource_stmts)*
- #(#body_stmts)*
+ #(#immut_resource_stmts)*
+ #wrapped_body
}
}
} else {
@@ -288,12 +318,11 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
};
#[cfg(not(feature = "async"))]
- let proc_fn = if has_resources {
- let body_stmts = &fn_body.stmts;
+ let proc_fn = if has_immut_resources || has_mut_resources {
quote! {
fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
- #(#resource_stmts)*
- #(#body_stmts)*
+ #(#immut_resource_stmts)*
+ #wrapped_body
}
}
} else {