Skip to content

feat(spring-ai-bedrock-converse): Introduce BedrockProxyChatOptions #1760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -131,7 +130,7 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch

private final BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;

private FunctionCallingOptions defaultOptions;
private BedrockProxyChatOptions defaultOptions;

/**
* Observation registry used for instrumentation.
Expand All @@ -144,7 +143,7 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
private ChatModelObservationConvention observationConvention;

public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, FunctionCallingOptions defaultOptions,
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, BedrockProxyChatOptions defaultOptions,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
ObservationRegistry observationRegistry) {

Expand Down Expand Up @@ -305,17 +304,14 @@ else if (message.getMessageType() == MessageType.TOOL) {
.map(sysMessage -> SystemContentBlock.builder().text(sysMessage.getContent()).build())
.toList();

FunctionCallingOptions updatedRuntimeOptions = (FunctionCallingOptions) this.defaultOptions.copy();
BedrockProxyChatOptions updatedRuntimeOptions = (BedrockProxyChatOptions) this.defaultOptions.copy();

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof FunctionCallingOptions) {
var functionCallingOptions = (FunctionCallingOptions) prompt.getOptions();
updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions)
.merge(functionCallingOptions);
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
updatedRuntimeOptions = (BedrockProxyChatOptions) updatedRuntimeOptions.merge(functionCallingOptions);
}
else if (prompt.getOptions() instanceof ChatOptions) {
var chatOptions = (ChatOptions) prompt.getOptions();
updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions).merge(chatOptions);
else if (prompt.getOptions() instanceof ChatOptions chatOptions) {
updatedRuntimeOptions = updatedRuntimeOptions.merge(chatOptions);
}
}

Expand All @@ -334,6 +330,7 @@ else if (prompt.getOptions() instanceof ChatOptions) {
? updatedRuntimeOptions.getTemperature().floatValue() : null)
.topP(updatedRuntimeOptions.getTopP() != null ? updatedRuntimeOptions.getTopP().floatValue() : null)
.build();

Document additionalModelRequestFields = ConverseApiUtils
.getChatOptionsAdditionalModelRequestFields(this.defaultOptions, prompt.getOptions());

Expand Down Expand Up @@ -586,7 +583,7 @@ public static final class Builder {

private Duration timeout = Duration.ofMinutes(10);

private FunctionCallingOptions defaultOptions = new FunctionCallingOptionsBuilder().build();
private BedrockProxyChatOptions defaultOptions = BedrockProxyChatOptions.builder().build();

private FunctionCallbackContext functionCallbackContext;

Expand Down Expand Up @@ -621,7 +618,7 @@ public Builder withTimeout(Duration timeout) {
return this;
}

public Builder withDefaultOptions(FunctionCallingOptions defaultOptions) {
public Builder withDefaultOptions(BedrockProxyChatOptions defaultOptions) {
Assert.notNull(defaultOptions, "'defaultOptions' must not be null.");
this.defaultOptions = defaultOptions;
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
/*
* Copyright 2024 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.converse;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/**
* @author Christian Tzolov
* @since 1.0.0
*/
public class BedrockProxyChatOptions implements FunctionCallingOptions {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why BedrockProxyChatOptions can not implement ChatOptions here as I think we got all of the ChatOptions as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FunctionCallingOptions implements ChatOptions already

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I got deceived by the explicit declaration of ChatOptions in some of the model option classes. I just pushed a minor PR to fix this: #1762

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not specific to this PR, but, do you think it would be good to have a DefaultFunctionCallingOptions as the default implementation for FunctionCallingOptions so that all the Model Chat options can extend? or, any reason why we didn't have the default implementation at the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have something like this called PortableFunctionCallingOptions, but it needs some cleaning regarding the Builder/Object relation ship there.
I started by extending it but then for the sake of the flexibility for the draft I've copied its content.
IMO Before merging this PR we would likely revert to the PortableFunctionCallingOptions parent, unless there are some constrains.


private List<FunctionCallback> functionCallbacks = new ArrayList<>();

private Set<String> functions = new HashSet<>();

private String model;

private Double frequencyPenalty;

private Integer maxTokens;

private Double presencePenalty;

private List<String> stopSequences;

private Double temperature;

private Integer topK;

private Double topP;

private Boolean proxyToolCalls = false;

private Map<String, Object> context = new HashMap<>();

private Map<String, Object> additional = new HashMap<>();

public static BedrockProxyChatOptionsBuilder builder() {
return new BedrockProxyChatOptionsBuilder();
}

@Override
public List<FunctionCallback> getFunctionCallbacks() {
return Collections.unmodifiableList(this.functionCallbacks);
}

public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
this.functionCallbacks = new ArrayList<>(functionCallbacks);
}

@Override
public Set<String> getFunctions() {
return Collections.unmodifiableSet(this.functions);
}

public void setFunctions(Set<String> functions) {
Assert.notNull(functions, "Functions must not be null");
this.functions = new HashSet<>(functions);
}

@Override
public String getModel() {
return this.model;
}

public void setModel(String model) {
this.model = model;
}

@Override
public Double getFrequencyPenalty() {
return this.frequencyPenalty;
}

public void setFrequencyPenalty(Double frequencyPenalty) {
this.frequencyPenalty = frequencyPenalty;
}

