Skip to content

Commit

Permalink
Merge pull request #333 from cdcd72/feature/support-gpt-4o
Browse files Browse the repository at this point in the history
Support GPT-4o model
  • Loading branch information
memochou1993 authored Jul 9, 2024
2 parents fa947ca + 34adcc9 commit 4793599
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 20 deletions.
2 changes: 1 addition & 1 deletion app/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ const handleEvents = async (events = []) => (
events
.map((event) => new Event(event))
.filter((event) => event.isMessage)
.filter((event) => event.isText || event.isAudio)
.filter((event) => event.isText || event.isAudio || event.isImage)
.map((event) => new Context(event))
.map((context) => context.initialize()),
))
Expand Down
24 changes: 22 additions & 2 deletions app/context.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
addMark,
convertText,
fetchAudio,
fetchImage,
fetchGroup,
fetchUser,
generateTranscription,
Expand Down Expand Up @@ -87,6 +88,9 @@ class Context {
const text = this.transcription.replace(config.BOT_NAME, '').trim();
return addMark(text);
}
if (this.event.isImage) {
return this.transcription.trim();
}
return '?';
}

Expand All @@ -99,6 +103,10 @@ class Context {
const text = this.transcription.toLowerCase();
return text.startsWith(config.BOT_NAME.toLowerCase());
}
if (this.event.isImage) {
const text = this.transcription.toLowerCase();
return text.startsWith(config.BOT_NAME.toLowerCase());
}
return false;
}

Expand All @@ -111,7 +119,14 @@ class Context {
}
if (this.event.isAudio) {
try {
await this.transcribe();
await this.transcribeAudio();
} catch (err) {
return this.pushError(err);
}
}
if (this.event.isImage) {
try {
await this.transcribeImage();
} catch (err) {
return this.pushError(err);
}
Expand Down Expand Up @@ -163,14 +178,19 @@ class Context {
this.source = new Source(sources[this.id]);
}

async transcribe() {
async transcribeAudio() {
const buffer = await fetchAudio(this.event.messageId);
const file = `/tmp/${this.event.messageId}.m4a`;
fs.writeFileSync(file, buffer);
const { text } = await generateTranscription({ file, buffer });
this.transcription = convertText(text);
}

async transcribeImage() {
const base64String = await fetchImage(this.event.messageId);
this.transcription = base64String;
}

/**
* @param {Object} param
* @param {string} param.text
Expand Down
25 changes: 17 additions & 8 deletions app/handlers/talk.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import config from '../../config/index.js';
import { t } from '../../locales/index.js';
import { ROLE_AI, ROLE_HUMAN } from '../../services/openai.js';
import { generateCompletion } from '../../utils/index.js';
import { COMMAND_BOT_CONTINUE, COMMAND_BOT_TALK } from '../commands/index.js';
import { COMMAND_BOT_CONTINUE, COMMAND_BOT_TALK, COMMAND_BOT_FORGET } from '../commands/index.js';
import Context from '../context.js';
import { updateHistory } from '../history/index.js';
import { getPrompt, setPrompt } from '../prompt/index.js';
Expand All @@ -24,14 +24,23 @@ const check = (context) => (
const exec = (context) => check(context) && (
async () => {
const prompt = getPrompt(context.userId);
prompt.write(ROLE_HUMAN, `${t('__COMPLETION_DEFAULT_AI_TONE')(config.BOT_TONE)}${context.trimmedText}`).write(ROLE_AI);
try {
const { text, isFinishReasonStop } = await generateCompletion({ prompt });
prompt.patch(text);
setPrompt(context.userId, prompt);
updateHistory(context.id, (history) => history.write(config.BOT_NAME, text));
const actions = isFinishReasonStop ? [] : [COMMAND_BOT_CONTINUE];
context.pushText(text, actions);
if (context.event.isImage) {
const text = context.trimmedText;
prompt.writeImage(ROLE_HUMAN, text).write(ROLE_AI);
prompt.patch('Get Image');
setPrompt(context.userId, prompt);
updateHistory(context.id, (history) => history.writeImage(ROLE_HUMAN, text));
context.pushText(t('__COMPLETION_GOT_IMAGE_REPLY'), [COMMAND_BOT_FORGET]);
} else {
prompt.write(ROLE_HUMAN, `${t('__COMPLETION_DEFAULT_AI_TONE')(config.BOT_TONE)}${context.trimmedText}`).write(ROLE_AI);
const { text, isFinishReasonStop } = await generateCompletion({ prompt });
prompt.patch(text);
setPrompt(context.userId, prompt);
updateHistory(context.id, (history) => history.write(config.BOT_NAME, text));
const actions = isFinishReasonStop ? [COMMAND_BOT_FORGET] : [COMMAND_BOT_CONTINUE];
context.pushText(text, actions);
}
} catch (err) {
context.pushError(err);
}
Expand Down
21 changes: 21 additions & 0 deletions app/history/history.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ class History {
return this;
}

/**
* @param {string} role
* @param {string} content
*/
writeImage(role, content = '') {
const imageContent = [
{
type: 'text',
text: '這是一張圖片',
},
{
type: 'image',
image_url: {
url: content,
},
},
];
this.messages.push(new Message({ role, content: imageContent }));
return this;
}

/**
* @param {string} content
*/
Expand Down
8 changes: 8 additions & 0 deletions app/models/event.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
MESSAGE_TYPE_AUDIO,
MESSAGE_TYPE_STICKER,
MESSAGE_TYPE_TEXT,
MESSAGE_TYPE_IMAGE,
SOURCE_TYPE_GROUP,
} from '../../services/line.js';

Expand Down Expand Up @@ -62,6 +63,13 @@ class Event {
return this.message.type === MESSAGE_TYPE_AUDIO;
}

/**
* @returns {boolean}
*/
get isImage() {
return this.message.type === MESSAGE_TYPE_IMAGE;
}

/**
* @returns {string}
*/
Expand Down
3 changes: 3 additions & 0 deletions app/prompt/message.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class Message {
}

toString() {
if (Array.isArray(this.content)) {
return `\n${this.role}: ${this.content[0].text}`;
}
return this.role ? `\n${this.role}: ${this.content}` : this.content;
}
}
Expand Down
21 changes: 21 additions & 0 deletions app/prompt/prompt.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,27 @@ class Prompt {
return this;
}

