Skip to content

Commit

Permalink
terraform: switch to Context for stop, Stoppable provisioners
Browse files Browse the repository at this point in the history
This switches to the Go "context" package for cancellation and threads
the context through all the way to evaluation to allow behavior based on
stopping deep within graph execution.

This also adds the Stop API to provisioners so they can quickly exit
when stop is called.
  • Loading branch information
mitchellh committed Jan 26, 2017
1 parent a9d799c commit f8c7b63
Show file tree
Hide file tree
Showing 12 changed files with 251 additions and 65 deletions.
154 changes: 96 additions & 58 deletions terraform/context.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package terraform

import (
"context"
"fmt"
"log"
"sort"
Expand Down Expand Up @@ -91,8 +92,9 @@ type Context struct {
l sync.Mutex // Lock acquired during any task
parallelSem Semaphore
providerInputConfig map[string]map[string]interface{}
runCh <-chan struct{}
stopCh chan struct{}
runLock sync.Mutex
runContext context.Context
runContextCancel context.CancelFunc
shadowErr error
}

Expand Down Expand Up @@ -339,8 +341,7 @@ func (c *Context) Interpolater() *Interpolater {
// This modifies the configuration in-place, so asking for Input twice
// may result in different UI output showing different current values.
func (c *Context) Input(mode InputMode) error {
v := c.acquireRun("input")
defer c.releaseRun(v)
defer c.acquireRun("input")()

if mode&InputModeVar != 0 {
// Walk the variables first for the root module. We walk them in
Expand Down Expand Up @@ -459,8 +460,7 @@ func (c *Context) Input(mode InputMode) error {
// In addition to returning the resulting state, this context is updated
// with the latest state.
func (c *Context) Apply() (*State, error) {
v := c.acquireRun("apply")
defer c.releaseRun(v)
defer c.acquireRun("apply")()

// Copy our own state
c.state = c.state.DeepCopy()
Expand Down Expand Up @@ -504,8 +504,7 @@ func (c *Context) Apply() (*State, error) {
// Plan also updates the diff of this context to be the diff generated
// by the plan, so Apply can be called after.
func (c *Context) Plan() (*Plan, error) {
v := c.acquireRun("plan")
defer c.releaseRun(v)
defer c.acquireRun("plan")()

p := &Plan{
Module: c.module,
Expand Down Expand Up @@ -600,8 +599,7 @@ func (c *Context) Plan() (*Plan, error) {
// Even in the case an error is returned, the state will be returned and
// will potentially be partially updated.
func (c *Context) Refresh() (*State, error) {
v := c.acquireRun("refresh")
defer c.releaseRun(v)
defer c.acquireRun("refresh")()

// Copy our own state
c.state = c.state.DeepCopy()
Expand Down Expand Up @@ -635,29 +633,32 @@ func (c *Context) Refresh() (*State, error) {
// Stop will block until the task completes.
func (c *Context) Stop() {
c.l.Lock()
ch := c.runCh

// If we aren't running, then just return
if ch == nil {
c.l.Unlock()
return
}
// If we're running, then stop
if c.runContextCancel != nil {
// Tell the hook we want to stop
c.sh.Stop()

// Tell the hook we want to stop
c.sh.Stop()
// Stop the context
c.runContextCancel()
c.runContextCancel = nil
}

// Close the stop channel
close(c.stopCh)
// Grab the context before we unlock
ctx := c.runContext

// Wait for us to stop
// Unlock
c.l.Unlock()
<-ch

// Wait if we have a context
if ctx != nil {
<-ctx.Done()
}
}

// Validate validates the configuration and returns any warnings or errors.
func (c *Context) Validate() ([]string, []error) {
v := c.acquireRun("validate")
defer c.releaseRun(v)
defer c.acquireRun("validate")()

var errs error

Expand Down Expand Up @@ -718,37 +719,38 @@ func (c *Context) SetVariable(k string, v interface{}) {
c.variables[k] = v
}

func (c *Context) acquireRun(phase string) chan<- struct{} {
func (c *Context) acquireRun(phase string) func() {
// Acquire the runlock first. This is the lock that is held for
// the duration of a run to prevent multiple runs.
c.runLock.Lock()

// With the run lock held, grab the context lock to make changes
// to the run context.
c.l.Lock()
defer c.l.Unlock()

// Setup debugging
dbug.SetPhase(phase)

// Wait for no channel to exist
for c.runCh != nil {
c.l.Unlock()
ch := c.runCh
<-ch
c.l.Lock()
// runContext should never be non-nil, check that here
if c.runContext != nil {
panic("acquireRun called with runContext != nil")
}

// Create the new channel
ch := make(chan struct{})
c.runCh = ch

// Reset the stop channel so we can watch that
c.stopCh = make(chan struct{})
// Create a new run context
c.runContext, c.runContextCancel = context.WithCancel(context.Background())

// Reset the stop hook so we're not stopped
c.sh.Reset()

// Reset the shadow errors
c.shadowErr = nil

return ch
return c.releaseRun
}

func (c *Context) releaseRun(ch chan<- struct{}) {
func (c *Context) releaseRun() {
// Grab the context lock so that we can make modifications to fields
c.l.Lock()
defer c.l.Unlock()

Expand All @@ -757,9 +759,17 @@ func (c *Context) releaseRun(ch chan<- struct{}) {
// phase
dbug.SetPhase("INVALID")

close(ch)
c.runCh = nil
c.stopCh = nil
// End our run. We check if runContext is non-nil because it can be
// set to nil if it was cancelled via Stop()
if c.runContextCancel != nil {
c.runContextCancel()
}

// Unset the context
c.runContext = nil

// Unlock the run lock
c.runLock.Unlock()
}

func (c *Context) walk(
Expand Down Expand Up @@ -791,13 +801,14 @@ func (c *Context) walk(
log.Printf("[DEBUG] Starting graph walk: %s", operation.String())

walker := &ContextGraphWalker{
Context: realCtx,
Operation: operation,
Context: realCtx,
Operation: operation,
StopContext: c.runContext,
}

// Watch for a stop so we can call the provider Stop() API.
doneCh := make(chan struct{})
go c.watchStop(walker, c.stopCh, doneCh)
go c.watchStop(walker, doneCh)

// Walk the real graph, this will block until it completes
realErr := graph.Walk(walker)
Expand Down Expand Up @@ -892,7 +903,15 @@ func (c *Context) walk(
return walker, realErr
}

func (c *Context) watchStop(walker *ContextGraphWalker, stopCh, doneCh <-chan struct{}) {
func (c *Context) watchStop(walker *ContextGraphWalker, doneCh <-chan struct{}) {
// Get the stop channel. runContext might be nil only during tests.
// If this is called during a proper run operation, this will never
// be nil.
var stopCh <-chan struct{}
if ctx := c.runContext; ctx != nil {
stopCh = ctx.Done()
}

// Wait for a stop or completion
select {
case <-stopCh:
Expand All @@ -904,20 +923,39 @@ func (c *Context) watchStop(walker *ContextGraphWalker, stopCh, doneCh <-chan st

// If we're here, we're stopped, trigger the call.

// Copy the providers so that a misbehaved blocking Stop doesn't
// completely hang Terraform.
walker.providerLock.Lock()
ps := make([]ResourceProvider, 0, len(walker.providerCache))
for _, p := range walker.providerCache {
ps = append(ps, p)
{
// Copy the providers so that a misbehaved blocking Stop doesn't
// completely hang Terraform.
walker.providerLock.Lock()
ps := make([]ResourceProvider, 0, len(walker.providerCache))
for _, p := range walker.providerCache {
ps = append(ps, p)
}
defer walker.providerLock.Unlock()

for _, p := range ps {
// We ignore the error for now since there isn't any reasonable
// action to take if there is an error here, since the stop is still
// advisory: Terraform will exit once the graph node completes.
p.Stop()
}
}
defer walker.providerLock.Unlock()

for _, p := range ps {
// We ignore the error for now since there isn't any reasonable
// action to take if there is an error here, since the stop is still
// advisory: Terraform will exit once the graph node completes.
p.Stop()
{
// Call stop on all the provisioners
walker.provisionerLock.Lock()
ps := make([]ResourceProvisioner, 0, len(walker.provisionerCache))
for _, p := range walker.provisionerCache {
ps = append(ps, p)
}
defer walker.provisionerLock.Unlock()

for _, p := range ps {
// We ignore the error for now since there isn't any reasonable
// action to take if there is an error here, since the stop is still
// advisory: Terraform will exit once the graph node completes.
p.Stop()
}
}
}

Expand Down
63 changes: 63 additions & 0 deletions terraform/context_apply_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,69 @@ func TestContext2Apply_cancel(t *testing.T) {
}
}

func TestContext2Apply_cancelProvisioner(t *testing.T) {
m := testModule(t, "apply-cancel-provisioner")
p := testProvider("aws")
p.ApplyFn = testApplyFn
p.DiffFn = testDiffFn
pr := testProvisioner()
ctx := testContext2(t, &ContextOpts{
Module: m,
Providers: map[string]ResourceProviderFactory{
"aws": testProviderFuncFixed(p),
},
Provisioners: map[string]ResourceProvisionerFactory{
"shell": testProvisionerFuncFixed(pr),
},
})

prStopped := make(chan struct{})
pr.ApplyFn = func(rs *InstanceState, c *ResourceConfig) error {
// Start the stop process
go ctx.Stop()

<-prStopped
return nil
}
pr.StopFn = func() error {
close(prStopped)
return nil
}

if _, err := ctx.Plan(); err != nil {
t.Fatalf("err: %s", err)
}

// Start the Apply in a goroutine
var applyErr error
stateCh := make(chan *State)
go func() {
state, err := ctx.Apply()
if err != nil {
applyErr = err
}

stateCh <- state
}()

// Wait for completion
state := <-stateCh
if applyErr != nil {
t.Fatalf("err: %s", applyErr)
}

checkStateString(t, state, `
aws_instance.foo: (tainted)
ID = foo
num = 2
type = aws_instance
`)

if !pr.StopCalled {
t.Fatal("stop should be called")
}
}

func TestContext2Apply_compute(t *testing.T) {
m := testModule(t, "apply-compute")
p := testProvider("aws")
Expand Down
3 changes: 1 addition & 2 deletions terraform/context_import.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ type ImportTarget struct {
// imported.
func (c *Context) Import(opts *ImportOpts) (*State, error) {
// Hold a lock since we can modify our own state here
v := c.acquireRun("import")
defer c.releaseRun(v)
defer c.acquireRun("import")()

// Copy our own state
c.state = c.state.DeepCopy()
Expand Down
4 changes: 4 additions & 0 deletions terraform/eval_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (

// EvalContext is the interface that is given to eval nodes to execute.
type EvalContext interface {
// Stopped returns a channel that is closed when evaluation is stopped
// via Terraform.Context.Stop()
Stopped() <-chan struct{}

// Path is the current module path.
Path() []string

Expand Down
13 changes: 13 additions & 0 deletions terraform/eval_context_builtin.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package terraform

import (
"context"
"fmt"
"log"
"strings"
Expand All @@ -12,6 +13,9 @@ import (
// BuiltinEvalContext is an EvalContext implementation that is used by
// Terraform by default.
type BuiltinEvalContext struct {
// StopContext is the context used to track whether we're complete
StopContext context.Context

// PathValue is the Path that this context is operating within.
PathValue []string

Expand Down Expand Up @@ -43,6 +47,15 @@ type BuiltinEvalContext struct {
once sync.Once
}

func (ctx *BuiltinEvalContext) Stopped() <-chan struct{} {
// This can happen during tests. During tests, we just block forever.
if ctx.StopContext == nil {
return nil
}

return ctx.StopContext.Done()
}

func (ctx *BuiltinEvalContext) Hook(fn func(Hook) (HookAction, error)) error {
for _, h := range ctx.Hooks {
action, err := fn(h)
Expand Down
Loading

0 comments on commit f8c7b63

Please sign in to comment.