@Override
public Integer getMaxTokens() {
return this.maxTokens;
}

public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}

@Override
public Double getPresencePenalty() {
return this.presencePenalty;
}

public void setPresencePenalty(Double presencePenalty) {
this.presencePenalty = presencePenalty;
}

@Override
public List<String> getStopSequences() {
return this.stopSequences;
}

public void setStopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
}

@Override
public Double getTemperature() {
return this.temperature;
}

public void setTemperature(Double temperature) {
this.temperature = temperature;
}

@Override
public Integer getTopK() {
return this.topK;
}

public void setTopK(Integer topK) {
this.topK = topK;
}

@Override
public Double getTopP() {
return this.topP;
}

public void setTopP(Double topP) {
this.topP = topP;
}

@Override
public Boolean getProxyToolCalls() {
return this.proxyToolCalls;
}

public void setProxyToolCalls(Boolean proxyToolCalls) {
this.proxyToolCalls = proxyToolCalls;
}

public Map<String, Object> getToolContext() {
return Collections.unmodifiableMap(this.context);
}

public void setToolContext(Map<String, Object> context) {
Assert.notNull(context, "Context must not be null");
this.context = new HashMap<>(context);
}

public Map<String, Object> getAdditional() {
return Collections.unmodifiableMap(this.additional);
}

public void setAdditional(Map<String, Object> additional) {
Assert.notNull(additional, "Additional must not be null");
this.additional = new HashMap<>(additional);
}

@Override
public ChatOptions copy() {
return new BedrockProxyChatOptionsBuilder().model(this.model)
.frequencyPenalty(this.frequencyPenalty)
.maxTokens(this.maxTokens)
.presencePenalty(this.presencePenalty)
.stopSequences(this.stopSequences != null ? new ArrayList<>(this.stopSequences) : null)
.temperature(this.temperature)
.topK(this.topK)
.topP(this.topP)
.functions(new HashSet<>(this.functions))
.functionCallbacks(new ArrayList<>(this.functionCallbacks))
.proxyToolCalls(this.proxyToolCalls)
.toolContext(new HashMap<>(this.getToolContext()))
.additional(new HashMap<>(this.additional))
.build();
}

public BedrockProxyChatOptions merge(FunctionCallingOptions options) {

var builder = builder().model(StringUtils.hasText(options.getModel()) ? options.getModel() : this.model)
.frequencyPenalty(
options.getFrequencyPenalty() != null ? options.getFrequencyPenalty() : this.frequencyPenalty)
.maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.maxTokens)
.presencePenalty(options.getPresencePenalty() != null ? options.getPresencePenalty() : this.presencePenalty)
.stopSequences(options.getStopSequences() != null ? options.getStopSequences() : this.stopSequences)
.temperature(options.getTemperature() != null ? options.getTemperature() : this.temperature)
.topK(options.getTopK() != null ? options.getTopK() : this.topK)
.topP(options.getTopP() != null ? options.getTopP() : this.topP)
.proxyToolCalls(options.getProxyToolCalls() != null ? options.getProxyToolCalls() : this.proxyToolCalls);

Set<String> functions = new HashSet<>();
if (!CollectionUtils.isEmpty(this.functions)) {
functions.addAll(this.functions);
}
if (!CollectionUtils.isEmpty(options.getFunctions())) {
functions.addAll(options.getFunctions());
}
builder.functions(functions);

List<FunctionCallback> functionCallbacks = new ArrayList<>();
if (!CollectionUtils.isEmpty(this.functionCallbacks)) {
functionCallbacks.addAll(this.functionCallbacks);
}
if (!CollectionUtils.isEmpty(options.getFunctionCallbacks())) {
functionCallbacks.addAll(options.getFunctionCallbacks());
}
builder.functionCallbacks(functionCallbacks);

Map<String, Object> context = new HashMap<>();
if (!CollectionUtils.isEmpty(this.context)) {
context.putAll(this.context);
}
if (!CollectionUtils.isEmpty(options.getToolContext())) {
context.putAll(options.getToolContext());
}
builder.toolContext(context);

Map<String, Object> additional = new HashMap<>();
if (!CollectionUtils.isEmpty(this.additional)) {
context.putAll(this.additional);
}

if (options instanceof BedrockProxyChatOptions bedrockProxyChatOptions) {
if (!CollectionUtils.isEmpty(bedrockProxyChatOptions.getAdditional())) {
additional.putAll(bedrockProxyChatOptions.getAdditional());
}
}
builder.additional(additional);

return builder.build();
}

public BedrockProxyChatOptions merge(ChatOptions options) {

var builder = BedrockProxyChatOptions.builder()
.model(StringUtils.hasText(options.getModel()) ? options.getModel() : this.model)
.frequencyPenalty(
options.getFrequencyPenalty() != null ? options.getFrequencyPenalty() : this.frequencyPenalty)
.maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.maxTokens)
.presencePenalty(options.getPresencePenalty() != null ? options.getPresencePenalty() : this.presencePenalty)
.stopSequences(options.getStopSequences() != null ? options.getStopSequences() : this.stopSequences)
.temperature(options.getTemperature() != null ? options.getTemperature() : this.temperature)
.topK(options.getTopK() != null ? options.getTopK() : this.topK)
.topP(options.getTopP() != null ? options.getTopP() : this.topP);

return builder.build();
}

}
Loading