Skip to content

Commit

Permalink
Removed example tests, corrected private variable naming
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasKiljanczykDev committed May 20, 2021
1 parent bdc163e commit f4d87b0
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 105 deletions.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@ open class MainActivity : AppCompatActivity() {
private lateinit var mDetector: Detector
private var mDetectionProcessor: DetectionProcessor? = null

private lateinit var sourceBitmap: Bitmap
private lateinit var cropBitmap: Bitmap
private lateinit var mSourceBitmap: Bitmap
private lateinit var mCropBitmap: Bitmap

override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
mBinding = ActivityMainBinding.inflate(layoutInflater)
setContentView(mBinding.root)

sourceBitmap = assets.open("kite.jpg").use { inputStream ->
mSourceBitmap = assets.open("kite.jpg").use { inputStream ->
BitmapFactory.decodeStream(inputStream)
}

cropBitmap = processBitmap(sourceBitmap, DETECTION_MODEL.inputSize)
mBinding.imageView.setImageBitmap(cropBitmap)
mCropBitmap = processBitmap(mSourceBitmap, DETECTION_MODEL.inputSize)
mBinding.imageView.setImageBitmap(mCropBitmap)

setUpDetector()
lifecycleScope.launch(Dispatchers.Main) {
Expand Down Expand Up @@ -104,7 +104,7 @@ open class MainActivity : AppCompatActivity() {

mBinding.detectButton.setOnClickListener {
lifecycleScope.launch(Dispatchers.Default) {
mDetectionProcessor?.processImage(cropBitmap)
mDetectionProcessor?.processImage(mCropBitmap)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class DetectionProcessor(
previewWidth,
previewHeight,
((rotation + 1) % 4) * 90,
showScore = SHOW_SCORE
mShowScore = SHOW_SCORE
)
trackingOverlay.setTracker(mTracker)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ interface Detector {
* A unique identifier for what has been recognized. Specific to the detected class,
* not the instance of the [Detection] object.
*/
private val id: String,
private val mId: String,
/**
* Display name for the [Detection].
*/
Expand All @@ -36,7 +36,7 @@ interface Detector {
val detectedClass: Int
) : Comparable<Detection> {
override fun toString(): String {
var resultString = "[$id] $className "
var resultString = "[$mId] $className "
resultString += "(%.1f%%) ".format(score * 100.0f)
resultString += "$boundingBox"
return resultString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import kotlin.math.min
@Suppress("UNNECESSARY_NOT_NULL_ASSERTION")
internal class YoloV4Detector(
assetManager: AssetManager,
private val detectionModel: DetectionModel,
private val minimumScore: Float,
private val mDetectionModel: DetectionModel,
private val mMinimumScore: Float,
) : Detector {

private companion object {
Expand All @@ -33,53 +33,50 @@ internal class YoloV4Detector(

}

private val inputSize: Int = detectionModel.inputSize
private val mInputSize: Int = mDetectionModel.inputSize

// Config values.
// Pre-allocated buffers.
private val labels: List<String>
private val interpreter: Interpreter
private val mLabels: List<String>
private val mInterpreter: Interpreter
private val mNmsThresh = 0.6f

private val intValues = IntArray(inputSize * inputSize)

private val byteBuffer: Array<ByteBuffer>

private val outputMap: MutableMap<Int, Array<Array<FloatArray>>> = HashMap()
// Pre-allocated buffers.
private val intValues = IntArray(mInputSize * mInputSize)
private val mByteBuffer: Array<ByteBuffer>
private val mOutputMap: MutableMap<Int, Array<Array<FloatArray>>> = HashMap()

init {

val labelsFilename = detectionModel.labelFilePath
val labelsFilename = mDetectionModel.labelFilePath
.split("file:///android_asset/")
.toTypedArray()[1]

labels = assetManager.open(labelsFilename)
mLabels = assetManager.open(labelsFilename)
.use { it.readBytes() }
.decodeToString()
.trim()
.split("\n")
.map { it.trim() }

interpreter = initializeInterpreter(assetManager)
mInterpreter = initializeInterpreter(assetManager)

val numBytesPerChannel = if (detectionModel.isQuantized) {
val numBytesPerChannel = if (mDetectionModel.isQuantized) {
1 // Quantized (int8)
} else {
4 // Floating point (fp32)
}

// input size * input size * pixel count (RGB) * pixel size (int8/fp32)
byteBuffer = arrayOf(
ByteBuffer.allocateDirect(inputSize * inputSize * 3 * numBytesPerChannel)
mByteBuffer = arrayOf(
ByteBuffer.allocateDirect(mInputSize * mInputSize * 3 * numBytesPerChannel)
)
byteBuffer[0].order(ByteOrder.nativeOrder())
mByteBuffer[0].order(ByteOrder.nativeOrder())

outputMap[0] = arrayOf(Array(detectionModel.outputSize) { FloatArray(numBytesPerChannel) })
outputMap[1] = arrayOf(Array(detectionModel.outputSize) { FloatArray(labels.size) })
mOutputMap[0] = arrayOf(Array(mDetectionModel.outputSize) { FloatArray(numBytesPerChannel) })
mOutputMap[1] = arrayOf(Array(mDetectionModel.outputSize) { FloatArray(mLabels.size) })
}

override fun getDetectionModel(): DetectionModel {
return detectionModel
return mDetectionModel
}

override fun runDetection(bitmap: Bitmap): List<Detection> {
Expand All @@ -105,7 +102,7 @@ internal class YoloV4Detector(
}
}

return assetManager.openFd(detectionModel.modelFilename).use { fileDescriptor ->
return assetManager.openFd(mDetectionModel.modelFilename).use { fileDescriptor ->
val fileInputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileByteBuffer = fileInputStream.channel.map(
FileChannel.MapMode.READ_ONLY,
Expand All @@ -122,36 +119,36 @@ internal class YoloV4Detector(
*/
private fun convertBitmapToByteBuffer(bitmap: Bitmap) {
val startTime = SystemClock.uptimeMillis()
val scaledBitmap = Bitmap.createScaledBitmap(bitmap, inputSize, inputSize, true)
val scaledBitmap = Bitmap.createScaledBitmap(bitmap, mInputSize, mInputSize, true)

scaledBitmap.getPixels(intValues, 0, inputSize, 0, 0, inputSize, inputSize)
scaledBitmap.getPixels(intValues, 0, mInputSize, 0, 0, mInputSize, mInputSize)
scaledBitmap.recycle()

byteBuffer[0].clear()
mByteBuffer[0].clear()
for (pixel in intValues) {
val r = (pixel and 0xFF) / 255.0f
val g = (pixel shr 8 and 0xFF) / 255.0f
val b = (pixel shr 16 and 0xFF) / 255.0f

byteBuffer[0].putFloat(r)
byteBuffer[0].putFloat(g)
byteBuffer[0].putFloat(b)
mByteBuffer[0].putFloat(r)
mByteBuffer[0].putFloat(g)
mByteBuffer[0].putFloat(b)
}
Log.v(TAG, "ByteBuffer conversion time : ${SystemClock.uptimeMillis() - startTime} ms")
}

private fun getDetections(imageWidth: Int, imageHeight: Int): List<Detection> {
interpreter.runForMultipleInputsOutputs(byteBuffer, outputMap as Map<Int, Any>)
mInterpreter.runForMultipleInputsOutputs(mByteBuffer, mOutputMap as Map<Int, Any>)

val boundingBoxes = outputMap[0]!![0]
val outScore = outputMap[1]!![0]
val boundingBoxes = mOutputMap[0]!![0]
val outScore = mOutputMap[1]!![0]

return outScore.zip(boundingBoxes)
.mapIndexedNotNull { index, (classScores, boundingBoxes) ->
val bestClassIndex: Int = labels.indices.maxByOrNull { classScores[it] }!!
val bestClassIndex: Int = mLabels.indices.maxByOrNull { classScores[it] }!!
val bestScore = classScores[bestClassIndex]

if (bestScore <= minimumScore) {
if (bestScore <= mMinimumScore) {
return@mapIndexedNotNull null
}

Expand All @@ -167,8 +164,8 @@ internal class YoloV4Detector(
)

return@mapIndexedNotNull Detection(
id = index.toString(),
className = labels[bestClassIndex],
mId = index.toString(),
className = mLabels[bestClassIndex],
detectedClass = bestClassIndex,
score = bestScore,
boundingBox = rectF
Expand All @@ -179,7 +176,7 @@ internal class YoloV4Detector(
private fun nms(detections: List<Detection>): List<Detection> {
val nmsList: MutableList<Detection> = mutableListOf()

for (labelIndex in labels.indices) {
for (labelIndex in mLabels.indices) {
val priorityQueue = PriorityQueue<Detection>(50)
priorityQueue.addAll(detections.filter { it.detectedClass == labelIndex })

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,29 @@ import java.nio.ByteBuffer
* Utility class for converting [Image] to [Bitmap].
*/
class ImageToBitmapConverter(context: Context, image: Image) {
private val bitmap: Bitmap =
private val mBitmap: Bitmap =
Bitmap.createBitmap(image.width, image.height, Bitmap.Config.ARGB_8888)

private val renderScript: RenderScript = RenderScript.create(context)
private val mRenderScript: RenderScript = RenderScript.create(context)

private val scriptYuvToRgb: ScriptIntrinsicYuvToRGB =
ScriptIntrinsicYuvToRGB.create(renderScript, Element.U8_3(renderScript))
private val mScriptYuvToRgb: ScriptIntrinsicYuvToRGB =
ScriptIntrinsicYuvToRGB.create(mRenderScript, Element.U8_3(mRenderScript))

private val elemType = Type.Builder(renderScript, Element.YUV(renderScript))
private val mElemType = Type.Builder(mRenderScript, Element.YUV(mRenderScript))
.setYuvFormat(ImageFormat.YUV_420_888)
.create()

private val inputAllocation: Allocation =
private val mInputAllocation: Allocation =
Allocation.createSized(
renderScript,
elemType.element,
mRenderScript,
mElemType.element,
image.planes.sumOf { it.buffer.remaining() }
)

private val outputAllocation: Allocation = Allocation.createFromBitmap(renderScript, bitmap)
private val mOutputAllocation: Allocation = Allocation.createFromBitmap(mRenderScript, mBitmap)

init {
scriptYuvToRgb.setInput(inputAllocation)
mScriptYuvToRgb.setInput(mInputAllocation)
}

/**
Expand All @@ -42,11 +42,11 @@ class ImageToBitmapConverter(context: Context, image: Image) {
fun imageToBitmap(image: Image): Bitmap {
val yuvBuffer: ByteArray = yuv420ToByteArray(image)

inputAllocation.copyFrom(yuvBuffer)
scriptYuvToRgb.forEach(outputAllocation)
outputAllocation.copyTo(bitmap)
mInputAllocation.copyFrom(yuvBuffer)
mScriptYuvToRgb.forEach(mOutputAllocation)
mOutputAllocation.copyTo(mBitmap)

return bitmap
return mBitmap
}

private fun yuv420ToByteArray(image: Image): ByteArray {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MultiBoxTracker(
private val mFrameWidth: Int,
private val mFrameHeight: Int,
private val mOrientation: Int,
private val showScore: Boolean = true
private val mShowScore: Boolean = true
) {

private companion object {
Expand Down Expand Up @@ -120,14 +120,14 @@ class MultiBoxTracker(
trackedDetection.boxPaint
)

val labelString = if (showScore && trackedDetection.title.isNotBlank()) {
val labelString = if (mShowScore && trackedDetection.title.isNotBlank()) {
"%s %.2f%%".format(
trackedDetection.title,
100 * trackedDetection.score
)
} else if (trackedDetection.title.isNotBlank()) {
trackedDetection.title
} else if (showScore) {
} else if (mShowScore) {
"%.2f%%".format(100 * trackedDetection.score)
} else ""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import android.view.View
class TrackingOverlayView(context: Context, attrs: AttributeSet?) : View(context, attrs) {
private var mTracker: MultiBoxTracker? = null


override fun draw(canvas: Canvas) {
super.draw(canvas)
mTracker?.draw(canvas)
Expand Down

0 comments on commit f4d87b0

Please sign in to comment.