Skip to content

Commit 8211f9e

Browse files
committed
Added remove_rows() and save functionality in parse_common.py
1 parent cdcd805 commit 8211f9e

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

scripts/IndexParser/parse_common.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import array
22
import struct
3+
import numpy as np
34

45
"""Constants"""
56
SECTOR_LEN = 4096
@@ -19,7 +20,6 @@ def get_data_type_code_and_size(data_type):
1920
data_type_size = 1
2021
else:
2122
raise Exception("Unsupported data type. Supported data types are float, int8 and uint8")
22-
2323
return data_type_code, data_type_size
2424

2525
class Node:
@@ -81,11 +81,47 @@ def load_data_only_from_opened_file(self, file, num_rows, num_cols, file_offset_
8181
self.data = array.array(self.data_format_specifier)
8282
self.data.fromfile(file, self.num_rows*self.num_cols)
8383

84+
def remove_rows(self, rows_to_remove):
85+
for row in rows_to_remove:
86+
if row < 0 or row >= self.num_rows:
87+
raise Exception("Invalid row index: " + str(row))
88+
89+
new_data = array.array(self.data_format_specifier)
90+
91+
# index = 0
92+
# rows_to_remove.sort()
93+
# while index < len(rows_to_remove):
94+
# if rows_to_remove[index+1] - rows_to_remove[index] == 1:
95+
# index += 1
96+
# else:
97+
# #have to add everything after row
98+
# new_data.extend(self.data[(rows_to_remove[index]+1)*self.num_cols:(rows_to_remove[index+1])*self.num_cols])
99+
100+
for i in range(self.num_rows):
101+
if i not in rows_to_remove:
102+
new_data.extend(self.data[i*self.num_cols:(i+1)*self.num_cols])
103+
104+
self.data = new_data
105+
self.num_rows -= len(rows_to_remove)
106+
107+
def save_bin(self, file_name):
108+
with open(file_name, "wb") as file:
109+
self.save_bin_to_opened_file(file)
110+
111+
def save_bin_to_opened_file(self, file):
112+
file.write(struct.pack('I', self.num_rows))
113+
file.write(struct.pack('I', self.num_cols))
114+
self.data.tofile(file)
115+
116+
84117
def __len__(self):
85118
return self.num_rows
86119

120+
def get_vector(self, row):
121+
return np.array(self.data[row*self.num_cols:(row+1)*self.num_cols])
87122

88123
def __getitem__(self, key):
89124
return self.data[key*self.num_cols:(key+1)*self.num_cols]
125+
90126

91127

0 commit comments

Comments
 (0)