Skip to content

Commit

Permalink
Move updater logic to tools package
Browse files Browse the repository at this point in the history
  • Loading branch information
vapopov committed Oct 16, 2024
1 parent d5dd770 commit 5f5688e
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 45 deletions.
6 changes: 3 additions & 3 deletions integration/autoupdate/tools/updater/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
"time"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/lib/autoupdate"
"github.com/gravitational/teleport/lib/autoupdate/tools"
)

var (
Expand All @@ -41,11 +41,11 @@ func main() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

updater := autoupdate.NewClientUpdater(
updater := tools.NewUpdater(
clientTools(),
toolsDir,
version,
autoupdate.WithBaseURL(baseUrl),
tools.WithBaseURL(baseUrl),
)
toolsVersion, reExec := updater.CheckLocal()
if reExec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/lib/autoupdate"
"github.com/gravitational/teleport/lib/autoupdate/tools"
)

var (
Expand All @@ -50,11 +50,11 @@ func TestUpdate(t *testing.T) {
defer cancel()

// Fetch compiled test binary with updater logic and install to $TELEPORT_HOME.
updater := autoupdate.NewClientUpdater(
updater := tools.NewUpdater(
clientTools(),
toolsDir,
testVersions[0],
autoupdate.WithBaseURL(fmt.Sprintf("http://%s", baseURL)),
tools.WithBaseURL(fmt.Sprintf("http://%s", baseURL)),
)
err := updater.Update(ctx, testVersions[0])
require.NoError(t, err)
Expand Down Expand Up @@ -92,11 +92,11 @@ func TestParallelUpdate(t *testing.T) {
defer cancel()

// Initial fetch the updater binary un-archive and replace.
updater := autoupdate.NewClientUpdater(
updater := tools.NewUpdater(
clientTools(),
toolsDir,
testVersions[0],
autoupdate.WithBaseURL(fmt.Sprintf("http://%s", baseURL)),
tools.WithBaseURL(fmt.Sprintf("http://%s", baseURL)),
)
err := updater.Update(ctx, testVersions[0])
require.NoError(t, err)
Expand Down Expand Up @@ -166,11 +166,11 @@ func TestUpdateInterruptSignal(t *testing.T) {
defer cancel()

// Initial fetch the updater binary un-archive and replace.
updater := autoupdate.NewClientUpdater(
updater := tools.NewUpdater(
clientTools(),
toolsDir,
testVersions[0],
autoupdate.WithBaseURL(fmt.Sprintf("http://%s", baseURL)),
tools.WithBaseURL(fmt.Sprintf("http://%s", baseURL)),
)
err := updater.Update(ctx, testVersions[0])
require.NoError(t, err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate
package tools

func init() {
featureFlag |= FlagEnt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate
package tools

func init() {
featureFlag |= FlagFips
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate
package tools

import (
"fmt"
Expand Down
54 changes: 26 additions & 28 deletions lib/autoupdate/client_update.go → lib/autoupdate/tools/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate
package tools

import (
"bytes"
Expand Down Expand Up @@ -51,8 +51,6 @@ const (
teleportToolsVersionEnv = "TELEPORT_TOOLS_VERSION"
// baseURL is CDN URL for downloading official Teleport packages.
baseURL = "https://cdn.teleport.dev"
// checksumHexLen is length of the hash sum.
checksumHexLen = 64
// reservedFreeDisk is the predefined amount of free disk space (in bytes) required
// to remain available after downloading archives.
reservedFreeDisk = 10 * 1024 * 1024 // 10 Mb
Expand All @@ -67,25 +65,25 @@ var (
pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`)
)

// ClientOption applies an option value for the ClientUpdater.
type ClientOption func(u *ClientUpdater)
// Option applies an option value for the Updater.
type Option func(u *Updater)

// WithBaseURL defines custom base url for the updater.
func WithBaseURL(baseUrl string) ClientOption {
return func(u *ClientUpdater) {
func WithBaseURL(baseUrl string) Option {
return func(u *Updater) {
u.baseUrl = baseUrl
}
}

// WithClient defines custom http client for the ClientUpdater.
func WithClient(client *http.Client) ClientOption {
return func(u *ClientUpdater) {
// WithClient defines custom http client for the Updater.
func WithClient(client *http.Client) Option {
return func(u *Updater) {
u.client = client
}
}

// ClientUpdater is updater implementation for the client tools auto updates.
type ClientUpdater struct {
// Updater is updater implementation for the client tools auto updates.
type Updater struct {
toolsDir string
localVersion string
tools []string
Expand All @@ -94,9 +92,9 @@ type ClientUpdater struct {
client *http.Client
}

// NewClientUpdater initiate updater for the client tools auto updates.
func NewClientUpdater(tools []string, toolsDir string, localVersion string, options ...ClientOption) *ClientUpdater {
updater := &ClientUpdater{
// NewUpdater initiate updater for the client tools auto updates.
func NewUpdater(tools []string, toolsDir string, localVersion string, options ...Option) *Updater {
updater := &Updater{
tools: tools,
toolsDir: toolsDir,
localVersion: localVersion,
Expand All @@ -111,7 +109,7 @@ func NewClientUpdater(tools []string, toolsDir string, localVersion string, opti
}

// CheckLocal is run at client tool startup and will only perform local checks.
func (u *ClientUpdater) CheckLocal() (string, bool) {
func (u *Updater) CheckLocal() (string, bool) {
// Check if the user has requested a specific version of client tools.
requestedVersion := os.Getenv(teleportToolsVersionEnv)
switch {
Expand All @@ -125,7 +123,7 @@ func (u *ClientUpdater) CheckLocal() (string, bool) {

// If a version of client tools has already been downloaded to
// tools directory, return that.
toolsVersion, err := checkClientToolVersion(u.toolsDir)
toolsVersion, err := checkToolVersion(u.toolsDir)
if err != nil {
return "", false
}
Expand All @@ -139,7 +137,7 @@ func (u *ClientUpdater) CheckLocal() (string, bool) {

// CheckRemote will check against Proxy Service if client tools need to be
// updated.
func (u *ClientUpdater) CheckRemote(ctx context.Context, proxyAddr string) (string, bool, error) {
func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (string, bool, error) {
// Check if the user has requested a specific version of client tools.
requestedVersion := os.Getenv(teleportToolsVersionEnv)
switch {
Expand Down Expand Up @@ -167,7 +165,7 @@ func (u *ClientUpdater) CheckRemote(ctx context.Context, proxyAddr string) (stri

// If a version of client tools has already been downloaded to
// tools directory, return that.
toolsVersion, err := checkClientToolVersion(u.toolsDir)
toolsVersion, err := checkToolVersion(u.toolsDir)
if err != nil {
return "", false, trace.Wrap(err)
}
Expand All @@ -187,7 +185,7 @@ func (u *ClientUpdater) CheckRemote(ctx context.Context, proxyAddr string) (stri
}

// UpdateWithLock acquires filesystem lock, downloads requested version package, unarchive and replace existing one.
func (u *ClientUpdater) UpdateWithLock(ctx context.Context, toolsVersion string) (err error) {
func (u *Updater) UpdateWithLock(ctx context.Context, toolsVersion string) (err error) {
// Create tools directory if it does not exist.
if err := os.MkdirAll(u.toolsDir, 0o755); err != nil {
return trace.Wrap(err)
Expand All @@ -204,7 +202,7 @@ func (u *ClientUpdater) UpdateWithLock(ctx context.Context, toolsVersion string)
// If the version of the running binary or the version downloaded to
// tools directory is the same as the requested version of client tools,
// nothing to be done, exit early.
teleportVersion, err := checkClientToolVersion(u.toolsDir)
teleportVersion, err := checkToolVersion(u.toolsDir)
if err != nil && !trace.IsNotFound(err) {
return trace.Wrap(err)

Expand All @@ -223,7 +221,7 @@ func (u *ClientUpdater) UpdateWithLock(ctx context.Context, toolsVersion string)

// Update downloads requested version and replace it with existing one and cleanups the previous downloads
// with defined updater directory suffix.
func (u *ClientUpdater) Update(ctx context.Context, toolsVersion string) error {
func (u *Updater) Update(ctx context.Context, toolsVersion string) error {
// Get platform specific download URLs.
packages, err := teleportPackageURLs(u.baseUrl, toolsVersion)
if err != nil {
Expand All @@ -244,7 +242,7 @@ func (u *ClientUpdater) Update(ctx context.Context, toolsVersion string) error {

// update downloads the archive and validate against the hash. Download to a
// temporary path within tools directory.
func (u *ClientUpdater) update(ctx context.Context, pkg packageURL) error {
func (u *Updater) update(ctx context.Context, pkg packageURL) error {
hash, err := u.downloadHash(ctx, pkg.Hash)
if pkg.Optional && trace.IsNotFound(err) {
return nil
Expand Down Expand Up @@ -289,7 +287,7 @@ func (u *ClientUpdater) update(ctx context.Context, pkg packageURL) error {
}

// Exec re-executes tool command with same arguments and environ variables.
func (u *ClientUpdater) Exec() (int, error) {
func (u *Updater) Exec() (int, error) {
path, err := toolName(u.toolsDir)
if err != nil {
return 0, trace.Wrap(err)
Expand Down Expand Up @@ -317,7 +315,7 @@ func (u *ClientUpdater) Exec() (int, error) {
return 0, nil
}

func (u *ClientUpdater) downloadHash(ctx context.Context, url string) (string, error) {
func (u *Updater) downloadHash(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", trace.Wrap(err)
Expand All @@ -335,7 +333,7 @@ func (u *ClientUpdater) downloadHash(ctx context.Context, url string) (string, e
}

var buf bytes.Buffer
_, err = io.CopyN(&buf, resp.Body, checksumHexLen)
_, err = io.CopyN(&buf, resp.Body, sha256.Size*2)
if err != nil {
return "", trace.Wrap(err)
}
Expand All @@ -346,7 +344,7 @@ func (u *ClientUpdater) downloadHash(ctx context.Context, url string) (string, e
return raw, nil
}

func (u *ClientUpdater) downloadArchive(ctx context.Context, downloadDir string, url string) (string, string, error) {
func (u *Updater) downloadArchive(ctx context.Context, downloadDir string, url string) (string, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", "", trace.Wrap(err)
Expand Down Expand Up @@ -388,5 +386,5 @@ func (u *ClientUpdater) downloadArchive(ctx context.Context, downloadDir string,
return "", "", trace.Wrap(err)
}

return f.Name(), fmt.Sprintf("%x", h.Sum(nil)), nil
return f.Name(), hex.EncodeToString(h.Sum(nil)), nil
}
8 changes: 4 additions & 4 deletions lib/autoupdate/utils.go → lib/autoupdate/tools/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate
package tools

import (
"bufio"
Expand Down Expand Up @@ -49,8 +49,8 @@ var (
featureFlag int
)

// ToolsDir returns the path to {tsh, tctl} in $TELEPORT_HOME/bin.
func ToolsDir() (string, error) {
// Dir returns the path to client tools in $TELEPORT_HOME/bin.
func Dir() (string, error) {
home := os.Getenv(types.HomeEnvVar)
if home == "" {
var err error
Expand All @@ -63,7 +63,7 @@ func ToolsDir() (string, error) {
return filepath.Join(filepath.Clean(home), ".tsh", "bin"), nil
}

func checkClientToolVersion(toolsDir string) (string, error) {
func checkToolVersion(toolsDir string) (string, error) {
// Find the path to the current executable.
path, err := toolName(toolsDir)
if err != nil {
Expand Down

0 comments on commit 5f5688e

Please sign in to comment.