Skip to content

Commit

Permalink
Scripts to remove selected vectors, verify gt, and some file conversi…
Browse files Browse the repository at this point in the history
…ons for Bing filter index
  • Loading branch information
gopal-msr committed Oct 7, 2024
1 parent bae427e commit 2ede651
Show file tree
Hide file tree
Showing 9 changed files with 872 additions and 0 deletions.
149 changes: 149 additions & 0 deletions scripts/datamanip/GTVerifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import sys
import re
import argparse
from sklearn.metrics import pairwise_distances
import parse_common as pc


#This is a very lame implementation of GT computation so that we avoid errors
#and can compare the GTs generated by more sophisticated methods.
#This code assumes that all the data fits in memory

K = 100
STR_NO_FILTER = "__NONE__"
dist_fn_sklearn_metric_map = {"l2" : "euclidean",
"cosine" : "cosine",
"mips" : "linear"}


ground_truth = {}

class FilteredVectors:
def __init__(self, filter):
self.data = []
self.filter = filter
self.orig_id_fil_id_map = {}
self.count = 0
self.data_mat = None

def add_vector(self, vector, orig_id):
self.data.append(vector)
self.orig_id_fil_id_map[self.count] = orig_id
self.count += 1

def assign_data_mat(self, data_mat):
self.data_mat = data_mat

def __len__(self):
return len(self.data) if self.data_mat is None else self.data_mat.num_rows

def __getitem__(self, index):
return self.data[index]

def __str__(self) -> str:
return self.filter + " filters out " + str(self.count) + " vectors"


def parse_filter_line(line, delim_regex, line_number, filter_file_name):
line = line.strip()
filters = re.split(delim_regex, line)
if len(filters) == 0:
raise Exception(f"Filter line: {line} at line number: {line_number} in file {filter_file_name} does not have any filters")
return filters

def process_filters(filter_file, data_mat, is_query):
filters = []
filtered_vectors = {}
if filter_file is not None:
with open(filter_file, mode='r', encoding='UTF-8') as f:
line_num = 0
for line in f:
filters_of_point = parse_filter_line(line, ',', line_num+1, filter_file)
filters.append(filters_of_point)
for filter in filters_of_point:
if filter not in filtered_vectors:
filtered_vectors[filter] = FilteredVectors(filter)
filtered_vectors[filter].add_vector(data_mat.get_vector(line_num), line_num)
line_num += 1
else:
if not is_query:
filtered_vectors[STR_NO_FILTER] = FilteredVectors(STR_NO_FILTER)
filtered_vectors[STR_NO_FILTER].assign_data_mat(data_mat)
else:
#to simplify the code, we copy the queries one-by-one into the data list.
filtered_vectors[STR_NO_FILTER] = FilteredVectors(STR_NO_FILTER)
for i in range(data_mat.num_rows):
filtered_vectors[STR_NO_FILTER].add_vector(data_mat.get_vector(i), i)

for filter in filtered_vectors:
print(filtered_vectors[filter])

unique_vector_ids = set()
for filtered_vector in filtered_vectors.values():
unique_vector_ids.update(filtered_vector.orig_id_fil_id_map.values())
all_ids = set(range(data_mat.num_rows))
if len(all_ids.difference(unique_vector_ids)) > 0:
raise Exception(f"Missing vectors in filters: {all_ids.difference(unique_vector_ids)}")

return filters, filtered_vectors

def compute_filtered_gt(base_filtered_vector, query_filtered_vector, dist_fn):
print(f"Computing GT for filter: {query_filtered_vector.filter}, base count: {len(base_filtered_vector)}, query count: {len(query_filtered_vector)}")
for fil_q_id, query_vector in enumerate(query_filtered_vector.data):
qv = query_vector.reshape(1, -1)
if base_filtered_vector.data_mat is not None:
dist = pairwise_distances(base_filtered_vector.data_mat.data, qv, metric=dist_fn_sklearn_metric_map[dist_fn])
else:
dist = pairwise_distances(base_filtered_vector.data, qv, metric=dist_fn_sklearn_metric_map[dist_fn])

