Skip to content

Commit

Permalink
[mobile] Enhance ML (#1114)
Browse files Browse the repository at this point in the history
## Description
- Increase the interaction timeout to 15s
- Make sure that models aren't downloaded over mobile data, and that it
resumes initialization when network conditions are favorable

## Tests
- [x] Tested manually
  • Loading branch information
vishnukvmd authored Mar 15, 2024
2 parents 11ccb37 + 14c7533 commit bf3b257
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ class MachineLearningController {

static const kMaximumTemperature = 42; // 42 degree celsius
static const kMinimumBatteryLevel = 20; // 20%
static const kInitialInteractionTimeout = Duration(seconds: 10);
static const kDefaultInteractionTimeout = Duration(seconds: 5);
static const kDefaultInteractionTimeout = Duration(seconds: 15);
static const kUnhealthyStates = ["over_heat", "over_voltage", "dead"];

bool _isDeviceHealthy = true;
Expand All @@ -28,7 +27,7 @@ class MachineLearningController {

void init() {
if (Platform.isAndroid) {
_startInteractionTimer(timeout: kInitialInteractionTimeout);
_startInteractionTimer();
BatteryInfoPlugin()
.androidBatteryInfoStream
.listen((AndroidBatteryInfo? batteryInfo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import "package:logging/logging.dart";
import "package:photos/core/errors.dart";

import "package:photos/core/event_bus.dart";
import "package:photos/core/network/network.dart";
import "package:photos/events/event.dart";
import "package:photos/services/remote_assets_service.dart";

Expand All @@ -17,9 +16,9 @@ abstract class MLFramework {
static final _logger = Logger("MLFramework");

final bool shouldDownloadOverMobileData;
final _initializationCompleter = Completer<void>();

InitializationState _state = InitializationState.notInitialized;
final _initializationCompleter = Completer<void>();

MLFramework(this.shouldDownloadOverMobileData) {
Connectivity()
Expand Down Expand Up @@ -65,6 +64,7 @@ abstract class MLFramework {
/// instead of a CDN.
Future<void> init() async {
try {
_initState = InitializationState.initializing;
await Future.wait([_initImageModel(), _initTextModel()]);
} catch (e, s) {
_logger.warning(e, s);
Expand Down Expand Up @@ -102,41 +102,32 @@ abstract class MLFramework {
if (!kImageEncoderEnabled) {
return;
}
_initState = InitializationState.initializingImageModel;
final imageModel =
await RemoteAssetsService.instance.getAsset(getImageModelRemotePath());
final imageModel = await _getModel(getImageModelRemotePath());
await loadImageModel(imageModel.path);
_initState = InitializationState.initializedImageModel;
}

Future<void> _initTextModel() async {
_initState = InitializationState.initializingTextModel;
final textModel =
await RemoteAssetsService.instance.getAsset(getTextModelRemotePath());
final textModel = await _getModel(getTextModelRemotePath());
await loadTextModel(textModel.path);
_initState = InitializationState.initializedTextModel;
}

Future<void> _downloadFile(
String url,
String savePath, {
Future<File> _getModel(
String url, {
int trialCount = 1,
}) async {
if (await RemoteAssetsService.instance.hasAsset(url)) {
return RemoteAssetsService.instance.getAsset(url);
}
if (!await _canDownload()) {
_initState = InitializationState.waitingForNetwork;
throw WiFiUnavailableError();
}
_logger.info("Downloading " + url);
final existingFile = File(savePath);
if (await existingFile.exists()) {
await existingFile.delete();
}
try {
await NetworkClient.instance.getDio().download(url, savePath);
return RemoteAssetsService.instance.getAsset(url);
} catch (e, s) {
_logger.severe(e, s);
if (trialCount < kMaximumRetrials) {
return _downloadFile(url, savePath, trialCount: trialCount + 1);
return _getModel(url, trialCount: trialCount + 1);
} else {
rethrow;
}
Expand All @@ -159,9 +150,6 @@ class MLFrameworkInitializationUpdateEvent extends Event {
enum InitializationState {
notInitialized,
waitingForNetwork,
initializingImageModel,
initializedImageModel,
initializingTextModel,
initializedTextModel,
initializing,
initialized,
}
5 changes: 5 additions & 0 deletions mobile/lib/services/remote_assets_service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class RemoteAssetsService {
}
}

Future<bool> hasAsset(String remotePath) async {
final path = await _getLocalPath(remotePath);
return File(path).exists();
}

Future<String> _getLocalPath(String remotePath) async {
return (await getApplicationSupportDirectory()).path +
"/assets/" +
Expand Down

0 comments on commit bf3b257

Please sign in to comment.