Skip to content

Commit

Permalink
Match NumPy regarding .npy for names in .npz files
Browse files Browse the repository at this point in the history
Closes #48.

When writing a .npz file, NumPy unconditionally adds ".npy" to every
name. When generating the public list of names when reading a .npz
file, NumPy strips a single ".npy" (if present) from each name. When
accessing an array in a .npz file by name, it first checks if that
exact name is in the .npz file, and if not, it then tries the name
with ".npy" appended. While I personally dislike the
inconsistency/implicitness of this behavior, ndarray-npy should follow
it for compatibility with NumPy.
  • Loading branch information
jturner314 committed Sep 14, 2024
1 parent 7e6ea69 commit b7e27b6
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions src/npz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ impl<W: Write + Seek> NpzWriter<W> {

/// Adds an array with the specified `name` to the `.npz` file.
///
/// Note that a `.npy` extension will be appended to `name`; this matches NumPy's behavior.
///
/// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or
/// [`aview0`](ndarray::aview0).
pub fn add_array<N, S, D>(
Expand All @@ -118,7 +120,7 @@ impl<W: Write + Seek> NpzWriter<W> {
S: Data,
D: Dimension,
{
self.zip.start_file(name, self.options)?;
self.zip.start_file(name.into() + ".npy", self.options)?;
// Buffering when writing individual arrays is beneficial even when the
// underlying writer is `Cursor<Vec<u8>>` instead of a real file. The
// only exception I saw in testing was the "compressed, in-memory
Expand Down Expand Up @@ -221,20 +223,40 @@ impl<R: Read + Seek> NpzReader<R> {
}

/// Returns the names of all of the arrays in the file.
///
/// Note that a single ".npy" suffix (if present) will be stripped from each name; this matches
/// NumPy's behavior.
pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
Ok((0..self.zip.len())
.map(|i| Ok(self.zip.by_index(i)?.name().to_owned()))
.map(|i| {
let file = self.zip.by_index(i)?;
let name = file.name();
let stripped = name.strip_suffix(".npy").unwrap_or(name);
Ok(stripped.to_owned())
})
.collect::<Result<_, ZipError>>()?)
}

/// Reads an array by name.
///
/// Note that this first checks for `name` in the `.npz` file, and if that is not present,
/// checks for `format!("{name}.npy")`. This matches NumPy's behavior.
pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
where
S::Elem: ReadableElement,
S: DataOwned,
D: Dimension,
{
Ok(ArrayBase::<S, D>::read_npy(self.zip.by_name(name)?)?)
// TODO: Combine the two cases into a single `let file = match { ... }` once
// https://github.com/rust-lang/rust/issues/47680 is resolved.
match self.zip.by_name(name) {
Ok(file) => return Ok(ArrayBase::<S, D>::read_npy(file)?),
Err(ZipError::FileNotFound) => {}
Err(err) => return Err(err.into()),
};
Ok(ArrayBase::<S, D>::read_npy(
self.zip.by_name(&format!("{name}.npy"))?,
)?)
}

/// Reads an array by index in the `.npz` file.
Expand Down

0 comments on commit b7e27b6

Please sign in to comment.