Skip to content

Commit

Permalink
finished implementing recommendation alg
Browse files Browse the repository at this point in the history
  • Loading branch information
kylezryr committed Apr 20, 2024
1 parent 3c181e6 commit 8135d84
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 62 deletions.
96 changes: 54 additions & 42 deletions src/app/(tabs)/home/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
fetchFeaturedStoryPreviews,
fetchNewStories,
fetchRecommendedStories,
fetchStoryPreviewById,
} from '../../../queries/stories';
import { StoryCard, StoryPreview, Story } from '../../../queries/types';
import globalStyles from '../../../styles/globalStyles';
Expand All @@ -30,7 +31,16 @@ function HomeScreen() {
const [recommendedStories, setRecommendedStories] = useState<StoryCard[]>([]);
const [newStories, setNewStories] = useState<StoryCard[]>([]);

const setRecentStory = async (recentStories: StoryPreview[]) => {
const getRecentStory = async () => {
try {
const jsonValue = await AsyncStorage.getItem('GWN_RECENT_STORIES_ARRAY');
return jsonValue != null ? JSON.parse(jsonValue) : [];
} catch (error) {
console.log(error);
}
};

const setRecentStory = async (recentStories: StoryCard[]) => {
try {
const jsonValue = JSON.stringify(recentStories);
await AsyncStorage.setItem('GWN_RECENT_STORIES_ARRAY', jsonValue);
Expand All @@ -39,60 +49,77 @@ function HomeScreen() {
}
};

useEffect(() => {
const getRecentStory = async () => {
try {
const jsonValue = await AsyncStorage.getItem(
'GWN_RECENT_STORIES_ARRAY',
);
return jsonValue != null ? JSON.parse(jsonValue) : [];
} catch (error) {
console.log(error);
const handleStoryPreviewPressed = (story: StoryPreview) => {
recentlyViewedStacking(story);
router.push({
pathname: '/story',
params: { storyId: story.id.toString() },
});
};

const handleStoryCardPressed = async (story: StoryCard) => {
const newStoryArray = await fetchStoryPreviewById(story.id);
recentlyViewedStacking(newStoryArray[0]);
router.push({
pathname: '/story',
params: { storyId: story.id.toString() },
});
};

const recentlyViewedStacking = async (story: StoryPreview) => {
const maxArrayLength = 5;
const newRecentlyViewed = [...recentlyViewed];

for (let i = 0; i < recentlyViewed.length; i++) {
if (story.id === recentlyViewed[i].id) {
newRecentlyViewed.splice(i, 1);
break;
}
};
}

if (newRecentlyViewed.length >= maxArrayLength) {
newRecentlyViewed.splice(-1, 1);
}

newRecentlyViewed.splice(0, 0, story);

setRecentStory(newRecentlyViewed);
setRecentlyViewed(newRecentlyViewed);
};

useEffect(() => {
const getRecommendedStories = async () => {
const recentStoryResponse = await getRecentStory();
// console.log('recentStoryResponse', recentStoryResponse);
// setRecentlyViewed(recentStoryResponse);
// console.log('state: recentlyViewed', recentlyViewed);
const recommendedStoriesResponse =
await fetchRecommendedStories(recentStoryResponse);
setRecommendedStories(recommendedStoriesResponse);
// return recommendedStoriesResponse;
};

(async () => {
const [
usernameResponse,
featuredStoryResponse,
featuredStoryDescriptionResponse,
// recommendedStoriesResponse,
newStoriesResponse,
// recentStoryResponse,
recentStoryResponse,
] = await Promise.all([
fetchUsername(user?.id).catch(() => ''),
fetchFeaturedStoryPreviews().catch(() => []),
fetchFeaturedStoriesDescription().catch(() => ''),
// fetchRecommendedStories(recentlyViewed).catch(() => []), // need to set recentlyViewed before
fetchNewStories().catch(() => []),
// getRecentStory(),
getRecentStory(),
]);
setUsername(usernameResponse);
setFeaturedStories(featuredStoryResponse);
setFeaturedStoriesDescription(featuredStoryDescriptionResponse);
// setRecommendedStories(recommendedStoriesResponse);
setNewStories(newStoriesResponse);
// setRecentlyViewed(recentStoryResponse);
setRecentlyViewed(recentStoryResponse);
await getRecommendedStories();
})().finally(() => {
setLoading(false);
});
}, [user]);

useEffect(() => {}, []);

// console.log(recommendedStories);
return (
<SafeAreaView
style={[globalStyles.container, { marginLeft: -8, marginRight: -32 }]}
Expand Down Expand Up @@ -136,12 +163,7 @@ function HomeScreen() {
tags={story.genre_medium
.concat(story.tone)
.concat(story.topic)}
pressFunction={() =>
router.push({
pathname: '/story',
params: { storyId: story.id.toString() },
})
}
pressFunction={() => handleStoryPreviewPressed(story)}
/>
))}
</View>
Expand All @@ -164,12 +186,7 @@ function HomeScreen() {
title={story.title}
author={story.author_name}
authorImage={story.author_image}
pressFunction={() =>
router.push({
pathname: '/story',
params: { storyId: story.id.toString() },
})
}
pressFunction={() => handleStoryCardPressed(story)}
image={story.featured_media}
/>
))}
Expand All @@ -193,12 +210,7 @@ function HomeScreen() {
title={story.title}
author={story.author_name}
authorImage={story.author_image}
pressFunction={() =>
router.push({
pathname: '/story',
params: { storyId: story.id.toString() },
})
}
pressFunction={() => handleStoryCardPressed(story)}
image={story.featured_media}
/>
))}
Expand Down
87 changes: 67 additions & 20 deletions src/queries/stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,81 @@ export async function fetchFeaturedStoriesDescription(): Promise<string> {
}

export async function fetchRecommendedStories(
recentlyViewed: StoryCard[],
): Promise<StoryCard[]> {
const recentlyViewedID = recentlyViewed[0].id; //change to take in multiple stories

const getStoryEmbedding = async () => {
const { data } = await supabase
.from('stories')
.select('embedding')
.eq('id', recentlyViewedID);

if (error) {
console.log(error);
throw new Error(
`An error occured when trying to fetch embeddings: ${error.details}`,
);
} else {
if (data) return data[0].embedding as number;
}
inputStories: StoryPreview[],
): Promise<StoryPreview[]> {
if (inputStories.length == 0) {
return [];
}
const storyIDs = inputStories.map(story => story.id);

//fill storyIDs with 0's if less than 5 ids
for (let n = storyIDs.length; n < 5; n++) {
storyIDs[n] = 0;
}

//get embedding vectors for each of the inputs
const getStoryEmbeddings = async () => {
const embeddings = inputStories.map(async story => {
const { data, error } = await supabase
.from('stories')
.select('embedding')
.eq('id', story.id);

if (error) {
console.log(error);
throw new Error(
`An error occured when trying to fetch embeddings: ${error.details}`,
);
} else {
if (data) {
return data[0].embedding as string;
}
}
});

return await Promise.all(embeddings);
};

const embedding = await getStoryEmbedding();
//get embeddings of every story in inputStory
const embeddingsArray = await getStoryEmbeddings();
const newEmbeddingsArray = [];
for (let k = 0; k < embeddingsArray.length; k++) {
const stringLength = embeddingsArray[k]?.length;
if (stringLength) {
const embedding = embeddingsArray[k]?.substring(1, stringLength - 1);
const formattedEmbedding = embedding?.split(',');
newEmbeddingsArray[k] = formattedEmbedding;
}
}
const embeddingsLength =
newEmbeddingsArray.length > 5 ? 5 : newEmbeddingsArray.length;

//calculate average embedding vector
const averageEmbedding: number[] = [];
for (let m = 0; m < 384; m++) {
averageEmbedding[m] = 0;
}
for (let i = 0; i < embeddingsLength; i++) {
const vector = newEmbeddingsArray[i];
if (vector) {
for (let j = 0; j < vector.length; j++) {
const element = parseFloat(vector[j]);
averageEmbedding[j] += element / embeddingsLength;
}
}
}

const { data, error } = await supabase.rpc(
'fetch_users_recommended_stories',
{
query_embedding: embedding,
query_embedding: averageEmbedding,
match_threshold: 0.0,
match_count: 5,
storyid1: storyIDs[0],
storyid2: storyIDs[1],
storyid3: storyIDs[2],
storyid4: storyIDs[3],
storyid5: storyIDs[4],
},
);

Expand Down

0 comments on commit 8135d84

Please sign in to comment.