-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GitOrigin-RevId: 41b5533e96aa9287a0e54c5e3f0dc6921cd2af6b
- Loading branch information
1 parent
8e4812c
commit 8eb482e
Showing
5 changed files
with
542 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright 2025 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import { | ||
AutoratingAggregatedResults, | ||
generateEvaluationReport, | ||
collectStatementsAndCommentsFromCSV, | ||
} from "./autorating_utils"; | ||
import fs from "fs"; | ||
|
||
describe("Autorating Utils", () => { | ||
describe("collectStatementsAndCommentsFromCSV", () => { | ||
it("should read statements and comments from a CSV file", () => { | ||
const csvFilePath = "test_autorating.csv"; // Create a dummy CSV file for testing | ||
const csvContent = | ||
'"summary","comments","has_hallucination"\n"statement 1","comment 1",1\n"statement 2","comment 2",0'; | ||
fs.writeFileSync(csvFilePath, csvContent); // Write the dummy data to the file | ||
|
||
const result = collectStatementsAndCommentsFromCSV(csvFilePath); | ||
|
||
fs.unlinkSync(csvFilePath); // Remove the test file | ||
|
||
expect(result).toEqual([ | ||
{ statement: "statement 1", comments: "comment 1" }, | ||
{ statement: "statement 2", comments: "comment 2" }, | ||
]); | ||
}); | ||
}); | ||
|
||
describe("generateEvaluationReport", () => { | ||
it("should generate a report with correct percentages and formatting", () => { | ||
const results: AutoratingAggregatedResults = { | ||
totalStatements: 10, | ||
questions: { | ||
"Question 1": { pass: 7, fail: 2, unsure: 1 }, | ||
"Question 2": { pass: 5, fail: 5, unsure: 0 }, | ||
}, | ||
}; | ||
const totalRuntimeMinutes = 5.25; | ||
|
||
const report = generateEvaluationReport(results, totalRuntimeMinutes); | ||
|
||
expect(report).toContain("Summary Evaluation Report"); | ||
expect(report).toContain("Total statements: 10"); | ||
|
||
expect(report).toContain("Question 1"); | ||
expect(report).toContain("Pass: 70% (7/10)"); | ||
expect(report).toContain("Fail: 20% (2/10)"); | ||
expect(report).toContain("Unsure: 10% (1/10)"); | ||
|
||
expect(report).toContain("Question 2"); | ||
expect(report).toContain("Pass: 50% (5/10)"); | ||
expect(report).toContain("Fail: 50% (5/10)"); | ||
expect(report).toContain("Unsure: 0% (0/10)"); | ||
|
||
expect(report).toContain("Total runtime: 5.25 minutes"); | ||
}); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
// Copyright 2025 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
// Utility functions and types for automated evaluation of summarization results using LLMs. | ||
|
||
import fs from "fs"; | ||
import { parse } from "csv-parse/sync"; | ||
|
||
/** | ||
* Represents a statement and its corresponding comments for evaluation. | ||
*/ | ||
export interface StatementWithComments { | ||
/** | ||
* The summary statement to be evaluated. | ||
*/ | ||
statement: string; | ||
/** | ||
* The comments associated with the statement. | ||
*/ | ||
comments: string; | ||
} | ||
|
||
/** | ||
* Represents aggregated results for autorating evaluations. | ||
*/ | ||
export interface AutoratingAggregatedResults { | ||
/** | ||
* Total number of statements evaluated. | ||
*/ | ||
totalStatements: number; | ||
/** | ||
* Evaluation results broken down by question. Each question maps to pass/fail/unsure counts. | ||
*/ | ||
questions: { | ||
[question: string]: { | ||
pass: number; | ||
fail: number; | ||
unsure: number; | ||
}; | ||
}; | ||
} | ||
|
||
/** | ||
* Reads statements and comments from a CSV file and returns them as an array of StatementWithComments objects. | ||
* | ||
* The CSV file is expected to have columns for 'summary' and 'comments'. | ||
* | ||
* @param csvFilePath The path to the CSV file. | ||
* @returns An array of StatementWithComments objects. | ||
* @throws Error if the CSV file cannot be read or parsed. | ||
*/ | ||
export function collectStatementsAndCommentsFromCSV(csvFilePath: string): StatementWithComments[] { | ||
const statementsAndComments: StatementWithComments[] = []; | ||
try { | ||
const csvFileContent = fs.readFileSync(csvFilePath, "utf8"); | ||
|
||
const csvRecords = parse(csvFileContent, { | ||
columns: true, | ||
skip_empty_lines: true, | ||
}); | ||
|
||
for (const record of csvRecords) { | ||
statementsAndComments.push({ | ||
statement: record.summary, | ||
comments: record.comments, | ||
}); | ||
} | ||
} catch (error) { | ||
console.error("Failed to read the input file:", error); | ||
} | ||
|
||
return statementsAndComments; | ||
} | ||
|
||
/** | ||
* Generates a summary evaluation report based on aggregated autorating results. | ||
* @param results Aggregated results from the autorating process. | ||
* @param totalRuntimeMin Total runtime of the evaluation in minutes. | ||
* @returns A formatted report string. | ||
*/ | ||
export function generateEvaluationReport( | ||
results: AutoratingAggregatedResults, | ||
totalRuntimeMin: number | ||
): string { | ||
let report = "Summary Evaluation Report\n\n"; | ||
report += `Total statements: ${results.totalStatements}\n\n`; | ||
for (const question in results.questions) { | ||
const counts = results.questions[question]; | ||
const totalAnswers = counts.pass + counts.fail + counts.unsure; | ||
report += `${question}\n`; | ||
report += `Pass: ${((counts.pass / totalAnswers) * 100).toFixed(0)}% (${ | ||
counts.pass | ||
}/${totalAnswers})\n`; | ||
report += `Fail: ${((counts.fail / totalAnswers) * 100).toFixed(0)}% (${ | ||
counts.fail | ||
}/${totalAnswers})\n`; | ||
report += `Unsure: ${((counts.unsure / totalAnswers) * 100).toFixed(0)}% (${ | ||
counts.unsure | ||
}/${totalAnswers})\n`; | ||
report += "\n"; // Add a newline for better readability | ||
} | ||
report += `Total runtime: ${totalRuntimeMin.toFixed(2)} minutes\n`; | ||
return report; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
// Copyright 2025 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import { rateHallucination } from "./hallucination_autorater"; | ||
import { VertexModel } from "../../src/models/vertex_model"; | ||
import { StatementWithComments } from "./autorating_utils"; | ||
import fs from "fs"; | ||
import * as path from "path"; | ||
|
||
jest.mock("../../src/models/vertex_model"); // Mock the VertexModel | ||
|
||
describe("rateHallucination", () => { | ||
let mockModel: jest.Mocked<VertexModel>; | ||
const mockOutputDir = "test_output"; | ||
|
||
beforeEach(() => { | ||
mockModel = new VertexModel("", "", "") as jest.Mocked<VertexModel>; // Create a mocked instance | ||
// Ensure output directory exists and is empty | ||
if (!fs.existsSync(mockOutputDir)) { | ||
fs.mkdirSync(mockOutputDir); | ||
} else { | ||
fs.rmSync(mockOutputDir, { recursive: true, force: true }); // Clean up after previous tests | ||
fs.mkdirSync(mockOutputDir); | ||
} | ||
}); | ||
|
||
afterEach(() => { | ||
fs.rmSync(mockOutputDir, { recursive: true, force: true }); // Clean up after each test | ||
}); | ||
|
||
it("should correctly process summaries and generate report", async () => { | ||
const summaries: StatementWithComments[] = [ | ||
{ statement: "Statement 1", comments: "Comment 1" }, | ||
{ statement: "Statement 2", comments: "Comment 2" }, | ||
]; | ||
const mockResponseData = { | ||
analysis: "Test analysis", | ||
answer: "YES", | ||
explanation: "Test explanation", | ||
}; | ||
mockModel.generateData.mockResolvedValue(mockResponseData); // Mock generateData to resolve with mock data | ||
|
||
await rateHallucination(mockModel, summaries, mockOutputDir); | ||
|
||
// Check if the files were created | ||
const csvPath = path.join(mockOutputDir, "hallucination_autoratings.csv"); | ||
const reportPath = path.join(mockOutputDir, "hallucination_report.txt"); | ||
expect(fs.existsSync(csvPath)).toBe(true); | ||
expect(fs.existsSync(reportPath)).toBe(true); | ||
|
||
// Check some of the CSV content and aggregated results | ||
const csvContent = fs.readFileSync(csvPath, "utf8"); | ||
expect(csvContent).toContain("Statement 1"); | ||
expect(csvContent).toContain("YES"); // Hallucination result for Statement 1 | ||
|
||
// Check report content | ||
const reportContent = fs.readFileSync(reportPath, "utf8"); | ||
expect(reportContent).toContain("Summary Evaluation Report"); | ||
expect(reportContent).toContain("Total statements: 2"); | ||
}); | ||
|
||
it("should handle LLM errors gracefully", async () => { | ||
const summaries: StatementWithComments[] = [ | ||
{ statement: "Statement 1", comments: "Comment 1" }, | ||
]; | ||
mockModel.generateData.mockRejectedValue(new Error("LLM Error")); // Mock an LLM error | ||
const consoleErrorSpy = jest.spyOn(console, "error"); | ||
|
||
await rateHallucination(mockModel, summaries, mockOutputDir); | ||
|
||
expect(consoleErrorSpy).toHaveBeenCalledWith( | ||
"Error during LLM call or parsing:", | ||
expect.any(Error) | ||
); | ||
|
||
// Check for NULL values in CSV due to the error | ||
const csvPath = path.join(mockOutputDir, "hallucination_autoratings.csv"); | ||
expect(fs.existsSync(csvPath)).toBe(true); | ||
const csvContent = fs.readFileSync(csvPath, "utf8"); | ||
expect(csvContent).toContain("NULL"); | ||
|
||
consoleErrorSpy.mockRestore(); | ||
}); | ||
|
||
it("should handle invalid responses from LLM", async () => { | ||
const summaries: StatementWithComments[] = [ | ||
{ statement: "Statement 1", comments: "Comment 1" }, | ||
]; | ||
mockModel.generateData.mockResolvedValue(null); // Mock invalid response | ||
const consoleWarnSpy = jest.spyOn(console, "warn"); | ||
|
||
await rateHallucination(mockModel, summaries, mockOutputDir); | ||
|
||
expect(consoleWarnSpy).toHaveBeenCalledWith( | ||
"Skipping statement due to LLM error or invalid response." | ||
); | ||
consoleWarnSpy.mockRestore(); | ||
}); | ||
}); |
Oops, something went wrong.