From a1416c2f3fdadd6be2b4099363f65bb2b57820a3 Mon Sep 17 00:00:00 2001 From: Jinyang Gao Date: Mon, 17 Jun 2024 16:40:23 +0800 Subject: [PATCH] Create cifar100.py add cifar100 for testing the model selection --- examples/cnn_ms/data/cifar100.py | 81 ++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 examples/cnn_ms/data/cifar100.py diff --git a/examples/cnn_ms/data/cifar100.py b/examples/cnn_ms/data/cifar100.py new file mode 100644 index 000000000..b9f121b0a --- /dev/null +++ b/examples/cnn_ms/data/cifar100.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +try: + import pickle +except ImportError: + import cPickle as pickle + +import numpy as np +import os +import sys + + +def load_dataset(filepath): + with open(filepath, 'rb') as fd: + try: + cifar100 = pickle.load(fd, encoding='latin1') + except TypeError: + cifar100 = pickle.load(fd) + image = cifar100['data'].astype(dtype=np.uint8) + image = image.reshape((-1, 3, 32, 32)) + label = np.asarray(cifar100['fine_labels'], dtype=np.uint8) + label = label.reshape(label.size, 1) + return image, label + + +def load_train_data(dir_path='/tmp/cifar-100-python'): + images, labels = load_dataset(check_dataset_exist(dir_path + "/train")) + return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32) + + +def load_test_data(dir_path='/tmp/cifar-100-python'): + images, labels = load_dataset(check_dataset_exist(dir_path + "/test")) + return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32) + + +def check_dataset_exist(dirpath): + if not os.path.exists(dirpath): + print( + 'Please download the cifar100 dataset using python data/download_cifar100.py' + ) + sys.exit(0) + return dirpath + + +def normalize(train_x, val_x): + mean = [0.4914, 0.4822, 0.4465] + std = [0.2023, 0.1994, 0.2010] + train_x /= 255 + val_x /= 255 + for ch in range(0, 2): + train_x[:, ch, :, :] -= mean[ch] + train_x[:, ch, :, :] /= std[ch] + val_x[:, ch, :, :] -= mean[ch] + val_x[:, ch, :, :] /= std[ch] + return train_x, val_x + + +def load(): + train_x, train_y = load_train_data() + val_x, val_y = load_test_data() + train_x, val_x = normalize(train_x, val_x) + train_y = train_y.flatten() + val_y = val_y.flatten() + return train_x, train_y, val_x, val_y