1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
|
use std::collections::BTreeMap;
use crate::{
bill::{Bills, SplitResult, SplitResultItem},
error::BillSplitError,
who::Who,
};
pub fn calculate_from(item: Bills) -> Result<SplitResult, BillSplitError> {
// Validate input data
precheck(&item)?;
// Calculate each person's net balance and original transactions
let (direct_transactions, items) = calculate_balances_and_transactions(&item);
// Generate the simplest result: net settlement between each pair
let final_result = calculate_net_settlements(&direct_transactions);
// Add "Total" reason to final_result
let mut items = items;
add_total_reason(&mut items);
Ok(SplitResult {
items,
final_result,
})
}
fn precheck(item: &Bills) -> Result<(), BillSplitError> {
for bill_item in item.items.values() {
// Check if the paid amount is negative
if bill_item.paid < 0.0 {
return Err(BillSplitError::NegativePaidAmount);
}
// Check for duplicate members in the split list
let mut seen = std::collections::HashSet::new();
for person in &bill_item.split {
if !seen.insert(person) {
return Err(BillSplitError::DuplicateSplitMembers);
}
}
}
Ok(())
}
type DirectTransactions = BTreeMap<(Who, Who), f64>;
type ItemsByWho = BTreeMap<Who, Vec<SplitResultItem>>;
fn calculate_balances_and_transactions(item: &Bills) -> (DirectTransactions, ItemsByWho) {
let mut direct_transactions: BTreeMap<(Who, Who), f64> = BTreeMap::new();
let mut items: BTreeMap<Who, Vec<SplitResultItem>> = BTreeMap::new();
for bill_item in item.items.values() {
let who_paid = &bill_item.who_paid;
let paid = bill_item.paid;
let split_count = bill_item.split.len() as f64;
if split_count == 0.0 {
continue;
}
// Round
let share = (paid / split_count * 100.0).round() / 100.0;
// Calculate the amount each person should pay
for person in &bill_item.split {
// If the payer is also in the split list, deduct their own share
if person != who_paid {
// Record direct transaction
let key = (person.clone(), who_paid.clone());
*direct_transactions.entry(key).or_insert(0.0) += share;
// Add to full record
let bill_result_item = SplitResultItem {
payee: who_paid.clone(),
bill: share,
reason: bill_item.reason.clone(),
};
items
.entry(person.clone())
.or_default()
.push(bill_result_item);
}
}
}
(direct_transactions, items)
}
fn calculate_net_settlements(
direct_transactions: &BTreeMap<(Who, Who), f64>,
) -> BTreeMap<(Who, Who), f64> {
let mut final_result: BTreeMap<(Who, Who), f64> = BTreeMap::new();
// First, calculate net amounts for each transaction pair
let mut net_transactions: BTreeMap<(Who, Who), f64> = BTreeMap::new();
for ((from, to), amount) in direct_transactions {
let key = (from.clone(), to.clone());
*net_transactions.entry(key).or_insert(0.0) += amount;
}
// Now process net transactions, ensuring correct direction
let mut processed_pairs = std::collections::HashSet::new();
for ((from, to), amount) in &net_transactions {
// Create a normalized transaction pair key (sorted alphabetically)
let pair_key = if from < to {
(from.clone(), to.clone())
} else {
(to.clone(), from.clone())
};
// If this pair has already been processed, skip it
if processed_pairs.contains(&pair_key) {
continue;
}
processed_pairs.insert(pair_key.clone());
// Check for reverse transaction
let reverse_key = (to.clone(), from.clone());
if let Some(reverse_amount) = net_transactions.get(&reverse_key) {
// There is a reverse transaction, calculate net amount
let net_amount = *amount - *reverse_amount;
if net_amount > 0.0001 {
// from owes to (net)
final_result.insert(
(from.clone(), to.clone()),
(net_amount * 100.0).round() / 100.0,
);
} else if net_amount < -0.0001 {
// to owes from (net)
final_result.insert(
(to.clone(), from.clone()),
(-net_amount * 100.0).round() / 100.0,
);
}
// If net amount is close to 0, don't add any transaction
} else {
// No reverse transaction, add directly
if *amount > 0.0001 {
final_result.insert(
(from.clone(), to.clone()),
(*amount * 100.0).round() / 100.0,
);
}
}
}
final_result
}
fn add_total_reason(items: &mut BTreeMap<Who, Vec<SplitResultItem>>) {
for (_payer, bills_list) in items.iter_mut() {
for bill in bills_list {
bill.reason = format!("{} (Total)", bill.reason);
}
}
}
|