はじめに
本稿では Deeplearning4j を使った、ディープラーニングについて説明します。厳密な定義や数式には立ち入らず、意味合いと利用方法に焦点を当てていきます。
題材として、ディープラーニング界の Hello World である MNIST(Modified National Institute of Standards and Technology database)を使った手書き数字の画像認識を扱います。
ディープラーニングの考え方と、実際の実装をあわせて説明し、手書き数字の画像認識のアプリケーションを作成するまでをガイドします。今回は、全3回の内の 3 回目です。
前回までで、手書き数字の画像認識を行うための基本となる事柄を整理してきました。 今回は、この内容を踏まえてDeeplearning4j による実装を進めていきましょう。
Deeplearning4j のプロジェクトを準備する
ライブラリは deeplearning4j-core
と nd4j-native-platform
を利用します。
deeplearning4j-core
はニューラルネットワークの実装を含み、nd4j-native-platform
は行列ライブラリを提供します。
Maven の場合は以下の依存定義となります。
<dependencies> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>1.0.0-M1.1</version> </dependency> <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-native-platform</artifactId> <version>1.0.0-M1.1</version> </dependency> </dependencies>
今回は、Gradle の Kotlin スクリプトで作成するため、build.gradle.kts は以下のようなものを準備します。
plugins { application } repositories { mavenCentral() } dependencies { implementation("org.deeplearning4j:deeplearning4j-core:1.0.0-M1.1") implementation("org.nd4j:nd4j-native-platform:1.0.0-M1.1") implementation("org.slf4j:slf4j-jdk14:1.7.30") } application { mainClass.set("com.mammb.code.example.dl4j.App") } java { toolchain { languageVersion.set(JavaLanguageVersion.of(16)) } }
最初に、MNIST データベース からデータセットを作成していきます。
Deeplearning4j の提供する MnistDataSetIterator
を使うこともできますが、ここでは一から作成していきます。
MNIST データベース
手書き数字のデータベースには、MNIST(Modified National Institute of Standards and Technology database) を使います。
PNG画像化されたものが https://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz
にあるので、今回はこれを使うことにします。
mnist_png.tar.gz
を解凍すると、training
と testing
ディレクトリの中に 0 ~ 9 のディレクトリがあり、それぞれ該当する数字の PNG 画像が格納される形となります。
このデータベースを Minst
クラスとして作成します。
public class Mnist { private final URL url; private final Path baseDir; private final Path trainingDir; private final Path testingDir; final int imgHeight; final int imgWidth; final int outcomes; public Mnist() { try { this.url = new URL("https://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz"); this.baseDir = Paths.get("./mnist"); this.trainingDir = baseDir.resolve("mnist_png/training"); this.testingDir = baseDir.resolve("mnist_png/testing"); this.imgHeight = 28; this.imgWidth = 28; this.outcomes = 10; } catch (MalformedURLException e) { throw new RuntimeException(e); } } // ... }
trainingDir
と testingDir
が解凍後のテストデータのディレクトリを表します。
imgHeight
と imgWidth
は PNG 画像のサイズ、outcomes
が種類数で今回は 0 ~ 9 の10個となります。
Minst
クラスからは、以下のように画像ファイルを取得できるものとします。
public Stream<File> testingImages() { } public Stream<File> trainingImages() { }
最初に取り組むのは、MNIST ファイルのダウンロードと解凍になります。
.tar.gz
の解凍はJDKの標準ライブラリで提供されないため commons-compress
を使います。
GzipCompressorInputStream
と TarArchiveInputStream
を使い、以下のようにダウンロードと解凍を行うことができます。
private void fetch() { baseDir.toFile().mkdirs(); try (InputStream gzi = new GzipCompressorInputStream(url.openStream()); ArchiveInputStream in = new TarArchiveInputStream(gzi)) { ArchiveEntry entry; while ((entry = in.getNextEntry()) != null) { if (!in.canReadEntryData(entry)) { continue; } File file = baseDir.resolve(entry.getName()).toFile(); if (entry.isDirectory()) { if (!file.isDirectory() && !file.mkdirs()) { throw new IOException("failed to create directory " + file); } } else { File parent = file.getParentFile(); if (!parent.isDirectory() && !parent.mkdirs()) { throw new IOException("failed to create directory " + parent); } try (OutputStream o = Files.newOutputStream(file.toPath())) { IOUtils.copy(in, o); } } } } catch (IOException e) { throw new RuntimeException(e); } }
ダウンロードファイルはそれなりに大きいので、既にダウンロード済みの場合は再利用します。 以下のユーティリティメソッドを用意しておきましょう。
private Stream<File> images(Path path) { if (!exists()) { fetch(); } return IntStream.range(0, outcomes) .mapToObj(Integer::toString) .map(path::resolve) .map(Path::toFile) .map(File::listFiles) .flatMap(Stream::of) .filter(File::isFile); } private boolean exists() { File file = baseDir.toFile(); return file.exists() && file.list().length > 0; }
public メソッドとして以下のようにして、Minst
クラスは完成にしましょう。
public int nIn() { return imgHeight * imgWidth; } public long countTrainingImage() { return images(trainingDir).count(); } public long countTestingImage() { return images(testingDir).count(); } public Stream<File> trainingImages() { return images(trainingDir); } public Stream<File> testingImages() { return images(testingDir); } public File selectAny() { File[] files = testingDir.resolve( Integer.toString(new Random().nextInt(10))) .toFile() .listFiles(); return files[new Random().nextInt(files.length)]; }
全体像はMnist.java を参照してください。
データセットの定義
Deeplearning4j 用のデータセットとして MnistSet
クラスを作成します。
このクラスは、ネットワークモデル構築に必要な情報と、ネットワークモデルに流すデータセット DataSetIterator
を供給するものとし、以下の公開メソッドを提供するものとします。
public DataSetIterator iterator() { } public DataSetIterator iteratorTesting() { } public int nIn() { } public int outcomes() { }
MnistSet
クラスは以下のようにコンストラクトします。
public class MnistSet { private final Mnist mnist; private final NativeImageLoader imageLoader; private final ImagePreProcessingScaler imageScaler; public MnistSet() { this.mnist = new Mnist(); this.imageLoader = new NativeImageLoader(mnist.imgHeight, mnist.imgWidth); this.imageScaler = new ImagePreProcessingScaler(0, 1); } // ... }
Mnist
は先に作成したMNISTデータベースです。画像ファイルは Deeplearning4j の提供する NativeImageLoader
にて行列形式で読み込み、画像のスケール変換を、こちらも Deeplearning4j の提供する ImagePreProcessingScaler
にて行います。
DataSetIterator
の生成は以下のようになります。
public DataSetIterator iterator(Stream<File> images, int samples) { final INDArray in = Nd4j.create(samples, mnist.nIn()); final INDArray out = Nd4j.create(samples, mnist.outcomes); final AtomicInteger index = new AtomicInteger(); images.forEach(file -> { try { int n = index.getAndIncrement(); INDArray img = imageLoader.asRowVector(file); imageScaler.transform(img); in.putRow(n, img); int label = Integer.parseInt(file.toPath().getParent().getFileName().toString()); out.put(n, label, 1.0); // one-hot } catch (IOException e) { throw new RuntimeException(e); } }); List<DataSet> list = new DataSet(in, out).asList(); Collections.shuffle(list, new Random(System.currentTimeMillis())); int batchSize = 10; return new ListDataSetIterator<>(list, batchSize); }
INDArray in
が全入力画像を配列形式としたもの、INDArray out
が画像に対する正解ラベルの配列を表します。
これらの配列から、以下のように DataSet
を作成し、ランダムに並べ替えたうえで、ListDataSetIterator
を生成します。
List<DataSet> list = new DataSet(in, out).asList(); Collections.shuffle(list, new Random(System.currentTimeMillis())); int batchSize = 10; return new ListDataSetIterator<>(list, batchSize);
INDArray in
に格納する1つの画像データは、以下のように配列形式とした上で、0から1の float 値に正規化したものを使います。
INDArray img = imageLoader.asRowVector(file); imageScaler.transform(img);
正解ラベルは画像ファイルの格納ディレクトリ名から構築できるため、以下のように正解ラベルの位置に 1、その他は 0 となるものを準備しています。
int label = Integer.parseInt(file.toPath().getParent().getFileName().toString()); out.put(n, label, 1.0); // one-hot
これで、以下のように training データと testing データを DataSetIterator
の形で取得できるようになります。
public DataSetIterator iterator() { return iterator(mnist.trainingImages(), Long.valueOf(mnist.countTrainingImage()).intValue()); } public DataSetIterator iteratorTesting() { return iterator(mnist.testingImages(), Long.valueOf(mnist.countTestingImage()).intValue()); } public int nIn() { return mnist.nIn(); } public int outcomes() { return mnist.outcomes; }
全体像は MnistSet.java を参照してください。
ネットワークモデル
データセットの準備が整ったので、ネットワークモデルの作成に入ることができます。
NetworkModel
は以下のようなクラスとして定義します。
public class NetworkModel { private final MnistSet mnistSet; private final MultiLayerNetwork model; public NetworkModel(MnistSet mnistSet) { this.mnistSet = mnistSet; this.model = buildNetwork(mnistSet.nIn(), mnistSet.outcomes()); } // ... }
buildNetwork()
は前回までで説明した内容のまま、以下のようになります。
private static MultiLayerNetwork buildNetwork(int nIn, int nOut) { //create the first, input layer with xavier initialization DenseLayer denseLayer = new DenseLayer.Builder() .nIn(nIn) .nOut(1000) .activation(Activation.RELU) .weightInit(WeightInit.XAVIER) .build(); // create hidden layer OutputLayer outputLayer = new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) // loss function .nIn(1000) .nOut(nOut) .activation(Activation.SOFTMAX) .weightInit(WeightInit.XAVIER) .build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(123) //include a random seed for reproducibility // use stochastic gradient descent as an optimization algorithm .updater(new Nesterovs(0.006, 0.9)) .l2(1e-4) .list() .layer(denseLayer) .layer(outputLayer) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(500)); return model; }
ネットワークモデルのトレーニングは以下のように行うことができます。
public NetworkModel init() { int nEpochs = 2; // Number of training epochs model.fit(mnistSet.iterator(), nEpochs); Evaluation eval = model.evaluate(mnistSet.iteratorTesting()); System.out.println(eval.stats()); return this; }
model.fit(mnistSet.iterator(), nEpochs)
によりデータセットでモデルの学習を行うことができます。
モデルの保存と復元
一度学習したモデルは、保存して再利用できるようにしておきましょう。
学習済みのモデルは以下のように保存できます。
private static final String MODEL = "model.zip"; public void writeModel() { try { ModelSerializer.writeModel(model, MODEL, true); } catch (IOException e) { throw new RuntimeException(e); } }
保存したモデルからの復元は以下のようになります。
public MultiLayerNetwork restoreModel() { try { return ModelSerializer.restoreMultiLayerNetwork(MODEL); } catch (IOException e) { throw new RuntimeException(e); } }
モデルの保存と復元を行うよう、コンストラクタと init()
を以下のように変更しておきます。
public NetworkModel(MnistSet mnistSet) { // ... this.model = Paths.get(MODEL).toFile().exists() ? restoreModel() : buildNetwork(mnistSet.nIn(), mnistSet.outcomes()); } public NetworkModel init() { if (Paths.get(MODEL).toFile().exists()) { return this; } // ... writeModel(); return this; }
全体像は NetworkModel.java を参照してください。
サーバの実装
ブラウザから canvas に手書きした文字を認識できるようにサーバの実装を追加します。
クライアント側は、canvas へ入力したものを画像としてポストするようにします。
以下のような HTML を作成しておきます。
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <title>MNIST</title> </head> <body> <p>Clear : Right click</p> <canvas id="canvas" width="112" height="112" style="border: solid 1px #000;"></canvas> <p id="msg"></p> <script> const canvas = document.getElementById('canvas'); const ctx = canvas.getContext('2d'); ctx.lineWidth = 5; ctx.strokeStyle = 'rgb(255, 255, 255)'; ctx.lineCap = 'round'; ctx.fillStyle = 'black'; ctx.fillRect(0, 0, canvas.width, canvas.height); var clicked = false; canvas.addEventListener('mousedown', e => { if (event.which !== 1) return; clicked = true; ctx.beginPath(); ctx.moveTo(e.offsetX, e.offsetY); }); canvas.addEventListener('mouseup', e => { if (event.which !== 1) return; clicked = false; let data = canvas.toDataURL("image/png"); fetch("post", { method: "POST", body: JSON.stringify(data) }) .then(res => res.text()) .then(text => document.getElementById('msg').innerText = text); }); canvas.addEventListener('mousemove', e => { if (!clicked) return false; ctx.lineTo(e.offsetX, e.offsetY); ctx.stroke(); }); canvas.addEventListener('contextmenu', e => { e.preventDefault(); ctx.fillRect(0, 0, canvas.width, canvas.height); return false; }); </script> </body> </html>
サーバ側は、ポストされた画像を、構築済みのモデルで評価して結果をJSONで返すだけの単純なもので、以下のようになります。
public class Server { private final HttpServer server; private final String contextRoot; private final int port; private final Function<byte[], String> fn; private Server(String contextRoot, int port, Function<byte[], String> fn) { try { this.contextRoot = contextRoot; this.port = port; this.fn = fn; this.server = HttpServer.create(new InetSocketAddress(port), 0); this.server.createContext(contextRoot, this::handle); } catch (IOException e) { throw new RuntimeException(e); } } public Server(Function<byte[], String> fn) { this("/dl4j", 8080, fn); } public void start() { server.start(); } void handle(HttpExchange exchange) { String path = exchange.getRequestURI().normalize().getPath(); if (path.endsWith(".html")) { writePage("/index.html", exchange); } else { try (InputStream in = exchange.getRequestBody()) { var body = new String(in.readAllBytes(), StandardCharsets.UTF_8) .replace("\"data:image/png;base64,", "") .replace("\"", ""); byte[] bytes = Base64.getDecoder().decode(body); String res = fn.apply(bytes); writeJson(res, exchange); } catch (Exception e) { log.error(e.getMessage(), e); write(500, e.getMessage().getBytes(), "text/html", exchange); } } } private void writeJson(String res, HttpExchange exchange) { write(200, res.getBytes(StandardCharsets.UTF_8), "application/json", exchange); } private void writePage(String path, HttpExchange exchange) { try (InputStream is = getClass().getResourceAsStream(path)) { write(200, is.readAllBytes(), "text/html", exchange); } catch (IOException e) { throw new RuntimeException(e); } } private void write(int rCode, byte[] bytes, String contentType, HttpExchange exchange) { try (OutputStream os = exchange.getResponseBody()) { var header = String.format("%s; charset=%s", contentType, StandardCharsets.UTF_8); exchange.getResponseHeaders().set("Content-Type", header); exchange.sendResponseHeaders(rCode, bytes.length); os.write(bytes); } catch (IOException e) { throw new RuntimeException(e); } } }
全体像は Server.java を参照してください。
アプリケーションの起動
アプリケーションは以下のように起動し、手書き文字認識のサーバとして動作させます。
public class App { public static void main(String[] args) { var dataSet = new MnistSet(); var model = new NetworkModel(dataSet).init(); var server = new Server(model::outputAsString); server.start(); server.browse(); } }
以上の実装は以下を参照してください。
アプリケーションは以下で起動し、
$ ./gradlew run
ブラウザ画面から手書きした画像認識が可能となります。
まとめ
Deeplearning4j による手書き数字の画像認識の実装について見てきました。
実装内容を見た上で、もう一度以下の内容を見ると、理解しやすいかと思います。
ディープラーニングを Deeplearning4j でカジュアルに始める(その1) - A Memorandum
ディープラーニングを Deeplearning4j でカジュアルに始める(その2) - A Memorandum
今回のネットワークモデルは、サンプルとして以下の単純なものでした
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(123) .updater(new Nesterovs(0.006, 0.9)) .l2(1e-4) .list() .layer(denseLayer) .layer(outputLayer) .build();
ニューラルネットワークは様々に構成することができ、損失関数や初期値の与え方、重みの更新方法など、多くの手法が提案されています。
Deeplearning4j では、ネットワークの定義をビルダーパターンとして、様々なアルゴリズムを用いて構築することができます。
例えば、今回は扱わなかった畳み込み層を設けたネットワークを以下のように定義することができます。
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(123) .l2(0.0005) .weightInit(WeightInit.XAVIER) .updater(new Adam(1e-3)) .list() .layer(new ConvolutionLayer.Builder(5, 5) .nIn(1) .stride(1, 1) .nOut(20) .activation(Activation.IDENTITY) .build()) .layer(new SubsamplingLayer.Builder(PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(new ConvolutionLayer.Builder(5, 5) .stride(1,1) .nOut(50) .activation(Activation.IDENTITY) .build()) .layer(new SubsamplingLayer.Builder(PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(new DenseLayer.Builder().activation(Activation.RELU) .nOut(500).build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(nOut) .activation(Activation.SOFTMAX) .build()) .setInputType(InputType.convolutionalFlat(28, 28, 1)) .build();
畳み込み層(ConvolutionLayer
)とプーリング層(SubsamplingLayer
) によるネットワーク構成となっています。
パラメータの決定については、ある程度の知識と経験が必要になってきますが、先の実装のネットワーク定義を置き換えることで、認識精度の改善を行うことができます。
Deeplearning4j はもうすぐ 1.0.0 に達するので、Java でもニューラルネットワークを試してみてはいかがでしょうか。