Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
khram2003 committed May 22, 2024
2 parents 0bc235a + 7e0398d commit cc9087c
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 14 deletions.
18 changes: 16 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "coqpilot",
"displayName": "Сoqpilot",
"description": "An ai based completion extension for Coq interactive prover.",
"displayName": "Coqpilot",
"description": "An AI based completion extension for Coq interactive prover.",
"icon": "etc/img/logo.ico",
"repository": {
"type": "git",
Expand Down Expand Up @@ -46,6 +46,10 @@
{
"command": "coqpilot.perform_completion_in_selection",
"title": "Coqpilot: Try to prove holes (admitted goals) in the selection"
},
{
"command": "coqpilot.shorten_proof_in_selection",
"title": "Coqpilot: Try to shorten the proof in the selection"
}
],
"menus": {
Expand All @@ -54,6 +58,11 @@
"command": "coqpilot.perform_completion_in_selection",
"when": "editorTextFocus && editorHasSelection && resourceLangId == coq",
"group": "queries"
},
{
"command": "coqpilot.shorten_proof_in_selection",
"when": "editorTextFocus && editorHasSelection && resourceLangId == coq",
"group": "queries"
}
]
},
Expand Down Expand Up @@ -217,6 +226,11 @@
"description": "Prompt for the Grazie model to begin chat with. It is sent as a system message, which means it has more impact than other messages.",
"default": "Generate proof of the theorem from user input in Coq. You should only generate proofs in Coq. Never add special comments to the proof. Your answer should be a valid Coq proof. It should start with 'Proof.' and end with 'Qed.'."
},
"refactoringPrompt": {
"type": "string",
"description": "Prompt for the Grazie model to begin chat with.",
"default": "You should only generate proofs in Coq. Never add special comments to the proof. Your answer should be a valid Coq proof. It should start with 'Proof.' and end with 'Qed.'."
},
"maxTokensToGenerate": {
"type": "number",
"description": "Number of tokens that the model is allowed to generate as a response message (i.e. message with proof).",
Expand Down
16 changes: 12 additions & 4 deletions src/coqParser/parseCoqFile.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { readFileSync } from "fs";
import { Position, Range } from "vscode-languageclient";
import { Position } from "vscode-languageclient";

import { CoqLspClient } from "../coqLsp/coqLspClient";
import { FlecheDocument, RangedSpan } from "../coqLsp/coqLspTypes";
Expand Down Expand Up @@ -203,7 +203,8 @@ function parseProof(
let index = spanIndex;
let proven = false;
const proof: ProofStep[] = [];
let endPos: Range | null = null;
const startPos: Position = ast[index].range.start;
let endPos: Position = startPos;
let proofContainsAdmit = false;
let proofHoles: ProofStep[] = [];

Expand All @@ -228,7 +229,7 @@ function parseProof(
);
proof.push(proofStep);
proven = true;
endPos = span.range;
endPos = span.range.end;

if (
checkIfExprEAdmit(getExpr(span)) ||
Expand Down Expand Up @@ -257,9 +258,16 @@ function parseProof(
throw new CoqParsingError("invalid or incomplete proof");
}

const proofRange = {
start: startPos,
end: endPos,
};

console.log(proofRange);

const proofObj = new TheoremProof(
proof,
endPos,
proofRange,
proofContainsAdmit,
proofHoles
);
Expand Down
2 changes: 1 addition & 1 deletion src/coqParser/parsedTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ export class ProofStep {
export class TheoremProof {
constructor(
public proof_steps: ProofStep[],
public end_pos: Range,
public proof_range: Range,
public is_incomplete: boolean,
public holes: ProofStep[]
) {}
Expand Down
66 changes: 65 additions & 1 deletion src/core/inspectSourceFile.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import * as path from "path";
import { CoqLspClient } from "../coqLsp/coqLspClient";

import { parseCoqFile } from "../coqParser/parseCoqFile";
import { ProofStep, Theorem } from "../coqParser/parsedTypes";
import { ProofStep, Theorem, TheoremProof } from "../coqParser/parsedTypes";
import { Uri } from "../utils/uri";

import {
Expand Down Expand Up @@ -42,6 +42,34 @@ export async function inspectSourceFile(
return [completionContexts, sourceFileEnvironmentWithCompleteProofs];
}

export async function inspectSourceShortening(
fileVersion: number,
shouldShortenProof: (proof: TheoremProof) => boolean,
fileUri: Uri,
client: CoqLspClient
): Promise<AnalyzedFile> {
const sourceFileEnvironment = await createSourceFileEnvironment(
fileVersion,
fileUri,
client
);
const completionContexts = await createSohrteningContexts(
fileVersion,
shouldShortenProof,
sourceFileEnvironment.fileTheorems,
fileUri,
client
);
const sourceFileEnvironmentWithCompleteProofs: SourceFileEnvironment = {
...sourceFileEnvironment,
fileTheorems: sourceFileEnvironment.fileTheorems.filter(
(thr) => thr.proof && !thr.proof.is_incomplete
),
};

return [completionContexts, sourceFileEnvironmentWithCompleteProofs];
}

async function createCompletionContexts(
fileVersion: number,
shouldCompleteHole: (hole: ProofStep) => boolean,
Expand Down Expand Up @@ -74,6 +102,42 @@ async function createCompletionContexts(
return completionContexts;
}

async function createSohrteningContexts(
fileVersion: number,
shouldShortenProof: (proof: TheoremProof) => boolean,
fileTheorems: Theorem[],
fileUri: Uri,
client: CoqLspClient
): Promise<CompletionContext[]> {
const proofsToShorten = fileTheorems
.filter((thr) => thr.proof)
.map((thr) => thr.proof!)
.filter(shouldShortenProof);

let completionContexts: CompletionContext[] = [];
if (proofsToShorten.length != 1) {

Check warning on line 118 in src/core/inspectSourceFile.ts

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 4.14)

Expected '!==' and instead saw '!='
return completionContexts;
}

const proofToShorten = proofsToShorten[0];

const goal = await client.getFirstGoalAtPoint(
proofToShorten.proof_range.start,
fileUri,
fileVersion
);
if (!(goal instanceof Error)) {
completionContexts.push({
proofGoal: goal,
prefixEndPosition: proofToShorten.proof_range.start,
admitEndPosition: proofToShorten.proof_range.end,
});
}

Check warning on line 136 in src/core/inspectSourceFile.ts

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 4.14)

Delete `····⏎`

return completionContexts;
}

export async function createSourceFileEnvironment(
fileVersion: number,
fileUri: Uri,
Expand Down
119 changes: 115 additions & 4 deletions src/extension/coqPilot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import {
SourceFileEnvironment,
} from "../core/completionGenerator";
import { CoqProofChecker } from "../core/coqProofChecker";
import { inspectSourceFile } from "../core/inspectSourceFile";
import { inspectSourceFile, inspectSourceShortening } from "../core/inspectSourceFile";

Check warning on line 25 in src/extension/coqPilot.ts

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 4.14)

Replace `·inspectSourceFile,·inspectSourceShortening·` with `⏎····inspectSourceFile,⏎····inspectSourceShortening,⏎`

import { ProofStep } from "../coqParser/parsedTypes";
import { ProofStep, TheoremProof } from "../coqParser/parsedTypes";
import { Uri } from "../utils/uri";

import {
Expand Down Expand Up @@ -73,10 +73,44 @@ export class CoqPilot {
"perform_completion_for_all_admits",
this.performCompletionForAllAdmits.bind(this)
);
this.registerEditorCommand(
"shorten_proof_in_selection",
this.shortenProofUnderCursor.bind(this)
)

Check warning on line 79 in src/extension/coqPilot.ts

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 4.14)

Insert `;`

Check warning on line 79 in src/extension/coqPilot.ts

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 4.14)

Missing semicolon

this.vscodeExtensionContext.subscriptions.push(this);
}

async shortenProofUnderCursor(editor: TextEditor) {
const selection = editor.selection;
this.shortenProofWithProgress(
(proof) => (proof.proof_range.start.line <= selection.start.line) && (proof.proof_range.end.line >= selection.end.line),

Check warning on line 87 in src/extension/coqPilot.ts

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 4.14)

Replace `·(proof.proof_range.start.line·<=·selection.start.line)·&&·(proof.proof_range.end.line·>=·selection.end.line)` with `⏎················proof.proof_range.start.line·<=·selection.start.line·&&⏎················proof.proof_range.end.line·>=·selection.end.line`
editor
);
}

private async shortenProofWithProgress(shouldShorten: (proof: TheoremProof) => boolean, editor: TextEditor) {

Check warning on line 92 in src/extension/coqPilot.ts

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 4.14)

Replace `shouldShorten:·(proof:·TheoremProof)·=>·boolean,·editor:·TextEditor` with `⏎········shouldShorten:·(proof:·TheoremProof)·=>·boolean,⏎········editor:·TextEditor⏎····`
await window.withProgress(
{
location: ProgressLocation.Window,
title: `${pluginId}: In progress`,
},
async () => {
try {
await this.shortenProof(shouldShorten, editor);
} catch (error) {
if (error instanceof SettingsValidationError) {
error.showAsMessageToUser();
} else if (error instanceof Error) {
showMessageToUser(error.message, "error");
console.error(error);
}
}
}
);
}

Check warning on line 111 in src/extension/coqPilot.ts

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 4.14)

