diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index 6332dc26..faec65cd 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -531,7 +531,9 @@ def executescript(self, sql: str) -> sqlite3.Cursor: self._tracer(sql, None) return self.conn.executescript(sql) - def table(self, table_name: str, **kwargs) -> Union["Table", "View"]: + def table( + self, table_name: str, schema: str = "", **kwargs + ) -> Union["Table", "View"]: """ Return a table object, optionally configured with default options. @@ -540,10 +542,10 @@ def table(self, table_name: str, **kwargs) -> Union["Table", "View"]: :param table_name: Name of the table """ if table_name in self.view_names(): - return View(self, table_name, **kwargs) + return View(self, table_name, schema_name=schema, **kwargs) else: kwargs.setdefault("strict", self.strict) - return Table(self, table_name, **kwargs) + return Table(self, table_name, schema_name=schema, **kwargs) def quote(self, value: str) -> str: """ @@ -600,27 +602,46 @@ def quote_default_value(self, value: str) -> str: return self.quote(value) - def table_names(self, fts4: bool = False, fts5: bool = False) -> List[str]: + def schema_names(self) -> List[str]: + """List of string database schemas available in this connection. + + Unless other databases are ATTACHed using `attach`, this will only return + `['main']` or `['main', 'temp']`. See https://www.sqlite.org/lang_attach.html """ - List of string table names in this database. + return [r[1] for r in self.execute("PRAGMA database_list").fetchall()] + + def _from_schema(self, schema: str) -> str: + if schema and schema != "main": + return f"{schema}.sqlite_master" + return "sqlite_master" # keep SQL simple for the standard case. + + def table_names( + self, fts4: bool = False, fts5: bool = False, schema: str = "" + ) -> List[str]: + """List of string table names in the specified database schema. :param fts4: Only return tables that are part of FTS4 indexes :param fts5: Only return tables that are part of FTS5 indexes + :param schema: By default, the `main` schema is queried, but a different, + attached database can be queried instead. """ where = ["type = 'table'"] if fts4: where.append("sql like '%USING FTS4%'") if fts5: where.append("sql like '%USING FTS5%'") - sql = "select name from sqlite_master where {}".format(" AND ".join(where)) + + sql = "select name from {} where {}".format( + self._from_schema(schema), " AND ".join(where) + ) return [r[0] for r in self.execute(sql).fetchall()] - def view_names(self) -> List[str]: + def view_names(self, schema: str = "") -> List[str]: "List of string view names in this database." return [ r[0] for r in self.execute( - "select name from sqlite_master where type = 'view'" + f"select name from {self._from_schema(schema)} where type = 'view'" ).fetchall() ] @@ -1270,14 +1291,25 @@ def init_spatialite(self, path: Optional[str] = None) -> bool: return result and bool(result[0]) +def _fullname(schema_name: str, table_name: str) -> str: + if schema_name: + return f"{schema_name}.[{table_name}]" + return "[" + table_name + "]" + + class Queryable: def exists(self) -> bool: "Does this table or view exist yet?" return False - def __init__(self, db, name): + def __init__(self, db, name: str, schema_name: str = ""): self.db = db self.name = name + self.schema_name = schema_name # default is empty string, a.k.a. 'main' + + @property + def _fullname(self) -> str: + return _fullname(self.schema_name, self.name) def count_where( self, @@ -1291,7 +1323,7 @@ def count_where( :param where_args: Parameters to use with that fragment - an iterable for ``id > ?`` parameters, or a dictionary for ``id > :id`` """ - sql = "select count(*) from [{}]".format(self.name) + sql = "select count(*) from {}".format(self._fullname) if where is not None: sql += " where " + where return self.db.execute(sql, where_args or []).fetchone()[0] @@ -1334,7 +1366,7 @@ def rows_where( """ if not self.exists(): return - sql = "select {} from [{}]".format(select, self.name) + sql = "select {} from {}".format(select, self._fullname) if where is not None: sql += " where " + where if order_by is not None: @@ -1386,12 +1418,24 @@ def pks_and_rows_where( row_pk = row_pk[0] yield row_pk, row + @property + def is_attached(self) -> bool: + return self.schema_name not in {"", "main"} + + @property + def _pragma_name(self) -> Tuple[str, str]: + if self.schema_name: + return self.schema_name + ".", self.name + return "", self.name + @property def columns(self) -> List["Column"]: "List of :ref:`Columns ` representing the columns in this table or view." if not self.exists(): return [] - rows = self.db.execute("PRAGMA table_info([{}])".format(self.name)).fetchall() + rows = self.db.execute( + "PRAGMA {}table_info([{}])".format(*self._pragma_name) + ).fetchall() return [Column(*row) for row in rows] @property @@ -1402,8 +1446,9 @@ def columns_dict(self) -> Dict[str, Any]: @property def schema(self) -> str: "SQL schema for this table or view." + db, name = self._pragma_name return self.db.execute( - "select sql from sqlite_master where name = ?", (self.name,) + f"select sql from {db}sqlite_master where name = ?", (name,) ).fetchone()[0] @@ -1457,8 +1502,9 @@ def __init__( conversions: Optional[dict] = None, columns: Optional[Dict[str, Any]] = None, strict: bool = False, + schema_name: str = "", ): - super().__init__(db, name) + super().__init__(db, name, schema_name=schema_name) self._defaults = dict( pk=pk, foreign_keys=foreign_keys, @@ -1497,7 +1543,7 @@ def count(self) -> int: return self.count_where() def exists(self) -> bool: - return self.name in self.db.table_names() + return self.name in self.db.table_names(schema=self.schema_name) @property def pks(self) -> List[str]: @@ -1545,7 +1591,7 @@ def foreign_keys(self) -> List["ForeignKey"]: "List of foreign keys defined on this table." fks = [] for row in self.db.execute( - "PRAGMA foreign_key_list([{}])".format(self.name) + "PRAGMA {}foreign_key_list([{}])".format(*self._pragma_name) ).fetchall(): if row is not None: id, seq, table_name, from_, to_, on_update, on_delete, match = row @@ -1570,7 +1616,8 @@ def virtual_table_using(self) -> Optional[str]: @property def indexes(self) -> List[Index]: "List of indexes defined on this table." - sql = 'PRAGMA index_list("{}")'.format(self.name) + db, table_name = self._pragma_name + sql = 'PRAGMA {}index_list("{}")'.format(db, table_name) indexes = [] for row in self.db.execute_returning_dicts(sql): index_name = row["name"] @@ -1579,7 +1626,7 @@ def indexes(self) -> List[Index]: if not index_name.startswith('"') else index_name ) - column_sql = "PRAGMA index_info({})".format(index_name_quoted) + column_sql = "PRAGMA {}index_info({})".format(db, index_name_quoted) columns = [] for seqno, cid, name in self.db.execute(column_sql).fetchall(): columns.append(name) @@ -1594,7 +1641,8 @@ def indexes(self) -> List[Index]: @property def xindexes(self) -> List[XIndex]: "List of indexes defined on this table using the more detailed ``XIndex`` format." - sql = 'PRAGMA index_list("{}")'.format(self.name) + db, table_name = self._pragma_name + sql = 'PRAGMA {}index_list("{}")'.format(db, table_name) indexes = [] for row in self.db.execute_returning_dicts(sql): index_name = row["name"] @@ -1603,7 +1651,7 @@ def xindexes(self) -> List[XIndex]: if not index_name.startswith('"') else index_name ) - column_sql = "PRAGMA index_xinfo({})".format(index_name_quoted) + column_sql = "PRAGMA {}index_xinfo({})".format(db, index_name_quoted) index_columns = [] for info in self.db.execute(column_sql).fetchall(): index_columns.append(XIndexColumn(*info)) @@ -1613,12 +1661,13 @@ def xindexes(self) -> List[XIndex]: @property def triggers(self) -> List[Trigger]: "List of triggers defined on this table." + db, table_name = self._pragma_name return [ Trigger(*r) for r in self.db.execute( - "select name, tbl_name, sql from sqlite_master where type = 'trigger'" + f"select name, tbl_name, sql from {db}sqlite_master where type = 'trigger'" " and tbl_name = ?", - (self.name,), + (table_name,), ).fetchall() ] @@ -1710,9 +1759,9 @@ def duplicate(self, new_name: str) -> "Table": if not self.exists(): raise NoTable(f"Table {self.name} does not exist") with self.db.conn: - sql = "CREATE TABLE [{new_table}] AS SELECT * FROM [{table}];".format( + sql = "CREATE TABLE {new_table} AS SELECT * FROM {table};".format( new_table=new_name, - table=self.name, + table=self._fullname, ) self.db.execute(sql) return self.db[new_name] @@ -1766,21 +1815,22 @@ def transform( column_order=column_order, keep_table=keep_table, ) - pragma_foreign_keys_was_on = self.db.execute("PRAGMA foreign_keys").fetchone()[ - 0 - ] + db, _ = self._pragma_name + pragma_foreign_keys_was_on = self.db.execute( + f"PRAGMA {db}foreign_keys" + ).fetchone()[0] try: if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_keys=0;") + self.db.execute(f"PRAGMA {db}foreign_keys=0;") with self.db.conn: for sql in sqls: self.db.execute(sql) # Run the foreign_key_check before we commit if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_key_check;") + self.db.execute(f"PRAGMA {db}foreign_key_check;") finally: if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_keys=1;") + self.db.execute(f"PRAGMA {db}foreign_keys=1;") return self def transform_sql( @@ -1945,9 +1995,12 @@ def transform_sql( if "rowid" not in new_cols: new_cols.insert(0, "rowid") old_cols.insert(0, "rowid") - copy_sql = "INSERT INTO [{new_table}] ({new_cols})\n SELECT {old_cols} FROM [{old_table}];".format( - new_table=new_table_name, - old_table=self.name, + + old_fullname = _fullname(self.schema_name, self.name) + new_fullname = _fullname(self.schema_name, new_table_name) + copy_sql = "INSERT INTO {new_table} ({new_cols})\n SELECT {old_cols} FROM {old_table};".format( + new_table=new_fullname, + old_table=old_fullname, old_cols=", ".join("[{}]".format(col) for col in old_cols), new_cols=", ".join("[{}]".format(col) for col in new_cols), ) @@ -1955,14 +2008,14 @@ def transform_sql( # Drop (or keep) the old table if keep_table: sqls.append( - "ALTER TABLE [{}] RENAME TO [{}];".format(self.name, keep_table) + "ALTER TABLE {} RENAME TO {};".format( + old_fullname, _fullname(self.schema_name, keep_table) + ) ) else: - sqls.append("DROP TABLE [{}];".format(self.name)) + sqls.append("DROP TABLE {};".format(old_fullname)) # Rename the new one - sqls.append( - "ALTER TABLE [{}] RENAME TO [{}];".format(new_table_name, self.name) - ) + sqls.append("ALTER TABLE {} RENAME TO {};".format(new_fullname, old_fullname)) return sqls def extract( @@ -2024,11 +2077,11 @@ def extract( lookup_columns = [(rename.get(col) or col) for col in columns] lookup_table.create_index(lookup_columns, unique=True, if_not_exists=True) self.db.execute( - "INSERT OR IGNORE INTO [{lookup_table}] ({lookup_columns}) SELECT DISTINCT {table_cols} FROM [{table}]".format( - lookup_table=table, + "INSERT OR IGNORE INTO {lookup_table} ({lookup_columns}) SELECT DISTINCT {table_cols} FROM {table}".format( + lookup_table=_fullname(self.schema_name, table), lookup_columns=", ".join("[{}]".format(c) for c in lookup_columns), table_cols=", ".join("[{}]".format(c) for c in columns), - table=self.name, + table=self._fullname, ) ) @@ -2036,15 +2089,16 @@ def extract( self.add_column(magic_lookup_column, int) # And populate it + lookup_table_full = _fullname(self.schema_name, table) self.db.execute( - "UPDATE [{table}] SET [{magic_lookup_column}] = (SELECT id FROM [{lookup_table}] WHERE {where})".format( - table=self.name, + "UPDATE {table} SET [{magic_lookup_column}] = (SELECT id FROM {lookup_table} WHERE {where})".format( + table=self._fullname, magic_lookup_column=magic_lookup_column, - lookup_table=table, + lookup_table=lookup_table_full, where=" AND ".join( - "[{table}].[{column}] IS [{lookup_table}].[{lookup_column}]".format( - table=self.name, - lookup_table=table, + "{table}.[{column}] IS {lookup_table}.[{lookup_column}]".format( + table=self._fullname, + lookup_table=lookup_table_full, column=column, lookup_column=rename.get(column) or column, ) @@ -2118,13 +2172,13 @@ def create_index( textwrap.dedent( """ CREATE {unique}INDEX {if_not_exists}[{index_name}] - ON [{table_name}] ({columns}); + ON {table_name} ({columns}); """ ) .strip() .format( index_name=created_index_name, - table_name=self.name, + table_name=self._fullname, columns=", ".join(columns_sql), unique="UNIQUE " if unique else "", if_not_exists="IF NOT EXISTS " if if_not_exists else "", @@ -2172,7 +2226,7 @@ def add_column( fk_col_type = None if fk is not None: # fk must be a valid table - if fk not in self.db.table_names(): + if fk not in self.db.table_names(schema=self.schema_name): raise AlterError("table '{}' does not exist".format(fk)) # if fk_col specified, must be a valid column if fk_col is not None: @@ -2194,8 +2248,8 @@ def add_column( not_null_sql = "NOT NULL DEFAULT {}".format( self.db.quote_default_value(not_null_default) ) - sql = "ALTER TABLE [{table}] ADD COLUMN [{col_name}] {col_type}{not_null_default};".format( - table=self.name, + sql = "ALTER TABLE {table} ADD COLUMN [{col_name}] {col_type}{not_null_default};".format( + table=self._fullname, col_name=col_name, col_type=fk_col_type or COLUMN_TYPE_MAPPING[col_type], not_null_default=(" " + not_null_sql) if not_null_sql else "", @@ -2212,7 +2266,7 @@ def drop(self, ignore: bool = False): :param ignore: Set to ``True`` to ignore the error if the table does not exist """ try: - self.db.execute("DROP TABLE [{}]".format(self.name)) + self.db.execute("DROP TABLE {}".format(self._fullname)) except sqlite3.OperationalError: if not ignore: raise @@ -2238,7 +2292,9 @@ def guess_foreign_table(self, column: str) -> str: possibilities.append(column_without_id + "s") elif not column.endswith("s"): possibilities.append(column + "s") - existing_tables = {t.lower(): t for t in self.db.table_names()} + existing_tables = { + t.lower(): t for t in self.db.table_names(schema=self.schema_name) + } for table in possibilities: if table in existing_tables: return existing_tables[table] @@ -2379,6 +2435,9 @@ def enable_fts( """ Enable SQLite full-text search against the specified columns. + Creates the FTS virtual table(s) in the `main` database, even if the + source table is in an attached database. + See :ref:`python_api_fts` for more details. :param columns: List of column names to include in the search index. @@ -2387,6 +2446,7 @@ def enable_fts( :param tokenize: Custom SQLite tokenizer to use, for example ``"porter"`` to enable Porter stemming. :param replace: Should any existing FTS index for this table be replaced by the new one? """ + table_name = self.name create_fts_sql = ( textwrap.dedent( """ @@ -2398,19 +2458,21 @@ def enable_fts( ) .strip() .format( - table=self.name, + table=table_name, columns=", ".join("[{}]".format(c) for c in columns), fts_version=fts_version, tokenize="\n tokenize='{}',".format(tokenize) if tokenize else "", ) ) should_recreate = False - if replace and self.db["{}_fts".format(self.name)].exists(): + if replace and self.db["{}_fts".format(table_name)].exists(): # Does the table need to be recreated? - fts_schema = self.db["{}_fts".format(self.name)].schema + fts_schema = self.db["{}_fts".format(table_name)].schema if fts_schema != create_fts_sql: should_recreate = True - expected_triggers = {self.name + suffix for suffix in ("_ai", "_ad", "_au")} + expected_triggers = { + table_name + suffix for suffix in ("_ai", "_ad", "_au") + } existing_triggers = {t.name for t in self.triggers} has_triggers = existing_triggers.issuperset(expected_triggers) if has_triggers != create_triggers: @@ -2445,7 +2507,7 @@ def enable_fts( ) .strip() .format( - table=self.name, + table=table_name, columns=", ".join("[{}]".format(c) for c in columns), old_cols=old_cols, new_cols=new_cols, @@ -2506,11 +2568,9 @@ def rebuild_fts(self): fts_table = self.detect_fts() if fts_table is None: # Assume this is itself an FTS table - fts_table = self.name + fts_table = self._fullname self.db.execute( - "INSERT INTO [{table}]([{table}]) VALUES('rebuild');".format( - table=fts_table - ) + "INSERT INTO {table}({table}) VALUES('rebuild');".format(table=fts_table) ) return self @@ -2530,10 +2590,11 @@ def detect_fts(self) -> Optional[str]: ) """ ).strip() + table_name = self.name args = { - "like": "%VIRTUAL TABLE%USING FTS%content=[{}]%".format(self.name), - "like2": '%VIRTUAL TABLE%USING FTS%content="{}"%'.format(self.name), - "table": self.name, + "like": "%VIRTUAL TABLE%USING FTS%content=[{}]%".format(table_name), + "like2": '%VIRTUAL TABLE%USING FTS%content="{}"%'.format(table_name), + "table": table_name, } rows = self.db.execute(sql, args).fetchall() if len(rows) == 0: @@ -2593,7 +2654,7 @@ def search_sql( select rowid, {columns} - from [{dbtable}]{where_clause} + from {dbtable}{where_clause} ) select {columns_with_prefix} @@ -2622,7 +2683,7 @@ def search_sql( if offset is not None: limit_offset += " offset {}".format(offset) return sql.format( - dbtable=self.name, + dbtable=self._fullname, where_clause="\n where {}".format(where) if where else "", original=original, columns=columns_sql, @@ -2693,8 +2754,8 @@ def delete(self, pk_values: Union[list, tuple, str, int, float]) -> "Table": pk_values = [pk_values] self.get(pk_values) wheres = ["[{}] = ?".format(pk_name) for pk_name in self.pks] - sql = "delete from [{table}] where {wheres}".format( - table=self.name, wheres=" and ".join(wheres) + sql = "delete from {table} where {wheres}".format( + table=self._fullname, wheres=" and ".join(wheres) ) with self.db.conn: self.db.execute(sql, pk_values) @@ -2718,7 +2779,7 @@ def delete_where( """ if not self.exists(): return self - sql = "delete from [{}]".format(self.name) + sql = f"delete from {self._fullname}" if where is not None: sql += " where " + where self.db.execute(sql, where_args or []) @@ -2763,8 +2824,8 @@ def update( args.append(jsonify_if_needed(value)) wheres = ["[{}] = ?".format(pk_name) for pk_name in pks] args.extend(pk_values) - sql = "update [{table}] set {sets} where {wheres}".format( - table=self.name, sets=", ".join(sets), wheres=" and ".join(wheres) + sql = "update {table} set {sets} where {wheres}".format( + table=self._fullname, sets=", ".join(sets), wheres=" and ".join(wheres) ) with self.db.conn: try: @@ -2844,8 +2905,8 @@ def convert_value(v): if fn_name == "": fn_name = f"lambda_{abs(hash(fn))}" self.db.register_function(convert_value, name=fn_name) - sql = "update [{table}] set {sets}{where};".format( - table=self.name, + sql = "update {table} set {sets}{where};".format( + table=self._fullname, sets=", ".join( [ "[{output_column}] = {fn_name}([{column}])".format( @@ -2968,8 +3029,8 @@ def build_insert_queries_and_params( # them since it ignores the resulting integrity errors if not_null: placeholders.extend(not_null) - sql = "INSERT OR IGNORE INTO [{table}]({cols}) VALUES({placeholders});".format( - table=self.name, + sql = "INSERT OR IGNORE INTO {table}({cols}) VALUES({placeholders});".format( + table=self._fullname, cols=", ".join(["[{}]".format(p) for p in placeholders]), placeholders=", ".join(["?" for p in placeholders]), ) @@ -2979,8 +3040,8 @@ def build_insert_queries_and_params( # UPDATE [book] SET [name] = 'Programming' WHERE [id] = 1001; set_cols = [col for col in all_columns if col not in pks] if set_cols: - sql2 = "UPDATE [{table}] SET {pairs} WHERE {wheres}".format( - table=self.name, + sql2 = "UPDATE {table} SET {pairs} WHERE {wheres}".format( + table=self._fullname, pairs=", ".join( "[{}] = {}".format(col, conversions.get(col, "?")) for col in set_cols @@ -3007,10 +3068,10 @@ def build_insert_queries_and_params( elif ignore: or_what = "OR IGNORE " sql = """ - INSERT {or_what}INTO [{table}] ({columns}) VALUES {rows}; + INSERT {or_what}INTO {table} ({columns}) VALUES {rows}; """.strip().format( or_what=or_what, - table=self.name, + table=self._fullname, columns=", ".join("[{}]".format(c) for c in all_columns), rows=", ".join( "({placeholders})".format( @@ -3268,7 +3329,7 @@ def insert_all( self.last_rowid = None self.last_pk = None if truncate and self.exists(): - self.db.execute("DELETE FROM [{}];".format(self.name)) + self.db.execute("DELETE FROM {};".format(self._fullname)) for chunk in chunks(itertools.chain([first_record], records), batch_size): chunk = list(chunk) num_records_processed += len(chunk) @@ -3752,7 +3813,9 @@ def create_spatial_index(self, column_name) -> bool: :param column_name: Geometry column to create the spatial index against """ - if f"idx_{self.name}_{column_name}" in self.db.table_names(): + if f"idx_{self.name}_{column_name}" in self.db.table_names( + schema=self.schema_name + ): return False cursor = self.db.execute( @@ -3779,7 +3842,7 @@ def drop(self, ignore=False): """ try: - self.db.execute("DROP VIEW [{}]".format(self.name)) + self.db.execute("DROP VIEW {}".format(self._fullname)) except sqlite3.OperationalError: if not ignore: raise