diff --git a/.env.example b/.env.example index e7f75f0f..ce8fccd3 100644 --- a/.env.example +++ b/.env.example @@ -35,6 +35,8 @@ OPENAI_COMPLETION_MAX_TOKENS= OPENAI_COMPLETION_FREQUENCY_PENALTY= OPENAI_COMPLETION_PRESENCE_PENALTY= OPENAI_IMAGE_GENERATION_SIZE= +OPENAI_IMAGE_GENERATION_MODEL= +OPENAI_IMAGE_GENERATION_QUALITY= LINE_TIMEOUT= LINE_CHANNEL_ACCESS_TOKEN= @@ -43,3 +45,8 @@ LINE_CHANNEL_SECRET= SERPAPI_TIMEOUT= SERPAPI_API_KEY= SERPAPI_LOCATION= + +PROVIDER_BASE_URL= +PROVIDER_BASE_TOKEN= +PROVIDER_BASE_MODEL= + diff --git a/app/app.js b/app/app.js index 30d0850c..995f0bbd 100644 --- a/app/app.js +++ b/app/app.js @@ -47,7 +47,7 @@ const handleEvents = async (events = []) => ( events .map((event) => new Event(event)) .filter((event) => event.isMessage) - .filter((event) => event.isText || event.isAudio) + .filter((event) => event.isText || event.isAudio || event.isImage) .map((event) => new Context(event)) .map((context) => context.initialize()), )) diff --git a/app/context.js b/app/context.js index 930459c2..95abba8a 100644 --- a/app/context.js +++ b/app/context.js @@ -9,6 +9,7 @@ import { addMark, convertText, fetchAudio, + fetchImage, fetchGroup, fetchUser, generateTranscription, @@ -87,6 +88,9 @@ class Context { const text = this.transcription.replace(config.BOT_NAME, '').trim(); return addMark(text); } + if (this.event.isImage) { + return this.transcription.trim() + } return '?'; } @@ -99,6 +103,10 @@ class Context { const text = this.transcription.toLowerCase(); return text.startsWith(config.BOT_NAME.toLowerCase()); } + if (this.event.isImage) { + const text = this.transcription.toLowerCase(); + return text.startsWith(config.BOT_NAME.toLowerCase()); + } return false; } @@ -116,6 +124,13 @@ class Context { return this.pushError(err); } } + if (this.event.isImage) { + try { + await this.saveImage(); + } catch (err) { + return this.pushError(err); + } + } updateHistory(this.id, (history) => history.write(this.source.name, this.trimmedText)); return this; } @@ -171,6 +186,11 @@ class Context { this.transcription = convertText(text); } + async saveImage() { + const base64String = await fetchImage(this.event.messageId); + this.transcription = (base64String); + } + /** * @param {Object} param * @param {string} param.text diff --git a/app/handlers/talk.js b/app/handlers/talk.js index 7ce0d6d7..7a1558dc 100644 --- a/app/handlers/talk.js +++ b/app/handlers/talk.js @@ -2,7 +2,7 @@ import config from '../../config/index.js'; import { t } from '../../locales/index.js'; import { ROLE_AI, ROLE_HUMAN } from '../../services/openai.js'; import { generateCompletion } from '../../utils/index.js'; -import { COMMAND_BOT_CONTINUE, COMMAND_BOT_TALK } from '../commands/index.js'; +import { COMMAND_BOT_CONTINUE, COMMAND_BOT_TALK, COMMAND_BOT_FORGET } from '../commands/index.js'; import Context from '../context.js'; import { updateHistory } from '../history/index.js'; import { getPrompt, setPrompt } from '../prompt/index.js'; @@ -24,14 +24,30 @@ const check = (context) => ( const exec = (context) => check(context) && ( async () => { const prompt = getPrompt(context.userId); - prompt.write(ROLE_HUMAN, `${t('__COMPLETION_DEFAULT_AI_TONE')(config.BOT_TONE)}${context.trimmedText}`).write(ROLE_AI); try { - const { text, isFinishReasonStop } = await generateCompletion({ prompt }); - prompt.patch(text); + const obj = { + text: "", + actions: [] + } + + if (context.event.isImage) { + context.pushText('Get Image', [COMMAND_BOT_FORGET]); + obj.text = context.trimmedText + prompt.writeImageMsg(ROLE_HUMAN, obj.text).write(ROLE_AI); + prompt.patch('Get Image!!'); + updateHistory(context.id, (history) => history.writeImageMsg(ROLE_HUMAN, obj.text)); + + } else { + prompt.write(ROLE_HUMAN, `${t('__COMPLETION_DEFAULT_AI_TONE')(config.BOT_TONE)}${context.trimmedText}`).write(ROLE_AI); + const { text, isFinishReasonStop } = await generateCompletion({ prompt }); + obj.text = text; + obj.actions = isFinishReasonStop ? [COMMAND_BOT_FORGET] : [COMMAND_BOT_CONTINUE]; + context.pushText(obj.text, obj.actions); + prompt.patch(obj.text); + updateHistory(context.id, (history) => history.write(config.BOT_NAME, obj.text)); + } + setPrompt(context.userId, prompt); - updateHistory(context.id, (history) => history.write(config.BOT_NAME, text)); - const actions = isFinishReasonStop ? [] : [COMMAND_BOT_CONTINUE]; - context.pushText(text, actions); } catch (err) { context.pushError(err); } diff --git a/app/history/history.js b/app/history/history.js index 631e1120..c2126a4b 100644 --- a/app/history/history.js +++ b/app/history/history.js @@ -40,6 +40,25 @@ class History { return this; } + writeImageMsg(role, content = '') { + this.messages.push({ + role: role, + content: [ + { + type: 'text', + text: '這是一張圖片' + }, + { + type: 'image', + image_url: { + url: content + } + } + ] + }); + return this; + } + /** * @param {string} content */ diff --git a/app/history/index.js b/app/history/index.js index 602a1e23..532c183d 100644 --- a/app/history/index.js +++ b/app/history/index.js @@ -23,6 +23,7 @@ const updateHistory = (contextId, callback) => { const history = getHistory(contextId); callback(history); setHistory(contextId, history); + printHistories() }; /** @@ -37,7 +38,7 @@ const printHistories = () => { .filter((contextId) => getHistory(contextId).messages.length > 0) .map((contextId) => `\n=== ${contextId.slice(0, 6)} ===\n\n${getHistory(contextId).toString()}\n`); if (messages.length < 1) return; - console.info(messages.join('')); + // console.info("printHistories", messages.join('')); }; export { diff --git a/app/models/event.js b/app/models/event.js index 0dca4e6e..90ed8a29 100644 --- a/app/models/event.js +++ b/app/models/event.js @@ -3,6 +3,7 @@ import { MESSAGE_TYPE_AUDIO, MESSAGE_TYPE_STICKER, MESSAGE_TYPE_TEXT, + MESSAGE_TYPE_IMAGE, SOURCE_TYPE_GROUP, } from '../../services/line.js'; @@ -62,6 +63,10 @@ class Event { return this.message.type === MESSAGE_TYPE_AUDIO; } + get isImage() { + return this.message.type === MESSAGE_TYPE_IMAGE; + } + /** * @returns {string} */ diff --git a/app/prompt/message.js b/app/prompt/message.js index fddf9d1c..757ca042 100644 --- a/app/prompt/message.js +++ b/app/prompt/message.js @@ -15,11 +15,15 @@ class Message { get isEnquiring() { return this.content === TYPE_SUM - || this.content === TYPE_ANALYZE - || this.content === TYPE_TRANSLATE; + || this.content === TYPE_ANALYZE + || this.content === TYPE_TRANSLATE; } toString() { + if (Array.isArray(this.content)) { + return `\n${this.role}: ${this.content[0].text}` + }; + return this.role ? `\n${this.role}: ${this.content}` : this.content; } } diff --git a/app/prompt/prompt.js b/app/prompt/prompt.js index 07df1b17..2526f265 100644 --- a/app/prompt/prompt.js +++ b/app/prompt/prompt.js @@ -49,6 +49,24 @@ class Prompt { return this; } + writeImageMsg(role, content = '') { + const tempContent = [ + { + type: 'text', + text: '這是一張圖片' + }, + { + type: 'image_url', + image_url: { + url: content + } + } + ] + + this.messages.push(new Message({ role, content: tempContent })); + return this; + } + /** * @param {string} content */ diff --git a/config/index.js b/config/index.js index ba16065e..7abeedec 100644 --- a/config/index.js +++ b/config/index.js @@ -34,20 +34,25 @@ const config = Object.freeze({ VERCEL_DEPLOY_HOOK_URL: env.VERCEL_DEPLOY_HOOK_URL || null, OPENAI_TIMEOUT: env.OPENAI_TIMEOUT || env.APP_API_TIMEOUT, OPENAI_API_KEY: env.OPENAI_API_KEY || null, - OPENAI_BASE_URL: env.OPENAI_BASE_URL || 'https://api.openai.com', + OPENAI_BASE_URL: env.OPENAI_BASE_URL || 'https://api.openai.com/v1', OPENAI_COMPLETION_MODEL: env.OPENAI_COMPLETION_MODEL || 'gpt-3.5-turbo', OPENAI_COMPLETION_TEMPERATURE: Number(env.OPENAI_COMPLETION_TEMPERATURE) || 1, - OPENAI_COMPLETION_MAX_TOKENS: Number(env.OPENAI_COMPLETION_MAX_TOKENS) || 64, + OPENAI_COMPLETION_MAX_TOKENS: Number(env.OPENAI_COMPLETION_MAX_TOKENS) || 200, OPENAI_COMPLETION_FREQUENCY_PENALTY: Number(env.OPENAI_COMPLETION_FREQUENCY_PENALTY) || 0, OPENAI_COMPLETION_PRESENCE_PENALTY: Number(env.OPENAI_COMPLETION_PRESENCE_PENALTY) || 0.6, OPENAI_COMPLETION_STOP_SEQUENCES: env.OPENAI_COMPLETION_STOP_SEQUENCES ? String(env.OPENAI_COMPLETION_STOP_SEQUENCES).split(',') : [' assistant:', ' user:'], OPENAI_IMAGE_GENERATION_SIZE: env.OPENAI_IMAGE_GENERATION_SIZE || '256x256', + OPENAI_IMAGE_GENERATION_MODEL: env.OPENAI_IMAGE_GENERATION_MODEL || 'dall-e-2', + OPENAI_IMAGE_GENERATION_QUALITY: env.OPENAI_IMAGE_GENERATION_QUALITY || 'standard', LINE_TIMEOUT: env.LINE_TIMEOUT || env.APP_API_TIMEOUT, LINE_CHANNEL_ACCESS_TOKEN: env.LINE_CHANNEL_ACCESS_TOKEN || null, LINE_CHANNEL_SECRET: env.LINE_CHANNEL_SECRET || null, SERPAPI_TIMEOUT: env.SERPAPI_TIMEOUT || env.APP_API_TIMEOUT, SERPAPI_API_KEY: env.SERPAPI_API_KEY || null, SERPAPI_LOCATION: env.SERPAPI_LOCATION || 'tw', + PROVIDER_BASE_URL: env.PROVIDER_BASE_URL || 'https://api.openai.com/v1', + PROVIDER_BASE_TOKEN: env.PROVIDER_BASE_TOKEN || null, + PROVIDER_BASE_MODEL: env.PROVIDER_BASE_MODEL || 'gpt-3.5-turbo', }); export default config; diff --git a/services/openai.js b/services/openai.js index aa1ac599..3f6e9c17 100644 --- a/services/openai.js +++ b/services/openai.js @@ -15,19 +15,29 @@ export const IMAGE_SIZE_512 = '512x512'; export const IMAGE_SIZE_1024 = '1024x1024'; export const MODEL_GPT_3_5_TURBO = 'gpt-3.5-turbo'; -export const MODEL_GPT_4 = 'gpt-4'; export const MODEL_WHISPER_1 = 'whisper-1'; +export const MODEL_GPT_4_TURBO = 'gpt-4-turbo'; + +const BASE_URL = config.PROVIDER_BASE_URL; const client = axios.create({ - baseURL: config.OPENAI_BASE_URL, timeout: config.OPENAI_TIMEOUT, headers: { + 'Provieder': '', 'Accept-Encoding': 'gzip, deflate, compress', + "HTTP-Referer": `https://line.me`, // Optional, for including your app on openrouter.ai rankings. + "X-Title": `LINE Chatbot`, // Optional, for including your app on openrouter.ai rankings. }, }); client.interceptors.request.use((c) => { - c.headers.Authorization = `Bearer ${config.OPENAI_API_KEY}`; + if (c.headers.Provieder === 'openai') { + c.headers.Authorization = `Bearer ${config.OPENAI_API_KEY}`; + + } else { + c.headers.Authorization = `Bearer ${config.PROVIDER_BASE_TOKEN}`; + + } return handleRequest(c); }); @@ -39,30 +49,51 @@ client.interceptors.response.use(handleFulfilled, (err) => { }); const createChatCompletion = ({ - model = config.OPENAI_COMPLETION_MODEL, + model = config.PROVIDER_BASE_MODEL, messages, temperature = config.OPENAI_COMPLETION_TEMPERATURE, maxTokens = config.OPENAI_COMPLETION_MAX_TOKENS, frequencyPenalty = config.OPENAI_COMPLETION_FREQUENCY_PENALTY, presencePenalty = config.OPENAI_COMPLETION_PRESENCE_PENALTY, -}) => client.post('/v1/chat/completions', { - model, - messages, - temperature, - max_tokens: maxTokens, - frequency_penalty: frequencyPenalty, - presence_penalty: presencePenalty, -}); +}) => { + const body = { + model, + messages, + temperature, + max_tokens: maxTokens, + frequency_penalty: frequencyPenalty, + presence_penalty: presencePenalty, + } + + let isAboutImageCompletion = false; + messages.forEach(element => { + if (element.role === ROLE_AI && element.content === "Get Image!!") { + body['model'] = MODEL_GPT_4_TURBO; + isAboutImageCompletion = true; + } + }); + + if (isAboutImageCompletion) { + return client.post(config.OPENAI_BASE_URL + '/chat/completions', body, { + headers: { + Provieder: 'openai', + }, + }) + } else { + return client.post(BASE_URL + '/chat/completions', body) + } + +}; const createTextCompletion = ({ - model = config.OPENAI_COMPLETION_MODEL, + model = config.PROVIDER_BASE_MODEL, prompt, temperature = config.OPENAI_COMPLETION_TEMPERATURE, maxTokens = config.OPENAI_COMPLETION_MAX_TOKENS, frequencyPenalty = config.OPENAI_COMPLETION_FREQUENCY_PENALTY, presencePenalty = config.OPENAI_COMPLETION_PRESENCE_PENALTY, stop = config.OPENAI_COMPLETION_STOP_SEQUENCES, -}) => client.post('/v1/completions', { +}) => client.post(BASE_URL + '/completions', { model, prompt, temperature, @@ -76,11 +107,25 @@ const createImage = ({ prompt, n = 1, size = IMAGE_SIZE_256, -}) => client.post('/v1/images/generations', { - prompt, - n, - size, -}); +}) => { + + // DALL-E 3 only supports 1024x1024 images. + if (config.OPENAI_IMAGE_GENERATION_MODEL === 'dall-e-3' && (size === IMAGE_SIZE_256 || size === IMAGE_SIZE_512)) { + size = IMAGE_SIZE_1024; + } + + return client.post(config.OPENAI_BASE_URL + '/images/generations', { + "model": config.OPENAI_IMAGE_GENERATION_MODEL, + "quality": config.OPENAI_IMAGE_GENERATION_QUALITY, + prompt, + n, + size, + }, { + headers: { + Provieder: 'openai', + }, + }) +}; const createAudioTranscriptions = ({ buffer, @@ -90,14 +135,35 @@ const createAudioTranscriptions = ({ const formData = new FormData(); formData.append('file', buffer, file); formData.append('model', model); - return client.post('/v1/audio/transcriptions', formData.getBuffer(), { - headers: formData.getHeaders(), + var headers = formData.getHeaders(); + headers['Provieder'] = 'openai'; + return client.post(config.OPENAI_BASE_URL + '/audio/transcriptions', formData.getBuffer(), { + headers: headers, }); }; +const createVisionTranscriptions = ({ + model = MODEL_GPT_4_TURBO, + temperature = config.OPENAI_COMPLETION_TEMPERATURE, + maxTokens = config.OPENAI_COMPLETION_MAX_TOKENS, + frequencyPenalty = config.OPENAI_COMPLETION_FREQUENCY_PENALTY, + presencePenalty = config.OPENAI_COMPLETION_PRESENCE_PENALTY, +}) => { + + return client.post(BASE_URL + '/chat/completions', { + model, + messages, + temperature, + max_tokens: maxTokens, + frequency_penalty: frequencyPenalty, + presence_penalty: presencePenalty, + }) +}; + export { createChatCompletion, createTextCompletion, createImage, createAudioTranscriptions, + createVisionTranscriptions, }; diff --git a/utils/fetch-image.js b/utils/fetch-image.js new file mode 100644 index 00000000..94c0ba75 --- /dev/null +++ b/utils/fetch-image.js @@ -0,0 +1,12 @@ +import { fetchContent } from '../services/line.js'; + +/** + * @param {string} messageId + * @returns {Promise} + */ +const fetchImage = async (messageId) => { + const { data } = await fetchContent({ messageId }); + return 'data:image/jpeg;base64,' + Buffer.from(data, 'binary').toString('base64'); +}; + +export default fetchImage; diff --git a/utils/generate-completion.js b/utils/generate-completion.js index 458570e3..f08ea1a5 100644 --- a/utils/generate-completion.js +++ b/utils/generate-completion.js @@ -20,10 +20,12 @@ class Completion { } } -const isChatCompletionModel = (model) => ( - String(model).startsWith('ft:gpt') - || String(model).startsWith('gpt') -); +// const isChatCompletionModel = (model) => ( +// String(model).startsWith('ft:gpt') +// || String(model).startsWith('gpt') +// ); + +const isChatCompletionModel = (model) => (true); /** * @param {Object} param diff --git a/utils/index.js b/utils/index.js index 9483f3be..05ccda18 100644 --- a/utils/index.js +++ b/utils/index.js @@ -2,6 +2,7 @@ import addMark from './add-mark.js'; import convertText from './convert-text.js'; import fetchAnswer from './fetch-answer.js'; import fetchAudio from './fetch-audio.js'; +import fetchImage from './fetch-image.js'; import fetchEnvironment from './fetch-environment.js'; import fetchGroup from './fetch-group.js'; import fetchUser from './fetch-user.js'; @@ -19,6 +20,7 @@ export { convertText, fetchAnswer, fetchAudio, + fetchImage, fetchEnvironment, fetchGroup, fetchUser,