diff --git a/OpenNSFW/build.gradle b/OpenNSFW/build.gradle index 2f0faa8bc..2e80dd125 100644 --- a/OpenNSFW/build.gradle +++ b/OpenNSFW/build.gradle @@ -9,8 +9,8 @@ android { defaultConfig { minSdkVersion 16 targetSdkVersion 29 - versionCode 1 - versionName "1.3.2" + versionCode 3 + versionName "1.3.3" testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" consumerProguardFiles 'consumer-rules.pro' diff --git a/OpenNSFW/src/main/assets/nsfw.tflite b/OpenNSFW/src/main/assets/nsfw.tflite deleted file mode 100644 index 9583ed20d..000000000 Binary files a/OpenNSFW/src/main/assets/nsfw.tflite and /dev/null differ diff --git a/OpenNSFW/src/main/java/com/zwy/opennsfw/Classifier.kt b/OpenNSFW/src/main/java/com/zwy/opennsfw/Classifier.kt deleted file mode 100644 index 3321c2ec4..000000000 --- a/OpenNSFW/src/main/java/com/zwy/opennsfw/Classifier.kt +++ /dev/null @@ -1,205 +0,0 @@ -//package com.zwy.opennsfw -// -//import android.content.Context -//import android.content.res.AssetManager -//import android.graphics.Bitmap -//import android.graphics.Color -//import android.os.SystemClock -//import android.util.Log -//import org.tensorflow.lite.Interpreter -//import org.tensorflow.lite.gpu.GpuDelegate -//import java.io.FileInputStream -//import java.lang.Math.max -//import java.nio.ByteBuffer -//import java.nio.ByteOrder -//import java.nio.MappedByteBuffer -//import java.nio.channels.FileChannel -// -//class Classifier -//private constructor(assetManager: AssetManager, isGPU: Boolean?, numThreads: Int) { -// -// /** -// * 数据宽高 -// */ -// private val INPUT_WIDTH = 224 -// -// /** -// * 数据宽高 -// */ -// private val INPUT_HEIGHT = 224 -// -// /** -// * 通道 -// */ -// private val BYTES_PER_CHANNEL_NUM = 4 -// -// /** -// * Resize后的数据源 -// */ -// private val intValues = IntArray(INPUT_WIDTH * INPUT_HEIGHT) -// -// /** -// * 载入模型的客户端 -// */ -// private var tfliteModel: MappedByteBuffer? = null -// -// /** -// * GPU代理 -// */ -// private var gpuDelegate: GpuDelegate? = null -// -// /** -// * Tensorflow Lite -// */ -// private var tflite: Interpreter? = null -// -// /** -// * 喂入模型的最终数据源 -// */ -// private val imgData: ByteBuffer? -// -// -// init { -// tfliteModel = loadModelFile(assetManager) -// val tfliteOptions = Interpreter.Options() -// if (isGPU == true) { -// gpuDelegate = GpuDelegate() -// tfliteOptions.addDelegate(gpuDelegate) -// } -// tfliteOptions.setNumThreads(numThreads) -// tflite = Interpreter(tfliteModel!!, tfliteOptions) -// -// val tensor = tflite!!.getInputTensor(tflite!!.getInputIndex("input")) -// val stringBuilder = (" \n" -// + "dataType : " + -// tensor.dataType() + -// "\n" + -// "numBytes : " + -// tensor.numBytes() + -// "\n" + -// "numDimensions : " + -// tensor.numDimensions() + -// "\n" + -// "numElements : " + -// tensor.numElements() + -// "\n" + -// "shape : " + -// tensor.shape().size) -// Log.d(TAG, stringBuilder) -// -// imgData = ByteBuffer.allocateDirect( -// DIM_BATCH_SIZE -// * INPUT_WIDTH -// * INPUT_HEIGHT -// * DIM_PIXEL_SIZE -// * BYTES_PER_CHANNEL_NUM -// ) -// -// imgData!!.order(ByteOrder.LITTLE_ENDIAN) -// Log.d(TAG, "Tensorflow Lite Image Classifier Initialization Success.") -// } -// -// /** -// * Memory-map the model file in Assets. -// */ -// private fun loadModelFile(assetManager: AssetManager): MappedByteBuffer { -// val context:Context -// val fileDescriptor = assetManager.openFd("nsfw.tflite") -// val inputStream = FileInputStream(fileDescriptor.fileDescriptor) -// val fileChannel = inputStream.channel -// val startOffset = fileDescriptor.startOffset -// val declaredLength = fileDescriptor.declaredLength -// return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) -// } -// -// -// /** -// * Writes Image data into a `ByteBuffer`. -// */ -// private fun convertBitmapToByteBuffer(bitmap_: Bitmap) { -// if (imgData == null || bitmap_ == null) { -// return -// } -// imgData.rewind() -// val W = bitmap_.width -// val H = bitmap_.height -// -// val w_off = max((W - INPUT_WIDTH) / 2, 0) -// val h_off = max((H - INPUT_HEIGHT) / 2, 0) -// -// //把每个像素的颜色值转为int 存入intValues -// bitmap_.getPixels(intValues, 0, INPUT_WIDTH, h_off, w_off, INPUT_WIDTH, INPUT_HEIGHT) -// // Convert the image to floating point. -// val startTime = SystemClock.uptimeMillis() -// for (color in intValues) { -// val r1 = Color.red(color) -// val g1 = Color.green(color) -// val b1 = Color.blue(color) -// -// val rr1 = r1 - 123 -// val gg1 = g1 - 117 -// val bb1 = b1 - 104 -// -// imgData.putFloat(bb1.toFloat()) -// imgData.putFloat(gg1.toFloat()) -// imgData.putFloat(rr1.toFloat()) -// } -// val endTime = SystemClock.uptimeMillis() -// Log.d(TAG, "Timecost to put values into ByteBuffer: " + (endTime - startTime) + "ms") -// } -// -// fun run(bitmap: Bitmap): NsfwBean { -// -// val bitmap_256 = Bitmap.createScaledBitmap(bitmap, 256, 256, true) -// -// //Writes image data into byteBuffer -// convertBitmapToByteBuffer(bitmap_256) -// -// val startTime = SystemClock.uptimeMillis() -// // out -// val outArray = Array(1) { FloatArray(2) } -// -// tflite!!.run(imgData, outArray) -// -// val endTime = SystemClock.uptimeMillis() -// -// Log.d(TAG, "SFW score :" + outArray[0][0] + ",NSFW score :" + outArray[0][1]) -// Log.d(TAG, "Timecost to run model inference: " + (endTime - startTime) + "ms") -// return NsfwBean(outArray[0][0], outArray[0][1]) -// } -// -// /** -// * Closes the interpreter and model to release resources. -// */ -// fun close() { -// if (tflite != null) { -// tflite!!.close() -// tflite = null -// Log.d(TAG, "Tensorflow Lite Image Classifier close.") -// } -// if (gpuDelegate != null) { -// gpuDelegate!!.close() -// Log.d(TAG, "Tensorflow Lite Image gpuDelegate close.") -// gpuDelegate = null -// } -// tfliteModel = null -// Log.d(TAG, "Tensorflow Lite destroyed.") -// } -// -// companion object { -// -// val TAG = "open_nsfw_android" -// /** -// * Dimensions of inputs. -// */ -// private val DIM_BATCH_SIZE = 1 -// -// private val DIM_PIXEL_SIZE = 3 -// -// fun create(assetManager: AssetManager, isAddGpuDelegate: Boolean?, numThreads: Int): Classifier { -// return Classifier(assetManager, isAddGpuDelegate!!, numThreads) -// } -// -// } -// -//} diff --git a/OpenNSFW/src/main/java/com/zwy/opennsfw/core/Classifier.kt b/OpenNSFW/src/main/java/com/zwy/opennsfw/core/Classifier.kt index 1a80bb94f..6113f6b2a 100644 --- a/OpenNSFW/src/main/java/com/zwy/opennsfw/core/Classifier.kt +++ b/OpenNSFW/src/main/java/com/zwy/opennsfw/core/Classifier.kt @@ -11,12 +11,10 @@ import mClassifier import org.tensorflow.lite.Interpreter import org.tensorflow.lite.gpu.GpuDelegate import java.io.ByteArrayOutputStream -import java.io.FileInputStream +import java.io.File import java.lang.Math.max import java.nio.ByteBuffer import java.nio.ByteOrder -import java.nio.MappedByteBuffer -import java.nio.channels.FileChannel class Classifier private constructor(config: Config) { @@ -101,19 +99,31 @@ class Classifier private constructor(config: Config) { return this } + fun nsfwModuleFilePath(nsfwModuleFilePath: String): Build { + config.nsfwModuleFilePath = nsfwModuleFilePath + return this + } + fun build(): Classifier { return get(config) } } + //"/data/user/0/com.zwy.demo/files/nsfw.tflite" init { + + val file = File( + config.nsfwModuleFilePath + ?: throw java.lang.NullPointerException("未配置模型路径,请调用Classifier.Build().nsfwModuleFilePath(模型路径)初始化") + ) + if (!file.exists()) throw NullPointerException("模型加载失败,请确认路径是否正确") try { tflite = - Interpreter(loadModelFile(config.context!!), getTfLiteOptions(config.isOpenGPU)) + Interpreter(file, getTfLiteOptions(config.isOpenGPU)) if (config.isOpenGPU) "开启GPU加速成功".d() } catch (e: Exception) { "不支持GPU加速".e() - tflite = Interpreter(loadModelFile(config.context!!), getTfLiteOptions(false)) + tflite = Interpreter(file, getTfLiteOptions(false)) } imgData = ByteBuffer.allocateDirect( @@ -136,14 +146,6 @@ class Classifier private constructor(config: Config) { } } - private fun loadModelFile(context: Context): MappedByteBuffer { - val fileDescriptor = context.assets.openFd("nsfw.tflite") - val inputStream = FileInputStream(fileDescriptor.fileDescriptor) - val fileChannel = inputStream.channel - val startOffset = fileDescriptor.startOffset - val declaredLength = fileDescriptor.declaredLength - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) - } private fun convertBitmapToByteBuffer(bitmap_: Bitmap) { imgData.rewind() diff --git a/OpenNSFW/src/main/java/com/zwy/opennsfw/core/Config.kt b/OpenNSFW/src/main/java/com/zwy/opennsfw/core/Config.kt index 206e3b383..c4339d360 100644 --- a/OpenNSFW/src/main/java/com/zwy/opennsfw/core/Config.kt +++ b/OpenNSFW/src/main/java/com/zwy/opennsfw/core/Config.kt @@ -7,6 +7,17 @@ data class Config( * 是否开启GPU加速 */ var isOpenGPU: Boolean = true, + /** + * 扫描占用的线程数 + */ var numThreads: Int = 1, - var context: Context? + /** + * 全局配置的context + */ + var context: Context?, + + /** + * nsfw模型存放目录 + */ + var nsfwModuleFilePath: String? = null ) \ No newline at end of file diff --git a/demo/build.gradle b/demo/build.gradle index a1dfc5119..cd8ba281b 100644 --- a/demo/build.gradle +++ b/demo/build.gradle @@ -11,8 +11,8 @@ android { applicationId "com.zwy.demo" minSdkVersion 19 targetSdkVersion 29 - versionCode 1 - versionName "1.3.2" + versionCode 3 + versionName "1.3.3" multiDexEnabled true testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" } @@ -83,8 +83,8 @@ dependencies { implementation "org.jetbrains.anko:anko:0.10.5" //NSFW鉴黄库 - implementation 'com.github.devzwy:open_nsfw_android:1.3.2' -// implementation project(path: ':OpenNSFW') +// implementation 'com.github.devzwy:open_nsfw_android:1.3.2' + implementation project(path: ':OpenNSFW') implementation 'pub.devrel:easypermissions:3.0.0' implementation 'com.github.LuckSiege.PictureSelector:picture_library:2.2.5' diff --git a/demo/src/main/java/com/zwy/demo/NSFWApplication.kt b/demo/src/main/java/com/zwy/demo/NSFWApplication.kt index ba1b7f6b6..3c3cdc7ea 100644 --- a/demo/src/main/java/com/zwy/demo/NSFWApplication.kt +++ b/demo/src/main/java/com/zwy/demo/NSFWApplication.kt @@ -12,6 +12,7 @@ import org.koin.android.ext.koin.androidLogger import org.koin.core.context.startKoin import java.util.* + class NSFWApplication : MultiDexApplication() { @@ -30,6 +31,7 @@ class NSFWApplication : MultiDexApplication() { .context(this) //必须调用 否则会有异常抛出 // .isOpenGPU(true)//默认不开启GPU加速,默认为true // .numThreads(100) //分配的线程数 根据手机配置设置,默认1 +// .nsfwModuleFilePath("/data/user/0/com.zwy.demo/files/nsfw.tflite") .build() //全局注入对象 startKoin { @@ -47,4 +49,6 @@ class NSFWApplication : MultiDexApplication() { lateinit var context: Context var startTime: Long = 0 } + + } \ No newline at end of file diff --git a/demo/src/main/java/com/zwy/demo/views/MainActivity.kt b/demo/src/main/java/com/zwy/demo/views/MainActivity.kt index a15033e09..127003506 100644 --- a/demo/src/main/java/com/zwy/demo/views/MainActivity.kt +++ b/demo/src/main/java/com/zwy/demo/views/MainActivity.kt @@ -1,11 +1,17 @@ package com.zwy.demo.views +import android.app.Activity import android.os.Bundle import com.zwy.demo.R import com.zwy.demo.base.BaseActivity import com.zwy.demo.databinding.MainLayoutBinding import com.zwy.demo.models.MainViewModel +import d import org.jetbrains.anko.toast +import java.io.File +import java.io.FileOutputStream +import java.io.IOException +import java.io.InputStream class MainActivity : BaseActivity() { /** @@ -19,6 +25,7 @@ class MainActivity : BaseActivity() { override fun initData() { binding.titles = viewModel.titles viewModel.getTitles() +// copyAssetsFile2Phone(this,"nsfw.tflite") } private var mExitTime: Long = 0 @@ -33,4 +40,35 @@ class MainActivity : BaseActivity() { } +// /** +// * 将文件从assets目录,考贝到 /data/data/包名/files/ 目录中。assets 目录中的文件,会不经压缩打包至APK包中,使用时还应从apk包中导出来 +// * @param fileName 文件名,如aaa.txt +// */ +// fun copyAssetsFile2Phone(activity: Activity, fileName: String) { +// try { +// val inputStream: InputStream = activity.assets.open(fileName) +// //getFilesDir() 获得当前APP的安装路径 /data/data/包名/files 目录 +// val file = File( +// activity.filesDir.absolutePath + File.separator.toString() + fileName +// ) +// if (!file.exists() || file.length()==0L) { +// val fos = FileOutputStream(file) //如果文件不存在,FileOutputStream会自动创建文件 +// var len = -1 +// val buffer = ByteArray(1024) +// while (inputStream.read(buffer).also({ len = it }) != -1) { +// fos.write(buffer, 0, len) +// } +// fos.flush() //刷新缓存区 +// inputStream.close() +// fos.close() +// "模型文件复制完毕".d() +// } else { +// "模型文件已存在,无需复制".d() +// } +// } catch (e: IOException) { +// e.printStackTrace() +// } +// } + + }