ディープラーニングを Deeplearning4j でカジュアルに始める(その3)

f:id:Naotsugu:20210622222124g:plain


はじめに

本稿では Deeplearning4j を使った、ディープラーニングについて説明します。厳密な定義や数式には立ち入らず、意味合いと利用方法に焦点を当てていきます。

題材として、ディープラーニング界の Hello World である MNIST(Modified National Institute of Standards and Technology database)を使った手書き数字の画像認識を扱います。

ディープラーニングの考え方と、実際の実装をあわせて説明し、手書き数字の画像認識のアプリケーションを作成するまでをガイドします。今回は、全3回の内の 3 回目です。

blog1.mammb.com

前回までで、手書き数字の画像認識を行うための基本となる事柄を整理してきました。 今回は、この内容を踏まえてDeeplearning4j による実装を進めていきましょう。


Deeplearning4j のプロジェクトを準備する

ライブラリは deeplearning4j-corend4j-native-platform を利用します。

deeplearning4j-core はニューラルネットワークの実装を含み、nd4j-native-platform は行列ライブラリを提供します。

Maven の場合は以下の依存定義となります。

<dependencies>
  <dependency>
      <groupId>org.deeplearning4j</groupId>
      <artifactId>deeplearning4j-core</artifactId>
      <version>1.0.0-M1.1</version>
  </dependency>
  <dependency>
      <groupId>org.nd4j</groupId>
      <artifactId>nd4j-native-platform</artifactId>
      <version>1.0.0-M1.1</version>
  </dependency>
</dependencies>

今回は、Gradle の Kotlin スクリプトで作成するため、build.gradle.kts は以下のようなものを準備します。

plugins {
    application
}

repositories {
    mavenCentral()
}

dependencies {
    implementation("org.deeplearning4j:deeplearning4j-core:1.0.0-M1.1")
    implementation("org.nd4j:nd4j-native-platform:1.0.0-M1.1")
    implementation("org.slf4j:slf4j-jdk14:1.7.30")
}

application {
    mainClass.set("com.mammb.code.example.dl4j.App")
}

java {
    toolchain {
        languageVersion.set(JavaLanguageVersion.of(16))
    }
}


最初に、MNIST データベース からデータセットを作成していきます。

Deeplearning4j の提供する MnistDataSetIterator を使うこともできますが、ここでは一から作成していきます。


MNIST データベース

手書き数字のデータベースには、MNIST(Modified National Institute of Standards and Technology database) を使います。

PNG画像化されたものが https://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz にあるので、今回はこれを使うことにします。 mnist_png.tar.gz を解凍すると、trainingtesting ディレクトリの中に 0 ~ 9 のディレクトリがあり、それぞれ該当する数字の PNG 画像が格納される形となります。 このデータベースを Minst クラスとして作成します。

public class Mnist {

  private final URL url;

  private final Path baseDir;
  private final Path trainingDir;
  private final Path testingDir;

  final int imgHeight;
  final int imgWidth;
  final int outcomes;

  public Mnist() {
    try {
      this.url = new URL("https://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz");
      this.baseDir = Paths.get("./mnist");
      this.trainingDir = baseDir.resolve("mnist_png/training");
      this.testingDir = baseDir.resolve("mnist_png/testing");
      this.imgHeight = 28;
      this.imgWidth  = 28;
      this.outcomes = 10;
    } catch (MalformedURLException e) {
      throw new RuntimeException(e);
    }
  }

  // ...

}

trainingDirtestingDir が解凍後のテストデータのディレクトリを表します。 imgHeightimgWidth は PNG 画像のサイズ、outcomes が種類数で今回は 0 ~ 9 の10個となります。

Minst クラスからは、以下のように画像ファイルを取得できるものとします。

public Stream<File> testingImages() { }
public Stream<File> trainingImages() { }


最初に取り組むのは、MNIST ファイルのダウンロードと解凍になります。 .tar.gz の解凍はJDKの標準ライブラリで提供されないため commons-compress を使います。

