Skip to content

Commit

Permalink
Add test to get_known_periods
Browse files Browse the repository at this point in the history
  • Loading branch information
Mauko Quiroga committed Feb 8, 2020
1 parent c475f2a commit 9fa7172
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
65 changes: 37 additions & 28 deletions openfisca_core/data_storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import os
import shutil
from typing import Optional
from typing import List, Optional

import numpy

Expand All @@ -25,6 +25,26 @@ def put(self, value: numpy.ndarray, period: periods.Period) -> None:
def delete(self, period: Optional[periods.Period] = None) -> None:
...

def get_known_periods(self) -> List[periods.Period]:
...

def get_memory_usage(self):
if not self._arrays:
return {
"nb_arrays": 0,
"total_nb_bytes": 0,
"cell_size": numpy.nan,
}

nb_arrays = len(self._arrays)
array = next(iter(self._arrays.values()))

return {
"nb_arrays": nb_arrays,
"total_nb_bytes": array.nbytes * nb_arrays,
"cell_size": array.itemsize,
}


class InMemoryStorage(StorageLike):
"""
Expand All @@ -41,18 +61,20 @@ def __init__(self, is_eternal: bool = False) -> None:
def get(self, period: periods.Period) -> numpy.ndarray:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)

period = periods.period(period)
values = self._arrays.get(period)

if values is None:
return None

return values

def put(self, value, period):
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)

period = periods.period(period)
self._arrays[period] = value

def delete(self, period: Optional[periods.Period] = None) -> None:
Expand All @@ -72,25 +94,8 @@ def delete(self, period: Optional[periods.Period] = None) -> None:
if not period.contains(period_item)
}

def get_known_periods(self):
return self._arrays.keys()

def get_memory_usage(self):
if not self._arrays:
return {
"nb_arrays": 0,
"total_nb_bytes": 0,
"cell_size": numpy.nan,
}

nb_arrays = len(self._arrays)
array = next(iter(self._arrays.values()))

return {
"nb_arrays": nb_arrays,
"total_nb_bytes": array.nbytes * nb_arrays,
"cell_size": array.itemsize,
}
def get_known_periods(self) -> List[periods.Period]:
return list(self._arrays.keys())


class OnDiskStorage(StorageLike):
Expand Down Expand Up @@ -126,30 +131,34 @@ def _decode_file(self, file):
def get(self, period: periods.Period) -> numpy.ndarray:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)

period = periods.period(period)
values = self._files.get(period)

if values is None:
return None

return self._decode_file(values)

def put(self, value: numpy.ndarray, period: periods.Period) -> None:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)

period = periods.period(period)
filename = str(period)
path = os.path.join(self.storage_dir, filename) + '.npy'

if isinstance(value, indexed_enums.EnumArray):
self._enums[path] = value.possible_values
value = value.view(numpy.ndarray)

numpy.save(path, value)
self._files[period] = path

def delete(self, period: Optional[periods.Period] = None) -> None:
if period is None:
self._files = {}
return
return None

if self.is_eternal:
period = periods.period(periods.ETERNITY)
Expand All @@ -163,12 +172,12 @@ def delete(self, period: Optional[periods.Period] = None) -> None:
if not period.contains(period_item)
}

def get_known_periods(self):
return self._files.keys()
def get_known_periods(self) -> List[periods.Period]:
return list(self._files.keys())

def restore(self):
self._files = files = {}
# Restore self._files from content of storage_dir.
# Restore self._arrays from content of storage_dir.
for filename in os.listdir(self.storage_dir):
if not filename.endswith('.npy'):
continue
Expand Down
9 changes: 9 additions & 0 deletions tests/core/data_storage/test_in_memory_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,12 @@ def test_delete_when_is_eternal(eternal_storage, value):
result = storage.get("qwerty"), storage.get("azerty")

assert result == (None, None)


def test_get_known_periods(storage, value, period):
storage = storage()
storage.put(value, period)

result = storage.get_known_periods()

assert result == [period]

0 comments on commit 9fa7172

Please sign in to comment.