Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
InAnYan authored Sep 5, 2024
1 parent 6af91b9 commit d0415c8
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 53 deletions.
4 changes: 2 additions & 2 deletions src/main/java/org/jabref/gui/preferences/ai/AiTab.fxml
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@
<Button fx:id="temperatureHelp"
prefWidth="20.0"/>
</HBox>
<DoubleInputField
<TextField
fx:id="temperatureTextField"
HBox.hgrow="ALWAYS"/>
</VBox>
Expand Down Expand Up @@ -211,7 +211,7 @@
<Button fx:id="ragMinScoreHelp"
prefWidth="20.0"/>
</HBox>
<DoubleInputField
<TextField
fx:id="ragMinScoreTextField"
HBox.hgrow="ALWAYS"/>
</VBox>
Expand Down
34 changes: 9 additions & 25 deletions src/main/java/org/jabref/gui/preferences/ai/AiTab.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import com.airhacks.afterburner.views.ViewLoader;
import com.dlsc.gemsfx.ResizableTextArea;
import com.dlsc.unitfx.DoubleInputField;
import com.dlsc.unitfx.IntegerInputField;
import de.saxsys.mvvmfx.utils.validation.visualization.ControlsFxVisualizer;
import jakarta.inject.Inject;
Expand All @@ -44,12 +43,12 @@ public class AiTab extends AbstractPreferenceTabView<AiTabViewModel> implements
@FXML private TextField apiBaseUrlTextField;
@FXML private SearchableComboBox<EmbeddingModel> embeddingModelComboBox;
@FXML private ResizableTextArea instructionTextArea;
@FXML private DoubleInputField temperatureTextField;
@FXML private TextField temperatureTextField;
@FXML private IntegerInputField contextWindowSizeTextField;
@FXML private IntegerInputField documentSplitterChunkSizeTextField;
@FXML private IntegerInputField documentSplitterOverlapSizeTextField;
@FXML private IntegerInputField ragMaxResultsCountTextField;
@FXML private DoubleInputField ragMinScoreTextField;
@FXML private TextField ragMinScoreTextField;

@FXML private Button enableAiHelp;
@FXML private Button aiProviderHelp;
Expand Down Expand Up @@ -124,9 +123,6 @@ public void initialize() {
instructionTextArea.textProperty().bindBidirectional(viewModel.instructionProperty());
instructionTextArea.disableProperty().bind(viewModel.disableExpertSettingsProperty());

temperatureTextField.valueProperty().bindBidirectional(viewModel.temperatureProperty().asObject());
temperatureTextField.disableProperty().bind(viewModel.disableExpertSettingsProperty());

// bindBidirectional doesn't work well with number input fields ({@link IntegerInputField}, {@link DoubleInputField}),
// so they are expanded into `addListener` calls.

Expand All @@ -138,18 +134,11 @@ public void initialize() {
contextWindowSizeTextField.valueProperty().set(newValue == null ? 0 : newValue.intValue());
});

temperatureTextField.valueProperty().addListener((observable, oldValue, newValue) -> {
viewModel.temperatureProperty().set(newValue == null ? 0 : newValue);
});

viewModel.temperatureProperty().addListener((observable, oldValue, newValue) -> {
temperatureTextField.valueProperty().set(newValue == null ? 0 : newValue.doubleValue());
});
contextWindowSizeTextField.disableProperty().bind(viewModel.disableExpertSettingsProperty());

temperatureTextField.textProperty().bindBidirectional(viewModel.temperatureProperty());
temperatureTextField.disableProperty().bind(viewModel.disableExpertSettingsProperty());

contextWindowSizeTextField.disableProperty().bind(viewModel.disableExpertSettingsProperty());

