Sequential LSTM Input Reshaping

Das ist ein Folgethema von

Hab es hinbekommen! Zumindest so halb…

ArrayList<Pair<double[], double[]>> data1 = new ArrayList<>();
ArrayList<Pair<double[], double[]>> data2 = new ArrayList<>();
try (BufferedReader br = new BufferedReader(new FileReader("C:\\...path...\\" + name + "_0.csv"))) {
	String l;
	while ((l = br.readLine()) != null) {
		Pair<double[], double[]> temp = new Pair<>(new double[5], new double[1]);
		String[] s = l.split(";");
		for (int i = 0; i < s.length; i++) {
			temp.getFirst()[i] = Double.parseDouble(s[i]);
		}
		temp.getSecond()[0] = temp.getFirst()[3];
		data1.add(temp);
	}
}
try (BufferedReader br = new BufferedReader(new FileReader("C:\\...path...\\" + name + "_1.csv"))) {
	String l;
	while ((l = br.readLine()) != null) {
		Pair<double[], double[]> temp = new Pair<>(new double[5], new double[1]);
		String[] s = l.split(";");
		for (int i = 0; i < s.length; i++) {
			temp.getFirst()[i] = Double.parseDouble(s[i]);
		}
		temp.getSecond()[0] = temp.getFirst()[3];
		data2.add(temp);
	}
}

int batchSize = 128;

DataSetIterator dsi1 = new DoublesDataSetIterator(data1, batchSize);
DataSetIterator dsi2 = new DoublesDataSetIterator(data2, batchSize);

// Set neural network parameters

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
		.seed(123) // Random number generator seed for improved repeatability. Optional.
		.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
		.weightInit(WeightInit.XAVIER)
		.updater(new Nesterovs(0.005))
		.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) // Not always required, but helps with this data set
		.gradientNormalizationThreshold(0.5)
		.list()
		.layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(5).nOut(10).build())
		.layer(1, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(2).build())
		.build();

MultiLayerNetwork mln = new MultiLayerNetwork(conf);

mln.fit(dsi1, 10);
Evaluation evaluate = mln.evaluate(dsi2);

// print the basic statistics about the trained classifier
System.out.println( "Accuracy:  " + evaluate.accuracy() );
System.out.println( "Precision: " + evaluate.precision() );
System.out.println( "Recall:    " + evaluate.recall() );

Fehlermeldung:
java.lang.IllegalStateException: 3D input expected to RNN layer expected, got 2

Das heißt, ich muss durch Reshaping aus dem zweidimensionalen Input einen dreidimensionalen Input machen. Weiß jemand, wie das geht?

Hab mir den Blog durchgelesen: https://machinelearningmastery.com/reshape-input-data-long-short-term-memory-networks-keras/
aber er schreibt eben (nur), man soll die reshape-Funktion verwenden. :frowning_face:

Gebe euch auch gerne Beispieldaten, wenn das hilfreich wäre…

Ich bekomme das mit dem Deeplearning4j Framework nicht hin, und werde wahrscheinlich auf Python ausweichen müssen, das wollte ich doch eigentlich vermeiden… :confounded: Weiß jemand zufällig, ob Neuroph[1][4] auch Recurrent neural networks[2], also keine Feed-forward neural networks unterstützt? Ich befürchte, nein. :sob: Andere (vollstädige) Implementierungen habe ich für Java nicht gefunden. Neuroph beinhaltet zwar Hopfield networks[3], aber das will ja nur 0 oder 1 haben. :confused: Thus far, schönes WE.

[1] http://neuroph.sourceforge.net/
[4] http://neuroph.sourceforge.net/javadoc/org/neuroph/util/NeuralNetworkType.html
[2] https://en.wikipedia.org/wiki/Recurrent_neural_network#Long_short-term_memory
[3] https://de.wikipedia.org/wiki/Hopfield-Netz

Moin, könnt ihr mir sagen, was ich falsch mache? Ich habe es jetzt mit Python probiert (tensorflow, keras und pyplot), und meine Vorhersage für die nächsten 100 Werte der Reihe ist immer 0 :frowning_face:

Figure_1

import numpy as np
import tensorflow as tf
from tensorflow import keras

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout

from sklearn.preprocessing import MinMaxScaler

# importing the required module
import matplotlib.pyplot as plt

data = []
with open("BTCUSDT_0.csv") as fp:
    lines = fp.readlines()
    for c, line in enumerate(lines):
        x = line.split(";")
        data.append((c, (float(x[1]) + float(x[2])) / 2.0))

len0 = len(data)
len1 = 400

sc_x = MinMaxScaler()
sc_y = MinMaxScaler()

train_data_x = sc_x.fit_transform(data[:len1]).reshape(-1, 1, 1)
train_data_y = []
for i in range(len1):
    train_data_y.append(data[i][0])
    train_data_y.append(data[i][0])
train_data_y = sc_y.fit_transform(np.reshape(train_data_y, (-1, 1)))
print(train_data_x.shape)

regressor = Sequential()

regressor.add(
    LSTM(units=50, return_sequences=True, input_shape=(train_data_x.shape[1], 1))
)
regressor.add(Dropout(0.2))

regressor.add(LSTM(units=50, return_sequences=True))
regressor.add(Dropout(0.2))

regressor.add(LSTM(units=50, return_sequences=True))
regressor.add(Dropout(0.2))

regressor.add(LSTM(units=50))
regressor.add(Dropout(0.2))

regressor.add(Dense(units=1))

regressor.compile(optimizer="adam", loss="mean_squared_error")

regressor.fit(train_data_x, train_data_y, epochs=100, batch_size=32)

data2 = []
for i in range(len0):
    if i < len1:
        data2.append(data[i])
    else:
        data2.append((data[i][0], 0))
data2 = sc_x.fit_transform(data2).reshape((-1, 1, 1))
pre = regressor.predict(data2)
pre = sc_x.inverse_transform(pre.reshape(-1, 2))
data3 = []
for i in range(len0):
    data3.append(pre[i][1])
xs1 = []
for x in data:
    xs1.append(x[1])
