aboutsummaryrefslogtreecommitdiff
path: root/mingling_macros/src/chain.rs
diff options
context:
space:
mode:
Diffstat (limited to 'mingling_macros/src/chain.rs')
-rw-r--r--mingling_macros/src/chain.rs187
1 files changed, 155 insertions, 32 deletions
diff --git a/mingling_macros/src/chain.rs b/mingling_macros/src/chain.rs
index 6646a68..f9291e0 100644
--- a/mingling_macros/src/chain.rs
+++ b/mingling_macros/src/chain.rs
@@ -5,36 +5,101 @@ use syn::{
FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Signature, Type, TypePath, parse_macro_input,
};
-/// Extracts the previous type and parameter name from function arguments
-fn extract_previous_info(sig: &Signature) -> syn::Result<(Pat, TypePath)> {
- // The function should have exactly one parameter
- if sig.inputs.len() != 1 {
+/// Extracted information about a resource injection parameter
+struct ResourceInjection {
+ var_name: Ident,
+ full_type: Type,
+ inner_type: TypePath,
+ is_ref: bool,
+}
+
+/// Extracts the previous type and parameter name from function arguments,
+fn extract_args_info(sig: &Signature) -> syn::Result<(Pat, TypePath, Vec<ResourceInjection>)> {
+ if sig.inputs.is_empty() {
return Err(syn::Error::new(
- sig.inputs.span(),
- "Chain function must have exactly one parameter",
+ sig.span(),
+ "Chain function must have at least one parameter",
));
}
- let arg = &sig.inputs[0];
- match arg {
+ // First parameter: required, the previous chain type
+ let first_arg = &sig.inputs[0];
+ let (prev_param, previous_type) = match first_arg {
FnArg::Typed(PatType { pat, ty, .. }) => {
- // Extract the pattern (parameter name)
let param_pat = (**pat).clone();
-
- // Extract the type
match &**ty {
- Type::Path(type_path) => Ok((param_pat, type_path.clone())),
- _ => Err(syn::Error::new(
- ty.span(),
- "Parameter type must be a type path",
- )),
+ Type::Path(type_path) => (param_pat, type_path.clone()),
+ _ => {
+ return Err(syn::Error::new(
+ ty.span(),
+ "First parameter type must be a type path",
+ ));
+ }
+ }
+ }
+ FnArg::Receiver(_) => {
+ return Err(syn::Error::new(
+ first_arg.span(),
+ "Chain function cannot have self parameter",
+ ));
+ }
+ };
+
+ // 2nd to Nth parameters: optional, for resource injection
+ let mut resources = Vec::new();
+ for arg in sig.inputs.iter().skip(1) {
+ match arg {
+ FnArg::Typed(PatType { pat, ty, .. }) => {
+ // Extract the variable name – must be a simple identifier
+ let var_name = match &**pat {
+ Pat::Ident(pat_ident) => pat_ident.ident.clone(),
+ _ => {
+ return Err(syn::Error::new(
+ pat.span(),
+ "Resource injection parameter must be a simple identifier (e.g., `age: &Age`)",
+ ));
+ }
+ };
+
+ let full_type = *(*ty).clone();
+
+ // Try to extract inner type for reference patterns like `&Age` -> `Age`
+ let (inner_type, is_ref) = match &full_type {
+ Type::Reference(ref_type) => match &*ref_type.elem {
+ Type::Path(type_path) => (type_path.clone(), true),
+ _ => {
+ return Err(syn::Error::new(
+ ty.span(),
+ "Reference resource type must be a type path (e.g., `age: &Age`)",
+ ));
+ }
+ },
+ Type::Path(type_path) => (type_path.clone(), false),
+ _ => {
+ return Err(syn::Error::new(
+ ty.span(),
+ "Resource injection type must be a type path or reference to one (e.g., `age: Age` or `age: &Age`)",
+ ));
+ }
+ };
+
+ resources.push(ResourceInjection {
+ var_name,
+ full_type,
+ inner_type,
+ is_ref,
+ });
+ }
+ FnArg::Receiver(_) => {
+ return Err(syn::Error::new(
+ arg.span(),
+ "Resource injection parameter cannot be self",
+ ));
}
}
- FnArg::Receiver(_) => Err(syn::Error::new(
- arg.span(),
- "Chain function cannot have self parameter",
- )),
}
+
+ Ok((prev_param, previous_type, resources))
}
pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
@@ -103,8 +168,8 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
}
}
- // Extract the previous type and parameter name from function arguments
- let (prev_param, previous_type) = match extract_previous_info(&input_fn.sig) {
+ // Extract the previous type, parameter name, and resource injection params
+ let (prev_param, previous_type, resources) = match extract_args_info(&input_fn.sig) {
Ok(info) => info,
Err(e) => return e.to_compile_error().into(),
};
@@ -142,17 +207,65 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
quote! { #group_name }
};
+ // Build resource injection statements (let ... = ...)
+ let resource_stmts: Vec<_> = resources
+ .iter()
+ .map(|res| {
+ let var_binding_name = syn::Ident::new(
+ &format!("{}_binding", &res.var_name.to_string()),
+ res.var_name.span(),
+ );
+ let var_name = &res.var_name;
+ let full_type = &res.full_type;
+ let inner_type = &res.inner_type;
+ if res.is_ref {
+ quote! {
+ let #var_binding_name = ::mingling::this::<#program_type>()
+ .res_or_default::<#inner_type>();
+ let #var_name: #full_type = #var_binding_name.as_ref();
+ }
+ } else {
+ quote! {
+ let #var_name: #full_type = ::mingling::this::<#program_type>()
+ .res_or_default::<#full_type>();
+ }
+ }
+ })
+ .collect();
+
+ let has_resources = !resources.is_empty();
+
#[cfg(feature = "async")]
let proc_fn = if is_async_fn {
- quote! {
- async fn proc(#inputs) -> ::mingling::ChainProcess<#program_type> {
- #fn_name(#prev_param).await.into()
+ if has_resources {
+ let body_stmts = &fn_body.stmts;
+ quote! {
+ async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
+ #(#resource_stmts)*
+ #(#body_stmts)*
+ }
+ }
+ } else {
+ quote! {
+ async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
+ #fn_name(#prev_param).await.into()
+ }
}
}
} else {
- quote! {
- async fn proc(#inputs) -> ::mingling::ChainProcess<#program_type> {
- #fn_name(#prev_param).into()
+ if has_resources {
+ let body_stmts = &fn_body.stmts;
+ quote! {
+ async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
+ #(#resource_stmts)*
+ #(#body_stmts)*
+ }
+ }
+ } else {
+ quote! {
+ async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
+ #fn_name(#prev_param).into()
+ }
}
}
};
@@ -175,9 +288,19 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
};
#[cfg(not(feature = "async"))]
- let proc_fn = quote! {
- fn proc(#inputs) -> ::mingling::ChainProcess<#program_type> {
- #fn_name(#prev_param).into()
+ let proc_fn = if has_resources {
+ let body_stmts = &fn_body.stmts;
+ quote! {
+ fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
+ #(#resource_stmts)*
+ #(#body_stmts)*
+ }
+ }
+ } else {
+ quote! {
+ fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
+ #fn_name(#prev_param).into()
+ }
}
};
@@ -248,7 +371,7 @@ pub fn register_chain(input: TokenStream) -> TokenStream {
// Parse the input as a comma-separated list of arguments
let input_parsed = syn::parse_macro_input!(input with syn::punctuated::Punctuated<syn::Expr, syn::Token![,]>::parse_terminated);
- // Check that we have exactly two elements
+ // Check that there are exactly two elements
if input_parsed.len() != 2 {
return syn::Error::new(
input_parsed.span(),