diff --git a/benchmark_test.go b/benchmark_test.go index d8a423a..597c9c1 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "io" + "net/http" + "net/http/httptest" "os" "os/exec" "path/filepath" @@ -122,3 +124,30 @@ func BenchmarkReadCompressedSQLite(b *testing.B) { } } } + +func BenchmarkReadCompressedHTTPSQLite(b *testing.B) { + _, zstPath, cleanup := setupDB(b) + defer cleanup() + + zstDir := filepath.Dir(zstPath) + + server := httptest.NewServer(http.FileServer(http.Dir(zstDir))) + defer server.Close() + + client, err := sql.Open("sqlite3", fmt.Sprintf("%s/%s?vfs=zstd", server.URL, filepath.Base(zstPath))) + if err != nil { + b.Fatalf("Query failed: %v", err) + } + defer client.Close() + + b.ResetTimer() // Start timing now. + + for i := 0; i < b.N; i++ { + var count int + + err = client.QueryRow("SELECT MAX(value) FROM entries").Scan(&count) + if err != nil { + b.Fatalf("Query failed: %v", err) + } + } +} diff --git a/file.go b/file.go index 1534f2d..3bc7fb9 100644 --- a/file.go +++ b/file.go @@ -2,7 +2,6 @@ package sqlitezstd import ( "io" - "os" seekable "github.com/SaveTheRbtz/zstd-seekable-format-go" "github.com/klauspost/compress/zstd" @@ -11,7 +10,7 @@ import ( type ZstdFile struct { decoder *zstd.Decoder - file *os.File + closer io.Closer seekable seekable.Reader } @@ -23,7 +22,7 @@ func (z *ZstdFile) CheckReservedLock() (bool, error) { func (z *ZstdFile) Close() error { _ = z.seekable.Close() - _ = z.file.Close() + _ = z.closer.Close() return nil } diff --git a/go.mod b/go.mod index 029c623..b3ab59d 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/onsi/ginkgo/v2 v2.16.0 github.com/onsi/gomega v1.31.1 github.com/psanford/sqlite3vfs v0.0.0-20231201192653-4c99abef8114 + howett.net/ranger v0.0.0-20171016084633-e2e137620847 ) require ( diff --git a/go.sum b/go.sum index 32d2fca..306e414 100644 --- a/go.sum +++ b/go.sum @@ -61,3 +61,5 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +howett.net/ranger v0.0.0-20171016084633-e2e137620847 h1:jHX+2Sv8rQNb1nG2G9pcQLdVoWXBLW5fcd2gnfHluCU= +howett.net/ranger v0.0.0-20171016084633-e2e137620847/go.mod h1:ZWGIG4mR6Ck+CdmFqK2/jox5vO2OAS8Qpb33HB8z0og= diff --git a/sqlite_zstd_suite_test.go b/sqlite_zstd_suite_test.go index 2b734a1..934d4e8 100644 --- a/sqlite_zstd_suite_test.go +++ b/sqlite_zstd_suite_test.go @@ -3,6 +3,8 @@ package sqlitezstd_test import ( "database/sql" "fmt" + "net/http" + "net/http/httptest" "os" "os/exec" "path/filepath" @@ -114,4 +116,23 @@ var _ = Describe("SqliteZSTD", func() { Expect(row.Err()).To(HaveOccurred()) }) }) + + It("does something", func() { + zstPath := createDatabase() + zstDir := filepath.Dir(zstPath) + server := httptest.NewServer(http.FileServer(http.Dir(zstDir))) + defer server.Close() + + client, err := sql.Open("sqlite3", fmt.Sprintf("%s/%s?vfs=zstd", server.URL, filepath.Base(zstPath))) + Expect(err).ToNot(HaveOccurred()) + defer client.Close() + + row := client.QueryRow("SELECT COUNT(*) FROM entries;") + Expect(row.Err()).ToNot(HaveOccurred()) + + var count int64 + err = row.Scan(&count) + Expect(err).ToNot(HaveOccurred()) + Expect(count).To(BeEquivalentTo(1000)) + }) }) diff --git a/vfs.go b/vfs.go index 11cf5c1..71533b3 100644 --- a/vfs.go +++ b/vfs.go @@ -2,6 +2,8 @@ package sqlitezstd import ( "fmt" + "io" + "net/url" "os" "strings" @@ -9,6 +11,7 @@ import ( "github.com/klauspost/compress/zstd" _ "github.com/mattn/go-sqlite3" "github.com/psanford/sqlite3vfs" + "howett.net/ranger" ) type ZstdVFS struct{} @@ -32,9 +35,32 @@ func (z *ZstdVFS) FullPathname(name string) string { } func (z *ZstdVFS) Open(name string, flags sqlite3vfs.OpenFlag) (sqlite3vfs.File, sqlite3vfs.OpenFlag, error) { - file, err := os.Open(name) - if err != nil { - return nil, 0, sqlite3vfs.CantOpenError + var ( + err error + reader io.ReadSeeker + closer io.Closer + ) + + if strings.HasPrefix(name, "http://") || strings.HasPrefix(name, "https://") { + uri, err := url.Parse(name) + if err != nil { + return nil, 0, sqlite3vfs.CantOpenError + } + + reader, err = ranger.NewReader(&ranger.HTTPRanger{URL: uri}) + if err != nil { + return nil, 0, sqlite3vfs.CantOpenError + } + + closer = io.NopCloser(reader) + } else { + reader, err = os.Open(name) + if err != nil { + return nil, 0, sqlite3vfs.CantOpenError + } + + //nolint: forcetypeassert + closer = reader.(io.Closer) } decoder, err := zstd.NewReader(nil) @@ -42,14 +68,14 @@ func (z *ZstdVFS) Open(name string, flags sqlite3vfs.OpenFlag) (sqlite3vfs.File, return nil, 0, sqlite3vfs.CantOpenError } - seekable, err := seekable.NewReader(file, decoder) + seekable, err := seekable.NewReader(reader, decoder) if err != nil { return nil, 0, sqlite3vfs.CantOpenError } return &ZstdFile{ decoder: decoder, - file: file, + closer: closer, seekable: seekable, }, flags | sqlite3vfs.OpenReadOnly, nil }