index_dist_pairs = [(i, dist[i][0]) for i in range(len(dist))]
index_dist_pairs.sort(key=lambda x: x[1])
k = min(K, len(index_dist_pairs))
top_k_matches = index_dist_pairs[:k]
orig_query_id = query_filtered_vector.orig_id_fil_id_map[fil_q_id]
ground_truth[orig_query_id] = []
for match in top_k_matches:
ground_truth[orig_query_id].append((base_filtered_vector.orig_id_fil_id_map[match[0]], match[1]))

def compute_gt(base_filtered_vectors, query_filtered_vectors, dist_fn):
for query_filter in query_filtered_vectors.keys():
if query_filter not in base_filtered_vectors:
print(f"Filter: {query_filter} in query does not exist in base")
continue
base_filtered_vector = base_filtered_vectors[query_filter]
query_filtered_vector = query_filtered_vectors[query_filter]
compute_filtered_gt(base_filtered_vector, query_filtered_vector, dist_fn)

print(ground_truth)

def main(args):
data_type_code, data_type_size = pc.get_data_type_code(args.data_type)
base_data = pc.DataMat(data_type_code, data_type_size)
base_data.load_bin(args.base_file)

query_data = pc.DataMat(data_type_code, data_type_size)
query_data.load_bin(args.query_file)

print("Grouping base vectors by filters\n")
base_filters, base_filtered_vectors = process_filters(args.base_filter_file, base_data, is_query=False)
print("Grouping query vectors by filters\n")
query_filters, query_filtered_vectors = process_filters(args.query_filter_file, query_data, is_query=True)

compute_gt(base_filtered_vectors, query_filtered_vectors, args.dist_fn)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generate GTs for a given base and query file and compare with any existing GT file. Assumes that there is a ton of memory to process the data!!!', prog='GTVerifier.py')
parser.add_argument('--base_file', type=str, help='Base file', required=True)
parser.add_argument('--query_file', type=str, help='Query file', required=True)
parser.add_argument('--base_filter_file', type=str, help='Base filter file', required=False)
parser.add_argument('--query_filter_file', type=str, help='Query filter file', required=False)
parser.add_argument('--output_gt_file', type=str, help='Output GT file', required=True)
parser.add_argument('--existing_gt_file', type=str, help='Existing GT file', required=False)
parser.add_argument('--dist_fn', type=str, help='GT format', required=True)
parser.add_argument('--data_type', type=str, help='GT format', required=True)
args = parser.parse_args()


main(args)


39 changes: 39 additions & 0 deletions scripts/datamanip/RemoveSelectedVectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import sys
import parse_common as pc

def main(input_vector_file, data_type, ids_to_remove_file, output_file_prefix, filter_file_for_vectors):
data_type_code, data_type_size = pc.get_data_type_code(data_type)
vectors = pc.DataMat(data_type_code, data_type_size)
vectors.load_bin(input_vector_file)
vector_ids_to_remove = set()
with open(ids_to_remove_file, "r") as f:
for line in f:
vector_ids_to_remove.add(int(line.strip()))

filters = []
if filter_file_for_vectors is not None:
with open(filter_file_for_vectors, "r") as f:
for line in f:
filters.append(line.strip())

vectors.remove_rows(vector_ids_to_remove)

output_bin_file = output_file_prefix + "_vecs.bin"
vectors.save_bin(output_bin_file)
print(f"Removed {len(vector_ids_to_remove)} vectors. Output written to {output_bin_file}")

if len(filters) > 0:
output_filters_file = output_file_prefix + "_filters.txt"
output_filters = [filter for idx, filter in enumerate(filters) if idx not in vector_ids_to_remove]
with open(output_filters_file, "w") as f:
for output_filter in output_filters:
f.write(output_filter + "\n")
print(f"REmoved {len(vector_ids_to_remove)} filters. Output written to {output_filters_file}")


