-
Notifications
You must be signed in to change notification settings - Fork 236
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Scripts to remove selected vectors, verify gt, and some file conversi…
…ons for Bing filter index
- Loading branch information
Showing
9 changed files
with
872 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import sys | ||
import re | ||
import argparse | ||
from sklearn.metrics import pairwise_distances | ||
import parse_common as pc | ||
|
||
|
||
#This is a very lame implementation of GT computation so that we avoid errors | ||
#and can compare the GTs generated by more sophisticated methods. | ||
#This code assumes that all the data fits in memory | ||
|
||
K = 100 | ||
STR_NO_FILTER = "__NONE__" | ||
dist_fn_sklearn_metric_map = {"l2" : "euclidean", | ||
"cosine" : "cosine", | ||
"mips" : "linear"} | ||
|
||
|
||
ground_truth = {} | ||
|
||
class FilteredVectors: | ||
def __init__(self, filter): | ||
self.data = [] | ||
self.filter = filter | ||
self.orig_id_fil_id_map = {} | ||
self.count = 0 | ||
self.data_mat = None | ||
|
||
def add_vector(self, vector, orig_id): | ||
self.data.append(vector) | ||
self.orig_id_fil_id_map[self.count] = orig_id | ||
self.count += 1 | ||
|
||
def assign_data_mat(self, data_mat): | ||
self.data_mat = data_mat | ||
|
||
def __len__(self): | ||
return len(self.data) if self.data_mat is None else self.data_mat.num_rows | ||
|
||
def __getitem__(self, index): | ||
return self.data[index] | ||
|
||
def __str__(self) -> str: | ||
return self.filter + " filters out " + str(self.count) + " vectors" | ||
|
||
|
||
def parse_filter_line(line, delim_regex, line_number, filter_file_name): | ||
line = line.strip() | ||
filters = re.split(delim_regex, line) | ||
if len(filters) == 0: | ||
raise Exception(f"Filter line: {line} at line number: {line_number} in file {filter_file_name} does not have any filters") | ||
return filters | ||
|
||
def process_filters(filter_file, data_mat, is_query): | ||
filters = [] | ||
filtered_vectors = {} | ||
if filter_file is not None: | ||
with open(filter_file, mode='r', encoding='UTF-8') as f: | ||
line_num = 0 | ||
for line in f: | ||
filters_of_point = parse_filter_line(line, ',', line_num+1, filter_file) | ||
filters.append(filters_of_point) | ||
for filter in filters_of_point: | ||
if filter not in filtered_vectors: | ||
filtered_vectors[filter] = FilteredVectors(filter) | ||
filtered_vectors[filter].add_vector(data_mat.get_vector(line_num), line_num) | ||
line_num += 1 | ||
else: | ||
if not is_query: | ||
filtered_vectors[STR_NO_FILTER] = FilteredVectors(STR_NO_FILTER) | ||
filtered_vectors[STR_NO_FILTER].assign_data_mat(data_mat) | ||
else: | ||
#to simplify the code, we copy the queries one-by-one into the data list. | ||
filtered_vectors[STR_NO_FILTER] = FilteredVectors(STR_NO_FILTER) | ||
for i in range(data_mat.num_rows): | ||
filtered_vectors[STR_NO_FILTER].add_vector(data_mat.get_vector(i), i) | ||
|
||
for filter in filtered_vectors: | ||
print(filtered_vectors[filter]) | ||
|
||
unique_vector_ids = set() | ||
for filtered_vector in filtered_vectors.values(): | ||
unique_vector_ids.update(filtered_vector.orig_id_fil_id_map.values()) | ||
all_ids = set(range(data_mat.num_rows)) | ||
if len(all_ids.difference(unique_vector_ids)) > 0: | ||
raise Exception(f"Missing vectors in filters: {all_ids.difference(unique_vector_ids)}") | ||
|
||
return filters, filtered_vectors | ||
|
||
def compute_filtered_gt(base_filtered_vector, query_filtered_vector, dist_fn): | ||
print(f"Computing GT for filter: {query_filtered_vector.filter}, base count: {len(base_filtered_vector)}, query count: {len(query_filtered_vector)}") | ||
for fil_q_id, query_vector in enumerate(query_filtered_vector.data): | ||
qv = query_vector.reshape(1, -1) | ||
if base_filtered_vector.data_mat is not None: | ||
dist = pairwise_distances(base_filtered_vector.data_mat.data, qv, metric=dist_fn_sklearn_metric_map[dist_fn]) | ||
else: | ||
dist = pairwise_distances(base_filtered_vector.data, qv, metric=dist_fn_sklearn_metric_map[dist_fn]) | ||
|
||
index_dist_pairs = [(i, dist[i][0]) for i in range(len(dist))] | ||
index_dist_pairs.sort(key=lambda x: x[1]) | ||
k = min(K, len(index_dist_pairs)) | ||
top_k_matches = index_dist_pairs[:k] | ||
orig_query_id = query_filtered_vector.orig_id_fil_id_map[fil_q_id] | ||
ground_truth[orig_query_id] = [] | ||
for match in top_k_matches: | ||
ground_truth[orig_query_id].append((base_filtered_vector.orig_id_fil_id_map[match[0]], match[1])) | ||
|
||
def compute_gt(base_filtered_vectors, query_filtered_vectors, dist_fn): | ||
for query_filter in query_filtered_vectors.keys(): | ||
if query_filter not in base_filtered_vectors: | ||
print(f"Filter: {query_filter} in query does not exist in base") | ||
continue | ||
base_filtered_vector = base_filtered_vectors[query_filter] | ||
query_filtered_vector = query_filtered_vectors[query_filter] | ||
compute_filtered_gt(base_filtered_vector, query_filtered_vector, dist_fn) | ||
|
||
print(ground_truth) | ||
|
||
def main(args): | ||
data_type_code, data_type_size = pc.get_data_type_code(args.data_type) | ||
base_data = pc.DataMat(data_type_code, data_type_size) | ||
base_data.load_bin(args.base_file) | ||
|
||
query_data = pc.DataMat(data_type_code, data_type_size) | ||
query_data.load_bin(args.query_file) | ||
|
||
print("Grouping base vectors by filters\n") | ||
base_filters, base_filtered_vectors = process_filters(args.base_filter_file, base_data, is_query=False) | ||
print("Grouping query vectors by filters\n") | ||
query_filters, query_filtered_vectors = process_filters(args.query_filter_file, query_data, is_query=True) | ||
|
||
compute_gt(base_filtered_vectors, query_filtered_vectors, args.dist_fn) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description='Generate GTs for a given base and query file and compare with any existing GT file. Assumes that there is a ton of memory to process the data!!!', prog='GTVerifier.py') | ||
parser.add_argument('--base_file', type=str, help='Base file', required=True) | ||
parser.add_argument('--query_file', type=str, help='Query file', required=True) | ||
parser.add_argument('--base_filter_file', type=str, help='Base filter file', required=False) | ||
parser.add_argument('--query_filter_file', type=str, help='Query filter file', required=False) | ||
parser.add_argument('--output_gt_file', type=str, help='Output GT file', required=True) | ||
parser.add_argument('--existing_gt_file', type=str, help='Existing GT file', required=False) | ||
parser.add_argument('--dist_fn', type=str, help='GT format', required=True) | ||
parser.add_argument('--data_type', type=str, help='GT format', required=True) | ||
args = parser.parse_args() | ||
|
||
|
||
main(args) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import sys | ||
import parse_common as pc | ||
|
||
def main(input_vector_file, data_type, ids_to_remove_file, output_file_prefix, filter_file_for_vectors): | ||
data_type_code, data_type_size = pc.get_data_type_code(data_type) | ||
vectors = pc.DataMat(data_type_code, data_type_size) | ||
vectors.load_bin(input_vector_file) | ||
vector_ids_to_remove = set() | ||
with open(ids_to_remove_file, "r") as f: | ||
for line in f: | ||
vector_ids_to_remove.add(int(line.strip())) | ||
|
||
filters = [] | ||
if filter_file_for_vectors is not None: | ||
with open(filter_file_for_vectors, "r") as f: | ||
for line in f: | ||
filters.append(line.strip()) | ||
|
||
vectors.remove_rows(vector_ids_to_remove) | ||
|
||
output_bin_file = output_file_prefix + "_vecs.bin" | ||
vectors.save_bin(output_bin_file) | ||
print(f"Removed {len(vector_ids_to_remove)} vectors. Output written to {output_bin_file}") | ||
|
||
if len(filters) > 0: | ||
output_filters_file = output_file_prefix + "_filters.txt" | ||
output_filters = [filter for idx, filter in enumerate(filters) if idx not in vector_ids_to_remove] | ||
with open(output_filters_file, "w") as f: | ||
for output_filter in output_filters: | ||
f.write(output_filter + "\n") | ||
print(f"REmoved {len(vector_ids_to_remove)} filters. Output written to {output_filters_file}") | ||
|
||
|
||
if __name__ == "__main__": | ||
if len(sys.argv) != 5 and len(sys.argv) != 6: | ||
print("Usage: <program> <input_vector_file> <data_type_format(float|uint8|int8)> <ids_to_remove_file> <output_file_prefix (program adds _vecs.bin for the vector file and _filters.txt for the filter file)> [<filter_file_for_vectors>]") | ||
sys.exit(1) | ||
else: | ||
main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5] if len(sys.argv) > 5 else None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import sys | ||
import pandas as pd | ||
|
||
base_unique_filters = set() | ||
query_unique_filters = set() | ||
|
||
base_joined_filters_inv_indx = {} | ||
query_joined_filters_inv_indx = {} | ||
|
||
joined_filters_of_point = [] | ||
joined_filters_of_query_point = [] | ||
|
||
|
||
def create_joined_filters(per_category_label_lists): | ||
assert(len(per_category_label_lists) == 3) | ||
joined_filters = [] | ||
for l1 in per_category_label_lists[0]: | ||
for l2 in per_category_label_lists[1]: | ||
for l3 in per_category_label_lists[2]: | ||
joined_filters.append(f"{l1}_{l2}_{l3}") | ||
return joined_filters | ||
|
||
|
||
def parse_filter_line(line, separator, line_number, base_filter_file): | ||
line = line.strip() | ||
if line == '': | ||
print(f"Empty line at line number: {line_number} in {base_filter_file}") | ||
parts = line.split(separator) | ||
if len(parts) != 3: | ||
print(f"line: {line} at line number: {line_number} in {base_filter_file} does not have 3 parts when split by {separator}") | ||
|
||
return parts | ||
|
||
def append_category_name_to_labels(category_id, labels): | ||
category_name = f"C{category_id+1}" | ||
named_labels = [f"{category_name}={part}" for part in labels] | ||
return named_labels | ||
|
||
|
||
|
||
def load_base_file_filters(base_filter_file): | ||
with open(base_filter_file, mode='r', encoding='UTF-8') as f: | ||
count = 0 | ||
for line in f: | ||
count += 1 | ||
cs_filters_per_category = parse_filter_line(line, '|', count, base_filter_file) | ||
per_category_label_lists = [] | ||
for i in range(3): | ||
cat_labels = append_category_name_to_labels(i, cs_filters_per_category[i].split(',')) | ||
assert(len(cat_labels) > 0) | ||
base_unique_filters.update(cat_labels) | ||
per_category_label_lists.append(cat_labels) | ||
|
||
joined_filters = create_joined_filters(per_category_label_lists) | ||
joined_filters_of_point.append([]) | ||
for joined_filter in joined_filters: | ||
joined_filters_of_point[count-1].append(joined_filter) | ||
if joined_filter not in base_joined_filters_inv_indx: | ||
base_joined_filters_inv_indx[joined_filter] = [] | ||
base_joined_filters_inv_indx[joined_filter].append(count) | ||
|
||
if count % 500000 == 0: | ||
print(f"Processed {count} lines in {base_filter_file}") | ||
|
||
print(f"Obtained {len(base_unique_filters)} distinct filters from {base_filter_file}, line count: {count}") | ||
print(f"After joining number of filters is: {len(base_joined_filters_inv_indx)}") | ||
|
||
|
||
|
||
def load_query_file_filters(query_filter_file): | ||
with open(query_filter_file, mode='r', encoding='UTF-8') as f: | ||
count = 0 | ||
for line in f: | ||
count += 1 | ||
cs_filters_per_category = parse_filter_line(line, '|', count, query_filter_file) | ||
per_category_label_lists = [] | ||
for i in range(3): | ||
cat_labels = append_category_name_to_labels(i, cs_filters_per_category[i].split(',')) | ||
assert(len(cat_labels) > 0) | ||
query_unique_filters.update(cat_labels) | ||
per_category_label_lists.append(cat_labels) | ||
|
||
joined_filters = create_joined_filters(per_category_label_lists) | ||
joined_filters_of_query_point.append([]) | ||
for joined_filter in joined_filters: | ||
joined_filters_of_query_point[count-1].append(joined_filter) | ||
if joined_filter not in query_joined_filters_inv_indx: | ||
query_joined_filters_inv_indx[joined_filter] = [] | ||
query_joined_filters_inv_indx[joined_filter].append(count) | ||
|
||
print(f"Obtained {len(query_unique_filters)} distinct filters from {query_filter_file}, line count = {count}") | ||
print(f"After joining number of filters is: {len(query_joined_filters_inv_indx)}") | ||
|
||
|
||
|
||
def analyze(): | ||
missing_query_filters = query_unique_filters.difference(base_unique_filters) | ||
if len(missing_query_filters) > 0: | ||
print(f"Warning: found the following query filters not in base:{missing_query_filters}") | ||
|
||
|
||
bujf = set() | ||
bujf.update(base_joined_filters_inv_indx.keys()) | ||
qujf = set() | ||
qujf.update(query_joined_filters_inv_indx.keys()) | ||
missing_joined_filters = qujf.difference(bujf) | ||
if len(missing_joined_filters) > 0: | ||
print(f"Warning: found the following joined query filters not in base:{missing_joined_filters}") | ||
|
||
|
||
mjqf_query_ids_map = {} | ||
for filter in missing_joined_filters: | ||
mjqf_query_ids_map[filter] = [] | ||
for index, filters_of_point in enumerate(joined_filters_of_query_point): | ||
if filter == filters_of_point[0]: | ||
mjqf_query_ids_map[filter].append(index) | ||
|
||
with open('missing_joined_query_filters.txt', mode='w', encoding='UTF-8') as f: | ||
for filter in mjqf_query_ids_map: | ||
f.write(f"{filter}\t{len(mjqf_query_ids_map[filter])}\t{mjqf_query_ids_map[filter]}\n") | ||
|
||
|
||
print(f"Number of unique base filters: {len(base_unique_filters)}" ) | ||
print(f"Number of unique query filters: {len(query_unique_filters)}" ) | ||
print(f"Number of joined base filters: {len(base_joined_filters_inv_indx)}" ) | ||
print(f"Number of joined query filters: {len(query_joined_filters_inv_indx)}" ) | ||
|
||
def write_joined_filters(output_file_prefix): | ||
base_joined_filters_file = output_file_prefix + '_base_joined_filters.txt' | ||
|
||
with open(base_joined_filters_file, mode='w', encoding='UTF-8') as f: | ||
for filters_of_point in joined_filters_of_point: | ||
str = ','.join([x for x in filters_of_point]) | ||
f.write(f"{str}\n") | ||
print(f"Base joined filters written to {base_joined_filters_file}") | ||
|
||
query_joined_filters_file = output_file_prefix + '_query_joined_filters.txt' | ||
with open(query_joined_filters_file, mode='w', encoding='UTF-8') as f: | ||
for filters_of_point in joined_filters_of_query_point: | ||
str = ','.join([x for x in filters_of_point]) | ||
f.write(f"{str}\n") | ||
print(f"Query joined filters written to {query_joined_filters_file}") | ||
|
||
base_unique_filters_file = output_file_prefix + '_base_unique_filters.txt' | ||
with open(base_unique_filters_file , mode='w', encoding='UTF-8') as f: | ||
sorted_list = sorted(base_unique_filters) | ||
for filter in sorted_list: | ||
f.write(f"{filter}\n") | ||
print(f"Base unique filters written to {base_unique_filters_file}") | ||
|
||
query_unique_filters_file = output_file_prefix + '_query_unique_filters.txt' | ||
with open(query_unique_filters_file, mode='w', encoding='UTF-8') as f: | ||
sorted_list = sorted(query_unique_filters) | ||
for filter in sorted_list: | ||
f.write(f"{filter}\n") | ||
print(f"Query unique filters written to {query_unique_filters_file}") | ||
|
||
base_joined_unique_filters_file = output_file_prefix + '_base_joined_unique_filters.txt' | ||
with open(base_joined_unique_filters_file, mode='w', encoding='UTF-8') as f: | ||
sorted_list = sorted(base_joined_filters_inv_indx.keys()) | ||
for filter in sorted_list: | ||
f.write(f"{filter}\t{len(base_joined_filters_inv_indx[filter])}\n") | ||
print(f"Base joined unique filters written to {base_joined_unique_filters_file}") | ||
|
||
query_unique_joined_filters_file = output_file_prefix + '_query_joined_unique_filters.txt' | ||
with open(query_unique_joined_filters_file, mode='w', encoding='UTF-8') as f: | ||
sorted_list = sorted(query_joined_filters_inv_indx.keys()) | ||
for filter in sorted_list: | ||
f.write(f"{filter}\t{len(query_joined_filters_inv_indx[filter])}\n") | ||
|
||
|
||
|
||
|
||
|
||
def main(base_filter_file, query_filter_file, output_path_prefix): | ||
load_base_file_filters(base_filter_file) | ||
load_query_file_filters(query_filter_file) | ||
analyze() | ||
write_joined_filters(output_path_prefix) | ||
|
||
|
||
if __name__ == "__main__": | ||
if len(sys.argv) != 4: | ||
print("Usage: AdsMultiFilterAnalyzer.py <base_filter_file> <query_filter_file> <output_file_prefix>") | ||
print("Both base file should have label categories separated by | and labels separated by commas") | ||
print("Query file should have labels separated by |") | ||
sys.exit(1) | ||
else: | ||
main(sys.argv[1], sys.argv[2], sys.argv[3]) |
Oops, something went wrong.