Skip to content

Commit

Permalink
Upgrade file editor
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 22, 2024
1 parent 458ece9 commit cbb5ec6
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 52 deletions.
257 changes: 212 additions & 45 deletions keras/src/saving/file_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import zipfile

import h5py
import rich.console

from keras.src.saving import saving_lib
from keras.src.saving.saving_lib import H5IOStore
from keras.src.utils import io_utils
from keras.src.utils import summary_utils

try:
import IPython as ipython
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
self.metadata = None
self.config = None
self.model = None
self.console = rich.console.Console(highlight=False)

if filepath.endswith(".keras"):
zf = zipfile.ZipFile(filepath, "r")
Expand All @@ -61,13 +63,15 @@ def __init__(
f"Received: filepath={filepath}"
)

self.weights_dict = self._extract_weights_from_store(
weights_dict, object_metadata = self._extract_weights_from_store(
weights_store.h5_file
)
io_utils.print_msg(self._generate_filepath_info())
self.weights_dict = weights_dict
self.object_metadata = object_metadata # {path: object_name}
self.console.print(self._generate_filepath_info(rich_style=True))

if self.metadata is not None:
io_utils.print_msg(self._generate_metadata_info())
self.console.print(self._generate_metadata_info(rich_style=True))

def weights_summary(self):
if is_ipython_notebook():
Expand All @@ -79,33 +83,167 @@ def compare_to_reference(self, model):
# TODO
raise NotImplementedError()

def delete_layer(self, layer_name):
# TODO
raise NotImplementedError()
def _edit_object(self, edit_fn, source_name, target_name=None):
if target_name is not None and "/" in target_name:
raise ValueError(
"Argument `target_name` should be a leaf name, "
"not a full path name. "
f"Received: target_name='{target_name}'"
)
if "/" in source_name:
# It's a path
elements = source_name.split("/")
weights_dict = self.weights_dict
for e in elements[:-1]:
if e not in weights_dict:
raise ValueError(
f"Path '{source_name}' not found in model."
)
weights_dict = weights_dict[e]
if elements[-1] not in weights_dict:
raise ValueError(f"Path '{source_name}' not found in model.")
edit_fn(
weights_dict, source_name=elements[-1], target_name=target_name
)
else:
# Ensure unicity
def count_occurences(d, name, count=0):
for k in d:
if isinstance(d[k], dict):
count += count_occurences(d[k], name, count)
if name in d:
count += 1
return count

occurences = count_occurences(self.weights_dict, source_name)
if occurences > 1:
raise ValueError(
f"Name '{source_name}' occurs more than once in the model; "
"try passing a complete path"
)
if occurences == 0:
raise ValueError(
f"Source name '{source_name}' does not appear in the "
"model. Use `editor.weights_summary()` "
"to list all objects."
)

def add_layer(self, layer_name, weights):
# TODO
raise NotImplementedError()
def _edit(d):
for k in d:
if isinstance(d[k], dict):
_edit(d[k])
if source_name in d:
edit_fn(d, source_name=source_name, target_name=target_name)

def rename_layer(self, source_name, target_name):
# TODO
raise NotImplementedError()
_edit(self.weights_dict)

def delete_weight(self, layer_name, weight_name):
# TODO
raise NotImplementedError()
def rename_object(self, source_name, target_name):
def rename_fn(weights_dict, source_name, target_name):
weights_dict[target_name] = weights_dict[source_name]
weights_dict.pop(source_name)

def add_weight(self, layer_name, weight_name, weight_value):
# TODO
raise NotImplementedError()
self._edit_object(rename_fn, source_name, target_name)

def resave_weights(self, fpath):
# TODO
raise NotImplementedError()
def delete_object(self, name):
def delete_fn(weights_dict, source_name, target_name=None):
weights_dict.pop(source_name)

self._edit_object(delete_fn, name)

def add_object(self, name, weights):
if not isinstance(weights, dict):
raise ValueError(
"Argument `weights` should be a dict "
"where keys are weight names (usually '0', '1', etc.) "
"and values are NumPy arrays. "
f"Received: type(weights)={type(weights)}"
)

if "/" in name:
# It's a path
elements = name.split("/")
partial_path = "/".join(elements[:-1])
weights_dict = self.weights_dict
for e in elements[:-1]:
if e not in weights_dict:
raise ValueError(
f"Path '{partial_path}' not found in model."
)
weights_dict = weights_dict[e]
weights_dict[elements[-1]] = weights
else:
self.weights_dict[name] = weights

def delete_weight(self, object_name, weight_name):
def delete_weight_fn(weights_dict, source_name, target_name=None):
if weight_name not in weights_dict[source_name]:
raise ValueError(
f"Weight {weight_name} not found "
f"in object {object_name}. "
"Weights found: "
f"{list(weights_dict[source_name].keys())}"
)
weights_dict[source_name].pop(weight_name)

self._edit_object(delete_weight_fn, object_name)

def add_weights(self, object_name, weights):
if not isinstance(weights, dict):
raise ValueError(
"Argument `weights` should be a dict "
"where keys are weight names (usually '0', '1', etc.) "
"and values are NumPy arrays. "
f"Received: type(weights)={type(weights)}"
)

def add_weight_fn(weights_dict, source_name, target_name=None):
weights_dict[source_name].update(weights)

self._edit_object(add_weight_fn, object_name)

def resave_weights(self, filepath):
filepath = str(filepath)
if not filepath.endswith(".weights.h5"):
raise ValueError(
"Invalid `filepath` argument: "
"expected a `.weights.h5` extension. "
f"Received: filepath={filepath}"
)
weights_store = H5IOStore(filepath, mode="w")

def _save(weights_dict, weights_store, inner_path):
vars_to_create = {}
for name, value in weights_dict.items():
if isinstance(value, dict):
if value:
_save(
weights_dict[name],
weights_store,
inner_path=inner_path + "/" + name,
)
else:
# e.g. name="0", value=HDF5Dataset
vars_to_create[name] = value
if vars_to_create:
var_store = weights_store.make(inner_path)
for name, value in vars_to_create.items():
var_store[name] = value

_save(self.weights_dict, weights_store, inner_path="")
weights_store.close()

def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
metadata = metadata or {}

object_metadata = {}
for k, v in data.attrs.items():
object_metadata[k] = v
if object_metadata:
metadata[inner_path] = object_metadata

def _extract_weights_from_store(self, data):
result = collections.OrderedDict()
for key in data.keys():
inner_path = inner_path + "/" + key
value = data[key]
if isinstance(value, h5py.Group):
if len(value) == 0:
Expand All @@ -115,54 +253,83 @@ def _extract_weights_from_store(self, data):

if hasattr(value, "keys"):
if "vars" in value.keys():
result[key] = self._extract_weights_from_store(
value["vars"]
result[key], metadata = self._extract_weights_from_store(
value["vars"], metadata=metadata, inner_path=inner_path
)
else:
result[key] = self._extract_weights_from_store(value)
result[key], metadata = self._extract_weights_from_store(
value, metadata=metadata, inner_path=inner_path
)
else:
result[key] = value
return result
return result, metadata

def _generate_filepath_info(self):
return f"Keras model file '{self.filepath}'"
def _generate_filepath_info(self, rich_style=False):
if rich_style:
filepath = f"'{self.filepath}'"
filepath = f"{summary_utils.highlight_symbol(filepath)}"
else:
filepath = f"'{self.filepath}'"
return f"Keras model file {filepath}"

def _generate_config_info(self):
def _generate_config_info(self, rich_style=False):
return pprint.pformat(self.config)

def _generate_metadata_info(self):
return (
f"Saved with Keras {self.metadata['keras_version']} "
f"- date: {self.metadata['date_saved']}"
)
def _generate_metadata_info(self, rich_style=False):
version = self.metadata["keras_version"]
date = self.metadata["date_saved"]
if rich_style:
version = f"{summary_utils.highlight_symbol(version)}"
date = f"{summary_utils.highlight_symbol(date)}"
return f"Saved with Keras {version} " f"- date: {date}"

def _print_weights_structure(
self, weights_dict, indent=0, is_last=True, prefix=""
self, weights_dict, indent=0, is_first=True, prefix="", inner_path=""
):
for idx, (key, value) in enumerate(weights_dict.items()):
is_last_item = idx == len(weights_dict) - 1
connector = "└─ " if is_last_item else "├─ "
inner_path = inner_path + "/" + key
is_last = idx == len(weights_dict) - 1
if is_first:
is_first = False
connector = "> "
elif is_last:
connector = "└─ "
else:
connector = "├─ "

if isinstance(value, dict):
io_utils.print_msg(f"{prefix}{connector}{key}")
new_prefix = prefix + (" " if is_last_item else "│ ")
bold_key = summary_utils.bold_text(key)
object_label = f"{prefix}{connector}{bold_key}"
if inner_path in self.object_metadata:
metadata = self.object_metadata[inner_path]
if "name" in metadata:
name = metadata["name"]
object_label += f" ('{name}')"
self.console.print(object_label)
if is_last:
appended = " "
else:
appended = "│ "
new_prefix = prefix + appended
self._print_weights_structure(
value,
indent + 1,
is_last=is_last_item,
is_first=is_first,
prefix=new_prefix,
inner_path=inner_path,
)
else:
if isinstance(value, h5py.Dataset):
io_utils.print_msg(
f"{prefix}{connector}{key}:"
bold_key = summary_utils.bold_text(key)
self.console.print(
f"{prefix}{connector}{bold_key}:"
+ f" shape={value.shape}, dtype={value.dtype}"
)
else:
io_utils.print_msg(f"{prefix}{connector}{key}: {value}")
self.console.print(f"{prefix}{connector}{key}: {value}")

def _weights_summary_cli(self):
io_utils.print_msg("Weights structure")
self.console.print("Weights structure")
self._print_weights_structure(self.weights_dict, prefix=" " * 2)

def _weights_summary_iteractive(self):
Expand Down
25 changes: 18 additions & 7 deletions keras/src/saving/saving_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,13 @@ def _save_state(
return

if hasattr(saveable, "save_own_variables") and weights_store:
saveable.save_own_variables(weights_store.make(inner_path))
if hasattr(saveable, "name") and isinstance(saveable.name, str):
metadata = {"name": saveable.name}
else:
metadata = None
saveable.save_own_variables(
weights_store.make(inner_path, metadata=metadata)
)
if hasattr(saveable, "save_assets") and assets_store:
saveable.save_assets(assets_store.make(inner_path))

Expand Down Expand Up @@ -924,8 +930,8 @@ def __init__(self, root_path, archive=None, mode="r"):
else:
self.h5_file = h5py.File(root_path, mode=self.mode)

def make(self, path):
return H5Entry(self.h5_file, path, mode="w")
def make(self, path, metadata=None):
return H5Entry(self.h5_file, path, mode="w", metadata=metadata)

def get(self, path):
return H5Entry(self.h5_file, path, mode="r")
Expand All @@ -941,10 +947,11 @@ def close(self):
class H5Entry:
"""Leaf entry in a H5IOStore."""

def __init__(self, h5_file, path, mode):
def __init__(self, h5_file, path, mode, metadata=None):
self.h5_file = h5_file
self.path = path
self.mode = mode
self.metadata = metadata

if mode == "w":
if not path:
Expand All @@ -953,11 +960,15 @@ def __init__(self, h5_file, path, mode):
self.group = self.h5_file.create_group(self.path).create_group(
"vars"
)
if self.metadata:
for k, v in self.metadata.items():
self.group.attrs[k] = v
else:
found = False
if not path:
self.group = self.h5_file["vars"]
found = True
if "vars" in self.h5_file:
self.group = self.h5_file["vars"]
found = True
elif path in self.h5_file and "vars" in self.h5_file[path]:
self.group = self.h5_file[path]["vars"]
found = True
Expand Down Expand Up @@ -1026,7 +1037,7 @@ def __init__(self, root_path, archive=None, mode="r"):
self.f = open(root_path, mode="rb")
self.contents = np.load(self.f, allow_pickle=True)

def make(self, path):
def make(self, path, metadata=None):
if not path:
self.contents["__root__"] = {}
return self.contents["__root__"]
Expand Down

0 comments on commit cbb5ec6

Please sign in to comment.