diff --git a/src/Core/Storage/Database.py b/src/Core/Storage/Database.py index 5c9884d..2f95da2 100644 --- a/src/Core/Storage/Database.py +++ b/src/Core/Storage/Database.py @@ -422,13 +422,15 @@ def __add_data(self, def update(self, table_name: str, data: Dict[str, Any], - line_id: int = -1): + line_id: int = -1, + create_fields: bool = False): """ Update a line of a Table. :param table_name: Name of the Table on which to perform the query. :param data: Updated data of the line. :param line_id: Index of the line to update. + :param create_fields: Create missing fields. """ # Check table existence @@ -443,6 +445,9 @@ def update(self, # Define the line index nb_line = self.nb_lines(table_name=table_name) + if nb_line == 0: + self.add_data(table_name=table_name, data=data) + return if line_id < 0: line_id += nb_line + 1 elif line_id > nb_line: @@ -450,6 +455,10 @@ def update(self, # Check fields existence undefined_fields = set(fields_names) - set(table.fields()) + if create_fields: + fields_to_create = [(undef, type(data[undef])) for undef in undefined_fields] + self.create_fields(table_name=table_name, fields=fields_to_create) + undefined_fields = set(fields_names) - set(table.fields()) if len(undefined_fields) > 0: raise ValueError(f"[{self.__class__.__name__}] Some fields where not defined in table {table}." f" As table {table} is non-empty, please define first the following fields :" @@ -476,15 +485,13 @@ def update(self, def get_line(self, table_name: str, fields: Optional[Union[str, List[str]]] = None, - line_id: int = -1, - joins: Optional[Union[str, List[str]]] = None): + line_id: int = -1): """ Get a line of a Table. :param table_name: Name of the Table on which to perform the query. :param fields: Name(s) of the Field(s) to request. :param line_id: Index of the line to get. - :param joins: Name(s) of Table(s) to join to the selection. """ # Check the Table existence @@ -495,19 +502,13 @@ def get_line(self, # Define the fields to select fields_selection = () - if fields is not None: - fields_selection += (table.id,) - fields = [fields] if type(fields) == str else fields - for field in fields: - if field in table.fields(): - fields_selection += (table.fields(only_names=False)[field],) - if joins is not None: - joins = [joins] if type(joins) == str else joins - for j in joins: - if j in self.__fk[table_name].values() and j not in fields: - field_name = list(self.__fk[table_name].keys())[ - list(self.__fk[table_name].values()).index(j)] - fields_selection += (table.fields(only_names=False)[field_name],) + if fields is None: + fields = table.fields() + fields_selection += (table.id,) + fields = [fields] if type(fields) == str else fields + for field in fields: + if field in table.fields(): + fields_selection += (table.fields(only_names=False)[field],) # Define the index of the line to select nb_line = self.nb_lines(table_name=table_name) @@ -520,16 +521,11 @@ def get_line(self, data = table.select(*fields_selection).where(table.id == line_id).dicts()[0] # Join - if joins is not None: - joins = [joins] if type(joins) == str else joins - for j in joins: - if j in self.__fk[table_name].values(): - field_name = list(self.__fk[table_name].keys())[list(self.__fk[table_name].values()).index(j)] - if field_name in data: - data[field_name] = self.get_line(table_name=j, - fields=fields, - line_id=data[field_name], - joins=j) + for field in fields: + if field in self.__fk[table_name].keys(): + data[field] = self.get_line(table_name=self.__fk[table_name][field], + fields=fields, + line_id=data[field]) return data @@ -538,7 +534,6 @@ def get_lines(self, fields: Optional[Union[str, List[str]]] = None, lines_id: Optional[List[int]] = None, lines_range: Optional[List[int]] = None, - joins: Optional[Union[str, List[str]]] = None, batched: bool = False): """ Get a set of lines of a Table. @@ -547,7 +542,6 @@ def get_lines(self, :param fields: Name(s) of the Field(s) to select. :param lines_id: Indices of the lines to get. If not specified, 'lines_range' value will be used. :param lines_range: Range of indices of the lines to get. If not specified, all lines will be selected. - :param joins: Name(s) of Table(s) to join to the selection. :param batched: If True, data is returned as one batch per field. Otherwise, data is returned as list of lines. """ @@ -559,19 +553,13 @@ def get_lines(self, # Define the fields to select fields_selection = () - if fields is not None: - fields_selection += (table.id,) - fields = [fields] if type(fields) == str else fields - for field in fields: - if field in table.fields(): - fields_selection += (table.fields(only_names=False)[field],) - if joins is not None: - joins = [joins] if type(joins) == str else joins - for j in joins: - if j in self.__fk[table_name].values() and j not in fields: - field_name = list(self.__fk[table_name].keys())[ - list(self.__fk[table_name].values()).index(j)] - fields_selection += (table.fields(only_names=False)[field_name],) + if fields is None: + fields = table.fields() + fields_selection += (table.id,) + fields = [fields] if type(fields) == str else fields + for field in fields: + if field in table.fields(): + fields_selection += (table.fields(only_names=False)[field],) # Define the indices of lines to select if lines_id is None: @@ -601,27 +589,17 @@ def get_lines(self, lines = [line for line in query] # Join - if joins is not None: - joins = [joins] if type(joins) == str else joins - for j in joins: - if j in self.__fk[table_name].values(): - field_name = list(self.__fk[table_name].keys())[ - list(self.__fk[table_name].values()).index(j)] - dict_keys = lines.keys() if batched else lines[0].keys() - if field_name in dict_keys: - lines_id = lines[field_name] if batched else [line[field_name] for line in lines] - data = self.get_lines(table_name=j, - fields=fields, - lines_id=lines_id, - joins=joins, - batched=batched) - - if batched: - lines[field_name] = data - else: - for i, l in enumerate(data): - lines[i][field_name] = l - + for field in fields: + if field in self.__fk[table_name].keys(): + data = self.get_lines(table_name=self.__fk[table_name][field], + fields=fields, + lines_id=lines[field] if batched else [line[field] for line in lines], + batched=batched) + if batched: + lines[field] = data + else: + for i, l in enumerate(data): + lines[i][field] = l return lines def nb_lines(self,