diff --git a/flutter/integration_test/utils.dart b/flutter/integration_test/utils.dart index f8fcd7ed1..ab5f0ee16 100644 --- a/flutter/integration_test/utils.dart +++ b/flutter/integration_test/utils.dart @@ -33,7 +33,7 @@ Future startApp(WidgetTester tester) async { Future validateSettings(WidgetTester tester) async { final state = tester.state(find.byType(MaterialApp)); final benchmarkState = state.context.read(); - for (var benchmark in benchmarkState.benchmarks) { + for (var benchmark in benchmarkState.allBenchmarks) { expect(benchmark.selectedDelegate.batchSize, greaterThanOrEqualTo(0), reason: 'batchSize must >= 0'); for (var modelFile in benchmark.selectedDelegate.modelFile) { @@ -67,7 +67,7 @@ Future validateSettings(WidgetTester tester) async { Future setBenchmarks(WidgetTester tester) async { final state = tester.state(find.byType(MaterialApp)); final benchmarkState = state.context.read(); - for (var benchmark in benchmarkState.benchmarks) { + for (var benchmark in benchmarkState.allBenchmarks) { // Disable test for stable diffusion since it take too long to finish. if (benchmark.id == BenchmarkId.stableDiffusion) { benchmark.isActive = false; diff --git a/flutter/lib/benchmark/benchmark.dart b/flutter/lib/benchmark/benchmark.dart index 38b1867f3..a753338f3 100644 --- a/flutter/lib/benchmark/benchmark.dart +++ b/flutter/lib/benchmark/benchmark.dart @@ -117,7 +117,11 @@ class Benchmark { } class BenchmarkStore { - final List benchmarks = []; + final List allBenchmarks = []; + + List get activeBenchmarks { + return allBenchmarks.where((e) => e.isActive).toList(); + } BenchmarkStore({ required pb.MLPerfConfig appConfig, @@ -137,7 +141,7 @@ class BenchmarkStore { } final enabled = taskSelection[task.id] ?? true; - benchmarks.add(Benchmark( + allBenchmarks.add(Benchmark( taskConfig: task, benchmarkSettings: backendSettings, isActive: enabled, @@ -186,7 +190,7 @@ class BenchmarkStore { Map get selection { Map result = {}; - for (var item in benchmarks) { + for (var item in allBenchmarks) { result[item.id] = item.isActive; } return result; diff --git a/flutter/lib/benchmark/state.dart b/flutter/lib/benchmark/state.dart index db9b07ecd..be21b5193 100644 --- a/flutter/lib/benchmark/state.dart +++ b/flutter/lib/benchmark/state.dart @@ -54,20 +54,20 @@ class BenchmarkState extends ChangeNotifier { ExtendedResult? lastResult; num get result { - final benchmarksCount = benchmarks + final benchmarksCount = allBenchmarks .where((benchmark) => benchmark.performanceModeResult != null) .length; if (benchmarksCount == 0) return 0; final summaryThroughput = pow( - benchmarks.fold(1, (prev, i) { + allBenchmarks.fold(1, (prev, i) { return prev * (i.performanceModeResult?.throughput?.value ?? 1.0); }), 1.0 / benchmarksCount); final maxSummaryThroughput = pow( - benchmarks.fold(1, (prev, i) { + allBenchmarks.fold(1, (prev, i) { return prev * (i.info.maxThroughput); }), 1.0 / benchmarksCount); @@ -75,7 +75,9 @@ class BenchmarkState extends ChangeNotifier { return summaryThroughput / maxSummaryThroughput; } - List get benchmarks => _benchmarkStore.benchmarks; + List get allBenchmarks => _benchmarkStore.allBenchmarks; + + List get activeBenchmarks => _benchmarkStore.activeBenchmarks; late BenchmarkStore _benchmarkStore; @@ -131,25 +133,42 @@ class BenchmarkState extends ChangeNotifier { } } - Future loadResources({required bool downloadMissing}) async { + Future loadResources( + {required bool downloadMissing, + List benchmarks = const []}) async { final newAppVersion = '${BuildInfoHelper.info.version}+${BuildInfoHelper.info.buildNumber}'; var needToPurgeCache = _store.previousAppVersion != newAppVersion; _store.previousAppVersion = newAppVersion; + final selectedBenchmarks = benchmarks.isEmpty ? allBenchmarks : benchmarks; await Wakelock.enable(); - print('Start loading resources with downloadMissing=$downloadMissing'); - final resources = _benchmarkStore.listResources( + final selectedResources = _benchmarkStore.listResources( + modes: [taskRunner.perfMode, taskRunner.accuracyMode], + benchmarks: selectedBenchmarks, + ); + final allResources = _benchmarkStore.listResources( modes: [taskRunner.perfMode, taskRunner.accuracyMode], - benchmarks: benchmarks, + benchmarks: allBenchmarks, ); try { + final selectedBenchmarkIds = selectedBenchmarks + .map((e) => e.benchmarkSettings.benchmarkId) + .join(', '); + print('Start loading resources with downloadMissing=$downloadMissing ' + 'for $selectedBenchmarkIds'); await resourceManager.handleResources( - resources, - needToPurgeCache, - downloadMissing, + resources: selectedResources, + purgeOldCache: needToPurgeCache, + downloadMissing: downloadMissing, ); print('Finished loading resources with downloadMissing=$downloadMissing'); + // We still need to load all resources after download selected resources. + await resourceManager.handleResources( + resources: allResources, + purgeOldCache: false, + downloadMissing: false, + ); error = null; stackTrace = null; taskConfigFailedToLoad = false; @@ -289,7 +308,7 @@ class BenchmarkState extends ChangeNotifier { } void resetCurrentResults() { - for (var b in _benchmarkStore.benchmarks) { + for (var b in _benchmarkStore.allBenchmarks) { b.accuracyModeResult = null; b.performanceModeResult = null; } @@ -304,7 +323,7 @@ class BenchmarkState extends ChangeNotifier { lastResult = ExtendedResult.fromJson( jsonDecode(_store.previousExtendedResult) as Map); resourceManager.resultManager - .restoreResults(lastResult!.results, benchmarks); + .restoreResults(lastResult!.results, allBenchmarks); _doneRunning = true; return; } catch (e, trace) { diff --git a/flutter/lib/l10n/app_en.arb b/flutter/lib/l10n/app_en.arb index 3120a8228..4a84f1589 100644 --- a/flutter/lib/l10n/app_en.arb +++ b/flutter/lib/l10n/app_en.arb @@ -93,6 +93,7 @@ "dialogContentMissingFiles": "The following files don't exist:", "dialogContentMissingFilesHint": "Please go to the menu Resources to download the missing files.", "dialogContentChecksumError": "The following files failed checksum validation:", + "dialogContentChecksumErrorHint": "Please go to the menu Resources to clear the cache and download the files again.", "dialogContentNoSelectedBenchmarkError": "Please select at least one benchmark.", "benchModePerformanceOnly": "Performance Only", @@ -122,7 +123,9 @@ "benchInfoStableDiffusionDesc": "The Text to Image Gen AI benchmark adopts Stable Diffusion v1.5 for generating images from text prompts. It is a latent diffusion model. The benchmarked Stable Diffusion v1.5 refers to a specific configuration of the model architecture that uses a downsampling-factor 8 autoencoder with an 860M UNet,123M CLIP ViT-L/14 text encoder for the diffusion model, and VAE Decoder of 49.5M parameters. The model was trained on 595k steps at resolution of 512x512, which enables it to generate high quality images. We refer you to https://huggingface.co/benjamin-paine/stable-diffusion-v1-5 for more information. The benchmark runs 20 denoising steps for inference, and uses a precalculated time embedding of size 1x1280. Reference models can be found here https://github.com/mlcommons/mobile_open/releases.\n\nFor latency benchmarking, we benchmark end to end, excluding the time embedding calculation and the tokenizer. For accuracy calculations, the app adopts the CLIP metric for text-to-image consistency, and further evaluation of the generated images using this Image Quality Aesthetic Assessment metric https://github.com/idealo/image-quality-assessment/tree/master?tab=readme-ov-file", "resourceDownload": "Download", + "resourceDownloadAll": "Download all", "resourceClear": "Clear", + "resourceClearAll": "Clear all", "resourceChecking": "Checking download status", "resourceDownloading": "Downloading", "resourceErrorMessage": "Some resources failed to load.\nIf you didn't change config from default you can try clearing the cache.\nIf you use a custom configuration file ensure that it has correct structure or switch back to default config.", diff --git a/flutter/lib/resources/cache_manager.dart b/flutter/lib/resources/cache_manager.dart index 973cfb5d1..38aa2c5cd 100644 --- a/flutter/lib/resources/cache_manager.dart +++ b/flutter/lib/resources/cache_manager.dart @@ -91,11 +91,12 @@ class CacheManager { return deleteLoadedResources(currentResources, atLeastDaysOld); } - Future cache( - List urls, - void Function(double, String) reportProgress, - bool purgeOldCache, - bool downloadMissing) async { + Future cache({ + required List urls, + required void Function(double, String) onProgressUpdate, + required bool purgeOldCache, + required bool downloadMissing, + }) async { final resourcesToDownload = []; _resourcesMap = {}; @@ -120,7 +121,7 @@ class CacheManager { continue; } if (downloadMissing) { - await _download(resourcesToDownload, reportProgress); + await _download(resourcesToDownload, onProgressUpdate); } if (purgeOldCache) { await purgeOutdatedCache(_oldFilesAgeInDays); @@ -132,18 +133,20 @@ class CacheManager { } Future _download( - List urls, void Function(double, String) reportProgress) async { + List urls, + void Function(double, String) onProgressUpdate, + ) async { var progress = 0.0; for (var url in urls) { progress += 0.1 / urls.length; - reportProgress(progress, url); + onProgressUpdate(progress, url); if (isResourceAnArchive(url)) { _resourcesMap[url] = await archiveCacheHelper.get(url, true); } else { _resourcesMap[url] = await fileCacheHelper.get(url, true); } progress += 0.9 / urls.length; - reportProgress(progress, url); + onProgressUpdate(progress, url); } } } diff --git a/flutter/lib/resources/resource_manager.dart b/flutter/lib/resources/resource_manager.dart index cb41d3357..970807fe7 100644 --- a/flutter/lib/resources/resource_manager.dart +++ b/flutter/lib/resources/resource_manager.dart @@ -94,8 +94,11 @@ class ResourceManager { return checksum == md5Checksum; } - Future handleResources(List resources, bool purgeOldCache, - bool downloadMissing) async { + Future handleResources({ + required List resources, + required bool purgeOldCache, + required bool downloadMissing, + }) async { _loadingPath = ''; _loadingProgress = 0.001; _done = false; @@ -114,14 +117,14 @@ class ResourceManager { final internetPaths = internetResources.map((e) => e.path).toList(); try { await cacheManager.cache( - internetPaths, - (double currentProgress, String currentPath) { + urls: internetPaths, + onProgressUpdate: (double currentProgress, String currentPath) { _loadingProgress = currentProgress; _loadingPath = currentPath; _onUpdate(); }, - purgeOldCache, - downloadMissing, + purgeOldCache: purgeOldCache, + downloadMissing: downloadMissing, ); } on SocketException { throw 'A network error has occurred. Please make sure you are connected to the internet.'; diff --git a/flutter/lib/resources/validation_helper.dart b/flutter/lib/resources/validation_helper.dart index ce577e4cd..7268cf65b 100644 --- a/flutter/lib/resources/validation_helper.dart +++ b/flutter/lib/resources/validation_helper.dart @@ -18,9 +18,6 @@ class ValidationHelper { required this.selectedRunModes, }); - List get activeBenchmarks => - benchmarkStore.benchmarks.where((e) => e.isActive).toList(); - Future validateExternalResourcesDirectory( String errorDescription) async { final dataFolderPath = resourceManager.getDataFolder(); @@ -32,7 +29,7 @@ class ValidationHelper { } final resources = benchmarkStore.listResources( modes: selectedRunModes, - benchmarks: activeBenchmarks, + benchmarks: benchmarkStore.activeBenchmarks, ); final result = await resourceManager.validateResourcesExist(resources); final missing = result[false] ?? []; @@ -42,10 +39,22 @@ class ValidationHelper { missing.mapIndexed((i, element) => '\n${i + 1}) $element').join(); } + Future validateChecksum(String errorDescription) async { + final resources = benchmarkStore.listResources( + modes: selectedRunModes, + benchmarks: benchmarkStore.activeBenchmarks, + ); + final checksumFailed = + await resourceManager.validateResourcesChecksum(resources); + if (checksumFailed.isEmpty) return ''; + final mismatchedPaths = checksumFailed.map((e) => '\n${e.path}').join(); + return errorDescription + mismatchedPaths; + } + Future validateOfflineMode(String errorDescription) async { final resources = benchmarkStore.listResources( modes: selectedRunModes, - benchmarks: activeBenchmarks, + benchmarks: benchmarkStore.activeBenchmarks, ); final internetResources = filterInternetResources(resources); if (internetResources.isEmpty) return ''; diff --git a/flutter/lib/state/task_runner.dart b/flutter/lib/state/task_runner.dart index 0b67e7ff0..6ded52c87 100644 --- a/flutter/lib/state/task_runner.dart +++ b/flutter/lib/state/task_runner.dart @@ -80,8 +80,7 @@ class TaskRunner { final cooldown = store.cooldown; final cooldownDuration = Duration(seconds: store.cooldownDuration); - final activeBenchmarks = - benchmarkStore.benchmarks.where((element) => element.isActive); + final activeBenchmarks = benchmarkStore.activeBenchmarks; final resultHelpers = []; for (final benchmark in activeBenchmarks) { diff --git a/flutter/lib/ui/home/benchmark_config_section.dart b/flutter/lib/ui/home/benchmark_config_section.dart index a10c63263..5867148ed 100644 --- a/flutter/lib/ui/home/benchmark_config_section.dart +++ b/flutter/lib/ui/home/benchmark_config_section.dart @@ -33,7 +33,7 @@ class _BenchmarkConfigSectionState extends State { l10n = AppLocalizations.of(context)!; final childrenList = []; - for (var benchmark in state.benchmarks) { + for (var benchmark in state.allBenchmarks) { childrenList.add(_listTile(benchmark)); childrenList.add(const Divider(height: 20)); } diff --git a/flutter/lib/ui/home/benchmark_result_screen.dart b/flutter/lib/ui/home/benchmark_result_screen.dart index c1bab7373..2cfe04c25 100644 --- a/flutter/lib/ui/home/benchmark_result_screen.dart +++ b/flutter/lib/ui/home/benchmark_result_screen.dart @@ -219,7 +219,7 @@ class _BenchmarkResultScreenState extends State Widget _detailSection() { final children = []; - for (final benchmark in state.benchmarks) { + for (final benchmark in state.allBenchmarks) { final row = _benchmarkResultRow(benchmark); children.add(row); children.add(const Divider()); diff --git a/flutter/lib/ui/home/benchmark_start_screen.dart b/flutter/lib/ui/home/benchmark_start_screen.dart index fd778f867..14520eb7a 100644 --- a/flutter/lib/ui/home/benchmark_start_screen.dart +++ b/flutter/lib/ui/home/benchmark_start_screen.dart @@ -55,9 +55,8 @@ class _BenchmarkStartScreenState extends State { } Widget _infoSection() { - final selectedCount = - state.benchmarks.where((e) => e.isActive).length.toString(); - final totalCount = state.benchmarks.length.toString(); + final selectedCount = state.activeBenchmarks.length.toString(); + final totalCount = state.allBenchmarks.length.toString(); final selectedBenchmarkText = l10n.mainScreenBenchmarkSelected .replaceAll('', selectedCount) .replaceAll('', totalCount); @@ -128,6 +127,17 @@ class _BenchmarkStartScreenState extends State { await showErrorDialog(context, messages); return; } + final checksumError = await state.validator + .validateChecksum(l10n.dialogContentChecksumError); + if (checksumError.isNotEmpty) { + if (!context.mounted) return; + final messages = [ + checksumError, + l10n.dialogContentChecksumErrorHint + ]; + await showErrorDialog(context, messages); + return; + } if (store.offlineMode) { final offlineError = await state.validator .validateOfflineMode(l10n.dialogContentOfflineWarning); @@ -143,8 +153,7 @@ class _BenchmarkStartScreenState extends State { } } } - final selectedCount = - state.benchmarks.where((e) => e.isActive).length; + final selectedCount = state.activeBenchmarks.length; if (selectedCount < 1) { // Workaround for Dart linter bug. See https://github.com/dart-lang/linter/issues/4007 // ignore: use_build_context_synchronously diff --git a/flutter/lib/ui/settings/resources_screen.dart b/flutter/lib/ui/settings/resources_screen.dart index 98289da1c..eaef33ef1 100644 --- a/flutter/lib/ui/settings/resources_screen.dart +++ b/flutter/lib/ui/settings/resources_screen.dart @@ -8,6 +8,7 @@ import 'package:mlperfbench/benchmark/run_mode.dart'; import 'package:mlperfbench/benchmark/state.dart'; import 'package:mlperfbench/localizations/app_localizations.dart'; import 'package:mlperfbench/store.dart'; +import 'package:mlperfbench/ui/app_styles.dart'; import 'package:mlperfbench/ui/confirm_dialog.dart'; import 'package:mlperfbench/ui/error_dialog.dart'; import 'package:mlperfbench/ui/nil.dart'; @@ -35,13 +36,14 @@ class _ResourcesScreen extends State { final children = []; - for (var benchmark in state.benchmarks) { + for (var benchmark in state.allBenchmarks) { children.add(_listTileBuilder(benchmark)); children.add(const Divider(height: 20)); } children.add(const SizedBox(height: 20)); children.add(_downloadProgress()); - children.add(_downloadButton()); + children + .add(_downloadButton(state.allBenchmarks, l10n.resourceDownloadAll)); children.add(const SizedBox(height: 20)); children.add(_clearCacheButton()); @@ -84,6 +86,7 @@ class _ResourcesScreen extends State { ], ), ), + trailing: _downloadButton([benchmark], l10n.resourceDownload), ); } @@ -102,32 +105,31 @@ class _ResourcesScreen extends State { final missing = result[false] ?? []; final existed = result[true] ?? []; final downloaded = missing.isEmpty; - return Row( - children: [ - SizedBox( - height: size, - width: size, - child: IconButton( - padding: const EdgeInsets.all(0), - icon: downloaded ? downloadedIcon : notDownloadedIcon, - onPressed: () { - showDialog( - context: context, - builder: (BuildContext context) { - return _ResourcesTable( - taskName: benchmark.info.taskName, - modeName: mode.readable, - missing: missing, - existed: existed, - ); - }, - ); - }, - ), - ), - const SizedBox(width: 10), - Text(mode.readable), - ], + return TextButton.icon( + icon: downloaded ? downloadedIcon : notDownloadedIcon, + label: Text( + mode.readable, + style: const TextStyle(color: AppColors.darkText), + ), + //iconAlignment: IconAlignment.start, + style: TextButton.styleFrom( + padding: EdgeInsets.zero, + minimumSize: const Size(50, 30), + tapTargetSize: MaterialTapTargetSize.shrinkWrap, + alignment: Alignment.centerLeft), + onPressed: () { + showDialog( + context: context, + builder: (BuildContext context) { + return _ResourcesTable( + taskName: benchmark.info.taskName, + modeName: mode.readable, + missing: missing, + existed: existed, + ); + }, + ); + }, ); } else { return Text(l10n.resourceChecking); @@ -164,12 +166,15 @@ class _ResourcesScreen extends State { ); } - Widget _downloadButton() { + Widget _downloadButton(List benchmarks, String title) { return AbsorbPointer( absorbing: downloading, child: ElevatedButton( onPressed: () async { - await state.loadResources(downloadMissing: true); + await state.loadResources( + downloadMissing: true, + benchmarks: benchmarks, + ); if (state.error != null) { if (!mounted) return; await showErrorDialog(context, [state.error.toString()]); @@ -182,7 +187,7 @@ class _ResourcesScreen extends State { backgroundColor: downloading ? Colors.grey : Colors.blue), child: FittedBox( fit: BoxFit.scaleDown, - child: Text(l10n.resourceDownload), + child: Text(title), ), ), ); @@ -213,7 +218,7 @@ class _ResourcesScreen extends State { backgroundColor: downloading ? Colors.grey : Colors.red), child: FittedBox( fit: BoxFit.scaleDown, - child: Text(l10n.resourceClear), + child: Text(l10n.resourceClearAll), ), ), ); diff --git a/flutter/unit_test/benchmark/benchmark_store_test.dart b/flutter/unit_test/benchmark/benchmark_store_test.dart index 5f8e22d06..1e916353e 100644 --- a/flutter/unit_test/benchmark/benchmark_store_test.dart +++ b/flutter/unit_test/benchmark/benchmark_store_test.dart @@ -41,12 +41,12 @@ void main() { taskSelection: {}, ); - expect(store.benchmarks.length, 1); + expect(store.allBenchmarks.length, 1); - expect(store.benchmarks.first.taskConfig, task1); - expect(store.benchmarks.first.benchmarkSettings, backendSettings1); + expect(store.allBenchmarks.first.taskConfig, task1); + expect(store.allBenchmarks.first.benchmarkSettings, backendSettings1); expect( - store.benchmarks.first.isActive, + store.allBenchmarks.first.isActive, true, reason: 'benchmarks must be enabled by default', ); @@ -59,13 +59,13 @@ void main() { taskSelection: {}, ); - expect(store.benchmarks.length, 2); + expect(store.allBenchmarks.length, 2); - expect(store.benchmarks.first.taskConfig, task2); - expect(store.benchmarks.first.benchmarkSettings, backendSettings2); + expect(store.allBenchmarks.first.taskConfig, task2); + expect(store.allBenchmarks.first.benchmarkSettings, backendSettings2); - expect(store.benchmarks.last.taskConfig, task1); - expect(store.benchmarks.last.benchmarkSettings, backendSettings1); + expect(store.allBenchmarks.last.taskConfig, task1); + expect(store.allBenchmarks.last.benchmarkSettings, backendSettings1); }); test('selection', () async { @@ -75,9 +75,9 @@ void main() { taskSelection: {task1.id: true, task2.id: false}, ); - expect(store.benchmarks.length, 2); - expect(store.benchmarks.first.isActive, true); - expect(store.benchmarks.last.isActive, false); + expect(store.allBenchmarks.length, 2); + expect(store.allBenchmarks.first.isActive, true); + expect(store.allBenchmarks.last.isActive, false); }); test('resource list: skip', () async { @@ -88,11 +88,9 @@ void main() { ); final modes = [BenchmarkRunModeEnum.performanceOnly.performanceRunMode]; - final activeBenchmarks = - store.benchmarks.where((e) => e.isActive).toList(); final resources = store.listResources( modes: modes, - benchmarks: activeBenchmarks, + benchmarks: store.activeBenchmarks, ); expect(resources.length, 0); @@ -105,11 +103,9 @@ void main() { ); final modes = [BenchmarkRunModeEnum.accuracyOnly.accuracyRunMode]; - final activeBenchmarks = - store.benchmarks.where((e) => e.isActive).toList(); final resources = store.listResources( modes: modes, - benchmarks: activeBenchmarks, + benchmarks: store.activeBenchmarks, ); expect(resources.length, 3); @@ -144,8 +140,7 @@ void main() { ); final modes = [BenchmarkRunModeEnum.performanceOnly.performanceRunMode]; - final activeBenchmarks = - store.benchmarks.where((e) => e.isActive).toList(); + final activeBenchmarks = store.activeBenchmarks; final resources = store.listResources( modes: modes, benchmarks: activeBenchmarks, @@ -179,8 +174,7 @@ void main() { BenchmarkRunModeEnum.integrationTestRun.accuracyRunMode, BenchmarkRunModeEnum.integrationTestRun.performanceRunMode, ]; - final activeBenchmarks = - store.benchmarks.where((e) => e.isActive).toList(); + final activeBenchmarks = store.activeBenchmarks; final resources = store.listResources( modes: modes, benchmarks: activeBenchmarks, diff --git a/flutter/unit_test/resources/cache_manager_test.dart b/flutter/unit_test/resources/cache_manager_test.dart index d814e525a..b91124b79 100644 --- a/flutter/unit_test/resources/cache_manager_test.dart +++ b/flutter/unit_test/resources/cache_manager_test.dart @@ -15,7 +15,12 @@ void main() async { setUp(() async { manager = CacheManager('/tmp/resources'); - await manager.cache(paths, (val, str) {}, true, true); + await manager.cache( + urls: paths, + onProgressUpdate: (val, str) {}, + purgeOldCache: true, + downloadMissing: true, + ); }); test('get', () async {