Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save fields as 32 bit floats #6038

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/ert/config/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _fetch_from_ensemble(

def _transform_data(
self, data_array: xr.DataArray
) -> np.ma.MaskedArray[Any, np.dtype[np.double]]:
) -> np.ma.MaskedArray[Any, np.dtype[np.float32]]:
return np.ma.MaskedArray( # type: ignore
_field_truncate(
field_transform(
Expand Down Expand Up @@ -228,20 +228,20 @@ def mask(self) -> Any:
@overload
def field_transform(
data: xr.DataArray, transform_name: Optional[str]
) -> Union[npt.NDArray[np.double], xr.DataArray]:
) -> Union[npt.NDArray[np.float32], xr.DataArray]:
pass


@overload
def field_transform(
data: npt.NDArray[np.double], transform_name: Optional[str]
) -> npt.NDArray[np.double]:
data: npt.NDArray[np.float32], transform_name: Optional[str]
) -> npt.NDArray[np.float32]:
pass


def field_transform(
data: Union[xr.DataArray, npt.NDArray[np.double]], transform_name: Optional[str]
) -> Union[npt.NDArray[np.double], xr.DataArray]:
data: Union[xr.DataArray, npt.NDArray[np.float32]], transform_name: Optional[str]
) -> Union[npt.NDArray[np.float32], xr.DataArray]:
if transform_name is None:
return data
return TRANSFORM_FUNCTIONS[transform_name](data) # type: ignore
Expand Down
10 changes: 5 additions & 5 deletions src/ert/field_utils/field_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def read_field(
field_name: str,
mask: npt.NDArray[np.bool_],
shape: Shape,
) -> np.ma.MaskedArray[Any, np.dtype[np.double]]:
) -> np.ma.MaskedArray[Any, np.dtype[np.float32]]:
path = Path(field_path)
file_extension = path.suffix[1:].upper()
try:
Expand All @@ -115,22 +115,22 @@ def read_field(
raise ValueError(
f'Could not read {field_path}. Unrecognized suffix "{file_extension}"'
) from err
ext = path.suffix
values: Union[npt.NDArray[np.double], np.ma.MaskedArray[Any, np.dtype[np.double]]]
values: Union[npt.NDArray[np.float32], np.ma.MaskedArray[Any, np.dtype[np.float32]]]
if file_format in ROFF_FORMATS:
values = import_roff(field_path, field_name)
elif file_format == FieldFileFormat.GRDECL:
values = import_grdecl(path, field_name, shape, dtype=np.double)
values = import_grdecl(path, field_name, shape, dtype=np.float32)
elif file_format == FieldFileFormat.BGRDECL:
values = import_bgrdecl(field_path, field_name, shape)
else:
ext = path.suffix
raise ValueError(f'Could not read {field_path}. Unrecognized suffix "{ext}"')

return np.ma.MaskedArray(data=values, mask=mask, fill_value=np.nan) # type: ignore


def save_field(
field: np.ma.MaskedArray[Any, np.dtype[np.double]],
field: np.ma.MaskedArray[Any, np.dtype[np.float32]],
field_name: str,
output_path: _PathLike,
file_format: FieldFileFormat,
Expand Down
12 changes: 7 additions & 5 deletions src/ert/field_utils/grdecl_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def import_grdecl(
filename: Union[str, os.PathLike[str]],
name: str,
dimensions: Tuple[int, int, int],
dtype: npt.DTypeLike = float,
) -> npt.NDArray[np.double]:
dtype: npt.DTypeLike = np.float32,
) -> npt.NDArray[np.float32]:
"""
Read a field from a grdecl file, see open_grdecl for description
of format.
Expand Down Expand Up @@ -193,7 +193,7 @@ def import_bgrdecl(
file_path: Union[str, os.PathLike[str]],
field_name: str,
dimensions: Tuple[int, int, int],
) -> npt.NDArray[np.double]:
) -> npt.NDArray[np.float32]:
field_name = field_name.strip()
with open(file_path, "rb") as f:
for entry in ecl_data_io.lazy_read(f):
Expand All @@ -211,15 +211,17 @@ def import_bgrdecl(
f"Attempted to import integer typed field {field_name}"
f" in {file_path}"
)
values = values.astype(np.float64)
values = values.astype(np.float32)
return values.reshape(dimensions, order="F")

raise ValueError(f"Did not find field parameter {field_name} in {file_path}")


# pylint: disable=too-many-arguments
def export_grdecl(
values: Union[np.ma.MaskedArray[Any, np.dtype[np.double]], npt.NDArray[np.double]],
values: Union[
np.ma.MaskedArray[Any, np.dtype[np.float32]], npt.NDArray[np.float32]
],
file_path: Union[str, os.PathLike[str]],
param_name: str,
binary: bool,
Expand Down
11 changes: 7 additions & 4 deletions src/ert/field_utils/roff_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
_PathLike = Union[str, PathLike[str]]


RMS_UNDEFINED_FLOAT = -999.0
RMS_UNDEFINED_FLOAT = np.float32(-999.0)


def export_roff(
data: np.ma.MaskedArray[Any, np.dtype[np.double]],
data: np.ma.MaskedArray[Any, np.dtype[np.float32]],
filelike: Union[TextIO, BinaryIO, _PathLike],
parameter_name: str,
binary: bool,
Expand Down Expand Up @@ -48,9 +48,9 @@ def export_roff(
roffio.write(filelike, file_contents, roff_format=roff_format)


def import_roff(
def import_roff( # pylint: disable=R0912
filelike: Union[TextIO, BinaryIO, _PathLike], name: Optional[str] = None
) -> np.ma.MaskedArray[Any, np.dtype[np.double]]:
) -> np.ma.MaskedArray[Any, np.dtype[np.float32]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JHolba
Can I assume that the code that does the import does the right thing and loads 32bit floats as that is what is stored in .roff files?

Copy link
Contributor

@JHolba JHolba Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like that should be returning 32bit as long as the file is 32bit.
the typing is currently incorrect.
it will return a 64 bit array if the file is 64 bit though.
@eivindjahren

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an explicit conversion to 32bit in case it is 64bit as xtgeo can create 64bit .roff files.

looking_for = {
"dimensions": {
"nX": None,
Expand Down Expand Up @@ -104,6 +104,9 @@ def should_skip_parameter(key: Tuple[str, str]) -> bool:
if np.issubdtype(data.dtype, np.integer):
raise ValueError("Ert does not support discrete roff field parameters")
if np.issubdtype(data.dtype, np.floating):
if data.dtype == np.float64:
# RMS can only work with 32 bit roff files
data = data.astype(np.float32)
dim = looking_for["dimensions"]
if dim["nX"] * dim["nY"] * dim["nZ"] != data.size:
raise ValueError(
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/storage/test_local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def test_that_egrid_files_are_saved_and_loaded_correctly(tmp_path):
ensemble_dir = tmp_path / "ensembles" / str(ensemble.id)
assert ensemble_dir.exists()

data = np.full_like(mask_values, np.nan, dtype=np.double)
np.place(data, mask_values, np.array([1.2, 1.1, 4.3, 3.1]))
data = np.full_like(mask_values, np.nan, dtype=np.float32)
np.place(data, mask_values, np.array([1.2, 1.1, 4.3, 3.1], dtype=np.float32))
da = xr.DataArray(
data.reshape((4, 5, 1)), name="values", dims=["x", "y", "z"] # type: ignore
)
Expand All @@ -43,8 +43,8 @@ def test_that_grid_files_are_saved_and_loaded_correctly(tmp_path):
grid = EclGridGenerator.create_rectangular((4, 5, 1), (1, 1, 1), actnum=mask)
grid.save_GRID(f"{experiment.mount_point}/grid.GRID")

data = np.full_like(mask, np.nan, dtype=np.double)
np.place(data, mask, np.array([1.2, 1.1, 4.3, 3.1]))
data = np.full_like(mask, np.nan, dtype=np.float32)
np.place(data, mask, np.array([1.2, 1.1, 4.3, 3.1], dtype=np.float32))
da = xr.DataArray(
data.reshape((4, 5, 1)), name="values", dims=["x", "y", "z"] # type: ignore
)
Expand Down
Loading