Skip to content

Commit

Permalink
Mypy type safety: round 8
Browse files Browse the repository at this point in the history
  • Loading branch information
TheBB committed Feb 15, 2024
1 parent 316e63d commit e3b5492
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 78 deletions.
43 changes: 24 additions & 19 deletions splipy/io/spl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from itertools import islice
from pathlib import Path
from typing import Union, TextIO, Type, Optional, Iterator
from types import TracebackType

import numpy as np
from typing_extensions import Self

from ..curve import Curve
from ..surface import Surface
Expand All @@ -13,21 +17,29 @@

class SPL(MasterIO):

def __init__(self, filename):
if not filename.endswith('.spl'):
filename += '.spl'
self.filename = filename
self.trimming_curves = []
filename: str
fstream: TextIO

def __enter__(self):
self.fstream = open(self.filename, 'r')
def __init__(self, filename: Union[Path, str]) -> None:
self.filename = str(filename)

def __enter__(self) -> Self:
self.fstream = open(self.filename, 'r').__enter__()
return self

def lines(self):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]
) -> None:
self.fstream.__exit__(exc_type, exc_val, exc_tb)

def lines(self) -> Iterator[str]:
for line in self.fstream:
yield line.split('#', maxsplit=1)[0].strip()

def read(self):
def read(self) -> list[SplineObject]:
lines = self.lines()

version = next(lines).split()
Expand All @@ -46,19 +58,12 @@ def read(self):
knots = [[float(k) for k in islice(lines, nkts)] for nkts in nknots]
bases = [BSplineBasis(p, kts, -1) for p, kts in zip(orders, knots)]

cpts = np.array([float(k) for k in islice(lines, totcoeffs * physdim)])
cpts = np.array([float(k) for k in islice(lines, totcoeffs * physdim)], dtype=float)
cpts = cpts.reshape(physdim, *(ncoeffs[::-1])).transpose()

if pardim == 1:
patch = Curve(*bases, controlpoints=cpts, raw=True)
elif pardim == 2:
patch = Surface(*bases, controlpoints=cpts, raw=True)
elif pardim == 3:
patch = Volume(*bases, controlpoints=cpts, raw=True)
if 1 <= pardim <= 3:
patch = SplineObject.constructor(pardim)(bases, cpts, raw=True)
else:
patch = SplineObject(bases, controlpoints=cpts, raw=True)

return [patch]

def __exit__(self, exc_type, exc_value, traceback):
self.fstream.close()
162 changes: 103 additions & 59 deletions splipy/io/stl.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
#coding:utf-8

import struct
from abc import ABC, abstractmethod
from typing import Union, Optional, Type, Sequence, TextIO, BinaryIO, cast
from types import TracebackType
from pathlib import Path

import numpy as np
from typing_extensions import Self

from ..surface import Surface
from ..volume import Volume
from ..utils import ensure_listlike
from ..splinemodel import SplineModel
from ..splineobject import SplineObject
from ..types import Scalars

from .master import MasterIO

Expand All @@ -25,28 +30,29 @@
BINARY_FACET = "12fH"


class ASCII_STL_Writer(object):
""" Export 3D objects build of 3 or 4 vertices as ASCII STL file.
"""
def __init__(self, stream):
self.fp = stream
self._write_header()
Face = Sequence[Scalars]

def _write_header(self):
self.fp.write("solid python\n")

def close(self):
self.fp.write("endsolid python\n")
class StlWriter(ABC):

def _write(self, face):
self.fp.write(ASCII_FACET.format(face=face))
@abstractmethod
def _write_header(self) -> None:
...

@abstractmethod
def _write(self, face: Face) -> None:
...

def _split(self, face):
@abstractmethod
def close(self) -> None:
...

def _split(self, face: Face) -> tuple[Face, Face]:
p1, p2, p3, p4 = face
return (p1, p2, p3), (p3, p4, p1)

def add_face(self, face):
""" Add one face with 3 or 4 vertices. """
def add_face(self, face: Face) -> None:
"""Add one face with 3 or 4 vertices."""
if len(face) == 4:
face1, face2 = self._split(face)
self._write(face1)
Expand All @@ -56,30 +62,50 @@ def add_face(self, face):
else:
raise ValueError('only 3 or 4 vertices for each face')

def add_faces(self, faces):
def add_faces(self, faces: Sequence[Face]) -> None:
""" Add many faces. """
for face in faces:
self.add_face(face)

class BINARY_STL_Writer(ASCII_STL_Writer):
""" Export 3D objects build of 3 or 4 vertices as binary STL file.
"""
def __init__(self, stream):
self.counter = 0
#### new-style classes way of calling super constructor
# super(Binary_STL_Writer, self).__init__(stream)

#### old-style classes way of doing it
ASCII_STL_Writer.__init__(self, stream)
class AsciiStlWriter(StlWriter):
"""Export 3D objects build of 3 or 4 vertices as ASCII STL file."""

def close(self):
fp: TextIO

def __init__(self, stream: TextIO) -> None:
self.fp = stream
self._write_header()

