Deeplearning4j: Deep learning and ETL for the JVM
- 11 Aug 2020
Eclipse Deeplearning4j is an open source, distributed, deep learning library for the JVM. Deeplearning4j is written in Java and is compatible with any JVM language, such as Scala, Clojure, or Kotlin. The underlying computations are written in C, C++, and Cuda. Keras will serve as the Python API. Integrated with Hadoop and Apache Spark, Deeplearning4j brings AI to business environments for use on distributed GPUs and CPUs.
Deeplearning4j is actually a stack of projects intended to support all the needs of a JVM-based deep learning application. Beyond Deeplearning4j itself (the high-level API), it includes ND4J (general-purpose linear algebra,), SameDiff (graph-based automatic differentiation), DataVec (ETL), Arbiter (hyperparameter search), and the C++ LibND4J (underpins all of the above). LibND4J in turns calls on standard libraries for CPU and GPU support, such as OpenBLAS, OneDNN (MKL-DNN), cuDNN, and cuBLAS.
The goal of Eclipse Deeplearning4j is to provide a core set of components for building applications that incorporate AI. AI products within an enterprise often have a wider scope than just machine learning. The overall goal of the distribution is to provide smart defaults for building deep learning applications.
PyTorch, probably the leading deep learning framework for research, only supports immediate mode; it has interfaces for Python, C++, and Java. H2O Sparkling Water integrates the H2O open source, distributed in-memory machine learning platform with Spark. H2O has interfaces for Java and Scala, Python, R, and H2O Flow notebooks.
Commercial support for Deeplearning4j can be purchased from Konduit, which also supports many of the developers working on the project.
How Deeplearning4j works
Deeplearning4j treats the tasks of loading data and training algorithms as separate processes. You load and transform the data using the DataVec library, and train models using tensors and the ND4J library.
You ingest data through a
RecordReader interface, and walk through the data using a
RecordReaderDataSetIterator. You can choose a
DataNormalization class to use as a preprocessor for your
DataSetIterator. Use the
ImagePreProcessingScaler for image data, the
NormalizerMinMaxScaler if you have a uniform range along all dimensions of your input data, and
NormalizerStandardize for most other cases. If necessary, you can implement a custom
DataSet objects are containers for the features and labels of your data, and keep the values in several instances of
INDArray: one for the features of your examples, one for the labels, and two additional ones for masking, if you are using time series data. In the case of the features, the
INDArray is a tensor of the size
Number of Examples x Number of Features. Typically you’ll divide the data into mini-batches for training; the number of examples in an
INDArray is small enough to fit in memory but large enough to get a good gradient.
If you look at the Deeplearning4j code for defining models, such as the Java example below, you’ll see that it’s a very high-level API, similar to Keras. In fact, the planned Python interface to Deeplearning4j will use Keras; right now, if you have a Keras model, you can import it into Deeplearning4j.
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Nesterovs(learningRate, 0.9)) .list( new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes).activation("relu").build(), new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD). activation("softmax").nIn(numHiddenNodes).nOut(numOutputs).build() ).backprop(true).build();
MultiLayerNetwork class is the simplest network configuration API available in Eclipse Deeplearning4j; for DAG structures, use the
ComputationGraph instead. Note that the optimization algorithm (SGD in this example) is specified separately from the updater (Nesterov in this example). This very simple neural network has one dense layer with a
ReLU activation function and one output layer with
-log(likelihood) loss and a
softmax activation function, and is solved by back propagation. More complex networks may also have
EmbeddingLayer, and others of the two dozen supported layer types and sixteen layer space types.
The simplest way to train the model is to call the
.fit() method on the model configuration with your
DataSetIterator as an argument. You can also reset the iterator and call the
.fit() method for as many epochs as you need, or use an
To test model performance, use an
Evaluation class to see how well the trained model fits your test data, which should not be the same as the training data.
Deeplearning4j provides a listener facility help you monitor your network’s performance visually, which will be called after each mini-batch is processed. One of most often used listeners is
Installing and testing Deeplearning4j
At the moment, the easiest way to try out Deeplearning4j is by using the official quick start. It requires a relatively recent version of Java, an installation of Maven, a working Git, and a copy of IntelliJ IDEA (preferred) or Eclipse. There are also a few user-contributed quick starts. Start by cloning the eclipse/deeplearning4j-examples repo to your own machine with Git or GitHub Desktop. Then install the projects with Maven from the dl4j-examples folder.
[email protected] dl4j-examples % mvn clean install [INFO] Scanning for projects... [WARNING] [WARNING] Some problems were encountered while building the effective model for org.deeplearning4j:dl4j-examples:jar:1.0.0-beta7 [WARNING] 'build.plugins.plugin.(groupId:artifactId)' must be unique but found duplicate declaration of plugin org.apache.maven.plugins:maven-compiler-plugin @ line 250, column 21 [WARNING] [WARNING] It is highly recommended to fix these problems because they threaten the stability of your build. [WARNING] [WARNING] For this reason, future Maven versions might no longer support building such malformed projects. [WARNING] [INFO] [INFO] ------------------< org.deeplearning4j:dl4j-examples >------------------ [INFO] Building Introduction to DL4J 1.0.0-beta7 [INFO] --------------------------------[ jar ]--------------------------------- Downloading from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-enforcer-plugin/1.0.1/maven-enforcer-plugin-1.0.1.pom Downloaded from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-enforcer-plugin/1.0.1/maven-enforcer-plugin-1.0.1.pom (6.5 kB at 4.4 kB/s) Downloading from central: https://repo.maven.apache.org/maven2/org/apache/maven/enforcer/enforcer/1.0.1/enforcer-1.0.1.pom Downloaded from central: https://repo.maven.apache.org/maven2/org/apache/maven/enforcer/enforcer/1.0.1/enforcer-1.0.1.pom (11 kB at 137 kB/s) Downloading from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-enforcer-plugin/1.0.1/maven-enforcer-plugin-1.0.1.jar Downloaded from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-enforcer-plugin/1.0.1/maven-enforcer-plugin-1.0.1.jar (22 kB at 396 kB/s) Downloading from central: https://repo.maven.apache.org/maven2/org/codehaus/mojo/exec-maven-plugin/1.4.0/exec-maven-plugin-1.4.0.pom Downloaded from central: https://repo.maven.apache.org/maven2/org/codehaus/mojo/exec-maven-plugin/1.4.0/exec-maven-plugin-1.4.0.pom (12 kB at 283 kB/s) Downloading from central: https://repo.maven.apache.org/maven2/org/codehaus/mojo/exec-maven-plugin/1.4.0/exec-maven-plugin-1.4.0.jar Downloaded from central: https://repo.maven.apache.org/maven2/org/codehaus/mojo/exec-maven-plugin/1.4.0/exec-maven-plugin-1.4.0.jar (46 kB at 924 kB/s) Downloading from central: https://repo.maven.apache.org/maven2/com/lewisd/lint-maven-plugin/0.0.11/lint-maven-plugin-0.0.11.pom Downloaded from central: https://repo.maven.apache.org/maven2/com/lewisd/lint-maven-plugin/0.0.11/lint-maven-plugin-0.0.11.pom (19 kB at 430 kB/s) Downloading from central: https://repo.maven.apache.org/maven2/com/lewisd/lint-maven-plugin/0.0.11/lint-maven-plugin-0.0.11.jar Downloaded from central: https://repo.maven.apache.org/maven2/com/lewisd/lint-maven-plugin/0.0.11/lint-maven-plugin-0.0.11.jar (106 kB at 1.6 MB/s) Downloading from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-compiler-plugin/3.5.1/maven-compiler-plugin-3.5.1.pom … [WARNING] - org.agrona.collections.Hashing [WARNING] - org.agrona.collections.Long2ObjectCache$ValueIterator [WARNING] - org.agrona.collections.Int2ObjectHashMap$EntrySet [WARNING] - org.agrona.concurrent.SleepingIdleStrategy [WARNING] - org.agrona.collections.MutableInteger [WARNING] - org.agrona.collections.Int2IntHashMap [WARNING] - org.agrona.collections.IntIntConsumer [WARNING] - org.agrona.concurrent.status.StatusIndicator [WARNING] - 175 more... [WARNING] javafx-base-14-mac.jar, javafx-graphics-14-mac.jar, jakarta.xml.bind-api-2.3.2.jar define 1 overlapping classes: [WARNING] - module-info [WARNING] protobuf-1.0.0-beta7.jar, guava-19.0.jar define 3 overlapping classes: [WARNING] - com.google.thirdparty.publicsuffix.TrieParser [WARNING] - com.google.thirdparty.publicsuffix.PublicSuffixPatterns [WARNING] - com.google.thirdparty.publicsuffix.PublicSuffixType [WARNING] jsr305-3.0.2.jar, guava-1.0.0-beta7.jar define 35 overlapping classes: [WARNING] - javax.annotation.RegEx [WARNING] - javax.annotation.concurrent.Immutable [WARNING] - javax.annotation.meta.TypeQualifierDefault [WARNING] - javax.annotation.meta.TypeQualifier [WARNING] - javax.annotation.Syntax [WARNING] - javax.annotation.CheckReturnValue [WARNING] - javax.annotation.CheckForNull [WARNING] - javax.annotation.Nonnull [WARNING] - javax.annotation.meta.TypeQualifierNickname [WARNING] - javax.annotation.MatchesPattern [WARNING] - 25 more... [WARNING] maven-shade-plugin has detected that some class files are [WARNING] present in two or more JARs. When this happens, only one [WARNING] single version of the class is copied to the uber jar. [WARNING] Usually this is not harmful and you can skip these warnings, [WARNING] otherwise try to manually exclude artifacts based on [WARNING] mvn dependency:tree -Ddetail=true and the above output. [WARNING] See http://maven.apache.org/plugins/maven-shade-plugin/ [INFO] Attaching shaded artifact. [INFO] [INFO] --- maven-install-plugin:2.4:install (default-install) @ dl4j-examples --- [INFO] Installing /Volumes/Data/repos/deeplearning4j-examples/dl4j-examples/target/dl4j-examples-1.0.0-beta7.jar to /Users/martinheller/.m2/repository/org/deeplearning4j/dl4j-examples/1.0.0-beta7/dl4j-examples-1.0.0-beta7.jar [INFO] Installing /Volumes/Data/repos/deeplearning4j-examples/dl4j-examples/pom.xml to /Users/martinheller/.m2/repository/org/deeplearning4j/dl4j-examples/1.0.0-beta7/dl4j-examples-1.0.0-beta7.pom [INFO] Installing /Volumes/Data/repos/deeplearning4j-examples/dl4j-examples/target/dl4j-examples-1.0.0-beta7-shaded.jar to /Users/martinheller/.m2/repository/org/deeplearning4j/dl4j-examples/1.0.0-beta7/dl4j-examples-1.0.0-beta7-shaded.jar [INFO] ------------------------------------------------------------------------ [INFO] BUILD SUCCESS [INFO] ------------------------------------------------------------------------ [INFO] Total time: 05:07 min [INFO] Finished at: 2020-07-10T10:58:55-04:00 [INFO] ------------------------------------------------------------------------ [email protected] dl4j-examples %
Once the installation is complete, open the dl4j-examples/ directory with IntelliJ IDEA and try running some of the examples.
The well-known Iris dataset has just 150 samples and is generally easy to model, although a few of the irises are often misclassified. The model used here is a three-layer dense neural network.
Running the Iris classifier shown in the previous figure yields a fairly good fit: accuracy, precision, recall, and F1 score are all ~98%. Note in the confusion matrix that only one of the test cases was misclassified.
The Linear Classifier demo runs in a few seconds and generates probability plots for the training and test datasets. The data was generated specifically to be linearly separable into two classes.
A multi-layer perceptron (MLP) classification model for the MNIST hand-written digit dataset yields accuracy, precision, recall, and F1 score all ~97% after about 14K iterations. That isn’t as good or as fast as the results of convolutional neural networks (such as LeNet) on this dataset.