Skip to content

Commit

Permalink
feat: support Claude v3 (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
jessieweiyi authored May 14, 2024
1 parent 27b0a12 commit c0cb406
Show file tree
Hide file tree
Showing 46 changed files with 680 additions and 783 deletions.
2 changes: 1 addition & 1 deletion .ncurc.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .projen/deps.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions .projenrc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const monorepo = new MonorepoProject({
'node-localstorage',
'prompts',
'tsconfig-paths',
'@pnpm/logger'
],
tsconfig: {
compilerOptions: {
Expand Down
2 changes: 1 addition & 1 deletion .syncpackrc.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions demo/api/generated/runtime/python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 11 additions & 1 deletion demo/corpus/logic/.projen/deps.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion demo/corpus/logic/package.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion demo/corpus/logic/src/indexing/vectorstore.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */
import { VectorStore } from 'langchain/vectorstores/base';
import { VectorStore } from '@langchain/core/vectorstores';
import { getEmbeddingsByModelId } from '../embedding/util';
import { vectorStoreFactory } from '../vectorstore';

Expand Down
2 changes: 1 addition & 1 deletion demo/corpus/logic/src/indexing/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import fs from 'node:fs/promises';
import { getLogger } from '@aws/galileo-sdk/lib/common';
import { PGVectorStore } from '@aws/galileo-sdk/lib/vectorstores';
import { MetricUnits } from '@aws-lambda-powertools/metrics';
import { Document } from 'langchain/document';
import { Document } from '@langchain/core/documents';
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import { IndexEntity, IndexingCache } from './datastore';
import { chunkArray } from './utils';
Expand Down
4 changes: 2 additions & 2 deletions demo/corpus/logic/src/vectorstore/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
PDX-License-Identifier: Apache-2.0 */
import { PGVectorStore, PGVectorStoreOptions } from '@aws/galileo-sdk/lib/vectorstores';
import { RDSConnConfig, getRDSConnConfig } from '@aws/galileo-sdk/lib/vectorstores/pgvector/rds';
import { Embeddings } from 'langchain/embeddings/base';
import { VectorStore } from 'langchain/vectorstores/base';
import { Embeddings } from '@langchain/core/embeddings';
import { VectorStore } from '@langchain/core/vectorstores';
import { ENV } from '../env';

let __RDS_CONN__: RDSConnConfig;
Expand Down
10 changes: 5 additions & 5 deletions demo/website/.projen/deps.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion demo/website/package.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ export const PromptEditor: FC<PromptEditorProps> = (props) => {
) : (
<CodeEditor
ref={editorRef as any}
value={value}
value={typeof value === 'string' ? value : ''}
onDelayedChange={({ detail }) => onChange(detail.value)}
completions={completions}
language="handlebars"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */

import { Document } from 'langchain/document';
import { AIMessage, BaseMessage, HumanMessage } from 'langchain/schema';
import { Document } from '@langchain/core/documents';
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
import '@aws/galileo-sdk/lib/langchain/patch';

export const CONTEXT_DOCUMENTS = [
Expand Down
1 change: 1 addition & 0 deletions package.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions packages/galileo-cli/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions packages/galileo-cli/src/lib/account-utils/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ export async function listBedrockModels({
export async function listBedrockTextModels(options: Required<CredentialsParams>): Promise<BedrockModelSummary[]> {
const models = await listBedrockModels(options);
return models.filter((v) => {
const modalities = new Set<string>([...v.inputModalities, ...v.outputModalities]);
return modalities.size === 1 && modalities.has(BedrockModality.TEXT);
return v.inputModalities.includes(BedrockModality.TEXT) && v.outputModalities.includes(BedrockModality.TEXT);
});
}
12 changes: 11 additions & 1 deletion packages/galileo-sdk/.projen/deps.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion packages/galileo-sdk/package.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions packages/galileo-sdk/src/chat/chain.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */
import { PromptTemplate } from '@langchain/core/prompts';
import { BaseRetriever } from '@langchain/core/retrievers';
import { ChainValues } from '@langchain/core/utils/types';
import { BaseLanguageModel } from 'langchain/base_language';
import { CallbackManagerForChainRun } from 'langchain/callbacks';
import { BaseChain, ChainInputs, LLMChain, QAChainParams, StuffDocumentsChain } from 'langchain/chains';
import { PromptTemplate } from 'langchain/prompts';
import { ChainValues } from 'langchain/schema';
import { BaseRetriever } from 'langchain/schema/retriever';
import { ResolvedLLMChainConfig } from './config/index.js';
import { getLogger } from '../common/index.js';
import { startPerfMetric } from '../common/metrics/index.js';
Expand Down Expand Up @@ -222,7 +222,7 @@ export class ChatEngineChain extends BaseChain implements ChatEngineChainInput {
const $$CombineDocumentsExecutionTime = startPerfMetric('Chain.QA.ExecutionTime', {
highResolution: true,
});
const result = await this.qaChain.call(inputs, runManager?.getChild('combine_documents'));
const result = await this.qaChain.invoke(inputs, runManager?.getChild('combine_documents'));
$$CombineDocumentsExecutionTime();
logger.debug('Chain:condenseQuestionChain:output', { output: result });

Expand Down
2 changes: 1 addition & 1 deletion packages/galileo-sdk/src/chat/config/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */
import { PromptTemplate } from '@langchain/core/prompts';
import type {
ChatEngineConfig,
ChatEngineChainConfig,
Expand All @@ -8,7 +9,6 @@ import type {
ChatEngineMemoryConfig,
} from 'api-typescript-runtime';
import { BaseLanguageModel } from 'langchain/base_language';
import { PromptTemplate } from 'langchain/prompts';
import { difference } from 'lodash';
import { ChainType } from '../../schema/index.js';
import { mergeConfig } from '../../utils/merge.js';
Expand Down
8 changes: 4 additions & 4 deletions packages/galileo-sdk/src/chat/context.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */
import { BedrockChat } from '@langchain/community/chat_models/bedrock';
import { SageMakerEndpoint } from '@langchain/community/llms/sagemaker_endpoint';
import { PromptTemplate } from '@langchain/core/prompts';
import { BaseLanguageModel } from 'langchain/base_language';
import { Bedrock } from 'langchain/llms/bedrock';
import { SageMakerEndpoint } from 'langchain/llms/sagemaker_endpoint';
import { PromptTemplate } from 'langchain/prompts';
import { merge } from 'lodash';
import { getLogger } from '../common/index.js';
import { ModelAdapter } from '../models/adapter.js';
Expand Down Expand Up @@ -97,7 +97,7 @@ export class ChatEngineContext {
};
logger.debug('Resolved bedrock kwargs', { modelKwargs });

llm = new Bedrock({
llm = new BedrockChat({
verbose: options?.verbose,
// Support cross-account endpoint if enabled and provided in env
// Otherwise default to execution role credentials
Expand Down
2 changes: 1 addition & 1 deletion packages/galileo-sdk/src/chat/dynamodb/lib/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
QueryCommand,
QueryCommandInput,
} from '@aws-sdk/lib-dynamodb';
import { Document } from 'langchain/document';
import { Document } from '@langchain/core/documents';
import { v4 as uuidv4 } from 'uuid';
import { listChatMessageSources } from './sources.js';
import {
Expand Down
5 changes: 3 additions & 2 deletions packages/galileo-sdk/src/chat/dynamodb/message-history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ PDX-License-Identifier: Apache-2.0 */
import { DynamoDBClient, DynamoDBClientConfig } from '@aws-sdk/client-dynamodb';
import { DynamoDBDocumentClient } from '@aws-sdk/lib-dynamodb';

import { Document } from 'langchain/document';
import { BaseMessage, BaseListChatMessageHistory, HumanMessage, AIMessage, StoredMessage } from 'langchain/schema';
import { BaseListChatMessageHistory } from '@langchain/core/chat_history';
import { Document } from '@langchain/core/documents';
import { BaseMessage, HumanMessage, AIMessage, StoredMessage } from '@langchain/core/messages';
import * as lib from './lib/index.js';
import { getLogger } from '../../common/index.js';
import { mapStoredMessagesToChatMessages } from '../../langchain/stores/messages/utils.js';
Expand Down
2 changes: 1 addition & 1 deletion packages/galileo-sdk/src/chat/engine.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */
import '../langchain/patch.js';
import { BaseRetriever } from '@langchain/core/retrievers';
import { BaseLanguageModel } from 'langchain/base_language';
import { BaseRetriever } from 'langchain/schema/retriever';
import { ChatEngineChain, ChatEngineChainFromInput } from './chain.js';
import { ChatEngineConfig, resolveChatEngineConfig } from './config/index.js';
import { DynamoDBChatMessageHistory } from './dynamodb/message-history.js';
Expand Down
2 changes: 1 addition & 1 deletion packages/galileo-sdk/src/chat/search.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */
import { Document } from '@langchain/core/documents';
import { fetch } from 'cross-fetch';
import { Document } from 'langchain/document';
import {
RemoteLangChainRetrieverParams,
RemoteLangChainRetriever,
Expand Down
2 changes: 1 addition & 1 deletion packages/galileo-sdk/src/langchain/output_parsers/pojo.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */

import { BaseOutputParser, OutputParserException } from 'langchain/schema/output_parser';
import { BaseOutputParser, OutputParserException } from '@langchain/core/output_parsers';

export class PojoOutputParser<T extends any = any> extends BaseOutputParser<T> {
static lc_name() {
Expand Down
2 changes: 1 addition & 1 deletion packages/galileo-sdk/src/langchain/patch.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */
import { BaseMessage } from 'langchain/schema';
import { BaseMessage } from '@langchain/core/messages';

// Patch langchain BaseMessage prototype to have "type" getter
// Currently only has _getType() method and we need simple getter for template
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
HumanMessage,
StoredMessage,
SystemMessage,
} from 'langchain/schema';
} from '@langchain/core/messages';

interface StoredMessageV1 {
type: string;
Expand Down
2 changes: 1 addition & 1 deletion packages/galileo-sdk/src/models/adapter.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*! Copyright [Amazon.com](http://amazon.com/), Inc. or its affiliates. All Rights Reserved.
PDX-License-Identifier: Apache-2.0 */
import { BaseSageMakerContentHandler } from 'langchain/llms/sagemaker_endpoint';
import { BaseSageMakerContentHandler } from '@langchain/community/llms/sagemaker_endpoint';
import { set, get, isEmpty } from 'lodash';
import { BaseChatTemplatePartials } from '../prompt/templates/chat/base.js';
import { ChatTemplateTypedRuntimeRecord, PromptTemplateStore } from '../prompt/templates/store/registry.js';
Expand Down
6 changes: 3 additions & 3 deletions packages/galileo-sdk/src/models/llms/anthropic/claude.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
PDX-License-Identifier: Apache-2.0 */
// @ts-ignore - .test files are ignored
import type {} from '@types/jest';
import { AIMessage, HumanMessage, SystemMessage } from 'langchain/schema';
import { CLAUDE_V2_ADAPTER } from './claude';
import { AIMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
import { CLAUDE_ADAPTER } from './claude';
import { resolvePromptTemplateByChainType } from '../../../prompt/templates/store/resolver.js';
import { ChainType } from '../../../schema';
import { ModelAdapter } from '../../adapter';

describe('models/llms/anthropic/claude', () => {
describe('adapter', () => {
const adapter = new ModelAdapter(CLAUDE_V2_ADAPTER);
const adapter = new ModelAdapter(CLAUDE_ADAPTER);

test('should render qa prompt', async () => {
const template = await resolvePromptTemplateByChainType(ChainType.QA, adapter.prompt?.chat?.QA);
Expand Down
Loading

0 comments on commit c0cb406

Please sign in to comment.