From 657789bd687725bc20af46f42388e724ead6621f Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Tue, 22 Oct 2019 10:30:11 +0800 Subject: [PATCH 1/4] Add fp16 option for NNAPI Some NNAPI accelerators are fp16 only. Add an option to allow fp32 on fp16 accelerators. --- .../classification/CameraActivity.java | 18 +++++++++++++++++- .../classification/ClassifierActivity.java | 11 ++++++----- .../classification/tflite/Classifier.java | 9 +++++---- .../tflite/ClassifierFloatEfficientNet.java | 4 ++-- .../tflite/ClassifierFloatMobileNet.java | 4 ++-- .../ClassifierQuantizedEfficientNet.java | 2 +- .../tflite/ClassifierQuantizedMobileNet.java | 2 +- .../res/layout/tfe_ic_layout_bottom_sheet.xml | 6 ++++++ 8 files changed, 40 insertions(+), 16 deletions(-) diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java index e54fcfee16c..3c3e0798481 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java @@ -43,9 +43,11 @@ import android.view.ViewTreeObserver; import android.view.WindowManager; import android.widget.AdapterView; +import android.widget.CompoundButton; import android.widget.ImageView; import android.widget.LinearLayout; import android.widget.Spinner; +import android.widget.Switch; import android.widget.TextView; import android.widget.Toast; import com.google.android.material.bottomsheet.BottomSheetBehavior; @@ -61,7 +63,7 @@ public abstract class CameraActivity extends AppCompatActivity implements OnImageAvailableListener, Camera.PreviewCallback, View.OnClickListener, - AdapterView.OnItemSelectedListener { + AdapterView.OnItemSelectedListener, Switch.OnCheckedChangeListener { private static final Logger LOGGER = new Logger(); private static final int PERMISSIONS_REQUEST = 1; @@ -97,10 +99,12 @@ public abstract class CameraActivity extends AppCompatActivity private Spinner modelSpinner; private Spinner deviceSpinner; private TextView threadsTextView; + private Switch fp16Switch; private Model model = Model.QUANTIZED_EFFICIENTNET; private Device device = Device.CPU; private int numThreads = -1; + private boolean allowFP16 = false; @Override protected void onCreate(final Bundle savedInstanceState) { @@ -116,6 +120,7 @@ protected void onCreate(final Bundle savedInstanceState) { requestPermission(); } + fp16Switch = findViewById(R.id.fp16); threadsTextView = findViewById(R.id.threads); plusImageView = findViewById(R.id.plus); minusImageView = findViewById(R.id.minus); @@ -192,9 +197,12 @@ public void onSlide(@NonNull View bottomSheet, float slideOffset) {} plusImageView.setOnClickListener(this); minusImageView.setOnClickListener(this); + fp16Switch.setOnCheckedChangeListener(this); + model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase()); device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); + allowFP16 = fp16Switch.isChecked(); } protected int[] getRgbBytes() { @@ -609,6 +617,8 @@ private void setNumThreads(int numThreads) { } } + protected boolean getFP16() { return allowFP16; } + protected abstract void processImage(); protected abstract void onPreviewSizeChosen(final Size size, final int rotation); @@ -651,4 +661,10 @@ public void onItemSelected(AdapterView parent, View view, int pos, long id) { public void onNothingSelected(AdapterView parent) { // Do nothing. } + + @Override + public void onCheckedChanged(CompoundButton switchView, boolean isChecked) { + allowFP16 = isChecked; + onInferenceConfigurationChanged(); + } } diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java index 0dd00e9d25c..28a0a2e6fc1 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java @@ -64,7 +64,7 @@ public void onPreviewSizeChosen(final Size size, final int rotation) { borderedText = new BorderedText(textSizePx); borderedText.setTypeface(Typeface.MONOSPACE); - recreateClassifier(getModel(), getDevice(), getNumThreads()); + recreateClassifier(getModel(), getDevice(), getNumThreads(), getFP16()); if (classifier == null) { LOGGER.e("No classifier on preview!"); return; @@ -123,10 +123,11 @@ protected void onInferenceConfigurationChanged() { final Device device = getDevice(); final Model model = getModel(); final int numThreads = getNumThreads(); - runInBackground(() -> recreateClassifier(model, device, numThreads)); + final boolean fp16 = getFP16(); + runInBackground(() -> recreateClassifier(model, device, numThreads, fp16)); } - private void recreateClassifier(Model model, Device device, int numThreads) { + private void recreateClassifier(Model model, Device device, int numThreads, boolean fp16) { if (classifier != null) { LOGGER.d("Closing classifier."); classifier.close(); @@ -143,8 +144,8 @@ private void recreateClassifier(Model model, Device device, int numThreads) { } try { LOGGER.d( - "Creating classifier (model=%s, device=%s, numThreads=%d)", model, device, numThreads); - classifier = Classifier.create(this, model, device, numThreads); + "Creating classifier (model=%s, device=%s, numThreads=%d, allowFP16=%b)", model, device, numThreads, fp16); + classifier = Classifier.create(this, model, device, numThreads, fp16); } catch (IOException e) { LOGGER.e(e, "Failed to create classifier."); } diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java index c88d8630bb0..83d97c2b03e 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java @@ -109,14 +109,14 @@ public enum Device { * @param numThreads The number of threads to use for classification. * @return A classifier with the desired configuration. */ - public static Classifier create(Activity activity, Model model, Device device, int numThreads) + public static Classifier create(Activity activity, Model model, Device device, int numThreads, boolean fp16) throws IOException { if (model == Model.QUANTIZED_MOBILENET) { return new ClassifierQuantizedMobileNet(activity, device, numThreads); } else if (model == Model.FLOAT_MOBILENET) { - return new ClassifierFloatMobileNet(activity, device, numThreads); + return new ClassifierFloatMobileNet(activity, device, numThreads, fp16); } else if (model == Model.FLOAT_EFFICIENTNET) { - return new ClassifierFloatEfficientNet(activity, device, numThreads); + return new ClassifierFloatEfficientNet(activity, device, numThreads, fp16); } else if (model == Model.QUANTIZED_EFFICIENTNET) { return new ClassifierQuantizedEfficientNet(activity, device, numThreads); } else { @@ -195,12 +195,13 @@ public String toString() { } /** Initializes a {@code Classifier}. */ - protected Classifier(Activity activity, Device device, int numThreads) throws IOException { + protected Classifier(Activity activity, Device device, int numThreads, boolean fp16) throws IOException { tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); switch (device) { case NNAPI: nnApiDelegate = new NnApiDelegate(); tfliteOptions.addDelegate(nnApiDelegate); + tfliteOptions.setAllowFp16PrecisionForFp32(fp16); break; case GPU: gpuDelegate = new GpuDelegate(); diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java index a48bd823384..f84f023b65d 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java @@ -40,9 +40,9 @@ public class ClassifierFloatEfficientNet extends Classifier { * * @param activity */ - public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) + public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads, boolean fp16) throws IOException { - super(activity, device, numThreads); + super(activity, device, numThreads, fp16); } @Override diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java index dd3b6aea9c9..a4fece68aed 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java @@ -42,9 +42,9 @@ public class ClassifierFloatMobileNet extends Classifier { * * @param activity */ - public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) + public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads, boolean fp16) throws IOException { - super(activity, device, numThreads); + super(activity, device, numThreads, fp16); } @Override diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java index f84313f6e28..032bec7e7e2 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java @@ -43,7 +43,7 @@ public class ClassifierQuantizedEfficientNet extends Classifier { */ public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) throws IOException { - super(activity, device, numThreads); + super(activity, device, numThreads, /*fp16=*/ false); } @Override diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java index 5f18f79c956..e68f3798dd6 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java @@ -44,7 +44,7 @@ public class ClassifierQuantizedMobileNet extends Classifier { */ public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) throws IOException { - super(activity, device, numThreads); + super(activity, device, numThreads, false); } @Override diff --git a/lite/examples/image_classification/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml b/lite/examples/image_classification/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml index 601b4a45d93..f4d62f21c0b 100644 --- a/lite/examples/image_classification/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml +++ b/lite/examples/image_classification/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml @@ -224,6 +224,12 @@ android:layout_marginTop="10dp" android:background="@android:color/darker_gray" /> + + Date: Fri, 25 Oct 2019 10:45:45 +0800 Subject: [PATCH 2/4] disable FP16 switch by default enable the FP16 switch only when NNAPI and floating point model are used. --- .../lite/examples/classification/CameraActivity.java | 3 +++ .../lite/examples/classification/tflite/Classifier.java | 1 + .../classification/tflite/ClassifierQuantizedMobileNet.java | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java index 3c3e0798481..aff86dfb379 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java @@ -203,6 +203,7 @@ public void onSlide(@NonNull View bottomSheet, float slideOffset) {} device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); allowFP16 = fp16Switch.isChecked(); + fp16Switch.setEnabled(false); } protected int[] getRgbBytes() { @@ -585,6 +586,7 @@ private void setModel(Model model) { if (this.model != model) { LOGGER.d("Updating model: " + model); this.model = model; + fp16Switch.setEnabled((model == Model.FLOAT) && (device == Device.NNAPI)); onInferenceConfigurationChanged(); } } @@ -601,6 +603,7 @@ private void setDevice(Device device) { plusImageView.setEnabled(threadsEnabled); minusImageView.setEnabled(threadsEnabled); threadsTextView.setText(threadsEnabled ? String.valueOf(numThreads) : "N/A"); + fp16Switch.setEnabled((model == Model.FLOAT) && (device == Device.NNAPI)); onInferenceConfigurationChanged(); } } diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java index 83d97c2b03e..774f97a1a60 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java @@ -107,6 +107,7 @@ public enum Device { * @param model The model to use for classification. * @param device The device to use for classification. * @param numThreads The number of threads to use for classification. + * @param fp16 Allow FP32 model to run on FP16 accelerators * @return A classifier with the desired configuration. */ public static Classifier create(Activity activity, Model model, Device device, int numThreads, boolean fp16) diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java index e68f3798dd6..1d9c02176f3 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java @@ -44,7 +44,7 @@ public class ClassifierQuantizedMobileNet extends Classifier { */ public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) throws IOException { - super(activity, device, numThreads, false); + super(activity, device, numThreads, /*fp16=*/ false); } @Override From 33aa4c37175281e7a3583b21147ed05fa5dc9907 Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Sat, 16 Nov 2019 10:51:25 +0800 Subject: [PATCH 3/4] classifier could be null --- .../lite/examples/classification/ClassifierActivity.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java index 28a0a2e6fc1..188f7ad55ef 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java @@ -83,6 +83,8 @@ public void onPreviewSizeChosen(final Size size, final int rotation) { @Override protected void processImage() { rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); + final int imageSizeX = (classifier == null) ? 224 : classifier.getImageSizeX(); + final int imageSizeY = (classifier == null) ? 224 : classifier.getImageSizeY(); final int cropSize = Math.min(previewWidth, previewHeight); runInBackground( From 7243233fa0d6bb55d25ef84f2263480e106c6cb5 Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Tue, 7 Apr 2020 11:29:01 +0800 Subject: [PATCH 4/4] reflect model changes --- .../lite/examples/classification/CameraActivity.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java index aff86dfb379..93abc616b98 100644 --- a/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java +++ b/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java @@ -586,7 +586,7 @@ private void setModel(Model model) { if (this.model != model) { LOGGER.d("Updating model: " + model); this.model = model; - fp16Switch.setEnabled((model == Model.FLOAT) && (device == Device.NNAPI)); + fp16Switch.setEnabled(((model == Model.FLOAT_MOBILENET) || (model == Model.FLOAT_EFFICIENTNET)) && (device == Device.NNAPI)); onInferenceConfigurationChanged(); } } @@ -603,7 +603,7 @@ private void setDevice(Device device) { plusImageView.setEnabled(threadsEnabled); minusImageView.setEnabled(threadsEnabled); threadsTextView.setText(threadsEnabled ? String.valueOf(numThreads) : "N/A"); - fp16Switch.setEnabled((model == Model.FLOAT) && (device == Device.NNAPI)); + fp16Switch.setEnabled(((model == Model.FLOAT_MOBILENET) || (model == Model.FLOAT_EFFICIENTNET)) && (device == Device.NNAPI)); onInferenceConfigurationChanged(); } }