Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp16 option for NNAPI in Android classification example #105

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -43,9 +43,11 @@
import android.view.ViewTreeObserver;
import android.view.WindowManager;
import android.widget.AdapterView;
import android.widget.CompoundButton;
import android.widget.ImageView;
import android.widget.LinearLayout;
import android.widget.Spinner;
import android.widget.Switch;
import android.widget.TextView;
import android.widget.Toast;
import com.google.android.material.bottomsheet.BottomSheetBehavior;
@@ -61,7 +63,7 @@ public abstract class CameraActivity extends AppCompatActivity
implements OnImageAvailableListener,
Camera.PreviewCallback,
View.OnClickListener,
AdapterView.OnItemSelectedListener {
AdapterView.OnItemSelectedListener, Switch.OnCheckedChangeListener {
private static final Logger LOGGER = new Logger();

private static final int PERMISSIONS_REQUEST = 1;
@@ -97,10 +99,12 @@ public abstract class CameraActivity extends AppCompatActivity
private Spinner modelSpinner;
private Spinner deviceSpinner;
private TextView threadsTextView;
private Switch fp16Switch;

private Model model = Model.QUANTIZED_EFFICIENTNET;
private Device device = Device.CPU;
private int numThreads = -1;
private boolean allowFP16 = false;

@Override
protected void onCreate(final Bundle savedInstanceState) {
@@ -116,6 +120,7 @@ protected void onCreate(final Bundle savedInstanceState) {
requestPermission();
}

fp16Switch = findViewById(R.id.fp16);
threadsTextView = findViewById(R.id.threads);
plusImageView = findViewById(R.id.plus);
minusImageView = findViewById(R.id.minus);
@@ -192,9 +197,13 @@ public void onSlide(@NonNull View bottomSheet, float slideOffset) {}
plusImageView.setOnClickListener(this);
minusImageView.setOnClickListener(this);

fp16Switch.setOnCheckedChangeListener(this);

model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase());
device = Device.valueOf(deviceSpinner.getSelectedItem().toString());
numThreads = Integer.parseInt(threadsTextView.getText().toString().trim());
allowFP16 = fp16Switch.isChecked();
fp16Switch.setEnabled(false);
}

protected int[] getRgbBytes() {
@@ -577,6 +586,7 @@ private void setModel(Model model) {
if (this.model != model) {
LOGGER.d("Updating model: " + model);
this.model = model;
fp16Switch.setEnabled(((model == Model.FLOAT_MOBILENET) || (model == Model.FLOAT_EFFICIENTNET)) && (device == Device.NNAPI));
onInferenceConfigurationChanged();
}
}
@@ -593,6 +603,7 @@ private void setDevice(Device device) {
plusImageView.setEnabled(threadsEnabled);
minusImageView.setEnabled(threadsEnabled);
threadsTextView.setText(threadsEnabled ? String.valueOf(numThreads) : "N/A");
fp16Switch.setEnabled(((model == Model.FLOAT_MOBILENET) || (model == Model.FLOAT_EFFICIENTNET)) && (device == Device.NNAPI));
onInferenceConfigurationChanged();
}
}
@@ -609,6 +620,8 @@ private void setNumThreads(int numThreads) {
}
}

protected boolean getFP16() { return allowFP16; }

protected abstract void processImage();

protected abstract void onPreviewSizeChosen(final Size size, final int rotation);
@@ -651,4 +664,10 @@ public void onItemSelected(AdapterView<?> parent, View view, int pos, long id) {
public void onNothingSelected(AdapterView<?> parent) {
// Do nothing.
}

@Override
public void onCheckedChanged(CompoundButton switchView, boolean isChecked) {
allowFP16 = isChecked;
onInferenceConfigurationChanged();
}
}
Original file line number Diff line number Diff line change
@@ -64,7 +64,7 @@ public void onPreviewSizeChosen(final Size size, final int rotation) {
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);