documentSplitterChunkSizeTextField.valueProperty().addListener((observable, oldValue, newValue) -> {
viewModel.documentSplitterChunkSizeProperty().set(newValue == null ? 0 : newValue);
});
Expand Down Expand Up @@ -180,14 +169,7 @@ public void initialize() {

ragMaxResultsCountTextField.disableProperty().bind(viewModel.disableExpertSettingsProperty());

ragMinScoreTextField.valueProperty().addListener((observable, oldValue, newValue) -> {
viewModel.ragMinScoreProperty().set(newValue == null ? 0.0 : newValue);
});

viewModel.ragMinScoreProperty().addListener((observable, oldValue, newValue) -> {
ragMinScoreTextField.valueProperty().set(newValue == null ? 0.0 : newValue.doubleValue());
});

ragMinScoreTextField.textProperty().bindBidirectional(viewModel.ragMinScoreProperty());
ragMinScoreTextField.disableProperty().bind(viewModel.disableExpertSettingsProperty());

Platform.runLater(() -> {
Expand All @@ -196,12 +178,14 @@ public void initialize() {
visualizer.initVisualization(viewModel.getApiBaseUrlValidationStatus(), apiBaseUrlTextField);
visualizer.initVisualization(viewModel.getEmbeddingModelValidationStatus(), embeddingModelComboBox);
visualizer.initVisualization(viewModel.getSystemMessageValidationStatus(), instructionTextArea);
visualizer.initVisualization(viewModel.getTemperatureValidationStatus(), temperatureTextField);
visualizer.initVisualization(viewModel.getTemperatureTypeValidationStatus(), temperatureTextField);
visualizer.initVisualization(viewModel.getTemperatureRangeValidationStatus(), temperatureTextField);
visualizer.initVisualization(viewModel.getMessageWindowSizeValidationStatus(), contextWindowSizeTextField);
visualizer.initVisualization(viewModel.getDocumentSplitterChunkSizeValidationStatus(), documentSplitterChunkSizeTextField);
visualizer.initVisualization(viewModel.getDocumentSplitterOverlapSizeValidationStatus(), documentSplitterOverlapSizeTextField);
visualizer.initVisualization(viewModel.getRagMaxResultsCountValidationStatus(), ragMaxResultsCountTextField);
visualizer.initVisualization(viewModel.getRagMinScoreValidationStatus(), ragMinScoreTextField);
visualizer.initVisualization(viewModel.getRagMinScoreTypeValidationStatus(), ragMinScoreTextField);
visualizer.initVisualization(viewModel.getRagMinScoreRangeValidationStatus(), ragMinScoreTextField);
});

ActionFactory actionFactory = new ActionFactory();
Expand Down
77 changes: 52 additions & 25 deletions src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
package org.jabref.gui.preferences.ai;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

import javafx.beans.property.BooleanProperty;
import javafx.beans.property.DoubleProperty;
import javafx.beans.property.IntegerProperty;
import javafx.beans.property.ListProperty;
import javafx.beans.property.ObjectProperty;
import javafx.beans.property.ReadOnlyListProperty;
import javafx.beans.property.SimpleBooleanProperty;
import javafx.beans.property.SimpleDoubleProperty;
import javafx.beans.property.SimpleIntegerProperty;
import javafx.beans.property.SimpleListProperty;
import javafx.beans.property.SimpleObjectProperty;
Expand All @@ -22,6 +21,7 @@
import org.jabref.gui.preferences.PreferenceTabViewModel;
import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.util.LocalizedNumbers;
import org.jabref.model.strings.StringUtil;
import org.jabref.preferences.PreferencesService;
import org.jabref.preferences.ai.AiApiKeyProvider;
Expand All @@ -35,6 +35,8 @@
import de.saxsys.mvvmfx.utils.validation.Validator;

public class AiTabViewModel implements PreferenceTabViewModel {
private final Locale oldLocale;

private final BooleanProperty enableAi = new SimpleBooleanProperty();

private final ListProperty<AiProvider> aiProvidersList =
Expand Down Expand Up @@ -70,12 +72,12 @@ public class AiTabViewModel implements PreferenceTabViewModel {
private final StringProperty huggingFaceApiBaseUrl = new SimpleStringProperty();

private final StringProperty instruction = new SimpleStringProperty();
private final DoubleProperty temperature = new SimpleDoubleProperty();
private final StringProperty temperature = new SimpleStringProperty();
private final IntegerProperty contextWindowSize = new SimpleIntegerProperty();
private final IntegerProperty documentSplitterChunkSize = new SimpleIntegerProperty();
private final IntegerProperty documentSplitterOverlapSize = new SimpleIntegerProperty();
private final IntegerProperty ragMaxResultsCount = new SimpleIntegerProperty();
private final DoubleProperty ragMinScore = new SimpleDoubleProperty();
private final StringProperty ragMinScore = new SimpleStringProperty();

private final BooleanProperty disableBasicSettings = new SimpleBooleanProperty(true);
private final BooleanProperty disableExpertSettings = new SimpleBooleanProperty(true);
Expand All @@ -88,14 +90,18 @@ public class AiTabViewModel implements PreferenceTabViewModel {
private final Validator apiBaseUrlValidator;
private final Validator embeddingModelValidator;
private final Validator instructionValidator;
private final Validator temperatureValidator;
private final Validator temperatureTypeValidator;
private final Validator temperatureRangeValidator;
private final Validator contextWindowSizeValidator;
private final Validator documentSplitterChunkSizeValidator;
private final Validator documentSplitterOverlapSizeValidator;
private final Validator ragMaxResultsCountValidator;
private final Validator ragMinScoreValidator;
private final Validator ragMinScoreTypeValidator;
private final Validator ragMinScoreRangeValidator;

public AiTabViewModel(PreferencesService preferencesService, AiApiKeyProvider aiApiKeyProvider) {
this.oldLocale = Locale.getDefault();

this.aiPreferences = preferencesService.getAiPreferences();
this.aiApiKeyProvider = aiApiKeyProvider;

Expand Down Expand Up @@ -216,10 +222,15 @@ public AiTabViewModel(PreferencesService preferencesService, AiApiKeyProvider ai
message -> !StringUtil.isBlank(message),
ValidationMessage.error(Localization.lang("The instruction has to be provided")));

this.temperatureTypeValidator = new FunctionBasedValidator<>(
temperature,
temp -> LocalizedNumbers.stringToDouble(temp).isPresent(),
ValidationMessage.error(Localization.lang("Temperature must be a number")));

// Source: https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature
this.temperatureValidator = new FunctionBasedValidator<>(
this.temperatureRangeValidator = new FunctionBasedValidator<>(
temperature,
temp -> temp.doubleValue() >= 0 && temp.doubleValue() <= 2,
temp -> LocalizedNumbers.stringToDouble(temp).map(t -> t >= 0 && t <= 2).orElse(false),
ValidationMessage.error(Localization.lang("Temperature must be between 0 and 2")));

this.contextWindowSizeValidator = new FunctionBasedValidator<>(
Expand All @@ -242,10 +253,15 @@ public AiTabViewModel(PreferencesService preferencesService, AiApiKeyProvider ai
count -> count.intValue() > 0,
ValidationMessage.error(Localization.lang("RAG max results count must be greater than 0")));

this.ragMinScoreValidator = new FunctionBasedValidator<>(
this.ragMinScoreTypeValidator = new FunctionBasedValidator<>(
ragMinScore,
minScore -> LocalizedNumbers.stringToDouble(minScore).isPresent(),
ValidationMessage.error(Localization.lang("RAG minimum score must be a number")));

this.ragMinScoreRangeValidator = new FunctionBasedValidator<>(
ragMinScore,
score -> score.doubleValue() > 0 && score.doubleValue() < 1,
ValidationMessage.error(Localization.lang("RAG min score must be greater than 0 and less than 1")));
minScore -> LocalizedNumbers.stringToDouble(minScore).map(s -> s > 0 && s < 1).orElse(false),
ValidationMessage.error(Localization.lang("RAG minimum score must be greater than 0 and less than 1")));
}

@Override
Expand All @@ -271,12 +287,12 @@ public void setValues() {
selectedEmbeddingModel.setValue(aiPreferences.getEmbeddingModel());

instruction.setValue(aiPreferences.getInstruction());
temperature.setValue(aiPreferences.getTemperature());
temperature.setValue(LocalizedNumbers.doubleToString(aiPreferences.getTemperature()));
contextWindowSize.setValue(aiPreferences.getContextWindowSize());
documentSplitterChunkSize.setValue(aiPreferences.getDocumentSplitterChunkSize());
documentSplitterOverlapSize.setValue(aiPreferences.getDocumentSplitterOverlapSize());
ragMaxResultsCount.setValue(aiPreferences.getRagMaxResultsCount());
ragMinScore.setValue(aiPreferences.getRagMinScore());
ragMinScore.setValue(LocalizedNumbers.doubleToString(aiPreferences.getRagMinScore()));
}

@Override
Expand Down Expand Up @@ -304,12 +320,13 @@ public void storeSettings() {
aiPreferences.setHuggingFaceApiBaseUrl(huggingFaceApiBaseUrl.get() == null ? "" : huggingFaceApiBaseUrl.get());

aiPreferences.setInstruction(instruction.get());
aiPreferences.setTemperature(temperature.get());
// We already check the correctness of temperature and RAG minimum score in validators, so we don't need to check it here.
aiPreferences.setTemperature(LocalizedNumbers.stringToDouble(oldLocale, temperature.get()).get());
aiPreferences.setContextWindowSize(contextWindowSize.get());
aiPreferences.setDocumentSplitterChunkSize(documentSplitterChunkSize.get());
aiPreferences.setDocumentSplitterOverlapSize(documentSplitterOverlapSize.get());
aiPreferences.setRagMaxResultsCount(ragMaxResultsCount.get());
aiPreferences.setRagMinScore(ragMinScore.get());
aiPreferences.setRagMinScore(LocalizedNumbers.stringToDouble(oldLocale, ragMinScore.get()).get());
}

public void resetExpertSettings() {
Expand All @@ -321,11 +338,11 @@ public void resetExpertSettings() {
int resetContextWindowSize = AiDefaultPreferences.CONTEXT_WINDOW_SIZES.getOrDefault(selectedAiProvider.get(), Map.of()).getOrDefault(currentChatModel.get(), 0);
contextWindowSize.set(resetContextWindowSize);

temperature.set(AiDefaultPreferences.TEMPERATURE);
temperature.set(LocalizedNumbers.doubleToString(AiDefaultPreferences.TEMPERATURE));
documentSplitterChunkSize.set(AiDefaultPreferences.DOCUMENT_SPLITTER_CHUNK_SIZE);
documentSplitterOverlapSize.set(AiDefaultPreferences.DOCUMENT_SPLITTER_OVERLAP);
ragMaxResultsCount.set(AiDefaultPreferences.RAG_MAX_RESULTS_COUNT);
ragMinScore.set(AiDefaultPreferences.RAG_MIN_SCORE);
ragMinScore.set(LocalizedNumbers.doubleToString(AiDefaultPreferences.RAG_MIN_SCORE));
}

@Override
Expand Down Expand Up @@ -355,12 +372,14 @@ public boolean validateExpertSettings() {
apiBaseUrlValidator,
embeddingModelValidator,
instructionValidator,
temperatureValidator,
temperatureTypeValidator,
temperatureRangeValidator,
contextWindowSizeValidator,
documentSplitterChunkSizeValidator,
documentSplitterOverlapSizeValidator,
ragMaxResultsCountValidator,
ragMinScoreValidator
ragMinScoreTypeValidator,
ragMinScoreRangeValidator
);

return validators.stream().map(Validator::getValidationStatus).allMatch(ValidationStatus::isValid);
Expand Down Expand Up @@ -418,7 +437,7 @@ public StringProperty instructionProperty() {
return instruction;
}

public DoubleProperty temperatureProperty() {
public StringProperty temperatureProperty() {
return temperature;
}

Expand All @@ -438,7 +457,7 @@ public IntegerProperty ragMaxResultsCountProperty() {
return ragMaxResultsCount;
}

public DoubleProperty ragMinScoreProperty() {
public StringProperty ragMinScoreProperty() {
return ragMinScore;
}

Expand Down Expand Up @@ -470,8 +489,12 @@ public ValidationStatus getSystemMessageValidationStatus() {
return instructionValidator.getValidationStatus();
}

public ValidationStatus getTemperatureValidationStatus() {
return temperatureValidator.getValidationStatus();
public ValidationStatus getTemperatureTypeValidationStatus() {
return temperatureTypeValidator.getValidationStatus();
}

public ValidationStatus getTemperatureRangeValidationStatus() {
return temperatureRangeValidator.getValidationStatus();
}

public ValidationStatus getMessageWindowSizeValidationStatus() {
Expand All @@ -490,7 +513,11 @@ public ValidationStatus getRagMaxResultsCountValidationStatus() {
return ragMaxResultsCountValidator.getValidationStatus();
}

public ValidationStatus getRagMinScoreValidationStatus() {
return ragMinScoreValidator.getValidationStatus();
public ValidationStatus getRagMinScoreTypeValidationStatus() {
return ragMinScoreTypeValidator.getValidationStatus();
}

public ValidationStatus getRagMinScoreRangeValidationStatus() {
return ragMinScoreRangeValidator.getValidationStatus();
}
}
35 changes: 35 additions & 0 deletions src/main/java/org/jabref/logic/util/LocalizedNumbers.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.jabref.logic.util;

import java.text.NumberFormat;
import java.text.ParseException;
import java.util.Locale;
import java.util.Optional;

public class LocalizedNumbers {
public static Optional<Double> stringToDouble(String value) {
return stringToDouble(Locale.getDefault(), value);
}

public static Optional<Double> stringToDouble(Locale locale, String value) {
if (value == null) {
return Optional.empty();
}

try {
NumberFormat format = NumberFormat.getInstance(locale);
Number parsedNumber = format.parse(value);
return Optional.of(parsedNumber.doubleValue());
} catch (ParseException e) {
return Optional.empty();
}
}

public static String doubleToString(double value) {
return doubleToString(Locale.getDefault(), value);
}

public static String doubleToString(Locale locale, double value) {
NumberFormat format = NumberFormat.getInstance(locale);
return format.format(value);
}
}
4 changes: 3 additions & 1 deletion src/main/resources/l10n/JabRef_en.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2542,7 +2542,6 @@ Please\ provide\ a\ non-empty\ and\ unique\ citation\ key\ for\ this\ entry.=Ple
RAG\ -\ maximum\ results\ count=RAG - maximum results count
RAG\ -\ minimum\ score=RAG - minimum score
RAG\ max\ results\ count\ must\ be\ greater\ than\ 0=RAG max results count must be greater than 0
RAG\ min\ score\ must\ be\ greater\ than\ 0\ and\ less\ than\ 1=RAG min score must be greater than 0 and less than 1
Clear\ embeddings\ cache=Clear embeddings cache
Clear\ embeddings\ cache\ for\ current\ library?=Clear embeddings cache for current library?
Clearing\ embeddings\ cache...=Clearing embeddings cache...
Expand Down Expand Up @@ -2630,6 +2629,9 @@ Unable\ to\ generate\ summary=Unable to generate summary
Group\ %0=Group %0
AI\ chat\ with\ %0=AI chat with %0
Generating\ embeddings\ for\ %0=Generating embeddings for %0
RAG\ minimum\ score\ must\ be\ a\ number=RAG minimum score must be a number
RAG\ minimum\ score\ must\ be\ greater\ than\ 0\ and\ less\ than\ 1=RAG minimum score must be greater than 0 and less than 1
Temperature\ must\ be\ a\ number=Temperature must be a number
Link=Link
Source\ URL=Source URL
Expand Down

0 comments on commit d0415c8

Please sign in to comment.