xs2 = []
for x in data3:
    xs2.append(x)
print(len(xs1), len(xs2))
plt.plot(xs1, label="data")
plt.plot(xs2, label="data2")
plt.legend()
plt.show()

Ich weiß, dieser Code ist alles andere als schön… Aber ich würde mich freuen, wenn es wenigstens etwas funktionieren würde. Das darf doch alles nicht wahr sein…

Hab es hinbekommen :slight_smile: hab nur ein paar grundlegende Fehler gemacht, aber @eagleeye konnte mir im Discord auf die Sprünge helfen. :slight_smile:

Moin, es funktioniert leider immer noch nicht (wie ich mir das vorstelle). Wie hier zu sehen ist:

500 Werte sind vorgegeben, mit diesen Werten füttere ich das NN. Anschließend soll es eine Vorhersage für 600 Werte machen, von denen ich die letzten 100 Werte randomisiert habe…

Aber augenscheinlich genau das, was ich „reinstecke“, kommt wieder raus…

import numpy as np
import tensorflow as tf
from tensorflow import keras

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout

from sklearn.preprocessing import MinMaxScaler

import matplotlib.pyplot as plt

import sys
import random
import datetime
import matplotlib

data = []
with open(sys.argv[1]) as fp:
    lines = fp.readlines()
    for c, line in enumerate(lines):
        x = line.split(";")
        data.append((float(x[1]) + float(x[2])) / 2.0)

data2 = []
for i in range(9):
    temp = []
    for j in range(i * 50, i * 50 + 100):
        temp.append(data[j])
    data2.append(temp)

sc_x = MinMaxScaler()
sc_y = MinMaxScaler()

train_data_x = sc_x.fit_transform(data2).reshape(-1, 1, 1)
train_data_y = []
for i in range(9):
    for j in range(i * 50, i * 50 + 100):
        train_data_y.append(j)
train_data_y = sc_y.fit_transform(np.reshape(train_data_y, (-1, 1)))

print(train_data_x.shape, train_data_y.shape)

units_c = 100

regressor = Sequential()

regressor.add(
    LSTM(units=units_c, return_sequences=True, input_shape=(train_data_x.shape[1], 1))
)
regressor.add(Dropout(0.2))

regressor.add(LSTM(units=units_c, return_sequences=True))
regressor.add(Dropout(0.2))

regressor.add(LSTM(units=units_c, return_sequences=True))
regressor.add(Dropout(0.2))

regressor.add(LSTM(units=units_c))
regressor.add(Dropout(0.2))

regressor.add(Dense(units=1))

regressor.compile(optimizer="adam", loss="mean_squared_error")

regressor.fit(train_data_x, train_data_y, epochs=20, batch_size=16)

data3 = []
for i in range(500):
    data3.append(data[i])
for i in range(100):
    data3.append((1.01 - random.random() * 0.02) * data[-1])
test_data_x = sc_x.fit_transform([data3]).reshape(-1, 1, 1)
pre = regressor.predict(test_data_x)
pre = sc_x.inverse_transform(pre.reshape(1, -1)).reshape(-1)

xs1 = []
for i in range(500):
    xs1.append(data[i])
for i in range(100):
    xs1.append(data[-1])
xs2 = []
for i in range(600):
    xs2.append(pre[i])

nowy = datetime.datetime.now()
ys = []
for i in range(600):
    ys.append(nowy + datetime.timedelta(0, 300 * (i - 500)))
dates = matplotlib.dates.date2num(ys)

print(len(xs1), len(xs2), len(ys))
plt.plot_date(dates, xs1, label="data", linestyle="solid")
plt.plot_date(dates, xs2, label="data2", linestyle="solid")
plt.legend()
plt.show()

Und, btw, wofür ist train_data_y gut?

It’s all about the input and output (transformation). Es kommt etwas heraus, es sieht „schön“ aus, wie kann ich feststellen, ob das auch nur ansatzweise richtig wäre? Also wie sinnvoll der output wäre?:

Die blaue Linie ist eine gedankliche Stütze, ab ihr beginnt die prediction.

Das hat schon einmal jemand gemacht:

https://github.com/IsaacChanghau/StockPrediction/blob/master/src/main/java/com/isaac/stock/predict/StockPricePrediction.java

List<Pair<INDArray, INDArray>> test = iterator.getTestDataSet();

… jetzt muss ich nur noch herausfinden, was das ist und was die Library macht…

Hab es mit dem OpenJDK JDK 15 zum Laufen bekommen. Zwei Probleme, einmal Lombok und einmal gibt es das import javafx.util.Pair; nicht mehr. (Zudem muss der JAVA_HOME-Path für Lombok richtig sein).

