TensorFlow mit Java

Das hier is von einem anderen Thema abgetrennt, in dem es ursprünglich nur um einen Rant gegen Protocol Buffers ging - das Thema TensorFlow könnte aber von allgemeinerem Interesse sein :wink:


Tensorflow an sich aus Java zu verwenden ist mit/dank Maven erfrischend trivial: Dependency rein und los geht’s.

Um rauszufinden, was so eine Zeile wie

g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();

macht, reicht die JavaDoc aber nicht: Ist „Const“ einfach ein Name oder ein „reservierter Bezeichner“? Welche „attr“'s gibt es? Hier offenbar „dtype“ und „value“. Gibt es die immer? Welche gibt es noch? Welche Werte dürfen/müssen die haben? Das geht aus der oben verlinkten Datei hervor:

name: "Const"
output_arg {
  name: "output"
  type_attr: "dtype"
}
attr {
  name: "value"
  type: "tensor"
  description: "Attr `value` is the tensor to return."
}
attr {
  name: "dtype"
  type: "type"
}
summary: "Returns a constant tensor."

Bisher macht mein Generator aus diesen Dingern jetzt nur sowas wie

    /**
     * Gradient op for <code>MirrorPad</code> op. This op folds a mirror-padded tensor.<p>This operation folds the padded areas of <code>input</code> by <code>MirrorPad</code> according to the
     * <code>paddings</code> you specify. <code>paddings</code> must be the same as <code>paddings</code> argument
     * given to the corresponding <code>MirrorPad</code> op.</p>
     * <p>The folded size of each dimension D of the output is:</p>
     * <p><code>input.dim_size(D) - paddings(D, 0) - paddings(D, 1)</code></p>
     * <p>For example:</p>
     * <pre><code># 't' is [[1, 2, 3], [4, 5, 6], [7, 8, 9]].
     * # 'paddings' is [[0, 1]], [0, 1]].
     * # 'mode' is SYMMETRIC.
     * # rank of 't' is 2.
     * pad(t, paddings) ==&gt; [[ 1,  5]
     *                       [11, 28]]
     * </code></pre>
     * 
     * @param input
     *     The input tensor to be folded.
     * @param paddings
     *     A two-column matrix specifying the padding sizes. The number of
     *     rows must be the same as the rank of <code>input</code>.
     */
    public<T >Output<T> mirrorPadGrad(Output<T> input, Output<T> paddings) {
        return null;
    }

(also ziemlich leer…). Mal schauen, ob da was vernünftiges rauskommen kann…

Das sieht doch schonmal ziemlich gut aus. Aus dem return null, noch was passendes zu machen dürfte auch zu schaffen sein, aber doch sehr aufwendig.
Aber ich verstehe jetzt was du meinst. Ja, es wäre schön von Google gewesen, da gleich eine richtige API auszuliefern, die etwa so aussieht, wie das, was du da generiert hast. Das wäre verwendbarer und hätte den Konzern wohl auch nicht in die Insolvenz getrieben. Ich muss schon schmunzeln, wenn ich mir vorstelle Code zu sehen der aus zig opBuilderaufrufen besteht um damit dann irgend ein Model zu trainieren.

Hab auch mal geschaut ob das für Java schonmal jemand in “Schön” gebaut hat, aber das was dem ganzen am nächsten kommt ist in clojure unter dem namespace helpers zu finden. https://github.com/kieranbrowne/clojure-tensorflow-interop/blob/master/src/clojure_tensorflow_interop/helpers.clj
Und wohl auch nicht vollständig oder gar kommentiert.

Naja, das ist eher „Spielerei“. Die Methoden werden jetzt zwar schon gefüllt,…

public Output<Float> fakeQuantWithMinMaxArgsGradient(Output<Float> gradients, Output<Float> inputs) {
    OperationBuilder operationBuilder = graph.opBuilder("FakeQuantWithMinMaxArgsGradient", "FakeQuantWithMinMaxArgsGradient");
    operationBuilder = operationBuilder.addInput(gradients);
    operationBuilder = operationBuilder.addInput(inputs);
    Operation operation = operationBuilder.build();
    Object result = operation.output(0);
    return ((Output<Float> ) result);
}

…aber ich kenne die Konzepte von TensorFlow (noch?) nicht genau genug, um das was sinnvolles machen zu können.

Intern arbeiten die (also Google und Contributors) (natürlich) auch schon an was besserem, das wird um [Java][Feature] Generating operation methods used to build a graph · Issue #7149 · tensorflow/tensorflow · GitHub herum diskutiert. Den Code kann man sich schon ansehen. Das sieht zumindest übersichtlicher und Objektorientierter aus, als irgendwelche anonymen „ops“ mit Strings, von denen keiner weiß, wo sie herkommen und was sie bedeuten.

Das ganze liegt aber in einem Branch, der gerade nicht sooo aktiv aussieht. Dazu kommt, dass die wohl was mit einem Tool namens „Bazel“ und irgendwelchen Annotation-Processors machen… da müßte man sich wohl länger reinfräsen…

Das Clojure-Ding … ja (auch seit ein paar Monaten still, und) sieht recht unvollständig aus. (Außerdem tun mir bei Clojure ja immer noch die Augen ein bißchen weh :wink: ). Natürlich gibt’s noch anderes, sowas wie das um tensorframes/Ops.scala at master · databricks/tensorframes · GitHub herum für Scala (die sind anscheinend auch diesen Protobuffer-Trampelpfad runtergelaufen.

Insgesamt scheint TensorFlow schon *räusper* „sehr komplex“ zu sein, und sowohl technisch (in bezug auf die API, mit den Buildern und den magischen Strings) als auch in bezug auf die Konzepte wird man da einiges an Einarbeitung brauchen.