GzipCompressorInputStreamTarArchiveInputStream を使い、以下のようにダウンロードと解凍を行うことができます。

private void fetch() {

    baseDir.toFile().mkdirs();

    try (InputStream gzi = new GzipCompressorInputStream(url.openStream());
         ArchiveInputStream in = new TarArchiveInputStream(gzi)) {

        ArchiveEntry entry;
        while ((entry = in.getNextEntry()) != null) {
            if (!in.canReadEntryData(entry)) {
                continue;
            }
            File file = baseDir.resolve(entry.getName()).toFile();
            if (entry.isDirectory()) {
                if (!file.isDirectory() && !file.mkdirs()) {
                    throw new IOException("failed to create directory " + file);
                }
            } else {
                File parent = file.getParentFile();
                if (!parent.isDirectory() && !parent.mkdirs()) {
                    throw new IOException("failed to create directory " + parent);
                }
                try (OutputStream o = Files.newOutputStream(file.toPath())) {
                    IOUtils.copy(in, o);
                }
            }
        }
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
}

ダウンロードファイルはそれなりに大きいので、既にダウンロード済みの場合は再利用します。 以下のユーティリティメソッドを用意しておきましょう。

  private Stream<File> images(Path path) {
    if (!exists()) {
        fetch();
    }
    return IntStream.range(0, outcomes)
            .mapToObj(Integer::toString)
            .map(path::resolve)
            .map(Path::toFile)
            .map(File::listFiles)
            .flatMap(Stream::of)
            .filter(File::isFile);
  }

  private boolean exists() {
    File file = baseDir.toFile();
    return file.exists() && file.list().length > 0;
  }

public メソッドとして以下のようにして、Minst クラスは完成にしましょう。

  public int nIn() {
      return imgHeight * imgWidth;
  }

  public long countTrainingImage() {
      return images(trainingDir).count();
  }


  public long countTestingImage() {
      return images(testingDir).count();
  }


  public Stream<File> trainingImages() {
      return images(trainingDir);
  }


  public Stream<File> testingImages() {
      return images(testingDir);
  }

  public File selectAny() {
      File[] files = testingDir.resolve(
              Integer.toString(new Random().nextInt(10)))
          .toFile()
          .listFiles();
      return files[new Random().nextInt(files.length)];
  }

全体像はMnist.java を参照してください。


データセットの定義

Deeplearning4j 用のデータセットとして MnistSet クラスを作成します。 このクラスは、ネットワークモデル構築に必要な情報と、ネットワークモデルに流すデータセット DataSetIterator を供給するものとし、以下の公開メソッドを提供するものとします。

  public DataSetIterator iterator() { }
  public DataSetIterator iteratorTesting() { }

  public int nIn() { }
  public int outcomes() { }

MnistSet クラスは以下のようにコンストラクトします。

public class MnistSet {

  private final Mnist mnist;
  private final NativeImageLoader imageLoader;
  private final ImagePreProcessingScaler imageScaler;

  public MnistSet() {
      this.mnist = new Mnist();
      this.imageLoader = new NativeImageLoader(mnist.imgHeight, mnist.imgWidth);
      this.imageScaler = new ImagePreProcessingScaler(0, 1);
  }
  // ...
}

Mnist は先に作成したMNISTデータベースです。画像ファイルは Deeplearning4j の提供する NativeImageLoader にて行列形式で読み込み、画像のスケール変換を、こちらも Deeplearning4j の提供する ImagePreProcessingScaler にて行います。

DataSetIterator の生成は以下のようになります。

  public DataSetIterator iterator(Stream<File> images, int samples) {
      final INDArray in  = Nd4j.create(samples, mnist.nIn());
      final INDArray out = Nd4j.create(samples, mnist.outcomes);

      final AtomicInteger index = new AtomicInteger();
      images.forEach(file -> {
          try {
              int n = index.getAndIncrement();

              INDArray img = imageLoader.asRowVector(file);
              imageScaler.transform(img);
              in.putRow(n, img);

              int label = Integer.parseInt(file.toPath().getParent().getFileName().toString());
              out.put(n, label, 1.0); // one-hot

          } catch (IOException e) {
              throw new RuntimeException(e);
          }
      });

      List<DataSet> list = new DataSet(in, out).asList();
      Collections.shuffle(list, new Random(System.currentTimeMillis()));

      int batchSize = 10;
      return new ListDataSetIterator<>(list, batchSize);
  }

INDArray in が全入力画像を配列形式としたもの、INDArray out が画像に対する正解ラベルの配列を表します。

これらの配列から、以下のように DataSet を作成し、ランダムに並べ替えたうえで、ListDataSetIterator を生成します。

List<DataSet> list = new DataSet(in, out).asList();
Collections.shuffle(list, new Random(System.currentTimeMillis()));

int batchSize = 10;
return new ListDataSetIterator<>(list, batchSize);

INDArray in に格納する1つの画像データは、以下のように配列形式とした上で、0から1の float 値に正規化したものを使います。

INDArray img = imageLoader.asRowVector(file);
imageScaler.transform(img);

正解ラベルは画像ファイルの格納ディレクトリ名から構築できるため、以下のように正解ラベルの位置に 1、その他は 0 となるものを準備しています。

int label = Integer.parseInt(file.toPath().getParent().getFileName().toString());
out.put(n, label, 1.0); // one-hot

これで、以下のように training データと testing データを DataSetIterator の形で取得できるようになります。

  public DataSetIterator iterator() {
      return iterator(mnist.trainingImages(),
              Long.valueOf(mnist.countTrainingImage()).intValue());
  }

  public DataSetIterator iteratorTesting() {
      return iterator(mnist.testingImages(),
              Long.valueOf(mnist.countTestingImage()).intValue());
  }

  public int nIn() {
      return mnist.nIn();
  }

  public int outcomes() {
      return mnist.outcomes;
  }

全体像は MnistSet.java を参照してください。


ネットワークモデル

データセットの準備が整ったので、ネットワークモデルの作成に入ることができます。

NetworkModel は以下のようなクラスとして定義します。

public class NetworkModel {

    private final MnistSet mnistSet;
    private final MultiLayerNetwork model;

    public NetworkModel(MnistSet mnistSet) {
        this.mnistSet = mnistSet;
        this.model = buildNetwork(mnistSet.nIn(), mnistSet.outcomes());
    }
    // ...
}

buildNetwork() は前回までで説明した内容のまま、以下のようになります。

  private static MultiLayerNetwork buildNetwork(int nIn, int nOut) {

      //create the first, input layer with xavier initialization
      DenseLayer denseLayer = new DenseLayer.Builder()
              .nIn(nIn)
              .nOut(1000)
              .activation(Activation.RELU)
              .weightInit(WeightInit.XAVIER)
              .build();

      // create hidden layer
      OutputLayer outputLayer = new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) // loss function
              .nIn(1000)
              .nOut(nOut)
              .activation(Activation.SOFTMAX)
              .weightInit(WeightInit.XAVIER)
              .build();

      MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
              .seed(123) //include a random seed for reproducibility
              // use stochastic gradient descent as an optimization algorithm
              .updater(new Nesterovs(0.006, 0.9))
              .l2(1e-4)
              .list()
              .layer(denseLayer)
              .layer(outputLayer)
              .build();

      MultiLayerNetwork model = new MultiLayerNetwork(conf);
      model.init();
      model.setListeners(new ScoreIterationListener(500));

      return model;
  }