Delete `⏎····`


async performCompletionUnderCursor(editor: TextEditor) {
const cursorPosition = editor.selection.active;
this.performSpecificCompletionsWithProgress(
Expand Down Expand Up @@ -127,6 +161,38 @@ export class CoqPilot {
);
}

private async shortenProof(shouldShorten: (proof: TheoremProof) => boolean, editor: TextEditor) {

Check warning on line 164 in src/extension/coqPilot.ts

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 4.14)

Replace `shouldShorten:·(proof:·TheoremProof)·=>·boolean,·editor:·TextEditor` with `⏎········shouldShorten:·(proof:·TheoremProof)·=>·boolean,⏎········editor:·TextEditor⏎····`
const [shorteningContexts, sourceFileEnvironment, processEnvironment] =
await this.prepareForShortening(
shouldShorten,
editor.document.version,
editor.document.uri.fsPath
);

const unsubscribeFromLLMServicesEventsCallback =
subscribeToHandleLLMServicesEvents(
this.globalExtensionState.llmServices,
this.globalExtensionState.eventLogger
);

try {
const shorteningContext = shorteningContexts[0];
let shorterningPromise = this.performSingleCompletion(
shorteningContext,
sourceFileEnvironment,
processEnvironment,
editor,
true
);

await Promise.all([shorterningPromise]);
} finally {
unsubscribeFromLLMServicesEventsCallback();
}
}