if __name__ == "__main__":
if len(sys.argv) != 5 and len(sys.argv) != 6:
print("Usage: <program> <input_vector_file> <data_type_format(float|uint8|int8)> <ids_to_remove_file> <output_file_prefix (program adds _vecs.bin for the vector file and _filters.txt for the filter file)> [<filter_file_for_vectors>]")
sys.exit(1)
else:
main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5] if len(sys.argv) > 5 else None)
189 changes: 189 additions & 0 deletions scripts/datamanip/convert/AdsMultiFilterAnalyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import sys
import pandas as pd

base_unique_filters = set()
query_unique_filters = set()

base_joined_filters_inv_indx = {}
query_joined_filters_inv_indx = {}

joined_filters_of_point = []
joined_filters_of_query_point = []


def create_joined_filters(per_category_label_lists):
assert(len(per_category_label_lists) == 3)
joined_filters = []
for l1 in per_category_label_lists[0]:
for l2 in per_category_label_lists[1]:
for l3 in per_category_label_lists[2]:
joined_filters.append(f"{l1}_{l2}_{l3}")
return joined_filters


def parse_filter_line(line, separator, line_number, base_filter_file):
line = line.strip()
if line == '':
print(f"Empty line at line number: {line_number} in {base_filter_file}")
parts = line.split(separator)
if len(parts) != 3:
print(f"line: {line} at line number: {line_number} in {base_filter_file} does not have 3 parts when split by {separator}")

return parts

def append_category_name_to_labels(category_id, labels):
category_name = f"C{category_id+1}"
named_labels = [f"{category_name}={part}" for part in labels]
return named_labels



def load_base_file_filters(base_filter_file):
with open(base_filter_file, mode='r', encoding='UTF-8') as f:
count = 0
for line in f:
count += 1
cs_filters_per_category = parse_filter_line(line, '|', count, base_filter_file)
per_category_label_lists = []
for i in range(3):
cat_labels = append_category_name_to_labels(i, cs_filters_per_category[i].split(','))
assert(len(cat_labels) > 0)
base_unique_filters.update(cat_labels)
per_category_label_lists.append(cat_labels)

joined_filters = create_joined_filters(per_category_label_lists)
joined_filters_of_point.append([])
for joined_filter in joined_filters:
joined_filters_of_point[count-1].append(joined_filter)
if joined_filter not in base_joined_filters_inv_indx:
base_joined_filters_inv_indx[joined_filter] = []
base_joined_filters_inv_indx[joined_filter].append(count)

if count % 500000 == 0:
print(f"Processed {count} lines in {base_filter_file}")

print(f"Obtained {len(base_unique_filters)} distinct filters from {base_filter_file}, line count: {count}")
print(f"After joining number of filters is: {len(base_joined_filters_inv_indx)}")



def load_query_file_filters(query_filter_file):
with open(query_filter_file, mode='r', encoding='UTF-8') as f:
count = 0
for line in f:
count += 1
cs_filters_per_category = parse_filter_line(line, '|', count, query_filter_file)
per_category_label_lists = []
for i in range(3):
cat_labels = append_category_name_to_labels(i, cs_filters_per_category[i].split(','))
assert(len(cat_labels) > 0)
query_unique_filters.update(cat_labels)
per_category_label_lists.append(cat_labels)

joined_filters = create_joined_filters(per_category_label_lists)
joined_filters_of_query_point.append([])
for joined_filter in joined_filters:
joined_filters_of_query_point[count-1].append(joined_filter)
if joined_filter not in query_joined_filters_inv_indx:
query_joined_filters_inv_indx[joined_filter] = []
query_joined_filters_inv_indx[joined_filter].append(count)

print(f"Obtained {len(query_unique_filters)} distinct filters from {query_filter_file}, line count = {count}")
print(f"After joining number of filters is: {len(query_joined_filters_inv_indx)}")



def analyze():
missing_query_filters = query_unique_filters.difference(base_unique_filters)
if len(missing_query_filters) > 0:
print(f"Warning: found the following query filters not in base:{missing_query_filters}")


