aboutsummaryrefslogtreecommitdiff
path: root/mingling_macros/src/chain.rs
diff options
context:
space:
mode:
Diffstat (limited to 'mingling_macros/src/chain.rs')
-rw-r--r--mingling_macros/src/chain.rs69
1 files changed, 54 insertions, 15 deletions
diff --git a/mingling_macros/src/chain.rs b/mingling_macros/src/chain.rs
index 84353e9..14e62ec 100644
--- a/mingling_macros/src/chain.rs
+++ b/mingling_macros/src/chain.rs
@@ -4,7 +4,7 @@
//! generating structs that implement the `Chain` trait from async functions.
use proc_macro::TokenStream;
-use quote::quote;
+use quote::{ToTokens, quote};
use syn::spanned::Spanned;
use syn::{
FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Signature, Type, TypePath, parse_macro_input,
@@ -129,7 +129,7 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
#[doc(hidden)]
#vis struct #struct_name;
- ::mingling::macros::register_type!(#previous_type);
+ ::mingling::macros::register_chain!(#previous_type, #struct_name);
impl ::mingling::Chain<ThisProgram> for #struct_name {
type Previous = #previous_type;
@@ -156,6 +156,8 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
#(#fn_attrs)*
#vis struct #struct_name;
+ ::mingling::macros::register_chain!(#previous_type, #struct_name);
+
impl ::mingling::Chain<#group_name> for #struct_name {
type Previous = #previous_type;
@@ -178,19 +180,6 @@ pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
}
};
- // Record the chain mapping
- let chain_entry = build_chain_arm(&struct_name, &previous_type);
- let chain_exist_entry = build_chain_exist_arm(&previous_type);
-
- let mut chains = crate::CHAINS.lock().unwrap();
- let mut chain_exist = crate::CHAINS_EXIST.lock().unwrap();
-
- let chain_entry = chain_entry.to_string();
- let chain_exist_entry = chain_exist_entry.to_string();
-
- chains.insert(chain_entry);
- chain_exist.insert(chain_exist_entry);
-
expanded.into()
}
@@ -207,3 +196,53 @@ pub fn build_chain_exist_arm(previous_type: &TypePath) -> proc_macro2::TokenStre
Self::#previous_type => true,
}
}
+
+pub fn register_chain(input: TokenStream) -> TokenStream {
+ // Parse the input as a comma-separated list of arguments
+ let input_parsed = syn::parse_macro_input!(input with syn::punctuated::Punctuated<syn::Expr, syn::Token![,]>::parse_terminated);
+
+ // Check that we have exactly two elements
+ if input_parsed.len() != 2 {
+ return syn::Error::new(
+ input_parsed.span(),
+ "Expected exactly two comma-separated arguments: `PreviousType, StructName`",
+ )
+ .to_compile_error()
+ .into();
+ }
+
+ // Extract the two elements
+ let previous_type_expr = &input_parsed[0];
+ let struct_name_expr = &input_parsed[1];
+
+ // Convert expressions to TypePath and Ident
+ let previous_type = match syn::parse2::<TypePath>(previous_type_expr.to_token_stream()) {
+ Ok(ty) => ty,
+ Err(e) => return e.to_compile_error().into(),
+ };
+
+ let struct_name = match syn::parse2::<syn::Ident>(struct_name_expr.to_token_stream()) {
+ Ok(ident) => ident,
+ Err(e) => return e.to_compile_error().into(),
+ };
+
+ // Record the chain mapping: previous_type => struct_name
+ let chain_entry = build_chain_arm(&struct_name, &previous_type);
+
+ // Record the chain existence check
+ let chain_exist_entry = build_chain_exist_arm(&previous_type);
+
+ let mut chains = crate::CHAINS.lock().unwrap();
+ let mut chain_exist = crate::CHAINS_EXIST.lock().unwrap();
+
+ let chain_entry_str = chain_entry.to_string();
+ let chain_exist_entry_str = chain_exist_entry.to_string();
+
+ chains.insert(chain_entry_str);
+ chain_exist.insert(chain_exist_entry_str);
+
+ quote! {
+ ::mingling::macros::register_type!(#previous_type);
+ }
+ .into()
+}