diff --git a/aws/ebs.go b/aws/ebs.go index 8f701933..cdc665f8 100644 --- a/aws/ebs.go +++ b/aws/ebs.go @@ -1,9 +1,9 @@ package aws import ( - "strings" "time" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/gruntwork-io/aws-nuke/logging" @@ -47,16 +47,19 @@ func nukeAllEbsVolumes(session *session.Session, volumeIds []*string) error { _, err := svc.DeleteVolume(params) if err != nil { - // Ignore not found errors, some volumes are deleted along with EC2 Instances - if !strings.Contains(err.Error(), "InvalidVolume.NotFound") { - logging.Logger.Errorf("[Failed] %s", err) - return errors.WithStackTrace(err) + if awsErr, isAwsErr := err.(awserr.Error); isAwsErr && awsErr.Code() == "VolumeInUse" { + logging.Logger.Warnf("EBS volume %s can't be deleted, it is still attached to an active resource", *volumeID) + return nil + } else if awsErr, isAwsErr := err.(awserr.Error); isAwsErr && awsErr.Code() == "InvalidVolume.NotFound" { + logging.Logger.Infof("EBS volume %s has already been deleted", *volumeID) + return nil } - logging.Logger.Infof("EBS volume %s has already been deleted", *volumeID) - } else { - logging.Logger.Infof("Deleted EBS Volume: %s", *volumeID) + logging.Logger.Errorf("[Failed] %s", err) + return errors.WithStackTrace(err) } + + logging.Logger.Infof("Deleted EBS Volume: %s", *volumeID) } err := svc.WaitUntilVolumeDeleted(&ec2.DescribeVolumesInput{ diff --git a/aws/ebs_test.go b/aws/ebs_test.go index 88faaa16..b8cd871f 100644 --- a/aws/ebs_test.go +++ b/aws/ebs_test.go @@ -108,7 +108,7 @@ func TestNukeEBSVolumes(t *testing.T) { } uniqueTestID := "aws-nuke-test-" + util.UniqueID() - createTestEC2Instance(t, session, uniqueTestID, false) + volume := createTestEBSVolume(t, session, uniqueTestID) output, err := ec2.New(session).DescribeVolumes(&ec2.DescribeVolumesInput{}) if err != nil { @@ -117,16 +117,76 @@ func TestNukeEBSVolumes(t *testing.T) { volumeIds := findEBSVolumesByNameTag(output, uniqueTestID) + assert.Len(t, volumeIds, 1) + assert.Equal(t, awsgo.StringValue(volume.VolumeId), awsgo.StringValue(volumeIds[0])) + if err := nukeAllEbsVolumes(session, volumeIds); err != nil { assert.Fail(t, errors.WithStackTrace(err).Error()) } - volumes, err := getAllEbsVolumes(session, region, time.Now().Add(1*time.Hour)) + + volumeIds, err = getAllEbsVolumes(session, region, time.Now().Add(1*time.Hour)) + if err != nil { + assert.Fail(t, "Unable to fetch list of EBS Volumes") + } + + assert.NotContains(t, awsgo.StringValueSlice(volumeIds), awsgo.StringValue(volume.VolumeId)) +} + +func TestNukeEBSVolumesInUse(t *testing.T) { + t.Parallel() + + region := getRandomRegion() + session, err := session.NewSession(&awsgo.Config{ + Region: awsgo.String(region)}, + ) if err != nil { - assert.Fail(t, "Unable to fetch list of EC2 Instances") + assert.Fail(t, errors.WithStackTrace(err).Error()) } - for _, volumeID := range volumeIds { - assert.NotContains(t, volumes, *volumeID) + svc := ec2.New(session) + + uniqueTestID := "aws-nuke-test-" + util.UniqueID() + volume := createTestEBSVolume(t, session, uniqueTestID) + instance := createTestEC2Instance(t, session, uniqueTestID, true) + + defer nukeAllEbsVolumes(session, []*string{volume.VolumeId}) + defer nukeAllEc2Instances(session, []*string{instance.InstanceId}) + + // attach volume to protected instance + svc.AttachVolume(&ec2.AttachVolumeInput{ + Device: awsgo.String("/dev/sdf"), + InstanceId: instance.InstanceId, + VolumeId: volume.VolumeId, + }) + + svc.WaitUntilVolumeInUse(&ec2.DescribeVolumesInput{ + VolumeIds: []*string{volume.VolumeId}, + }) + + output, err := svc.DescribeVolumes(&ec2.DescribeVolumesInput{}) + if err != nil { + assert.Fail(t, errors.WithStackTrace(err).Error()) + } + + volumeIds := findEBSVolumesByNameTag(output, uniqueTestID) + + assert.Len(t, volumeIds, 1) + assert.Equal(t, awsgo.StringValue(volume.VolumeId), awsgo.StringValue(volumeIds[0])) + + if err := nukeAllEbsVolumes(session, volumeIds); err != nil { + assert.Fail(t, errors.WithStackTrace(err).Error()) + } + + volumeIds, err = getAllEbsVolumes(session, region, time.Now().Add(1*time.Hour)) + if err != nil { + assert.Fail(t, "Unable to fetch list of EBS Volumes") + } + + // Volumes should still be in returned slice + assert.Contains(t, awsgo.StringValueSlice(volumeIds), awsgo.StringValue(volume.VolumeId)) + // remove protection so instance can be cleaned up + if err = removeEC2InstanceProtection(svc, &instance); err != nil { + assert.Fail(t, errors.WithStackTrace(err).Error()) } } diff --git a/aws/ec2_test.go b/aws/ec2_test.go index 5c33346d..cfae3dcc 100644 --- a/aws/ec2_test.go +++ b/aws/ec2_test.go @@ -93,6 +93,18 @@ func createTestEC2Instance(t *testing.T, session *session.Session, name string, return *runResult.Instances[0] } +func removeEC2InstanceProtection(svc *ec2.EC2, instance *ec2.Instance) error { + // make instance unprotected so it can be cleaned up + _, err := svc.ModifyInstanceAttribute(&ec2.ModifyInstanceAttributeInput{ + DisableApiTermination: &ec2.AttributeBooleanValue{ + Value: awsgo.Bool(false), + }, + InstanceId: instance.InstanceId, + }) + + return err +} + func findEC2InstancesByNameTag(output *ec2.DescribeInstancesOutput, name string) []*string { var instanceIds []*string for _, reservation := range output.Reservations { @@ -147,6 +159,10 @@ func TestListInstances(t *testing.T) { assert.Contains(t, instanceIds, instance.InstanceId) assert.NotContains(t, instanceIds, protectedInstance.InstanceId) + + if err = removeEC2InstanceProtection(ec2.New(session), &protectedInstance); err != nil { + assert.Fail(t, errors.WithStackTrace(err).Error()) + } } func TestNukeInstances(t *testing.T) {