Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: download resources per task #960

Merged
merged 9 commits into from
Feb 25, 2025
4 changes: 2 additions & 2 deletions flutter/integration_test/utils.dart
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Future<void> startApp(WidgetTester tester) async {
Future<void> validateSettings(WidgetTester tester) async {
final state = tester.state(find.byType(MaterialApp));
final benchmarkState = state.context.read<BenchmarkState>();
for (var benchmark in benchmarkState.benchmarks) {
for (var benchmark in benchmarkState.allBenchmarks) {
expect(benchmark.selectedDelegate.batchSize, greaterThanOrEqualTo(0),
reason: 'batchSize must >= 0');
for (var modelFile in benchmark.selectedDelegate.modelFile) {
Expand Down Expand Up @@ -67,7 +67,7 @@ Future<void> validateSettings(WidgetTester tester) async {
Future<void> setBenchmarks(WidgetTester tester) async {
final state = tester.state(find.byType(MaterialApp));
final benchmarkState = state.context.read<BenchmarkState>();
for (var benchmark in benchmarkState.benchmarks) {
for (var benchmark in benchmarkState.allBenchmarks) {
// Disable test for stable diffusion since it take too long to finish.
if (benchmark.id == BenchmarkId.stableDiffusion) {
benchmark.isActive = false;
Expand Down
10 changes: 7 additions & 3 deletions flutter/lib/benchmark/benchmark.dart
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ class Benchmark {
}

class BenchmarkStore {
final List<Benchmark> benchmarks = <Benchmark>[];
final List<Benchmark> allBenchmarks = <Benchmark>[];

List<Benchmark> get activeBenchmarks {
return allBenchmarks.where((e) => e.isActive).toList();
}

BenchmarkStore({
required pb.MLPerfConfig appConfig,
Expand All @@ -137,7 +141,7 @@ class BenchmarkStore {
}

final enabled = taskSelection[task.id] ?? true;
benchmarks.add(Benchmark(
allBenchmarks.add(Benchmark(
taskConfig: task,
benchmarkSettings: backendSettings,
isActive: enabled,
Expand Down Expand Up @@ -186,7 +190,7 @@ class BenchmarkStore {

Map<String, bool> get selection {
Map<String, bool> result = {};
for (var item in benchmarks) {
for (var item in allBenchmarks) {
result[item.id] = item.isActive;
}
return result;
Expand Down
45 changes: 32 additions & 13 deletions flutter/lib/benchmark/state.dart
Original file line number Diff line number Diff line change
Expand Up @@ -54,28 +54,30 @@ class BenchmarkState extends ChangeNotifier {
ExtendedResult? lastResult;

num get result {
final benchmarksCount = benchmarks
final benchmarksCount = allBenchmarks
.where((benchmark) => benchmark.performanceModeResult != null)
.length;

if (benchmarksCount == 0) return 0;

final summaryThroughput = pow(
benchmarks.fold<double>(1, (prev, i) {
allBenchmarks.fold<double>(1, (prev, i) {
return prev * (i.performanceModeResult?.throughput?.value ?? 1.0);
}),
1.0 / benchmarksCount);

final maxSummaryThroughput = pow(
benchmarks.fold<double>(1, (prev, i) {
allBenchmarks.fold<double>(1, (prev, i) {
return prev * (i.info.maxThroughput);
}),
1.0 / benchmarksCount);

return summaryThroughput / maxSummaryThroughput;
}

List<Benchmark> get benchmarks => _benchmarkStore.benchmarks;
List<Benchmark> get allBenchmarks => _benchmarkStore.allBenchmarks;

List<Benchmark> get activeBenchmarks => _benchmarkStore.activeBenchmarks;

late BenchmarkStore _benchmarkStore;

Expand Down Expand Up @@ -131,25 +133,42 @@ class BenchmarkState extends ChangeNotifier {
}
}

Future<void> loadResources({required bool downloadMissing}) async {
Future<void> loadResources(
{required bool downloadMissing,
List<Benchmark> benchmarks = const []}) async {
final newAppVersion =
'${BuildInfoHelper.info.version}+${BuildInfoHelper.info.buildNumber}';
var needToPurgeCache = _store.previousAppVersion != newAppVersion;
_store.previousAppVersion = newAppVersion;

final selectedBenchmarks = benchmarks.isEmpty ? allBenchmarks : benchmarks;
await Wakelock.enable();
print('Start loading resources with downloadMissing=$downloadMissing');
final resources = _benchmarkStore.listResources(
final selectedResources = _benchmarkStore.listResources(
modes: [taskRunner.perfMode, taskRunner.accuracyMode],
benchmarks: selectedBenchmarks,
);
final allResources = _benchmarkStore.listResources(
modes: [taskRunner.perfMode, taskRunner.accuracyMode],
benchmarks: benchmarks,
benchmarks: allBenchmarks,
);
try {
final selectedBenchmarkIds = selectedBenchmarks
.map((e) => e.benchmarkSettings.benchmarkId)
.join(', ');
print('Start loading resources with downloadMissing=$downloadMissing '
'for $selectedBenchmarkIds');
await resourceManager.handleResources(
resources,
needToPurgeCache,
downloadMissing,
resources: selectedResources,
purgeOldCache: needToPurgeCache,
downloadMissing: downloadMissing,
);
print('Finished loading resources with downloadMissing=$downloadMissing');
// We still need to load all resources after download selected resources.
await resourceManager.handleResources(
resources: allResources,
purgeOldCache: false,
downloadMissing: false,
);
error = null;
stackTrace = null;
taskConfigFailedToLoad = false;
Expand Down Expand Up @@ -289,7 +308,7 @@ class BenchmarkState extends ChangeNotifier {
}

void resetCurrentResults() {
for (var b in _benchmarkStore.benchmarks) {
for (var b in _benchmarkStore.allBenchmarks) {
b.accuracyModeResult = null;
b.performanceModeResult = null;
}
Expand All @@ -304,7 +323,7 @@ class BenchmarkState extends ChangeNotifier {
lastResult = ExtendedResult.fromJson(
jsonDecode(_store.previousExtendedResult) as Map<String, dynamic>);
resourceManager.resultManager
.restoreResults(lastResult!.results, benchmarks);
.restoreResults(lastResult!.results, allBenchmarks);
_doneRunning = true;
return;
} catch (e, trace) {
Expand Down
3 changes: 3 additions & 0 deletions flutter/lib/l10n/app_en.arb
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"dialogContentMissingFiles": "The following files don't exist:",
"dialogContentMissingFilesHint": "Please go to the menu Resources to download the missing files.",
"dialogContentChecksumError": "The following files failed checksum validation:",
"dialogContentChecksumErrorHint": "Please go to the menu Resources to clear the cache and download the files again.",
"dialogContentNoSelectedBenchmarkError": "Please select at least one benchmark.",

"benchModePerformanceOnly": "Performance Only",
Expand Down Expand Up @@ -122,7 +123,9 @@
"benchInfoStableDiffusionDesc": "The Text to Image Gen AI benchmark adopts Stable Diffusion v1.5 for generating images from text prompts. It is a latent diffusion model. The benchmarked Stable Diffusion v1.5 refers to a specific configuration of the model architecture that uses a downsampling-factor 8 autoencoder with an 860M UNet,123M CLIP ViT-L/14 text encoder for the diffusion model, and VAE Decoder of 49.5M parameters. The model was trained on 595k steps at resolution of 512x512, which enables it to generate high quality images. We refer you to https://huggingface.co/benjamin-paine/stable-diffusion-v1-5 for more information. The benchmark runs 20 denoising steps for inference, and uses a precalculated time embedding of size 1x1280. Reference models can be found here https://github.com/mlcommons/mobile_open/releases.\n\nFor latency benchmarking, we benchmark end to end, excluding the time embedding calculation and the tokenizer. For accuracy calculations, the app adopts the CLIP metric for text-to-image consistency, and further evaluation of the generated images using this Image Quality Aesthetic Assessment metric https://github.com/idealo/image-quality-assessment/tree/master?tab=readme-ov-file",

"resourceDownload": "Download",
"resourceDownloadAll": "Download all",
"resourceClear": "Clear",
"resourceClearAll": "Clear all",
"resourceChecking": "Checking download status",
"resourceDownloading": "Downloading",
"resourceErrorMessage": "Some resources failed to load.\nIf you didn't change config from default you can try clearing the cache.\nIf you use a custom configuration file ensure that it has correct structure or switch back to default config.",
Expand Down
21 changes: 12 additions & 9 deletions flutter/lib/resources/cache_manager.dart
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ class CacheManager {
return deleteLoadedResources(currentResources, atLeastDaysOld);
}

Future<void> cache(
List<String> urls,
void Function(double, String) reportProgress,
bool purgeOldCache,
bool downloadMissing) async {
Future<void> cache({
required List<String> urls,
required void Function(double, String) onProgressUpdate,
required bool purgeOldCache,
required bool downloadMissing,
}) async {
final resourcesToDownload = <String>[];
_resourcesMap = {};

Expand All @@ -120,7 +121,7 @@ class CacheManager {
continue;
}
if (downloadMissing) {
await _download(resourcesToDownload, reportProgress);
await _download(resourcesToDownload, onProgressUpdate);
}
if (purgeOldCache) {
await purgeOutdatedCache(_oldFilesAgeInDays);
Expand All @@ -132,18 +133,20 @@ class CacheManager {
}

Future<void> _download(
List<String> urls, void Function(double, String) reportProgress) async {
List<String> urls,
void Function(double, String) onProgressUpdate,
) async {
var progress = 0.0;
for (var url in urls) {
progress += 0.1 / urls.length;
reportProgress(progress, url);
onProgressUpdate(progress, url);
if (isResourceAnArchive(url)) {
_resourcesMap[url] = await archiveCacheHelper.get(url, true);
} else {
_resourcesMap[url] = await fileCacheHelper.get(url, true);
}
progress += 0.9 / urls.length;
reportProgress(progress, url);
onProgressUpdate(progress, url);
}
}
}
15 changes: 9 additions & 6 deletions flutter/lib/resources/resource_manager.dart
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@ class ResourceManager {
return checksum == md5Checksum;
}

Future<void> handleResources(List<Resource> resources, bool purgeOldCache,
bool downloadMissing) async {
Future<void> handleResources({
required List<Resource> resources,
required bool purgeOldCache,
required bool downloadMissing,
}) async {
_loadingPath = '';
_loadingProgress = 0.001;
_done = false;
Expand All @@ -114,14 +117,14 @@ class ResourceManager {
final internetPaths = internetResources.map((e) => e.path).toList();
try {
await cacheManager.cache(
internetPaths,
(double currentProgress, String currentPath) {
urls: internetPaths,
onProgressUpdate: (double currentProgress, String currentPath) {
_loadingProgress = currentProgress;
_loadingPath = currentPath;
_onUpdate();
},
purgeOldCache,
downloadMissing,
purgeOldCache: purgeOldCache,
downloadMissing: downloadMissing,
);
} on SocketException {
throw 'A network error has occurred. Please make sure you are connected to the internet.';
Expand Down
19 changes: 14 additions & 5 deletions flutter/lib/resources/validation_helper.dart
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ class ValidationHelper {
required this.selectedRunModes,
});

List<Benchmark> get activeBenchmarks =>
benchmarkStore.benchmarks.where((e) => e.isActive).toList();

Future<String> validateExternalResourcesDirectory(
String errorDescription) async {
final dataFolderPath = resourceManager.getDataFolder();
Expand All @@ -32,7 +29,7 @@ class ValidationHelper {
}
final resources = benchmarkStore.listResources(
modes: selectedRunModes,
benchmarks: activeBenchmarks,
benchmarks: benchmarkStore.activeBenchmarks,
);
final result = await resourceManager.validateResourcesExist(resources);
final missing = result[false] ?? [];
Expand All @@ -42,10 +39,22 @@ class ValidationHelper {
missing.mapIndexed((i, element) => '\n${i + 1}) $element').join();
}

Future<String> validateChecksum(String errorDescription) async {
final resources = benchmarkStore.listResources(
modes: selectedRunModes,
benchmarks: benchmarkStore.activeBenchmarks,
);
final checksumFailed =
await resourceManager.validateResourcesChecksum(resources);
if (checksumFailed.isEmpty) return '';
final mismatchedPaths = checksumFailed.map((e) => '\n${e.path}').join();
return errorDescription + mismatchedPaths;
}

Future<String> validateOfflineMode(String errorDescription) async {
final resources = benchmarkStore.listResources(
modes: selectedRunModes,
benchmarks: activeBenchmarks,
benchmarks: benchmarkStore.activeBenchmarks,
);
final internetResources = filterInternetResources(resources);
if (internetResources.isEmpty) return '';
Expand Down
3 changes: 1 addition & 2 deletions flutter/lib/state/task_runner.dart
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ class TaskRunner {
final cooldown = store.cooldown;
final cooldownDuration = Duration(seconds: store.cooldownDuration);

final activeBenchmarks =
benchmarkStore.benchmarks.where((element) => element.isActive);
final activeBenchmarks = benchmarkStore.activeBenchmarks;

final resultHelpers = <ResultHelper>[];
for (final benchmark in activeBenchmarks) {
Expand Down
2 changes: 1 addition & 1 deletion flutter/lib/ui/home/benchmark_config_section.dart
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class _BenchmarkConfigSectionState extends State<BenchmarkConfigSection> {
l10n = AppLocalizations.of(context)!;
final childrenList = <Widget>[];

for (var benchmark in state.benchmarks) {
for (var benchmark in state.allBenchmarks) {
childrenList.add(_listTile(benchmark));
childrenList.add(const Divider(height: 20));
}
Expand Down
2 changes: 1 addition & 1 deletion flutter/lib/ui/home/benchmark_result_screen.dart
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class _BenchmarkResultScreenState extends State<BenchmarkResultScreen>

Widget _detailSection() {
final children = <Widget>[];
for (final benchmark in state.benchmarks) {
for (final benchmark in state.allBenchmarks) {
final row = _benchmarkResultRow(benchmark);
children.add(row);
children.add(const Divider());
Expand Down
19 changes: 14 additions & 5 deletions flutter/lib/ui/home/benchmark_start_screen.dart
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ class _BenchmarkStartScreenState extends State<BenchmarkStartScreen> {
}

Widget _infoSection() {
final selectedCount =
state.benchmarks.where((e) => e.isActive).length.toString();
final totalCount = state.benchmarks.length.toString();
final selectedCount = state.activeBenchmarks.length.toString();
final totalCount = state.allBenchmarks.length.toString();
final selectedBenchmarkText = l10n.mainScreenBenchmarkSelected
.replaceAll('<selected>', selectedCount)
.replaceAll('<total>', totalCount);
Expand Down Expand Up @@ -128,6 +127,17 @@ class _BenchmarkStartScreenState extends State<BenchmarkStartScreen> {
await showErrorDialog(context, messages);
return;
}
final checksumError = await state.validator
.validateChecksum(l10n.dialogContentChecksumError);
if (checksumError.isNotEmpty) {
if (!context.mounted) return;
final messages = [
checksumError,
l10n.dialogContentChecksumErrorHint
];
await showErrorDialog(context, messages);
return;
}
if (store.offlineMode) {
final offlineError = await state.validator
.validateOfflineMode(l10n.dialogContentOfflineWarning);
Expand All @@ -143,8 +153,7 @@ class _BenchmarkStartScreenState extends State<BenchmarkStartScreen> {
}
}
}
final selectedCount =
state.benchmarks.where((e) => e.isActive).length;
final selectedCount = state.activeBenchmarks.length;
if (selectedCount < 1) {
// Workaround for Dart linter bug. See https://github.com/dart-lang/linter/issues/4007
// ignore: use_build_context_synchronously
Expand Down
Loading
Loading