/**
* @param {string} role
* @param {string} content
*/
writeImage(role, content = '') {
const imageContent = [
{
type: 'text',
text: '這是一張圖片',
},
{
type: 'image_url',
image_url: {
url: content,
},
},
];
this.messages.push(new Message({ role, content: imageContent }));
return this;
}

/**
* @param {string} content
*/
Expand Down
1 change: 1 addition & 0 deletions locales/en.js
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ const en = {
__COMPLETION_SEARCH_NOT_FOUND: '查無資料', // TODO
__COMPLETION_QUOTATION_MARK_OPENING: '"',
__COMPLETION_QUOTATION_MARK_CLOSING: '"',
__COMPLETION_GOT_IMAGE_REPLY: 'The image has been obtained, please explain the intention.',
__ERROR_ECONNABORTED: 'Timed out',
__ERROR_UNKNOWN: 'Something went wrong',
__ERROR_MAX_GROUPS_REACHED: 'Maximum groups reached',
Expand Down
1 change: 1 addition & 0 deletions locales/ja.js
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ const ja = {
__COMPLETION_SEARCH_NOT_FOUND: '查無資料', // TODO
__COMPLETION_QUOTATION_MARK_OPENING: '「',
__COMPLETION_QUOTATION_MARK_CLOSING: '」',
__COMPLETION_GOT_IMAGE_REPLY: '画像を取得しました、意図を説明してください。',
__ERROR_ECONNABORTED: '接続がタイムアウトしました。',
__ERROR_UNKNOWN: '技術的な問題が発生しています。',
__ERROR_MAX_GROUPS_REACHED: '最大ユーザー数に達しています。',
Expand Down
1 change: 1 addition & 0 deletions locales/zh.js
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ const zh = {
__COMPLETION_SEARCH_NOT_FOUND: '查無資料',
__COMPLETION_QUOTATION_MARK_OPENING: '「',
__COMPLETION_QUOTATION_MARK_CLOSING: '」',
__COMPLETION_GOT_IMAGE_REPLY: '已取得圖片,請說明意圖。',
__ERROR_ECONNABORTED: '這個問題太複雜了',
__ERROR_UNKNOWN: '系統出了點狀況',
__ERROR_MAX_GROUPS_REACHED: '群組數量到達上限了',
Expand Down
31 changes: 22 additions & 9 deletions services/openai.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export const IMAGE_SIZE_512 = '512x512';
export const IMAGE_SIZE_1024 = '1024x1024';

export const MODEL_GPT_3_5_TURBO = 'gpt-3.5-turbo';
export const MODEL_GPT_4 = 'gpt-4';
export const MODEL_GPT_4_OMNI = 'gpt-4o';
export const MODEL_WHISPER_1 = 'whisper-1';

const client = axios.create({
Expand All @@ -38,21 +38,34 @@ client.interceptors.response.use(handleFulfilled, (err) => {
return handleRejected(err);
});

const isAboutImageCompletion = ({ messages }) => {
let flag = false;
messages.forEach((message) => {
if (message.role === ROLE_AI && message.content === 'Get Image') {
flag = true;
}
});
return flag;
};

const createChatCompletion = ({
model = config.OPENAI_COMPLETION_MODEL,
messages,
temperature = config.OPENAI_COMPLETION_TEMPERATURE,
maxTokens = config.OPENAI_COMPLETION_MAX_TOKENS,
frequencyPenalty = config.OPENAI_COMPLETION_FREQUENCY_PENALTY,
presencePenalty = config.OPENAI_COMPLETION_PRESENCE_PENALTY,
}) => client.post('/v1/chat/completions', {
model,
messages,
temperature,
max_tokens: maxTokens,
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
});
}) => {
const body = {
model: isAboutImageCompletion({ messages }) ? MODEL_GPT_4_OMNI : model,
messages,
temperature,
max_tokens: maxTokens,
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
};
return client.post('/v1/chat/completions', body);
};

const createTextCompletion = ({
model = config.OPENAI_COMPLETION_MODEL,
Expand Down
12 changes: 12 additions & 0 deletions utils/fetch-image.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { fetchContent } from '../services/line.js';

/**
* @param {string} messageId
* @returns {Promise<string>}
*/
const fetchImage = async (messageId) => {
const { data } = await fetchContent({ messageId });
return `data:image/jpeg;base64,${Buffer.from(data, 'binary').toString('base64')}`;
};

export default fetchImage;
2 changes: 2 additions & 0 deletions utils/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import addMark from './add-mark.js';
import convertText from './convert-text.js';
import fetchAnswer from './fetch-answer.js';
import fetchAudio from './fetch-audio.js';
import fetchImage from './fetch-image.js';
import fetchEnvironment from './fetch-environment.js';
import fetchGroup from './fetch-group.js';
import fetchUser from './fetch-user.js';
Expand All @@ -19,6 +20,7 @@ export {
convertText,
fetchAnswer,
fetchAudio,
fetchImage,
fetchEnvironment,
fetchGroup,
fetchUser,
Expand Down

0 comments on commit 4793599

Please sign in to comment.