aboutsummaryrefslogtreecommitdiff
path: root/mingling_macros/src/chain.rs
diff options
context:
space:
mode:
Diffstat (limited to 'mingling_macros/src/chain.rs')
-rw-r--r--mingling_macros/src/chain.rs38
1 files changed, 37 insertions, 1 deletions
diff --git a/mingling_macros/src/chain.rs b/mingling_macros/src/chain.rs
index a91949d..e7b2db2 100644
--- a/mingling_macros/src/chain.rs
+++ b/mingling_macros/src/chain.rs
@@ -1,7 +1,9 @@
use proc_macro::TokenStream;
use quote::{ToTokens, quote};
use syn::spanned::Spanned;
-use syn::{FnArg, Ident, ItemFn, Pat, PatType, Signature, Type, TypePath, parse_macro_input};
+use syn::{
+ FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Signature, Type, TypePath, parse_macro_input,
+};
/// Extracts the previous type and parameter name from function arguments
fn extract_previous_info(sig: &Signature) -> syn::Result<(Pat, TypePath)> {
@@ -67,6 +69,40 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
}
}
+ // Check that return type is NextProcess
+ let return_type = &input_fn.sig.output;
+ match return_type {
+ ReturnType::Type(_, ty) => {
+ // Check if the return type is NextProcess
+ match &**ty {
+ Type::Path(type_path) => {
+ let last_segment = type_path.path.segments.last().unwrap();
+ if last_segment.ident.to_string() != "NextProcess" {
+ return syn::Error::new(
+ ty.span(),
+ "Chain function must return `NextProcess`",
+ )
+ .to_compile_error()
+ .into();
+ }
+ }
+ _ => {
+ return syn::Error::new(ty.span(), "Chain function must return `NextProcess`")
+ .to_compile_error()
+ .into();
+ }
+ }
+ }
+ ReturnType::Default => {
+ return syn::Error::new(
+ input_fn.sig.span(),
+ "Chain function must specify a return type (must be `NextProcess`)",
+ )
+ .to_compile_error()
+ .into();
+ }
+ }
+
// Extract the previous type and parameter name from function arguments
let (prev_param, previous_type) = match extract_previous_info(&input_fn.sig) {
Ok(info) => info,