summaryrefslogtreecommitdiff
path: root/src/bin/butckrepo-refresh.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/bin/butckrepo-refresh.rs')
-rw-r--r--src/bin/butckrepo-refresh.rs619
1 files changed, 619 insertions, 0 deletions
diff --git a/src/bin/butckrepo-refresh.rs b/src/bin/butckrepo-refresh.rs
new file mode 100644
index 0000000..9184efb
--- /dev/null
+++ b/src/bin/butckrepo-refresh.rs
@@ -0,0 +1,619 @@
+use colored::Colorize;
+use just_fmt::fmt_path::fmt_path_str;
+use just_template::{Template, tmpl, tmpl_param};
+use std::{
+ env::current_dir,
+ path::{Path, PathBuf},
+};
+use tokio::fs;
+
+const LIB_RS_TEMPLATE_PATH: &str = "policy/_policies/src/lib.rs.t";
+const CARGO_TOML_TEMPLATE_PATH: &str = "policy/_policies/Cargo.toml.t";
+const LIB_RS_PATH: &str = "./policy/_policies/src/lib.rs";
+const CARGO_TOML_PATH: &str = "./policy/_policies/Cargo.toml";
+
+#[tokio::main]
+async fn main() {
+ let current_dir = current_dir().unwrap();
+ precheck(&current_dir).await;
+
+ println!("Updating policies ...");
+ let (mut lib_rs_template, mut cargo_toml_template) = {
+ let lib_rs_template_path = current_dir.join("policy/_policies/src/lib.rs.t");
+ let cargo_toml_template_path = current_dir.join("policy/_policies/Cargo.toml.t");
+
+ let lib_rs_content = fs::read_to_string(&lib_rs_template_path)
+ .await
+ .unwrap_or_else(|_| {
+ eprintln!(
+ "{}",
+ format!(
+ "Error: Failed to read template file: {}",
+ lib_rs_template_path.display()
+ )
+ .red()
+ );
+ std::process::exit(1);
+ });
+
+ let cargo_toml_content = fs::read_to_string(&cargo_toml_template_path)
+ .await
+ .unwrap_or_else(|_| {
+ eprintln!(
+ "{}",
+ format!(
+ "Error: Failed to read template file: {}",
+ cargo_toml_template_path.display()
+ )
+ .red()
+ );
+ std::process::exit(1);
+ });
+
+ (
+ Template::from(lib_rs_content),
+ Template::from(cargo_toml_content),
+ )
+ };
+
+ let cargo_toml_pathes = find_cargo_toml_dirs(&current_dir.join("policy")).await;
+ println!(
+ "Found {} crates, register to `{}`",
+ cargo_toml_pathes.len(),
+ CARGO_TOML_PATH.bright_green()
+ );
+
+ tmpl_param!(lib_rs_template, policy_count = cargo_toml_pathes.len());
+
+ let collect_futures = cargo_toml_pathes.iter().map(collect).collect::<Vec<_>>();
+
+ for policy in futures::future::join_all(collect_futures).await {
+ let Some(policy) = policy else { continue };
+ tmpl!(cargo_toml_template += {
+ deps { (crate_name = policy.crate_name, path = policy.path) }
+ });
+ // Determine which export template to use based on detected functions
+ if policy.matched_func_stream.is_some() {
+ let stream_struct_id = format!(
+ "{}::{}",
+ policy.crate_name,
+ policy.stream_struct_id.unwrap()
+ );
+ if policy.matched_func.is_empty() {
+ // Only stream function
+ tmpl!(lib_rs_template += {
+ exports_stream { (
+ crate_name = policy.crate_name,
+ matched_func_stream = policy.matched_func_stream.unwrap(),
+ has_await_stream =
+ if policy.matched_func_stream_has_await { ".await" } else { "" },
+ stream_struct_id = stream_struct_id
+ ) },
+ match_arms { (
+ crate_name = policy.crate_name,
+ ) },
+ match_arms_stream { (
+ crate_name = policy.crate_name,
+ stream_struct_id = stream_struct_id
+ ) },
+ policy_names { (
+ name = policy.crate_name,
+ ) }
+ });
+ } else {
+ // Both simple and stream functions
+ tmpl!(lib_rs_template += {
+ exports_both { (
+ crate_name = policy.crate_name,
+ matched_func = policy.matched_func,
+ has_await =
+ if policy.matched_func_has_await { ".await" } else { "" },
+ matched_func_stream = policy.matched_func_stream.unwrap(),
+ has_await_stream =
+ if policy.matched_func_stream_has_await { ".await" } else { "" },
+ stream_struct_id = stream_struct_id
+ ) },
+ match_arms { (
+ crate_name = policy.crate_name,
+ ) },
+ match_arms_stream { (
+ crate_name = policy.crate_name,
+ stream_struct_id = stream_struct_id
+ ) },
+ policy_names { (
+ name = policy.crate_name,
+ ) }
+ });
+ }
+ } else {
+ // Only simple function
+ tmpl!(lib_rs_template += {
+ exports_simple { (
+ crate_name = policy.crate_name,
+ matched_func = policy.matched_func,
+ has_await =
+ if policy.matched_func_has_await { ".await" } else { "" }
+ ) },
+ match_arms { (
+ crate_name = policy.crate_name,
+ ) },
+ policy_names { (
+ name = policy.crate_name,
+ ) }
+ });
+ }
+ }
+
+ let (write_cargo, write_lib) = tokio::join!(
+ fs::write(CARGO_TOML_PATH, cargo_toml_template.expand().unwrap()),
+ fs::write(LIB_RS_PATH, lib_rs_template.expand().unwrap())
+ );
+ write_cargo.unwrap();
+ write_lib.unwrap();
+}
+
+struct CollectedPolicy {
+ crate_name: String,
+ path: String,
+ matched_func: String,
+ matched_func_has_await: bool,
+ matched_func_stream: Option<String>,
+ matched_func_stream_has_await: bool,
+ stream_struct_id: Option<String>,
+}
+
+async fn collect(policy_crate_path: &PathBuf) -> Option<CollectedPolicy> {
+ let lib_rs_path = policy_crate_path.join("src").join("lib.rs");
+ let lib_rs_content = fs::read_to_string(&lib_rs_path).await.ok()?;
+
+ let cargo_toml_content = fs::read_to_string(policy_crate_path.join("Cargo.toml"))
+ .await
+ .ok()?;
+ let cargo_toml: toml::Value = toml::from_str(&cargo_toml_content).ok()?;
+ let crate_name = cargo_toml
+ .get("package")?
+ .get("name")?
+ .as_str()?
+ .to_string();
+ let crate_path = fmt_path_str(
+ policy_crate_path
+ .strip_prefix(current_dir().unwrap())
+ .unwrap()
+ .to_string_lossy(),
+ )
+ .ok()?;
+
+ let (
+ matched_func,
+ matched_func_has_await,
+ matched_func_stream,
+ matched_func_stream_has_await,
+ stream_struct_id,
+ ) = collect_matched_func(lib_rs_content.as_str())?;
+
+ println!(
+ "{} {} (at: `{}`) with func `{}{}{}{}(..)`",
+ "Register:".bright_blue().bold(),
+ crate_name,
+ crate_path.bright_green(),
+ "pub ".bright_magenta(),
+ if matched_func_has_await { "async " } else { "" }.bright_magenta(),
+ "fn ".bright_magenta(),
+ matched_func.bright_blue(),
+ );
+ if let Some(stream_func) = &matched_func_stream {
+ println!(
+ " and stream func `{}{}{}{}(..)`",
+ "pub ".bright_magenta(),
+ if matched_func_stream_has_await {
+ "async "
+ } else {
+ ""
+ }
+ .bright_magenta(),
+ "fn ".bright_magenta(),
+ stream_func.bright_blue()
+ );
+ }
+
+ Some(CollectedPolicy {
+ crate_name,
+ path: crate_path,
+ matched_func,
+ matched_func_has_await,
+ matched_func_stream,
+ matched_func_stream_has_await,
+ stream_struct_id,
+ })
+}
+
+fn collect_matched_func(
+ lib_rs_content: &str,
+) -> Option<(String, bool, Option<String>, bool, Option<String>)> {
+ let syntax_tree = syn::parse_file(lib_rs_content).ok()?;
+
+ let mut matched_func = None;
+ let mut matched_func_has_await = false;
+ let mut matched_func_stream = None;
+ let mut matched_func_stream_has_await = false;
+ let mut stream_struct_id = None;
+
+ // Iterate over all items, looking for functions that match the criteria
+ for item in &syntax_tree.items {
+ let syn::Item::Fn(func) = item else { continue };
+
+ // Check if the function visibility is pub
+ if !matches!(func.vis, syn::Visibility::Public(_)) {
+ continue;
+ }
+
+ let sig = &func.sig;
+
+ // Check for simple chunk function (returns Vec<u32>)
+ if check_simple_chunk_function(sig) {
+ matched_func = Some(sig.ident.to_string());
+ matched_func_has_await = sig.asyncness.is_some();
+ }
+ // Check for stream chunk function (returns Option<u8>)
+ else if let Some(struct_id) = check_stream_chunk_function(sig, &syntax_tree) {
+ matched_func_stream = Some(sig.ident.to_string());
+ matched_func_stream_has_await = sig.asyncness.is_some();
+ stream_struct_id = Some(struct_id);
+ }
+ }
+
+ if matched_func.is_some() || matched_func_stream.is_some() {
+ Some((
+ matched_func.unwrap_or_default(),
+ matched_func_has_await,
+ matched_func_stream,
+ matched_func_stream_has_await,
+ stream_struct_id,
+ ))
+ } else {
+ None
+ }
+}
+
+fn check_simple_chunk_function(sig: &syn::Signature) -> bool {
+ // Check if the return type is Vec<u32>
+ let return_type_matches = match &sig.output {
+ syn::ReturnType::Type(_, ty) => {
+ let syn::Type::Path(type_path) = &**ty else {
+ return false;
+ };
+ let segments = &type_path.path.segments;
+
+ segments.len() == 1
+ && segments[0].ident == "Vec"
+ && matches!(&segments[0].arguments, syn::PathArguments::AngleBracketed(args)
+ if args.args.len() == 1 &&
+ matches!(&args.args[0], syn::GenericArgument::Type(syn::Type::Path(inner_type))
+ if inner_type.path.segments.len() == 1 &&
+ inner_type.path.segments[0].ident == "u32"
+ )
+ )
+ }
+ _ => false,
+ };
+
+ if !return_type_matches {
+ return false;
+ }
+
+ // Check that there are exactly 2 parameters
+ if sig.inputs.len() != 2 {
+ return false;
+ }
+
+ // Check that the first parameter type is &[u8]
+ let first_param_matches = match &sig.inputs[0] {
+ syn::FnArg::Typed(pat_type) => {
+ let syn::Type::Reference(type_ref) = &*pat_type.ty else {
+ return false;
+ };
+ let syn::Type::Slice(slice_type) = &*type_ref.elem else {
+ return false;
+ };
+ let syn::Type::Path(type_path) = &*slice_type.elem else {
+ return false;
+ };
+
+ type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "u8"
+ }
+ _ => false,
+ };
+
+ // Check that the second parameter type is &HashMap<&str, &str>
+ let second_param_matches = match &sig.inputs[1] {
+ syn::FnArg::Typed(pat_type) => {
+ let syn::Type::Reference(type_ref) = &*pat_type.ty else {
+ return false;
+ };
+ let syn::Type::Path(type_path) = &*type_ref.elem else {
+ return false;
+ };
+
+ type_path.path.segments.len() == 1
+ && type_path.path.segments[0].ident == "HashMap"
+ && matches!(&type_path.path.segments[0].arguments, syn::PathArguments::AngleBracketed(args)
+ if args.args.len() == 2 &&
+ matches!(&args.args[0], syn::GenericArgument::Type(syn::Type::Reference(first_ref))
+ if matches!(&*first_ref.elem, syn::Type::Path(first_path)
+ if first_path.path.segments.len() == 1 &&
+ first_path.path.segments[0].ident == "str"
+ )
+ ) &&
+ matches!(&args.args[1], syn::GenericArgument::Type(syn::Type::Reference(second_ref))
+ if matches!(&*second_ref.elem, syn::Type::Path(second_path)
+ if second_path.path.segments.len() == 1 &&
+ second_path.path.segments[0].ident == "str"
+ )
+ )
+ )
+ }
+ _ => false,
+ };
+
+ first_param_matches && second_param_matches
+}
+
+fn check_stream_chunk_function(sig: &syn::Signature, syntax_tree: &syn::File) -> Option<String> {
+ // Check if the return type is Option<u32>
+ let return_type_matches = match &sig.output {
+ syn::ReturnType::Type(_, ty) => {
+ let syn::Type::Path(type_path) = &**ty else {
+ return None;
+ };
+ let segments = &type_path.path.segments;
+
+ segments.len() == 1
+ && segments[0].ident == "Option"
+ && matches!(&segments[0].arguments, syn::PathArguments::AngleBracketed(args)
+ if args.args.len() == 1 &&
+ matches!(&args.args[0], syn::GenericArgument::Type(syn::Type::Path(inner_type))
+ if inner_type.path.segments.len() == 1 &&
+ inner_type.path.segments[0].ident == "u32"
+ )
+ )
+ }
+ _ => false,
+ };
+
+ if !return_type_matches {
+ return None;
+ }
+
+ // Check that there are exactly 4 parameters
+ if sig.inputs.len() != 4 {
+ return None;
+ }
+
+ // Check that the first parameter type is &[u8]
+ let first_param_matches = match &sig.inputs[0] {
+ syn::FnArg::Typed(pat_type) => {
+ let syn::Type::Reference(type_ref) = &*pat_type.ty else {
+ return None;
+ };
+ let syn::Type::Slice(slice_type) = &*type_ref.elem else {
+ return None;
+ };
+ let syn::Type::Path(type_path) = &*slice_type.elem else {
+ return None;
+ };
+
+ // Check it's u8
+ type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "u8"
+ }
+ _ => false,
+ };
+
+ // Check that the second parameter type is u32
+ let second_param_matches = match &sig.inputs[1] {
+ syn::FnArg::Typed(pat_type) => {
+ let syn::Type::Path(type_path) = &*pat_type.ty else {
+ return None;
+ };
+ type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "u32"
+ }
+ _ => false,
+ };
+
+ // Check that the third parameter type is &mut T where T is a struct defined in this crate
+ let third_param_info = match &sig.inputs[2] {
+ syn::FnArg::Typed(pat_type) => {
+ let syn::Type::Reference(type_ref) = &*pat_type.ty else {
+ return None;
+ };
+
+ // Check it's mutable reference
+ type_ref.mutability?;
+
+ // Get the inner type
+ let syn::Type::Path(type_path) = &*type_ref.elem else {
+ return None;
+ };
+
+ // Get the struct identifier
+ if type_path.path.segments.len() != 1 {
+ return None;
+ }
+
+ let struct_ident = type_path.path.segments[0].ident.to_string();
+
+ // Check if this struct is defined in the current crate and implements Default
+ if is_struct_defined_in_crate(&struct_ident, syntax_tree) {
+ Some(struct_ident)
+ } else {
+ None
+ }
+ }
+ _ => None,
+ };
+
+ let struct_ident = third_param_info?;
+
+ // Check that the fourth parameter type is &HashMap<&str, &str>
+ let fourth_param_matches = match &sig.inputs[3] {
+ syn::FnArg::Typed(pat_type) => {
+ let syn::Type::Reference(type_ref) = &*pat_type.ty else {
+ return None;
+ };
+ let syn::Type::Path(type_path) = &*type_ref.elem else {
+ return None;
+ };
+
+ type_path.path.segments.len() == 1
+ && type_path.path.segments[0].ident == "HashMap"
+ && matches!(&type_path.path.segments[0].arguments, syn::PathArguments::AngleBracketed(args)
+ if args.args.len() == 2 &&
+ matches!(&args.args[0], syn::GenericArgument::Type(syn::Type::Reference(first_ref))
+ if matches!(&*first_ref.elem, syn::Type::Path(first_path)
+ if first_path.path.segments.len() == 1 &&
+ first_path.path.segments[0].ident == "str"
+ )
+ ) &&
+ matches!(&args.args[1], syn::GenericArgument::Type(syn::Type::Reference(second_ref))
+ if matches!(&*second_ref.elem, syn::Type::Path(second_path)
+ if second_path.path.segments.len() == 1 &&
+ second_path.path.segments[0].ident == "str"
+ )
+ )
+ )
+ }
+ _ => false,
+ };
+
+ if first_param_matches && second_param_matches && fourth_param_matches {
+ Some(struct_ident)
+ } else {
+ None
+ }
+}
+
+fn is_struct_defined_in_crate(struct_ident: &str, syntax_tree: &syn::File) -> bool {
+ for item in &syntax_tree.items {
+ match item {
+ syn::Item::Struct(item_struct) => {
+ if item_struct.ident == struct_ident {
+ // Check if it implements Default via derive attribute
+ return has_default_derive(&item_struct.attrs)
+ || has_default_trait_bound(&item_struct.generics);
+ }
+ }
+ _ => continue,
+ }
+ }
+ false
+}
+
+fn has_default_derive(attrs: &[syn::Attribute]) -> bool {
+ for attr in attrs {
+ if attr.path().is_ident("derive") {
+ // Parse the attribute meta to check for Default
+ if let syn::Meta::List(list) = &attr.meta {
+ // Convert tokens to string and check for Default
+ let tokens = list.tokens.to_string();
+ if tokens.contains("Default") {
+ return true;
+ }
+ }
+ }
+ }
+ false
+}
+
+fn has_default_trait_bound(generics: &syn::Generics) -> bool {
+ for param in &generics.params {
+ if let syn::GenericParam::Type(type_param) = param {
+ for bound in &type_param.bounds {
+ if let syn::TypeParamBound::Trait(trait_bound) = bound {
+ let path = &trait_bound.path;
+ if path.segments.len() == 1 && path.segments[0].ident == "Default" {
+ return true;
+ }
+ }
+ }
+ }
+ }
+ false
+}
+
+async fn find_cargo_toml_dirs(root: &Path) -> Vec<PathBuf> {
+ let mut result = Vec::new();
+ let mut dirs_to_visit = vec![root.to_path_buf()];
+
+ while let Some(current_dir) = dirs_to_visit.pop() {
+ let cargo_toml_path = current_dir.join("Cargo.toml");
+ if fs::metadata(&cargo_toml_path).await.is_ok() {
+ result.push(current_dir);
+ continue;
+ }
+
+ let mut read_dir = match fs::read_dir(&current_dir).await {
+ Ok(rd) => rd,
+ Err(_) => continue,
+ };
+
+ while let Ok(Some(entry)) = read_dir.next_entry().await {
+ if let Ok(file_type) = entry.file_type().await
+ && file_type.is_dir()
+ {
+ let path = entry.path();
+ if let Some(file_name) = path.file_name()
+ && let Some(name_str) = file_name.to_str()
+ && name_str.starts_with('_')
+ {
+ continue;
+ }
+ dirs_to_visit.push(path);
+ }
+ }
+ }
+
+ result
+}
+
+async fn precheck(current_dir: &Path) {
+ let cargo_toml_path = current_dir.join("Cargo.toml");
+ let cargo_toml_content = fs::read_to_string(&cargo_toml_path)
+ .await
+ .unwrap_or_else(|_| {
+ eprintln!(
+ "{}",
+ "Error: Cargo.toml not found in current directory".red()
+ );
+ std::process::exit(1);
+ });
+ let cargo_toml: toml::Value = toml::from_str(&cargo_toml_content).unwrap_or_else(|_| {
+ eprintln!("{}", "Error: Failed to parse Cargo.toml".red());
+ std::process::exit(1);
+ });
+ let package_name = cargo_toml
+ .get("package")
+ .unwrap_or_else(|| {
+ eprintln!("{}", "Error: No package section in Cargo.toml".red());
+ std::process::exit(1);
+ })
+ .get("name")
+ .unwrap_or_else(|| {
+ eprintln!("{}", "Error: No package.name in Cargo.toml".red());
+ std::process::exit(1);
+ })
+ .as_str()
+ .unwrap_or_else(|| {
+ eprintln!("{}", "Error: package.name is not a string".red());
+ std::process::exit(1);
+ });
+ if package_name != "butchunker" {
+ eprintln!(
+ "{}",
+ format!(
+ "Error: package.name must be 'butchunker', found '{}'",
+ package_name
+ )
+ .red()
+ );
+ std::process::exit(1);
+ }
+}