diff options
| author | 魏曹先生 <1992414357@qq.com> | 2026-03-28 00:47:46 +0800 |
|---|---|---|
| committer | 魏曹先生 <1992414357@qq.com> | 2026-03-28 00:47:46 +0800 |
| commit | 7ce68cd11516bd7cf037ecea99a92aee7c31b2c3 (patch) | |
| tree | a3923ad41c91aa21fe169fd6b4b1bf8898a82589 /mingling_macros/src/chain.rs | |
Add initial Mingling framework codebase
Diffstat (limited to 'mingling_macros/src/chain.rs')
| -rw-r--r-- | mingling_macros/src/chain.rs | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/mingling_macros/src/chain.rs b/mingling_macros/src/chain.rs new file mode 100644 index 0000000..0c21c79 --- /dev/null +++ b/mingling_macros/src/chain.rs @@ -0,0 +1,164 @@ +//! Chain Attribute Macro Implementation +//! +//! This module provides the `#[chain]` attribute macro for automatically +//! generating structs that implement the `Chain` trait from async functions. + +use proc_macro::TokenStream; +use quote::quote; +use syn::parse::{Parse, ParseStream}; +use syn::spanned::Spanned; +use syn::{ + FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Signature, Type, TypePath, parse_macro_input, +}; + +/// Parses the chain attribute arguments +struct ChainAttribute { + struct_name: Ident, +} + +impl Parse for ChainAttribute { + fn parse(input: ParseStream) -> syn::Result<Self> { + let struct_name = input.parse()?; + Ok(ChainAttribute { struct_name }) + } +} + +/// Extracts the previous type and parameter name from function arguments +fn extract_previous_info(sig: &Signature) -> syn::Result<(Pat, TypePath)> { + // The function should have exactly one parameter + if sig.inputs.len() != 1 { + return Err(syn::Error::new( + sig.inputs.span(), + "Chain function must have exactly one parameter", + )); + } + + let arg = &sig.inputs[0]; + match arg { + FnArg::Typed(PatType { pat, ty, .. }) => { + // Extract the pattern (parameter name) + let param_pat = (**pat).clone(); + + // Extract the type + match &**ty { + Type::Path(type_path) => Ok((param_pat, type_path.clone())), + _ => Err(syn::Error::new( + ty.span(), + "Parameter type must be a type path", + )), + } + } + FnArg::Receiver(_) => Err(syn::Error::new( + arg.span(), + "Chain function cannot have self parameter", + )), + } +} + +/// Extracts the return type from the function signature +fn extract_return_type(sig: &Signature) -> syn::Result<TypePath> { + match &sig.output { + ReturnType::Type(_, ty) => match &**ty { + Type::Path(type_path) => Ok(type_path.clone()), + _ => Err(syn::Error::new( + ty.span(), + "Return type must be a type path", + )), + }, + ReturnType::Default => Err(syn::Error::new( + sig.span(), + "Chain function must have a return type", + )), + } +} + +/// Implementation of the `#[chain]` attribute macro +/// +/// This macro transforms an async function into a struct that implements +/// the `Chain` trait. The struct name is specified in the attribute. +/// +/// # Examples +/// +/// ```ignore +/// use mingling_macros::chain; +/// +/// #[chain(InitEntry)] +/// pub async fn process(data: InitBegin) -> mingling::AnyOutput { +/// AnyOutput::new::<InitResult>("初始化成功!".to_string().into()) +/// } +/// ``` +/// +/// This generates: +/// ```ignore +/// pub struct InitEntry; +/// impl Chain for InitEntry { +/// type Previous = InitBegin; +/// async fn proc(data: Self::Previous) -> mingling::AnyOutput { +/// AnyOutput::new::<InitResult>("初始化成功!".to_string().into()) +/// } +/// } +/// ``` +pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream { + // Parse the attribute arguments + let chain_attr = parse_macro_input!(attr as ChainAttribute); + let struct_name = chain_attr.struct_name; + + // Parse the function item + let input_fn = parse_macro_input!(item as ItemFn); + + // Validate the function + if !input_fn.sig.asyncness.is_some() { + return syn::Error::new(input_fn.sig.span(), "Chain function must be async") + .to_compile_error() + .into(); + } + + // Extract the previous type and parameter name from function arguments + let (prev_param, previous_type) = match extract_previous_info(&input_fn.sig) { + Ok(info) => info, + Err(e) => return e.to_compile_error().into(), + }; + + // Extract the return type + let return_type = match extract_return_type(&input_fn.sig) { + Ok(ty) => ty, + Err(e) => return e.to_compile_error().into(), + }; + + // Get the function body + let fn_body = &input_fn.block; + + // Get function attributes (excluding the chain attribute) + let mut fn_attrs = input_fn.attrs.clone(); + // Remove any #[chain(...)] attributes to avoid infinite recursion + fn_attrs.retain(|attr| !attr.path().is_ident("chain")); + + // Get function visibility + let vis = &input_fn.vis; + + // Get function name + let fn_name = &input_fn.sig.ident; + + // Generate the struct and implementation + let expanded = quote! { + #(#fn_attrs)* + #vis struct #struct_name; + + impl ::mingling::Chain for #struct_name { + type Previous = #previous_type; + + async fn proc(#prev_param: Self::Previous) -> #return_type { + // Call the original function + #fn_name(#prev_param).await + } + } + + // Keep the original function for internal use + #(#fn_attrs)* + #vis async fn #fn_name(#prev_param: #previous_type) -> #return_type { + #fn_body + } + }; + + expanded.into() +} |