ネットワークモデルのトレーニングは以下のように行うことができます。

  public NetworkModel init() {
      int nEpochs = 2; // Number of training epochs
      model.fit(mnistSet.iterator(), nEpochs);
      Evaluation eval = model.evaluate(mnistSet.iteratorTesting());
      System.out.println(eval.stats());
      return this;
  }

model.fit(mnistSet.iterator(), nEpochs) によりデータセットでモデルの学習を行うことができます。


モデルの保存と復元

一度学習したモデルは、保存して再利用できるようにしておきましょう。

学習済みのモデルは以下のように保存できます。

  private static final String MODEL = "model.zip";

  public void writeModel() {
      try {
          ModelSerializer.writeModel(model, MODEL, true);
      } catch (IOException e) {
          throw new RuntimeException(e);
      }
  }

保存したモデルからの復元は以下のようになります。

  public MultiLayerNetwork restoreModel() {
      try {
          return ModelSerializer.restoreMultiLayerNetwork(MODEL);
      } catch (IOException e) {
          throw new RuntimeException(e);
      }
  }


モデルの保存と復元を行うよう、コンストラクタと init() を以下のように変更しておきます。

    public NetworkModel(MnistSet mnistSet) {
        // ...
        this.model = Paths.get(MODEL).toFile().exists()
                ? restoreModel()
                : buildNetwork(mnistSet.nIn(), mnistSet.outcomes());
    }

    public NetworkModel init() {
        if (Paths.get(MODEL).toFile().exists()) {
            return this;
        }
        // ...
        writeModel();
        return this;
    }

