Skip to content

Commit

Permalink
Install cancellation handler (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
scudette authored Jul 11, 2024
1 parent acdf919 commit 80f6ecb
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 8 deletions.
5 changes: 4 additions & 1 deletion go-winpmem/cmd/acquire.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ func doAcquire() error {
}
defer closer()

return imager.WriteTo(compressed_writer)
ctx, cancel := install_sig_handler()
defer cancel()

return imager.WriteTo(ctx, compressed_writer)
}

func init() {
Expand Down
5 changes: 4 additions & 1 deletion go-winpmem/cmd/decompress.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ func doDecompress() error {
}
defer out_fd.Close()

ctx, cancel := install_sig_handler()
defer cancel()

logger := &DecompressionLogger{Logger: winpmem.NewLogger(*verbose)}
return winpmem.CopyAndLog(decompressed_fd, out_fd, logger)
return winpmem.CopyAndLog(ctx, decompressed_fd, out_fd, logger)
}

func init() {
Expand Down
31 changes: 31 additions & 0 deletions go-winpmem/cmd/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"context"
"os"
"os/signal"
"syscall"
)

func install_sig_handler() (context.Context, context.CancelFunc) {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT)

ctx, cancel := context.WithCancel(context.Background())

go func() {
select {
case <-quit:
// Ordered shutdown now.
cancel()

case <-ctx.Done():
return
}
}()

return ctx, cancel
}
10 changes: 9 additions & 1 deletion go-winpmem/compressor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package winpmem

import (
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -66,7 +67,8 @@ func GetDecompressor(header []byte, r io.Reader) (io.Reader, error) {
return nil, errors.New("Unknown compression scheme")
}

func CopyAndLog(in io.Reader, out io.Writer, logger Logger) error {
func CopyAndLog(
ctx context.Context, in io.Reader, out io.Writer, logger Logger) error {
buff := make([]byte, 1024*PAGE_SIZE)
for {
n, err := in.Read(buff)
Expand All @@ -83,5 +85,11 @@ func CopyAndLog(in io.Reader, out io.Writer, logger Logger) error {
if err != nil {
return err
}

select {
case <-ctx.Done():
return errors.New("Cancelled!")
default:
}
}
}
27 changes: 22 additions & 5 deletions go-winpmem/imager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package winpmem

import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
"os"
"sync"
Expand Down Expand Up @@ -173,7 +175,9 @@ func (self *Imager) SetSparse() {
self.sparse_output = true
}

func (self *Imager) pad(size uint64, w io.Writer) error {
func (self *Imager) pad(
ctx context.Context,
size uint64, w io.Writer) error {
if self.sparse_output {
write_seeker, ok := w.(io.WriteSeeker)
if ok {
Expand All @@ -200,14 +204,22 @@ func (self *Imager) pad(size uint64, w io.Writer) error {
self.logger.Progress(int(to_write / PAGE_SIZE))

offset += uint64(n)

select {
case <-ctx.Done():
return errors.New("Cancelled!")
default:
}
}

return nil
}

// copyRange copies a range from the base_addr to the writer. We
// assume size is a multiple of PAGE_SIZE
func (self *Imager) copyRange(base_addr, size uint64, w io.Writer) error {
func (self *Imager) copyRange(
ctx context.Context,
base_addr, size uint64, w io.Writer) error {
buff := make([]byte, BUFSIZE)
pad := make([]byte, PAGE_SIZE)
end := base_addr + size
Expand Down Expand Up @@ -275,12 +287,17 @@ func (self *Imager) copyRange(base_addr, size uint64, w io.Writer) error {
offset += uint64(actual_read)
}

select {
case <-ctx.Done():
return errors.New("Cancelled!")
default:
}
}

return nil
}

func (self *Imager) WriteTo(w io.Writer) error {
func (self *Imager) WriteTo(ctx context.Context, w io.Writer) error {
var offset uint64
for _, r := range self.stats.Run {
base_addr := uint64(r.BaseAddress)
Expand All @@ -293,7 +310,7 @@ func (self *Imager) WriteTo(w io.Writer) error {
self.logger.Info(
"Padding %v pages from %#x", pad_size/PAGE_SIZE, offset)

err := self.pad(pad_size, w)
err := self.pad(ctx, pad_size, w)
if err != nil {
return err
}
Expand All @@ -305,7 +322,7 @@ func (self *Imager) WriteTo(w io.Writer) error {
"Copying %v pages (%#x) from %#x", number_of_bytes/PAGE_SIZE,
number_of_bytes, offset)

err := self.copyRange(base_addr, number_of_bytes, w)
err := self.copyRange(ctx, base_addr, number_of_bytes, w)
if err != nil {
return err
}
Expand Down

0 comments on commit 80f6ecb

Please sign in to comment.