recreateClassifier(getModel(), getDevice(), getNumThreads());
recreateClassifier(getModel(), getDevice(), getNumThreads(), getFP16());
if (classifier == null) {
LOGGER.e("No classifier on preview!");
return;
@@ -83,6 +83,8 @@ public void onPreviewSizeChosen(final Size size, final int rotation) {
@Override
protected void processImage() {
rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
final int imageSizeX = (classifier == null) ? 224 : classifier.getImageSizeX();
final int imageSizeY = (classifier == null) ? 224 : classifier.getImageSizeY();
final int cropSize = Math.min(previewWidth, previewHeight);

runInBackground(
@@ -123,10 +125,11 @@ protected void onInferenceConfigurationChanged() {
final Device device = getDevice();
final Model model = getModel();
final int numThreads = getNumThreads();
runInBackground(() -> recreateClassifier(model, device, numThreads));
final boolean fp16 = getFP16();
runInBackground(() -> recreateClassifier(model, device, numThreads, fp16));
}

private void recreateClassifier(Model model, Device device, int numThreads) {
private void recreateClassifier(Model model, Device device, int numThreads, boolean fp16) {
if (classifier != null) {
LOGGER.d("Closing classifier.");
classifier.close();
@@ -143,8 +146,8 @@ private void recreateClassifier(Model model, Device device, int numThreads) {
}
try {
LOGGER.d(
"Creating classifier (model=%s, device=%s, numThreads=%d)", model, device, numThreads);
classifier = Classifier.create(this, model, device, numThreads);
"Creating classifier (model=%s, device=%s, numThreads=%d, allowFP16=%b)", model, device, numThreads, fp16);
classifier = Classifier.create(this, model, device, numThreads, fp16);
} catch (IOException e) {
LOGGER.e(e, "Failed to create classifier.");
}
Original file line number Diff line number Diff line change
@@ -107,16 +107,17 @@ public enum Device {
* @param model The model to use for classification.
* @param device The device to use for classification.
* @param numThreads The number of threads to use for classification.
* @param fp16 Allow FP32 model to run on FP16 accelerators
* @return A classifier with the desired configuration.
*/
public static Classifier create(Activity activity, Model model, Device device, int numThreads)
public static Classifier create(Activity activity, Model model, Device device, int numThreads, boolean fp16)
throws IOException {
if (model == Model.QUANTIZED_MOBILENET) {
return new ClassifierQuantizedMobileNet(activity, device, numThreads);
} else if (model == Model.FLOAT_MOBILENET) {
return new ClassifierFloatMobileNet(activity, device, numThreads);
return new ClassifierFloatMobileNet(activity, device, numThreads, fp16);
} else if (model == Model.FLOAT_EFFICIENTNET) {
return new ClassifierFloatEfficientNet(activity, device, numThreads);
return new ClassifierFloatEfficientNet(activity, device, numThreads, fp16);
} else if (model == Model.QUANTIZED_EFFICIENTNET) {
return new ClassifierQuantizedEfficientNet(activity, device, numThreads);
} else {
@@ -195,12 +196,13 @@ public String toString() {
}

/** Initializes a {@code Classifier}. */
protected Classifier(Activity activity, Device device, int numThreads) throws IOException {
protected Classifier(Activity activity, Device device, int numThreads, boolean fp16) throws IOException {
tfliteModel = FileUtil.loadMappedFile(activity, getModelPath());
switch (device) {
case NNAPI:
nnApiDelegate = new NnApiDelegate();
tfliteOptions.addDelegate(nnApiDelegate);
tfliteOptions.setAllowFp16PrecisionForFp32(fp16);
break;
case GPU:
gpuDelegate = new GpuDelegate();
Original file line number Diff line number Diff line change
@@ -40,9 +40,9 @@ public class ClassifierFloatEfficientNet extends Classifier {
*
* @param activity
*/
public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads)
public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads, boolean fp16)
throws IOException {
super(activity, device, numThreads);
super(activity, device, numThreads, fp16);
}

@Override
Original file line number Diff line number Diff line change
@@ -42,9 +42,9 @@ public class ClassifierFloatMobileNet extends Classifier {
*
* @param activity
*/
public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads)
public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads, boolean fp16)
throws IOException {
super(activity, device, numThreads);
super(activity, device, numThreads, fp16);
}

@Override
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ public class ClassifierQuantizedEfficientNet extends Classifier {
*/
public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads)
throws IOException {
super(activity, device, numThreads);
super(activity, device, numThreads, /*fp16=*/ false);
}

@Override
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ public class ClassifierQuantizedMobileNet extends Classifier {
*/
public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads)
throws IOException {
super(activity, device, numThreads);
super(activity, device, numThreads, /*fp16=*/ false);
}

@Override
Original file line number Diff line number Diff line change
@@ -224,6 +224,12 @@
android:layout_marginTop="10dp"
android:background="@android:color/darker_gray" />

<Switch
android:id="@+id/fp16"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:text="AllowFP16" />

<RelativeLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"