1
1
import array
2
2
import struct
3
+ import numpy as np
3
4
4
5
"""Constants"""
5
6
SECTOR_LEN = 4096
@@ -19,7 +20,6 @@ def get_data_type_code_and_size(data_type):
19
20
data_type_size = 1
20
21
else :
21
22
raise Exception ("Unsupported data type. Supported data types are float, int8 and uint8" )
22
-
23
23
return data_type_code , data_type_size
24
24
25
25
class Node :
@@ -81,11 +81,47 @@ def load_data_only_from_opened_file(self, file, num_rows, num_cols, file_offset_
81
81
self .data = array .array (self .data_format_specifier )
82
82
self .data .fromfile (file , self .num_rows * self .num_cols )
83
83
84
+ def remove_rows (self , rows_to_remove ):
85
+ for row in rows_to_remove :
86
+ if row < 0 or row >= self .num_rows :
87
+ raise Exception ("Invalid row index: " + str (row ))
88
+
89
+ new_data = array .array (self .data_format_specifier )
90
+
91
+ # index = 0
92
+ # rows_to_remove.sort()
93
+ # while index < len(rows_to_remove):
94
+ # if rows_to_remove[index+1] - rows_to_remove[index] == 1:
95
+ # index += 1
96
+ # else:
97
+ # #have to add everything after row
98
+ # new_data.extend(self.data[(rows_to_remove[index]+1)*self.num_cols:(rows_to_remove[index+1])*self.num_cols])
99
+
100
+ for i in range (self .num_rows ):
101
+ if i not in rows_to_remove :
102
+ new_data .extend (self .data [i * self .num_cols :(i + 1 )* self .num_cols ])
103
+
104
+ self .data = new_data
105
+ self .num_rows -= len (rows_to_remove )
106
+
107
+ def save_bin (self , file_name ):
108
+ with open (file_name , "wb" ) as file :
109
+ self .save_bin_to_opened_file (file )
110
+
111
+ def save_bin_to_opened_file (self , file ):
112
+ file .write (struct .pack ('I' , self .num_rows ))
113
+ file .write (struct .pack ('I' , self .num_cols ))
114
+ self .data .tofile (file )
115
+
116
+
84
117
def __len__ (self ):
85
118
return self .num_rows
86
119
120
+ def get_vector (self , row ):
121
+ return np .array (self .data [row * self .num_cols :(row + 1 )* self .num_cols ])
87
122
88
123
def __getitem__ (self , key ):
89
124
return self .data [key * self .num_cols :(key + 1 )* self .num_cols ]
125
+
90
126
91
127
0 commit comments