Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring: Remove global variables & add more tests #92

Merged
merged 1 commit into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 114 additions & 17 deletions cmd/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"

Expand All @@ -25,28 +26,62 @@ var (

var SESSION_PATH string

func inProgressUpdates(ci bool) bool {
return !(ci)
func (r *Root) updateContext(cmd string, args []string) error {
r.ctx.Cmd = cmd // Get the command name

targetQuery, err := parseTargetQuery(cmd, args)
if err != nil {
return err
}

r.ctx.Target = targetQuery.Target

if targetQuery.From != "" {
r.ctx.From = targetQuery.From
}

if targetQuery.Resolver != "" {
r.ctx.Resolver = targetQuery.Resolver
}

// Check env for CI
if os.Getenv("CI") != "" {
r.ctx.CIMode = true
}

// Check if it is a terminal or being piped/redirected
// We want to disable realtime updates if that is the case
f, ok := r.printer.OutWriter.(*os.File)
if ok {
stdoutFileInfo, err := f.Stat()
if err != nil {
return fmt.Errorf("stdout stat failed: %s", err)
}
if (stdoutFileInfo.Mode() & os.ModeCharDevice) == 0 {
// stdout is piped, run in ci mode
r.ctx.CIMode = true
}
} else {
r.ctx.CIMode = true
}

return nil
}

func createLocations(from string) ([]globalping.Locations, bool, error) {
fromArr := strings.Split(from, ",")
if len(fromArr) == 1 {
mId, err := mapToMeasurementID(fromArr[0])
mId, err := mapFromHistory(fromArr[0])
if err != nil {
return nil, false, err
}
isPreviousMeasurementId := false
isFromHistory := false
if mId == "" {
mId = strings.TrimSpace(fromArr[0])
} else {
isPreviousMeasurementId = true
isFromHistory = true
}
return []globalping.Locations{
{
Magic: mId,
},
}, isPreviousMeasurementId, nil
return []globalping.Locations{{Magic: mId}}, isFromHistory, nil
}
locations := make([]globalping.Locations, len(fromArr))
for i, v := range fromArr {
Expand All @@ -57,8 +92,70 @@ func createLocations(from string) ([]globalping.Locations, bool, error) {
return locations, false, nil
}

// Maps a location to a measurement ID if possible
func mapToMeasurementID(location string) (string, error) {
type TargetQuery struct {
Target string
From string
Resolver string
}

var commandsWithResolver = []string{
"dns",
"http",
}

func parseTargetQuery(cmd string, args []string) (*TargetQuery, error) {
targetQuery := &TargetQuery{}
if len(args) == 0 {
return nil, errors.New("provided target is empty")
}

resolver, argsWithoutResolver := findAndRemoveResolver(args)
if resolver != "" {
// resolver was found
if !slices.Contains(commandsWithResolver, cmd) {
return nil, fmt.Errorf("command %s does not accept a resolver argument. @%s was provided", cmd, resolver)
}

targetQuery.Resolver = resolver
}

targetQuery.Target = argsWithoutResolver[0]

if len(argsWithoutResolver) > 1 {
if argsWithoutResolver[1] == "from" {
targetQuery.From = strings.TrimSpace(strings.Join(argsWithoutResolver[2:], " "))
} else {
return nil, errors.New("invalid command format")
}
}

return targetQuery, nil
}

func findAndRemoveResolver(args []string) (string, []string) {
var resolver string
resolverIndex := -1
for i := 0; i < len(args); i++ {
if len(args[i]) > 0 && args[i][0] == '@' && args[i-1] != "from" {
resolver = args[i][1:]
resolverIndex = i
break
}
}

if resolverIndex == -1 {
// resolver was not found
return "", args
}

argsClone := slices.Clone(args)
argsWithoutResolver := slices.Delete(argsClone, resolverIndex, resolverIndex+1)

return resolver, argsWithoutResolver
}

// Maps a location to a measurement ID from history, if possible.
func mapFromHistory(location string) (string, error) {
if location == "" {
return "", nil
}
Expand All @@ -67,19 +164,19 @@ func mapToMeasurementID(location string) (string, error) {
if err != nil {
return "", ErrInvalidIndex
}
return getMeasurementID(index)
return getIdFromHistory(index)
}
if location == "first" {
return getMeasurementID(1)
return getIdFromHistory(1)
}
if location == "last" || location == "previous" {
return getMeasurementID(-1)
return getIdFromHistory(-1)
}
return "", nil
}

// Returns the measurement ID at the given index from the session history
func getMeasurementID(index int) (string, error) {
func getIdFromHistory(index int) (string, error) {
if index == 0 {
return "", ErrInvalidIndex
}
Expand Down Expand Up @@ -130,7 +227,7 @@ func getMeasurementID(index int) (string, error) {
}

// Saves the measurement ID to the session history
func saveMeasurementID(id string) error {
func saveIdToHistory(id string) error {
_, err := os.Stat(getSessionPath())
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
Expand Down
Loading