use colored::Colorize; use just_fmt::fmt_path::fmt_path_str; use just_template::{Template, tmpl, tmpl_param}; use std::{ env::current_dir, path::{Path, PathBuf}, }; use tokio::fs; const LIB_RS_TEMPLATE_PATH: &str = "policy/_policies/src/lib.rs.tmpl"; const CARGO_TOML_TEMPLATE_PATH: &str = "policy/_policies/Cargo.toml.tmpl"; const LIB_RS_PATH: &str = "./policy/_policies/src/lib.rs"; const CARGO_TOML_PATH: &str = "./policy/_policies/Cargo.toml"; #[tokio::main] async fn main() { let current_dir = current_dir().unwrap(); precheck(¤t_dir).await; println!("Updating policies ..."); let (mut lib_rs_template, mut cargo_toml_template) = { let lib_rs_template_path = current_dir.join(LIB_RS_TEMPLATE_PATH); let cargo_toml_template_path = current_dir.join(CARGO_TOML_TEMPLATE_PATH); let lib_rs_content = fs::read_to_string(&lib_rs_template_path) .await .unwrap_or_else(|_| { eprintln!( "{}", format!( "Error: Failed to read template file: {}", lib_rs_template_path.display() ) .red() ); std::process::exit(1); }); let cargo_toml_content = fs::read_to_string(&cargo_toml_template_path) .await .unwrap_or_else(|_| { eprintln!( "{}", format!( "Error: Failed to read template file: {}", cargo_toml_template_path.display() ) .red() ); std::process::exit(1); }); ( Template::from(lib_rs_content), Template::from(cargo_toml_content), ) }; let cargo_toml_pathes = find_cargo_toml_dirs(¤t_dir.join("policy")).await; println!( "{} {} crates, register to `{}`", "Found".bright_blue().bold(), cargo_toml_pathes.len(), CARGO_TOML_PATH.bright_green() ); tmpl_param!(lib_rs_template, policy_count = cargo_toml_pathes.len()); let collect_futures = cargo_toml_pathes .iter() .map(|path| collect(path)) .collect::>(); for policy in futures::future::join_all(collect_futures).await { let Some(policy) = policy else { continue }; tmpl!(cargo_toml_template += { deps { (crate_name = policy.crate_name, path = policy.path) } }); // Determine which export template to use based on detected functions if policy.matched_func_stream.is_some() { let stream_struct_id = format!( "{}::{}", policy.crate_name, policy.stream_struct_id.unwrap() ); if policy.matched_func.is_empty() { // Only stream function tmpl!(lib_rs_template += { exports_stream { ( crate_name = policy.crate_name, matched_func_stream = policy.matched_func_stream.unwrap(), has_await_stream = if policy.matched_func_stream_has_await { ".await" } else { "" }, stream_struct_id = stream_struct_id ) }, match_arms { ( crate_name = policy.crate_name, ) }, match_arms_stream { ( crate_name = policy.crate_name, stream_struct_id = stream_struct_id ) }, match_arms_stream_display { ( crate_name = policy.crate_name, stream_struct_id = stream_struct_id ) }, stream_policy_names { ( name = policy.crate_name ) } }); } else { // Both simple and stream functions tmpl!(lib_rs_template += { exports_both { ( crate_name = policy.crate_name, matched_func = policy.matched_func, has_await = if policy.matched_func_has_await { ".await" } else { "" }, matched_func_stream = policy.matched_func_stream.unwrap(), has_await_stream = if policy.matched_func_stream_has_await { ".await" } else { "" }, stream_struct_id = stream_struct_id ) }, match_arms { ( crate_name = policy.crate_name, ) }, match_arms_stream { ( crate_name = policy.crate_name, stream_struct_id = stream_struct_id ) }, match_arms_stream_display { ( crate_name = policy.crate_name, stream_struct_id = stream_struct_id ) }, policy_names { ( name = policy.crate_name, ) }, stream_policy_names { ( name = policy.crate_name ) } }); } } else { // Only simple function tmpl!(lib_rs_template += { exports_simple { ( crate_name = policy.crate_name, matched_func = policy.matched_func, has_await = if policy.matched_func_has_await { ".await" } else { "" } ) }, match_arms { ( crate_name = policy.crate_name, ) }, policy_names { ( name = policy.crate_name, ) } }); } } let (write_cargo, write_lib) = tokio::join!( fs::write(CARGO_TOML_PATH, cargo_toml_template.expand().unwrap()), fs::write(LIB_RS_PATH, lib_rs_template.expand().unwrap()) ); write_cargo.unwrap(); write_lib.unwrap(); } struct CollectedPolicy { crate_name: String, path: String, matched_func: String, matched_func_has_await: bool, matched_func_stream: Option, matched_func_stream_has_await: bool, stream_struct_id: Option, } type MatchedFuncInfo = (String, bool, Option, bool, Option); async fn collect(policy_crate_path: &Path) -> Option { let lib_rs_path = policy_crate_path.join("src").join("lib.rs"); let lib_rs_content = fs::read_to_string(&lib_rs_path).await.ok()?; let cargo_toml_content = fs::read_to_string(policy_crate_path.join("Cargo.toml")) .await .ok()?; let cargo_toml: toml::Value = toml::from_str(&cargo_toml_content).ok()?; let crate_name = cargo_toml .get("package")? .get("name")? .as_str()? .to_string(); let crate_path = fmt_path_str( policy_crate_path .strip_prefix(current_dir().unwrap()) .unwrap() .to_string_lossy(), ) .ok()?; let ( matched_func, matched_func_has_await, matched_func_stream, matched_func_stream_has_await, stream_struct_id, ) = collect_matched_func(lib_rs_content.as_str())?; if !matched_func.is_empty() { println!( "{} {} (at: `{}`) with func `{}{}{}{}(..)`", "Register:".bright_blue().bold(), crate_name, crate_path.bright_green(), "pub ".bright_magenta(), if matched_func_has_await { "async " } else { "" }.bright_magenta(), "fn ".bright_magenta(), matched_func.bright_blue(), ); } if let Some(stream_func) = &matched_func_stream { println!( "{} {} (at: `{}`) with stream func `{}{}{}{}(..)`", "Register:".bright_blue().bold(), crate_name, crate_path.bright_green(), "pub ".bright_magenta(), if matched_func_stream_has_await { "async " } else { "" } .bright_magenta(), "fn ".bright_magenta(), stream_func.bright_blue() ); } Some(CollectedPolicy { crate_name, path: crate_path, matched_func, matched_func_has_await, matched_func_stream, matched_func_stream_has_await, stream_struct_id, }) } fn collect_matched_func(lib_rs_content: &str) -> Option { let syntax_tree = syn::parse_file(lib_rs_content).ok()?; let mut matched_func = None; let mut matched_func_has_await = false; let mut matched_func_stream = None; let mut matched_func_stream_has_await = false; let mut stream_struct_id = None; // Iterate over all items, looking for functions that match the criteria for item in &syntax_tree.items { let syn::Item::Fn(func) = item else { continue }; // Check if the function visibility is pub if !matches!(func.vis, syn::Visibility::Public(_)) { continue; } let sig = &func.sig; // Check for simple chunk function (returns Vec) if check_simple_chunk_function(sig) { matched_func = Some(sig.ident.to_string()); matched_func_has_await = sig.asyncness.is_some(); } // Check for stream chunk function (returns Option) else if let Some(struct_id) = check_stream_chunk_function(sig, &syntax_tree) { matched_func_stream = Some(sig.ident.to_string()); matched_func_stream_has_await = sig.asyncness.is_some(); stream_struct_id = Some(struct_id); } } if matched_func.is_some() || matched_func_stream.is_some() { Some(( matched_func.unwrap_or_default(), matched_func_has_await, matched_func_stream, matched_func_stream_has_await, stream_struct_id, )) } else { None } } fn check_simple_chunk_function(sig: &syn::Signature) -> bool { // Check if the return type is Vec let return_type_matches = match &sig.output { syn::ReturnType::Type(_, ty) => { let syn::Type::Path(type_path) = &**ty else { return false; }; let segments = &type_path.path.segments; segments.len() == 1 && segments[0].ident == "Vec" && matches!(&segments[0].arguments, syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 && matches!(&args.args[0], syn::GenericArgument::Type(syn::Type::Path(inner_type)) if inner_type.path.segments.len() == 1 && inner_type.path.segments[0].ident == "u32" ) ) } _ => false, }; if !return_type_matches { return false; } // Check that there are exactly 2 parameters if sig.inputs.len() != 2 { return false; } // Check that the first parameter type is &[u8] let first_param_matches = match &sig.inputs[0] { syn::FnArg::Typed(pat_type) => { let syn::Type::Reference(type_ref) = &*pat_type.ty else { return false; }; let syn::Type::Slice(slice_type) = &*type_ref.elem else { return false; }; let syn::Type::Path(type_path) = &*slice_type.elem else { return false; }; type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "u8" } _ => false, }; // Check that the second parameter type is &HashMap<&str, &str> let second_param_matches = match &sig.inputs[1] { syn::FnArg::Typed(pat_type) => { let syn::Type::Reference(type_ref) = &*pat_type.ty else { return false; }; let syn::Type::Path(type_path) = &*type_ref.elem else { return false; }; type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "HashMap" && matches!(&type_path.path.segments[0].arguments, syn::PathArguments::AngleBracketed(args) if args.args.len() == 2 && matches!(&args.args[0], syn::GenericArgument::Type(syn::Type::Reference(first_ref)) if matches!(&*first_ref.elem, syn::Type::Path(first_path) if first_path.path.segments.len() == 1 && first_path.path.segments[0].ident == "str" ) ) && matches!(&args.args[1], syn::GenericArgument::Type(syn::Type::Reference(second_ref)) if matches!(&*second_ref.elem, syn::Type::Path(second_path) if second_path.path.segments.len() == 1 && second_path.path.segments[0].ident == "str" ) ) ) } _ => false, }; first_param_matches && second_param_matches } fn check_stream_chunk_function(sig: &syn::Signature, syntax_tree: &syn::File) -> Option { // Check if the return type is Option let return_type_matches = match &sig.output { syn::ReturnType::Type(_, ty) => { let syn::Type::Path(type_path) = &**ty else { return None; }; let segments = &type_path.path.segments; segments.len() == 1 && segments[0].ident == "Option" && matches!(&segments[0].arguments, syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 && matches!(&args.args[0], syn::GenericArgument::Type(syn::Type::Path(inner_type)) if inner_type.path.segments.len() == 1 && inner_type.path.segments[0].ident == "u32" ) ) } _ => false, }; if !return_type_matches { return None; } // Check that there are exactly 4 parameters if sig.inputs.len() != 4 { return None; } // Check that the first parameter type is &[u8] let first_param_matches = match &sig.inputs[0] { syn::FnArg::Typed(pat_type) => { let syn::Type::Reference(type_ref) = &*pat_type.ty else { return None; }; let syn::Type::Slice(slice_type) = &*type_ref.elem else { return None; }; let syn::Type::Path(type_path) = &*slice_type.elem else { return None; }; // Check it's u8 type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "u8" } _ => false, }; // Check that the second parameter type is u32 let second_param_matches = match &sig.inputs[1] { syn::FnArg::Typed(pat_type) => { let syn::Type::Path(type_path) = &*pat_type.ty else { return None; }; type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "u32" } _ => false, }; // Check that the third parameter type is &mut T where T is a struct defined in this crate let third_param_info = match &sig.inputs[2] { syn::FnArg::Typed(pat_type) => { let syn::Type::Reference(type_ref) = &*pat_type.ty else { return None; }; // Check it's mutable reference type_ref.mutability?; // Get the inner type let syn::Type::Path(type_path) = &*type_ref.elem else { return None; }; // Get the struct identifier if type_path.path.segments.len() != 1 { return None; } let struct_ident = type_path.path.segments[0].ident.to_string(); // Check if this struct is defined in the current crate and implements Default if is_struct_defined_in_crate(&struct_ident, syntax_tree) { Some(struct_ident) } else { None } } _ => None, }; let struct_ident = third_param_info?; // Check that the fourth parameter type is &HashMap<&str, &str> let fourth_param_matches = match &sig.inputs[3] { syn::FnArg::Typed(pat_type) => { let syn::Type::Reference(type_ref) = &*pat_type.ty else { return None; }; let syn::Type::Path(type_path) = &*type_ref.elem else { return None; }; type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "HashMap" && matches!(&type_path.path.segments[0].arguments, syn::PathArguments::AngleBracketed(args) if args.args.len() == 2 && matches!(&args.args[0], syn::GenericArgument::Type(syn::Type::Reference(first_ref)) if matches!(&*first_ref.elem, syn::Type::Path(first_path) if first_path.path.segments.len() == 1 && first_path.path.segments[0].ident == "str" ) ) && matches!(&args.args[1], syn::GenericArgument::Type(syn::Type::Reference(second_ref)) if matches!(&*second_ref.elem, syn::Type::Path(second_path) if second_path.path.segments.len() == 1 && second_path.path.segments[0].ident == "str" ) ) ) } _ => false, }; if first_param_matches && second_param_matches && fourth_param_matches { Some(struct_ident) } else { None } } fn is_struct_defined_in_crate(struct_ident: &str, syntax_tree: &syn::File) -> bool { for item in &syntax_tree.items { match item { syn::Item::Struct(item_struct) => { if item_struct.ident == struct_ident { // Check if it implements Default via derive attribute return has_default_derive(&item_struct.attrs) || has_default_trait_bound(&item_struct.generics); } } _ => continue, } } false } fn has_default_derive(attrs: &[syn::Attribute]) -> bool { for attr in attrs { if attr.path().is_ident("derive") { // Parse the attribute meta to check for Default if let syn::Meta::List(list) = &attr.meta { // Convert tokens to string and check for Default let tokens = list.tokens.to_string(); if tokens.contains("Default") { return true; } } } } false } fn has_default_trait_bound(generics: &syn::Generics) -> bool { for param in &generics.params { if let syn::GenericParam::Type(type_param) = param { for bound in &type_param.bounds { if let syn::TypeParamBound::Trait(trait_bound) = bound { let path = &trait_bound.path; if path.segments.len() == 1 && path.segments[0].ident == "Default" { return true; } } } } } false } async fn find_cargo_toml_dirs(root: &Path) -> Vec { let mut result = Vec::new(); let mut dirs_to_visit = vec![root.to_path_buf()]; while let Some(current_dir) = dirs_to_visit.pop() { let cargo_toml_path = current_dir.join("Cargo.toml"); if fs::metadata(&cargo_toml_path).await.is_ok() { result.push(current_dir); continue; } let mut read_dir = match fs::read_dir(¤t_dir).await { Ok(rd) => rd, Err(_) => continue, }; while let Ok(Some(entry)) = read_dir.next_entry().await { if let Ok(file_type) = entry.file_type().await && file_type.is_dir() { let path = entry.path(); if let Some(file_name) = path.file_name() && let Some(name_str) = file_name.to_str() && name_str.starts_with('_') { continue; } dirs_to_visit.push(path); } } } result } async fn precheck(current_dir: &Path) { let cargo_toml_path = current_dir.join("Cargo.toml"); let cargo_toml_content = fs::read_to_string(&cargo_toml_path) .await .unwrap_or_else(|_| { eprintln!( "{}", "Error: Cargo.toml not found in current directory".red() ); std::process::exit(1); }); let cargo_toml: toml::Value = toml::from_str(&cargo_toml_content).unwrap_or_else(|_| { eprintln!("{}", "Error: Failed to parse Cargo.toml".red()); std::process::exit(1); }); let package_name = cargo_toml .get("package") .unwrap_or_else(|| { eprintln!("{}", "Error: No package section in Cargo.toml".red()); std::process::exit(1); }) .get("name") .unwrap_or_else(|| { eprintln!("{}", "Error: No package.name in Cargo.toml".red()); std::process::exit(1); }) .as_str() .unwrap_or_else(|| { eprintln!("{}", "Error: package.name is not a string".red()); std::process::exit(1); }); if package_name != "butchunker" { eprintln!( "{}", format!( "Error: package.name must be 'butchunker', found '{}'", package_name ) .red() ); std::process::exit(1); } }