def _write_header(self):
def _write_header(self) -> None:
self.fp.write("solid python\n")

def close(self) -> None:
self.fp.write("endsolid python\n")

def _write(self, face: Face) -> None:
self.fp.write(ASCII_FACET.format(face=face))


class BinaryStlWriter(StlWriter):
"""Export 3D objects build of 3 or 4 vertices as binary STL file."""

counter: int
fp: BinaryIO

def __init__(self, stream: BinaryIO) -> None:
self.counter = 0
self.fp = stream
self._write_header()

def close(self) -> None:
self._write_header()

def _write_header(self) -> None:
self.fp.seek(0)
self.fp.write(struct.pack(BINARY_HEADER, b'Python Binary STL Writer', self.counter))

def _write(self, face):
def _write(self, face: Face) -> None:
self.counter += 1
data = [
0., 0., 0.,
Expand All @@ -92,42 +118,65 @@ def _write(self, face):


class STL(MasterIO):
def __init__(self, filename, binary=True):
if filename[-4:] != '.stl':
filename += '.stl'
self.filename = filename
self.binary = binary

def __enter__(self):
filename: str
binary: bool

writer: Union[BinaryStlWriter, AsciiStlWriter]

def __init__(self, filename: Union[str, Path], binary: bool = True) -> None:
self.filename = str(filename)
self.binary = binary

def __enter__(self) -> Self:
if self.binary:
fp = open(self.filename, 'wb')
self.writer = BINARY_STL_Writer(fp)
self.writer = BinaryStlWriter(open(self.filename, 'wb'))
else:
fp = open(self.filename, 'w')
self.writer = ASCII_STL_Writer(fp)
self.writer = AsciiStlWriter(open(self.filename, 'w'))
return self

def write(self, obj, n=None):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]
) -> None:
self.writer.close()
self.writer.fp.__exit__(exc_type, exc_val, exc_tb)

def write(
self,
obj: Union[SplineModel, Sequence[SplineObject], SplineObject],
n: Optional[Union[int, Sequence[int]]] = None,
) -> None:
if isinstance(obj, SplineModel):
if obj.pardim == 3: # volume model
for surface in obj.boundary():
self.write_surface(surface.obj,n)
for node in obj.boundary():
self.write_surface(cast(Surface, node.obj), n)
elif obj.pardim == 2: # surface model
for surface in obj:
self.write_surface(surface, n)
self.write_surface(cast(Surface, surface), n)

elif isinstance(obj, Volume):
for surface in obj.faces():
if surface is not None: # happens with periodic volumes
self.write_surface(surface, n)
for face in obj.faces():
if face is not None: # happens with periodic volumes
self.write_surface(face, n)

elif isinstance(obj, Surface):
self.write_surface(obj, n)

elif isinstance(obj, Sequence):
for o in obj:
self.write(o)

else:
raise ValueError('Unsopported object for STL format')

def write_surface(self, surface, n=None):
def write_surface(
self,
surface: Surface,
n: Optional[Union[int, Sequence[int]]] = None,
) -> None:
# choose evaluation points as one of three cases:
# 1. specified with input
# 2. linear splines, only picks knots
Expand All @@ -142,9 +191,9 @@ def write_surface(self, surface, n=None):
else:
knots = surface.knots(0)
p = surface.order(0)
u = [np.linspace(k0,k1, 2*p-3, endpoint=False) for (k0,k1) in zip(knots[:-1], knots[1:])]
u = [point for element in u for point in element] + list(knots)
u = np.sort(u)
ut = [np.linspace(k0,k1, 2*p-3, endpoint=False) for (k0,k1) in zip(knots[:-1], knots[1:])]
ut = [point for element in ut for point in element] + list(knots)
u = np.sort(ut)

if n is not None:
v = np.linspace(surface.start(1), surface.end(1), n[1])
Expand All @@ -153,9 +202,9 @@ def write_surface(self, surface, n=None):
else:
knots = surface.knots(1)
p = surface.order(1)
v = [np.linspace(k0,k1, 2*p-3, endpoint=False) for (k0,k1) in zip(knots[:-1], knots[1:])]
v = [point for element in v for point in element] + list(knots)
v = np.sort(v)
vt = [np.linspace(k0,k1, 2*p-3, endpoint=False) for (k0,k1) in zip(knots[:-1], knots[1:])]
vt = [point for element in vt for point in element] + list(knots)
v = np.sort(vt)

# perform evaluation and make sure that we have 3 components (in case of 2D geometries)
x = surface(u,v)
Expand All @@ -166,8 +215,3 @@ def write_surface(self, surface, n=None):
faces = [[x[i,j], x[i,j+1], x[i+1,j+1], x[i+1,j]] for i in range(x.shape[0]-1) for j in range(x.shape[1]-1)]

self.writer.add_faces(faces)

def __exit__(self, exc_type, exc_value, traceback):
self.writer.close()
self.writer.fp.close()

0 comments on commit e3b5492

Please sign in to comment.