diff --git a/scripts/IndexParser/parse_common.py b/scripts/IndexParser/parse_common.py index eb0b4f6b9..50d047d4f 100644 --- a/scripts/IndexParser/parse_common.py +++ b/scripts/IndexParser/parse_common.py @@ -1,5 +1,6 @@ import array import struct +import numpy as np """Constants""" SECTOR_LEN = 4096 @@ -19,7 +20,6 @@ def get_data_type_code_and_size(data_type): data_type_size = 1 else: raise Exception("Unsupported data type. Supported data types are float, int8 and uint8") - return data_type_code, data_type_size class Node: @@ -81,11 +81,47 @@ def load_data_only_from_opened_file(self, file, num_rows, num_cols, file_offset_ self.data = array.array(self.data_format_specifier) self.data.fromfile(file, self.num_rows*self.num_cols) + def remove_rows(self, rows_to_remove): + for row in rows_to_remove: + if row < 0 or row >= self.num_rows: + raise Exception("Invalid row index: " + str(row)) + + new_data = array.array(self.data_format_specifier) + + # index = 0 + # rows_to_remove.sort() + # while index < len(rows_to_remove): + # if rows_to_remove[index+1] - rows_to_remove[index] == 1: + # index += 1 + # else: + # #have to add everything after row + # new_data.extend(self.data[(rows_to_remove[index]+1)*self.num_cols:(rows_to_remove[index+1])*self.num_cols]) + + for i in range(self.num_rows): + if i not in rows_to_remove: + new_data.extend(self.data[i*self.num_cols:(i+1)*self.num_cols]) + + self.data = new_data + self.num_rows -= len(rows_to_remove) + + def save_bin(self, file_name): + with open(file_name, "wb") as file: + self.save_bin_to_opened_file(file) + + def save_bin_to_opened_file(self, file): + file.write(struct.pack('I', self.num_rows)) + file.write(struct.pack('I', self.num_cols)) + self.data.tofile(file) + + def __len__(self): return self.num_rows + def get_vector(self, row): + return np.array(self.data[row*self.num_cols:(row+1)*self.num_cols]) def __getitem__(self, key): return self.data[key*self.num_cols:(key+1)*self.num_cols] +