博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow Lite for Android 初探(附demo)
阅读量:6572 次
发布时间:2019-06-24

本文共 4491 字,大约阅读时间需要 14 分钟。

一. TensorFlow Lite

TensorFlow Lite 是用于移动设备和嵌入式设备的轻量级解决方案。TensorFlow Lite 支持 Android、iOS 甚至树莓派等多种平台。

我们知道大多数的 AI 是在云端运算的,但是在移动端使用 AI 具有无网络延迟、响应更加及时、数据隐私等特性。

对于离线的场合,云端的 AI 就无法使用了,而此时可以在移动设备中使用 TensorFlow Lite。

二. tflite 格式

TensorFlow 生成的模型是无法直接给移动端使用的,需要离线转换成.tflite文件格式。

tflite 存储格式是 flatbuffers。

FlatBuffers 是由Google开源的一个免费软件库,用于实现序列化格式。它类似于Protocol Buffers、Thrift、Apache Avro。

因此,如果要给移动端使用的话,必须把 TensorFlow 训练好的 protobuf 模型文件转换成 FlatBuffers 格式。官方提供了 toco 来实现模型格式的转换。

三. 常用的 Java API

TensorFlow Lite 提供了 C ++ 和 Java 两种类型的 API。无论哪种 API 都需要加载模型和运行模型。

而 TensorFlow Lite 的 Java API 使用了 Interpreter 类(解释器)来完成加载模型和运行模型的任务。后面的例子会看到如何使用 Interpreter。

四. TensorFlow Lite + mnist 数据集实现识别手写数字

mnist 是手写数字图片数据集,包含60000张训练样本和10000张测试样本。 测试集也是同样比例的手写数字数据。每张图片有28x28个像素点构成,每个像素点用一个灰度值表示,这里是将28x28的像素展开为一个一维的行向量(每行784个值)。

mnist 数据集获取地址:

下面的 demo 中已经包含了 mnist.tflite 模型文件。(如果没有的话,需要自己训练保存成pb文件,再转换成tflite 格式)

对于一个识别类,首先需要初始化 TensorFlow Lite 解释器,以及输入、输出。

// The tensorflow lite file    private lateinit var tflite: Interpreter    // Input byte buffer    private lateinit var inputBuffer: ByteBuffer    // Output array [batch_size, 10]    private lateinit var mnistOutput: Array
init { try { tflite = Interpreter(loadModelFile(activity)) inputBuffer = ByteBuffer.allocateDirect( BYTE_SIZE_OF_FLOAT * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE) inputBuffer.order(ByteOrder.nativeOrder()) mnistOutput = Array(DIM_BATCH_SIZE) { FloatArray(NUMBER_LENGTH) } Log.d(TAG, "Created a Tensorflow Lite MNIST Classifier.") } catch (e: IOException) { Log.e(TAG, "IOException loading the tflite file failed.") } }复制代码

从 asserts 文件中加载 mnist.tflite 模型:

/**     * Load the model file from the assets folder     */    @Throws(IOException::class)    private fun loadModelFile(activity: Activity): MappedByteBuffer {        val fileDescriptor = activity.assets.openFd(MODEL_PATH)        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)        val fileChannel = inputStream.channel        val startOffset = fileDescriptor.startOffset        val declaredLength = fileDescriptor.declaredLength        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)    }复制代码

真正识别手写数字是在 classify() 方法:

val digit = mnistClassifier.classify(Bitmap.createScaledBitmap(paintView.bitmap, PIXEL_WIDTH, PIXEL_WIDTH, false))复制代码

classify() 方法包含了预处理用于初始化 inputBuffer、运行 mnist 模型、识别出数字。

/**     * Classifies the number with the mnist model.     *     * @param bitmap     * @return the identified number     */    fun classify(bitmap: Bitmap): Int {        if (tflite == null) {            Log.e(TAG, "Image classifier has not been initialized; Skipped.")        }        preProcess(bitmap)        runModel()        return postProcess()    }    /**     * Converts it into the Byte Buffer to feed into the model     *     * @param bitmap     */    private fun preProcess(bitmap: Bitmap?) {        if (bitmap == null || inputBuffer == null) {            return        }        // Reset the image data        inputBuffer.rewind()        val width = bitmap.width        val height = bitmap.height        // The bitmap shape should be 28 x 28        val pixels = IntArray(width * height)        bitmap.getPixels(pixels, 0, width, 0, 0, width, height)        for (i in pixels.indices) {            // Set 0 for white and 255 for black pixels            val pixel = pixels[i]            // The color of the input is black so the blue channel will be 0xFF.            val channel = pixel and 0xff            inputBuffer.putFloat((0xff - channel).toFloat())        }    }    /**     * Run the TFLite model     */    private fun runModel() = tflite.run(inputBuffer, mnistOutput)    /**     * Go through the output and find the number that was identified.     *     * @return the number that was identified (returns -1 if one wasn't found)     */    private fun postProcess(): Int {        for (i in 0 until mnistOutput[0].size) {            val value = mnistOutput[0][i]            if (value == 1f) {                return i            }        }        return -1    }复制代码

对于 Android 有一个地方需要注意,必须在 app 模块的 build.gradle 中添加如下的语句,否则无法加载模型。

android {    ......    aaptOptions {        noCompress "tflite"    }}复制代码

demo 运行效果如下:

五. 总结

本文只是 TF Lite 的初探,很多细节并没有详细阐述。应该会在未来的文章中详细介绍。

本文 demo 的 github 地址:

当然,也可以跑一下官方的例子:


Java与Android技术栈:每周更新推送原创技术文章,欢迎扫描下方的公众号二维码并关注,期待与您的共同成长和进步。

转载地址:http://leljo.baihongyu.com/

你可能感兴趣的文章
C 小白的 thrift 环境搭建
查看>>
php闭包使用例子
查看>>
虚拟机+centOS挂载ISO步骤
查看>>
java 如何查看jdk版本&位数
查看>>
JAVA中字符串的startWith什么意思
查看>>
Deepin 系统下安装VMware并激活
查看>>
ms12_004漏洞进行渗透
查看>>
spring mvc: xml练习
查看>>
QT-提示“database not open”
查看>>
Linux常用基本命令:三剑客命令之-awk内置函数用法
查看>>
【Mac brew】代理安装brew insall
查看>>
Nginx 项目部署和配置
查看>>
laravel validate 设置为中文(验证提示为中文)
查看>>
1. ansible-playbook 变量定义与引用
查看>>
OkHttp3源码详解(五) okhttp连接池复用机制
查看>>
SQL SERVER使用ODBC 驱动建立的链接服务器调用存储过程时参数不能为NULL值
查看>>
CSS3之超出隐藏
查看>>
通用Web后台魔方NewLife.Cube
查看>>
java 泛型详解-绝对是对泛型方法讲解最详细的,没有之一
查看>>
Windows7下安装配置PostgreSQL10
查看>>