Skip to content

Commit

Permalink
feat(jest): add bench mock
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Dec 20, 2023
1 parent b173d18 commit 9df745c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 7 deletions.
17 changes: 12 additions & 5 deletions jest/mock.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ const { NativeModules, DeviceEventEmitter } = require('react-native')

if (!NativeModules.RNLlama) {
NativeModules.RNLlama = {
initContext: jest.fn(() => Promise.resolve({
contextId: 1,
gpu: false,
reasonNoGPU: 'Test',
})),
initContext: jest.fn(() =>
Promise.resolve({
contextId: 1,
gpu: false,
reasonNoGPU: 'Test',
}),
),

completion: jest.fn(async (contextId, jobId) => {
const testResult = {
Expand Down Expand Up @@ -150,6 +152,11 @@ if (!NativeModules.RNLlama) {
})),
saveSession: jest.fn(async () => 1),

bench: jest.fn(
async () =>
'["test 3B Q4_0",1600655360,2779683840,16.211304,0.021748,38.570646,1.195800]',
),

releaseContext: jest.fn(() => Promise.resolve()),
releaseAllContexts: jest.fn(() => Promise.resolve()),

Expand Down
14 changes: 13 additions & 1 deletion src/__tests__/__snapshots__/index.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,19 @@ Array [
]
`;

exports[`Mock 2`] = `
exports[`Mock: bench 1`] = `
Object {
"modelDesc": "test 3B Q4_0",
"modelNParams": 2779683840,
"modelSize": 1600655360,
"ppAvg": 16.211304,
"ppStd": 0.021748,
"tgAvg": 38.570646,
"tgStd": 1.1958,
}
`;

exports[`Mock: completion result 1`] = `
Object {
"completion_probabilities": Array [
Object {
Expand Down
4 changes: 3 additions & 1 deletion src/__tests__/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ test('Mock', async () => {
events.push(data)
})
expect(events).toMatchSnapshot()
expect(completionResult).toMatchSnapshot()
expect(completionResult).toMatchSnapshot('completion result')

expect(await context.bench(512, 128, 1, 3)).toMatchSnapshot('bench')

await context.release()
await releaseAllLlama()
Expand Down
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ export class LlamaContext {

async bench(pp: number, tg: number, pl: number, nr: number): Promise<BenchResult> {
const result = await RNLlama.bench(this.id, pp, tg, pl, nr)
console.log(result)
const [
modelDesc,
modelSize,
Expand Down

0 comments on commit 9df745c

Please sign in to comment.