From 30dc718170ece8beabc5a2c57944f82c69b7cbf6 Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer Date: Sat, 21 Sep 2024 10:41:43 -0700 Subject: [PATCH] training.py: two tweaks to feature selection (#226) 1. Include posting amounts as a feature. This allows us to distinguish different classes of payments to the same payee (e.g. recurring membership fees, which often have a constant amount, from individual purchases). 2. For example key/value pairs, include the key by itself (with no substring of the value) as a feature. This is useful because different account types often have non-overlapping sets of example keys, and including the bare key as a value allows the decision tree to be effectively segmented by account type fairly close to the root. These two very small changes significantly improve training accuracy on my journal, from 94.81% to 99.32% (an 86% reduction in error rate!). --- beancount_import/training.py | 5 +++-- beancount_import/training_test.py | 6 +++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/beancount_import/training.py b/beancount_import/training.py index 30f74173..007233c4 100644 --- a/beancount_import/training.py +++ b/beancount_import/training.py @@ -30,10 +30,11 @@ def get_features(example: PredictionInput) -> Dict[str, bool]: features = collections.defaultdict(lambda: False) # type: Dict[str, bool] features['account:%s' % example.source_account] = True - - # For now, skip amount and date. + features['amount:%s' % example.amount.currency] = example.amount.number + # For now, skip date. for key, values in example.key_value_pairs.items(): + features[key] = True if isinstance(values, str): values = (values, ) for value in values: diff --git a/beancount_import/training_test.py b/beancount_import/training_test.py index aaabf27e..2ae65eb9 100644 --- a/beancount_import/training_test.py +++ b/beancount_import/training_test.py @@ -1,6 +1,7 @@ import datetime from beancount.core.data import Amount +from beancount.core.number import D from . import test_util from . import training @@ -21,7 +22,10 @@ def test_get_features(): 'a:hello': True, 'b:foo': True, 'b:bar': True, - 'b:foo bar': True + 'b:foo bar': True, + 'a': True, + 'b': True, + 'amount:USD': D(3) }