Skip to content

Commit

Permalink
optimize error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ysmood committed Dec 26, 2023
1 parent a840f3a commit 8a55c96
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 73 deletions.
22 changes: 17 additions & 5 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"encoding/base64"
"fmt"
"io"
"log"
"os"
Expand Down Expand Up @@ -50,7 +51,11 @@ func startAgent() {
log.Println("background whisper agent started")
}

func callAgent(decrypt bool, publicKey string, conf whisper.Config, inFile, outFile string) bool {
func agentCheckPassphrase(prv whisper.PrivateKey) bool {
return whisper.IsPassphraseRight(WHISPER_AGENT_ADDR, prv)
}

func agentWhisper(decrypt bool, addPubKey string, conf whisper.Config, inFile, outFile string) {
in := getInput(inFile)
defer func() { _ = in.Close() }()

Expand All @@ -65,18 +70,25 @@ func callAgent(decrypt bool, publicKey string, conf whisper.Config, inFile, outF
req.Config.Public = append(req.Config.Public, pub)
}
} else {
req.PublicKey = prefixPublicKey(publicKey, out)
req.PublicKey = prefixPublicKey(addPubKey, out)
}

return whisper.CallAgent(WHISPER_AGENT_ADDR, req, in, out)
defer func() {
if err := recover(); err != nil {
fmt.Fprintln(out)
fmt.Fprintln(out, err)
}
}()

whisper.CallAgent(WHISPER_AGENT_ADDR, req, in, out)
}

// If there's no public key, the output will be prefixed with "_".
// If the public key is remote, the output will be prefixed with "@", the prefix will end with space.
// If the public key is local, the output will be prefixed with ".", the prefix will end with space.
func prefixPublicKey(publicKey string, out io.Writer) secure.KeyWithFilter {
if publicKey == "." {
publicKey = DEFAULT_KEY_NAME + PUB_SUFFIX
publicKey = pubKeyName(DEFAULT_KEY_NAME)
}

if publicKey == "" {
Expand Down Expand Up @@ -153,7 +165,7 @@ func extractPublicKey(in io.Reader) secure.KeyWithFilter {
}
default:
return secure.KeyWithFilter{
Key: getKey(DEFAULT_KEY_NAME + PUB_SUFFIX),
Key: getKey(pubKeyName(DEFAULT_KEY_NAME)),
}
}
}
70 changes: 52 additions & 18 deletions lib/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ import (
)

type AgentReq struct {
Version string
Decrypt bool
Version string
Decrypt bool
CheckPassphrase bool

PublicKey secure.KeyWithFilter
Config Config
}

type AgentRes struct {
Running bool
WrongPassphrase bool
PassphraseRight bool
WrongPublicKey bool
}

Expand Down Expand Up @@ -96,7 +98,7 @@ func (a *AgentServer) Listen(l net.Listener) {
}
}

