From 3d4b31f059fefffb0a83e1556d833da4cfc2c088 Mon Sep 17 00:00:00 2001 From: Naji Obeid Date: Mon, 26 Feb 2024 16:19:26 +0000 Subject: [PATCH 1/6] change update backend - updater now uses the new metadata endpoint to check for edition updates. - updater now uses the new download endpoint to directly download editions. --- cmd/geoipupdate/main.go | 6 +- pkg/geoipupdate/database/http_reader.go | 166 ----------- pkg/geoipupdate/database/http_reader_test.go | 160 ---------- .../database/local_file_writer_test.go | 124 -------- pkg/geoipupdate/database/reader.go | 76 ----- pkg/geoipupdate/database/writer.go | 11 - pkg/geoipupdate/download/download.go | 149 ++++++++++ pkg/geoipupdate/download/download_test.go | 36 +++ pkg/geoipupdate/download/metadata.go | 98 ++++++ pkg/geoipupdate/download/metadata_test.go | 99 +++++++ pkg/geoipupdate/download/output.go | 40 +++ pkg/geoipupdate/download/output_test.go | 39 +++ .../writer.go} | 181 ++++++------ pkg/geoipupdate/download/writer_test.go | 278 ++++++++++++++++++ pkg/geoipupdate/geoip_updater.go | 132 ++++----- pkg/geoipupdate/geoip_updater_test.go | 234 ++++++++------- pkg/geoipupdate/internal/errors.go | 15 +- pkg/geoipupdate/internal/errors_test.go | 4 +- 18 files changed, 1024 insertions(+), 824 deletions(-) delete mode 100644 pkg/geoipupdate/database/http_reader.go delete mode 100644 pkg/geoipupdate/database/http_reader_test.go delete mode 100644 pkg/geoipupdate/database/local_file_writer_test.go delete mode 100644 pkg/geoipupdate/database/reader.go delete mode 100644 pkg/geoipupdate/database/writer.go create mode 100644 pkg/geoipupdate/download/download.go create mode 100644 pkg/geoipupdate/download/download_test.go create mode 100644 pkg/geoipupdate/download/metadata.go create mode 100644 pkg/geoipupdate/download/metadata_test.go create mode 100644 pkg/geoipupdate/download/output.go create mode 100644 pkg/geoipupdate/download/output_test.go rename pkg/geoipupdate/{database/local_file_writer.go => download/writer.go} (54%) create mode 100644 pkg/geoipupdate/download/writer_test.go diff --git a/cmd/geoipupdate/main.go b/cmd/geoipupdate/main.go index f8e4a81a..d25fb71e 100644 --- a/cmd/geoipupdate/main.go +++ b/cmd/geoipupdate/main.go @@ -49,7 +49,11 @@ func main() { log.Printf("Using database directory %s", config.DatabaseDirectory) } - client := geoipupdate.NewClient(config) + client, err := geoipupdate.NewClient(config) + if err != nil { + log.Fatalf("Error initializing download client: %s", err) + } + if err = client.Run(context.Background()); err != nil { log.Fatalf("Error retrieving updates: %s", err) } diff --git a/pkg/geoipupdate/database/http_reader.go b/pkg/geoipupdate/database/http_reader.go deleted file mode 100644 index 7cb234fe..00000000 --- a/pkg/geoipupdate/database/http_reader.go +++ /dev/null @@ -1,166 +0,0 @@ -// Package database provides an abstraction over getting and writing a -// database file. -package database - -import ( - "compress/gzip" - "context" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/url" - "strconv" - "time" - - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/internal" - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/vars" -) - -const urlFormat = "%s/geoip/databases/%s/update?db_md5=%s" - -// HTTPReader is a Reader that uses an HTTP client to retrieve -// databases. -type HTTPReader struct { - // client is an http client responsible of fetching database updates. - client *http.Client - // path is the request path. - path string - // accountID is used for request auth. - accountID int - // licenseKey is used for request auth. - licenseKey string - // verbose turns on/off debug logs. - verbose bool -} - -// NewHTTPReader creates a Reader that downloads database updates via -// HTTP. -func NewHTTPReader( - proxy *url.URL, - path string, - accountID int, - licenseKey string, - verbose bool, -) Reader { - transport := http.DefaultTransport - if proxy != nil { - proxyFunc := http.ProxyURL(proxy) - transport.(*http.Transport).Proxy = proxyFunc - } - - return &HTTPReader{ - client: &http.Client{Transport: transport}, - path: path, - accountID: accountID, - licenseKey: licenseKey, - verbose: verbose, - } -} - -// Read attempts to fetch database updates for a specific editionID. -// It takes an editionID and it's previously downloaded hash if available -// as arguments and returns a ReadResult struct as a response. -// It's the responsibility of the Writer to close the io.ReadCloser -// included in the response after consumption. -func (r *HTTPReader) Read(ctx context.Context, editionID, hash string) (*ReadResult, error) { - result, err := r.get(ctx, editionID, hash) - if err != nil { - return nil, fmt.Errorf("getting update for %s: %w", editionID, err) - } - - return result, nil -} - -// get makes an http request to fetch updates for a specific editionID if any. -func (r *HTTPReader) get( - ctx context.Context, - editionID string, - hash string, -) (result *ReadResult, err error) { - requestURL := fmt.Sprintf( - urlFormat, - r.path, - url.PathEscape(editionID), - url.QueryEscape(hash), - ) - - if r.verbose { - log.Printf("Requesting updates for %s: %s", editionID, requestURL) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - req.Header.Add("User-Agent", "geoipupdate/"+vars.Version) - req.SetBasicAuth(strconv.Itoa(r.accountID), r.licenseKey) - - response, err := r.client.Do(req) - if err != nil { - return nil, fmt.Errorf("performing HTTP request: %w", err) - } - // It is safe to close the response body reader as it wouldn't be - // consumed in case this function returns an error. - defer func() { - if err != nil { - response.Body.Close() - } - }() - - switch response.StatusCode { - case http.StatusNotModified: - if r.verbose { - log.Printf("No new updates available for %s", editionID) - } - return &ReadResult{EditionID: editionID, OldHash: hash, NewHash: hash}, nil - case http.StatusOK: - default: - //nolint:errcheck // we are already returning an error. - buf, _ := io.ReadAll(io.LimitReader(response.Body, 256)) - httpErr := internal.HTTPError{ - Body: string(buf), - StatusCode: response.StatusCode, - } - return nil, fmt.Errorf("unexpected HTTP status code: %w", httpErr) - } - - newHash := response.Header.Get("X-Database-MD5") - if newHash == "" { - return nil, errors.New("no X-Database-MD5 header found") - } - - modifiedAt, err := parseTime(response.Header.Get("Last-Modified")) - if err != nil { - return nil, fmt.Errorf("reading Last-Modified header: %w", err) - } - - gzReader, err := gzip.NewReader(response.Body) - if err != nil { - return nil, fmt.Errorf("encountered an error creating GZIP reader: %w", err) - } - - if r.verbose { - log.Printf("Updates available for %s", editionID) - } - - return &ReadResult{ - reader: gzReader, - EditionID: editionID, - OldHash: hash, - NewHash: newHash, - ModifiedAt: modifiedAt, - }, nil -} - -// parseTime parses a string representation of a time into time.Time according to the -// RFC1123 format. -func parseTime(s string) (time.Time, error) { - t, err := time.ParseInLocation(time.RFC1123, s, time.UTC) - if err != nil { - return time.Time{}, fmt.Errorf("parsing time: %w", err) - } - - return t, nil -} diff --git a/pkg/geoipupdate/database/http_reader_test.go b/pkg/geoipupdate/database/http_reader_test.go deleted file mode 100644 index 6f68dd7d..00000000 --- a/pkg/geoipupdate/database/http_reader_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package database - -import ( - "bytes" - "compress/gzip" - "context" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -// TestHTTPReader tests the functionality of the HTTPReader.Read method. -func TestHTTPReader(t *testing.T) { - testTime := time.Date(2023, 4, 10, 12, 47, 31, 0, time.UTC) - - tests := []struct { - description string - checkErr func(require.TestingT, error, ...interface{}) //nolint:revive // support older versions - requestEdition string - requestHash string - responseStatus int - responseBody string - responseHash string - responseTime string - result *ReadResult - }{ - { - description: "success", - checkErr: require.NoError, - requestEdition: "GeoIP2-City", - requestHash: "fbe1786bfd80e1db9dc42ddaff868f38", - responseStatus: http.StatusOK, - responseBody: "database content", - responseHash: "cfa36ddc8279b5483a5aa25e9a6151f4", - responseTime: testTime.Format(time.RFC1123), - result: &ReadResult{ - reader: getReader(t, "database content"), - EditionID: "GeoIP2-City", - OldHash: "fbe1786bfd80e1db9dc42ddaff868f38", - NewHash: "cfa36ddc8279b5483a5aa25e9a6151f4", - ModifiedAt: testTime, - }, - }, { - description: "no new update", - checkErr: require.NoError, - requestEdition: "GeoIP2-City", - requestHash: "fbe1786bfd80e1db9dc42ddaff868f38", - responseStatus: http.StatusNotModified, - responseBody: "", - responseHash: "", - responseTime: "", - result: &ReadResult{ - reader: nil, - EditionID: "GeoIP2-City", - OldHash: "fbe1786bfd80e1db9dc42ddaff868f38", - NewHash: "fbe1786bfd80e1db9dc42ddaff868f38", - ModifiedAt: time.Time{}, - }, - }, { - description: "bad request", - checkErr: require.Error, - requestEdition: "GeoIP2-City", - requestHash: "fbe1786bfd80e1db9dc42ddaff868f38", - responseStatus: http.StatusBadRequest, - responseBody: "", - responseHash: "", - responseTime: "", - }, { - description: "missing hash header", - checkErr: require.Error, - requestEdition: "GeoIP2-City", - requestHash: "fbe1786bfd80e1db9dc42ddaff868f38", - responseStatus: http.StatusOK, - responseBody: "database content", - responseHash: "", - responseTime: testTime.Format(time.RFC1123), - }, { - description: "modified time header wrong format", - checkErr: require.Error, - requestEdition: "GeoIP2-City", - requestHash: "fbe1786bfd80e1db9dc42ddaff868f38", - responseStatus: http.StatusOK, - responseBody: "database content", - responseHash: "fbe1786bfd80e1db9dc42ddaff868f38", - responseTime: testTime.Format(time.Kitchen), - }, - } - - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - server := httptest.NewServer( - http.HandlerFunc( - func(w http.ResponseWriter, _ *http.Request) { - if test.responseStatus != http.StatusOK { - w.WriteHeader(test.responseStatus) - return - } - - w.Header().Set("X-Database-MD5", test.responseHash) - w.Header().Set("Last-Modified", test.responseTime) - - buf := &bytes.Buffer{} - gzWriter := gzip.NewWriter(buf) - _, err := gzWriter.Write([]byte(test.responseBody)) - require.NoError(t, err) - require.NoError(t, gzWriter.Flush()) - require.NoError(t, gzWriter.Close()) - _, err = w.Write(buf.Bytes()) - require.NoError(t, err) - }, - ), - ) - defer server.Close() - - reader := NewHTTPReader( - nil, // request proxy. - server.URL, // fixed, as the server is mocked above. - 10, // fixed, as it's not valuable for the purpose of the test. - "license", // fixed, as it's not valuable for the purpose of the test. - false, // verbose - ) - - result, err := reader.Read(context.Background(), test.requestEdition, test.requestHash) - test.checkErr(t, err) - if err == nil { - require.Equal(t, result.EditionID, test.result.EditionID) - require.Equal(t, result.OldHash, test.result.OldHash) - require.Equal(t, result.NewHash, test.result.NewHash) - require.Equal(t, result.ModifiedAt, test.result.ModifiedAt) - - if test.result.reader != nil && result.reader != nil { - defer result.reader.Close() - defer test.result.reader.Close() - resultDatabase, err := io.ReadAll(test.result.reader) - require.NoError(t, err) - expectedDatabase, err := io.ReadAll(result.reader) - require.NoError(t, err) - require.Equal(t, expectedDatabase, resultDatabase) - } - } - }) - } -} - -//nolint:unparam // complains that it always receives the same string to encode. ridiculous. -func getReader(t *testing.T, s string) io.ReadCloser { - var buf bytes.Buffer - gz := gzip.NewWriter(&buf) - _, err := gz.Write([]byte(s)) - require.NoError(t, err) - require.NoError(t, gz.Close()) - require.NoError(t, gz.Flush()) - r, err := gzip.NewReader(&buf) - require.NoError(t, err) - return r -} diff --git a/pkg/geoipupdate/database/local_file_writer_test.go b/pkg/geoipupdate/database/local_file_writer_test.go deleted file mode 100644 index c4ea63c4..00000000 --- a/pkg/geoipupdate/database/local_file_writer_test.go +++ /dev/null @@ -1,124 +0,0 @@ -package database - -import ( - "os" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -// TestLocalFileWriterWrite tests functionality of the LocalFileWriter.Write method. -func TestLocalFileWriterWrite(t *testing.T) { - testTime := time.Date(2023, 4, 10, 12, 47, 31, 0, time.UTC) - - tests := []struct { - description string - //nolint:revive // support older versions - checkErr func(require.TestingT, error, ...interface{}) - preserveFileTime bool - //nolint:revive // support older versions - checkTime func(require.TestingT, interface{}, interface{}, ...interface{}) - result *ReadResult - }{ - { - description: "success", - checkErr: require.NoError, - preserveFileTime: true, - checkTime: require.Equal, - result: &ReadResult{ - reader: getReader(t, "database content"), - EditionID: "GeoIP2-City", - OldHash: "", - NewHash: "cfa36ddc8279b5483a5aa25e9a6151f4", - ModifiedAt: testTime, - }, - }, { - description: "hash does not match", - checkErr: require.Error, - preserveFileTime: true, - checkTime: require.Equal, - result: &ReadResult{ - reader: getReader(t, "database content"), - EditionID: "GeoIP2-City", - OldHash: "", - NewHash: "badhash", - ModifiedAt: testTime, - }, - }, { - description: "hash case does not matter", - checkErr: require.NoError, - preserveFileTime: true, - checkTime: require.Equal, - result: &ReadResult{ - reader: getReader(t, "database content"), - EditionID: "GeoIP2-City", - OldHash: "", - NewHash: "cfa36ddc8279b5483a5aa25e9a6151f4", - ModifiedAt: testTime, - }, - }, { - description: "do not preserve file modification time", - checkErr: require.NoError, - preserveFileTime: false, - checkTime: require.NotEqual, - result: &ReadResult{ - reader: getReader(t, "database content"), - EditionID: "GeoIP2-City", - OldHash: "", - NewHash: "CFA36DDC8279B5483A5AA25E9A6151F4", - ModifiedAt: testTime, - }, - }, - } - - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - tempDir := t.TempDir() - defer test.result.reader.Close() - - fw, err := NewLocalFileWriter(tempDir, test.preserveFileTime, false) - require.NoError(t, err) - - err = fw.Write(test.result) - test.checkErr(t, err) - if err == nil { - database, err := os.Stat(fw.getFilePath(test.result.EditionID)) - require.NoError(t, err) - - test.checkTime(t, database.ModTime().UTC(), testTime) - } - }) - } -} - -// TestLocalFileWriterGetHash tests functionality of the LocalFileWriter.GetHash method. -func TestLocalFileWriterGetHash(t *testing.T) { - result := &ReadResult{ - reader: getReader(t, "database content"), - EditionID: "GeoIP2-City", - OldHash: "", - NewHash: "cfa36ddc8279b5483a5aa25e9a6151f4", - ModifiedAt: time.Time{}, - } - - tempDir := t.TempDir() - - defer result.reader.Close() - - fw, err := NewLocalFileWriter(tempDir, false, false) - require.NoError(t, err) - - err = fw.Write(result) - require.NoError(t, err) - - // returns the correct hash for an existing database. - hash, err := fw.GetHash(result.EditionID) - require.NoError(t, err) - require.Equal(t, hash, result.NewHash) - - // returns a zero hash for a non existing edition. - hash, err = fw.GetHash("NewEdition") - require.NoError(t, err) - require.Equal(t, ZeroMD5, hash) -} diff --git a/pkg/geoipupdate/database/reader.go b/pkg/geoipupdate/database/reader.go deleted file mode 100644 index 17f49353..00000000 --- a/pkg/geoipupdate/database/reader.go +++ /dev/null @@ -1,76 +0,0 @@ -package database - -import ( - "context" - "encoding/json" - "fmt" - "io" - "time" -) - -// Reader provides an interface for retrieving a database update and copying it -// into place. -type Reader interface { - Read(context.Context, string, string) (*ReadResult, error) -} - -// ReadResult is the struct returned by a Reader's Get method. -type ReadResult struct { - reader io.ReadCloser - EditionID string `json:"edition_id"` - OldHash string `json:"old_hash"` - NewHash string `json:"new_hash"` - ModifiedAt time.Time `json:"modified_at"` - CheckedAt time.Time `json:"checked_at"` -} - -// MarshalJSON is a custom json marshaler that strips out zero time fields. -func (r ReadResult) MarshalJSON() ([]byte, error) { - type partialResult ReadResult - s := &struct { - partialResult - ModifiedAt int64 `json:"modified_at,omitempty"` - CheckedAt int64 `json:"checked_at,omitempty"` - }{ - partialResult: partialResult(r), - ModifiedAt: 0, - CheckedAt: 0, - } - - if !r.ModifiedAt.IsZero() { - s.ModifiedAt = r.ModifiedAt.Unix() - } - - if !r.CheckedAt.IsZero() { - s.CheckedAt = r.CheckedAt.Unix() - } - - res, err := json.Marshal(s) - if err != nil { - return nil, fmt.Errorf("marshaling ReadResult: %w", err) - } - return res, nil -} - -// UnmarshalJSON is a custom json unmarshaler that converts timestamps to go -// time fields. -func (r *ReadResult) UnmarshalJSON(data []byte) error { - type partialResult ReadResult - s := &struct { - partialResult - ModifiedAt int64 `json:"modified_at,omitempty"` - CheckedAt int64 `json:"checked_at,omitempty"` - }{} - - err := json.Unmarshal(data, &s) - if err != nil { - return fmt.Errorf("unmarshaling json into ReadResult: %w", err) - } - - result := ReadResult(s.partialResult) - result.ModifiedAt = time.Unix(s.ModifiedAt, 0).In(time.UTC) - result.CheckedAt = time.Unix(s.CheckedAt, 0).In(time.UTC) - *r = result - - return nil -} diff --git a/pkg/geoipupdate/database/writer.go b/pkg/geoipupdate/database/writer.go deleted file mode 100644 index e5f40158..00000000 --- a/pkg/geoipupdate/database/writer.go +++ /dev/null @@ -1,11 +0,0 @@ -package database - -// ZeroMD5 is the default value provided as an MD5 hash for a non-existent -// database. -const ZeroMD5 = "00000000000000000000000000000000" - -// Writer provides an interface for writing a database to a target location. -type Writer interface { - Write(*ReadResult) error - GetHash(editionID string) (string, error) -} diff --git a/pkg/geoipupdate/download/download.go b/pkg/geoipupdate/download/download.go new file mode 100644 index 00000000..ac110f97 --- /dev/null +++ b/pkg/geoipupdate/download/download.go @@ -0,0 +1,149 @@ +// Package download provides a library for checking/downloading/updating mmdb files. +package download + +import ( + "context" + "crypto/md5" + "encoding/hex" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "path/filepath" + "time" +) + +const ( + Extension = ".mmdb" + + // zeroMD5 is the default value provided as an MD5 hash for a non-existent + // database. + zeroMD5 = "00000000000000000000000000000000" +) + +type Downloader interface { + GetOutdatedEditions(ctx context.Context) ([]Metadata, error) + DownloadEdition(ctx context.Context, edition Metadata) error + MakeOutput() ([]byte, error) +} + +// Download exposes methods needed to check for and perform update to a set of mmdb editions. +type Download struct { + // accountID is the requester's account ID. + accountID int + // client is an http client responsible of fetching database updates. + client *http.Client + // databaseDir is the database download path. + databaseDir string + // editionIDs is the list of editions to be updated. + editionIDs []string + // licenseKey is the requester's license key. + licenseKey string + // oldEditionsHash holds the hashes of the previously downloaded mmdb editions. + oldEditionsHash map[string]string + // metadata holds the metadata pulled for each edition. + metadata []Metadata + // preserveFileTimes sets whether database modification times are preserved across downloads. + preserveFileTimes bool + // url points to maxmind servers. + url string + + now func() time.Time + verbose bool +} + +// New initializes a new Downloader struct. +func New( + accountID int, + licenseKey string, + url string, + proxy *url.URL, + databaseDir string, + preserveFileTimes bool, + editionIDs []string, + verbose bool, +) (*Download, error) { + transport := http.DefaultTransport + if proxy != nil { + proxyFunc := http.ProxyURL(proxy) + transport.(*http.Transport).Proxy = proxyFunc + } + + d := Download{ + accountID: accountID, + client: &http.Client{Transport: transport}, + databaseDir: databaseDir, + editionIDs: editionIDs, + licenseKey: licenseKey, + oldEditionsHash: map[string]string{}, + preserveFileTimes: preserveFileTimes, + url: url, + now: time.Now, + verbose: verbose, + } + + for _, e := range editionIDs { + hash, err := d.getHash(e) + if err != nil { + return nil, fmt.Errorf("getting existing %q database hash: %w", e, err) + } + d.oldEditionsHash[e] = hash + } + + return &d, nil +} + +// getHash returns the hash of a certain database file. +func (d *Download) getHash(editionID string) (string, error) { + databaseFilePath := d.getFilePath(editionID) + //nolint:gosec // we really need to read this file. + database, err := os.Open(databaseFilePath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + if d.verbose { + log.Print("Database does not exist, returning zeroed hash") + } + return zeroMD5, nil + } + return "", fmt.Errorf("opening database: %w", err) + } + + defer func() { + if err := database.Close(); err != nil { + log.Println(fmt.Errorf("closing database: %w", err)) + } + }() + + md5Hash := md5.New() + if _, err := io.Copy(md5Hash, database); err != nil { + return "", fmt.Errorf("calculating database hash: %w", err) + } + + result := byteToString(md5Hash.Sum(nil)) + if d.verbose { + log.Printf("Calculated MD5 sum for %s: %s", databaseFilePath, result) + } + return result, nil +} + +// getFilePath construct the file path for a database edition. +func (d *Download) getFilePath(editionID string) string { + return filepath.Join(d.databaseDir, editionID) + Extension +} + +// byteToString returns the base16 representation of a byte array. +func byteToString(b []byte) string { + return hex.EncodeToString(b) +} + +// ParseTime parses the date returned in a metadata response to time.Time. +func ParseTime(dateString string) (time.Time, error) { + t, err := time.ParseInLocation("2006-01-02", dateString, time.Local) + if err != nil { + return time.Time{}, fmt.Errorf("parsing edition date: %w", err) + } + return t, nil +} diff --git a/pkg/geoipupdate/download/download_test.go b/pkg/geoipupdate/download/download_test.go new file mode 100644 index 00000000..4b4f5642 --- /dev/null +++ b/pkg/geoipupdate/download/download_test.go @@ -0,0 +1,36 @@ +package download + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestGetHash tests functionality of the LocalFileWriter.GetHash method. +func TestGetHash(t *testing.T) { + tempDir, err := os.MkdirTemp("", "db") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + d := Download{ + databaseDir: tempDir, + } + + // returns a zero hash for a non existing edition. + hash, err := d.getHash("NewEdition") + require.NoError(t, err) + require.Equal(t, zeroMD5, hash) + + // returns the correct md5 for an existing edition. + edition := "edition-1" + dbFile := filepath.Join(tempDir, edition+Extension) + + err = os.WriteFile(dbFile, []byte("edition-1 content"), os.ModePerm) + require.NoError(t, err) + + hash, err = d.getHash(edition) + require.NoError(t, err) + require.Equal(t, "618dd27a10de24809ec160d6807f363f", hash) +} diff --git a/pkg/geoipupdate/download/metadata.go b/pkg/geoipupdate/download/metadata.go new file mode 100644 index 00000000..cf24f40b --- /dev/null +++ b/pkg/geoipupdate/download/metadata.go @@ -0,0 +1,98 @@ +package download + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/internal" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/vars" +) + +const ( + metadataEndpoint = "%s/geoip/updates/metadata?%s" +) + +// metadataResponse represents a successful response returned by the metadata endpoint. +type metadataResponse struct { + Databases []Metadata `json:"databases"` +} + +// databaseResponse represents the metadata content for a certain database returned by the +// metadata endpoint. +type Metadata struct { + Date string `json:"date"` + EditionID string `json:"edition_id"` + MD5 string `json:"md5"` +} + +// GetOutdatedEditions returns the list of outdated database editions. +func (d *Download) GetOutdatedEditions(ctx context.Context) ([]Metadata, error) { + var editionsQuery []string + for _, e := range d.editionIDs { + editionsQuery = append(editionsQuery, "edition_id="+url.QueryEscape(e)) + } + + requestURL := fmt.Sprintf(metadataEndpoint, d.url, strings.Join(editionsQuery, "&")) + if d.verbose { + log.Printf("Requesting edition metadata: %s", requestURL) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Add("User-Agent", "geoipupdate/"+vars.Version) + req.SetBasicAuth(strconv.Itoa(d.accountID), d.licenseKey) + + response, err := d.client.Do(req) + if err != nil { + return nil, fmt.Errorf("performing HTTP request: %w", err) + } + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("reading response body: %w", err) + } + + if response.StatusCode != http.StatusOK { + errResponse := internal.ResponseError{ + StatusCode: response.StatusCode, + } + + if err := json.Unmarshal(responseBody, &errResponse); err != nil { + errResponse.Message = err.Error() + } + + return nil, fmt.Errorf("requesting metadata: %w", errResponse) + } + + var metadata metadataResponse + if err := json.Unmarshal(responseBody, &metadata); err != nil { + return nil, fmt.Errorf("parsing body: %w", err) + } + + var outdatedEditions []Metadata + for _, m := range metadata.Databases { + oldMD5 := d.oldEditionsHash[m.EditionID] + if oldMD5 != m.MD5 { + outdatedEditions = append(outdatedEditions, m) + continue + } + + if d.verbose { + log.Printf("Database %s up to date", m.EditionID) + } + } + + d.metadata = metadata.Databases + + return outdatedEditions, nil +} diff --git a/pkg/geoipupdate/download/metadata_test.go b/pkg/geoipupdate/download/metadata_test.go new file mode 100644 index 00000000..82e297e9 --- /dev/null +++ b/pkg/geoipupdate/download/metadata_test.go @@ -0,0 +1,99 @@ +package download + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestGetOutdatedEditions checks the metadata fetching functionality. +func TestGetOutdatedEditions(t *testing.T) { + tempDir, err := os.MkdirTemp("", "db") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + edition := "edition-1" + dbFile := filepath.Join(tempDir, edition+Extension) + // equivalent MD5: 618dd27a10de24809ec160d6807f363f + err = os.WriteFile(dbFile, []byte("edition-1 content"), os.ModePerm) + require.NoError(t, err) + + edition = "edition-2" + dbFile = filepath.Join(tempDir, edition+Extension) + err = os.WriteFile(dbFile, []byte("edition-2 content"), os.ModePerm) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jsonData := ` +{ + "databases": [ + { + "edition_id": "edition-1", + "md5": "618dd27a10de24809ec160d6807f363f", + "date": "2024-02-23" + }, + { + "edition_id": "edition-2", + "md5": "abc123", + "date": "2024-02-23" + }, + { + "edition_id": "edition-3", + "md5": "def456", + "date": "2024-02-02" + } + ] +} +` + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(jsonData)) + })) + defer server.Close() + + ctx := context.Background() + d, err := New( + 0, // accountID is not relevant for this test. + "", // licenseKey is not relevant for this test. + server.URL, + nil, // proxy is not relevant for this test. + tempDir, + false, // preserveFileTimes is not relevant for this test. + []string{"edition-1", "edition-2", "edition-3"}, + false, // verbose is not relevant for this test. + ) + require.NoError(t, err) + + // edition-1 md5 hasn't changed + expectedOutdatedEditions := []Metadata{ + { + EditionID: "edition-2", + MD5: "abc123", + Date: "2024-02-23", + }, + { + EditionID: "edition-3", + MD5: "def456", + Date: "2024-02-02", + }, + } + + outdatedEditions, err := d.GetOutdatedEditions(ctx) + require.NoError(t, err) + require.ElementsMatch(t, expectedOutdatedEditions, outdatedEditions) + + expectedDatabases := append( + expectedOutdatedEditions, + Metadata{ + EditionID: "edition-1", + MD5: "618dd27a10de24809ec160d6807f363f", + Date: "2024-02-23", + }, + ) + require.ElementsMatch(t, expectedDatabases, d.metadata) +} diff --git a/pkg/geoipupdate/download/output.go b/pkg/geoipupdate/download/output.go new file mode 100644 index 00000000..5a8424e5 --- /dev/null +++ b/pkg/geoipupdate/download/output.go @@ -0,0 +1,40 @@ +package download + +import ( + "encoding/json" + "fmt" +) + +// ReadResult is the struct returned by a Reader's Get method. +type output struct { + EditionID string `json:"edition_id"` + OldHash string `json:"old_hash"` + NewHash string `json:"new_hash"` + ModifiedAt int64 `json:"modified_at"` + CheckedAt int64 `json:"checked_at"` +} + +// MakeOutput returns a json formatted summary about the current download attempt. +func (d *Download) MakeOutput() ([]byte, error) { + var out []output + now := d.now() + for _, m := range d.metadata { + modifiedAt, err := ParseTime(m.Date) + if err != nil { + return nil, err + } + o := output{ + EditionID: m.EditionID, + OldHash: d.oldEditionsHash[m.EditionID], + NewHash: m.MD5, + ModifiedAt: modifiedAt.Unix(), + CheckedAt: now.Unix(), + } + out = append(out, o) + } + res, err := json.Marshal(out) + if err != nil { + return nil, fmt.Errorf("marshaling output: %w", err) + } + return res, nil +} diff --git a/pkg/geoipupdate/download/output_test.go b/pkg/geoipupdate/download/output_test.go new file mode 100644 index 00000000..96f786e3 --- /dev/null +++ b/pkg/geoipupdate/download/output_test.go @@ -0,0 +1,39 @@ +package download + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestOutputFormat(t *testing.T) { + now, err := ParseTime("2024-02-23") + require.NoError(t, err) + + d := Download{ + oldEditionsHash: map[string]string{ + "edition-1": "1", + "edition-2": "10", + }, + metadata: []Metadata{ + { + EditionID: "edition-1", + MD5: "1", + Date: "2024-01-01", + }, + { + EditionID: "edition-2", + MD5: "11", + Date: "2024-02-01", + }, + }, + now: func() time.Time { return now }, + } + + expectedOutput := `[{"edition_id":"edition-1","old_hash":"1","new_hash":"1","modified_at":1704067200,"checked_at":1708646400},{"edition_id":"edition-2","old_hash":"10","new_hash":"11","modified_at":1706745600,"checked_at":1708646400}]` + + output, err := d.MakeOutput() + require.NoError(t, err) + require.Equal(t, expectedOutput, string(output)) +} diff --git a/pkg/geoipupdate/database/local_file_writer.go b/pkg/geoipupdate/download/writer.go similarity index 54% rename from pkg/geoipupdate/database/local_file_writer.go rename to pkg/geoipupdate/download/writer.go index d76043e0..d634ab8b 100644 --- a/pkg/geoipupdate/database/local_file_writer.go +++ b/pkg/geoipupdate/download/writer.go @@ -1,75 +1,75 @@ -package database +package download import ( + "archive/tar" + "compress/gzip" + "context" "crypto/md5" - "encoding/hex" + "encoding/json" "errors" "fmt" "hash" "io" + "io/ioutil" "log" + "net/http" "os" "path/filepath" + "strconv" "strings" - "time" + + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/internal" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/vars" ) const ( - extension = ".mmdb" - tempExtension = ".temporary" + downloadEndpoint = "%s/geoip/databases/%s/download?date=%s&suffix=tar.gz" ) -// LocalFileWriter is a database.Writer that stores the database to the -// local file system. -type LocalFileWriter struct { - dir string - preserveFileTime bool - verbose bool -} +// DownloadEdition downloads and writes an edition to a database file. +func (d *Download) DownloadEdition(ctx context.Context, edition Metadata) error { + date := strings.ReplaceAll(edition.Date, "-", "") + requestURL := fmt.Sprintf(downloadEndpoint, d.url, edition.EditionID, date) + if d.verbose { + log.Printf("Downloading: %s", requestURL) + } -// NewLocalFileWriter create a LocalFileWriter. -func NewLocalFileWriter( - databaseDir string, - preserveFileTime bool, - verbose bool, -) (*LocalFileWriter, error) { - err := os.MkdirAll(filepath.Dir(databaseDir), 0o750) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) if err != nil { - return nil, fmt.Errorf("creating database directory: %w", err) + return fmt.Errorf("creating request: %w", err) } + req.Header.Add("User-Agent", "geoipupdate/"+vars.Version) + req.SetBasicAuth(strconv.Itoa(d.accountID), d.licenseKey) - return &LocalFileWriter{ - dir: databaseDir, - preserveFileTime: preserveFileTime, - verbose: verbose, - }, nil -} + response, err := d.client.Do(req) + if err != nil { + return fmt.Errorf("performing HTTP request: %w", err) + } + defer response.Body.Close() -// Write writes the result struct returned by a Reader to a database file. -func (w *LocalFileWriter) Write(result *ReadResult) (err error) { - // exit early if we've got the latest database version. - if strings.EqualFold(result.OldHash, result.NewHash) { - if w.verbose { - log.Printf("Database %s up to date", result.EditionID) + if response.StatusCode != http.StatusOK { + responseBody, err := ioutil.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("reading response body: %w", err) } - return nil - } - defer func() { - if closeErr := result.reader.Close(); closeErr != nil { - err = errors.Join( - err, - fmt.Errorf("closing reader for %s: %w", result.EditionID, closeErr), - ) + errResponse := internal.ResponseError{ + StatusCode: response.StatusCode, + } + + if err := json.Unmarshal(responseBody, &errResponse); err != nil { + errResponse.Message = err.Error() } - }() - databaseFilePath := w.getFilePath(result.EditionID) + return fmt.Errorf("requesting edition: %w", errResponse) + } + + databaseFilePath := d.getFilePath(edition.EditionID) - // write the Reader's result into a temporary file. - fw, err := newFileWriter(databaseFilePath + tempExtension) + // write the result into a temporary file. + fw, err := newFileWriter(databaseFilePath + ".temporary") if err != nil { - return fmt.Errorf("setting up database writer for %s: %w", result.EditionID, err) + return fmt.Errorf("setting up database writer for %s: %w", edition.EditionID, err) } defer func() { if closeErr := fw.close(); closeErr != nil { @@ -80,13 +80,36 @@ func (w *LocalFileWriter) Write(result *ReadResult) (err error) { } }() - if err = fw.write(result.reader); err != nil { - return fmt.Errorf("writing to the temp file for %s: %w", result.EditionID, err) + gzReader, err := gzip.NewReader(response.Body) + if err != nil { + return fmt.Errorf("encountered an error creating GZIP reader: %w", err) + } + defer gzReader.Close() + + tarReader := tar.NewReader(gzReader) + + // iterate through the tar archive to extract the mmdb file + for { + header, err := tarReader.Next() + if err == io.EOF { + return errors.New("tar archive does not contain an mmdb file") + } + if err != nil { + return fmt.Errorf("reading tar archive: %w", err) + } + + if strings.HasSuffix(header.Name, Extension) { + break + } + } + + if err = fw.write(tarReader); err != nil { + return fmt.Errorf("writing to the temp file for %s: %w", edition.EditionID, err) } // make sure the hash of the temp file matches the expected hash. - if err = fw.validateHash(result.NewHash); err != nil { - return fmt.Errorf("validating hash for %s: %w", result.EditionID, err) + if err = fw.validateHash(edition.MD5); err != nil { + return fmt.Errorf("validating hash for %s: %w", edition.EditionID, err) } // move the temoporary database file into it's final location and @@ -101,57 +124,19 @@ func (w *LocalFileWriter) Write(result *ReadResult) (err error) { } // check if we need to set the file's modified at time - if w.preserveFileTime { - if err = setModifiedAtTime(databaseFilePath, result.ModifiedAt); err != nil { + if d.preserveFileTimes { + if err = setModifiedAtTime(databaseFilePath, edition.Date); err != nil { return err } } - if w.verbose { - log.Printf("Database %s successfully updated: %+v", result.EditionID, result.NewHash) + if d.verbose { + log.Printf("Database %s successfully updated: %+v", edition.EditionID, edition.MD5) } return nil } -// GetHash returns the hash of the current database file. -func (w *LocalFileWriter) GetHash(editionID string) (string, error) { - databaseFilePath := w.getFilePath(editionID) - //nolint:gosec // we really need to read this file. - database, err := os.Open(databaseFilePath) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - if w.verbose { - log.Print("Database does not exist, returning zeroed hash") - } - return ZeroMD5, nil - } - return "", fmt.Errorf("opening database: %w", err) - } - - defer func() { - if err := database.Close(); err != nil { - log.Println(fmt.Errorf("closing database: %w", err)) - } - }() - - md5Hash := md5.New() - if _, err := io.Copy(md5Hash, database); err != nil { - return "", fmt.Errorf("calculating database hash: %w", err) - } - - result := byteToString(md5Hash.Sum(nil)) - if w.verbose { - log.Printf("Calculated MD5 sum for %s: %s", databaseFilePath, result) - } - return result, nil -} - -// getFilePath construct the file path for a database edition. -func (w *LocalFileWriter) getFilePath(editionID string) string { - return filepath.Join(w.dir, editionID) + extension -} - // fileWriter is used to write the content of a Reader's response // into a file. type fileWriter struct { @@ -247,14 +232,14 @@ func syncDir(path string) error { } // setModifiedAtTime sets the times for a database file to a certain value. -func setModifiedAtTime(path string, t time.Time) error { - if err := os.Chtimes(path, t, t); err != nil { +func setModifiedAtTime(path string, dateString string) error { + releaseDate, err := ParseTime(dateString) + if err != nil { + return err + } + + if err := os.Chtimes(path, releaseDate, releaseDate); err != nil { return fmt.Errorf("setting times on file %s: %w", path, err) } return nil } - -// byteToString returns the base16 representation of a byte array. -func byteToString(b []byte) string { - return hex.EncodeToString(b) -} diff --git a/pkg/geoipupdate/download/writer_test.go b/pkg/geoipupdate/download/writer_test.go new file mode 100644 index 00000000..ffd12d76 --- /dev/null +++ b/pkg/geoipupdate/download/writer_test.go @@ -0,0 +1,278 @@ +package download + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// TestDownloadEdition checks the database download functionality. +func TestDownloadEdition(t *testing.T) { + tempDir, err := os.MkdirTemp("", "db") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + ctx := context.Background() + now, err := ParseTime("2024-02-23") + require.NoError(t, err) + + edition := Metadata{ + EditionID: "edition-1", + Date: "2024-02-02", + MD5: "618dd27a10de24809ec160d6807f363f", + } + + dbContent := "edition-1 content" + + tests := []struct { + description string + preserveFileTime bool + server func(t *testing.T) *httptest.Server + checkResult func(t *testing.T, err error) + }{ + { + description: "successful download", + preserveFileTime: false, + server: func(t *testing.T) *httptest.Server { + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + header := &tar.Header{ + Name: "edition-1" + Extension, + Size: int64(len(dbContent)), + } + + err = tw.WriteHeader(header) + require.NoError(t, err) + _, err = tw.Write([]byte(dbContent)) + require.NoError(t, err) + + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") + _, err := io.Copy(w, &buf) + require.NoError(t, err) + })) + + return server + }, + checkResult: func(t *testing.T, err error) { + require.NoError(t, err) + + dbFile := filepath.Join(tempDir, edition.EditionID+Extension) + + fileContent, err := os.ReadFile(dbFile) + require.NoError(t, err) + require.Equal(t, dbContent, string(fileContent)) + + database, err := os.Stat(dbFile) + require.NoError(t, err) + require.GreaterOrEqual(t, database.ModTime(), now) + }, + }, + { + description: "successful download - preserve time", + preserveFileTime: true, + server: func(t *testing.T) *httptest.Server { + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + header := &tar.Header{ + Name: "edition-1" + Extension, + Size: int64(len(dbContent)), + } + + err = tw.WriteHeader(header) + require.NoError(t, err) + _, err = tw.Write([]byte(dbContent)) + require.NoError(t, err) + + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") + _, err := io.Copy(w, &buf) + require.NoError(t, err) + })) + + return server + }, + checkResult: func(t *testing.T, err error) { + require.NoError(t, err) + + dbFile := filepath.Join(tempDir, edition.EditionID+Extension) + + fileContent, err := os.ReadFile(dbFile) + require.NoError(t, err) + require.Equal(t, dbContent, string(fileContent)) + + modTime, err := ParseTime(edition.Date) + require.NoError(t, err) + + database, err := os.Stat(dbFile) + require.NoError(t, err) + require.Equal(t, modTime, database.ModTime()) + }, + }, + { + description: "server error", + preserveFileTime: false, + server: func(t *testing.T) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + return server + }, + checkResult: func(t *testing.T, err error) { + require.Error(t, err) + require.Regexp(t, "^requesting edition: received: HTTP status code '500'", err.Error()) + }, + }, + { + description: "wrong file format", + preserveFileTime: false, + server: func(t *testing.T) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jsonData := `{"message": "Hello, world!", "status": "ok"}` + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(jsonData)) + })) + return server + }, + checkResult: func(t *testing.T, err error) { + require.Error(t, err) + require.Regexp(t, "^encountered an error creating GZIP reader: gzip: invalid header", err.Error()) + }, + }, + { + description: "empty tar archive", + preserveFileTime: false, + server: func(t *testing.T) *httptest.Server { + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") + _, err := io.Copy(w, &buf) + require.NoError(t, err) + })) + + return server + }, + checkResult: func(t *testing.T, err error) { + require.Error(t, err) + require.Regexp(t, "^tar archive does not contain an mmdb file", err.Error()) + }, + }, + { + description: "tar does not contain an mmdb file", + preserveFileTime: false, + server: func(t *testing.T) *httptest.Server { + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + header := &tar.Header{ + Name: "edition-1.zip", + Size: int64(len(dbContent)), + } + + err = tw.WriteHeader(header) + require.NoError(t, err) + _, err = tw.Write([]byte(dbContent)) + require.NoError(t, err) + + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") + _, err := io.Copy(w, &buf) + require.NoError(t, err) + })) + + return server + }, + checkResult: func(t *testing.T, err error) { + require.Error(t, err) + require.Regexp(t, "^tar archive does not contain an mmdb file", err.Error()) + }, + }, + { + description: "mmdb hash does not match metadata", + preserveFileTime: false, + server: func(t *testing.T) *httptest.Server { + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + header := &tar.Header{ + Name: "edition-1" + Extension, + Size: int64(len(dbContent) - 1), + } + + err = tw.WriteHeader(header) + require.NoError(t, err) + _, err = tw.Write([]byte(dbContent[:len(dbContent)-1])) + require.NoError(t, err) + + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") + _, err := io.Copy(w, &buf) + require.NoError(t, err) + })) + + return server + }, + checkResult: func(t *testing.T, err error) { + require.Error(t, err) + require.Regexp(t, "^validating hash for edition-1: md5 of new database .* does not match expected md5", err.Error()) + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + server := test.server(t) + defer server.Close() + + d := Download{ + client: http.DefaultClient, + databaseDir: tempDir, + now: func() time.Time { return now }, + preserveFileTimes: test.preserveFileTime, + url: server.URL, + } + + err = d.DownloadEdition(ctx, edition) + test.checkResult(t, err) + }) + } +} diff --git a/pkg/geoipupdate/geoip_updater.go b/pkg/geoipupdate/geoip_updater.go index 88f11754..fe962c86 100644 --- a/pkg/geoipupdate/geoip_updater.go +++ b/pkg/geoipupdate/geoip_updater.go @@ -4,54 +4,46 @@ package geoipupdate import ( "context" - "encoding/json" "fmt" "log" "os" - "sync" "time" "github.com/cenkalti/backoff/v4" - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/database" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/download" "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/internal" ) // Client uses config data to initiate a download or update // process for GeoIP databases. type Client struct { - config *Config - getReader func() (database.Reader, error) - getWriter func() (database.Writer, error) - output *log.Logger + config *Config + downloader download.Downloader + output *log.Logger } // NewClient initialized a new Client struct. -func NewClient(config *Config) *Client { - getReader := func() (database.Reader, error) { - return database.NewHTTPReader( - config.Proxy, - config.URL, - config.AccountID, - config.LicenseKey, - config.Verbose, - ), nil - } - - getWriter := func() (database.Writer, error) { - return database.NewLocalFileWriter( - config.DatabaseDirectory, - config.PreserveFileTimes, - config.Verbose, - ) +func NewClient(config *Config) (*Client, error) { + d, err := download.New( + config.AccountID, + config.LicenseKey, + config.URL, + config.Proxy, + config.DatabaseDirectory, + config.PreserveFileTimes, + config.EditionIDs, + config.Verbose, + ) + if err != nil { + return nil, fmt.Errorf("initializing database downloader: %w", err) } return &Client{ - config: config, - getReader: getReader, - getWriter: getWriter, - output: log.New(os.Stdout, "", 0), - } + config: config, + downloader: d, + output: log.New(os.Stdout, "", 0), + }, nil } // Run starts the download or update process. @@ -69,36 +61,40 @@ func (c *Client) Run(ctx context.Context) error { } }() - jobProcessor := internal.NewJobProcessor(ctx, c.config.Parallelism) + var outdatedEditions []download.Metadata + getOutdatedEditions := func() (err error) { + outdatedEditions, err = c.downloader.GetOutdatedEditions(ctx) + if err != nil { + return fmt.Errorf("getting outdated database editions: %w", err) + } + return nil + } - reader, err := c.getReader() - if err != nil { - return fmt.Errorf("initializing database reader: %w", err) + if err := c.retry(ctx, getOutdatedEditions, "Couldn't get download metadata"); err != nil { + return fmt.Errorf("getting download metadata: %w", err) } - writer, err := c.getWriter() - if err != nil { - return fmt.Errorf("initializing database writer: %w", err) + downloadEdition := func(edition download.Metadata) error { + if err := c.downloader.DownloadEdition(ctx, edition); err != nil { + return fmt.Errorf("downloading edition '%s': %w", edition.EditionID, err) + } + return nil } - var editions []database.ReadResult - var mu sync.Mutex - for _, editionID := range c.config.EditionIDs { - editionID := editionID + jobProcessor := internal.NewJobProcessor(ctx, c.config.Parallelism) + for _, edition := range outdatedEditions { + edition := edition processFunc := func(ctx context.Context) error { - edition, err := c.downloadEdition(ctx, editionID, reader, writer) + err := c.retry( + ctx, + func() error { return downloadEdition(edition) }, + fmt.Sprintf("Couldn't download %s", edition.EditionID), + ) if err != nil { return err } - - edition.CheckedAt = time.Now().In(time.UTC) - - mu.Lock() - editions = append(editions, *edition) - mu.Unlock() return nil } - jobProcessor.Add(processFunc) } @@ -109,7 +105,7 @@ func (c *Client) Run(ctx context.Context) error { } if c.config.Output { - result, err := json.Marshal(editions) + result, err := c.downloader.MakeOutput() if err != nil { return fmt.Errorf("marshaling result log: %w", err) } @@ -119,18 +115,12 @@ func (c *Client) Run(ctx context.Context) error { return nil } -// downloadEdition downloads the file with retries. -func (c *Client) downloadEdition( +// retry implements a retry functionality for downloads for non permanent errors. +func (c *Client) retry( ctx context.Context, - editionID string, - r database.Reader, - w database.Writer, -) (*database.ReadResult, error) { - editionHash, err := w.GetHash(editionID) - if err != nil { - return nil, err - } - + f func() error, + logMsg string, +) error { // RetryFor value of 0 means that no retries should be performed. // Max zero retries has to be set to achieve that // because the backoff never stops if MaxElapsedTime is zero. @@ -141,37 +131,21 @@ func (c *Client) downloadEdition( b = backoff.WithMaxRetries(exp, 0) } - var edition *database.ReadResult - err = backoff.RetryNotify( + return backoff.RetryNotify( func() error { - if edition, err = r.Read(ctx, editionID, editionHash); err != nil { + if err := f(); err != nil { if internal.IsPermanentError(err) { return backoff.Permanent(err) } - return err } - - if err = w.Write(edition); err != nil { - if internal.IsPermanentError(err) { - return backoff.Permanent(err) - } - - return err - } - return nil }, b, func(err error, d time.Duration) { if c.config.Verbose { - log.Printf("Couldn't download %s, retrying in %v: %v", editionID, d, err) + log.Printf("%s, retrying in %v: %v", logMsg, d, err) } }, ) - if err != nil { - return nil, err - } - - return edition, nil } diff --git a/pkg/geoipupdate/geoip_updater_test.go b/pkg/geoipupdate/geoip_updater_test.go index 876bb897..c0aec5ef 100644 --- a/pkg/geoipupdate/geoip_updater_test.go +++ b/pkg/geoipupdate/geoip_updater_test.go @@ -1,128 +1,162 @@ package geoipupdate import ( + "archive/tar" "bytes" + "compress/gzip" "context" - "encoding/json" - "errors" - "log" + "io" + "net/http" + "net/http/httptest" + "os" "path/filepath" + "strings" "testing" "time" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/download" "github.com/stretchr/testify/require" - "golang.org/x/net/http2" - - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/database" ) -// TestClientOutput makes sure that the client outputs the result of it's -// operation to stdout in json format. -func TestClientOutput(t *testing.T) { - now := time.Now().Truncate(time.Second).In(time.UTC) - testTime := time.Date(2023, 4, 27, 12, 4, 48, 0, time.UTC) - databases := []database.ReadResult{ - { - EditionID: "GeoLite2-City", - OldHash: "A", - NewHash: "B", - ModifiedAt: testTime, - }, { - EditionID: "GeoIP2-Country", - OldHash: "C", - NewHash: "D", - ModifiedAt: testTime, - }, - } - - tempDir := t.TempDir() - - config := &Config{ - EditionIDs: []string{"GeoLite2-City", "GeoLite2-Country"}, - LockFile: filepath.Join(tempDir, ".geoipupdate.lock"), - Output: true, - Parallelism: 1, - } +func TestFullDownload(t *testing.T) { + testDate := time.Now().Truncate(24 * time.Hour) - // capture the output of the `output` logger. - logOutput := &bytes.Buffer{} - - // create a fake client with a mocked database reader and writer. - c := &Client{ - config: config, - getReader: func() (database.Reader, error) { - return &mockReader{i: 0, result: databases}, nil - }, - getWriter: func() (database.Writer, error) { - return &mockWriter{}, nil - }, - output: log.New(logOutput, "", 0), - } + // mock existing databases. + tempDir, err := os.MkdirTemp("", "db") + require.NoError(t, err) + defer os.RemoveAll(tempDir) - // run the client - err := c.Run(context.Background()) + edition := "edition-1" + dbFile := filepath.Join(tempDir, edition+download.Extension) + // equivalent MD5: 618dd27a10de24809ec160d6807f363f + err = os.WriteFile(dbFile, []byte("edition-1 content"), os.ModePerm) require.NoError(t, err) - // make sure the expected output matches the input. - var outputDatabases []database.ReadResult - err = json.Unmarshal(logOutput.Bytes(), &outputDatabases) + edition = "edition-2" + dbFile = filepath.Join(tempDir, edition+download.Extension) + err = os.WriteFile(dbFile, []byte("edition-2 content"), os.ModePerm) require.NoError(t, err) - require.Equal(t, len(outputDatabases), len(databases)) - - for i := 0; i < len(databases); i++ { - require.Equal(t, databases[i].EditionID, outputDatabases[i].EditionID) - require.Equal(t, databases[i].OldHash, outputDatabases[i].OldHash) - require.Equal(t, databases[i].NewHash, outputDatabases[i].NewHash) - require.Equal(t, databases[i].ModifiedAt, outputDatabases[i].ModifiedAt) - // comparing time wasn't supported with require in older go versions. - if !afterOrEqual(outputDatabases[i].CheckedAt, now) { - t.Errorf("database %s was not updated", outputDatabases[i].EditionID) - } - } - streamErr := http2.StreamError{ - Code: http2.ErrCodeInternal, - } - c.getWriter = func() (database.Writer, error) { - w := mockWriter{ - WriteFunc: func(_ *database.ReadResult) error { - return streamErr - }, + // mock metadata handler. + metadataHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jsonData := ` +{ + "databases": [ + { + "edition_id": "edition-1", + "md5": "618dd27a10de24809ec160d6807f363f", + "date": "2024-02-23" + }, + { + "edition_id": "edition-2", + "md5": "9960e83daa34d69e9b58b375616e145b", + "date": "2024-02-23" + }, + { + "edition_id": "edition-3", + "md5": "08628247c1e8c1aa6d05ffc578fa09a8", + "date": "2024-02-02" + } + ] +} +` + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(jsonData)) + }) + + // mock download handler. + downloadHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + edition := strings.Split(r.URL.Path, "/")[3] // extract the edition-id. + + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + content := "new " + edition + " content" + header := &tar.Header{ + Name: "edition-1" + download.Extension, + Size: int64(len(content)), } - return &w, nil - } - err = c.Run(context.Background()) - require.ErrorIs(t, err, streamErr) -} + err = tw.WriteHeader(header) + require.NoError(t, err) + _, err = tw.Write([]byte(content)) + require.NoError(t, err) + + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + + w.Header().Set("Content-Type", "application/gzip") + w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") + _, err := io.Copy(w, &buf) + require.NoError(t, err) + }) + + // create test server. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/geoip/updates/metadata") { + metadataHandler.ServeHTTP(w, r) + return + } -type mockReader struct { - i int - result []database.ReadResult -} + if strings.HasPrefix(r.URL.Path, "/geoip/databases") { + downloadHandler.ServeHTTP(w, r) + return + } -func (mr *mockReader) Read(_ context.Context, _, _ string) (*database.ReadResult, error) { - if mr.i >= len(mr.result) { - return nil, errors.New("out of bounds") + http.NotFound(w, r) + })) + defer server.Close() + + ctx := context.Background() + conf := &Config{ + AccountID: 0, // AccountID is not relevant for this test. + LicenseKey: "000000000001", // LicenseKey is not relevant for this test. + DatabaseDirectory: tempDir, + EditionIDs: []string{"edition-1", "edition-2", "edition-3"}, + LockFile: filepath.Clean(filepath.Join(tempDir, ".geoipupdate.lock")), + URL: server.URL, + RetryFor: 0, + Parallelism: 1, + PreserveFileTimes: true, } - res := mr.result[mr.i] - mr.i++ - return &res, nil -} -type mockWriter struct { - WriteFunc func(*database.ReadResult) error -} + client, err := NewClient(conf) + require.NoError(t, err) -func (w *mockWriter) Write(r *database.ReadResult) error { - if w.WriteFunc != nil { - return w.WriteFunc(r) - } + // download updates. + err = client.Run(ctx) + require.NoError(t, err) - return nil -} -func (w mockWriter) GetHash(_ string) (string, error) { return "", nil } + // edition-1 file hasn't been modified. + dbFile = filepath.Join(tempDir, "edition-1"+download.Extension) + fileContent, err := os.ReadFile(dbFile) + require.NoError(t, err) + require.Equal(t, "edition-1 content", string(fileContent)) + database, err := os.Stat(dbFile) + require.NoError(t, err) + require.LessOrEqual(t, testDate, database.ModTime()) -func afterOrEqual(t1, t2 time.Time) bool { - return t1.After(t2) || t1.Equal(t2) + // edition-2 file has been updated. + dbFile = filepath.Join(tempDir, "edition-2"+download.Extension) + fileContent, err = os.ReadFile(dbFile) + require.NoError(t, err) + require.Equal(t, "new edition-2 content", string(fileContent)) + modTime, err := download.ParseTime("2024-02-23") + require.NoError(t, err) + database, err = os.Stat(dbFile) + require.NoError(t, err) + require.Equal(t, modTime, database.ModTime()) + + // edition-3 file has been downloaded. + dbFile = filepath.Join(tempDir, "edition-3"+download.Extension) + fileContent, err = os.ReadFile(dbFile) + require.NoError(t, err) + require.Equal(t, "new edition-3 content", string(fileContent)) + modTime, err = download.ParseTime("2024-02-02") + require.NoError(t, err) + database, err = os.Stat(dbFile) + require.NoError(t, err) + require.Equal(t, modTime, database.ModTime()) } diff --git a/pkg/geoipupdate/internal/errors.go b/pkg/geoipupdate/internal/errors.go index 8d487c64..0fa5cae6 100644 --- a/pkg/geoipupdate/internal/errors.go +++ b/pkg/geoipupdate/internal/errors.go @@ -6,20 +6,21 @@ import ( "fmt" ) -// HTTPError is an error from performing an HTTP request. -type HTTPError struct { - Body string +// ResponseError represents an error response returned by the geoip servers. +type ResponseError struct { StatusCode int + Code string `json:"code"` + Message string `json:"error"` } -func (h HTTPError) Error() string { - return fmt.Sprintf("received HTTP status code: %d: %s", h.StatusCode, h.Body) +func (e ResponseError) Error() string { + return fmt.Sprintf("received: HTTP status code '%d' - Error code '%s' - Message '%s'", e.StatusCode, e.Code, e.Message) } // IsPermanentError returns true if the error is non-retriable. func IsPermanentError(err error) bool { - var httpErr HTTPError - if errors.As(err, &httpErr) && httpErr.StatusCode >= 400 && httpErr.StatusCode < 500 { + var r ResponseError + if errors.As(err, &r) && r.StatusCode >= 400 && r.StatusCode < 500 { return true } diff --git a/pkg/geoipupdate/internal/errors_test.go b/pkg/geoipupdate/internal/errors_test.go index f5ea2109..efa2ad46 100644 --- a/pkg/geoipupdate/internal/errors_test.go +++ b/pkg/geoipupdate/internal/errors_test.go @@ -19,13 +19,13 @@ func TestIsPermanentError(t *testing.T) { want: false, }, "bad gateway": { - err: HTTPError{ + err: ResponseError{ StatusCode: http.StatusBadGateway, }, want: false, }, "bad request": { - err: HTTPError{ + err: ResponseError{ StatusCode: http.StatusBadRequest, }, want: true, From 45d3e0efc2b2d74444dd9286fb9e339c682b2b8e Mon Sep 17 00:00:00 2001 From: Naji Obeid Date: Mon, 26 Feb 2024 16:30:51 +0000 Subject: [PATCH 2/6] mention new download package in changelog --- CHANGELOG.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1cba468..d7b2025d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,14 @@ * `geoipupdate` now supports retrying on more types of errors such as HTTP2 INTERNAL_ERROR. -* `HTTPReader` no longer retries on HTTP errors and therefore - `retryFor` was removed from `NewHTTPReader`. * Now `geoipupdate` doesn't requires the user to specify the config file even if all the other arguments are set via the environment variables. Reported by jsf84ksnf. GitHub #284. +* `pkg/goipupdate/database` has been rewritten into `pkg/geoipupdate/download`. + This new package changes the update behaviour of the tool: + 1. It will first request edition information from a newly introduced metadata endpoint. + 2. Compare the result with existing editions and decide which edition needs to be updated. + 3. Then download individual editions from a newly introduced download endpoint. ## 6.1.0 (2024-01-09) From 2d903c1e9cf32aacadb59b3c51e88a4cbc6d9730 Mon Sep 17 00:00:00 2001 From: Naji Obeid Date: Mon, 26 Feb 2024 16:42:20 +0000 Subject: [PATCH 3/6] remove e2e test that has been rewritten in another package. --- cmd/geoipupdate/end_to_end_test.go | 95 --------------------------- pkg/geoipupdate/geoip_updater_test.go | 8 +++ 2 files changed, 8 insertions(+), 95 deletions(-) delete mode 100644 cmd/geoipupdate/end_to_end_test.go diff --git a/cmd/geoipupdate/end_to_end_test.go b/cmd/geoipupdate/end_to_end_test.go deleted file mode 100644 index d8994956..00000000 --- a/cmd/geoipupdate/end_to_end_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package main - -import ( - "bytes" - "compress/gzip" - "context" - "crypto/md5" - "encoding/hex" - "io" - "log" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMultipleDatabaseDownload(t *testing.T) { - databaseContent := "database content goes here" - - server := httptest.NewServer( - http.HandlerFunc( - func(rw http.ResponseWriter, r *http.Request) { - err := r.ParseForm() - require.NoError(t, err, "parse form") - - if strings.HasPrefix(r.URL.Path, "/geoip/databases") { - buf := &bytes.Buffer{} - gzWriter := gzip.NewWriter(buf) - md5Writer := md5.New() - multiWriter := io.MultiWriter(gzWriter, md5Writer) - _, err := multiWriter.Write([]byte( - databaseContent + " " + r.URL.Path, - )) - require.NoError(t, err) - err = gzWriter.Close() - require.NoError(t, err) - - rw.Header().Set( - "X-Database-MD5", - hex.EncodeToString(md5Writer.Sum(nil)), - ) - rw.Header().Set("Last-Modified", time.Now().Format(time.RFC1123)) - - _, err = rw.Write(buf.Bytes()) - require.NoError(t, err) - - return - } - - rw.WriteHeader(http.StatusBadRequest) - }, - ), - ) - defer server.Close() - - tempDir := t.TempDir() - - config := &geoipupdate.Config{ - AccountID: 123, - DatabaseDirectory: tempDir, - EditionIDs: []string{"GeoLite2-City", "GeoLite2-Country"}, - LicenseKey: "testing", - LockFile: filepath.Join(tempDir, ".geoipupdate.lock"), - URL: server.URL, - Parallelism: 1, - } - - logOutput := &bytes.Buffer{} - log.SetOutput(logOutput) - - client := geoipupdate.NewClient(config) - err := client.Run(context.Background()) - require.NoError(t, err, "run successfully") - - assert.Equal(t, "", logOutput.String(), "no logged output") - - for _, editionID := range config.EditionIDs { - path := filepath.Join(config.DatabaseDirectory, editionID+".mmdb") - buf, err := os.ReadFile(filepath.Clean(path)) - require.NoError(t, err, "read file") - assert.Equal( - t, - databaseContent+" /geoip/databases/"+editionID+"/update", - string(buf), - "correct database", - ) - } -} diff --git a/pkg/geoipupdate/geoip_updater_test.go b/pkg/geoipupdate/geoip_updater_test.go index c0aec5ef..e7b20512 100644 --- a/pkg/geoipupdate/geoip_updater_test.go +++ b/pkg/geoipupdate/geoip_updater_test.go @@ -6,6 +6,7 @@ import ( "compress/gzip" "context" "io" + "log" "net/http" "net/http/httptest" "os" @@ -15,9 +16,11 @@ import ( "time" "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/download" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// TestFullDownload runs an end to end test simulation. func TestFullDownload(t *testing.T) { testDate := time.Now().Truncate(24 * time.Hour) @@ -122,6 +125,9 @@ func TestFullDownload(t *testing.T) { PreserveFileTimes: true, } + logOutput := &bytes.Buffer{} + log.SetOutput(logOutput) + client, err := NewClient(conf) require.NoError(t, err) @@ -129,6 +135,8 @@ func TestFullDownload(t *testing.T) { err = client.Run(ctx) require.NoError(t, err) + assert.Equal(t, "", logOutput.String(), "no logged output") + // edition-1 file hasn't been modified. dbFile = filepath.Join(tempDir, "edition-1"+download.Extension) fileContent, err := os.ReadFile(dbFile) From 9180d729157382506711a9dfed924524831044b6 Mon Sep 17 00:00:00 2001 From: Naji Obeid Date: Mon, 26 Feb 2024 17:06:39 +0000 Subject: [PATCH 4/6] fix linting: --- pkg/geoipupdate/download/download.go | 7 +++++-- pkg/geoipupdate/download/metadata.go | 6 +++--- pkg/geoipupdate/download/metadata_test.go | 10 ++++++---- pkg/geoipupdate/download/output_test.go | 1 + pkg/geoipupdate/download/writer.go | 5 ++--- pkg/geoipupdate/download/writer_test.go | 22 +++++++++++++--------- pkg/geoipupdate/geoip_updater.go | 8 +++----- pkg/geoipupdate/geoip_updater_test.go | 14 +++++++++----- pkg/geoipupdate/internal/errors.go | 7 ++++++- 9 files changed, 48 insertions(+), 32 deletions(-) diff --git a/pkg/geoipupdate/download/download.go b/pkg/geoipupdate/download/download.go index ac110f97..2e692714 100644 --- a/pkg/geoipupdate/download/download.go +++ b/pkg/geoipupdate/download/download.go @@ -17,6 +17,7 @@ import ( ) const ( + // Extension is the typical extension used for database files. Extension = ".mmdb" // zeroMD5 is the default value provided as an MD5 hash for a non-existent @@ -24,6 +25,8 @@ const ( zeroMD5 = "00000000000000000000000000000000" ) +// Downloader represents common methods required to implement the functionality +// required to download/update mmdb files. type Downloader interface { GetOutdatedEditions(ctx context.Context) ([]Metadata, error) DownloadEdition(ctx context.Context, edition Metadata) error @@ -59,7 +62,7 @@ type Download struct { func New( accountID int, licenseKey string, - url string, + serverURL string, proxy *url.URL, databaseDir string, preserveFileTimes bool, @@ -80,7 +83,7 @@ func New( licenseKey: licenseKey, oldEditionsHash: map[string]string{}, preserveFileTimes: preserveFileTimes, - url: url, + url: serverURL, now: time.Now, verbose: verbose, } diff --git a/pkg/geoipupdate/download/metadata.go b/pkg/geoipupdate/download/metadata.go index cf24f40b..4f27ebfd 100644 --- a/pkg/geoipupdate/download/metadata.go +++ b/pkg/geoipupdate/download/metadata.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "log" "net/http" "net/url" @@ -24,7 +24,7 @@ type metadataResponse struct { Databases []Metadata `json:"databases"` } -// databaseResponse represents the metadata content for a certain database returned by the +// Metadata represents the metadata content for a certain database returned by the // metadata endpoint. type Metadata struct { Date string `json:"date"` @@ -57,7 +57,7 @@ func (d *Download) GetOutdatedEditions(ctx context.Context) ([]Metadata, error) } defer response.Body.Close() - responseBody, err := ioutil.ReadAll(response.Body) + responseBody, err := io.ReadAll(response.Body) if err != nil { return nil, fmt.Errorf("reading response body: %w", err) } diff --git a/pkg/geoipupdate/download/metadata_test.go b/pkg/geoipupdate/download/metadata_test.go index 82e297e9..941ec187 100644 --- a/pkg/geoipupdate/download/metadata_test.go +++ b/pkg/geoipupdate/download/metadata_test.go @@ -28,7 +28,7 @@ func TestGetOutdatedEditions(t *testing.T) { err = os.WriteFile(dbFile, []byte("edition-2 content"), os.ModePerm) require.NoError(t, err) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { jsonData := ` { "databases": [ @@ -52,7 +52,8 @@ func TestGetOutdatedEditions(t *testing.T) { ` w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write([]byte(jsonData)) + _, err := w.Write([]byte(jsonData)) + require.NoError(t, err) })) defer server.Close() @@ -87,8 +88,9 @@ func TestGetOutdatedEditions(t *testing.T) { require.NoError(t, err) require.ElementsMatch(t, expectedOutdatedEditions, outdatedEditions) - expectedDatabases := append( - expectedOutdatedEditions, + expectedDatabases := expectedOutdatedEditions + expectedDatabases = append( + expectedDatabases, Metadata{ EditionID: "edition-1", MD5: "618dd27a10de24809ec160d6807f363f", diff --git a/pkg/geoipupdate/download/output_test.go b/pkg/geoipupdate/download/output_test.go index 96f786e3..fad8a025 100644 --- a/pkg/geoipupdate/download/output_test.go +++ b/pkg/geoipupdate/download/output_test.go @@ -31,6 +31,7 @@ func TestOutputFormat(t *testing.T) { now: func() time.Time { return now }, } + //nolint:lll expectedOutput := `[{"edition_id":"edition-1","old_hash":"1","new_hash":"1","modified_at":1704067200,"checked_at":1708646400},{"edition_id":"edition-2","old_hash":"10","new_hash":"11","modified_at":1706745600,"checked_at":1708646400}]` output, err := d.MakeOutput() diff --git a/pkg/geoipupdate/download/writer.go b/pkg/geoipupdate/download/writer.go index d634ab8b..274749d9 100644 --- a/pkg/geoipupdate/download/writer.go +++ b/pkg/geoipupdate/download/writer.go @@ -10,7 +10,6 @@ import ( "fmt" "hash" "io" - "io/ioutil" "log" "net/http" "os" @@ -48,7 +47,7 @@ func (d *Download) DownloadEdition(ctx context.Context, edition Metadata) error defer response.Body.Close() if response.StatusCode != http.StatusOK { - responseBody, err := ioutil.ReadAll(response.Body) + responseBody, err := io.ReadAll(response.Body) if err != nil { return fmt.Errorf("reading response body: %w", err) } @@ -232,7 +231,7 @@ func syncDir(path string) error { } // setModifiedAtTime sets the times for a database file to a certain value. -func setModifiedAtTime(path string, dateString string) error { +func setModifiedAtTime(path, dateString string) error { releaseDate, err := ParseTime(dateString) if err != nil { return err diff --git a/pkg/geoipupdate/download/writer_test.go b/pkg/geoipupdate/download/writer_test.go index ffd12d76..f4f51888 100644 --- a/pkg/geoipupdate/download/writer_test.go +++ b/pkg/geoipupdate/download/writer_test.go @@ -61,7 +61,7 @@ func TestDownloadEdition(t *testing.T) { require.NoError(t, tw.Close()) require.NoError(t, gw.Close()) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/gzip") w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") _, err := io.Copy(w, &buf) @@ -75,6 +75,7 @@ func TestDownloadEdition(t *testing.T) { dbFile := filepath.Join(tempDir, edition.EditionID+Extension) + //nolint:gosec // we need to read the content of the file in this test. fileContent, err := os.ReadFile(dbFile) require.NoError(t, err) require.Equal(t, dbContent, string(fileContent)) @@ -105,7 +106,7 @@ func TestDownloadEdition(t *testing.T) { require.NoError(t, tw.Close()) require.NoError(t, gw.Close()) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/gzip") w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") _, err := io.Copy(w, &buf) @@ -119,6 +120,7 @@ func TestDownloadEdition(t *testing.T) { dbFile := filepath.Join(tempDir, edition.EditionID+Extension) + //nolint:gosec // we need to read the content of the file in this test. fileContent, err := os.ReadFile(dbFile) require.NoError(t, err) require.Equal(t, dbContent, string(fileContent)) @@ -134,8 +136,8 @@ func TestDownloadEdition(t *testing.T) { { description: "server error", preserveFileTime: false, - server: func(t *testing.T) *httptest.Server { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server: func(_ *testing.T) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) return server @@ -149,11 +151,12 @@ func TestDownloadEdition(t *testing.T) { description: "wrong file format", preserveFileTime: false, server: func(t *testing.T) *httptest.Server { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { jsonData := `{"message": "Hello, world!", "status": "ok"}` w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write([]byte(jsonData)) + _, err := w.Write([]byte(jsonData)) + require.NoError(t, err) })) return server }, @@ -172,7 +175,7 @@ func TestDownloadEdition(t *testing.T) { require.NoError(t, tw.Close()) require.NoError(t, gw.Close()) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/gzip") w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") _, err := io.Copy(w, &buf) @@ -207,7 +210,7 @@ func TestDownloadEdition(t *testing.T) { require.NoError(t, tw.Close()) require.NoError(t, gw.Close()) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/gzip") w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") _, err := io.Copy(w, &buf) @@ -242,7 +245,7 @@ func TestDownloadEdition(t *testing.T) { require.NoError(t, tw.Close()) require.NoError(t, gw.Close()) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/gzip") w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") _, err := io.Copy(w, &buf) @@ -253,6 +256,7 @@ func TestDownloadEdition(t *testing.T) { }, checkResult: func(t *testing.T, err error) { require.Error(t, err) + //nolint:lll require.Regexp(t, "^validating hash for edition-1: md5 of new database .* does not match expected md5", err.Error()) }, }, diff --git a/pkg/geoipupdate/geoip_updater.go b/pkg/geoipupdate/geoip_updater.go index fe962c86..55be212e 100644 --- a/pkg/geoipupdate/geoip_updater.go +++ b/pkg/geoipupdate/geoip_updater.go @@ -70,7 +70,7 @@ func (c *Client) Run(ctx context.Context) error { return nil } - if err := c.retry(ctx, getOutdatedEditions, "Couldn't get download metadata"); err != nil { + if err := c.retry(getOutdatedEditions, "Couldn't get download metadata"); err != nil { return fmt.Errorf("getting download metadata: %w", err) } @@ -84,11 +84,10 @@ func (c *Client) Run(ctx context.Context) error { jobProcessor := internal.NewJobProcessor(ctx, c.config.Parallelism) for _, edition := range outdatedEditions { edition := edition - processFunc := func(ctx context.Context) error { + processFunc := func(_ context.Context) error { err := c.retry( - ctx, func() error { return downloadEdition(edition) }, - fmt.Sprintf("Couldn't download %s", edition.EditionID), + "Couldn't download "+edition.EditionID, ) if err != nil { return err @@ -117,7 +116,6 @@ func (c *Client) Run(ctx context.Context) error { // retry implements a retry functionality for downloads for non permanent errors. func (c *Client) retry( - ctx context.Context, f func() error, logMsg string, ) error { diff --git a/pkg/geoipupdate/geoip_updater_test.go b/pkg/geoipupdate/geoip_updater_test.go index e7b20512..2c4d4565 100644 --- a/pkg/geoipupdate/geoip_updater_test.go +++ b/pkg/geoipupdate/geoip_updater_test.go @@ -41,7 +41,7 @@ func TestFullDownload(t *testing.T) { require.NoError(t, err) // mock metadata handler. - metadataHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + metadataHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { jsonData := ` { "databases": [ @@ -65,20 +65,21 @@ func TestFullDownload(t *testing.T) { ` w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write([]byte(jsonData)) + _, err := w.Write([]byte(jsonData)) + require.NoError(t, err) }) // mock download handler. downloadHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - edition := strings.Split(r.URL.Path, "/")[3] // extract the edition-id. + name := strings.Split(r.URL.Path, "/")[3] // extract the edition-id. var buf bytes.Buffer gw := gzip.NewWriter(&buf) tw := tar.NewWriter(gw) - content := "new " + edition + " content" + content := "new " + name + " content" header := &tar.Header{ - Name: "edition-1" + download.Extension, + Name: name + download.Extension, Size: int64(len(content)), } @@ -139,6 +140,7 @@ func TestFullDownload(t *testing.T) { // edition-1 file hasn't been modified. dbFile = filepath.Join(tempDir, "edition-1"+download.Extension) + //nolint:gosec // we need to read the content of the file in this test. fileContent, err := os.ReadFile(dbFile) require.NoError(t, err) require.Equal(t, "edition-1 content", string(fileContent)) @@ -148,6 +150,7 @@ func TestFullDownload(t *testing.T) { // edition-2 file has been updated. dbFile = filepath.Join(tempDir, "edition-2"+download.Extension) + //nolint:gosec // we need to read the content of the file in this test. fileContent, err = os.ReadFile(dbFile) require.NoError(t, err) require.Equal(t, "new edition-2 content", string(fileContent)) @@ -159,6 +162,7 @@ func TestFullDownload(t *testing.T) { // edition-3 file has been downloaded. dbFile = filepath.Join(tempDir, "edition-3"+download.Extension) + //nolint:gosec // we need to read the content of the file in this test. fileContent, err = os.ReadFile(dbFile) require.NoError(t, err) require.Equal(t, "new edition-3 content", string(fileContent)) diff --git a/pkg/geoipupdate/internal/errors.go b/pkg/geoipupdate/internal/errors.go index 0fa5cae6..e2381148 100644 --- a/pkg/geoipupdate/internal/errors.go +++ b/pkg/geoipupdate/internal/errors.go @@ -14,7 +14,12 @@ type ResponseError struct { } func (e ResponseError) Error() string { - return fmt.Sprintf("received: HTTP status code '%d' - Error code '%s' - Message '%s'", e.StatusCode, e.Code, e.Message) + return fmt.Sprintf("received: "+ + "HTTP status code '%d' - "+ + "Error code '%s' - "+ + "Message '%s'", + e.StatusCode, e.Code, e.Message, + ) } // IsPermanentError returns true if the error is non-retriable. From c5a5c639df8b891fd2ca47524c9d71bf76f5a4d8 Mon Sep 17 00:00:00 2001 From: Naji Obeid Date: Mon, 26 Feb 2024 17:30:27 +0000 Subject: [PATCH 5/6] check output in new e2e test --- pkg/geoipupdate/geoip_updater_test.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pkg/geoipupdate/geoip_updater_test.go b/pkg/geoipupdate/geoip_updater_test.go index 2c4d4565..28b6fb9c 100644 --- a/pkg/geoipupdate/geoip_updater_test.go +++ b/pkg/geoipupdate/geoip_updater_test.go @@ -16,7 +16,6 @@ import ( "time" "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/download" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -120,23 +119,26 @@ func TestFullDownload(t *testing.T) { DatabaseDirectory: tempDir, EditionIDs: []string{"edition-1", "edition-2", "edition-3"}, LockFile: filepath.Clean(filepath.Join(tempDir, ".geoipupdate.lock")), - URL: server.URL, - RetryFor: 0, + Output: true, Parallelism: 1, PreserveFileTimes: true, + RetryFor: 0, + URL: server.URL, } logOutput := &bytes.Buffer{} - log.SetOutput(logOutput) client, err := NewClient(conf) require.NoError(t, err) + client.output = log.New(logOutput, "", 0) // download updates. err = client.Run(ctx) require.NoError(t, err) - assert.Equal(t, "", logOutput.String(), "no logged output") + //nolint:lll + expectedOutput := `[{"edition_id":"edition\-1","old_hash":"618dd27a10de24809ec160d6807f363f","new_hash":"618dd27a10de24809ec160d6807f363f","modified_at":1708646400,"checked_at":\d+},{"edition_id":"edition\-2","old_hash":"c9bbf7cb507370339633b44001bae038","new_hash":"9960e83daa34d69e9b58b375616e145b","modified_at":1708646400,"checked_at":\d+},{"edition_id":"edition\-3","old_hash":"00000000000000000000000000000000","new_hash":"08628247c1e8c1aa6d05ffc578fa09a8","modified_at":1706832000,"checked_at":\d+}]` + require.Regexp(t, expectedOutput, logOutput.String()) // edition-1 file hasn't been modified. dbFile = filepath.Join(tempDir, "edition-1"+download.Extension) From 7fa63731164e8db0a3e73b89e60bdc74d0d9122e Mon Sep 17 00:00:00 2001 From: Naji Obeid Date: Fri, 1 Mar 2024 15:15:55 +0000 Subject: [PATCH 6/6] reorganize packages to make it more flexible for additions or changes --- .github/workflows/go.yml | 2 +- CHANGELOG.md | 14 +- cmd/geoipupdate/main.go | 65 ++++- go.mod | 2 +- pkg/geoipupdate/api/api.go | 33 +++ pkg/geoipupdate/api/http.go | 182 ++++++++++++++ .../writer_test.go => api/http_test.go} | 226 ++++++++---------- pkg/geoipupdate/{ => config}/config.go | 4 +- pkg/geoipupdate/{ => config}/config_test.go | 2 +- pkg/geoipupdate/download/download.go | 152 ------------ pkg/geoipupdate/download/download_test.go | 36 --- pkg/geoipupdate/download/metadata.go | 98 -------- pkg/geoipupdate/download/metadata_test.go | 101 -------- pkg/geoipupdate/download/output_test.go | 40 ---- pkg/geoipupdate/geoip_updater.go | 111 +++++---- pkg/geoipupdate/geoip_updater_test.go | 36 ++- .../{internal => lock}/file_lock.go | 33 +-- .../{internal => lock}/file_lock_test.go | 4 +- pkg/geoipupdate/lock/lock.go | 9 + pkg/geoipupdate/{download => }/output.go | 20 +- pkg/geoipupdate/output_test.go | 35 +++ .../{download/writer.go => writer/disk.go} | 142 +++++------ pkg/geoipupdate/writer/disk_test.go | 122 ++++++++++ pkg/geoipupdate/writer/writer.go | 21 ++ 24 files changed, 758 insertions(+), 732 deletions(-) create mode 100644 pkg/geoipupdate/api/api.go create mode 100644 pkg/geoipupdate/api/http.go rename pkg/geoipupdate/{download/writer_test.go => api/http_test.go} (57%) rename pkg/geoipupdate/{ => config}/config.go (98%) rename pkg/geoipupdate/{ => config}/config_test.go (99%) delete mode 100644 pkg/geoipupdate/download/download.go delete mode 100644 pkg/geoipupdate/download/download_test.go delete mode 100644 pkg/geoipupdate/download/metadata.go delete mode 100644 pkg/geoipupdate/download/metadata_test.go delete mode 100644 pkg/geoipupdate/download/output_test.go rename pkg/geoipupdate/{internal => lock}/file_lock.go (60%) rename pkg/geoipupdate/{internal => lock}/file_lock_test.go (96%) create mode 100644 pkg/geoipupdate/lock/lock.go rename pkg/geoipupdate/{download => }/output.go (55%) create mode 100644 pkg/geoipupdate/output_test.go rename pkg/geoipupdate/{download/writer.go => writer/disk.go} (59%) create mode 100644 pkg/geoipupdate/writer/disk_test.go create mode 100644 pkg/geoipupdate/writer/writer.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 249058f1..8d728099 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -11,7 +11,7 @@ jobs: build: strategy: matrix: - go-version: [1.20.x, 1.21.x] + go-version: [1.21.x] platform: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.platform }} name: "Build ${{ matrix.go-version }} test on ${{ matrix.platform }}" diff --git a/CHANGELOG.md b/CHANGELOG.md index d7b2025d..24361623 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,11 +7,15 @@ * Now `geoipupdate` doesn't requires the user to specify the config file even if all the other arguments are set via the environment variables. Reported by jsf84ksnf. GitHub #284. -* `pkg/goipupdate/database` has been rewritten into `pkg/geoipupdate/download`. - This new package changes the update behaviour of the tool: - 1. It will first request edition information from a newly introduced metadata endpoint. - 2. Compare the result with existing editions and decide which edition needs to be updated. - 3. Then download individual editions from a newly introduced download endpoint. +* `pkg/goipupdate/database` has been divided and rewritten into multiple packages, + which change the behaviour of the tool: + + * `pkg/goipupdate/database` which is responsible for handling api calls to maxmind + servers and includes two new api calls to the `metadata` and `download` endpoints. + * `pkg/goipupdate/writer` which is responsible for writing databases to various targets. + It only supports writing to disk out of the box. + * `pkg/goipupdate/lock` which is responsible for synchronizing concurrent access to the tool. + It only supports file locks out of the box. ## 6.1.0 (2024-01-09) diff --git a/cmd/geoipupdate/main.go b/cmd/geoipupdate/main.go index d25fb71e..046242fe 100644 --- a/cmd/geoipupdate/main.go +++ b/cmd/geoipupdate/main.go @@ -4,9 +4,16 @@ package main import ( "context" "log" + "log/slog" + "net/http" + "os" "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/api" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/config" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/lock" "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/vars" + geoipupdatewriter "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/writer" ) const unknownVersion = "unknown" @@ -32,29 +39,61 @@ func main() { args := getArgs() - config, err := geoipupdate.NewConfig( - geoipupdate.WithConfigFile(args.ConfigFile), - geoipupdate.WithDatabaseDirectory(args.DatabaseDirectory), - geoipupdate.WithParallelism(args.Parallelism), - geoipupdate.WithVerbose(args.Verbose), - geoipupdate.WithOutput(args.Output), + conf, err := config.NewConfig( + config.WithConfigFile(args.ConfigFile), + config.WithDatabaseDirectory(args.DatabaseDirectory), + config.WithParallelism(args.Parallelism), + config.WithVerbose(args.Verbose), + config.WithOutput(args.Output), ) if err != nil { log.Fatalf("Error loading configuration: %s", err) } - if config.Verbose { - log.Printf("geoipupdate version %s", version) - log.Printf("Using config file %s", args.ConfigFile) - log.Printf("Using database directory %s", config.DatabaseDirectory) + options := &slog.HandlerOptions{Level: slog.LevelInfo} + if conf.Verbose { + options = &slog.HandlerOptions{Level: slog.LevelDebug} } + handler := slog.NewTextHandler(os.Stderr, options) + slog.SetDefault(slog.New(handler)) - client, err := geoipupdate.NewClient(config) + slog.Debug("geoipupdate", "version", version) + slog.Debug("config file", "path", args.ConfigFile) + slog.Debug("database directory", "path", conf.DatabaseDirectory) + + transport := http.DefaultTransport + if conf.Proxy != nil { + proxyFunc := http.ProxyURL(conf.Proxy) + transport.(*http.Transport).Proxy = proxyFunc + } + + downloader := api.NewHTTPDownloader( + conf.AccountID, + conf.LicenseKey, + &http.Client{Transport: transport}, + conf.URL, + ) + + writer := geoipupdatewriter.NewDiskWriter( + conf.DatabaseDirectory, + conf.PreserveFileTimes, + ) + + slog.Debug("initializing file lock", "path", conf.LockFile) + locker, err := lock.NewFileLock(conf.LockFile) if err != nil { - log.Fatalf("Error initializing download client: %s", err) + slog.Error("setting up file lock", "error", err) + os.Exit(1) } + client := geoipupdate.NewClient( + conf, + downloader, + locker, + writer, + ) if err = client.Run(context.Background()); err != nil { - log.Fatalf("Error retrieving updates: %s", err) + slog.Error("retrieving updates", "error", err) + os.Exit(1) } } diff --git a/go.mod b/go.mod index 9b521727..96a2d2c2 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/maxmind/geoipupdate/v6 -go 1.20 +go 1.21 require ( github.com/cenkalti/backoff/v4 v4.2.1 diff --git a/pkg/geoipupdate/api/api.go b/pkg/geoipupdate/api/api.go new file mode 100644 index 00000000..e3f7f781 --- /dev/null +++ b/pkg/geoipupdate/api/api.go @@ -0,0 +1,33 @@ +// Package api provides a wrapper around the maxmind api responsible for +// checking/downloading mmdb files. +package api + +import ( + "context" + "fmt" + "io" + "time" +) + +// DownloadAPI presents common methods required to download mmdb files. +type DownloadAPI interface { + GetMetadata(ctx context.Context, editions []string) ([]Metadata, error) + GetEdition(ctx context.Context, edition Metadata) (io.Reader, func(), error) +} + +// Metadata represents the metadata content for a certain database returned by the +// metadata endpoint. +type Metadata struct { + Date string `json:"date"` + EditionID string `json:"edition_id"` + MD5 string `json:"md5"` +} + +// ParseTime parses the date returned in a metadata response to time.Time. +func ParseTime(dateString string) (time.Time, error) { + t, err := time.ParseInLocation("2006-01-02", dateString, time.Local) + if err != nil { + return time.Time{}, fmt.Errorf("parsing edition date: %w", err) + } + return t, nil +} diff --git a/pkg/geoipupdate/api/http.go b/pkg/geoipupdate/api/http.go new file mode 100644 index 00000000..9cb9bc41 --- /dev/null +++ b/pkg/geoipupdate/api/http.go @@ -0,0 +1,182 @@ +package api + +import ( + "archive/tar" + "compress/gzip" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/internal" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/vars" +) + +const ( + metadataEndpoint = "%s/geoip/updates/metadata?%s" + downloadEndpoint = "%s/geoip/databases/%s/download?date=%s&suffix=tar.gz" +) + +// httpDownloader is a http implementation of the DownloadAPI interface. +type httpDownloader struct { + // accountID is the requester's account ID. + accountID int + // client is an http client responsible of fetching database updates. + client *http.Client + // host points to maxmind servers. + host string + // licenseKey is the requester's license key. + licenseKey string +} + +// NewHTTPDownloader initializes a new httpDownloader struct. +// +//nolint:revive // unexported type fileLock is not meant to be used as a standalone type. +func NewHTTPDownloader( + accountID int, + licenseKey string, + client *http.Client, + host string, +) *httpDownloader { + return &httpDownloader{ + accountID: accountID, + client: client, + host: host, + licenseKey: licenseKey, + } +} + +// GetMetadata makes an http request to retrieve metadata about the provided database editions. +func (h *httpDownloader) GetMetadata(ctx context.Context, editions []string) ([]Metadata, error) { + var editionsQuery []string + for _, e := range editions { + editionsQuery = append(editionsQuery, "edition_id="+url.QueryEscape(e)) + } + + requestURL := fmt.Sprintf(metadataEndpoint, h.host, strings.Join(editionsQuery, "&")) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Add("User-Agent", "geoipupdate/"+vars.Version) + req.SetBasicAuth(strconv.Itoa(h.accountID), h.licenseKey) + + response, err := h.client.Do(req) + if err != nil { + return nil, fmt.Errorf("performing HTTP request: %w", err) + } + defer response.Body.Close() + + responseBody, err := io.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("reading response body: %w", err) + } + + if response.StatusCode != http.StatusOK { + errResponse := internal.ResponseError{ + StatusCode: response.StatusCode, + } + + if err := json.Unmarshal(responseBody, &errResponse); err != nil { + errResponse.Message = err.Error() + } + + return nil, fmt.Errorf("requesting metadata: %w", errResponse) + } + + var metadataResponse struct { + Databases []Metadata `json:"databases"` + } + + if err := json.Unmarshal(responseBody, &metadataResponse); err != nil { + return nil, fmt.Errorf("parsing body: %w", err) + } + + return metadataResponse.Databases, nil +} + +// GetEdition makes an http request to download the requested database edition. +// It returns an io.Reader that points to the content of the database file. +func (h *httpDownloader) GetEdition( + ctx context.Context, + edition Metadata, +) (reader io.Reader, cleanupCallback func(), err error) { + date := strings.ReplaceAll(edition.Date, "-", "") + requestURL := fmt.Sprintf(downloadEndpoint, h.host, edition.EditionID, date) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) + if err != nil { + return nil, nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Add("User-Agent", "geoipupdate/"+vars.Version) + req.SetBasicAuth(strconv.Itoa(h.accountID), h.licenseKey) + + response, err := h.client.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("performing HTTP request: %w", err) + } + // It is safe to close the response body reader as it wouldn't be + // consumed in case this function returns an error. + defer func() { + if err != nil { + response.Body.Close() + } + }() + + if response.StatusCode != http.StatusOK { + responseBody, err := io.ReadAll(response.Body) + if err != nil { + return nil, nil, fmt.Errorf("reading error response body: %w", err) + } + + errResponse := internal.ResponseError{ + StatusCode: response.StatusCode, + } + + if err := json.Unmarshal(responseBody, &errResponse); err != nil { + errResponse.Message = err.Error() + } + + return nil, nil, fmt.Errorf("requesting edition: %w", errResponse) + } + + gzReader, err := gzip.NewReader(response.Body) + if err != nil { + return nil, nil, fmt.Errorf("encountered an error creating GZIP reader: %w", err) + } + defer func() { + if err != nil { + gzReader.Close() + } + }() + + tarReader := tar.NewReader(gzReader) + + // iterate through the tar archive to extract the mmdb file + for { + header, err := tarReader.Next() + if err == io.EOF { + return nil, nil, errors.New("tar archive does not contain an mmdb file") + } + if err != nil { + return nil, nil, fmt.Errorf("reading tar archive: %w", err) + } + + if strings.HasSuffix(header.Name, ".mmdb") { + break + } + } + + cleanupCallback = func() { + gzReader.Close() + response.Body.Close() + } + + return tarReader, cleanupCallback, nil +} diff --git a/pkg/geoipupdate/download/writer_test.go b/pkg/geoipupdate/api/http_test.go similarity index 57% rename from pkg/geoipupdate/download/writer_test.go rename to pkg/geoipupdate/api/http_test.go index f4f51888..8765f5ae 100644 --- a/pkg/geoipupdate/download/writer_test.go +++ b/pkg/geoipupdate/api/http_test.go @@ -1,4 +1,4 @@ -package download +package api import ( "archive/tar" @@ -8,24 +8,90 @@ import ( "io" "net/http" "net/http/httptest" - "os" - "path/filepath" "testing" - "time" "github.com/stretchr/testify/require" ) -// TestDownloadEdition checks the database download functionality. -func TestDownloadEdition(t *testing.T) { - tempDir, err := os.MkdirTemp("", "db") - require.NoError(t, err) - defer os.RemoveAll(tempDir) +// TestGetMetadata checks the metadata fetching functionality. +func TestGetMetadata(t *testing.T) { + tests := []struct { + description string + preserveFileTime bool + server func(t *testing.T) *httptest.Server + checkResult func(t *testing.T, metadata []Metadata, err error) + }{ + { + description: "successful request", + preserveFileTime: false, + server: func(t *testing.T) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + jsonData := ` +{ + "databases": [ + { "edition_id": "edition-1", "md5": "123456", "date": "2024-02-23" }, + { "edition_id": "edition-2", "md5": "abc123", "date": "2024-02-23" }, + { "edition_id": "edition-3", "md5": "def456", "date": "2024-02-02" } + ] +} +` + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(jsonData)) + require.NoError(t, err) + })) + return server + }, + checkResult: func(t *testing.T, metadata []Metadata, err error) { + require.NoError(t, err) + + expectedMetadata := []Metadata{ + {EditionID: "edition-1", MD5: "123456", Date: "2024-02-23"}, + {EditionID: "edition-2", MD5: "abc123", Date: "2024-02-23"}, + {EditionID: "edition-3", MD5: "def456", Date: "2024-02-02"}, + } + require.ElementsMatch(t, expectedMetadata, metadata) + }, + }, + { + description: "server error", + preserveFileTime: false, + server: func(_ *testing.T) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + return server + }, + checkResult: func(t *testing.T, metadata []Metadata, err error) { + require.Empty(t, metadata) + require.Error(t, err) + require.Regexp(t, "^requesting metadata: received: HTTP status code '500'", err.Error()) + }, + }, + } ctx := context.Background() - now, err := ParseTime("2024-02-23") - require.NoError(t, err) + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + server := test.server(t) + defer server.Close() + + h := NewHTTPDownloader( + 0, // accountID is not relevant for this test. + "", // licenseKey is not relevant for this test. + http.DefaultClient, + server.URL, + ) + + metadata, err := h.GetMetadata(ctx, []string{"edition-1", "edition-2", "edition-3"}) + test.checkResult(t, metadata, err) + }) + } +} + +// TestGetEdition checks the database download functionality. +func TestGetEdition(t *testing.T) { edition := Metadata{ EditionID: "edition-1", Date: "2024-02-02", @@ -38,7 +104,7 @@ func TestDownloadEdition(t *testing.T) { description string preserveFileTime bool server func(t *testing.T) *httptest.Server - checkResult func(t *testing.T, err error) + checkResult func(t *testing.T, reader io.Reader, err error) }{ { description: "successful download", @@ -49,11 +115,11 @@ func TestDownloadEdition(t *testing.T) { tw := tar.NewWriter(gw) header := &tar.Header{ - Name: "edition-1" + Extension, + Name: "edition-1.mmdb", Size: int64(len(dbContent)), } - err = tw.WriteHeader(header) + err := tw.WriteHeader(header) require.NoError(t, err) _, err = tw.Write([]byte(dbContent)) require.NoError(t, err) @@ -70,67 +136,11 @@ func TestDownloadEdition(t *testing.T) { return server }, - checkResult: func(t *testing.T, err error) { - require.NoError(t, err) - - dbFile := filepath.Join(tempDir, edition.EditionID+Extension) - - //nolint:gosec // we need to read the content of the file in this test. - fileContent, err := os.ReadFile(dbFile) + checkResult: func(t *testing.T, reader io.Reader, err error) { require.NoError(t, err) - require.Equal(t, dbContent, string(fileContent)) - - database, err := os.Stat(dbFile) - require.NoError(t, err) - require.GreaterOrEqual(t, database.ModTime(), now) - }, - }, - { - description: "successful download - preserve time", - preserveFileTime: true, - server: func(t *testing.T) *httptest.Server { - var buf bytes.Buffer - gw := gzip.NewWriter(&buf) - tw := tar.NewWriter(gw) - - header := &tar.Header{ - Name: "edition-1" + Extension, - Size: int64(len(dbContent)), - } - - err = tw.WriteHeader(header) - require.NoError(t, err) - _, err = tw.Write([]byte(dbContent)) - require.NoError(t, err) - - require.NoError(t, tw.Close()) - require.NoError(t, gw.Close()) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/gzip") - w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") - _, err := io.Copy(w, &buf) - require.NoError(t, err) - })) - - return server - }, - checkResult: func(t *testing.T, err error) { - require.NoError(t, err) - - dbFile := filepath.Join(tempDir, edition.EditionID+Extension) - - //nolint:gosec // we need to read the content of the file in this test. - fileContent, err := os.ReadFile(dbFile) - require.NoError(t, err) - require.Equal(t, dbContent, string(fileContent)) - - modTime, err := ParseTime(edition.Date) - require.NoError(t, err) - - database, err := os.Stat(dbFile) - require.NoError(t, err) - require.Equal(t, modTime, database.ModTime()) + c, rerr := io.ReadAll(reader) + require.NoError(t, rerr) + require.Equal(t, dbContent, string(c)) }, }, { @@ -142,7 +152,8 @@ func TestDownloadEdition(t *testing.T) { })) return server }, - checkResult: func(t *testing.T, err error) { + checkResult: func(t *testing.T, reader io.Reader, err error) { + require.Nil(t, reader) require.Error(t, err) require.Regexp(t, "^requesting edition: received: HTTP status code '500'", err.Error()) }, @@ -160,7 +171,8 @@ func TestDownloadEdition(t *testing.T) { })) return server }, - checkResult: func(t *testing.T, err error) { + checkResult: func(t *testing.T, reader io.Reader, err error) { + require.Nil(t, reader) require.Error(t, err) require.Regexp(t, "^encountered an error creating GZIP reader: gzip: invalid header", err.Error()) }, @@ -184,7 +196,8 @@ func TestDownloadEdition(t *testing.T) { return server }, - checkResult: func(t *testing.T, err error) { + checkResult: func(t *testing.T, reader io.Reader, err error) { + require.Nil(t, reader) require.Error(t, err) require.Regexp(t, "^tar archive does not contain an mmdb file", err.Error()) }, @@ -202,7 +215,7 @@ func TestDownloadEdition(t *testing.T) { Size: int64(len(dbContent)), } - err = tw.WriteHeader(header) + err := tw.WriteHeader(header) require.NoError(t, err) _, err = tw.Write([]byte(dbContent)) require.NoError(t, err) @@ -219,64 +232,29 @@ func TestDownloadEdition(t *testing.T) { return server }, - checkResult: func(t *testing.T, err error) { + checkResult: func(t *testing.T, reader io.Reader, err error) { + require.Nil(t, reader) require.Error(t, err) require.Regexp(t, "^tar archive does not contain an mmdb file", err.Error()) }, }, - { - description: "mmdb hash does not match metadata", - preserveFileTime: false, - server: func(t *testing.T) *httptest.Server { - var buf bytes.Buffer - gw := gzip.NewWriter(&buf) - tw := tar.NewWriter(gw) - - header := &tar.Header{ - Name: "edition-1" + Extension, - Size: int64(len(dbContent) - 1), - } - - err = tw.WriteHeader(header) - require.NoError(t, err) - _, err = tw.Write([]byte(dbContent[:len(dbContent)-1])) - require.NoError(t, err) - - require.NoError(t, tw.Close()) - require.NoError(t, gw.Close()) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/gzip") - w.Header().Set("Content-Disposition", "attachment; filename=test.tar.gz") - _, err := io.Copy(w, &buf) - require.NoError(t, err) - })) - - return server - }, - checkResult: func(t *testing.T, err error) { - require.Error(t, err) - //nolint:lll - require.Regexp(t, "^validating hash for edition-1: md5 of new database .* does not match expected md5", err.Error()) - }, - }, } + ctx := context.Background() for _, test := range tests { t.Run(test.description, func(t *testing.T) { server := test.server(t) defer server.Close() - d := Download{ - client: http.DefaultClient, - databaseDir: tempDir, - now: func() time.Time { return now }, - preserveFileTimes: test.preserveFileTime, - url: server.URL, - } + h := NewHTTPDownloader( + 0, // accountID is not relevant for this test. + "", // licenseKey is not relevant for this test. + http.DefaultClient, + server.URL, + ) - err = d.DownloadEdition(ctx, edition) - test.checkResult(t, err) + reader, _, err := h.GetEdition(ctx, edition) + test.checkResult(t, reader, err) }) } } diff --git a/pkg/geoipupdate/config.go b/pkg/geoipupdate/config/config.go similarity index 98% rename from pkg/geoipupdate/config.go rename to pkg/geoipupdate/config/config.go index 5655ab4b..a585b185 100644 --- a/pkg/geoipupdate/config.go +++ b/pkg/geoipupdate/config/config.go @@ -1,4 +1,6 @@ -package geoipupdate +// Package config deals reading and setting the client configuration from multiple +// source: config file, environment variables and cli flags. +package config import ( "bufio" diff --git a/pkg/geoipupdate/config_test.go b/pkg/geoipupdate/config/config_test.go similarity index 99% rename from pkg/geoipupdate/config_test.go rename to pkg/geoipupdate/config/config_test.go index 54817572..8c427e34 100644 --- a/pkg/geoipupdate/config_test.go +++ b/pkg/geoipupdate/config/config_test.go @@ -1,4 +1,4 @@ -package geoipupdate +package config import ( "fmt" diff --git a/pkg/geoipupdate/download/download.go b/pkg/geoipupdate/download/download.go deleted file mode 100644 index 2e692714..00000000 --- a/pkg/geoipupdate/download/download.go +++ /dev/null @@ -1,152 +0,0 @@ -// Package download provides a library for checking/downloading/updating mmdb files. -package download - -import ( - "context" - "crypto/md5" - "encoding/hex" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/url" - "os" - "path/filepath" - "time" -) - -const ( - // Extension is the typical extension used for database files. - Extension = ".mmdb" - - // zeroMD5 is the default value provided as an MD5 hash for a non-existent - // database. - zeroMD5 = "00000000000000000000000000000000" -) - -// Downloader represents common methods required to implement the functionality -// required to download/update mmdb files. -type Downloader interface { - GetOutdatedEditions(ctx context.Context) ([]Metadata, error) - DownloadEdition(ctx context.Context, edition Metadata) error - MakeOutput() ([]byte, error) -} - -// Download exposes methods needed to check for and perform update to a set of mmdb editions. -type Download struct { - // accountID is the requester's account ID. - accountID int - // client is an http client responsible of fetching database updates. - client *http.Client - // databaseDir is the database download path. - databaseDir string - // editionIDs is the list of editions to be updated. - editionIDs []string - // licenseKey is the requester's license key. - licenseKey string - // oldEditionsHash holds the hashes of the previously downloaded mmdb editions. - oldEditionsHash map[string]string - // metadata holds the metadata pulled for each edition. - metadata []Metadata - // preserveFileTimes sets whether database modification times are preserved across downloads. - preserveFileTimes bool - // url points to maxmind servers. - url string - - now func() time.Time - verbose bool -} - -// New initializes a new Downloader struct. -func New( - accountID int, - licenseKey string, - serverURL string, - proxy *url.URL, - databaseDir string, - preserveFileTimes bool, - editionIDs []string, - verbose bool, -) (*Download, error) { - transport := http.DefaultTransport - if proxy != nil { - proxyFunc := http.ProxyURL(proxy) - transport.(*http.Transport).Proxy = proxyFunc - } - - d := Download{ - accountID: accountID, - client: &http.Client{Transport: transport}, - databaseDir: databaseDir, - editionIDs: editionIDs, - licenseKey: licenseKey, - oldEditionsHash: map[string]string{}, - preserveFileTimes: preserveFileTimes, - url: serverURL, - now: time.Now, - verbose: verbose, - } - - for _, e := range editionIDs { - hash, err := d.getHash(e) - if err != nil { - return nil, fmt.Errorf("getting existing %q database hash: %w", e, err) - } - d.oldEditionsHash[e] = hash - } - - return &d, nil -} - -// getHash returns the hash of a certain database file. -func (d *Download) getHash(editionID string) (string, error) { - databaseFilePath := d.getFilePath(editionID) - //nolint:gosec // we really need to read this file. - database, err := os.Open(databaseFilePath) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - if d.verbose { - log.Print("Database does not exist, returning zeroed hash") - } - return zeroMD5, nil - } - return "", fmt.Errorf("opening database: %w", err) - } - - defer func() { - if err := database.Close(); err != nil { - log.Println(fmt.Errorf("closing database: %w", err)) - } - }() - - md5Hash := md5.New() - if _, err := io.Copy(md5Hash, database); err != nil { - return "", fmt.Errorf("calculating database hash: %w", err) - } - - result := byteToString(md5Hash.Sum(nil)) - if d.verbose { - log.Printf("Calculated MD5 sum for %s: %s", databaseFilePath, result) - } - return result, nil -} - -// getFilePath construct the file path for a database edition. -func (d *Download) getFilePath(editionID string) string { - return filepath.Join(d.databaseDir, editionID) + Extension -} - -// byteToString returns the base16 representation of a byte array. -func byteToString(b []byte) string { - return hex.EncodeToString(b) -} - -// ParseTime parses the date returned in a metadata response to time.Time. -func ParseTime(dateString string) (time.Time, error) { - t, err := time.ParseInLocation("2006-01-02", dateString, time.Local) - if err != nil { - return time.Time{}, fmt.Errorf("parsing edition date: %w", err) - } - return t, nil -} diff --git a/pkg/geoipupdate/download/download_test.go b/pkg/geoipupdate/download/download_test.go deleted file mode 100644 index 4b4f5642..00000000 --- a/pkg/geoipupdate/download/download_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package download - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" -) - -// TestGetHash tests functionality of the LocalFileWriter.GetHash method. -func TestGetHash(t *testing.T) { - tempDir, err := os.MkdirTemp("", "db") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - d := Download{ - databaseDir: tempDir, - } - - // returns a zero hash for a non existing edition. - hash, err := d.getHash("NewEdition") - require.NoError(t, err) - require.Equal(t, zeroMD5, hash) - - // returns the correct md5 for an existing edition. - edition := "edition-1" - dbFile := filepath.Join(tempDir, edition+Extension) - - err = os.WriteFile(dbFile, []byte("edition-1 content"), os.ModePerm) - require.NoError(t, err) - - hash, err = d.getHash(edition) - require.NoError(t, err) - require.Equal(t, "618dd27a10de24809ec160d6807f363f", hash) -} diff --git a/pkg/geoipupdate/download/metadata.go b/pkg/geoipupdate/download/metadata.go deleted file mode 100644 index 4f27ebfd..00000000 --- a/pkg/geoipupdate/download/metadata.go +++ /dev/null @@ -1,98 +0,0 @@ -package download - -import ( - "context" - "encoding/json" - "fmt" - "io" - "log" - "net/http" - "net/url" - "strconv" - "strings" - - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/internal" - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/vars" -) - -const ( - metadataEndpoint = "%s/geoip/updates/metadata?%s" -) - -// metadataResponse represents a successful response returned by the metadata endpoint. -type metadataResponse struct { - Databases []Metadata `json:"databases"` -} - -// Metadata represents the metadata content for a certain database returned by the -// metadata endpoint. -type Metadata struct { - Date string `json:"date"` - EditionID string `json:"edition_id"` - MD5 string `json:"md5"` -} - -// GetOutdatedEditions returns the list of outdated database editions. -func (d *Download) GetOutdatedEditions(ctx context.Context) ([]Metadata, error) { - var editionsQuery []string - for _, e := range d.editionIDs { - editionsQuery = append(editionsQuery, "edition_id="+url.QueryEscape(e)) - } - - requestURL := fmt.Sprintf(metadataEndpoint, d.url, strings.Join(editionsQuery, "&")) - if d.verbose { - log.Printf("Requesting edition metadata: %s", requestURL) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - req.Header.Add("User-Agent", "geoipupdate/"+vars.Version) - req.SetBasicAuth(strconv.Itoa(d.accountID), d.licenseKey) - - response, err := d.client.Do(req) - if err != nil { - return nil, fmt.Errorf("performing HTTP request: %w", err) - } - defer response.Body.Close() - - responseBody, err := io.ReadAll(response.Body) - if err != nil { - return nil, fmt.Errorf("reading response body: %w", err) - } - - if response.StatusCode != http.StatusOK { - errResponse := internal.ResponseError{ - StatusCode: response.StatusCode, - } - - if err := json.Unmarshal(responseBody, &errResponse); err != nil { - errResponse.Message = err.Error() - } - - return nil, fmt.Errorf("requesting metadata: %w", errResponse) - } - - var metadata metadataResponse - if err := json.Unmarshal(responseBody, &metadata); err != nil { - return nil, fmt.Errorf("parsing body: %w", err) - } - - var outdatedEditions []Metadata - for _, m := range metadata.Databases { - oldMD5 := d.oldEditionsHash[m.EditionID] - if oldMD5 != m.MD5 { - outdatedEditions = append(outdatedEditions, m) - continue - } - - if d.verbose { - log.Printf("Database %s up to date", m.EditionID) - } - } - - d.metadata = metadata.Databases - - return outdatedEditions, nil -} diff --git a/pkg/geoipupdate/download/metadata_test.go b/pkg/geoipupdate/download/metadata_test.go deleted file mode 100644 index 941ec187..00000000 --- a/pkg/geoipupdate/download/metadata_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package download - -import ( - "context" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" -) - -// TestGetOutdatedEditions checks the metadata fetching functionality. -func TestGetOutdatedEditions(t *testing.T) { - tempDir, err := os.MkdirTemp("", "db") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - edition := "edition-1" - dbFile := filepath.Join(tempDir, edition+Extension) - // equivalent MD5: 618dd27a10de24809ec160d6807f363f - err = os.WriteFile(dbFile, []byte("edition-1 content"), os.ModePerm) - require.NoError(t, err) - - edition = "edition-2" - dbFile = filepath.Join(tempDir, edition+Extension) - err = os.WriteFile(dbFile, []byte("edition-2 content"), os.ModePerm) - require.NoError(t, err) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - jsonData := ` -{ - "databases": [ - { - "edition_id": "edition-1", - "md5": "618dd27a10de24809ec160d6807f363f", - "date": "2024-02-23" - }, - { - "edition_id": "edition-2", - "md5": "abc123", - "date": "2024-02-23" - }, - { - "edition_id": "edition-3", - "md5": "def456", - "date": "2024-02-02" - } - ] -} -` - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(jsonData)) - require.NoError(t, err) - })) - defer server.Close() - - ctx := context.Background() - d, err := New( - 0, // accountID is not relevant for this test. - "", // licenseKey is not relevant for this test. - server.URL, - nil, // proxy is not relevant for this test. - tempDir, - false, // preserveFileTimes is not relevant for this test. - []string{"edition-1", "edition-2", "edition-3"}, - false, // verbose is not relevant for this test. - ) - require.NoError(t, err) - - // edition-1 md5 hasn't changed - expectedOutdatedEditions := []Metadata{ - { - EditionID: "edition-2", - MD5: "abc123", - Date: "2024-02-23", - }, - { - EditionID: "edition-3", - MD5: "def456", - Date: "2024-02-02", - }, - } - - outdatedEditions, err := d.GetOutdatedEditions(ctx) - require.NoError(t, err) - require.ElementsMatch(t, expectedOutdatedEditions, outdatedEditions) - - expectedDatabases := expectedOutdatedEditions - expectedDatabases = append( - expectedDatabases, - Metadata{ - EditionID: "edition-1", - MD5: "618dd27a10de24809ec160d6807f363f", - Date: "2024-02-23", - }, - ) - require.ElementsMatch(t, expectedDatabases, d.metadata) -} diff --git a/pkg/geoipupdate/download/output_test.go b/pkg/geoipupdate/download/output_test.go deleted file mode 100644 index fad8a025..00000000 --- a/pkg/geoipupdate/download/output_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package download - -import ( - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestOutputFormat(t *testing.T) { - now, err := ParseTime("2024-02-23") - require.NoError(t, err) - - d := Download{ - oldEditionsHash: map[string]string{ - "edition-1": "1", - "edition-2": "10", - }, - metadata: []Metadata{ - { - EditionID: "edition-1", - MD5: "1", - Date: "2024-01-01", - }, - { - EditionID: "edition-2", - MD5: "11", - Date: "2024-02-01", - }, - }, - now: func() time.Time { return now }, - } - - //nolint:lll - expectedOutput := `[{"edition_id":"edition-1","old_hash":"1","new_hash":"1","modified_at":1704067200,"checked_at":1708646400},{"edition_id":"edition-2","old_hash":"10","new_hash":"11","modified_at":1706745600,"checked_at":1708646400}]` - - output, err := d.MakeOutput() - require.NoError(t, err) - require.Equal(t, expectedOutput, string(output)) -} diff --git a/pkg/geoipupdate/geoip_updater.go b/pkg/geoipupdate/geoip_updater.go index 55be212e..c8055654 100644 --- a/pkg/geoipupdate/geoip_updater.go +++ b/pkg/geoipupdate/geoip_updater.go @@ -6,82 +6,115 @@ import ( "context" "fmt" "log" + "log/slog" "os" "time" "github.com/cenkalti/backoff/v4" - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/download" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/api" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/config" "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/internal" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/lock" + geoipupdatewriter "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/writer" ) // Client uses config data to initiate a download or update // process for GeoIP databases. type Client struct { - config *Config - downloader download.Downloader + editionIDs []string + parallelism int + retryFor time.Duration + printOutput bool + + downloader api.DownloadAPI + locker lock.Lock output *log.Logger + writer geoipupdatewriter.Writer } // NewClient initialized a new Client struct. -func NewClient(config *Config) (*Client, error) { - d, err := download.New( - config.AccountID, - config.LicenseKey, - config.URL, - config.Proxy, - config.DatabaseDirectory, - config.PreserveFileTimes, - config.EditionIDs, - config.Verbose, - ) - if err != nil { - return nil, fmt.Errorf("initializing database downloader: %w", err) - } - +func NewClient( + conf *config.Config, + downloader api.DownloadAPI, + locker lock.Lock, + writer geoipupdatewriter.Writer, +) *Client { return &Client{ - config: config, - downloader: d, + editionIDs: conf.EditionIDs, + parallelism: conf.Parallelism, + printOutput: conf.Output, + retryFor: conf.RetryFor, + + downloader: downloader, + locker: locker, output: log.New(os.Stdout, "", 0), - }, nil + writer: writer, + } } // Run starts the download or update process. func (c *Client) Run(ctx context.Context) error { - fileLock, err := internal.NewFileLock(c.config.LockFile, c.config.Verbose) - if err != nil { - return fmt.Errorf("initializing file lock: %w", err) - } - if err := fileLock.Acquire(); err != nil { + if err := c.locker.Acquire(); err != nil { return fmt.Errorf("acquiring file lock: %w", err) } + slog.Debug("file lock acquired") defer func() { - if err := fileLock.Release(); err != nil { + if err := c.locker.Release(); err != nil { log.Printf("releasing file lock: %s", err) } + slog.Debug("file lock successfully released") }() - var outdatedEditions []download.Metadata - getOutdatedEditions := func() (err error) { - outdatedEditions, err = c.downloader.GetOutdatedEditions(ctx) + oldEditionsHash := map[string]string{} + for _, e := range c.editionIDs { + hash, err := c.writer.GetHash(e) + if err != nil { + return fmt.Errorf("getting existing %q database hash: %w", e, err) + } + oldEditionsHash[e] = hash + slog.Debug("existing database md5", "edition", e, "md5", hash) + } + + var allEditions []api.Metadata + getMetadata := func() (err error) { + slog.Debug("requesting metadata") + allEditions, err = c.downloader.GetMetadata(ctx, c.editionIDs) if err != nil { return fmt.Errorf("getting outdated database editions: %w", err) } return nil } - if err := c.retry(getOutdatedEditions, "Couldn't get download metadata"); err != nil { + if err := c.retry(getMetadata, "Couldn't get download metadata"); err != nil { return fmt.Errorf("getting download metadata: %w", err) } - downloadEdition := func(edition download.Metadata) error { - if err := c.downloader.DownloadEdition(ctx, edition); err != nil { + var outdatedEditions []api.Metadata + for _, m := range allEditions { + if m.MD5 != oldEditionsHash[m.EditionID] { + outdatedEditions = append(outdatedEditions, m) + continue + } + slog.Debug("database up to date", "edition", m.EditionID) + } + + downloadEdition := func(edition api.Metadata) error { + slog.Debug("downloading", "edition", edition.EditionID) + reader, cleanupCallback, err := c.downloader.GetEdition(ctx, edition) + if err != nil { return fmt.Errorf("downloading edition '%s': %w", edition.EditionID, err) } + slog.Debug("writing", "edition", edition.EditionID) + if err := c.writer.Write(edition, reader); err != nil { + return fmt.Errorf("writing edition '%s': %w", edition.EditionID, err) + } + cleanupCallback() + slog.Debug("database successfully downloaded", "edition", edition.EditionID, "md5", edition.MD5) return nil } - jobProcessor := internal.NewJobProcessor(ctx, c.config.Parallelism) + jobProcessor := internal.NewJobProcessor(ctx, c.parallelism) for _, edition := range outdatedEditions { edition := edition processFunc := func(_ context.Context) error { @@ -103,8 +136,8 @@ func (c *Client) Run(ctx context.Context) error { return fmt.Errorf("running the job processor: %w", err) } - if c.config.Output { - result, err := c.downloader.MakeOutput() + if c.printOutput { + result, err := makeOutput(allEditions, oldEditionsHash) if err != nil { return fmt.Errorf("marshaling result log: %w", err) } @@ -123,7 +156,7 @@ func (c *Client) retry( // Max zero retries has to be set to achieve that // because the backoff never stops if MaxElapsedTime is zero. exp := backoff.NewExponentialBackOff() - exp.MaxElapsedTime = c.config.RetryFor + exp.MaxElapsedTime = c.retryFor b := backoff.BackOff(exp) if exp.MaxElapsedTime == 0 { b = backoff.WithMaxRetries(exp, 0) @@ -141,9 +174,7 @@ func (c *Client) retry( }, b, func(err error, d time.Duration) { - if c.config.Verbose { - log.Printf("%s, retrying in %v: %v", logMsg, d, err) - } + slog.Debug(logMsg, "retrying-in", d, "error", err) }, ) } diff --git a/pkg/geoipupdate/geoip_updater_test.go b/pkg/geoipupdate/geoip_updater_test.go index 28b6fb9c..e3a71fe7 100644 --- a/pkg/geoipupdate/geoip_updater_test.go +++ b/pkg/geoipupdate/geoip_updater_test.go @@ -15,7 +15,10 @@ import ( "testing" "time" - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/download" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/api" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/config" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/lock" + geoipupdatewriter "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/writer" "github.com/stretchr/testify/require" ) @@ -29,13 +32,13 @@ func TestFullDownload(t *testing.T) { defer os.RemoveAll(tempDir) edition := "edition-1" - dbFile := filepath.Join(tempDir, edition+download.Extension) + dbFile := filepath.Join(tempDir, edition+".mmdb") // equivalent MD5: 618dd27a10de24809ec160d6807f363f err = os.WriteFile(dbFile, []byte("edition-1 content"), os.ModePerm) require.NoError(t, err) edition = "edition-2" - dbFile = filepath.Join(tempDir, edition+download.Extension) + dbFile = filepath.Join(tempDir, edition+".mmdb") err = os.WriteFile(dbFile, []byte("edition-2 content"), os.ModePerm) require.NoError(t, err) @@ -78,7 +81,7 @@ func TestFullDownload(t *testing.T) { content := "new " + name + " content" header := &tar.Header{ - Name: name + download.Extension, + Name: name + ".mmdb", Size: int64(len(content)), } @@ -113,7 +116,7 @@ func TestFullDownload(t *testing.T) { defer server.Close() ctx := context.Background() - conf := &Config{ + conf := &config.Config{ AccountID: 0, // AccountID is not relevant for this test. LicenseKey: "000000000001", // LicenseKey is not relevant for this test. DatabaseDirectory: tempDir, @@ -128,8 +131,19 @@ func TestFullDownload(t *testing.T) { logOutput := &bytes.Buffer{} - client, err := NewClient(conf) + downloader := api.NewHTTPDownloader( + conf.AccountID, + conf.LicenseKey, + http.DefaultClient, + conf.URL, + ) + + writer := geoipupdatewriter.NewDiskWriter(conf.DatabaseDirectory, conf.PreserveFileTimes) + + locker, err := lock.NewFileLock(conf.LockFile) require.NoError(t, err) + + client := NewClient(conf, downloader, locker, writer) client.output = log.New(logOutput, "", 0) // download updates. @@ -141,7 +155,7 @@ func TestFullDownload(t *testing.T) { require.Regexp(t, expectedOutput, logOutput.String()) // edition-1 file hasn't been modified. - dbFile = filepath.Join(tempDir, "edition-1"+download.Extension) + dbFile = filepath.Join(tempDir, "edition-1.mmdb") //nolint:gosec // we need to read the content of the file in this test. fileContent, err := os.ReadFile(dbFile) require.NoError(t, err) @@ -151,24 +165,24 @@ func TestFullDownload(t *testing.T) { require.LessOrEqual(t, testDate, database.ModTime()) // edition-2 file has been updated. - dbFile = filepath.Join(tempDir, "edition-2"+download.Extension) + dbFile = filepath.Join(tempDir, "edition-2.mmdb") //nolint:gosec // we need to read the content of the file in this test. fileContent, err = os.ReadFile(dbFile) require.NoError(t, err) require.Equal(t, "new edition-2 content", string(fileContent)) - modTime, err := download.ParseTime("2024-02-23") + modTime, err := api.ParseTime("2024-02-23") require.NoError(t, err) database, err = os.Stat(dbFile) require.NoError(t, err) require.Equal(t, modTime, database.ModTime()) // edition-3 file has been downloaded. - dbFile = filepath.Join(tempDir, "edition-3"+download.Extension) + dbFile = filepath.Join(tempDir, "edition-3.mmdb") //nolint:gosec // we need to read the content of the file in this test. fileContent, err = os.ReadFile(dbFile) require.NoError(t, err) require.Equal(t, "new edition-3 content", string(fileContent)) - modTime, err = download.ParseTime("2024-02-02") + modTime, err = api.ParseTime("2024-02-02") require.NoError(t, err) database, err = os.Stat(dbFile) require.NoError(t, err) diff --git a/pkg/geoipupdate/internal/file_lock.go b/pkg/geoipupdate/lock/file_lock.go similarity index 60% rename from pkg/geoipupdate/internal/file_lock.go rename to pkg/geoipupdate/lock/file_lock.go index 598d9696..66bdd99d 100644 --- a/pkg/geoipupdate/internal/file_lock.go +++ b/pkg/geoipupdate/lock/file_lock.go @@ -1,45 +1,37 @@ -package internal +package lock import ( "fmt" - "log" "os" "path/filepath" "github.com/gofrs/flock" ) -// FileLock provides a file lock mechanism based on flock. -type FileLock struct { - lock *flock.Flock - verbose bool +// fileLock provides a file lock mechanism based on flock. +type fileLock struct { + lock *flock.Flock } // NewFileLock creates a new instance of FileLock. -func NewFileLock(path string, verbose bool) (*FileLock, error) { +// +//nolint:revive // unexported type fileLock is not meant to be used as a standalone type. +func NewFileLock(path string) (*fileLock, error) { err := os.MkdirAll(filepath.Dir(path), 0o750) if err != nil { return nil, fmt.Errorf("creating lock file directory: %w", err) } - if verbose { - log.Printf("Initializing file lock at %s", path) - } - - return &FileLock{ - lock: flock.New(path), - verbose: verbose, + return &fileLock{ + lock: flock.New(path), }, nil } // Release unlocks the file lock. -func (f *FileLock) Release() error { +func (f *fileLock) Release() error { if err := f.lock.Unlock(); err != nil { return fmt.Errorf("releasing file lock at %s: %w", f.lock.Path(), err) } - if f.verbose { - log.Printf("Lock file %s successfully released", f.lock.Path()) - } return nil } @@ -47,7 +39,7 @@ func (f *FileLock) Release() error { // It is possible for multiple goroutines within the same process // to acquire the same lock, so acquireLock is not thread safe in // that sense, but protects access across different processes. -func (f *FileLock) Acquire() error { +func (f *fileLock) Acquire() error { ok, err := f.lock.TryLock() if err != nil { return fmt.Errorf("acquiring file lock at %s: %w", f.lock.Path(), err) @@ -55,8 +47,5 @@ func (f *FileLock) Acquire() error { if !ok { return fmt.Errorf("lock %s already acquired by another process", f.lock.Path()) } - if f.verbose { - log.Printf("Acquired lock file at %s", f.lock.Path()) - } return nil } diff --git a/pkg/geoipupdate/internal/file_lock_test.go b/pkg/geoipupdate/lock/file_lock_test.go similarity index 96% rename from pkg/geoipupdate/internal/file_lock_test.go rename to pkg/geoipupdate/lock/file_lock_test.go index a6d3d2e4..8a64bdd8 100644 --- a/pkg/geoipupdate/internal/file_lock_test.go +++ b/pkg/geoipupdate/lock/file_lock_test.go @@ -1,4 +1,4 @@ -package internal +package lock import ( "path/filepath" @@ -12,7 +12,7 @@ import ( func TestAcquireFileLock(t *testing.T) { tempDir := t.TempDir() - fl, err := NewFileLock(filepath.Join(tempDir, ".geoipupdate.lock"), false) + fl, err := NewFileLock(filepath.Join(tempDir, ".geoipupdate.lock")) require.NoError(t, err) defer func() { err := fl.Release() diff --git a/pkg/geoipupdate/lock/lock.go b/pkg/geoipupdate/lock/lock.go new file mode 100644 index 00000000..e19f4d5f --- /dev/null +++ b/pkg/geoipupdate/lock/lock.go @@ -0,0 +1,9 @@ +// Package lock is responsible for synchronizing concurrent access to the client. +package lock + +// Lock presents common methods required to prevent concurrent access to +// the download client. +type Lock interface { + Acquire() error + Release() error +} diff --git a/pkg/geoipupdate/download/output.go b/pkg/geoipupdate/output.go similarity index 55% rename from pkg/geoipupdate/download/output.go rename to pkg/geoipupdate/output.go index 5a8424e5..8f626e96 100644 --- a/pkg/geoipupdate/download/output.go +++ b/pkg/geoipupdate/output.go @@ -1,11 +1,15 @@ -package download +package geoipupdate import ( "encoding/json" "fmt" + "time" + + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/api" ) -// ReadResult is the struct returned by a Reader's Get method. +// output contains information collected about a certain database edition +// collected during a download attempt. type output struct { EditionID string `json:"edition_id"` OldHash string `json:"old_hash"` @@ -14,18 +18,18 @@ type output struct { CheckedAt int64 `json:"checked_at"` } -// MakeOutput returns a json formatted summary about the current download attempt. -func (d *Download) MakeOutput() ([]byte, error) { +// makeOutput returns a json formatted summary about the current download attempt. +func makeOutput(allEditions []api.Metadata, oldHashes map[string]string) ([]byte, error) { var out []output - now := d.now() - for _, m := range d.metadata { - modifiedAt, err := ParseTime(m.Date) + now := time.Now() + for _, m := range allEditions { + modifiedAt, err := api.ParseTime(m.Date) if err != nil { return nil, err } o := output{ EditionID: m.EditionID, - OldHash: d.oldEditionsHash[m.EditionID], + OldHash: oldHashes[m.EditionID], NewHash: m.MD5, ModifiedAt: modifiedAt.Unix(), CheckedAt: now.Unix(), diff --git a/pkg/geoipupdate/output_test.go b/pkg/geoipupdate/output_test.go new file mode 100644 index 00000000..2ef4f41d --- /dev/null +++ b/pkg/geoipupdate/output_test.go @@ -0,0 +1,35 @@ +package geoipupdate + +import ( + "testing" + + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/api" + "github.com/stretchr/testify/require" +) + +func TestOutputFormat(t *testing.T) { + oldHashes := map[string]string{ + "edition-1": "1", + "edition-2": "10", + } + + metadata := []api.Metadata{ + { + EditionID: "edition-1", + MD5: "1", + Date: "2024-01-01", + }, + { + EditionID: "edition-2", + MD5: "11", + Date: "2024-02-01", + }, + } + + //nolint:lll + expectedOutput := `[{"edition_id":"edition\-1","old_hash":"1","new_hash":"1","modified_at":1704067200,"checked_at":\d+},{"edition_id":"edition\-2","old_hash":"10","new_hash":"11","modified_at":1706745600,"checked_at":\d+}]` + + output, err := makeOutput(metadata, oldHashes) + require.NoError(t, err) + require.Regexp(t, expectedOutput, string(output)) +} diff --git a/pkg/geoipupdate/download/writer.go b/pkg/geoipupdate/writer/disk.go similarity index 59% rename from pkg/geoipupdate/download/writer.go rename to pkg/geoipupdate/writer/disk.go index 274749d9..09aa8263 100644 --- a/pkg/geoipupdate/download/writer.go +++ b/pkg/geoipupdate/writer/disk.go @@ -1,75 +1,82 @@ -package download +package writer import ( - "archive/tar" - "compress/gzip" - "context" "crypto/md5" - "encoding/json" + "encoding/hex" "errors" "fmt" "hash" "io" "log" - "net/http" "os" "path/filepath" - "strconv" "strings" - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/internal" - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/vars" + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/api" ) const ( - downloadEndpoint = "%s/geoip/databases/%s/download?date=%s&suffix=tar.gz" + extension = ".mmdb" ) -// DownloadEdition downloads and writes an edition to a database file. -func (d *Download) DownloadEdition(ctx context.Context, edition Metadata) error { - date := strings.ReplaceAll(edition.Date, "-", "") - requestURL := fmt.Sprintf(downloadEndpoint, d.url, edition.EditionID, date) - if d.verbose { - log.Printf("Downloading: %s", requestURL) - } +// diskWriter is used to write mmdb databases into files. +type diskWriter struct { + // databaseDir is the database download path. + databaseDir string + // preserveFileTimes sets whether database modification times are preserved across downloads. + preserveFileTimes bool +} - req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) - if err != nil { - return fmt.Errorf("creating request: %w", err) +// NewDiskWriter initializes a new fileWriter struct. +// +//nolint:revive // unexported type fileLock is not meant to be used as a standalone type. +func NewDiskWriter( + databaseDir string, + preserveFileTimes bool, +) *diskWriter { + return &diskWriter{ + databaseDir: databaseDir, + preserveFileTimes: preserveFileTimes, } - req.Header.Add("User-Agent", "geoipupdate/"+vars.Version) - req.SetBasicAuth(strconv.Itoa(d.accountID), d.licenseKey) +} - response, err := d.client.Do(req) +// GetHash returns the hash of a certain database file. +func (w *diskWriter) GetHash(editionID string) (string, error) { + databaseFilePath := w.getFilePath(editionID) + //nolint:gosec // we really need to read this file. + database, err := os.Open(databaseFilePath) if err != nil { - return fmt.Errorf("performing HTTP request: %w", err) - } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - responseBody, err := io.ReadAll(response.Body) - if err != nil { - return fmt.Errorf("reading response body: %w", err) - } - - errResponse := internal.ResponseError{ - StatusCode: response.StatusCode, + if errors.Is(err, os.ErrNotExist) { + return zeroMD5, nil } + return "", fmt.Errorf("opening database: %w", err) + } - if err := json.Unmarshal(responseBody, &errResponse); err != nil { - errResponse.Message = err.Error() + defer func() { + if err := database.Close(); err != nil { + log.Println(fmt.Errorf("closing database: %w", err)) } + }() - return fmt.Errorf("requesting edition: %w", errResponse) + md5Hash := md5.New() + if _, err := io.Copy(md5Hash, database); err != nil { + return "", fmt.Errorf("calculating database hash: %w", err) } - databaseFilePath := d.getFilePath(edition.EditionID) + result := byteToString(md5Hash.Sum(nil)) + return result, nil +} + +// Write writes the content of a mmdb database to a file. +func (w *diskWriter) Write(metadata api.Metadata, reader io.Reader) error { + databaseFilePath := w.getFilePath(metadata.EditionID) // write the result into a temporary file. fw, err := newFileWriter(databaseFilePath + ".temporary") if err != nil { - return fmt.Errorf("setting up database writer for %s: %w", edition.EditionID, err) + return fmt.Errorf("setting up database writer for %s: %w", metadata.EditionID, err) } + defer func() { if closeErr := fw.close(); closeErr != nil { err = errors.Join( @@ -79,36 +86,13 @@ func (d *Download) DownloadEdition(ctx context.Context, edition Metadata) error } }() - gzReader, err := gzip.NewReader(response.Body) - if err != nil { - return fmt.Errorf("encountered an error creating GZIP reader: %w", err) - } - defer gzReader.Close() - - tarReader := tar.NewReader(gzReader) - - // iterate through the tar archive to extract the mmdb file - for { - header, err := tarReader.Next() - if err == io.EOF { - return errors.New("tar archive does not contain an mmdb file") - } - if err != nil { - return fmt.Errorf("reading tar archive: %w", err) - } - - if strings.HasSuffix(header.Name, Extension) { - break - } - } - - if err = fw.write(tarReader); err != nil { - return fmt.Errorf("writing to the temp file for %s: %w", edition.EditionID, err) + if err = fw.write(reader); err != nil { + return fmt.Errorf("writing to the temp file for %s: %w", metadata.EditionID, err) } // make sure the hash of the temp file matches the expected hash. - if err = fw.validateHash(edition.MD5); err != nil { - return fmt.Errorf("validating hash for %s: %w", edition.EditionID, err) + if err = fw.validateHash(metadata.MD5); err != nil { + return fmt.Errorf("validating hash for %s: %w", metadata.EditionID, err) } // move the temoporary database file into it's final location and @@ -123,21 +107,22 @@ func (d *Download) DownloadEdition(ctx context.Context, edition Metadata) error } // check if we need to set the file's modified at time - if d.preserveFileTimes { - if err = setModifiedAtTime(databaseFilePath, edition.Date); err != nil { + if w.preserveFileTimes { + if err = setModifiedAtTime(databaseFilePath, metadata.Date); err != nil { return err } } - if d.verbose { - log.Printf("Database %s successfully updated: %+v", edition.EditionID, edition.MD5) - } - return nil } -// fileWriter is used to write the content of a Reader's response -// into a file. +// getFilePath construct the file path for a database edition. +func (w *diskWriter) getFilePath(editionID string) string { + return filepath.Join(w.databaseDir, editionID) + extension +} + +// fileWriter writes a mmdb file into a file and verify it's integrity +// by comparing hashes. type fileWriter struct { // file is used for writing the Reader's response. file *os.File @@ -232,7 +217,7 @@ func syncDir(path string) error { // setModifiedAtTime sets the times for a database file to a certain value. func setModifiedAtTime(path, dateString string) error { - releaseDate, err := ParseTime(dateString) + releaseDate, err := api.ParseTime(dateString) if err != nil { return err } @@ -242,3 +227,8 @@ func setModifiedAtTime(path, dateString string) error { } return nil } + +// byteToString returns the base16 representation of a byte array. +func byteToString(b []byte) string { + return hex.EncodeToString(b) +} diff --git a/pkg/geoipupdate/writer/disk_test.go b/pkg/geoipupdate/writer/disk_test.go new file mode 100644 index 00000000..03e2214a --- /dev/null +++ b/pkg/geoipupdate/writer/disk_test.go @@ -0,0 +1,122 @@ +package writer + +import ( + "io" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/api" + "github.com/stretchr/testify/require" +) + +// TestGetHash checks the GetHash method of a diskWriter. +func TestGetHash(t *testing.T) { + tempDir, err := os.MkdirTemp("", "db") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + d := NewDiskWriter(tempDir, false) + + // returns a zero hash for a non existing edition. + hash, err := d.GetHash("NewEdition") + require.NoError(t, err) + require.Equal(t, zeroMD5, hash) + + // returns the correct md5 for an existing edition. + edition := "edition-1" + dbFile := filepath.Join(tempDir, edition+extension) + + err = os.WriteFile(dbFile, []byte("edition-1 content"), os.ModePerm) + require.NoError(t, err) + + hash, err = d.GetHash(edition) + require.NoError(t, err) + require.Equal(t, "618dd27a10de24809ec160d6807f363f", hash) +} + +// TestDiskWriter checks the Write method of a diskWriter. +func TestDiskWriter(t *testing.T) { + tempDir, err := os.MkdirTemp("", "db") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + now := time.Now() + + edition := api.Metadata{ + EditionID: "edition-1", + Date: "2024-02-02", + MD5: "618dd27a10de24809ec160d6807f363f", + } + + dbContent := "edition-1 content" + tests := []struct { + description string + preserveFileTime bool + reader io.Reader + checkResult func(t *testing.T, err error) + }{ + { + description: "successful write", + preserveFileTime: false, + reader: strings.NewReader(dbContent), + checkResult: func(t *testing.T, err error) { + require.NoError(t, err) + + dbFile := filepath.Join(tempDir, edition.EditionID+extension) + + //nolint:gosec // we need to read the content of the file in this test. + fileContent, err := os.ReadFile(dbFile) + require.NoError(t, err) + require.Equal(t, dbContent, string(fileContent)) + + database, err := os.Stat(dbFile) + require.NoError(t, err) + // acomodate time drift + require.GreaterOrEqual(t, database.ModTime(), now.Add(-1*time.Hour)) + }, + }, + { + description: "successful write with db time preserved", + preserveFileTime: true, + reader: strings.NewReader(dbContent), + checkResult: func(t *testing.T, err error) { + require.NoError(t, err) + + dbFile := filepath.Join(tempDir, edition.EditionID+extension) + + //nolint:gosec // we need to read the content of the file in this test. + fileContent, err := os.ReadFile(dbFile) + require.NoError(t, err) + require.Equal(t, dbContent, string(fileContent)) + + modTime, err := api.ParseTime(edition.Date) + require.NoError(t, err) + + database, err := os.Stat(dbFile) + require.NoError(t, err) + require.Equal(t, modTime, database.ModTime()) + }, + }, + { + description: "file hash does not match metadata", + preserveFileTime: true, + reader: strings.NewReader("malformed content"), + checkResult: func(t *testing.T, err error) { + require.Error(t, err) + //nolint:lll + require.Regexp(t, "^validating hash for edition-1: md5 of new database .* does not match expected md5", err.Error()) + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + d := NewDiskWriter(tempDir, test.preserveFileTime) + err = d.Write(edition, test.reader) + test.checkResult(t, err) + }) + } +} diff --git a/pkg/geoipupdate/writer/writer.go b/pkg/geoipupdate/writer/writer.go new file mode 100644 index 00000000..6e572ce9 --- /dev/null +++ b/pkg/geoipupdate/writer/writer.go @@ -0,0 +1,21 @@ +// Package writer is responsible for writing databases editions to various +// destinations. +package writer + +import ( + "io" + + "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/api" +) + +const ( + // zeroMD5 is the default value provided as an MD5 hash for a non-existent + // database. + zeroMD5 = "00000000000000000000000000000000" +) + +// Writer presents common methods required to write mmdb databases. +type Writer interface { + GetHash(editionID string) (string, error) + Write(metadata api.Metadata, reader io.Reader) error +}