Skip to content

Commit

Permalink
Only resolve programs that are in shell (#668)
Browse files Browse the repository at this point in the history
The CLI is broken when any non-shell tasks/cells are being run. The
problem is that everything's treated as Shell and goes through the
program resolver (env var extraction/prompting etc).
  • Loading branch information
sourishkrout committed Sep 13, 2024
1 parent 57248b0 commit a64a67b
Show file tree
Hide file tree
Showing 24 changed files with 542 additions and 777 deletions.
2 changes: 1 addition & 1 deletion internal/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (c *base) lookPath(path string) (string, error) {

func (c *base) findDefaultProgram(name string, args []string) (string, []string, error) {
name, normArgs := normalizeProgramName(name)
if isShellLanguage(name) {
if IsShellLanguage(name) {
globalShell := shellFromShellPath(c.globalShellPath())
res, err := c.lookPath(globalShell)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions internal/command/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ func redactConfig(cfg *ProgramConfig) *ProgramConfig {
}

func isShell(cfg *ProgramConfig) bool {
return isShellProgram(filepath.Base(cfg.ProgramName)) || isShellLanguage(cfg.LanguageId)
return IsShellProgram(filepath.Base(cfg.ProgramName)) || IsShellLanguage(cfg.LanguageId)
}

func isShellProgram(programName string) bool {
func IsShellProgram(programName string) bool {
switch strings.ToLower(programName) {
case "sh", "bash", "zsh", "ksh", "shell":
return true
Expand All @@ -40,7 +40,7 @@ func isShellProgram(programName string) bool {
}

// TODO(adamb): this function is used for two quite different inputs: program name and language ID.
func isShellLanguage(languageID string) bool {
func IsShellLanguage(languageID string) bool {
switch strings.ToLower(languageID) {
// shellscripts
// TODO(adamb): breaking change: shellscript was removed to indicate
Expand Down
2 changes: 1 addition & 1 deletion internal/command/config_code_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (b *configBuilder) programPath() (programPath string) {
language := b.block.Language()

// If the language is a shell language, check frontmatter for shell overwrite.
if isShellLanguage(language) {
if IsShellLanguage(language) {
doc := b.block.Document()
fmtr, err := doc.FrontmatterWithError()
if err == nil && fmtr != nil && fmtr.Shell != "" {
Expand Down
1 change: 1 addition & 0 deletions internal/runner/client/client_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ func (r *LocalRunner) ResolveProgram(ctx context.Context, mode runnerv1.ResolveP
Mode: mode,
SessionStrategy: r.sessionStrategy,
Project: ConvertToRunnerProject(r.project),
LanguageId: language,
Source: &runnerv1.ResolveProgramRequest_Script{
Script: script,
},
Expand Down
1 change: 1 addition & 0 deletions internal/runner/client/client_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ func (r *RemoteRunner) ResolveProgram(ctx context.Context, mode runnerv1.Resolve
Env: envs,
Mode: mode,
Project: ConvertToRunnerProject(r.project),
LanguageId: language,
Source: &runnerv1.ResolveProgramRequest_Script{
Script: script,
},
Expand Down
35 changes: 23 additions & 12 deletions internal/runner/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,33 +672,44 @@ func runnerConformsOpinionatedEnvVarNaming(knownName string) bool {
func (r *runnerService) ResolveProgram(ctx context.Context, req *runnerv1.ResolveProgramRequest) (*runnerv1.ResolveProgramResponse, error) {
r.logger.Info("running ResolveProgram in runnerService")

// todo(sebastian): reenable once extension includes it in request
// if req.GetLanguageId() == "" {
// return nil, status.Error(codes.InvalidArgument, "language id is required")
// }

resolver, err := r.getProgramResolverFromReq(req)
if err != nil {
return nil, err
}

var (
result *commandpkg.ProgramResolverResult
modifiedScriptBuf bytes.Buffer
)

if script := req.GetScript(); script != "" {
result, err = resolver.Resolve(strings.NewReader(script), &modifiedScriptBuf)
} else if commands := req.GetCommands(); commands != nil && len(commands.Lines) > 0 {
result, err = resolver.Resolve(strings.NewReader(strings.Join(commands.Lines, "\n")), &modifiedScriptBuf)
} else {
err = status.Error(codes.InvalidArgument, "either script or commands must be provided")
}
if err != nil {
return nil, err
script := req.GetScript()
if commands := req.GetCommands(); script == "" && len(commands.Lines) > 0 {
script = strings.Join(commands.Lines, "\n")
}

modifiedScript := modifiedScriptBuf.String()
if script == "" {
return nil, status.Error(codes.InvalidArgument, "either script or commands must be provided")
}

// todo(sebastian): figure out how to return commands
response := &runnerv1.ResolveProgramResponse{
Script: modifiedScript,
Script: script,
}

// return early if it's not a shell language
if !IsShellLanguage(req.GetLanguageId()) {
return response, nil
}

result, err := resolver.Resolve(strings.NewReader(script), &modifiedScriptBuf)
if err != nil {
return nil, err
}
response.Script = modifiedScriptBuf.String()

for _, item := range result.Variables {
ritem := &runnerv1.ResolveProgramResponse_VarResult{
Expand Down
36 changes: 23 additions & 13 deletions internal/runnerv2service/service_resolve_program.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,44 @@ import (
func (r *runnerService) ResolveProgram(ctx context.Context, req *runnerv2.ResolveProgramRequest) (*runnerv2.ResolveProgramResponse, error) {
r.logger.Info("running ResolveProgram in runnerService")

// todo(sebastian): reenable once extension includes it in request
// if req.GetLanguageId() == "" {
// return nil, status.Error(codes.InvalidArgument, "language id is required")
// }

resolver, err := r.getProgramResolverFromReq(req)
if err != nil {
return nil, err
}

var (
result *command.ProgramResolverResult
modifiedScriptBuf bytes.Buffer
)

if script := req.GetScript(); script != "" {
result, err = resolver.Resolve(strings.NewReader(script), &modifiedScriptBuf)
} else if commands := req.GetCommands(); commands != nil && len(commands.Lines) > 0 {
script := strings.Join(commands.Lines, "\n")
result, err = resolver.Resolve(strings.NewReader(script), &modifiedScriptBuf)
} else {
err = status.Error(codes.InvalidArgument, "either script or commands must be provided")
}
if err != nil {
return nil, err
script := req.GetScript()
if commands := req.GetCommands(); script == "" && len(commands.Lines) > 0 {
script = strings.Join(commands.Lines, "\n")
}

modifiedScript := modifiedScriptBuf.String()
if script == "" {
return nil, status.Error(codes.InvalidArgument, "either script or commands must be provided")
}

// todo(sebastian): figure out how to return commands
response := &runnerv2.ResolveProgramResponse{
Script: modifiedScript,
Script: script,
}

// return early if it's not a shell language
if !command.IsShellLanguage(req.GetLanguageId()) {
return response, nil
}

result, err := resolver.Resolve(strings.NewReader(script), &modifiedScriptBuf)
if err != nil {
return nil, err
}
response.Script = modifiedScriptBuf.String()

for _, item := range result.Variables {
ritem := &runnerv2.ResolveProgramResponse_VarResult{
Expand Down
49 changes: 45 additions & 4 deletions internal/runnerv2service/service_resolve_program_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ func TestRunnerServiceResolveProgram(t *testing.T) {
{
name: "WithScript",
request: &runnerv2.ResolveProgramRequest{
Env: []string{"TEST_RESOLVED=value"},
Env: []string{"TEST_RESOLVED=value"},
LanguageId: "bash",
Source: &runnerv2.ResolveProgramRequest_Script{
Script: "export TEST_RESOLVED=default\nexport TEST_UNRESOLVED",
},
Expand All @@ -34,7 +35,8 @@ func TestRunnerServiceResolveProgram(t *testing.T) {
{
name: "WithCommands",
request: &runnerv2.ResolveProgramRequest{
Env: []string{"TEST_RESOLVED=value"},
Env: []string{"TEST_RESOLVED=value"},
LanguageId: "bash",
Source: &runnerv2.ResolveProgramRequest_Commands{
Commands: &runnerv2.ResolveProgramCommandList{
Lines: []string{"export TEST_RESOLVED=default", "export TEST_UNRESOLVED"},
Expand All @@ -45,7 +47,8 @@ func TestRunnerServiceResolveProgram(t *testing.T) {
{
name: "WithAdditionalEnv",
request: &runnerv2.ResolveProgramRequest{
Env: []string{"TEST_RESOLVED=value", "TEST_EXTRA=value"},
Env: []string{"TEST_RESOLVED=value", "TEST_EXTRA=value"},
LanguageId: "bash",
Source: &runnerv2.ResolveProgramRequest_Commands{
Commands: &runnerv2.ResolveProgramCommandList{
Lines: []string{"export TEST_RESOLVED=default", "export TEST_UNRESOLVED"},
Expand Down Expand Up @@ -93,7 +96,8 @@ func TestRunnerResolveProgram_CommandsWithNewLines(t *testing.T) {
_, client := testCreateRunnerServiceClient(t, lis)

request := &runnerv2.ResolveProgramRequest{
Env: []string{"FILE_NAME=my-file.txt"},
Env: []string{"FILE_NAME=my-file.txt"},
LanguageId: "bash",
Source: &runnerv2.ResolveProgramRequest_Commands{
Commands: &runnerv2.ResolveProgramCommandList{
Lines: []string{
Expand Down Expand Up @@ -137,3 +141,40 @@ func TestRunnerResolveProgram_CommandsWithNewLines(t *testing.T) {
resp.Commands.Lines,
)
}

func TestRunnerResolveProgram_OnlyShellLanguages(t *testing.T) {
lis, stop := testStartRunnerServiceServer(t)
t.Cleanup(stop)
_, client := testCreateRunnerServiceClient(t, lis)

t.Run("Javascript passed as script", func(t *testing.T) {
script := "console.log('test');"
request := &runnerv2.ResolveProgramRequest{
Env: []string{"TEST_RESOLVED=value"},
LanguageId: "javascript",
Source: &runnerv2.ResolveProgramRequest_Script{
Script: script,
},
}

resp, err := client.ResolveProgram(context.Background(), request)
require.NoError(t, err)
require.Len(t, resp.Vars, 0)
require.Equal(t, script, resp.Script)
})

t.Run("Python passed as commands", func(t *testing.T) {
script := "print('test')"
request := &runnerv2.ResolveProgramRequest{
LanguageId: "py",
Source: &runnerv2.ResolveProgramRequest_Commands{
Commands: &runnerv2.ResolveProgramCommandList{Lines: []string{script}},
},
}

resp, err := client.ResolveProgram(context.Background(), request)
require.NoError(t, err)
require.Len(t, resp.Vars, 0)
require.Equal(t, script, resp.Script)
})
}
Loading

0 comments on commit a64a67b

Please sign in to comment.