Skip to content

Commit

Permalink
Multiple providers support
Browse files Browse the repository at this point in the history
  • Loading branch information
arshad-yaseen committed Aug 19, 2024
1 parent 0d8f18a commit 003657a
Show file tree
Hide file tree
Showing 27 changed files with 482 additions and 344 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"eslint": "^8.57.0",
"groq-sdk": "^0.3.2",
"monaco-editor": "^0.50.0",
"openai": "^4.56.0",
"prettier": "^3.2.5",
"release-it": "^17.2.1",
"tsup": "^8.0.2",
Expand Down
293 changes: 179 additions & 114 deletions pnpm-lock.yaml

Large diffs are not rendered by default.

65 changes: 65 additions & 0 deletions src/classes/completion-cache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import {CompletionCacheItem, EditorModel, EditorPosition} from '../types';
import {getTextBeforeCursorInLine} from '../utils/editor';

/**
* CompletionCache class manages a cache of completion code completions.
* It provides methods to get, add, and clear completion cache items.
* The cache implements a First-In-First-Out (FIFO) mechanism.
*/
export class CompletionCache {
private static readonly MAX_CACHE_SIZE = 10;
private cache: ReadonlyArray<CompletionCacheItem> = [];

public getCompletionCache(
position: Readonly<EditorPosition>,
model: Readonly<EditorModel>,
): ReadonlyArray<CompletionCacheItem> {
return this.cache.filter(cache =>
this.isCacheItemValid(cache, position, model),
);
}

public addCompletionCache(cacheItem: Readonly<CompletionCacheItem>): void {
this.cache = [
...this.cache.slice(-(CompletionCache.MAX_CACHE_SIZE - 1)),
cacheItem,
];
}

public clearCompletionCache(): void {
this.cache = [];
}

private isCacheItemValid(
cache: Readonly<CompletionCacheItem>,
position: Readonly<EditorPosition>,
model: Readonly<EditorModel>,
): boolean {
const currentValueInRange = model.getValueInRange(cache.range);
const currentTextBeforeCursorInLine = getTextBeforeCursorInLine(
position,
model,
);

return (
currentTextBeforeCursorInLine.startsWith(cache.textBeforeCursorInLine) &&
this.isPositionValid(cache, position, currentValueInRange)
);
}

private isPositionValid(
cache: Readonly<CompletionCacheItem>,
position: Readonly<EditorPosition>,
currentValueInRange: string,
): boolean {
return (
(cache.range.startLineNumber === position.lineNumber &&
position.column === cache.range.startColumn) ||
(cache.completion.startsWith(currentValueInRange) &&
cache.range.startLineNumber === position.lineNumber &&
position.column >=
cache.range.startColumn - currentValueInRange.length &&
position.column <= cache.range.endColumn)
);
}
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,38 @@
import {EditorModel, EditorPosition} from '../../types';
import {getTextBeforeCursor} from '../../utils';
import {EditorModel, EditorPosition} from '../types';
import {getTextBeforeCursor} from '../utils';

/**
* This class is responsible for formatting code completions
* to ensure that they are displayed correctly in the editor.
*/
export class CompletionFormatter {
private formattedCompletion = '';
private originalCompletion = '';
private readonly model: EditorModel;
private readonly cursorPosition: EditorPosition;
private readonly lineCount: number;
private readonly model: Readonly<EditorModel>;
private readonly cursorPosition: Readonly<EditorPosition>;
private originalCompletion: string = '';
private formattedCompletion: string = '';

constructor(model: EditorModel, position: EditorPosition) {
private constructor(
model: Readonly<EditorModel>,
position: Readonly<EditorPosition>,
) {
this.model = model;
this.cursorPosition = position;
this.lineCount = model.getLineCount();
}

// Remove blank lines from the completion
private ignoreBlankLines(): this {
public static create(
model: Readonly<EditorModel>,
position: Readonly<EditorPosition>,
): CompletionFormatter {
return new CompletionFormatter(model, position);
}

public setCompletion(completion: string): CompletionFormatter {
this.originalCompletion = completion;
this.formattedCompletion = completion;
return this;
}

public ignoreBlankLines(): CompletionFormatter {
if (
this.formattedCompletion.trimStart() === '' &&
this.originalCompletion !== '\n'
Expand All @@ -29,17 +42,14 @@ export class CompletionFormatter {
return this;
}

// Normalize code or text by trimming whitespace
private normalise(text: string): string {
return text?.trim() ?? '';
}

// Remove duplicates from the start of the completion
private removeDuplicatesFromStartOfCompletion(): this {
public removeDuplicatesFromStartOfCompletion(): CompletionFormatter {
const before = getTextBeforeCursor(this.cursorPosition, this.model).trim();
const completion = this.normalise(this.formattedCompletion);

// Handle start duplicates
let startOverlapLength = 0;
const maxStartLength = Math.min(completion.length, before.length);
for (let length = 1; length <= maxStartLength; length++) {
Expand All @@ -52,7 +62,6 @@ export class CompletionFormatter {
}
}

// Apply the trimming
if (startOverlapLength > 0) {
this.formattedCompletion =
this.formattedCompletion.slice(startOverlapLength);
Expand All @@ -61,12 +70,11 @@ export class CompletionFormatter {
return this;
}

// Prevent suggesting completions that duplicate existing lines
private preventDuplicateLines(): this {
public preventDuplicateLines(): CompletionFormatter {
for (
let nextLineIndex = this.cursorPosition.lineNumber + 1;
nextLineIndex < this.cursorPosition.lineNumber + 3 &&
nextLineIndex < this.lineCount;
nextLineIndex < this.model.getLineCount();
nextLineIndex++
) {
const line = this.model.getLineContent(nextLineIndex);
Expand All @@ -78,14 +86,12 @@ export class CompletionFormatter {
return this;
}

// Remove any trailing line breaks
public removeInvalidLineBreaks(): this {
public removeInvalidLineBreaks(): CompletionFormatter {
this.formattedCompletion = this.formattedCompletion.trimEnd();
return this;
}

// Remove leading whitespace that would push the completion past the cursor position
private trimStart(): this {
public trimStart(): CompletionFormatter {
const firstNonSpaceIndex = this.formattedCompletion.search(/\S/);
if (firstNonSpaceIndex > this.cursorPosition.column - 1) {
this.formattedCompletion =
Expand All @@ -94,17 +100,7 @@ export class CompletionFormatter {
return this;
}

// Apply all formatting rules to the completion
public format(completion: string): string {
this.originalCompletion = completion;
this.formattedCompletion = completion;

this.ignoreBlankLines()
.removeDuplicatesFromStartOfCompletion()
.preventDuplicateLines()
.removeInvalidLineBreaks()
.trimStart();

public build(): string {
return this.formattedCompletion;
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import {EditorModel, EditorPosition} from '../../types';
import {
isCharAfterCursor,
isCursorAtStartWithTextAround,
} from '../../utils/completion';
import {EditorModel, EditorPosition} from '../types';
import {isCharAfterCursor, isCursorAtStartWithTextAround} from '../utils';

export class CompletionValidator {
private cursorPosition: EditorPosition;
Expand Down
2 changes: 0 additions & 2 deletions src/classes/completion/index.ts

This file was deleted.

45 changes: 33 additions & 12 deletions src/classes/copilot.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,38 @@
import {
COMPLETION_API_ENDPOINT,
COMPLETION_MODEL_IDS,
COMPLETION_PROVIDER_MODEL_MAP,
DEFAULT_COMPLETION_CREATE_PARAMS,
DEFAULT_COMPLETION_MODEL,
GROQ_COMPLETION_API_ENDPOINT,
DEFAULT_COMPLETION_PROVIDER,
} from '../constants';
import {ErrorContext, handleError} from '../error';
import {generateSystemPrompt, generateUserPrompt} from '../helpers';
import {
Completion,
CompletionCreateParams,
CompletionMetadata,
CompletionModel,
CompletionProvider,
CompletionRequest,
CompletionResponse,
CopilotOptions,
GroqCompletion,
GroqCompletionCreateParams,
Endpoint,
} from '../types';
import {HTTP} from '../utils';
import {HTTP, joinWithAnd} from '../utils';

/**
* Copilot class for handling completions using the Groq API.
*/
export class Copilot {
private readonly apiKey: string;
private readonly model: CompletionModel;
private readonly provider: CompletionProvider;

/**
* Initializes the Copilot with an API key and optional configuration.
* @param {string} apiKey - The Groq API key.
* @param {CopilotOptions} [options] - Optional parameters to configure the completion model.
* @param {CopilotOptions<CompletionProvider>} [options] - Optional parameters to configure the completion model.
* @throws {Error} If the API key is not provided.
*/
constructor(apiKey: string, options?: CopilotOptions) {
Expand All @@ -37,6 +42,7 @@ export class Copilot {

this.apiKey = apiKey;
this.model = options?.model || DEFAULT_COMPLETION_MODEL;
this.provider = options?.provider || DEFAULT_COMPLETION_PROVIDER;
}

/**
Expand All @@ -50,11 +56,13 @@ export class Copilot {
try {
const body = this.createRequestBody(completionMetadata);
const headers = this.createHeaders();
const endpoint = this.getEndpoint();

const completion = await HTTP.POST<
GroqCompletion,
GroqCompletionCreateParams
>(GROQ_COMPLETION_API_ENDPOINT, body, {headers});
const completion = await HTTP.POST<Completion, CompletionCreateParams>(
endpoint,
body,
{headers},
);

if (!completion.choices?.length) {
throw new Error('No completion choices received from API');
Expand All @@ -63,16 +71,29 @@ export class Copilot {
return {completion: completion.choices[0].message.content};
} catch (_err) {
handleError(_err, ErrorContext.COPILOT_COMPLETION_FETCH);
return {error: 'Failed to generate completion'};
return {error: 'Failed to fetch completion', completion: null};
}
}

private getEndpoint(): Endpoint {
return COMPLETION_API_ENDPOINT[this.provider];
}

private getModelId(): string {
if (!COMPLETION_PROVIDER_MODEL_MAP[this.provider].includes(this.model)) {
throw new Error(
`Model ${this.model} is not supported by ${this.provider} provider. Supported models: ${joinWithAnd(COMPLETION_PROVIDER_MODEL_MAP[this.provider])}`,
);
}
return COMPLETION_MODEL_IDS[this.model];
}

private createRequestBody(
completionMetadata: CompletionMetadata,
): GroqCompletionCreateParams {
): CompletionCreateParams {
return {
...DEFAULT_COMPLETION_CREATE_PARAMS,
model: COMPLETION_MODEL_IDS[this.model],
model: this.getModelId(),
messages: [
{role: 'system', content: generateSystemPrompt(completionMetadata)},
{role: 'user', content: generateUserPrompt(completionMetadata)},
Expand Down
3 changes: 2 additions & 1 deletion src/classes/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from './completion';
export * from './copilot';
export * from './completion-formatter';
export * from './completion-validator';
17 changes: 15 additions & 2 deletions src/constants/completion.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
import {
CompletionModel,
CompletionProvider,
GroqCompletionCreateParamsExcludingModelAndMessages,
} from '../types';

export const COMPLETION_MODEL_IDS: Record<CompletionModel, string> = {
llama: 'llama3-70b-8192',
'gpt-4o-mini': 'gpt-4o-mini',
};

export const COMPLETION_PROVIDER_MODEL_MAP: Record<
CompletionProvider,
CompletionModel[]
> = {
groq: ['llama'],
openai: ['gpt-4o-mini'],
};

export const DEFAULT_COMPLETION_MODEL: CompletionModel = 'llama';
export const DEFAULT_COMPLETION_PROVIDER: CompletionProvider = 'groq';

export const GROQ_COMPLETION_API_ENDPOINT =
'https://api.groq.com/openai/v1/chat/completions';
export const COMPLETION_API_ENDPOINT: Record<CompletionProvider, string> = {
groq: 'https://api.groq.com/openai/v1/chat/completions',
openai: 'https://api.openai.com/v1/chat/completions',
};

export const DEFAULT_COMPLETION_CREATE_PARAMS: GroqCompletionCreateParamsExcludingModelAndMessages =
{
Expand Down
Loading

0 comments on commit 003657a

Please sign in to comment.