Check warning on line 194 in src/extension/coqPilot.ts

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 4.14)

Delete `⏎⏎····`

private async performSpecificCompletions(
shouldCompleteHole: (hole: ProofStep) => boolean,
editor: TextEditor
Expand Down Expand Up @@ -166,7 +232,8 @@ export class CoqPilot {
completionContext: CompletionContext,
sourceFileEnvironment: SourceFileEnvironment,
processEnvironment: ProcessEnvironment,
editor: TextEditor
editor: TextEditor,
refactorMode: boolean = false
) {
const result = await generateCompletion(
completionContext,
Expand All @@ -176,7 +243,10 @@ export class CoqPilot {
);

if (result instanceof SuccessGenerationResult) {
const flatProof = this.prepareCompletionForInsertion(result.data);
let flatProof = this.prepareCompletionForInsertion(result.data);
if (refactorMode) {
flatProof = this.prepareRefactoredCompletionForInsertion(result.data);
}
const vscodeHoleRange = toVSCodeRange({
start: completionContext.prefixEndPosition,
end: completionContext.admitEndPosition,
Expand Down Expand Up @@ -229,6 +299,14 @@ export class CoqPilot {
.trim();
}

private prepareRefactoredCompletionForInsertion(text: string) {
const flatProof = text.replace(/\n/g, " ");
const formattedProof = flatProof.trim()
.slice(1, flatProof.length - 2)
.trim();
return `Proof.\n ${formattedProof}\n Qed.`;
}

private async prepareForCompletions(
shouldCompleteHole: (hole: ProofStep) => boolean,
fileVersion: number,
Expand Down Expand Up @@ -263,6 +341,39 @@ export class CoqPilot {
return [completionContexts, sourceFileEnvironment, processEnvironment];
}

private async prepareForShortening(
shouldShorten: (proof: TheoremProof) => boolean,
fileVersion: number,
filePath: string
): Promise<[CompletionContext[], SourceFileEnvironment, ProcessEnvironment]> {
const fileUri = Uri.fromPath(filePath);
const coqLspServerConfig = CoqLspConfig.createServerConfig();
const coqLspClientConfig = CoqLspConfig.createClientConfig();
const client = new CoqLspClient(coqLspServerConfig, coqLspClientConfig);
const contextTheoremsRanker = buildTheoremsRankerFromConfig();

const coqProofChecker = new CoqProofChecker(client);
const [completionContexts, sourceFileEnvironment] = await inspectSourceShortening(
fileVersion,
shouldShorten,
fileUri,
client
);

const processEnvironment: ProcessEnvironment = {
coqProofChecker: coqProofChecker,
modelsParams: readAndValidateUserModelsParams(
workspace.getConfiguration(pluginId),
this.globalExtensionState.llmServices
),
services: this.globalExtensionState.llmServices,
theoremRanker: contextTheoremsRanker,
};

return [completionContexts, sourceFileEnvironment, processEnvironment];
}


private registerEditorCommand(
command: string,
fn: (editor: TextEditor) => void
Expand Down
5 changes: 4 additions & 1 deletion src/llm/userModelParams.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ export interface UserModelParams {

systemPrompt?: string;

refactoringPrompt?: string;

maxTokensToGenerate?: number;
/**
* Includes tokens that the model generates as an answer message,
Expand Down Expand Up @@ -83,8 +85,9 @@ export const userModelParamsSchema: JSONSchemaType<UserModelParams> = {
properties: {
modelId: { type: "string" },
choices: { type: "number", nullable: true },

systemPrompt: { type: "string", nullable: true },
refactoringPrompt: { type: "string", nullable: true },

maxTokensToGenerate: { type: "number", nullable: true },
tokensLimit: { type: "number", nullable: true },
Expand Down
2 changes: 1 addition & 1 deletion src/test/coqParser/parseCoqFile.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ suite("Coq file parser tests", () => {
expect(theorem.statement).toEqual(theoremData[i].statement);
expect(theorem.statement_range.start).toEqual(theoremData[i].start);
expect(theorem.proof).not.toBeNullish();
expect(theorem.proof?.end_pos.end).toEqual(theoremData[i].end);
expect(theorem.proof?.proof_range.end).toEqual(theoremData[i].end);
expect(theorem.proof?.is_incomplete).toEqual(
theoremData[i].isIncomplete
);
Expand Down

0 comments on commit cc9087c

Please sign in to comment.