Skip to content

Commit

Permalink
Added remove_rows() and save functionality in parse_common.py
Browse files Browse the repository at this point in the history
  • Loading branch information
gopal-msr committed Oct 10, 2024
1 parent cdcd805 commit 8211f9e
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion scripts/IndexParser/parse_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import array
import struct
import numpy as np

"""Constants"""
SECTOR_LEN = 4096
Expand All @@ -19,7 +20,6 @@ def get_data_type_code_and_size(data_type):
data_type_size = 1
else:
raise Exception("Unsupported data type. Supported data types are float, int8 and uint8")

return data_type_code, data_type_size

class Node:
Expand Down Expand Up @@ -81,11 +81,47 @@ def load_data_only_from_opened_file(self, file, num_rows, num_cols, file_offset_
self.data = array.array(self.data_format_specifier)
self.data.fromfile(file, self.num_rows*self.num_cols)

def remove_rows(self, rows_to_remove):
for row in rows_to_remove:
if row < 0 or row >= self.num_rows:
raise Exception("Invalid row index: " + str(row))

new_data = array.array(self.data_format_specifier)

# index = 0
# rows_to_remove.sort()
# while index < len(rows_to_remove):
# if rows_to_remove[index+1] - rows_to_remove[index] == 1:
# index += 1
# else:
# #have to add everything after row
# new_data.extend(self.data[(rows_to_remove[index]+1)*self.num_cols:(rows_to_remove[index+1])*self.num_cols])

for i in range(self.num_rows):
if i not in rows_to_remove:
new_data.extend(self.data[i*self.num_cols:(i+1)*self.num_cols])

self.data = new_data
self.num_rows -= len(rows_to_remove)

def save_bin(self, file_name):
with open(file_name, "wb") as file:
self.save_bin_to_opened_file(file)

def save_bin_to_opened_file(self, file):
file.write(struct.pack('I', self.num_rows))
file.write(struct.pack('I', self.num_cols))
self.data.tofile(file)


def __len__(self):
return self.num_rows

def get_vector(self, row):
return np.array(self.data[row*self.num_cols:(row+1)*self.num_cols])

def __getitem__(self, key):
return self.data[key*self.num_cols:(key+1)*self.num_cols]



0 comments on commit 8211f9e

Please sign in to comment.