bujf = set()
bujf.update(base_joined_filters_inv_indx.keys())
qujf = set()
qujf.update(query_joined_filters_inv_indx.keys())
missing_joined_filters = qujf.difference(bujf)
if len(missing_joined_filters) > 0:
print(f"Warning: found the following joined query filters not in base:{missing_joined_filters}")


mjqf_query_ids_map = {}
for filter in missing_joined_filters:
mjqf_query_ids_map[filter] = []
for index, filters_of_point in enumerate(joined_filters_of_query_point):
if filter == filters_of_point[0]:
mjqf_query_ids_map[filter].append(index)

with open('missing_joined_query_filters.txt', mode='w', encoding='UTF-8') as f:
for filter in mjqf_query_ids_map:
f.write(f"{filter}\t{len(mjqf_query_ids_map[filter])}\t{mjqf_query_ids_map[filter]}\n")


print(f"Number of unique base filters: {len(base_unique_filters)}" )
print(f"Number of unique query filters: {len(query_unique_filters)}" )
print(f"Number of joined base filters: {len(base_joined_filters_inv_indx)}" )
print(f"Number of joined query filters: {len(query_joined_filters_inv_indx)}" )

def write_joined_filters(output_file_prefix):
base_joined_filters_file = output_file_prefix + '_base_joined_filters.txt'

with open(base_joined_filters_file, mode='w', encoding='UTF-8') as f:
for filters_of_point in joined_filters_of_point:
str = ','.join([x for x in filters_of_point])
f.write(f"{str}\n")
print(f"Base joined filters written to {base_joined_filters_file}")

query_joined_filters_file = output_file_prefix + '_query_joined_filters.txt'
with open(query_joined_filters_file, mode='w', encoding='UTF-8') as f:
for filters_of_point in joined_filters_of_query_point:
str = ','.join([x for x in filters_of_point])
f.write(f"{str}\n")
print(f"Query joined filters written to {query_joined_filters_file}")

base_unique_filters_file = output_file_prefix + '_base_unique_filters.txt'
with open(base_unique_filters_file , mode='w', encoding='UTF-8') as f:
sorted_list = sorted(base_unique_filters)
for filter in sorted_list:
f.write(f"{filter}\n")
print(f"Base unique filters written to {base_unique_filters_file}")

query_unique_filters_file = output_file_prefix + '_query_unique_filters.txt'
with open(query_unique_filters_file, mode='w', encoding='UTF-8') as f:
sorted_list = sorted(query_unique_filters)
for filter in sorted_list:
f.write(f"{filter}\n")
print(f"Query unique filters written to {query_unique_filters_file}")

base_joined_unique_filters_file = output_file_prefix + '_base_joined_unique_filters.txt'
with open(base_joined_unique_filters_file, mode='w', encoding='UTF-8') as f:
sorted_list = sorted(base_joined_filters_inv_indx.keys())
for filter in sorted_list:
f.write(f"{filter}\t{len(base_joined_filters_inv_indx[filter])}\n")
print(f"Base joined unique filters written to {base_joined_unique_filters_file}")

query_unique_joined_filters_file = output_file_prefix + '_query_joined_unique_filters.txt'
with open(query_unique_joined_filters_file, mode='w', encoding='UTF-8') as f:
sorted_list = sorted(query_joined_filters_inv_indx.keys())
for filter in sorted_list:
f.write(f"{filter}\t{len(query_joined_filters_inv_indx[filter])}\n")





def main(base_filter_file, query_filter_file, output_path_prefix):
load_base_file_filters(base_filter_file)
load_query_file_filters(query_filter_file)
analyze()
write_joined_filters(output_path_prefix)


if __name__ == "__main__":
if len(sys.argv) != 4:
print("Usage: AdsMultiFilterAnalyzer.py <base_filter_file> <query_filter_file> <output_file_prefix>")
print("Both base file should have label categories separated by | and labels separated by commas")
print("Query file should have labels separated by |")
sys.exit(1)
else:
main(sys.argv[1], sys.argv[2], sys.argv[3])
Loading

0 comments on commit 2ede651

Please sign in to comment.