Skip to content

Commit

Permalink
🚧 WIP: retrieving similar packs
Browse files Browse the repository at this point in the history
  • Loading branch information
mikib0 committed Jul 12, 2024
1 parent 280dbc7 commit e787b18
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 2 deletions.
19 changes: 19 additions & 0 deletions server/src/controllers/pack/getSimilarPacks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { publicProcedure } from '../../trpc';
import { getSimilarPacksService } from '../../services/pack/pack.service';
import { z } from 'zod';

/**
* Retrieves packs that are similar to provided pack.
* @param {Object} req - the request object
* @param {Object} res - the response object
* @return {Promise} - a promise that resolves with an array of similar packs
*/
export function getSimilarPacksRoute() {
return publicProcedure
.input(z.object({ id: z.string(), limit: z.number() }))
.query(async (opts) => {
const { id, limit } = opts.input;
const packs = await getSimilarPacksService(id, limit);
return packs;
});
}
1 change: 1 addition & 0 deletions server/src/controllers/pack/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ export * from './duplicatePublicPack';
export * from './getPackById';
export * from './getPublicPacks';
export * from './scorePack';
export * from './getSimilarPacks';
2 changes: 2 additions & 0 deletions server/src/routes/trpcRouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import {
getPackByIdRoute,
getPacksRoute,
getPublicPacksRoute,
getSimilarPacksRoute,
scorePackRoute,
} from '../controllers/pack';

Expand Down Expand Up @@ -157,6 +158,7 @@ export const appRouter = trpcRouter({
deletePack: deletePackRoute(), // Done
scorePack: scorePackRoute(), // Done
duplicatePublicPack: duplicatePublicPackRoute(), // Not Implemented
getSimilarPacks: getSimilarPacksRoute(),
// osm routes - currently breaking tests, see patch file
getPhotonResults: getPhotonResultsRoute(),
getTrailsOSM: getTrailsOSMRoute(),
Expand Down
1 change: 1 addition & 0 deletions server/src/services/pack/addPackService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export const addPackService = async (
await VectorClient.instance.syncRecord({
id: createdPack.id,
content: name,
namespace: 'packs',
});
});

Expand Down
49 changes: 49 additions & 0 deletions server/src/services/pack/getSimilarPacksService.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { DbClient } from '../../db/client';
import { Pack } from '../../drizzle/methods/pack';
import { VectorClient } from '../../vector/client';
import { pack as PacksTable } from '../../db/schema';
import { inArray } from 'drizzle-orm';

/**
* Retrieves packs that are similar to the provided pack.
*
* @param {string} id - ID of the pack for which to retrive similar packs for.
* @param {string} limit - Max number of similar packs to return.
* @return {Promise<any[]>} An array of similar packs.
*/
export async function getSimilarPacksService(id: string, limit: number = 5) {
const packClass = new Pack();
let pack = await packClass.findPack({
id,
});

if (!pack) {
throw new Error(`Pack with id: ${id} not found`);
}

const { matches } = await VectorClient.instance.search(
pack.name,
'packs',
limit,
);

const similarPacksResult = await DbClient.instance
.select()
.from(PacksTable)
.where(
inArray(
PacksTable.id,
matches.matches.map((m) => m.id),
),
);

// add similarity score to packs result
const similarPacks = matches.matches.map((match) => {
return {
...similarPacksResult.find((p) => p.id == match.id),
similarityScore: match.score,
};
});

return similarPacks;
}
1 change: 1 addition & 0 deletions server/src/services/pack/pack.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ export * from './getPackByIdService';
export * from './getPackService';
export * from './scorePackService';
export * from './getPublicPacksService';
export * from './getSimilarPacksService';
5 changes: 3 additions & 2 deletions server/src/vector/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class VectorClient {
// }

// New API-based search method
public async search(queryEmbedding: number[], namespace: string) {
public async search(content: string, namespace: string, topK: number) {
const values = await AiClient.getEmbedding(content);
const url = `https://api.cloudflare.com/client/v4/accounts/${this.accountId}/vectorize/indexes/${this.indexName}/vectors/query`;
const response = await fetch(url, {
method: 'POST',
Expand All @@ -78,7 +79,7 @@ class VectorClient {
Authorization: `Bearer ${this.apiKey}`,
},
body: JSON.stringify({
queries: [{ values: queryEmbedding, topK: 5, namespace }],
queries: [{ values, topK, namespace }],
}),
});

Expand Down

0 comments on commit e787b18

Please sign in to comment.