Skip to content

Commit

Permalink
Gpt 4 refactor (#71)
Browse files Browse the repository at this point in the history
* refactor: tokenCount counts array of strings instead of conversation

* refactor: logic to decide when to use GPT4

---------

Co-authored-by: Lauris Skraucis <[email protected]>
  • Loading branch information
supalarry and Lauris Skraucis authored Aug 15, 2023
1 parent bebc683 commit 841279c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { chatCompletion } from '@adaptly/services/openai/utils/chatCompletion';
import { getMessageContent } from '@adaptly/services/openai/utils/getMessageContent';
import { tokenCount } from '@adaptly/services/openai/utils/tokenCount';
import { GPT35_MODEL, GPT4_MODEL, MAX_NUM_TOKENS_8K } from '@adaptly/services/openai/client';
import { getConversationContents } from '@adaptly/services/openai/utils/getConversationContents';

export type BreakingChanges = {
cursorVersion: string;
Expand Down Expand Up @@ -121,7 +122,8 @@ async function extractBreakingChanges(packageName: string, cursorVersion: string
let completionData: CreateChatCompletionResponse;

try {
const firstCheckModel = tokenCount(breakingChangesConversation, GPT4_MODEL) < MAX_NUM_TOKENS_8K ? GPT4_MODEL : GPT35_MODEL;
const conversationContents = getConversationContents(breakingChangesConversation);
const firstCheckModel = tokenCount([...conversationContents, changelog], GPT4_MODEL) < MAX_NUM_TOKENS_8K ? GPT4_MODEL : GPT35_MODEL;
// here we need to specify function to have good JSON reply structure
const completion = await chatCompletion(breakingChangesConversation, firstCheckModel);
Logger.info('ChatGPT: Breaking changes extracted', { packageName, cursorVersion, breakingChanges: completion.data.choices });
Expand All @@ -133,12 +135,20 @@ async function extractBreakingChanges(packageName: string, cursorVersion: string
let breakingChanges = getBreakingChangesFromChatCompletion(completionData);

if (breakingChanges.length > 0) {
const modelMessage = completionData.choices[0].message;

if (modelMessage) {
breakingChangesConversation.push(modelMessage);
}

breakingChangesConversation.push({
role: RoleSystem,
content: `${breakingChangesDoubleCheckPrompt}`
});

const secondCheckModel = tokenCount(breakingChangesConversation, GPT4_MODEL) < MAX_NUM_TOKENS_8K ? GPT4_MODEL : GPT35_MODEL;
const nextConversationContents = getConversationContents(breakingChangesConversation);
const secondCheckModel =
tokenCount([...nextConversationContents, JSON.stringify(completionData)], GPT4_MODEL) < MAX_NUM_TOKENS_8K ? GPT4_MODEL : GPT35_MODEL;

const completion = await chatCompletion(breakingChangesConversation, secondCheckModel);
Logger.info('Double check ChatGPT: Breaking changes extracted', { packageName, cursorVersion, breakingChanges: completion.data.choices });
Expand Down
7 changes: 7 additions & 0 deletions source/services/openai/utils/getConversationContents.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import { ChatCompletionRequestMessage } from 'openai';

export function getConversationContents(conversation: ChatCompletionRequestMessage[]): string[] {
const messages = conversation.map((message) => message.content);

return messages.filter((message) => message !== undefined) as string[];
}
7 changes: 5 additions & 2 deletions source/services/openai/utils/tokenCount.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { RoleSystem } from '../types';
import { getConversationContents } from './getConversationContents';
import { tokenCount } from './tokenCount';
import { ChatCompletionRequestMessage } from 'openai';

Expand All @@ -19,7 +20,8 @@ describe('tokenCount', () => {
}
];

const result = tokenCount(conversation, 'gpt-3.5-turbo-16k-0613');
const conversationContents = getConversationContents(conversation);
const result = tokenCount(conversationContents, 'gpt-3.5-turbo-16k-0613');

// Replace with expected token count for the given conversation and model
const expectedTokenCount = conversation.reduce((acc, message) => acc + message.tokenCount, 0);
Expand All @@ -43,7 +45,8 @@ describe('tokenCount', () => {
}
];

const result = tokenCount(conversation, 'gpt-4-0613');
const conversationContents = getConversationContents(conversation);
const result = tokenCount(conversationContents, 'gpt-4-0613');

// Replace with expected token count for the given conversation and model
const expectedTokenCount = conversation.reduce((acc, message) => acc + message.tokenCount, 0);
Expand Down
8 changes: 3 additions & 5 deletions source/services/openai/utils/tokenCount.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ import { ChatCompletionRequestMessage } from 'openai';
const enc4 = encoding_for_model('gpt-4');
const enc35 = encoding_for_model('gpt-3.5-turbo');

export function tokenCount(conversation: ChatCompletionRequestMessage[], model: 'gpt-3.5-turbo-16k-0613' | 'gpt-4-0613'): number {
export function tokenCount(texts: string[], model: 'gpt-3.5-turbo-16k-0613' | 'gpt-4-0613'): number {
let length = 0;

const enc = model === 'gpt-3.5-turbo-16k-0613' ? enc35 : enc4;

for (let message of conversation) {
if (message.content) {
length += enc.encode(message.content).length;
}
for (let text of texts) {
length += enc.encode(text).length;
}

return length;
Expand Down

0 comments on commit 841279c

Please sign in to comment.