请教:在安卓 tf lite 使用训练好的模型,出现 DataType error

2020-06-11 21:57:43 +08:00
 a421

原本的目的是移植一个模型到安卓,遇到问题后,重新做了个简单的模型验证,出现同样的问题。

python 训练的代码


model = keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')

xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=np.float32)
ys = np.array([-3.0, -1.0, 0.0, 3.0, 5.0, 7.0], dtype=np.float32)

model.fit(xs, ys, epochs=500)
keras_file = 'linear.h5'
keras.models.save_model(model, keras_file)

转换成 .tflite 后,在安卓使用

Interpreter interpreter = new Interpreter(FileUtil.loadMappedFile(activity, "linear.tflite"));
interpreter.allocateTensors();
int probabilityTensorIndex = 0;
int[] probabilityShape =
        interpreter.getOutputTensor(probabilityTensorIndex).shape(); //
DataType probabilityDataType = interpreter.getOutputTensor(probabilityTensorIndex).dataType();
TensorBuffer outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);

int inputTensorIndex = 0;
DataType inputDataType = interpreter.getInputTensor(inputTensorIndex).dataType();
int[] inputShape = interpreter.getInputTensor(inputTensorIndex).shape();
TensorBuffer inputBuffer = TensorBuffer.createFixedSize(inputShape, inputDataType);
final float[] input = {10};
inputBuffer.loadArray(input);

interpreter.run(inputBuffer, outputProbabilityBuffer);

报错是

I/tflite: Initialized TensorFlow Lite runtime.
E/AndroidRuntime: FATAL EXCEPTION: inference
    Process: com.example.my1application, PID: 26839
    java.lang.IllegalArgumentException: DataType error: cannot resolve DataType of org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat
        at org.tensorflow.lite.Tensor.dataTypeOf(Tensor.java:344)
        at org.tensorflow.lite.Tensor.throwIfTypeIsIncompatible(Tensor.java:397)
        at org.tensorflow.lite.Tensor.getInputShapeIfDifferent(Tensor.java:287)
        at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:137)
        at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:316)
        at org.tensorflow.lite.Interpreter.run(Interpreter.java:277)
        at com.example.my1application.DisplayMessageActivity$1.run(DisplayMessageActivity.java:114)
        at android.os.Handler.handleCallback(Handler.java:815)
        at android.os.Handler.dispatchMessage(Handler.java:104)
        at android.os.Looper.loop(Looper.java:207)
        at android.os.HandlerThread.run(HandlerThread.java:61)
1465 次点击
所在节点    机器学习
0 条回复

这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。

https://www.v2ex.com/t/680821

V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。

V2EX is a community of developers, designers and creative people.

© 2021 V2EX