Skip to content

Commit

Permalink
refactor(scoring): improve context relevancy scoring
Browse files Browse the repository at this point in the history
Remove MIN_RESULTS constant and use MULTI_FILE_THRESHOLD consistently for
minimum results filtering. Improve scoring by adding base score to symbol
matches and simplify spatial distance calculation. Add score bonus for
current file context to improve relevancy of local results.

These changes make the context selection more accurate by:
- Using consistent thresholds
- Better handling of symbol matching scores
- Prioritizing local context appropriately
  • Loading branch information
deathbeam committed Feb 19, 2025
1 parent 5014b41 commit 0884422
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions lua/CopilotChat/context.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,15 @@ local OFF_SIDE_RULE_LANGUAGES = {
local MIN_SYMBOL_SIMILARITY = 0.3
local MIN_SEMANTIC_SIMILARITY = 0.4
local MULTI_FILE_THRESHOLD = 5
local MIN_RESULTS = 3
local MAX_FILES = 2500

--- Compute the cosine similarity between two vectors
---@param a table<number>
---@param b table<number>
---@param def number
---@return number
local function spatial_distance_cosine(a, b, def)
local function spatial_distance_cosine(a, b)
if not a or not b then
return def or 0
return 0
end

local dot_product = 0
Expand All @@ -103,8 +101,9 @@ end
local function data_ranked_by_relatedness(query, data, min_similarity)
local results = {}
for _, item in ipairs(data) do
local similarity = spatial_distance_cosine(item.embedding, query.embedding, item.score)
table.insert(results, vim.tbl_extend('force', item, { score = similarity }))
local score = spatial_distance_cosine(item.embedding, query.embedding)
score = score or item.score or 0
table.insert(results, vim.tbl_extend('force', item, { score = score }))
end

table.sort(results, function(a, b)
Expand All @@ -114,7 +113,7 @@ local function data_ranked_by_relatedness(query, data, min_similarity)
-- Take top MAX_RESULTS items that meet threshold, or at least MIN_RESULTS items
local filtered = {}
for i, result in ipairs(results) do
if (result.score >= min_similarity) or (i <= MIN_RESULTS) then
if (result.score >= min_similarity) or (i <= MULTI_FILE_THRESHOLD) then
table.insert(filtered, result)
end
end
Expand Down Expand Up @@ -175,7 +174,6 @@ local function data_ranked_by_symbols(query, data, min_similarity)
local max_score = 0

for _, entry in ipairs(data) do
local score = entry.score or 0
local basename = vim.fn.fnamemodify(entry.filename, ':t'):gsub('%..*$', '')

-- Get trigrams for basename and compound version
Expand All @@ -187,7 +185,7 @@ local function data_ranked_by_symbols(query, data, min_similarity)
local compound_sim = trigram_similarity(query_trigrams, compound_trigrams)

-- Take best match
score = math.max(name_sim, compound_sim)
local score = (entry.score or 0) + math.max(name_sim, compound_sim)

-- Add symbol matches
if entry.symbols then
Expand Down Expand Up @@ -221,7 +219,7 @@ local function data_ranked_by_symbols(query, data, min_similarity)
-- Filter results while preserving top scores
local filtered_results = {}
for i, result in ipairs(results) do
if (result.score >= min_similarity) or (i <= MIN_RESULTS) then
if (result.score >= min_similarity) or (i <= MULTI_FILE_THRESHOLD) then
table.insert(filtered_results, result)
end
end
Expand Down Expand Up @@ -408,6 +406,7 @@ function M.files(winnr, with_content)
content = table.concat(chunk, '\n'),
filename = chunk_name,
filetype = 'text',
score = 0.2, -- Score bonus
})
end

Expand Down

0 comments on commit 0884422

Please sign in to comment.