Skip to content

Commit

Permalink
chore: [vertexai] switch endpoint to prod (#10143)
Browse files Browse the repository at this point in the history
Allow made the following changes
1. Returns unmodifiable history and allow history to be changed with
   getHistory()
2. Don't block getHistory() forever if the last round ends with abnormal
   response.

* fix the prod endpoint
  • Loading branch information
ZhenyiQ authored Dec 12, 2023
1 parent bcd9ea4 commit 026ff86
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ public PredictionServiceClient getPredictionServiceClient() throws IOException {
if (predictionServiceClient == null) {
PredictionServiceSettings settings =
PredictionServiceSettings.newBuilder()
.setEndpoint(
String.format("%s-autopush-aiplatform.sandbox.googleapis.com:443", this.location))
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", this.location))
.setCredentialsProvider(FixedCredentialsProvider.create(this.credentials))
.build();
predictionServiceClient = PredictionServiceClient.create(settings);
Expand All @@ -162,8 +161,7 @@ public PredictionServiceClient getPredictionServiceRestClient() throws IOExcepti
if (predictionServiceRestClient == null) {
PredictionServiceSettings settings =
PredictionServiceSettings.newHttpJsonBuilder()
.setEndpoint(
String.format("%s-autopush-aiplatform.sandbox.googleapis.com:443", this.location))
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", this.location))
.setCredentialsProvider(FixedCredentialsProvider.create(this.credentials))
.build();
predictionServiceRestClient = PredictionServiceClient.create(settings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.google.cloud.vertexai.api.SafetySetting;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/** Represents a conversation between the user and the model */
Expand Down Expand Up @@ -301,18 +302,25 @@ private void checkLastResponseAndEditHistory() throws IllegalStateException {
if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) {
// We also remove the request from the history.
removeLastContent();
currentResponseStream = null;
throw new IllegalStateException(
String.format(
"Response stream did not finish normally. Finish reason is %s.", finishReason));
"The last round of conversation will not be added to history because response"
+ " stream did not finish normally. Finish reason is %s.",
finishReason));
}
history.add(getContent(response));
} else if (currentResponseStream == null && currentResponse != null) {
FinishReason finishReason = getFinishReason(currentResponse);
// We also remove the request from the history.
if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) {
// We also remove the request from the history.
removeLastContent();
currentResponse = null;
throw new IllegalStateException(
String.format("Response did not finish normally. Finish reason is %s.", finishReason));
String.format(
"The last round of conversation will not be added to history because response did"
+ " not finish normally. Finish reason is %s.",
finishReason));
}
history.add(getContent(currentResponse));
currentResponse = null;
Expand All @@ -322,10 +330,26 @@ private void checkLastResponseAndEditHistory() throws IllegalStateException {
/**
* Returns the history of the conversation.
*
* @return the history of the conversation.
* @return an unmodifiable history of the conversation.
*/
public List<Content> getHistory() {
checkLastResponseAndEditHistory();
return history;
try {
checkLastResponseAndEditHistory();
} catch (IllegalStateException e) {
if (e.getMessage()
.contains("The last round of conversation will not be added to history because")) {
IllegalStateException modifiedExecption =
new IllegalStateException("Rerun getHistory() to get cleaned history.");
modifiedExecption.initCause(e);
throw modifiedExecption;
}
throw e;
}
return Collections.unmodifiableList(history);
}

/** Set the history to a list of Content */
public void setHistory(List<Content> history) {
this.history = history;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,29 @@ public void sendMessageWithText_historyContainsTwoTurns() throws IOException {
assertThat(history.get(1).getParts(0).getText()).isEqualTo(FULL_RESPONSE_TEXT);
}

@Test
public void sendMessageWithTextThenModifyHistory_historyChangedToNewContentList()
throws IOException {

// (Arrange) Set up the return value of the generateContent
when(mockGenerativeModel.generateContent(
Arrays.asList(ContentMaker.fromString(SAMPLE_MESSAGE1)), null, null))
.thenReturn(RESPONSE_FROM_UNARY_CALL);

// (Act) Send text message via sendMessage and get the history.
GenerateContentResponse response = chat.sendMessage(SAMPLE_MESSAGE1);
List<Content> history = chat.getHistory();
// (Assert) Assert that 1) the first content contains the user request text, and 2) the second
// content in history contains the response.
assertThat(history.get(0).getParts(0).getText()).isEqualTo(SAMPLE_MESSAGE1);
assertThat(history.get(1).getParts(0).getText()).isEqualTo(FULL_RESPONSE_TEXT);

// (Act) Set history to an empty list
chat.setHistory(Arrays.asList());
// (Assert) Asser that the history is empty.
assertThat(chat.getHistory().size()).isEqualTo(0);
}

@Test
public void sendMessageStreamWithText_throwsIllegalStateExceptionWhenFinishReasonIsNotSTOP()
throws IOException {
Expand All @@ -234,10 +257,11 @@ public void sendMessageStreamWithText_throwsIllegalStateExceptionWhenFinishReaso
// reason.
IllegalStateException thrown =
assertThrows(IllegalStateException.class, () -> chat.getHistory());
assertThat(thrown)
.hasMessageThat()
.isEqualTo(
"Response stream did not finish normally. Finish reason is FINISH_REASON_UNSPECIFIED.");
assertThat(thrown).hasMessageThat().isEqualTo("Rerun getHistory() to get cleaned history.");

// Assert that the history can be fetched again and it's empty.
List<Content> history = chat.getHistory();
assertThat(history.size()).isEqualTo(0);
}

@Test
Expand All @@ -255,8 +279,9 @@ public void sendMessageWithText_throwsIllegalStateExceptionWhenFinishReasonIsNot
// reason.
IllegalStateException thrown =
assertThrows(IllegalStateException.class, () -> chat.getHistory());
assertThat(thrown)
.hasMessageThat()
.isEqualTo("Response did not finish normally. Finish reason is SAFETY.");
assertThat(thrown).hasMessageThat().isEqualTo("Rerun getHistory() to get cleaned history.");
// Assert that the history can be fetched again and it's empty.
List<Content> history = chat.getHistory();
assertThat(history.size()).isEqualTo(0);
}
}

0 comments on commit 026ff86

Please sign in to comment.