Skip to content

Commit

Permalink
Add option to create and/or write xns11 files
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-kipawa committed Feb 19, 2024
1 parent 5abf9d4 commit 8ea3ed5
Showing 1 changed file with 43 additions and 13 deletions.
56 changes: 43 additions & 13 deletions mikeio1d/xns11.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import functools

import os.path
from pathlib import Path

import pandas as pd

from .cross_sections import CrossSection
from .cross_sections import CrossSectionCollection

from DHI.Mike1D.CrossSectionModule import CrossSectionDataFactory
from DHI.Mike1D.CrossSectionModule import CrossSectionData
from DHI.Mike1D.Generic import Connection, Diagnostics, Location


Expand Down Expand Up @@ -79,16 +82,20 @@ def wrapper(self, *args, **kwargs):


class Xns11:
def __init__(self, file_path=None):
self.file_path = file_path
self.file = None
self._closed = True
def __init__(self, file_path: str | Path = None):
self.file_path: str | Path = file_path
self._cross_section_data_factory = CrossSectionDataFactory()
self._cross_section_data = None

self._reach_names = None
self.__reaches = None
self._topoid_names = None
self.__topoids = None

# Load the file on initialization
self._load_file()
self._init_cross_section_data()
self._closed = True

self.xsections = CrossSectionCollection()
self._init_xsections()

Expand All @@ -99,15 +106,15 @@ def _load_file(self):
"""Load the file."""
if not os.path.exists(self.file_path):
raise FileExistsError(f"File {self.file_path} does not exist.")
self.file = CrossSectionDataFactory().Open(
self._cross_section_data = self._cross_section_data_factory.Open(
Connection.Create(self.file_path), Diagnostics("Error loading file.")
)
self._closed = False

def _get_info(self) -> str:
info = []
if self.file_path:
info.append(f"# Cross sections: {str(self.file.Count)}")
info.append(f"# Cross sections: {str(self._cross_section_data.Count)}")
info.append(f"Interpolation type: {str(self.interpolation_type)}")

info = str.join("\n", info)
Expand All @@ -119,19 +126,42 @@ def __enter__(self):
def __exit__(self, *excinfo):
self.close()

def _init_cross_section_data(self):
"""Initialize the CrossSectionData object."""
if self.file_path and os.path.exists(self.file_path):
return self._load_file()
self._cross_section_data = CrossSectionData()

def _init_xsections(self):
"""Initialize the cross sections."""
for xs in self.file:
for xs in self._cross_section_data:
self.xsections.add_xsection(CrossSection(xs))

def info(self):
"""Prints information about the result file."""
info = self._get_info()
print(info)

def write(self, file_path: str | Path = None):
"""Write data to the file."""
file_path = file_path if file_path else self.file_path

if not file_path:
raise ValueError("A file path must be provided.")

file_path = Path(file_path)
if not file_path.suffix == ".xns11":
raise ValueError("The file extension must be .xns11.")

current_con_path = Path(self._cross_section_data.Connection.FilePath.Path)
if not file_path.exists() or not current_con_path.resolve().samefile(file_path.resolve()):
self._cross_section_data.Connection = Connection.Create(str(file_path))

self._cross_section_data_factory.Save(self._cross_section_data)

def close(self):
"""Close the file handle."""
self.file.Finalize()
self._cross_section_data.Finalize()
self._closed = True

@property
Expand All @@ -151,13 +181,13 @@ def interpolation_type(self):
- Middling: 2
Interpolation happens during runtime by requesting values at neighbour cross sections and interpolate between those.
"""
return self.file.XSInterpolationType
return self._cross_section_data.XSInterpolationType

@property
def _topoids(self):
if self.__topoids:
return self.__topoids
return list(self.file.GetReachTopoIdEnumerable())
return list(self._cross_section_data.GetReachTopoIdEnumerable())

@property
@_not_closed
Expand All @@ -171,7 +201,7 @@ def topoid_names(self):
def _reaches(self):
if self.__reaches:
return self.__reaches
return list(self.file.GetReachTopoIdEnumerable())
return list(self._cross_section_data.GetReachTopoIdEnumerable())

@property
@_not_closed
Expand Down Expand Up @@ -202,7 +232,7 @@ def _get_values(self, points):
location = Location()
location.ID = reach.value
location.Chainage = chainage.value
geometry = self.file.FindClosestCrossSection(
geometry = self._cross_section_data.FindClosestCrossSection(
location, topoid.value
).BaseCrossSection.Points
x, z = [], []
Expand Down

0 comments on commit 8ea3ed5

Please sign in to comment.