aboutsummaryrefslogtreecommitdiff
path: root/mingling_macros/src/dispatch_tree_gen.rs
blob: 840fe15127d96f20d7acf1cf2416efe097710577 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
//! Dispatch Tree Generation
//!
//! This module generates the dispatch tree code for the `dispatch_tree` feature.
//! It builds a compact, hardcoded match tree at compile time to achieve O(len)
//! command dispatch.
//!
//! # Algorithm
//!
//! For each depth, group nodes by the character at that depth.
//! - If a group has only one node: emit `starts_with` check for the full name.
//! - If a group has multiple nodes: emit a `match raw_chars.nth(depth)` arm and recurse.
//! - At the leaf: call `Dispatcher::begin` on the matched dispatcher.
//!
//! Names are matched with a trailing space (e.g. "hello ") to ensure exact boundary.

use proc_macro2::TokenStream;
use quote::quote;
use std::collections::BTreeMap;

/// Generate the `get_nodes()` function body for a ProgramCollect impl.
pub fn gen_get_nodes(entries: &[(String, String, String)]) -> TokenStream {
    let mut node_entries = Vec::new();

    for (node_name, _disp_type, _entry_name) in entries {
        let static_name_str = format!("__internal_dispatcher_{}", node_name.replace('.', "_"));
        let static_ident = syn::Ident::new(&static_name_str, proc_macro2::Span::call_site());

        let node_display_name = node_name.replace('.', " ");
        let node_display_lit = syn::LitStr::new(&node_display_name, proc_macro2::Span::call_site());

        node_entries.push(quote! {
            (#node_display_lit.to_string(), & #static_ident)
        });
    }

    quote! {
        fn get_nodes() -> Vec<(String, &'static (dyn ::mingling::Dispatcher<Self::Enum> + Send + Sync))> {
            vec![
                #(#node_entries),*
            ]
        }
    }
}

/// Generate the `dispatch_args_trie()` function body for a ProgramCollect impl.
///
/// Builds a hardcoded match tree: at each depth, group nodes by character.
/// Single-node groups use `starts_with`; multi-node groups recurse with `nth()` match.
pub fn gen_dispatch_args_trie(entries: &[(String, String, String)]) -> TokenStream {
    // Prepare (display_name, disp_type) pairs.
    // display_name = node_name.replace('.', " ")
    let nodes: Vec<(String, String)> = entries
        .iter()
        .map(|(name, disp, _)| (name.replace('.', " "), disp.clone()))
        .collect();

    let dispatch_body = build_dispatch_body(&nodes, 0);

    quote! {
        fn dispatch_args_trie(
            raw: &Vec<String>,
        ) -> Result<::mingling::AnyOutput<Self::Enum>, ::mingling::error::ProgramInternalExecuteError>
        {
            let raw_string = format!("{} ", raw.join(" "));
            let raw_str = raw_string.as_str();
            let mut raw_chars = raw_str.chars();
            #dispatch_body
        }
    }
}

/// Recursively build the trie match body.
///
/// `nodes`: slice of (display_name, disp_type) for commands that share the same prefix so far.
/// `depth`: The character index currently being matched.
///
/// Returns a `TokenStream` representing the match block at this depth.
fn build_dispatch_body(nodes: &[(String, String)], depth: usize) -> TokenStream {
    if nodes.is_empty() {
        return quote! {
            return Ok(Self::build_dispatcher_not_found(raw.clone()));
        };
    }

    // Group by character at `depth`
    // Nodes that end exactly at this depth (name is closed – rare but possible, e.g. "hell")
    // are collected into `exact_nodes`. All others go into `groups[char]`.
    let mut groups: BTreeMap<char, Vec<(String, String)>> = BTreeMap::new();
    let mut exact_nodes: Vec<(String, String)> = Vec::new();

    for (name, disp_type) in nodes {
        if let Some(ch) = name.chars().nth(depth) {
            groups
                .entry(ch)
                .or_default()
                .push((name.clone(), disp_type.clone()));
        } else {
            exact_nodes.push((name.clone(), disp_type.clone()));
        }
    }

    // Build a dispatch arm for a single node via `starts_with`
    let make_starts_with_arm = |name: &str, disp_type: &str| -> TokenStream {
        let name_space = format!("{} ", name);
        let name_lit = syn::LitStr::new(&name_space, proc_macro2::Span::call_site());
        let disp_ident = syn::Ident::new(disp_type, proc_macro2::Span::call_site());
        quote! {
            if let Some(stripped) = raw_str.strip_prefix(#name_lit) {
                let __cp = <#disp_ident as ::mingling::Dispatcher<Self::Enum>>::begin(
                    &#disp_ident::default(),
                    stripped
                        .split_whitespace()
                        .map(String::from)
                        .collect::<Vec<String>>(),
                );
                return match __cp {
                    ::mingling::ChainProcess::Ok(any_output) => Ok(any_output.0),
                    ::mingling::ChainProcess::Err(chain_process_error) => {
                        Err(chain_process_error.into())
                    }
                };
            }
        }
    };

    // Build match arms
    let mut arms = Vec::new();

    for (&ch, sub_nodes) in &groups {
        let ch_char = ch;

        if sub_nodes.len() == 1 {
            // Only one candidate – use `starts_with` directly at this depth.
            let (name, disp_type) = &sub_nodes[0];
            let arm = make_starts_with_arm(name, disp_type);
            arms.push(quote! {
                Some(#ch_char) => {
                    #arm
                    return Ok(Self::build_dispatcher_not_found(raw.clone()));
                }
            });
        } else {
            // Multiple candidates – recurse deeper
            let sub_body = build_dispatch_body(sub_nodes, depth + 1);
            arms.push(quote! {
                Some(#ch_char) => {
                    #sub_body
                }
            });
        }
    }

    // Exact-match nodes at this depth
    // These are names that are a prefix of other names (e.g. "hell" when "hello" exists).
    // They are tried first via `starts_with`, then fall through to `raw_chars.nth(depth)` match.
    let exact_checks: Vec<TokenStream> = exact_nodes
        .iter()
        .map(|(name, disp_type)| make_starts_with_arm(name, disp_type))
        .collect();

    // Assemble
    // If there are exact nodes, first check starts_with, then do matching.
    if !exact_checks.is_empty() && !groups.is_empty() {
        // Exact nodes + deeper groups: do starts_with checks, then match on nth(depth)
        let match_body = quote! {
            match raw_chars.nth(0) {
                #(#arms)*
                _ => return Ok(Self::build_dispatcher_not_found(raw.clone())),
            }
        };
        quote! {
            #(#exact_checks)*
            #match_body
        }
    } else if !exact_checks.is_empty() {
        // Only exact nodes, no deeper groups
        quote! {
            #(#exact_checks)*
            return Ok(Self::build_dispatcher_not_found(raw.clone()));
        }
    } else if arms.is_empty() {
        // Only fallback (shouldn't happen if nodes is non-empty)
        quote! {
            return Ok(Self::build_dispatcher_not_found(raw.clone()));
        }
    } else {
        // Only group arms
        quote! {
            match raw_chars.nth(0) {
                #(#arms)*
                _ => return Ok(Self::build_dispatcher_not_found(raw.clone())),
            }
        }
    }
}