全体像は NetworkModel.java を参照してください。


サーバの実装

ブラウザから canvas に手書きした文字を認識できるようにサーバの実装を追加します。

クライアント側は、canvas へ入力したものを画像としてポストするようにします。

以下のような HTML を作成しておきます。

<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>MNIST</title>
</head>
<body>
  <p>Clear : Right click</p>
  <canvas id="canvas" width="112" height="112" style="border: solid 1px #000;"></canvas>
  <p id="msg"></p>
<script>
    const canvas = document.getElementById('canvas');
    const ctx = canvas.getContext('2d');
    ctx.lineWidth = 5;
    ctx.strokeStyle = 'rgb(255, 255, 255)';
    ctx.lineCap = 'round';
    ctx.fillStyle = 'black';
    ctx.fillRect(0, 0, canvas.width, canvas.height);

    var clicked = false;
    canvas.addEventListener('mousedown', e => {
        if (event.which !== 1) return;
        clicked = true;
        ctx.beginPath();
        ctx.moveTo(e.offsetX, e.offsetY);
    });

    canvas.addEventListener('mouseup', e => {
        if (event.which !== 1) return;
        clicked = false;
        let data = canvas.toDataURL("image/png");
        fetch("post", {
            method: "POST",
            body: JSON.stringify(data)
        })
        .then(res => res.text())
        .then(text => document.getElementById('msg').innerText = text);
    });

    canvas.addEventListener('mousemove', e => {
        if (!clicked) return false;
        ctx.lineTo(e.offsetX, e.offsetY);
        ctx.stroke();
    });

    canvas.addEventListener('contextmenu', e => {
        e.preventDefault();
        ctx.fillRect(0, 0, canvas.width, canvas.height);
        return false;
    });

  </script>
</body>
</html>

サーバ側は、ポストされた画像を、構築済みのモデルで評価して結果をJSONで返すだけの単純なもので、以下のようになります。

public class Server {

    private final HttpServer server;
    private final String contextRoot;
    private final int port;
    private final Function<byte[], String> fn;

