diff --git a/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart b/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart index 94f4805831..3abd57db70 100644 --- a/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart +++ b/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart @@ -16,9 +16,9 @@ abstract class MLFramework { static final _logger = Logger("MLFramework"); final bool shouldDownloadOverMobileData; + final _initializationCompleter = Completer(); InitializationState _state = InitializationState.notInitialized; - final _initializationCompleter = Completer(); MLFramework(this.shouldDownloadOverMobileData) { Connectivity() @@ -64,6 +64,7 @@ abstract class MLFramework { /// instead of a CDN. Future init() async { try { + _initState = InitializationState.initializing; await Future.wait([_initImageModel(), _initTextModel()]); } catch (e, s) { _logger.warning(e, s); @@ -101,17 +102,13 @@ abstract class MLFramework { if (!kImageEncoderEnabled) { return; } - _initState = InitializationState.initializingImageModel; final imageModel = await _getModel(getImageModelRemotePath()); await loadImageModel(imageModel.path); - _initState = InitializationState.initializedImageModel; } Future _initTextModel() async { - _initState = InitializationState.initializingTextModel; final textModel = await _getModel(getTextModelRemotePath()); await loadTextModel(textModel.path); - _initState = InitializationState.initializedTextModel; } Future _getModel( @@ -153,9 +150,6 @@ class MLFrameworkInitializationUpdateEvent extends Event { enum InitializationState { notInitialized, waitingForNetwork, - initializingImageModel, - initializedImageModel, - initializingTextModel, - initializedTextModel, + initializing, initialized, }