aboutsummaryrefslogtreecommitdiff
path: root/src/calc.rs
blob: d1209c3b6d0c7b7a1fea6a9e58cb7230741c54b2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use std::collections::BTreeMap;

use crate::{
    bill::{Bills, SplitResult, SplitResultItem},
    error::BillSplitError,
    who::Who,
};

pub fn calculate_from(item: Bills) -> Result<SplitResult, BillSplitError> {
    // Validate input data
    precheck(&item)?;

    // Calculate each person's net balance and original transactions
    let (direct_transactions, items) = calculate_balances_and_transactions(&item);

    // Generate the simplest result: net settlement between each pair
    let final_result = calculate_net_settlements(&direct_transactions);

    // Add "Total" reason to final_result
    let mut items = items;
    add_total_reason(&mut items);

    Ok(SplitResult {
        items,
        final_result,
    })
}

fn precheck(item: &Bills) -> Result<(), BillSplitError> {
    for bill_item in item.items.values() {
        // Check if the paid amount is negative
        if bill_item.paid < 0.0 {
            return Err(BillSplitError::NegativePaidAmount);
        }

        // Check for duplicate members in the split list
        let mut seen = std::collections::HashSet::new();
        for person in &bill_item.split {
            if !seen.insert(person) {
                return Err(BillSplitError::DuplicateSplitMembers);
            }
        }
    }
    Ok(())
}

type DirectTransactions = BTreeMap<(Who, Who), f64>;
type ItemsByWho = BTreeMap<Who, Vec<SplitResultItem>>;

fn calculate_balances_and_transactions(item: &Bills) -> (DirectTransactions, ItemsByWho) {
    let mut direct_transactions: BTreeMap<(Who, Who), f64> = BTreeMap::new();
    let mut items: BTreeMap<Who, Vec<SplitResultItem>> = BTreeMap::new();

    for bill_item in item.items.values() {
        let who_paid = &bill_item.who_paid;
        let paid = bill_item.paid;
        let split_count = bill_item.split.len() as f64;

        if split_count == 0.0 {
            continue;
        }

        // Round
        let share = (paid / split_count * 100.0).round() / 100.0;

        // Calculate the amount each person should pay
        for person in &bill_item.split {
            // If the payer is also in the split list, deduct their own share
            if person != who_paid {
                // Record direct transaction
                let key = (person.clone(), who_paid.clone());
                *direct_transactions.entry(key).or_insert(0.0) += share;

                // Add to full record
                let bill_result_item = SplitResultItem {
                    payee: who_paid.clone(),
                    bill: share,
                    reason: bill_item.reason.clone(),
                };

                items
                    .entry(person.clone())
                    .or_default()
                    .push(bill_result_item);
            }
        }
    }

    (direct_transactions, items)
}

fn calculate_net_settlements(
    direct_transactions: &BTreeMap<(Who, Who), f64>,
) -> BTreeMap<(Who, Who), f64> {
    let mut final_result: BTreeMap<(Who, Who), f64> = BTreeMap::new();

    // First, calculate net amounts for each transaction pair
    let mut net_transactions: BTreeMap<(Who, Who), f64> = BTreeMap::new();
    for ((from, to), amount) in direct_transactions {
        let key = (from.clone(), to.clone());
        *net_transactions.entry(key).or_insert(0.0) += amount;
    }

    // Now process net transactions, ensuring correct direction
    let mut processed_pairs = std::collections::HashSet::new();

    for ((from, to), amount) in &net_transactions {
        // Create a normalized transaction pair key (sorted alphabetically)
        let pair_key = if from < to {
            (from.clone(), to.clone())
        } else {
            (to.clone(), from.clone())
        };

        // If this pair has already been processed, skip it
        if processed_pairs.contains(&pair_key) {
            continue;
        }
        processed_pairs.insert(pair_key.clone());

        // Check for reverse transaction
        let reverse_key = (to.clone(), from.clone());
        if let Some(reverse_amount) = net_transactions.get(&reverse_key) {
            // There is a reverse transaction, calculate net amount
            let net_amount = *amount - *reverse_amount;

            if net_amount > 0.0001 {
                // from owes to (net)
                final_result.insert(
                    (from.clone(), to.clone()),
                    (net_amount * 100.0).round() / 100.0,
                );
            } else if net_amount < -0.0001 {
                // to owes from (net)
                final_result.insert(
                    (to.clone(), from.clone()),
                    (-net_amount * 100.0).round() / 100.0,
                );
            }
            // If net amount is close to 0, don't add any transaction
        } else {
            // No reverse transaction, add directly
            if *amount > 0.0001 {
                final_result.insert(
                    (from.clone(), to.clone()),
                    (*amount * 100.0).round() / 100.0,
                );
            }
        }
    }

    final_result
}

fn add_total_reason(items: &mut BTreeMap<Who, Vec<SplitResultItem>>) {
    for (_payer, bills_list) in items.iter_mut() {
        for bill in bills_list {
            bill.reason = format!("{} (Total)", bill.reason);
        }
    }
}