    private Server(String contextRoot, int port, Function<byte[], String> fn) {
        try {
            this.contextRoot = contextRoot;
            this.port = port;
            this.fn = fn;
            this.server = HttpServer.create(new InetSocketAddress(port), 0);
            this.server.createContext(contextRoot, this::handle);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public Server(Function<byte[], String> fn) {
        this("/dl4j", 8080, fn);
    }

    public void start() {
        server.start();
    }

    void handle(HttpExchange exchange) {
        String path = exchange.getRequestURI().normalize().getPath();
        if (path.endsWith(".html")) {
            writePage("/index.html", exchange);
        } else {
            try (InputStream in = exchange.getRequestBody()) {
                var body = new String(in.readAllBytes(), StandardCharsets.UTF_8)
                        .replace("\"data:image/png;base64,", "")
                        .replace("\"", "");
                byte[] bytes = Base64.getDecoder().decode(body);
                String res = fn.apply(bytes);
                writeJson(res, exchange);
            } catch (Exception e) {
                log.error(e.getMessage(), e);
                write(500, e.getMessage().getBytes(), "text/html", exchange);
            }
        }
    }

    private void writeJson(String res, HttpExchange exchange) {
        write(200, res.getBytes(StandardCharsets.UTF_8), "application/json", exchange);
    }

    private void writePage(String path, HttpExchange exchange) {
        try (InputStream is = getClass().getResourceAsStream(path)) {
            write(200, is.readAllBytes(), "text/html", exchange);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private void write(int rCode, byte[] bytes, String contentType, HttpExchange exchange) {
        try (OutputStream os = exchange.getResponseBody()) {
            var header = String.format("%s; charset=%s", contentType, StandardCharsets.UTF_8);
            exchange.getResponseHeaders().set("Content-Type", header);
            exchange.sendResponseHeaders(rCode, bytes.length);
            os.write(bytes);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

}

全体像は Server.java を参照してください。


アプリケーションの起動

アプリケーションは以下のように起動し、手書き文字認識のサーバとして動作させます。

public class App {
    public static void main(String[] args) {
        var dataSet = new MnistSet();
        var model = new NetworkModel(dataSet).init();
        var server = new Server(model::outputAsString);
        server.start();
        server.browse();
    }
}

以上の実装は以下を参照してください。

github.com

アプリケーションは以下で起動し、

$ ./gradlew run

ブラウザ画面から手書きした画像認識が可能となります。

f:id:Naotsugu:20210823211734p:plain


まとめ

Deeplearning4j による手書き数字の画像認識の実装について見てきました。

実装内容を見た上で、もう一度以下の内容を見ると、理解しやすいかと思います。

ディープラーニングを Deeplearning4j でカジュアルに始める(その1) - A Memorandum

ディープラーニングを Deeplearning4j でカジュアルに始める(その2) - A Memorandum


今回のネットワークモデルは、サンプルとして以下の単純なものでした

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(123)
        .updater(new Nesterovs(0.006, 0.9))
        .l2(1e-4)
        .list()
        .layer(denseLayer)
        .layer(outputLayer)
        .build();


ニューラルネットワークは様々に構成することができ、損失関数や初期値の与え方、重みの更新方法など、多くの手法が提案されています。

Deeplearning4j では、ネットワークの定義をビルダーパターンとして、様々なアルゴリズムを用いて構築することができます。

例えば、今回は扱わなかった畳み込み層を設けたネットワークを以下のように定義することができます。

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
      .seed(123)
      .l2(0.0005)
      .weightInit(WeightInit.XAVIER)
      .updater(new Adam(1e-3))
      .list()
      .layer(new ConvolutionLayer.Builder(5, 5)
              .nIn(1)
              .stride(1, 1)
              .nOut(20)
              .activation(Activation.IDENTITY)
              .build())
      .layer(new SubsamplingLayer.Builder(PoolingType.MAX)
              .kernelSize(2, 2)
              .stride(2, 2)
              .build())
      .layer(new ConvolutionLayer.Builder(5, 5)
              .stride(1,1)
              .nOut(50)
              .activation(Activation.IDENTITY)
              .build())
      .layer(new SubsamplingLayer.Builder(PoolingType.MAX)
              .kernelSize(2, 2)
              .stride(2, 2)
              .build())
      .layer(new DenseLayer.Builder().activation(Activation.RELU)
              .nOut(500).build())
      .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
              .nOut(nOut)
              .activation(Activation.SOFTMAX)
              .build())
      .setInputType(InputType.convolutionalFlat(28, 28, 1))
      .build();

畳み込み層(ConvolutionLayer)とプーリング層(SubsamplingLayer) によるネットワーク構成となっています。

パラメータの決定については、ある程度の知識と経験が必要になってきますが、先の実装のネットワーク定義を置き換えることで、認識精度の改善を行うことができます。

Deeplearning4j はもうすぐ 1.0.0 に達するので、Java でもニューラルネットワークを試してみてはいかがでしょうか。