Skip to content

Commit ded9602

Browse files
committed
test(cifar-10): 实现cifar-10数据加载
1 parent 73bbcaa commit ded9602

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

data/load_cifar_10.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# @Time : 19-6-7 下午12:05
4+
# @Author : zj
5+
6+
import numpy as np
7+
import os
8+
import cv2
9+
10+
data_path = '/home/lab305/Documents/data/decompress_cifar_10'
11+
12+
cate_list = list(range(10))
13+
14+
dst_size = (32, 32)
15+
16+
17+
def read_image(img_path, isGray=False):
18+
if isGray:
19+
return cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
20+
else:
21+
return cv2.imread(img_path)
22+
23+
24+
def resize_image(src, dst_size):
25+
if src.shape == dst_size:
26+
return src
27+
return cv2.resize(src, dst_size)
28+
29+
30+
def change_channel(input):
31+
if len(input.shape) == 2:
32+
# 灰度图
33+
dst_shape = [1]
34+
dst_shape.extend(input.shape)
35+
return input.reshape(dst_shape)
36+
else:
37+
# 彩色图
38+
return input.transpose(2, 0, 1)
39+
40+
41+
def load_cifar_10_data(shuffle=True):
42+
"""
43+
加载mnist数据
44+
"""
45+
train_dir = os.path.join(data_path, 'train')
46+
test_dir = os.path.join(data_path, 'test')
47+
48+
x_train = []
49+
x_test = []
50+
y_train = []
51+
y_test = []
52+
train_file_list = []
53+
for i in cate_list:
54+
data_dir = os.path.join(train_dir, str(i))
55+
file_list = os.listdir(data_dir)
56+
for filename in file_list:
57+
file_path = os.path.join(data_dir, filename)
58+
train_file_list.append(file_path)
59+
60+
# 读取测试集图像
61+
data_dir = os.path.join(test_dir, str(i))
62+
file_list = os.listdir(data_dir)
63+
for filename in file_list:
64+
file_path = os.path.join(data_dir, filename)
65+
img = read_image(file_path)
66+
if img is not None:
67+
x_test.append(img.reshape(-1))
68+
y_test.append(i)
69+
70+
train_file_list = np.array(train_file_list)
71+
if shuffle:
72+
np.random.shuffle(train_file_list)
73+
74+
# 读取训练集图像
75+
for file_path in train_file_list:
76+
img = read_image(file_path)
77+
if img is not None:
78+
x_train.append(img.reshape(-1))
79+
y_train.append(int(os.path.split(file_path)[0].split('/')[-1]))
80+
81+
return np.array(x_train), np.array(x_test), np.array(y_train), np.array(y_test)

0 commit comments

Comments
 (0)