diff --git a/packages/cursorless-engine/src/util/allocateHats/HatMetrics.ts b/packages/cursorless-engine/src/util/allocateHats/HatMetrics.ts index cd747f57a4..1297222540 100644 --- a/packages/cursorless-engine/src/util/allocateHats/HatMetrics.ts +++ b/packages/cursorless-engine/src/util/allocateHats/HatMetrics.ts @@ -1,5 +1,9 @@ -import { CompositeKeyMap, HatStability, TokenHat } from "@cursorless/common"; -import { memoize, min } from "lodash"; +import { + CompositeKeyMap, + DefaultMap, + HatStability, + TokenHat, +} from "@cursorless/common"; import { HatCandidate } from "./allocateHats"; /** @@ -37,26 +41,10 @@ export function hatOldTokenRank( }; } -/** - * @param tokenRank The rank of the current token, so that we don't consider - * higher ranked tokens (which already have been assigned hats) - * @param graphemeTokenRanks A map from graphemes to an ordered list of the - * ranks of tokens containing the grapheme - * @returns A metric which returns the minimum token rank among lower ranked - * tokens that contain the hat's grapheme (or Infinity if the grapheme doesn't - * appear in any lower ranked tokens) - */ -export function minimumTokenRankContainingGrapheme( - tokenRank: number, - graphemeTokenRanks: { [key: string]: number[] }, +export function leastPopularGrapheme( + graphemePopularity: DefaultMap, ): HatMetric { - const coreMetric = memoize((graphemeText: string): number => { - return ( - min(graphemeTokenRanks[graphemeText].filter((r) => r > tokenRank)) ?? - Infinity - ); - }); - return ({ grapheme: { text } }) => coreMetric(text); + return ({ grapheme: { text } }) => -graphemePopularity.get(text); } /** diff --git a/packages/cursorless-engine/src/util/allocateHats/allocateHats.ts b/packages/cursorless-engine/src/util/allocateHats/allocateHats.ts index 09bb49fc00..5a29ca2e0a 100644 --- a/packages/cursorless-engine/src/util/allocateHats/allocateHats.ts +++ b/packages/cursorless-engine/src/util/allocateHats/allocateHats.ts @@ -62,11 +62,7 @@ export function allocateHats( * Lookup tables with information about which graphemes / hats appear in which * tokens */ - const context = getHatRankingContext( - rankedTokens, - tokenOldHatMap, - tokenGraphemeSplitter, - ); + const context = getHatRankingContext(rankedTokens, tokenOldHatMap); /* All initially enabled hat styles. */ const enabledHatStyleNames = Object.keys(enabledHatStyles); @@ -83,10 +79,25 @@ export function allocateHats( () => [...enabledHatStyleNames], ); + // For every token, add that token's score to all the graphemes in the token. + // TODO: move "graphemes for tokens" into getRankedTokens + // to avoid recalculating it every time. + const graphemePopularity = new DefaultMap(() => 0); + rankedTokens.forEach(({ token }) => { + tokenGraphemeSplitter + .getTokenGraphemes(token.text) + .forEach(({ text: graphemeText }) => { + graphemePopularity.set( + graphemeText, + graphemePopularity.get(graphemeText) + 1, + ); + }); + }); + // Iterate through tokens in order of decreasing rank, assigning each one a // hat return rankedTokens - .map(({ token, rank: tokenRank }) => { + .map(({ token }) => { /** * All hats for the graphemes in this token that weren't taken by a * higher ranked token @@ -101,11 +112,21 @@ export function allocateHats( const chosenHat = chooseTokenHat( context, hatStability, - tokenRank, tokenOldHatMap.get(token), + graphemePopularity, tokenRemainingHatCandidates, ); + // Remove the token from the grapheme popularity contest. + tokenGraphemeSplitter + .getTokenGraphemes(token.text) + .forEach(({ text: graphemeText }) => { + graphemePopularity.set( + graphemeText, + graphemePopularity.get(graphemeText) - 1, + ); + }); + // If there are no hats left for the graphemes in this token, the token // will get no hat if (chosenHat == null) { diff --git a/packages/cursorless-engine/src/util/allocateHats/chooseTokenHat.ts b/packages/cursorless-engine/src/util/allocateHats/chooseTokenHat.ts index c1e583677c..9d88b6b48a 100644 --- a/packages/cursorless-engine/src/util/allocateHats/chooseTokenHat.ts +++ b/packages/cursorless-engine/src/util/allocateHats/chooseTokenHat.ts @@ -1,10 +1,10 @@ -import { HatStability, TokenHat } from "@cursorless/common"; +import { DefaultMap, HatStability, TokenHat } from "@cursorless/common"; import { HatCandidate } from "./allocateHats"; import { RankingContext } from "./getHatRankingContext"; import { hatOldTokenRank, isOldTokenHat, - minimumTokenRankContainingGrapheme, + leastPopularGrapheme, negativePenalty, penaltyEquivalenceClass, } from "./HatMetrics"; @@ -48,10 +48,10 @@ import { maxByFirstDiffering } from "./maxByFirstDiffering"; * @returns The chosen hat, or `undefined` if {@link candidates} was empty */ export function chooseTokenHat( - { hatOldTokenRanks, graphemeTokenRanks }: RankingContext, + { hatOldTokenRanks }: RankingContext, hatStability: HatStability, - tokenRank: number, oldTokenHat: TokenHat | undefined, + graphemePopularity: DefaultMap, candidates: HatCandidate[], ): HatCandidate | undefined { // We narrow down the candidates by a series of criteria until there is only @@ -71,8 +71,7 @@ export function chooseTokenHat( // 4. Narrow to the hats with the lowest penalty negativePenalty, - // 5. Prefer hats that sit on a grapheme that doesn't appear in any highly - // ranked token - minimumTokenRankContainingGrapheme(tokenRank, graphemeTokenRanks), + // 5. Avoid popular graphemes + leastPopularGrapheme(graphemePopularity), ])!; } diff --git a/packages/cursorless-engine/src/util/allocateHats/getHatRankingContext.ts b/packages/cursorless-engine/src/util/allocateHats/getHatRankingContext.ts index 8a5aa42d4c..0a460a5686 100644 --- a/packages/cursorless-engine/src/util/allocateHats/getHatRankingContext.ts +++ b/packages/cursorless-engine/src/util/allocateHats/getHatRankingContext.ts @@ -4,7 +4,6 @@ import { Token, TokenHat, } from "@cursorless/common"; -import { TokenGraphemeSplitter } from "../../tokenGraphemeSplitter"; import { RankedToken } from "./getRankedTokens"; export interface RankingContext { @@ -19,53 +18,25 @@ export interface RankingContext { }, number >; - - /** - * Maps from a grapheme to the list of ranks of the tokens in which the - * given grapheme appears. - */ - graphemeTokenRanks: { - [key: string]: number[]; - }; } export function getHatRankingContext( tokens: RankedToken[], oldTokenHatMap: CompositeKeyMap, - tokenGraphemeSplitter: TokenGraphemeSplitter, ): RankingContext { - const graphemeTokenRanks: { - [key: string]: number[]; - } = {}; - const hatOldTokenRanks = new CompositeKeyMap< { grapheme: string; hatStyle: HatStyleName }, number >(({ grapheme, hatStyle }) => [grapheme, hatStyle]); - tokens.forEach(({ token, rank }) => { + tokens.forEach(({ token }, index) => { const existingTokenHat = oldTokenHatMap.get(token); if (existingTokenHat != null) { - hatOldTokenRanks.set(existingTokenHat, rank); + hatOldTokenRanks.set(existingTokenHat, -index); } - tokenGraphemeSplitter - .getTokenGraphemes(token.text) - .forEach(({ text: graphemeText }) => { - let tokenRanksForGrapheme: number[]; - - if (graphemeText in graphemeTokenRanks) { - tokenRanksForGrapheme = graphemeTokenRanks[graphemeText]; - } else { - tokenRanksForGrapheme = []; - graphemeTokenRanks[graphemeText] = tokenRanksForGrapheme; - } - - tokenRanksForGrapheme.push(rank); - }); }); return { hatOldTokenRanks, - graphemeTokenRanks, }; } diff --git a/packages/cursorless-engine/src/util/allocateHats/getRankedTokens.ts b/packages/cursorless-engine/src/util/allocateHats/getRankedTokens.ts index 3649502aa6..5ddbf33679 100644 --- a/packages/cursorless-engine/src/util/allocateHats/getRankedTokens.ts +++ b/packages/cursorless-engine/src/util/allocateHats/getRankedTokens.ts @@ -44,7 +44,9 @@ export function getRankedTokens( ), ); - return tokens.map((token, index) => ({ token, rank: -index })); + return tokens.map((token, index) => ({ + token, + })); }); } @@ -67,11 +69,4 @@ function getRankedEditors( export interface RankedToken { token: Token; - - /** - * A number indicating how likely the token is to be used. Tokens closer to - * the cursor will be considered more likely to be used, and will receive a - * higher rank, causing them to be assigned better hats. - */ - rank: number; }