MNIST手写数字数据集介绍
MNIST(Mixed National Institute of Standards and Technology)
获取 MNIST 数据集
数据库官网 THE MNIST DATABASEof handwritten digits

在官网上下载如图所示的4个文件并解压,得到下面4个文件
train-images.idx3-ubyte
, Size: 44.8MB
: ==训练集图片, 60000个样本, 每个样本是25*25像素的图片==
train-labels.idx1-ubyte
, Size: 58.6KB
: ==训练集标签, 60000个标签, 每个标签是0-9的数字==
t10k-images.idx3-ubyte
, Size: 7.47MB
: ==测试集图片, 10000个样本, 每个样本是25*25像素的图片==
t10k-labels.idx1-ubyte
, Size: 9.77KB
: ==测试集标签, 10000个标签, 每个标签是0-9的数字==
注意, 从官网下载的源文件是4个.gz
格式的压缩包, 需要分别解压才能得到上面的4个文件.
读取 MNIST 数据集
参考CSDN博客【数据】读取mnist数据集
安装numpy和matplotlib两个python库, 参考以下代码
1 2 3 4 5 6
| conda create --name py39_mnist python=3.9
conda activate py39_mnist
pip install numpy matplotlib
|
将4个文件放入Python工程文件的目录下(因为使用的是相对路径),执行下面的代码进行读取
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
| import struct import numpy as np import matplotlib.pyplot as plt
def load_images(file_name): binfile = open(file_name, 'rb') buffers = binfile.read() magic, num, rows, cols = struct.unpack_from('>IIII', buffers, 0) Bytes = num * rows * cols images = struct.unpack_from('>' + str(Bytes) + 'B', buffers, struct.calcsize('>IIII')) binfile.close() images = np.reshape(images, [num, rows * cols]) return images
def load_labels(file_name): binfile = open(file_name, 'rb') buffers = binfile.read() magic, num = struct.unpack_from('>II', buffers, 0) labels = struct.unpack_from('>' + str(num) + "B", buffers, struct.calcsize('>II')) binfile.close() labels = np.reshape(labels, [num]) return labels
train_images = load_images('train-images.idx3-ubyte') train_labels = load_labels('train-labels.idx1-ubyte') test_images = load_images('t10k-images.idx3-ubyte') test_labels = load_labels('t10k-labels.idx1-ubyte')
fig = plt.figure(figsize=(8, 8)) fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05) for i in range(30): images = np.reshape(train_images[i], [28, 28]) ax = fig.add_subplot(6, 5, i+1, xticks=[], yticks=[]) ax.imshow(images, cmap=plt.cm.binary, interpolation='nearest') ax.text(0, 7, str(train_labels[i])) plt.show()
|
Python3的文件操作
Python3 输入和输出 | 菜鸟教程
Python3 File(文件)方法 | 菜鸟教程
Python3 File read() 方法 | 菜鸟教程
Python3官方文档 | struct
— 将字节串解读为打包的二进制数据
廖雪峰 Python 3 教程|常用内建模块|struct