畳みこみニューラルネットを0から実装する (第1回)
2022年1月に加筆を行って記事を書き直しました!
こんにちはtatsyです。
今、いろいろな分野でディープ・ラーニング(Deep Learning)が大流行しています。
最近ではディープ・ラーニングを使うだけなら、Theanoであったり、Caffeであったりといったライブラリなどが充実しており、0からこれを実装する必要性は薄いです (2015年当時、2022年にはどちらも使われておらず、TensorFlowとPyTorchが主流というのが興味深いです)。
とはいうものの、中身をある程度理解しようと思えば、自分で実装するに勝る学習法はないでしょう。というわけで、今回は畳みこみニューラルネット(Convolutional Neural Network)を最終目標としてディープ・ラーニングの手法を少しずつ実装していきたいと思います。
今回の記事は「0から」実装することを目的としているので、処理を早くする都合上C++を使って実装します。NumPyなどを使うことにすればPythonなどでも実装は可能だと思います。
なお、早速で申し訳ないのですが、行列の演算についてはソースコードが冗長になるのを防ぐためにEigenというライブラリを使います。
第1回の記事の内容
第1回の今回はMNISTという手書き文字の画像データをサイトから落としてきて、ライブラリで読めるようにするあたりまでをやってみたいと思います。
MNISTのデータ形式
MNISTというのは手書きの文字が28×28の白黒画像で保存されているデータベースです。本家のウェブページは以下になります。
http://yann.lecun.com/exdb/mnist/
このページの上の方に4つのファイルが置いてあります。これらには上から
- トレーニング用の画像
- トレーニング画像に対する数字のラベル
- テスト用の画像
- テスト画像に対する数字のラベル
が含まれています。データ形式についての解説は先ほどのページの一番下に書かれています (以下に日本語訳も掲載しておきます)。
画像データ
バイト数 | 形式 | 内容 |
---|---|---|
4 Byte | 符号付き32ビット整数 | マジック・ナンバー (2051) |
4 Byte | 符号付き32ビット整数 | データの数 (トレーニング用は60000個、テスト用は10000個) |
4 Byte | 符号付き32ビット整数 | 画像の高さ (行数) |
4 Byte | 符号付き32ビット整数 | 画像の幅 (列数) |
1 Byteずつ | 符号なし8ビット整数 | 各画像の輝度値 (0-255) |
ラベルデータ
バイト数 | 形式 | 内容 |
---|---|---|
4 Byte | 符号付き32ビット整数 | マジック・ナンバー (2049) |
4 Byte | 符号付き32ビット整数 | データの数 (トレーニング用は60000個、テスト用は10000個) |
1 Byteずつ | 符号なし8ビット整数 | 各画像が表す数字 (0-9) |
基本的には、この表に従ってデータを読んでいけば良いのですが、画像、ラベルいずれのデータについても値がビッグ・エンディアンで保存されていることに注意が必要です。
とはいっても、実際に注意しなければならないのは複数バイトのデータであるデータのみだけなので、上の表で形式が「符号付き32ビット整数」となっているものだけです。
ビッグエンディアンで記録された整数は、以下のparse_bigendian
のように4バイトのデータを受け取って、インデックスが小さいバイトが実際の整数の上位バイトに来るように変換を行えばOKです。
/**
* Convert integer from little endian to big endian
* リトルエンディアンとビッグエンディアンの相互変換
*/
inline uint32_t parse_bigendian(uint8_t *buf) {
uint32_t ret = 0;
for (int i = 0; i < 4; i++) {
ret = (ret << 8) | buf[i];
}
return ret;
}
実装例
全体のソースコードは私のGitHubからご覧いただけます。適宜更新する予定なので最新版は以下からご確認ください。
https://github.com/tatsy/educnn
なお以下のコードでは各ファイルごとに重要部分だけを抜粋しています。
ヘッダファイル (mnist.h)
全体ソース: https://github.com/tatsy/educnn/blob/master/sources/mnist.h
画像データの読み取り
/**
* Load image data
* 画像データの読み込み
*/
inline Matrix load_images(const std::string &filename) {
std::ifstream reader(filename.c_str(), std::ios::in | std::ios::binary);
if (reader.fail()) {
std::cerr << "Failed to open data: " << filename << std::endl;
exit(1);
}
uint8_t temp[4];
// read magic number
// マジックナンバーの読み込み
reader.read((char *)temp, sizeof(char) * 4);
const int magic = parse_bigendian(temp);
Assertion(magic == 2051, "Invalid magic number!");
// read number of data
// データの数の読み込み
reader.read((char *)temp, sizeof(char) * 4);
const int n_image = parse_bigendian(temp);
// read image height (# of rows)
// 画像の高さ(行数)を読む
reader.read((char *)temp, sizeof(char) * 4);
const int rows = parse_bigendian(temp);
// read image width (# of columns)
// 画像の幅(列数)を読む
reader.read((char *)temp, sizeof(char) * 4);
const int cols = parse_bigendian(temp);
// read pixel values
// 画像の画素値を読む
uint8_t *buf = new uint8_t[rows * cols];
Matrix ret = Matrix::Zero(n_image, rows * cols);
for (int i = 0; i < n_image; i++) {
reader.read((char *)buf, sizeof(char) * rows * cols);
for (int j = 0; j < rows * cols; j++) {
ret(i, j) = buf[j] / 255.0;
}
}
delete[] buf;
reader.close();
return ret;
}
ラベルデータの読み取り
/**
* Load label data
* ラベルデータの読み込み
*/
inline Matrix load_labels(const std::string &filename) {
std::ifstream reader(filename.c_str(), std::ios::in | std::ios::binary);
if (reader.fail()) {
std::cerr << "Failed to open labels: " << filename << std::endl;
exit(1);
}
uint8_t temp[4];
// read magic number
// マジックナンバーの読み込み
reader.read((char *)temp, sizeof(char) * 4);
const int magic = parse_bigendian(temp);
Assertion(magic == 2049, "Invalid magic number!");
// read number of labes
// ラベル数の読み込み
reader.read((char *)temp, sizeof(char) * 4);
const int n_data = parse_bigendian(temp);
// read label index and convert it to one-hot vector
// ラベル番号を読み取ってone-hotベクトルに変換する
Matrix ret = Matrix::Zero(n_data, 10);
for (int i = 0; i < n_data; i++) {
char digit;
reader.read((char *)&digit, sizeof(char));
ret(i, digit) = 1.0;
}
reader.close();
return ret;
}
いずれのプログラムも最初にご紹介した表に従って ifstream
でバイトデータを読み取っています。ラベルデータについては、以後学習に使うためにインデックスのデータをone-hotベクトルに変換していることに注意してください。
今回は第1回ということで解説は簡単にMNISTの読み込みについてご紹介しました。第2回はシグモイド関数を使った通常のニューラルネットについて解説します。
最後までお読みいただき、ありがとうございました。