func (a *AgentServer) Handle(s io.ReadWriteCloser) error { //nolint: cyclop,funlen
func (a *AgentServer) Handle(s io.ReadWriteCloser) error {
b, err := byframe.NewScanner(s).Next()
if err != nil {
return err
Expand All @@ -108,25 +110,48 @@ func (a *AgentServer) Handle(s io.ReadWriteCloser) error { //nolint: cyclop,funl
}

if req.Version != "" {
if req.Version == Version() {
return a.res(s, AgentRes{Running: true})
}

a.Logger.Warn("version mismatch, close server", "server", Version(), "client", req.Version)
return a.listener.Close()
return a.handleCheckVersion(s, req.Version)
}

a.cacheLoadPrivate(&req.Config)

wsp, err := New(req.Config)
if req.CheckPassphrase {
return a.handleCheckPassphrase(s, req.Config.Private)
}

return a.handleWhisper(s, req)
}

func (a *AgentServer) handleCheckVersion(s io.ReadWriteCloser, version string) error {
if version == Version() {
return a.res(s, AgentRes{Running: true})
}

a.Logger.Warn("version mismatch, close server", "server", Version(), "client", version)
return a.listener.Close()
}

func (a *AgentServer) handleCheckPassphrase(s io.ReadWriteCloser, prv PrivateKey) error {
_, err := secure.SSHPrvKey(prv.Data, prv.Passphrase)
if err != nil {
if secure.IsAuthErr(err) {
return a.res(s, AgentRes{WrongPassphrase: true})
return a.res(s, AgentRes{})
}

return err
}

a.cachePrivate(prv)

return a.res(s, AgentRes{PassphraseRight: true})
}

func (a *AgentServer) handleWhisper(s io.ReadWriteCloser, req AgentReq) error {
wsp, err := New(req.Config)
if err != nil {
return err
}

a.cachePrivate(req.Config.Private)

if req.PublicKey.Key != nil &&
Expand Down Expand Up @@ -193,18 +218,14 @@ func (a *AgentServer) res(s io.Writer, res AgentRes) error {
}

// Return true if the passphrase is correct.
func CallAgent(addr string, req AgentReq, in io.Reader, out io.Writer) bool {
func CallAgent(addr string, req AgentReq, in io.Reader, out io.Writer) {
res, stream, err := agentReq(addr, req)
if err != nil {
panic(err)
}

defer func() { _ = stream.Close() }()

if res.WrongPassphrase {
return false
}

if res.WrongPublicKey {
panic("the public key from option -a doesn't belong to the private key")
}
Expand All @@ -225,8 +246,21 @@ func CallAgent(addr string, req AgentReq, in io.Reader, out io.Writer) bool {
if err != nil {
panic("your key might be wrong or data is corrupted: " + err.Error())
}
}

func IsPassphraseRight(addr string, prv PrivateKey) bool {
res, stream, err := agentReq(addr, AgentReq{CheckPassphrase: true, Config: Config{Private: prv}})
if err != nil {
if stream == nil {
return false
}

panic(err)
}

_ = stream.Close()

return true
return res.PassphraseRight
}

func IsAgentRunning(addr, version string) bool {
Expand Down
78 changes: 37 additions & 41 deletions lib/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func TestAgentEncode(t *testing.T) {
g.Eq(str, "hello")
}

func TestAgentDecode(t *testing.T) { //nolint:funlen
func TestAgentPassphrase(t *testing.T) {
g := got.T(t)

s := whisper.NewAgentServer()
Expand All @@ -88,57 +88,53 @@ func TestAgentDecode(t *testing.T) { //nolint:funlen

go s.Listen(l)

prv, pub := whisper.PrivateKey{read("id_ecdsa"), ""}, read("id_ecdsa.pub")
prv, _ := whisper.PrivateKey{read("id_ecdsa"), ""}, read("id_ecdsa.pub")

conf := whisper.Config{
GzipLevel: gzip.DefaultCompression,
Base64: true,
Private: prv,
Public: []secure.KeyWithFilter{{Key: pub}},
}
// no passphrase
g.False(whisper.IsPassphraseRight(addr, whisper.PrivateKey{}))

encoded := []byte("AQDCRtKH43W_QilOxCmrm5Ew_jv7UKDyyaNc8558QKgFydkAIRiurj1K2SvvH-LKhA")
// right passphrase
prv.Passphrase = "test"
g.True(whisper.IsPassphraseRight(addr, prv))

{ // no passphrase
decoded := bytes.NewBuffer(nil)
// cache passphrase
prv.Passphrase = ""
g.True(whisper.IsPassphraseRight(addr, prv))

g.False(whisper.CallAgent(addr, whisper.AgentReq{
Decrypt: true,
Config: conf,
}, bytes.NewReader(encoded), decoded))
}
// wrong passphrase
prv.Passphrase = "123"
g.False(whisper.IsPassphraseRight(addr, prv))
}

{
conf.Private.Passphrase = "test"
decoded := bytes.NewBuffer(nil)
func TestAgentDecode(t *testing.T) {
g := got.T(t)

g.True(whisper.CallAgent(addr, whisper.AgentReq{
Decrypt: true,
Config: conf,
}, bytes.NewReader(encoded), decoded))
s := whisper.NewAgentServer()

g.Eq(decoded.String(), "hello")
}
l, err := net.Listen("tcp", ":0")
g.E(err)
addr := l.Addr().String()

{ // cache passphrase
conf.Private.Passphrase = ""
decoded := bytes.NewBuffer(nil)
go s.Listen(l)

g.True(whisper.CallAgent(addr, whisper.AgentReq{
Decrypt: true,
Config: conf,
}, bytes.NewReader(encoded), decoded))
prv, pub := whisper.PrivateKey{read("id_ecdsa"), ""}, read("id_ecdsa.pub")

g.Eq(decoded.String(), "hello")
conf := whisper.Config{
GzipLevel: gzip.DefaultCompression,
Base64: true,
Private: prv,
Public: []secure.KeyWithFilter{{Key: pub}},
}

{ // wrong passphrase
conf.Private.Passphrase = "123"
decoded := bytes.NewBuffer(nil)
encoded := []byte("AQDCRtKH43W_QilOxCmrm5Ew_jv7UKDyyaNc8558QKgFydkAIRiurj1K2SvvH-LKhA")

conf.Private.Passphrase = "test"
decoded := bytes.NewBuffer(nil)

g.False(whisper.CallAgent(addr, whisper.AgentReq{
Decrypt: true,
Config: conf,
}, bytes.NewReader(encoded), decoded))
}
whisper.CallAgent(addr, whisper.AgentReq{
Decrypt: true,
Config: conf,
}, bytes.NewReader(encoded), decoded)

g.Eq(decoded.String(), "hello")
}
15 changes: 9 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ func main() { //nolint: funlen

startAgent()

if !*decryptMode && publicKeys == nil {
publicKeys = publicKeysFlag{pubKeyName(DEFAULT_KEY_NAME)}
if publicKeys == nil {
if *decryptMode {
DEFAULT_KEY_NAME = *privateKey
} else {
publicKeys = publicKeysFlag{pubKeyName(*privateKey)}
}
}

conf := whisper.Config{
Expand All @@ -74,14 +78,13 @@ func main() { //nolint: funlen
in := flags.Arg(0)
out := *outputFile

if !callAgent(*decryptMode, *addPublicKey, conf, in, out) {
if !agentCheckPassphrase(conf.Private) {
if in == "" {
panic("stdin is used for piping, can't read passphrase from it, please specify the input file path in cli arg")
}

conf.Private.Passphrase = readPassphrase()
if !callAgent(*decryptMode, *addPublicKey, conf, in, out) {
panic("wrong passphrase")
}
}

agentWhisper(*decryptMode, *addPublicKey, conf, in, out)
}
4 changes: 1 addition & 3 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
"golang.org/x/term"
)

const PUB_SUFFIX = ".pub"

func getPublicKeys(paths []string) []secure.KeyWithFilter {
list := []secure.KeyWithFilter{}
for _, p := range paths {
Expand Down Expand Up @@ -113,7 +111,7 @@ func getPublicKey(p string) secure.KeyWithFilter {
}

func pubKeyName(prv string) string {
return prv + PUB_SUFFIX
return prv + ".pub"
}

type publicKeysFlag []string
Expand Down

0 comments on commit 8a55c96

Please sign in to comment.