Skip to content

Commit

Permalink
support non-authoritative generic sources
Browse files Browse the repository at this point in the history
  • Loading branch information
dumbPy committed Apr 21, 2024
1 parent d85406c commit 3a25d01
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions beancount_import/source/generic_importer_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
class ImporterSource(DescriptionBasedSource):
def __init__(self,
directory: str,
account: str,
importer: ImporterProtocol,
account: Optional[str]=None, # use None for importers that are not authoritative and would not clear any postings
**kwargs) -> None:
super().__init__(**kwargs)
self.directory = os.path.expanduser(directory)
Expand All @@ -57,11 +57,16 @@ def name(self) -> str:
return self.importer.name()

def prepare(self, journal: 'JournalEditor', results: SourceResults) -> None:
results.add_account(self.account)
if self.account:
results.add_account(self.account)

entries = OrderedDict() #type: Dict[Hashable, List[Directive]]
for f in self.files:
f_entries = self.importer.extract(f, existing_entries=journal.entries)
# if the importer is not authoritative, add all entries to pending
if not self.account:
results.add_pending_entries(map(self._make_import_result, f_entries))
continue
# collect all entries in current statement, grouped by hash
hashed_entries = OrderedDict() #type: Dict[Hashable, Directive]
for entry in f_entries:
Expand All @@ -77,14 +82,15 @@ def prepare(self, journal: 'JournalEditor', results: SourceResults) -> None:
n = len(entries[key_])
entries.setdefault(key_, []).extend(hashed_entries[key_][n:])

get_pending_and_invalid_entries(
raw_entries=list(itertools.chain.from_iterable(entries.values())),
journal_entries=journal.all_entries,
account_set=set([self.account]),
get_key_from_posting=_get_key_from_posting,
get_key_from_raw_entry=self._get_key_from_imported_entry,
make_import_result=self._make_import_result,
results=results)
if self.account:
get_pending_and_invalid_entries(
raw_entries=list(itertools.chain.from_iterable(entries.values())),
journal_entries=journal.all_entries,
account_set=set([self.account]),
get_key_from_posting=_get_key_from_posting,
get_key_from_raw_entry=self._get_key_from_imported_entry,
make_import_result=self._make_import_result,
results=results)

def _add_description(self, entry: Transaction):
if not isinstance(entry, Transaction): return None
Expand Down

0 comments on commit 3a25d01

Please sign in to comment.