diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java
index e54fcfee16c..93abc616b98 100644
--- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java
+++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java
@@ -43,9 +43,11 @@
import android.view.ViewTreeObserver;
import android.view.WindowManager;
import android.widget.AdapterView;
+import android.widget.CompoundButton;
import android.widget.ImageView;
import android.widget.LinearLayout;
import android.widget.Spinner;
+import android.widget.Switch;
import android.widget.TextView;
import android.widget.Toast;
import com.google.android.material.bottomsheet.BottomSheetBehavior;
@@ -61,7 +63,7 @@ public abstract class CameraActivity extends AppCompatActivity
implements OnImageAvailableListener,
Camera.PreviewCallback,
View.OnClickListener,
- AdapterView.OnItemSelectedListener {
+ AdapterView.OnItemSelectedListener, Switch.OnCheckedChangeListener {
private static final Logger LOGGER = new Logger();
private static final int PERMISSIONS_REQUEST = 1;
@@ -97,10 +99,12 @@ public abstract class CameraActivity extends AppCompatActivity
private Spinner modelSpinner;
private Spinner deviceSpinner;
private TextView threadsTextView;
+ private Switch fp16Switch;
private Model model = Model.QUANTIZED_EFFICIENTNET;
private Device device = Device.CPU;
private int numThreads = -1;
+ private boolean allowFP16 = false;
@Override
protected void onCreate(final Bundle savedInstanceState) {
@@ -116,6 +120,7 @@ protected void onCreate(final Bundle savedInstanceState) {
requestPermission();
}
+ fp16Switch = findViewById(R.id.fp16);
threadsTextView = findViewById(R.id.threads);
plusImageView = findViewById(R.id.plus);
minusImageView = findViewById(R.id.minus);
@@ -192,9 +197,13 @@ public void onSlide(@NonNull View bottomSheet, float slideOffset) {}
plusImageView.setOnClickListener(this);
minusImageView.setOnClickListener(this);
+ fp16Switch.setOnCheckedChangeListener(this);
+
model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase());
device = Device.valueOf(deviceSpinner.getSelectedItem().toString());
numThreads = Integer.parseInt(threadsTextView.getText().toString().trim());
+ allowFP16 = fp16Switch.isChecked();
+ fp16Switch.setEnabled(false);
}
protected int[] getRgbBytes() {
@@ -577,6 +586,7 @@ private void setModel(Model model) {
if (this.model != model) {
LOGGER.d("Updating model: " + model);
this.model = model;
+ fp16Switch.setEnabled(((model == Model.FLOAT_MOBILENET) || (model == Model.FLOAT_EFFICIENTNET)) && (device == Device.NNAPI));
onInferenceConfigurationChanged();
}
}
@@ -593,6 +603,7 @@ private void setDevice(Device device) {
plusImageView.setEnabled(threadsEnabled);
minusImageView.setEnabled(threadsEnabled);
threadsTextView.setText(threadsEnabled ? String.valueOf(numThreads) : "N/A");
+ fp16Switch.setEnabled(((model == Model.FLOAT_MOBILENET) || (model == Model.FLOAT_EFFICIENTNET)) && (device == Device.NNAPI));
onInferenceConfigurationChanged();
}
}
@@ -609,6 +620,8 @@ private void setNumThreads(int numThreads) {
}
}
+ protected boolean getFP16() { return allowFP16; }
+
protected abstract void processImage();
protected abstract void onPreviewSizeChosen(final Size size, final int rotation);
@@ -651,4 +664,10 @@ public void onItemSelected(AdapterView> parent, View view, int pos, long id) {
public void onNothingSelected(AdapterView> parent) {
// Do nothing.
}
+
+ @Override
+ public void onCheckedChanged(CompoundButton switchView, boolean isChecked) {
+ allowFP16 = isChecked;
+ onInferenceConfigurationChanged();
+ }
}
diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java
index 0dd00e9d25c..188f7ad55ef 100644
--- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java
+++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java
@@ -64,7 +64,7 @@ public void onPreviewSizeChosen(final Size size, final int rotation) {
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
- recreateClassifier(getModel(), getDevice(), getNumThreads());
+ recreateClassifier(getModel(), getDevice(), getNumThreads(), getFP16());
if (classifier == null) {
LOGGER.e("No classifier on preview!");
return;
@@ -83,6 +83,8 @@ public void onPreviewSizeChosen(final Size size, final int rotation) {
@Override
protected void processImage() {
rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
+ final int imageSizeX = (classifier == null) ? 224 : classifier.getImageSizeX();
+ final int imageSizeY = (classifier == null) ? 224 : classifier.getImageSizeY();
final int cropSize = Math.min(previewWidth, previewHeight);
runInBackground(
@@ -123,10 +125,11 @@ protected void onInferenceConfigurationChanged() {
final Device device = getDevice();
final Model model = getModel();
final int numThreads = getNumThreads();
- runInBackground(() -> recreateClassifier(model, device, numThreads));
+ final boolean fp16 = getFP16();
+ runInBackground(() -> recreateClassifier(model, device, numThreads, fp16));
}
- private void recreateClassifier(Model model, Device device, int numThreads) {
+ private void recreateClassifier(Model model, Device device, int numThreads, boolean fp16) {
if (classifier != null) {
LOGGER.d("Closing classifier.");
classifier.close();
@@ -143,8 +146,8 @@ private void recreateClassifier(Model model, Device device, int numThreads) {
}
try {
LOGGER.d(
- "Creating classifier (model=%s, device=%s, numThreads=%d)", model, device, numThreads);
- classifier = Classifier.create(this, model, device, numThreads);
+ "Creating classifier (model=%s, device=%s, numThreads=%d, allowFP16=%b)", model, device, numThreads, fp16);
+ classifier = Classifier.create(this, model, device, numThreads, fp16);
} catch (IOException e) {
LOGGER.e(e, "Failed to create classifier.");
}
diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java
index c88d8630bb0..774f97a1a60 100644
--- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java
+++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java
@@ -107,16 +107,17 @@ public enum Device {
* @param model The model to use for classification.
* @param device The device to use for classification.
* @param numThreads The number of threads to use for classification.
+ * @param fp16 Allow FP32 model to run on FP16 accelerators
* @return A classifier with the desired configuration.
*/
- public static Classifier create(Activity activity, Model model, Device device, int numThreads)
+ public static Classifier create(Activity activity, Model model, Device device, int numThreads, boolean fp16)
throws IOException {
if (model == Model.QUANTIZED_MOBILENET) {
return new ClassifierQuantizedMobileNet(activity, device, numThreads);
} else if (model == Model.FLOAT_MOBILENET) {
- return new ClassifierFloatMobileNet(activity, device, numThreads);
+ return new ClassifierFloatMobileNet(activity, device, numThreads, fp16);
} else if (model == Model.FLOAT_EFFICIENTNET) {
- return new ClassifierFloatEfficientNet(activity, device, numThreads);
+ return new ClassifierFloatEfficientNet(activity, device, numThreads, fp16);
} else if (model == Model.QUANTIZED_EFFICIENTNET) {
return new ClassifierQuantizedEfficientNet(activity, device, numThreads);
} else {
@@ -195,12 +196,13 @@ public String toString() {
}
/** Initializes a {@code Classifier}. */
- protected Classifier(Activity activity, Device device, int numThreads) throws IOException {
+ protected Classifier(Activity activity, Device device, int numThreads, boolean fp16) throws IOException {
tfliteModel = FileUtil.loadMappedFile(activity, getModelPath());
switch (device) {
case NNAPI:
nnApiDelegate = new NnApiDelegate();
tfliteOptions.addDelegate(nnApiDelegate);
+ tfliteOptions.setAllowFp16PrecisionForFp32(fp16);
break;
case GPU:
gpuDelegate = new GpuDelegate();
diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java
index a48bd823384..f84f023b65d 100644
--- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java
+++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java
@@ -40,9 +40,9 @@ public class ClassifierFloatEfficientNet extends Classifier {
*
* @param activity
*/
- public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads)
+ public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads, boolean fp16)
throws IOException {
- super(activity, device, numThreads);
+ super(activity, device, numThreads, fp16);
}
@Override
diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java
index dd3b6aea9c9..a4fece68aed 100644
--- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java
+++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java
@@ -42,9 +42,9 @@ public class ClassifierFloatMobileNet extends Classifier {
*
* @param activity
*/
- public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads)
+ public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads, boolean fp16)
throws IOException {
- super(activity, device, numThreads);
+ super(activity, device, numThreads, fp16);
}
@Override
diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java
index f84313f6e28..032bec7e7e2 100644
--- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java
+++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java
@@ -43,7 +43,7 @@ public class ClassifierQuantizedEfficientNet extends Classifier {
*/
public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads)
throws IOException {
- super(activity, device, numThreads);
+ super(activity, device, numThreads, /*fp16=*/ false);
}
@Override
diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java
index 5f18f79c956..1d9c02176f3 100644
--- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java
+++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java
@@ -44,7 +44,7 @@ public class ClassifierQuantizedMobileNet extends Classifier {
*/
public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads)
throws IOException {
- super(activity, device, numThreads);
+ super(activity, device, numThreads, /*fp16=*/ false);
}
@Override
diff --git a/lite/examples/image_classification/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml b/lite/examples/image_classification/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml
index 601b4a45d93..f4d62f21c0b 100644
--- a/lite/examples/image_classification/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml
+++ b/lite/examples/image_classification/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml
@@ -224,6 +224,12 @@
android:layout_marginTop="10dp"
android:background="@android:color/darker_gray" />
+
+