Skip to content

Commit

Permalink
Add extensions-tensorflow-lite #24
Browse files Browse the repository at this point in the history
  • Loading branch information
kuloud committed Dec 27, 2019
1 parent 4ff67b7 commit c861774
Show file tree
Hide file tree
Showing 18 changed files with 406 additions and 256 deletions.
5 changes: 2 additions & 3 deletions Android/examples/demo/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,14 @@ dependencies {
implementation deps.aoe.runtime.mnn
implementation deps.aoe.runtime.ncnn


implementation deps.gson

implementation deps.kotlin

implementation 'com.didi.aoe:extensions-support:1.1.1.1'

implementation 'org.tensorflow:tensorflow-lite:2.0.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.0.0'

implementation deps.aoe.extensions.tensorflow
implementation deps.aoe.extensions.pytorch

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import com.didi.aoe.runtime.tensorflow.lite.TensorFlowLiteInterpreter;
import com.didi.aoe.runtime.tensorflow.lite.TensorFlowInterpreter;

/**
* @author noctis
*/
public class MnistInterpreter extends TensorFlowLiteInterpreter<float[], Integer, float[], float[][]> {
public class MnistInterpreter extends TensorFlowInterpreter<float[], Integer, float[], float[][]> {

@Nullable
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
package com.didi.aoe.features.pytorch

import android.graphics.Bitmap
import com.didi.aoe.pytorch.PytorchConvertor
import com.didi.aoe.runtime.pytorch.PyTorchInterpreter
import com.didi.aoe.extensions.pytorch.PytorchConvertor
import org.pytorch.Tensor
import org.pytorch.torchvision.TensorImageUtils

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ class SqueezeInterpreter :
val bmpBuffer = ByteBuffer.allocate(size)
input.copyPixelsToBuffer(bmpBuffer)
val rgba = bmpBuffer.array()
squeeze!!.inputRgba(rgba, input.width, input.height, INPUT_WIDTH,
squeeze?.inputRgba(rgba, input.width, input.height, INPUT_WIDTH,
INPUT_HEIGHT,
meanVals, norVals, 0)
val buffer = ByteBuffer.allocate(4096)
squeeze!!.run(null, buffer)
squeeze?.run(null, buffer)
buffer.order(ByteOrder.nativeOrder())
buffer.flip()
val shape = squeeze!!.getOutputTensor(0).shape()
Expand Down
2 changes: 1 addition & 1 deletion Android/extensions/pytorch/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
-->

<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.didi.aoe.pytorch" />
package="com.didi.aoe.extensions.pytorch" />
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package com.didi.aoe.pytorch
package com.didi.aoe.extensions.pytorch

import android.content.Context
import com.didi.aoe.library.api.Aoe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package com.didi.aoe.pytorch
package com.didi.aoe.extensions.pytorch

import com.didi.aoe.runtime.pytorch.PyTorchInterpreter
import org.pytorch.Tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package com.didi.aoe.pytorch
package com.didi.aoe.extensions.pytorch

import com.didi.aoe.library.api.convertor.Convertor
import org.pytorch.Tensor
Expand Down
1 change: 1 addition & 0 deletions Android/extensions/tensorflow-lite/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/build
60 changes: 60 additions & 0 deletions Android/extensions/tensorflow-lite/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright 2019 The AoE Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

apply plugin: 'com.android.library'
apply plugin: 'kotlin-android'


ext {
releaseArtifact = 'extensions-tensorflow-lite'
releaseDescription = 'The AoE tensorflow lite extensions library'
releaseVersion = aoe_version_name
}
apply from: rootProject.file('gradle/release.gradle')

android {
compileSdkVersion aoe_compile_sdk_version
defaultConfig {
minSdkVersion aoe_min_sdk_version
targetSdkVersion aoe_target_sdk_version
versionName releaseVersion
}

buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}

}

dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar'])

implementation deps.support.annotation

implementation deps.aoe.library.core

implementation deps.kotlin

implementation 'org.tensorflow:tensorflow-lite:2.0.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.0.0'

implementation deps.aoe.runtime.tensorflow

}

Empty file.
21 changes: 21 additions & 0 deletions Android/extensions/tensorflow-lite/proguard-rules.pro
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html

# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}

# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable

# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
18 changes: 18 additions & 0 deletions Android/extensions/tensorflow-lite/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<!--
~ Copyright 2019 The AoE Authors. All Rights Reserved.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.didi.aoe.extensions.tensorflow.lite" />
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2019 The AoE Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.didi.aoe.extensions.tensorflow.lite

import android.content.Context
import com.didi.aoe.library.api.Aoe
import com.didi.aoe.library.core.AoeClient
import com.didi.aoe.runtime.tensorflow.lite.TensorFlowMultipleInputsOutputsInterpreter
import org.tensorflow.lite.gpu.GpuDelegate

/**
*
*
* @author noctis
* @since 1.1.0
*/

fun Aoe.Companion.createAoeClient(context: Context, modelPath: String,
convertor: TensorFlowMultipleInputsOutputsInterpreter<*, *, *, *>, useGpu: Boolean): AoeClient {
if (useGpu) {
convertor.addDelegate(GpuDelegate())
}
val client = AoeClient(context, convertor, modelPath)
return client
}
3 changes: 2 additions & 1 deletion Android/global_config.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ ext {
mnn : isDebug() ? project(':runtime-mnn') : "com.didi.aoe:runtime-mnn:$aoe_version_name",
],
extensions: [
pytorch: isDebug() ? project(':extensions-pytorch') : "com.didi.aoe:extensions-pytorch:$aoe_version_name"
pytorch : isDebug() ? project(':extensions-pytorch') : "com.didi.aoe:extensions-pytorch:$aoe_version_name",
tensorflow: isDebug() ? project(':extensions-tensorflow-lite') : "com.didi.aoe:extensions-tensorflow-lite:$aoe_version_name",
]
],
// aoe 计划后续实现全量使用 kotlin,注解依赖暂时使用较低的 26.x 以保证最小的适配成本
Expand Down
3 changes: 1 addition & 2 deletions Android/settings.gradle
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
//include ':aoe-pytorch'

module('examples', 'demo')

module('library', 'core')
Expand All @@ -14,6 +12,7 @@ module('runtime', 'ncnn')
module('runtime', 'pytorch')

module('extensions', 'pytorch')
module('extensions', 'tensorflow-lite')

// -------------------------------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,40 +1,40 @@
package com.didi.aoe.runtime.tensorflow.lite;

import android.support.annotation.NonNull;
import android.support.annotation.Nullable;
import com.didi.aoe.library.api.convertor.Convertor;

import java.util.Map;

/**
* 基于TensorFlow Lite的运行时Interpreter封装,用于单输入,单输出的常见场景。多路输入的场景不要继承这个类,继承
* 它的父类TensorFlowLiteMultipleInputsOutputsInterpreter,实现preProcessMulti和postProcessMulti即可。
*
* @param <TInput> 范型,业务输入数据
* @param <TOutput> 范型,业务输出数据
* @param <TModelInput> 范型,模型输入数据
* @param <TModelOutput> 范型,模型输出数据
* @author noctis
*/
public abstract class TensorFlowLiteInterpreter<TInput, TOutput, TModelInput, TModelOutput> extends
TensorFlowLiteMultipleInputsOutputsInterpreter<TInput, TOutput, Object, TModelOutput> implements
Convertor<TInput, TOutput, TModelInput, TModelOutput> {

@Nullable
@Override
public final Object[] preProcessMulti(@NonNull TInput tInput) {
Object[] inputs = new Object[1];
inputs[0] = preProcess(tInput);
return inputs;
}

@Nullable
@Override
public final TOutput postProcessMulti(@Nullable Map<Integer, TModelOutput> modelOutput) {
if (modelOutput != null && !modelOutput.isEmpty()) {
return postProcess(modelOutput.get(0));
}
return null;
}

}
package com.didi.aoe.runtime.tensorflow.lite;

import android.support.annotation.NonNull;
import android.support.annotation.Nullable;
import com.didi.aoe.library.api.convertor.Convertor;

import java.util.Map;

/**
* 基于TensorFlow Lite的运行时Interpreter封装,用于单输入,单输出的常见场景。多路输入的场景不要继承这个类,继承
* 它的父类TensorFlowLiteMultipleInputsOutputsInterpreter,实现preProcessMulti和postProcessMulti即可。
*
* @param <TInput> 范型,业务输入数据
* @param <TOutput> 范型,业务输出数据
* @param <TModelInput> 范型,模型输入数据
* @param <TModelOutput> 范型,模型输出数据
* @author noctis
*/
public abstract class TensorFlowInterpreter<TInput, TOutput, TModelInput, TModelOutput> extends
TensorFlowMultipleInputsOutputsInterpreter<TInput, TOutput, Object, TModelOutput> implements
Convertor<TInput, TOutput, TModelInput, TModelOutput> {

@Nullable
@Override
public final Object[] preProcessMulti(@NonNull TInput tInput) {
Object[] inputs = new Object[1];
inputs[0] = preProcess(tInput);
return inputs;
}

@Nullable
@Override
public final TOutput postProcessMulti(@Nullable Map<Integer, TModelOutput> modelOutput) {
if (modelOutput != null && !modelOutput.isEmpty()) {
return postProcess(modelOutput.get(0));
}
return null;
}

}
Loading

0 comments on commit c861774

Please sign in to comment.