MNIST手写数字数据集介绍

MNIST(Mixed National Institute of Standards and Technology)

获取 MNIST 数据集

数据库官网 THE MNIST DATABASEof handwritten digits

image-20220510195738189

在官网上下载如图所示的4个文件并解压,得到下面4个文件

  • train-images.idx3-ubyte, Size: 44.8MB: ==训练集图片, 60000个样本, 每个样本是25*25像素的图片==
  • train-labels.idx1-ubyte, Size: 58.6KB: ==训练集标签, 60000个标签, 每个标签是0-9的数字==
  • t10k-images.idx3-ubyte, Size: 7.47MB: ==测试集图片, 10000个样本, 每个样本是25*25像素的图片==
  • t10k-labels.idx1-ubyte, Size: 9.77KB: ==测试集标签, 10000个标签, 每个标签是0-9的数字==

注意, 从官网下载的源文件是4个.gz格式的压缩包, 需要分别解压才能得到上面的4个文件.

读取 MNIST 数据集

参考CSDN博客【数据】读取mnist数据集

安装numpy和matplotlib两个python库, 参考以下代码

1
2
3
4
5
6
# 新建环境
conda create --name py39_mnist python=3.9
# 激活环境
conda activate py39_mnist
# pip安装两个库
pip install numpy matplotlib

将4个文件放入Python工程文件的目录下(因为使用的是相对路径),执行下面的代码进行读取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import struct
import numpy as np
import matplotlib.pyplot as plt


def load_images(file_name):
# 在读取或写入一个文件之前,你必须使用 Python 内置open()函数来打开它
# file object = open(file_name [, access_mode][, buffering])
# file_name是包含您要访问的文件名的字符串值
# access_mode指定该文件已被打开,即读,写,追加等方式
# 0表示不使用缓冲,1表示在访问一个文件时进行缓冲
# 这里rb表示只能以二进制读取的方式打开一个文件
binfile = open(file_name, 'rb')
# 从一个打开的文件读取数据
buffers = binfile.read()
# 读取image文件前4个整型数字,‘>’表示大端,‘IIII’表示4个int
# ‘0’表示offset=0(偏移量=0),即从0位置开始读取数据
magic, num, rows, cols = struct.unpack_from('>IIII', buffers, 0)
# 'train-images.idx3-ubyte'中整个images数据大小为60000*28*28
# 't10k-images.idx3-ubyte'中整个images数据大小为10000*28*28
# 每个数据都是unsigned char类型,大小为1个字节(byte)
Bytes = num * rows * cols
# ‘>’表示大端,‘B’表示unsigned char,‘str(Bytes)’表示把Bytes转换成字符串
# '>'+str(Bytes)+'B'表示从大端开始读取Bytes个unsigned char类型个数据
# ‘struct.calcsize('>IIII')’表示计算‘>IIII’的大小
# 这里的大小是struct.calcsize('>IIII') = 16 = 4*sizeof(int)
# 即offset = 16, 不读取前16为的数据,从第17位开始读取
# struct.unpack_from()返回的类型是元组,即type(images) = <class 'tuple'>
# 从'train-images.idx3-ubyte'中读取的images元组中有47040000 = 60000*28*28个元素
# 从't10k-images.idx3-ubyte'中读取的images元组中有7840000 = 10000*28*28个元素
images = struct.unpack_from('>' + str(Bytes) + 'B', buffers, struct.calcsize('>IIII'))
# 关闭文件
binfile.close()
# 将从'train-images.idx3-ubyte'中读取的images元组转换为[60000,784]型数组
# 将从't10k-images.idx3-ubyte'中读取的images元组转换为[10000,784]型数组
images = np.reshape(images, [num, rows * cols])
return images


def load_labels(file_name):
# 打开文件
binfile = open(file_name, 'rb')
# 从一个打开的文件读取数据
buffers = binfile.read()
# 读取label文件前2个整形数字,label的长度为num
magic, num = struct.unpack_from('>II', buffers, 0)
# 读取labels数据
labels = struct.unpack_from('>' + str(num) + "B", buffers, struct.calcsize('>II'))
# 关闭文件
binfile.close()
# 转换为一维数组
labels = np.reshape(labels, [num])
return labels


# 下面四行('')中的内容是路径,这里使用的是相对路径
train_images = load_images('train-images.idx3-ubyte')
train_labels = load_labels('train-labels.idx1-ubyte')
test_images = load_images('t10k-images.idx3-ubyte')
test_labels = load_labels('t10k-labels.idx1-ubyte')

# 将读取的图像绘制出来
fig = plt.figure(figsize=(8, 8))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
for i in range(30):
images = np.reshape(train_images[i], [28, 28])
ax = fig.add_subplot(6, 5, i+1, xticks=[], yticks=[])
ax.imshow(images, cmap=plt.cm.binary, interpolation='nearest')
ax.text(0, 7, str(train_labels[i]))
plt.show()

Python3的文件操作

Python3 输入和输出 | 菜鸟教程

Python3 File(文件)方法 | 菜鸟教程

Python3 File read() 方法 | 菜鸟教程

Python3官方文档 | struct— 将字节串解读为打包的二进制数据

廖雪峰 Python 3 教程|常用内建模块|struct