summaryrefslogtreecommitdiff
path: root/rola-utils/functions/src/levenshtein_distance.rs
blob: 7b8be6c86be93c1053992b9663b4cc9aa25b3e36 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
pub struct LevenshteinDistance;

impl LevenshteinDistance {
    pub fn filter_similar<'a>(str: &str, slice: &[&'a str], threshold: usize) -> Vec<&'a str> {
        slice
            .iter()
            .filter(|s| levenshtein_distance(str, s) <= threshold)
            .copied()
            .collect()
    }

    pub fn compare(a: impl AsRef<str>, b: impl AsRef<str>) -> usize {
        let a = a.as_ref();
        let b = b.as_ref();
        levenshtein_distance(a, b)
    }
}

fn levenshtein_distance(a: &str, b: &str) -> usize {
    let a_chars: Vec<char> = a.chars().collect();
    let b_chars: Vec<char> = b.chars().collect();
    let m = a_chars.len();
    let n = b_chars.len();
    if m == 0 {
        return n;
    }
    if n == 0 {
        return m;
    }

    let mut prev: Vec<usize> = (0..=n).collect();
    let mut curr = vec![0; n + 1];
    for (i, ca) in a_chars.iter().enumerate() {
        curr[0] = i + 1;
        for (j, cb) in b_chars.iter().enumerate() {
            let cost = if ca == cb { 0 } else { 1 };
            let del = prev[j + 1] + 1;
            let ins = curr[j] + 1;
            let rep = prev[j] + cost;
            curr[j + 1] = del.min(ins).min(rep);
        }
        std::mem::swap(&mut prev, &mut curr);
    }
    prev[n]
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_levenshtein_distance_identical_strings() {
        assert_eq!(LevenshteinDistance::compare("hello", "hello"), 0);
    }

    #[test]
    fn test_levenshtein_distance_completely_different() {
        let dist = LevenshteinDistance::compare("abc", "xyz");
        assert_eq!(dist, 3);
    }

    #[test]
    fn test_levenshtein_distance_one_insertion() {
        assert_eq!(LevenshteinDistance::compare("cat", "cats"), 1);
    }

    #[test]
    fn test_levenshtein_distance_one_deletion() {
        assert_eq!(LevenshteinDistance::compare("dogs", "dog"), 1);
    }

    #[test]
    fn test_levenshtein_distance_one_substitution() {
        assert_eq!(LevenshteinDistance::compare("cat", "cut"), 1);
    }

    #[test]
    fn test_levenshtein_distance_empty_strings() {
        assert_eq!(LevenshteinDistance::compare("", ""), 0);
    }

    #[test]
    fn test_levenshtein_distance_empty_vs_nonempty() {
        assert_eq!(LevenshteinDistance::compare("", "abc"), 3);
        assert_eq!(LevenshteinDistance::compare("xyz", ""), 3);
    }

    #[test]
    fn test_levenshtein_distance_unicode() {
        // Chinese characters
        assert_eq!(LevenshteinDistance::compare("你好", "你好"), 0);
        assert_eq!(LevenshteinDistance::compare("你好", "您好"), 1);
        assert_eq!(LevenshteinDistance::compare("你好", "您不好"), 2);
    }

    #[test]
    fn test_filter_similar_exact_threshold() {
        let words = vec!["cat", "cart", "ca", "cats", "cut"];
        let result = LevenshteinDistance::filter_similar("cat", &words, 2);
        // cat -> cat: 0, cat -> cart: 1, cat -> ca: 1, cat -> cats: 1, cat -> cut: 1
        assert_eq!(result.len(), 5);
    }

    #[test]
    fn test_filter_similar_empty_slice() {
        let words: Vec<&str> = vec![];
        let result = LevenshteinDistance::filter_similar("hello", &words, 3);
        assert!(result.is_empty());
    }

    #[test]
    fn test_filter_similar_threshold_zero() {
        let words = vec!["hello", "hallo", "hello!"];
        let result = LevenshteinDistance::filter_similar("hello", &words, 0);
        assert_eq!(result, vec!["hello"]);
    }
}