aboutsummaryrefslogtreecommitdiff
path: root/mingling_macros/src/chain.rs
diff options
context:
space:
mode:
Diffstat (limited to 'mingling_macros/src/chain.rs')
-rw-r--r--mingling_macros/src/chain.rs141
1 files changed, 102 insertions, 39 deletions
diff --git a/mingling_macros/src/chain.rs b/mingling_macros/src/chain.rs
index 02bbc6f..5ec5635 100644
--- a/mingling_macros/src/chain.rs
+++ b/mingling_macros/src/chain.rs
@@ -154,29 +154,47 @@ fn parse_chain_attr_args(attr: TokenStream) -> (proc_macro2::TokenStream, bool)
}
/// Validates that the return type of the function is `Next`.
-fn validate_return_type_is_next_process(sig: &Signature) -> Result<(), proc_macro2::TokenStream> {
+/// Checks whether the return type is `()` (unit).
+fn is_unit_return_type(sig: &Signature) -> bool {
+ match &sig.output {
+ ReturnType::Type(_, ty) => match &**ty {
+ Type::Tuple(tuple) => tuple.elems.is_empty(),
+ _ => false,
+ },
+ ReturnType::Default => true,
+ }
+}
+
+fn validate_return_type(sig: &Signature) -> Result<(), proc_macro2::TokenStream> {
+ // If return type is `()`, it's valid (no Next required)
+ if is_unit_return_type(sig) {
+ return Ok(());
+ }
+
match &sig.output {
ReturnType::Type(_, ty) => match &**ty {
Type::Path(type_path) => {
let last_segment = type_path.path.segments.last().unwrap();
if last_segment.ident != "Next" {
- return Err(
- syn::Error::new(ty.span(), "Chain function must return `Next`")
- .to_compile_error(),
- );
+ return Err(syn::Error::new(
+ ty.span(),
+ "Chain function must return `Next` or `()`",
+ )
+ .to_compile_error());
}
}
_ => {
- return Err(
- syn::Error::new(ty.span(), "Chain function must return `Next`")
- .to_compile_error(),
- );
+ return Err(syn::Error::new(
+ ty.span(),
+ "Chain function must return `Next` or `()`",
+ )
+ .to_compile_error());
}
},
ReturnType::Default => {
return Err(syn::Error::new(
sig.span(),
- "Chain function must specify a return type (must be `Next`)",
+ "Chain function must specify a return type (must be `Next` or `()`)",
)
.to_compile_error());
}
@@ -258,48 +276,73 @@ fn generate_proc_fn(
fn_name: &Ident,
fn_body_stmts: &[syn::Stmt],
is_async_fn: bool,
+ is_unit_return: bool,
) -> proc_macro2::TokenStream {
let immut_resource_stmts = generate_immut_resource_bindings(resources.iter(), program_type);
let mut_resources: Vec<_> = resources.iter().filter(|r| r.is_mut).collect();
- let wrapped_body = wrap_body_with_mut_resources(fn_body_stmts, &mut_resources, program_type);
- #[cfg(feature = "async")]
- {
+ let body_stmts: &[syn::Stmt] = if is_unit_return && has_resources {
+ let mut stmts = fn_body_stmts.to_vec();
+ stmts.push(syn::Stmt::Expr(
+ syn::parse_quote! { crate::EmptyResult::new(()).to_chain() },
+ None,
+ ));
+ // Box::leak to get a &'static [syn::Stmt]
+ Box::leak(Box::new(stmts))
+ } else {
+ fn_body_stmts
+ };
+
+ let wrapped_body = wrap_body_with_mut_resources(body_stmts, &mut_resources, program_type);
+
+ // When the function returns `()`, wrap the result with EmptyResult
+ let call_or_wrapped = if is_unit_return {
if has_resources {
quote! {
- async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
- #(#immut_resource_stmts)*
- #wrapped_body
- }
+ #(#immut_resource_stmts)*
+ #wrapped_body
}
} else {
let call = if is_async_fn {
- quote! { #fn_name(#prev_param).await.into() }
+ quote! { #fn_name(#prev_param).await; }
} else {
- quote! { #fn_name(#prev_param).into() }
+ quote! { #fn_name(#prev_param); }
};
quote! {
- async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
- #call
- }
+ #call
+ crate::EmptyResult::new(()).to_chain()
+ }
+ }
+ } else if has_resources {
+ quote! {
+ #(#immut_resource_stmts)*
+ #wrapped_body
+ }
+ } else {
+ let call = if is_async_fn {
+ quote! { #fn_name(#prev_param).await.into() }
+ } else {
+ quote! { #fn_name(#prev_param).into() }
+ };
+ quote! {
+ #call
+ }
+ };
+
+ #[cfg(feature = "async")]
+ {
+ quote! {
+ async fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
+ #call_or_wrapped
}
}
}
#[cfg(not(feature = "async"))]
{
- if has_resources {
- quote! {
- fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
- #(#immut_resource_stmts)*
- #wrapped_body
- }
- }
- } else {
- quote! {
- fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
- #fn_name(#prev_param).into()
- }
+ quote! {
+ fn proc(#prev_param: #previous_type) -> ::mingling::ChainProcess<#program_type> {
+ #call_or_wrapped
}
}
}
@@ -316,7 +359,22 @@ fn generate_original_fn(
fn_body: &syn::Block,
is_async_fn: bool,
program_type: &proc_macro2::TokenStream,
+ is_unit_return: bool,
) -> proc_macro2::TokenStream {
+ // Both unit and Next return types need to produce `impl Into<ChainProcess<ProgramType>>`
+ let return_type = quote! { impl Into<::mingling::ChainProcess<#program_type>> };
+
+ let body = if is_unit_return {
+ quote! {
+ {
+ #fn_body
+ crate::EmptyResult::new(()).to_chain()
+ }
+ }
+ } else {
+ quote! { #fn_body }
+ };
+
#[cfg(feature = "async")]
{
let async_kw = if is_async_fn {
@@ -326,8 +384,8 @@ fn generate_original_fn(
};
quote! {
#(#fn_attrs)*
- #vis #async_kw fn #fn_name(#inputs) -> impl Into<::mingling::ChainProcess<#program_type>> {
- #fn_body
+ #vis #async_kw fn #fn_name(#inputs) -> #return_type {
+ #body
}
}
}
@@ -336,8 +394,8 @@ fn generate_original_fn(
{
quote! {
#(#fn_attrs)*
- #vis fn #fn_name(#inputs) -> impl Into<::mingling::ChainProcess<#program_type>> {
- #fn_body
+ #vis fn #fn_name(#inputs) -> #return_type {
+ #body
}
}
}
@@ -426,8 +484,11 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
}
}
+ // Check if return type is unit
+ let is_unit_return = is_unit_return_type(&input_fn.sig);
+
// Validate return type
- if let Err(err) = validate_return_type_is_next_process(&input_fn.sig) {
+ if let Err(err) = validate_return_type(&input_fn.sig) {
return err.into();
}
@@ -482,6 +543,7 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
is_async_fn,
#[cfg(not(feature = "async"))]
false,
+ is_unit_return,
);
// Generate the original function
@@ -496,6 +558,7 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
#[cfg(not(feature = "async"))]
false,
&program_type,
+ is_unit_return,
);
// Assemble the final output