From adcd5f158f9188487a93a1222284ee162a5ed277 Mon Sep 17 00:00:00 2001 From: "felix.niederwanger@suse.com" Date: Thu, 25 Mar 2021 09:09:46 +0100 Subject: [PATCH] Improve termination signal handling Add handler for termination signal. --- cmd/disko-san/chunk.go | 4 ++ cmd/disko-san/disko-san.go | 105 +++++++++++++++++++++++++++++-------- 2 files changed, 88 insertions(+), 21 deletions(-) diff --git a/cmd/disko-san/chunk.go b/cmd/disko-san/chunk.go index b830b37..7acef11 100644 --- a/cmd/disko-san/chunk.go +++ b/cmd/disko-san/chunk.go @@ -92,5 +92,9 @@ func (cf *ChunkFactory) Read(buf []byte) error { } func (cf *ChunkFactory) Stop() { + if !cf.running { + return + } cf.running = false + cf.sig <- 1 } diff --git a/cmd/disko-san/disko-san.go b/cmd/disko-san/disko-san.go index fa278cf..88f5f3d 100644 --- a/cmd/disko-san/disko-san.go +++ b/cmd/disko-san/disko-san.go @@ -3,6 +3,8 @@ package main import ( "fmt" "os" + "os/signal" + "syscall" "time" ) @@ -15,7 +17,9 @@ type conf struct { } var cf conf -var avg float32 // average for average smoothing +var avg float32 // average for average smoothing +var running bool // running flag +var done chan bool // Signal for when the main thread is completed func (cf *conf) CheckValid() error { @@ -217,6 +221,9 @@ func WriteCheck(disk *Disk, progress *Progress, statsFile string) error { fmt.Printf("\033[s") // save cursor position for progress.Pos < progress.Size { + if !running { + return fmt.Errorf("interrupted") + } // Determine size of current chunk - at the end of the disk this might not be the full size anymore size := int64(CHUNKSIZE) if progress.Pos+CHUNKSIZE > progress.Size { @@ -286,6 +293,9 @@ func ReadCheck(disk *Disk, progress *Progress) error { // Read chunks one by one and verify them fmt.Printf("\033[s") // save cursor position for progress.Pos < progress.Size { + if !running { + return fmt.Errorf("interrupted") + } // Read and verify chunk runtime := time.Now().UnixNano() n, err := disk.Read(chunk) @@ -323,8 +333,60 @@ func ReadCheck(disk *Disk, progress *Progress) error { return nil } +func printUsage() { + fmt.Printf("Usage: %s DISK [PROGRESS] [SPEEDLOG]\n", os.Args[0]) + fmt.Println(" DISK: Disk file under test") + fmt.Println(" PROGRESS: Progress file, required for job continuation") + fmt.Println(" SPEEDLOG: Performance metrics log") +} + +func parseArgs(args []string, cf *conf) error { + if len(args) < 2 { + printUsage() + os.Stdout.Sync() // Ensure usage is flushed to stdout before returning with an error + return fmt.Errorf("Missing arguments") + } + if len(args) >= 2 { + cf.disk = args[1] + if cf.disk == "-h" || cf.disk == "--help" { + printUsage() + os.Exit(0) + } + } + if len(args) >= 3 { + cf.progress = args[2] + } + if len(args) >= 4 { + cf.stats = args[3] + } + if len(args) > 4 { + return fmt.Errorf("too many arguments") + } + return nil +} + +func terminationSignalHandler() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + sig := <-sigs + fmt.Println(sig) + running = false + // Wait for termination signal but quit after 2 seconds unconditionally + select { + case <-done: + os.Exit(1) + case <-time.After(2 * time.Second): + fmt.Fprintf(os.Stderr, "Termination timeout. Forcefully quiting.\n") + os.Exit(1) + } + // Just to be sure + os.Exit(1) +} + func main() { var progress Progress + done = make(chan bool, 1) + running = true // Default settings cf.disk = "" @@ -332,23 +394,10 @@ func main() { cf.stats = "" cf.verbose = false - // TOOD: Better argument handling - if len(os.Args) < 2 { - fmt.Printf("Usage: %s DISK [PROGRESS] [SPEEDLOG]\n", os.Args[0]) - fmt.Println(" DISK: Disk file under test") - fmt.Println(" PROGRESS: Progress file, required for job continuation") - fmt.Println(" SPEEDLOG: Performance metrics log") + if err := parseArgs(os.Args, &cf); err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err) os.Exit(1) } - if len(os.Args) >= 2 { - cf.disk = os.Args[1] - } - if len(os.Args) > 2 { - cf.progress = os.Args[2] - } - if len(os.Args) > 3 { - cf.stats = os.Args[3] - } // Check configuration for validity if err := cf.CheckValid(); err != nil { @@ -446,12 +495,15 @@ func main() { progress.Size = disk.Size() } + // Termination signal handler + go terminationSignalHandler() + // Preparation step if progress.State == 0 { // Prepare disk if err := disk.Prepare(); err != nil { fmt.Fprintf(os.Stderr, "Disk preparation error: %s\n", err) - os.Exit(1) + os.Exit(10) } progress.State = 1 progress.Pos = 0 @@ -464,8 +516,13 @@ func main() { // Write step if progress.State == 1 { if err := WriteCheck(&disk, &progress, cf.stats); err != nil { - fmt.Fprintf(os.Stderr, "Write check failed: %s\n", err) - os.Exit(2) + if err.Error() == "interrupted" { + done <- true + fmt.Fprintf(os.Stderr, "Cancelled\n") + } else { + fmt.Fprintf(os.Stderr, "Write check failed: %s\n", err) + } + os.Exit(11) } progress.State = 2 progress.Pos = 0 @@ -478,8 +535,13 @@ func main() { // Read step if progress.State == 2 { if err := ReadCheck(&disk, &progress); err != nil { - fmt.Fprintf(os.Stderr, "Read check failed: %s\n", err) - os.Exit(2) + if err.Error() == "interrupted" { + done <- true + fmt.Fprintf(os.Stderr, "Cancelled\n") + } else { + fmt.Fprintf(os.Stderr, "Read check failed: %s\n", err) + } + os.Exit(12) } progress.State = 3 if err := progress.WriteIfOpen(); err != nil { @@ -489,5 +551,6 @@ func main() { } // All good + done <- true fmt.Println("Done") }