Skip to content

Commit

Permalink
grapheme popularity as hat metric
Browse files Browse the repository at this point in the history
  • Loading branch information
josharian committed Jan 20, 2024
1 parent 2b596eb commit 5d3ceba
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 74 deletions.
30 changes: 9 additions & 21 deletions packages/cursorless-engine/src/util/allocateHats/HatMetrics.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { CompositeKeyMap, HatStability, TokenHat } from "@cursorless/common";
import { memoize, min } from "lodash";
import {
CompositeKeyMap,
DefaultMap,
HatStability,
TokenHat,
} from "@cursorless/common";
import { HatCandidate } from "./allocateHats";

/**
Expand Down Expand Up @@ -37,26 +41,10 @@ export function hatOldTokenRank(
};
}

/**
* @param tokenRank The rank of the current token, so that we don't consider
* higher ranked tokens (which already have been assigned hats)
* @param graphemeTokenRanks A map from graphemes to an ordered list of the
* ranks of tokens containing the grapheme
* @returns A metric which returns the minimum token rank among lower ranked
* tokens that contain the hat's grapheme (or Infinity if the grapheme doesn't
* appear in any lower ranked tokens)
*/
export function minimumTokenRankContainingGrapheme(
tokenRank: number,
graphemeTokenRanks: { [key: string]: number[] },
export function leastPopularGrapheme(
graphemePopularity: DefaultMap<string, number>,
): HatMetric {
const coreMetric = memoize((graphemeText: string): number => {
return (
min(graphemeTokenRanks[graphemeText].filter((r) => r > tokenRank)) ??
Infinity
);
});
return ({ grapheme: { text } }) => coreMetric(text);
return ({ grapheme: { text } }) => -graphemePopularity.get(text);
}

/**
Expand Down
35 changes: 28 additions & 7 deletions packages/cursorless-engine/src/util/allocateHats/allocateHats.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ export function allocateHats(
* Lookup tables with information about which graphemes / hats appear in which
* tokens
*/
const context = getHatRankingContext(
rankedTokens,
tokenOldHatMap,
tokenGraphemeSplitter,
);
const context = getHatRankingContext(rankedTokens, tokenOldHatMap);

/* All initially enabled hat styles. */
const enabledHatStyleNames = Object.keys(enabledHatStyles);
Expand All @@ -83,10 +79,25 @@ export function allocateHats(
() => [...enabledHatStyleNames],
);

// For every token, add that token's score to all the graphemes in the token.
// TODO: move "graphemes for tokens" into getRankedTokens
// to avoid recalculating it every time.
const graphemePopularity = new DefaultMap<string, number>(() => 0);
rankedTokens.forEach(({ token }) => {
tokenGraphemeSplitter
.getTokenGraphemes(token.text)
.forEach(({ text: graphemeText }) => {
graphemePopularity.set(
graphemeText,
graphemePopularity.get(graphemeText) + 1,
);
});
});

// Iterate through tokens in order of decreasing rank, assigning each one a
// hat
return rankedTokens
.map<TokenHat | undefined>(({ token, rank: tokenRank }) => {
.map<TokenHat | undefined>(({ token }) => {
/**
* All hats for the graphemes in this token that weren't taken by a
* higher ranked token
Expand All @@ -101,11 +112,21 @@ export function allocateHats(
const chosenHat = chooseTokenHat(
context,
hatStability,
tokenRank,
tokenOldHatMap.get(token),
graphemePopularity,
tokenRemainingHatCandidates,
);

// Remove the token from the grapheme popularity contest.
tokenGraphemeSplitter
.getTokenGraphemes(token.text)
.forEach(({ text: graphemeText }) => {
graphemePopularity.set(
graphemeText,
graphemePopularity.get(graphemeText) - 1,
);
});

// If there are no hats left for the graphemes in this token, the token
// will get no hat
if (chosenHat == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { HatStability, TokenHat } from "@cursorless/common";
import { DefaultMap, HatStability, TokenHat } from "@cursorless/common";
import { HatCandidate } from "./allocateHats";
import { RankingContext } from "./getHatRankingContext";
import {
hatOldTokenRank,
isOldTokenHat,
minimumTokenRankContainingGrapheme,
leastPopularGrapheme,
negativePenalty,
penaltyEquivalenceClass,
} from "./HatMetrics";
Expand Down Expand Up @@ -48,10 +48,10 @@ import { maxByFirstDiffering } from "./maxByFirstDiffering";
* @returns The chosen hat, or `undefined` if {@link candidates} was empty
*/
export function chooseTokenHat(
{ hatOldTokenRanks, graphemeTokenRanks }: RankingContext,
{ hatOldTokenRanks }: RankingContext,
hatStability: HatStability,
tokenRank: number,
oldTokenHat: TokenHat | undefined,
graphemePopularity: DefaultMap<string, number>,
candidates: HatCandidate[],
): HatCandidate | undefined {
// We narrow down the candidates by a series of criteria until there is only
Expand All @@ -71,8 +71,7 @@ export function chooseTokenHat(
// 4. Narrow to the hats with the lowest penalty
negativePenalty,

// 5. Prefer hats that sit on a grapheme that doesn't appear in any highly
// ranked token
minimumTokenRankContainingGrapheme(tokenRank, graphemeTokenRanks),
// 5. Avoid popular graphemes
leastPopularGrapheme(graphemePopularity),
])!;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
Token,
TokenHat,
} from "@cursorless/common";
import { TokenGraphemeSplitter } from "../../tokenGraphemeSplitter";
import { RankedToken } from "./getRankedTokens";

export interface RankingContext {
Expand All @@ -19,53 +18,25 @@ export interface RankingContext {
},
number
>;

/**
* Maps from a grapheme to the list of ranks of the tokens in which the
* given grapheme appears.
*/
graphemeTokenRanks: {
[key: string]: number[];
};
}

export function getHatRankingContext(
tokens: RankedToken[],
oldTokenHatMap: CompositeKeyMap<Token, TokenHat>,
tokenGraphemeSplitter: TokenGraphemeSplitter,
): RankingContext {
const graphemeTokenRanks: {
[key: string]: number[];
} = {};

const hatOldTokenRanks = new CompositeKeyMap<
{ grapheme: string; hatStyle: HatStyleName },
number
>(({ grapheme, hatStyle }) => [grapheme, hatStyle]);

tokens.forEach(({ token, rank }) => {
tokens.forEach(({ token }, index) => {
const existingTokenHat = oldTokenHatMap.get(token);
if (existingTokenHat != null) {
hatOldTokenRanks.set(existingTokenHat, rank);
hatOldTokenRanks.set(existingTokenHat, -index);
}
tokenGraphemeSplitter
.getTokenGraphemes(token.text)
.forEach(({ text: graphemeText }) => {
let tokenRanksForGrapheme: number[];

if (graphemeText in graphemeTokenRanks) {
tokenRanksForGrapheme = graphemeTokenRanks[graphemeText];
} else {
tokenRanksForGrapheme = [];
graphemeTokenRanks[graphemeText] = tokenRanksForGrapheme;
}

tokenRanksForGrapheme.push(rank);
});
});

return {
hatOldTokenRanks,
graphemeTokenRanks,
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ export function getRankedTokens(
),
);

return tokens.map((token, index) => ({ token, rank: -index }));
return tokens.map((token, index) => ({
token,
}));
});
}

Expand All @@ -67,11 +69,4 @@ function getRankedEditors(

export interface RankedToken {
token: Token;

/**
* A number indicating how likely the token is to be used. Tokens closer to
* the cursor will be considered more likely to be used, and will receive a
* higher rank, causing them to be assigned better hats.
*/
rank: number;
}

0 comments on commit 5d3ceba

Please sign in to comment.