From 692c24329928cea0d941f63a881e69bbb616066b Mon Sep 17 00:00:00 2001 From: suhas-koheda Date: Tue, 7 Jan 2025 22:40:39 +0530 Subject: [PATCH] Checking Safety Settings and Tool Config --- .../googleaigemini/spring/AutoConfig.java | 79 +++++++++---------- 1 file changed, 37 insertions(+), 42 deletions(-) diff --git a/langchain4j-google-ai-gemini-spring-boot-starter/src/main/java/dev/langchain4j/googleaigemini/spring/AutoConfig.java b/langchain4j-google-ai-gemini-spring-boot-starter/src/main/java/dev/langchain4j/googleaigemini/spring/AutoConfig.java index bd9aef4..879296b 100644 --- a/langchain4j-google-ai-gemini-spring-boot-starter/src/main/java/dev/langchain4j/googleaigemini/spring/AutoConfig.java +++ b/langchain4j-google-ai-gemini-spring-boot-starter/src/main/java/dev/langchain4j/googleaigemini/spring/AutoConfig.java @@ -24,7 +24,7 @@ public class AutoConfig { }) GoogleAiGeminiChatModel googleAiGeminiChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getChatModel(); - return GoogleAiGeminiChatModel.builder() + GoogleAiGeminiChatModel.GoogleAiGeminiChatModelBuilder builder = GoogleAiGeminiChatModel.builder() .apiKey(chatModelProperties.apiKey()) .modelName(chatModelProperties.modelName()) .temperature(chatModelProperties.temperature()) @@ -32,13 +32,23 @@ GoogleAiGeminiChatModel googleAiGeminiChatModel(Properties properties) { .topK(chatModelProperties.topK()) .maxOutputTokens(chatModelProperties.maxOutputTokens()) .responseFormat(chatModelProperties.responseFormat()) - .logRequestsAndResponses(chatModelProperties.logRequestsAndResponses()) - .safetySettings(checkSafetySettingForNull(chatModelProperties.safetySetting())) - .toolConfig( - checkGeminiModeForNull(chatModelProperties.functionCallingConfig()), - checkFunctionNamesForNull(chatModelProperties.functionCallingConfig()) - ) - .build(); + .logRequestsAndResponses(chatModelProperties.logRequestsAndResponses()); + + if (chatModelProperties.safetySetting() != null && !chatModelProperties.safetySetting().isEmpty()) { + builder.safetySettings(chatModelProperties.safetySetting()); + } + + if (chatModelProperties.functionCallingConfig() != null) { + builder.toolConfig(chatModelProperties + .functionCallingConfig() + .geminiMode(), + chatModelProperties + .functionCallingConfig() + .allowedFunctionNames() + .toArray(new String[0])); + } + + return builder.build(); } @Bean @@ -48,20 +58,31 @@ GoogleAiGeminiChatModel googleAiGeminiChatModel(Properties properties) { }) GoogleAiGeminiStreamingChatModel googleAiGeminiStreamingChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getStreamingChatModel(); - return GoogleAiGeminiStreamingChatModel.builder() + GoogleAiGeminiStreamingChatModel.GoogleAiGeminiStreamingChatModelBuilder builder = GoogleAiGeminiStreamingChatModel.builder() .apiKey(chatModelProperties.apiKey()) .modelName(chatModelProperties.modelName()) .temperature(chatModelProperties.temperature()) .topP(chatModelProperties.topP()) .topK(chatModelProperties.topK()) + .maxOutputTokens(chatModelProperties.maxOutputTokens()) .responseFormat(chatModelProperties.responseFormat()) - .logRequestsAndResponses(chatModelProperties.logRequestsAndResponses()) - .safetySettings(checkSafetySettingForNull(chatModelProperties.safetySetting())) - .toolConfig( - checkGeminiModeForNull(chatModelProperties.functionCallingConfig()), - checkFunctionNamesForNull(chatModelProperties.functionCallingConfig()) - ) - .build(); + .logRequestsAndResponses(chatModelProperties.logRequestsAndResponses()); + if (chatModelProperties.safetySetting() != null && !chatModelProperties.safetySetting().isEmpty()) { + builder.safetySettings(chatModelProperties.safetySetting()); + } + + if (chatModelProperties + .functionCallingConfig() != null) { + builder.toolConfig(chatModelProperties + .functionCallingConfig() + .geminiMode(), + chatModelProperties + .functionCallingConfig() + .allowedFunctionNames() + .toArray(new String[0])); + } + + return builder.build(); } @Bean @@ -82,30 +103,4 @@ GoogleAiEmbeddingModel googleAiEmbeddingModel(Properties properties) { .titleMetadataKey(embeddingModelProperties.titleMetadataKey()) .build(); } - - private String[] checkFunctionNamesForNull(GeminiFunctionCallingConfig geminiFunctionCallingConfig) { - if(geminiFunctionCallingConfig==null){ - return new String[0]; - } - return geminiFunctionCallingConfig.allowedFunctionNames().toArray(new String[0]); - } - - private GeminiMode checkGeminiModeForNull(GeminiFunctionCallingConfig geminiFunctionCallingConfig) { - if(geminiFunctionCallingConfig==null){ - return GeminiMode.NONE; - } - return geminiFunctionCallingConfig.geminiMode(); - } - - private Map checkSafetySettingForNull(Map safetySetting) { - if(safetySetting==null || safetySetting.isEmpty()){ - Map defaultMap= new HashMap<>(); - defaultMap.put(HARM_CATEGORY_CIVIC_INTEGRITY,HARM_BLOCK_THRESHOLD_UNSPECIFIED); - return defaultMap; - } - Map userMap= new HashMap<>(); - safetySetting.keySet().forEach(key-> userMap.put(key,safetySetting.get(key))); - return userMap; - } - } \ No newline at end of file