aboutsummaryrefslogtreecommitdiff
path: root/mingling_macros/src/program_setup.rs
blob: 7fd9d16f28a2704ed8764a7662a770738e4c8956 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use proc_macro::TokenStream;
use quote::quote;
use syn::spanned::Spanned;
use syn::{FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Signature, Type, parse_macro_input};

/// Extracts the program parameter from function arguments
fn extract_program_param(sig: &Signature) -> syn::Result<(Pat, Type)> {
    // The function should have exactly one parameter
    if sig.inputs.len() != 1 {
        return Err(syn::Error::new(
            sig.inputs.span(),
            "Setup function must have exactly one parameter",
        ));
    }

    let arg = &sig.inputs[0];
    match arg {
        FnArg::Typed(PatType { pat, ty, .. }) => {
            // Extract the pattern (parameter name)
            let param_pat = (**pat).clone();
            // Extract the type as-is
            let param_type = (**ty).clone();
            Ok((param_pat, param_type))
        }
        FnArg::Receiver(_) => Err(syn::Error::new(
            arg.span(),
            "Setup function cannot have self parameter",
        )),
    }
}

/// Extracts and validates the return type
fn extract_return_type(sig: &Signature) -> syn::Result<()> {
    // Setup functions should return () or have no return type
    match &sig.output {
        ReturnType::Type(_, ty) => {
            // Check if it's ()
            match &**ty {
                Type::Tuple(tuple) if tuple.elems.is_empty() => Ok(()),
                _ => Err(syn::Error::new(
                    ty.span(),
                    "Setup function must return () or have no return type",
                )),
            }
        }
        ReturnType::Default => Ok(()),
    }
}

pub fn setup_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
    // #[program_setup] takes no arguments; always use the default program path
    let _ = attr;
    let program_path = crate::default_program_path();

    // Parse the function item
    let input_fn = parse_macro_input!(item as ItemFn);

    // Validate the function is not async
    if input_fn.sig.asyncness.is_some() {
        return syn::Error::new(input_fn.sig.span(), "Setup function cannot be async")
            .to_compile_error()
            .into();
    }

    // Extract the program parameter
    let (program_param, program_type) = match extract_program_param(&input_fn.sig) {
        Ok(info) => info,
        Err(e) => return e.to_compile_error().into(),
    };

    // Validate return type
    if let Err(e) = extract_return_type(&input_fn.sig) {
        return e.to_compile_error().into();
    }

    // Get the function body
    let fn_body = &input_fn.block;

    // Get function attributes (excluding the setup attribute)
    let mut fn_attrs = input_fn.attrs.clone();

    // Remove any #[program_setup(...)] attributes to avoid infinite recursion
    fn_attrs.retain(|attr| !attr.path().is_ident("setup"));

    // Get function visibility
    let vis = &input_fn.vis;

    // Get function name
    let fn_name = &input_fn.sig.ident;

    // Generate struct name from function name using pascal_case
    let pascal_case_name = just_fmt::pascal_case!(fn_name.to_string());
    let struct_name = Ident::new(&pascal_case_name, fn_name.span());

    // Generate the struct and implementation
    let expanded = quote! {
        #(#fn_attrs)*
        #[doc(hidden)]
        #vis struct #struct_name;

        impl ::mingling::setup::ProgramSetup<#program_path> for #struct_name {
            fn setup(self, program: &mut ::mingling::Program<#program_path>) {
                // Call the original function with the actual Program type
                #fn_name(program);
            }
        }

        // Keep the original function for internal use
        #(#fn_attrs)*
        #vis fn #fn_name(#program_param: #program_type) {
            #fn_body
        }
    };

    expanded.into()
}