pom.xml (Lombok und <version>2.3.2</version> hinzugefügt):

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.isaac.stock</groupId>
    <artifactId>StockPrediction</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <deeplearning4j.version>0.9.1</deeplearning4j.version>
        <slf4j.version>1.7.21</slf4j.version>
        <opencsv.version>3.9</opencsv.version>
        <guava.version>23.0</guava.version>
        <jfreechart.version>1.0.19</jfreechart.version>
        <spark.version>2.1.0</spark.version>
    </properties>

    <dependencies>
        <!-- DL4J and ND4J Related Dependencies -->
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-api</artifactId>
            <version>${deeplearning4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native</artifactId>
            <version>${deeplearning4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${deeplearning4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.datavec</groupId>
            <artifactId>datavec-api</artifactId>
            <version>${deeplearning4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.datavec</groupId>
            <artifactId>datavec-dataframe</artifactId>
            <version>${deeplearning4j.version}</version>
        </dependency>
        <!-- Spark Dependencies -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.10</artifactId>
            <version>${spark.version}</version>
            <exclusions>
                <exclusion>
                    <groupId>org.slf4j</groupId>
                    <artifactId>slf4j-log4j12</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.10</artifactId>
            <version>${spark.version}</version>
            <exclusions>
                <exclusion>
                    <groupId>org.slf4j</groupId>
                    <artifactId>slf4j-log4j12</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.10</artifactId>
            <version>${spark.version}</version>
            <exclusions>
                <exclusion>
                    <groupId>org.slf4j</groupId>
                    <artifactId>slf4j-log4j12</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <!-- OpenCSV Dependency -->
        <dependency>
            <groupId>com.opencsv</groupId>
            <artifactId>opencsv</artifactId>
            <version>${opencsv.version}</version>
        </dependency>
        <!-- SLF4J Dependency -->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-simple</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <!-- JFreeChart Dependency -->
        <dependency>
    		<groupId>org.jfree</groupId>
    		<artifactId>jfreechart</artifactId>
    		<version>${jfreechart.version}</version>
		</dependency>
        <!-- Guava dependency -->
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>${guava.version}</version>
        </dependency>
		

	<dependency>
		<groupId>org.projectlombok</groupId>
		<artifactId>lombok</artifactId>
		<version>1.18.16</version>
		<scope>provided</scope>
	</dependency>

		
    </dependencies>

    <build>
        <finalName>StockPrediction</finalName>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
				<version>2.3.2</version>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-shade-plugin</artifactId>
                <version>2.3</version>
                <executions>
                    <execution>
                        <phase>package</phase>
                        <goals>
                            <goal>shade</goal>
                        </goals>
                        <configuration>
                            <transformers>
                                <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
                                    <mainClass>com.isaac.stock.predict.StockPricePrediction</mainClass>
                                </transformer>
                            </transformers>
                        </configuration>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>

</project>

/StockPrediction/src/main/java/com/isaac/stock/representation/MyEntry.java hinzugefügt (an Stelle von Pair):

package com.isaac.stock.representation;
import java.util.Map;

public class MyEntry<K, V> implements Map.Entry<K, V> {
	private final K key;
	private V value;

	public MyEntry(K key, V value) {
		this.key = key;
		this.value = value;
	}

	@Override
	public K getKey() {
		return key;
	}

	@Override
	public V getValue() {
		return value;
	}

	@Override
	public V setValue(V value) {
		V old = this.value;
		this.value = value;
		return old;
	}
}

/StockPrediction/src/main/java/com/isaac/stock/representation/StockDataSetIterator.java angepasst:

package com.isaac.stock.representation;

import com.google.common.collect.ImmutableMap;
import com.opencsv.CSVReader;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

import java.io.FileReader;
import java.io.IOException;
import java.util.*;

/**
 * Created by zhanghao on 26/7/17.
 * Modified by zhanghao on 28/9/17.
 * @author ZHANG HAO
 */
public class StockDataSetIterator implements DataSetIterator {

    /** category and its index */
    private final Map<PriceCategory, Integer> featureMapIndex = ImmutableMap.of(PriceCategory.OPEN, 0, PriceCategory.CLOSE, 1,
            PriceCategory.LOW, 2, PriceCategory.HIGH, 3, PriceCategory.VOLUME, 4);

    private final int VECTOR_SIZE = 5; // number of features for a stock data
    private int miniBatchSize; // mini-batch size
    private int exampleLength = 22; // default 22, say, 22 working days per month
    private int predictLength = 1; // default 1, say, one day ahead prediction

    /** minimal values of each feature in stock dataset */
    private double[] minArray = new double[VECTOR_SIZE];
    /** maximal values of each feature in stock dataset */
    private double[] maxArray = new double[VECTOR_SIZE];

    /** feature to be selected as a training target */
    private PriceCategory category;

    /** mini-batch offset */
    private LinkedList<Integer> exampleStartOffsets = new LinkedList<>();

    /** stock dataset for training */
    private List<StockData> train;
    /** adjusted stock dataset for testing */
    private List<MyEntry<INDArray, INDArray>> test;

    public StockDataSetIterator (String filename, String symbol, int miniBatchSize, int exampleLength, double splitRatio, PriceCategory category) {
        List<StockData> stockDataList = readStockDataFromFile(filename, symbol);
        this.miniBatchSize = miniBatchSize;
        this.exampleLength = exampleLength;
        this.category = category;
        int split = (int) Math.round(stockDataList.size() * splitRatio);
        train = stockDataList.subList(0, split);
        test = generateTestDataSet(stockDataList.subList(split, stockDataList.size()));
        initializeOffsets();
    }

    /** initialize the mini-batch offsets */
    private void initializeOffsets () {
        exampleStartOffsets.clear();
        int window = exampleLength + predictLength;
        for (int i = 0; i < train.size() - window; i++) { exampleStartOffsets.add(i); }
    }

    public List<MyEntry<INDArray, INDArray>> getTestDataSet() { return test; }

    public double[] getMaxArray() { return maxArray; }

    public double[] getMinArray() { return minArray; }

    public double getMaxNum (PriceCategory category) { return maxArray[featureMapIndex.get(category)]; }

    public double getMinNum (PriceCategory category) { return minArray[featureMapIndex.get(category)]; }

    @Override
    public DataSet next(int num) {
        if (exampleStartOffsets.size() == 0) throw new NoSuchElementException();
        int actualMiniBatchSize = Math.min(num, exampleStartOffsets.size());
        INDArray input = Nd4j.create(new int[] {actualMiniBatchSize, VECTOR_SIZE, exampleLength}, 'f');
        INDArray label;
        if (category.equals(PriceCategory.ALL)) label = Nd4j.create(new int[] {actualMiniBatchSize, VECTOR_SIZE, exampleLength}, 'f');
        else label = Nd4j.create(new int[] {actualMiniBatchSize, predictLength, exampleLength}, 'f');
        for (int index = 0; index < actualMiniBatchSize; index++) {
            int startIdx = exampleStartOffsets.removeFirst();
            int endIdx = startIdx + exampleLength;
            StockData curData = train.get(startIdx);
            StockData nextData;
            for (int i = startIdx; i < endIdx; i++) {
                int c = i - startIdx;
                input.putScalar(new int[] {index, 0, c}, (curData.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]));
                input.putScalar(new int[] {index, 1, c}, (curData.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
                input.putScalar(new int[] {index, 2, c}, (curData.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
                input.putScalar(new int[] {index, 3, c}, (curData.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
                input.putScalar(new int[] {index, 4, c}, (curData.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
                nextData = train.get(i + 1);
                if (category.equals(PriceCategory.ALL)) {
                    label.putScalar(new int[] {index, 0, c}, (nextData.getOpen() - minArray[1]) / (maxArray[1] - minArray[1]));
                    label.putScalar(new int[] {index, 1, c}, (nextData.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
                    label.putScalar(new int[] {index, 2, c}, (nextData.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
                    label.putScalar(new int[] {index, 3, c}, (nextData.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
                    label.putScalar(new int[] {index, 4, c}, (nextData.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
                } else {
                    label.putScalar(new int[]{index, 0, c}, feedLabel(nextData));
                }
                curData = nextData;
            }
            if (exampleStartOffsets.size() == 0) break;
        }
        return new DataSet(input, label);
    }

    private double feedLabel(StockData data) {
        double value;
        switch (category) {
            case OPEN: value = (data.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]); break;
            case CLOSE: value = (data.getClose() - minArray[1]) / (maxArray[1] - minArray[1]); break;
            case LOW: value = (data.getLow() - minArray[2]) / (maxArray[2] - minArray[2]); break;
            case HIGH: value = (data.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]); break;
            case VOLUME: value = (data.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]); break;
            default: throw new NoSuchElementException();
        }
        return value;
    }

    @Override public int totalExamples() { return train.size() - exampleLength - predictLength; }

    @Override public int inputColumns() { return VECTOR_SIZE; }

    @Override public int totalOutcomes() {
        if (this.category.equals(PriceCategory.ALL)) return VECTOR_SIZE;
        else return predictLength;
    }

    @Override public boolean resetSupported() { return false; }

    @Override public boolean asyncSupported() { return false; }

    @Override public void reset() { initializeOffsets(); }

    @Override public int batch() { return miniBatchSize; }

    @Override public int cursor() { return totalExamples() - exampleStartOffsets.size(); }

    @Override public int numExamples() { return totalExamples(); }

    @Override public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        throw new UnsupportedOperationException("Not Implemented");
    }

    @Override public DataSetPreProcessor getPreProcessor() { throw new UnsupportedOperationException("Not Implemented"); }

    @Override public List<String> getLabels() { throw new UnsupportedOperationException("Not Implemented"); }

    @Override public boolean hasNext() { return exampleStartOffsets.size() > 0; }

    @Override public DataSet next() { return next(miniBatchSize); }
    
    private List<MyEntry<INDArray, INDArray>> generateTestDataSet (List<StockData> stockDataList) {
    	int window = exampleLength + predictLength;
    	List<MyEntry<INDArray, INDArray>> test = new ArrayList<>();
    	for (int i = 0; i < stockDataList.size() - window; i++) {
    		INDArray input = Nd4j.create(new int[] {exampleLength, VECTOR_SIZE}, 'f');
    		for (int j = i; j < i + exampleLength; j++) {
    			StockData stock = stockDataList.get(j);
    			input.putScalar(new int[] {j - i, 0}, (stock.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]));
    			input.putScalar(new int[] {j - i, 1}, (stock.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
    			input.putScalar(new int[] {j - i, 2}, (stock.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
    			input.putScalar(new int[] {j - i, 3}, (stock.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
    			input.putScalar(new int[] {j - i, 4}, (stock.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
    		}
            StockData stock = stockDataList.get(i + exampleLength);
            INDArray label;
            if (category.equals(PriceCategory.ALL)) {
                label = Nd4j.create(new int[]{VECTOR_SIZE}, 'f'); // ordering is set as 'f', faster construct
                label.putScalar(new int[] {0}, stock.getOpen());
                label.putScalar(new int[] {1}, stock.getClose());
                label.putScalar(new int[] {2}, stock.getLow());
                label.putScalar(new int[] {3}, stock.getHigh());
                label.putScalar(new int[] {4}, stock.getVolume());
            } else {
                label = Nd4j.create(new int[] {1}, 'f');
                switch (category) {
                    case OPEN: label.putScalar(new int[] {0}, stock.getOpen()); break;
                    case CLOSE: label.putScalar(new int[] {0}, stock.getClose()); break;
                    case LOW: label.putScalar(new int[] {0}, stock.getLow()); break;
                    case HIGH: label.putScalar(new int[] {0}, stock.getHigh()); break;
                    case VOLUME: label.putScalar(new int[] {0}, stock.getVolume()); break;
                    default: throw new NoSuchElementException();
                }
            }
    		test.add(new MyEntry<>(input, label));
    	}
    	return test;
    }

	private List<StockData> readStockDataFromFile (String filename, String symbol) {
        List<StockData> stockDataList = new ArrayList<>();
        try {
            for (int i = 0; i < maxArray.length; i++) { // initialize max and min arrays
                maxArray[i] = Double.MIN_VALUE;
                minArray[i] = Double.MAX_VALUE;
            }
            List<String[]> list = new CSVReader(new FileReader(filename)).readAll(); // load all elements in a list
            for (String[] arr : list) {
                if (!arr[1].equals(symbol)) continue;
                double[] nums = new double[VECTOR_SIZE];
                for (int i = 0; i < arr.length - 2; i++) {
                    nums[i] = Double.valueOf(arr[i + 2]);
                    if (nums[i] > maxArray[i]) maxArray[i] = nums[i];
                    if (nums[i] < minArray[i]) minArray[i] = nums[i];
                }
                stockDataList.add(new StockData(arr[0], arr[1], nums[0], nums[1], nums[2], nums[3], nums[4]));
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return stockDataList;
    }
}

Und die Haupt-Klasse /StockPrediction/src/main/java/com/isaac/stock/predict/StockPricePrediction.java angepasst:

package com.isaac.stock.predict;

import com.isaac.stock.model.RecurrentNets;
import com.isaac.stock.representation.MyEntry;
import com.isaac.stock.representation.PriceCategory;
import com.isaac.stock.representation.StockDataSetIterator;
import com.isaac.stock.utils.PlotUtil;

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.NoSuchElementException;

/**
 * Created by zhanghao on 26/7/17.
 * Modified by zhanghao on 28/9/17.
 * @author ZHANG HAO
 */
public class StockPricePrediction {

    private static final Logger log = LoggerFactory.getLogger(StockPricePrediction.class);

    private static int exampleLength = 22; // time series length, assume 22 working days per month

    public static void main (String[] args) throws IOException {
        String file = new ClassPathResource("prices-split-adjusted.csv").getFile().getAbsolutePath();
        String symbol = "GOOG"; // stock name
        int batchSize = 64; // mini-batch size
        double splitRatio = 0.9; // 90% for training, 10% for testing
        int epochs = 100; // training epochs

        log.info("Create dataSet iterator...");
        PriceCategory category = PriceCategory.CLOSE; // CLOSE: predict close price
        StockDataSetIterator iterator = new StockDataSetIterator(file, symbol, batchSize, exampleLength, splitRatio, category);
        log.info("Load test dataset...");
        List<MyEntry<INDArray, INDArray>> test = iterator.getTestDataSet();

        log.info("Build lstm networks...");
        MultiLayerNetwork net = RecurrentNets.buildLstmNetworks(iterator.inputColumns(), iterator.totalOutcomes());

        log.info("Training...");
        for (int i = 0; i < epochs; i++) {
            while (iterator.hasNext()) net.fit(iterator.next()); // fit model using mini-batch data
            iterator.reset(); // reset iterator
            net.rnnClearPreviousState(); // clear previous state
        }

        log.info("Saving model...");
        File locationToSave = new File("src/main/resources/StockPriceLSTM_".concat(String.valueOf(category)).concat(".zip"));
        // saveUpdater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this to train your network more in the future
        ModelSerializer.writeModel(net, locationToSave, true);

        log.info("Load model...");
        net = ModelSerializer.restoreMultiLayerNetwork(locationToSave);

        log.info("Testing...");
        if (category.equals(PriceCategory.ALL)) {
            INDArray max = Nd4j.create(iterator.getMaxArray());
            INDArray min = Nd4j.create(iterator.getMinArray());
            predictAllCategories(net, test, max, min);
        } else {
            double max = iterator.getMaxNum(category);
            double min = iterator.getMinNum(category);
            predictPriceOneAhead(net, test, max, min, category);
        }
        log.info("Done...");
    }

    /** Predict one feature of a stock one-day ahead */
    private static void predictPriceOneAhead (MultiLayerNetwork net, List<MyEntry<INDArray, INDArray>> testData, double max, double min, PriceCategory category) {
        double[] predicts = new double[testData.size()];
        double[] actuals = new double[testData.size()];
        for (int i = 0; i < testData.size(); i++) {
            predicts[i] = net.rnnTimeStep(testData.get(i).getKey()).getDouble(exampleLength - 1) * (max - min) + min;
            actuals[i] = testData.get(i).getValue().getDouble(0);
        }
        log.info("Print out Predictions and Actual Values...");
        log.info("Predict,Actual");
        for (int i = 0; i < predicts.length; i++) log.info(predicts[i] + "," + actuals[i]);
        log.info("Plot...");
        PlotUtil.plot(predicts, actuals, String.valueOf(category));
    }

    private static void predictPriceMultiple (MultiLayerNetwork net, List<MyEntry<INDArray, INDArray>> testData, double max, double min) {
        // TODO
    }

    /** Predict all the features (open, close, low, high prices and volume) of a stock one-day ahead */
    private static void predictAllCategories (MultiLayerNetwork net, List<MyEntry<INDArray, INDArray>> testData, INDArray max, INDArray min) {
        INDArray[] predicts = new INDArray[testData.size()];
        INDArray[] actuals = new INDArray[testData.size()];
        for (int i = 0; i < testData.size(); i++) {
            predicts[i] = net.rnnTimeStep(testData.get(i).getKey()).getRow(exampleLength - 1).mul(max.sub(min)).add(min);
            actuals[i] = testData.get(i).getValue();
        }
        log.info("Print out Predictions and Actual Values...");
        log.info("Predict\tActual");
        for (int i = 0; i < predicts.length; i++) log.info(predicts[i] + "\t" + actuals[i]);
        log.info("Plot...");
        for (int n = 0; n < 5; n++) {
            double[] pred = new double[predicts.length];
            double[] actu = new double[actuals.length];
            for (int i = 0; i < predicts.length; i++) {
                pred[i] = predicts[i].getDouble(n);
                actu[i] = actuals[i].getDouble(n);
            }
            String name;
            switch (n) {
                case 0: name = "Stock OPEN Price"; break;
                case 1: name = "Stock CLOSE Price"; break;
                case 2: name = "Stock LOW Price"; break;
                case 3: name = "Stock HIGH Price"; break;
                case 4: name = "Stock VOLUME Amount"; break;
                default: throw new NoSuchElementException();
            }
            PlotUtil.plot(pred, actu, name);
        }
    }

}

Jetzt läuft es so vor sich hin und berechnet eine Prediction…

Der nächste Schritt wird sein, alle relevanten Teile der Library in mein Projekt zu integrieren (MultiLayerNetwork net = RecurrentNets.buildLstmNetworks(iterator.inputColumns(), iterator.totalOutcomes()); usw.) This will be fun. :smiley:

Edit: Und hier ist die Prediction, afaik, wurde das hintere Zehntel vorhergesagt:

image

Ich nerve euch bestimmt schon sehr. Aber ich habe es jetzt in mein Projekt integriert und es lüppt nicht mehr. :confused:

StockDataSetIterator-Adaption:


import com.google.common.collect.ImmutableMap;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

import java.util.*;

/**
 * Created by zhanghao on 26/7/17. Modified by zhanghao on 28/9/17.
 * 
 * @author ZHANG HAO
 */
public class StockDataSetIterator implements DataSetIterator {
	private static final long serialVersionUID = 1L;

	public static class StockData {
		private String date; // date
		private String symbol; // stock name

		private double open; // open price
		private double close; // close price
		private double low; // low price
		private double high; // high price
		private double volume; // volume

		public StockData() {
		}

		public StockData(String date, String symbol, double open, double close, double low, double high, double volume) {
			this.date = date;
			this.symbol = symbol;
			this.open = open;
			this.close = close;
			this.low = low;
			this.high = high;
			this.volume = volume;
		}

		public String getDate() {
			return date;
		}

		public void setDate(String date) {
			this.date = date;
		}

		public String getSymbol() {
			return symbol;
		}

		public void setSymbol(String symbol) {
			this.symbol = symbol;
		}

		public double getOpen() {
			return open;
		}

		public void setOpen(double open) {
			this.open = open;
		}

		public double getClose() {
			return close;
		}

		public void setClose(double close) {
			this.close = close;
		}

		public double getLow() {
			return low;
		}

		public void setLow(double low) {
			this.low = low;
		}

		public double getHigh() {
			return high;
		}

		public void setHigh(double high) {
			this.high = high;
		}

		public double getVolume() {
			return volume;
		}

		public void setVolume(double volume) {
			this.volume = volume;
		}
	}

	public static enum PriceCategory {
		OPEN, CLOSE, LOW, HIGH, VOLUME, ALL;
	}

	/** category and its index */
	private final Map<PriceCategory, Integer> featureMapIndex = ImmutableMap.of(PriceCategory.OPEN, 0, PriceCategory.CLOSE, 1, PriceCategory.LOW, 2, PriceCategory.HIGH, 3, PriceCategory.VOLUME, 4);

	private final int VECTOR_SIZE = 5; // number of features for a stock data
	private int miniBatchSize; // mini-batch size
	private int exampleLength = 22; // default 22, say, 22 working days per month
	private int predictLength = 1; // default 1, say, one day ahead prediction

	/** minimal values of each feature in stock dataset */
	private double[] minArray = new double[VECTOR_SIZE];
	/** maximal values of each feature in stock dataset */
	private double[] maxArray = new double[VECTOR_SIZE];

	/** feature to be selected as a training target */
	private PriceCategory category;

	/** mini-batch offset */
	private LinkedList<Integer> exampleStartOffsets = new LinkedList<>();

	/** stock dataset for training */
	private List<StockData> train;
	/** adjusted stock dataset for testing */
	private List<MyEntry<INDArray, INDArray>> test;

	public StockDataSetIterator(List<StockData> stockDataList) {
		this.miniBatchSize = 64;
		this.exampleLength = 22;
		this.category = PriceCategory.ALL;
		int split = (int) Math.round(stockDataList.size() * 0.9);
		train = stockDataList.subList(0, split);
		test = generateTestDataSet(stockDataList.subList(split, stockDataList.size()));
		initializeOffsets();
	}

//    public StockDataSetIterator (String filename, String symbol, int miniBatchSize, int exampleLength, double splitRatio, PriceCategory category) {
//        List<StockData> stockDataList = readStockDataFromFile(filename, symbol);
//        this.miniBatchSize = miniBatchSize;
//        this.exampleLength = exampleLength;
//        this.category = category;
//        int split = (int) Math.round(stockDataList.size() * splitRatio);
//        train = stockDataList.subList(0, split);
//        test = generateTestDataSet(stockDataList.subList(split, stockDataList.size()));
//        initializeOffsets();
//    }

	/** initialize the mini-batch offsets */
	private void initializeOffsets() {
		exampleStartOffsets.clear();
		int window = exampleLength + predictLength;
		for (int i = 0; i < train.size() - window; i++) {
			exampleStartOffsets.add(i);
		}
	}

	public List<MyEntry<INDArray, INDArray>> getTestDataSet() {
		return test;
	}

	public double[] getMaxArray() {
		return maxArray;
	}

	public double[] getMinArray() {
		return minArray;
	}

	public double getMaxNum(PriceCategory category) {
		return maxArray[featureMapIndex.get(category)];
	}

	public double getMinNum(PriceCategory category) {
		return minArray[featureMapIndex.get(category)];
	}

	@Override
	public DataSet next(int num) {
		if (exampleStartOffsets.size() == 0)
			throw new NoSuchElementException();
		int actualMiniBatchSize = Math.min(num, exampleStartOffsets.size());
		INDArray input = Nd4j.create(new int[] { actualMiniBatchSize, VECTOR_SIZE, exampleLength }, 'f');
		INDArray label;
		if (category.equals(PriceCategory.ALL))
			label = Nd4j.create(new int[] { actualMiniBatchSize, VECTOR_SIZE, exampleLength }, 'f');
		else
			label = Nd4j.create(new int[] { actualMiniBatchSize, predictLength, exampleLength }, 'f');
		for (int index = 0; index < actualMiniBatchSize; index++) {
			int startIdx = exampleStartOffsets.removeFirst();
			int endIdx = startIdx + exampleLength;
			StockData curData = train.get(startIdx);
			StockData nextData;
			for (int i = startIdx; i < endIdx; i++) {
				int c = i - startIdx;
				input.putScalar(new int[] { index, 0, c }, (curData.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]));
				input.putScalar(new int[] { index, 1, c }, (curData.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
				input.putScalar(new int[] { index, 2, c }, (curData.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
				input.putScalar(new int[] { index, 3, c }, (curData.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
				input.putScalar(new int[] { index, 4, c }, (curData.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
				nextData = train.get(i + 1);
				if (category.equals(PriceCategory.ALL)) {
					label.putScalar(new int[] { index, 0, c }, (nextData.getOpen() - minArray[1]) / (maxArray[1] - minArray[1]));
					label.putScalar(new int[] { index, 1, c }, (nextData.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
					label.putScalar(new int[] { index, 2, c }, (nextData.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
					label.putScalar(new int[] { index, 3, c }, (nextData.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
					label.putScalar(new int[] { index, 4, c }, (nextData.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
				} else {
					label.putScalar(new int[] { index, 0, c }, feedLabel(nextData));
				}
				curData = nextData;
			}
			if (exampleStartOffsets.size() == 0)
				break;
		}
		return new DataSet(input, label);
	}

	private double feedLabel(StockData data) {
		double value;
		switch (category) {
		case OPEN:
			value = (data.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]);
			break;
		case CLOSE:
			value = (data.getClose() - minArray[1]) / (maxArray[1] - minArray[1]);
			break;
		case LOW:
			value = (data.getLow() - minArray[2]) / (maxArray[2] - minArray[2]);
			break;
		case HIGH:
			value = (data.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]);
			break;
		case VOLUME:
			value = (data.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]);
			break;
		default:
			throw new NoSuchElementException();
		}
		return value;
	}

	@Override
	public int totalExamples() {
		return train.size() - exampleLength - predictLength;
	}

	@Override
	public int inputColumns() {
		return VECTOR_SIZE;
	}

	@Override
	public int totalOutcomes() {
		if (this.category.equals(PriceCategory.ALL))
			return VECTOR_SIZE;
		else
			return predictLength;
	}

	@Override
	public boolean resetSupported() {
		return false;
	}

	@Override
	public boolean asyncSupported() {
		return false;
	}

	@Override
	public void reset() {
		initializeOffsets();
	}

	@Override
	public int batch() {
		return miniBatchSize;
	}

	@Override
	public int cursor() {
		return totalExamples() - exampleStartOffsets.size();
	}

	@Override
	public int numExamples() {
		return totalExamples();
	}

	@Override
	public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
		throw new UnsupportedOperationException("Not Implemented");
	}

	@Override
	public DataSetPreProcessor getPreProcessor() {
		throw new UnsupportedOperationException("Not Implemented");
	}

	@Override
	public List<String> getLabels() {
		throw new UnsupportedOperationException("Not Implemented");
	}

	@Override
	public boolean hasNext() {
		return exampleStartOffsets.size() > 0;
	}

	@Override
	public DataSet next() {
		return next(miniBatchSize);
	}

	private List<MyEntry<INDArray, INDArray>> generateTestDataSet(List<StockData> stockDataList) {
		int window = exampleLength + predictLength;
		List<MyEntry<INDArray, INDArray>> test = new ArrayList<>();
		for (int i = 0; i < stockDataList.size() - window; i++) {
			INDArray input = Nd4j.create(new int[] { exampleLength, VECTOR_SIZE }, 'f');
			for (int j = i; j < i + exampleLength; j++) {
				StockData stock = stockDataList.get(j);
				input.putScalar(new int[] { j - i, 0 }, (stock.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]));
				input.putScalar(new int[] { j - i, 1 }, (stock.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
				input.putScalar(new int[] { j - i, 2 }, (stock.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
				input.putScalar(new int[] { j - i, 3 }, (stock.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
				input.putScalar(new int[] { j - i, 4 }, (stock.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
			}
			StockData stock = stockDataList.get(i + exampleLength);
			INDArray label;
			if (category.equals(PriceCategory.ALL)) {
				label = Nd4j.create(new int[] { VECTOR_SIZE }, 'f'); // ordering is set as 'f', faster construct
				label.putScalar(new int[] { 0 }, stock.getOpen());
				label.putScalar(new int[] { 1 }, stock.getClose());
				label.putScalar(new int[] { 2 }, stock.getLow());
				label.putScalar(new int[] { 3 }, stock.getHigh());
				label.putScalar(new int[] { 4 }, stock.getVolume());
			} else {
				label = Nd4j.create(new int[] { 1 }, 'f');
				switch (category) {
				case OPEN:
					label.putScalar(new int[] { 0 }, stock.getOpen());
					break;
				case CLOSE:
					label.putScalar(new int[] { 0 }, stock.getClose());
					break;
				case LOW:
					label.putScalar(new int[] { 0 }, stock.getLow());
					break;
				case HIGH:
					label.putScalar(new int[] { 0 }, stock.getHigh());
					break;
				case VOLUME:
					label.putScalar(new int[] { 0 }, stock.getVolume());
					break;
				default:
					throw new NoSuchElementException();
				}
			}
			test.add(new MyEntry<>(input, label));
		}
		return test;
	}

//	private List<StockData> readStockDataFromFile (String filename, String symbol) {
//        List<StockData> stockDataList = new ArrayList<>();
//        try {
//            for (int i = 0; i < maxArray.length; i++) { // initialize max and min arrays
//                maxArray[i] = Double.MIN_VALUE;
//                minArray[i] = Double.MAX_VALUE;
//            }
//            List<String[]> list = new CSVReader(new FileReader(filename)).readAll(); // load all elements in a list
//            for (String[] arr : list) {
//                if (!arr[1].equals(symbol)) continue;
//                double[] nums = new double[VECTOR_SIZE];
//                for (int i = 0; i < arr.length - 2; i++) {
//                    nums[i] = Double.valueOf(arr[i + 2]);
//                    if (nums[i] > maxArray[i]) maxArray[i] = nums[i];
//                    if (nums[i] < minArray[i]) minArray[i] = nums[i];
//                }
//                stockDataList.add(new StockData(arr[0], arr[1], nums[0], nums[1], nums[2], nums[3], nums[4]));
//            }
//        } catch (IOException e) {
//            e.printStackTrace();
//        }
//        return stockDataList;
//    }
}

Aufruf:

	public void predict(BarSeries series, String symbol) {
		System.out.println("Create iterator");
		List<StockDataSetIterator.StockData> stockDataList = new ArrayList<>();
		for (int i = 0; i < series.getBarCount(); i++) {
			Bar bar = series.getBar(i);
			stockDataList.add(new StockDataSetIterator.StockData(String.valueOf(bar.getEndTime().toInstant().toEpochMilli()), symbol, bar.getOpenPrice().doubleValue(), bar.getClosePrice().doubleValue(), bar.getLowPrice().doubleValue(), bar.getHighPrice().doubleValue(), bar.getVolume().doubleValue()));
		}
		StockDataSetIterator iterator = new StockDataSetIterator(stockDataList);
		List<MyEntry<INDArray, INDArray>> testData = iterator.getTestDataSet();

		System.out.println("Build lstm networks");
		final int nIn = iterator.inputColumns();
		final int nOut = iterator.totalOutcomes();
		final int seed = 12345;
		final int iterations = 1;
		final double learningRate = 0.05;
		final int lstmLayer1Size = 256;
		final int lstmLayer2Size = 256;
		final int denseLayerSize = 32;
		final double dropoutRatio = 0.2;
		final int truncatedBPTTLength = 22;
		MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).learningRate(learningRate).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).weightInit(WeightInit.XAVIER)
				.updater(Updater.RMSPROP).regularization(true).l2(1e-4).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(lstmLayer1Size).activation(Activation.TANH).gateActivationFunction(Activation.HARDSIGMOID).dropOut(dropoutRatio).build())
				.layer(1, new GravesLSTM.Builder().nIn(lstmLayer1Size).nOut(lstmLayer2Size).activation(Activation.TANH).gateActivationFunction(Activation.HARDSIGMOID).dropOut(dropoutRatio).build())
				.layer(2, new DenseLayer.Builder().nIn(lstmLayer2Size).nOut(denseLayerSize).activation(Activation.RELU).build())
				.layer(3, new RnnOutputLayer.Builder().nIn(denseLayerSize).nOut(nOut).activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build()).backpropType(BackpropType.TruncatedBPTT)
				.tBPTTForwardLength(truncatedBPTTLength).tBPTTBackwardLength(truncatedBPTTLength).pretrain(false).backprop(true).build();

		MultiLayerNetwork net = new MultiLayerNetwork(conf);
		net.init();
		net.setListeners(new ScoreIterationListener(100));

		System.out.println("Training");
		for (int i = 0; i < 20; i++) {
			while (iterator.hasNext())
				net.fit(iterator.next()); // fit model using mini-batch data
			iterator.reset(); // reset iterator
			net.rnnClearPreviousState(); // clear previous state
		}

		System.out.println("predict...");
		final INDArray max = Nd4j.create(iterator.getMaxArray());
		final INDArray min = Nd4j.create(iterator.getMinArray());
		final int exampleLength = 22; // time series length, assume 22 working days per month

		INDArray[] predicts = new INDArray[testData.size()];
		INDArray[] actuals = new INDArray[testData.size()];
		for (int i = 0; i < testData.size(); i++) {
			predicts[i] = net.rnnTimeStep(testData.get(i).getKey()).getRow(exampleLength - 1).mul(max.sub(min)).add(min);
			actuals[i] = testData.get(i).getValue();
		}

		System.out.println("Show prediction");
		BarSeries series2 = new BaseBarSeriesBuilder().withName(symbol + "_Prediction").build();
		for (int i = 0; i < predicts.length; i++) {
			System.out.println(i);
			series2.addBar(series.getBar(i).getEndTime(), predicts[i].getDouble(0), predicts[i].getDouble(3), predicts[i].getDouble(2), predicts[i].getDouble(1), predicts[i].getDouble(4));
		}
		TypicalPriceIndicator priceIndicator = new TypicalPriceIndicator(series2);
		List<Double> scores = new ArrayList<>();
		for (int i = 0; i < series2.getBarCount(); i++) {
			scores.add(priceIndicator.getValue(i).doubleValue());
		}
		SwingUtilities.invokeLater(new Runnable() {
			public void run() {
				JFrame gp = new JFrame(series2.getName());
				gp.add(new GraphPanel(scores));
				gp.setSize(800, 600);
				gp.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
				gp.setVisible(true);
			}
		});
	}

Augabe:

Create iterator
...
Build lstm networks
53818 [AWT-EventQueue-0] WARN org.reflections.Reflections  - given scan urls are empty. set urls in the configuration
53939 [AWT-EventQueue-0] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork  - Starting MultiLayerNetwork with WorkspaceModes set to [training: NONE; inference: SEPARATE]
Training
55031 [AWT-EventQueue-0] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener  - Score at iteration 0 is NaN
128646 [AWT-EventQueue-0] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener  - Score at iteration 100 is NaN
...
predict...
Show prediction
0
Exception in thread "AWT-EventQueue-0" java.lang.NumberFormatException: Character N is neither a decimal digit number, decimal point, nor "e" notation exponential mark.
	at java.base/java.math.BigDecimal.<init>(BigDecimal.java:519)

Der Fehler tritt zum Schluss in dieser Zeile auf: series2.addBar(series.getBar(i).getEndTime(), predicts[i].getDouble(0), predicts[i].getDouble(3), predicts[i].getDouble(2), predicts[i].getDouble(1), predicts[i].getDouble(4));
Zuvor ist der Score allerdings NaN, was ja auch schon verdächtig ist: - Score at iteration 0 is NaN
Diese Zeile kommt mir auch verdächtig vor: predicts[i] = net.rnnTimeStep(testData.get(i).getKey()).getRow(exampleLength - 1).mul(max.sub(min)).add(min);. Woher kommt die exampleLength = 22? Dieser Wert war fix.

Ich nutze 500 Eingabedatenzeilen, derjenige vor mir hat nur ca. 100 verwendet, kann es daran liegen?

Hehe, ich hab’s! :smiley: Einen Fehler hat dieser Hao gemacht und einen Fehler ich. Und zwar, schaut euch mal diese Zeilen an:

genau, da fehlt der 0-Index! (Wer möchte, kann das gerne Hao i-wie mitteilen).

Und dann noch was ganz simples, die Minimal- und Maximalwerte, die ich beim aus-kommentieren vergessen hab :frowning: !!! :slight_smile: :

	public StockDataSetIterator(List<StockData> stockDataList) {
		for (int i = 0; i < maxArray.length; i++) {
			maxArray[i] = Double.MIN_VALUE;
			minArray[i] = Double.MAX_VALUE;
		}
		for (StockData stockData : stockDataList) {
			double o = stockData.getOpen();
			double c = stockData.getClose();
			double l = stockData.getLow();
			double h = stockData.getHigh();
			double v = stockData.getVolume();
			double[] a = { o, c, l, h, v };
			for (int i = 0; i < a.length; i++) {
				if (maxArray[i] < a[i]) {
					maxArray[i] = a[i];
				}
				if (minArray[i] > a[i]) {
					minArray[i] = a[i];
				}
			}
		}

		this.miniBatchSize = 64;

(Teilen durch 0 ist immer suboptimal.)