Skip to content

Commit

Permalink
refactor: rec -> tail-rec
Browse files Browse the repository at this point in the history
  • Loading branch information
jcouyang committed Feb 12, 2023
1 parent 42f3a7e commit c32a293
Show file tree
Hide file tree
Showing 5 changed files with 396 additions and 79 deletions.
18 changes: 12 additions & 6 deletions app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@
{-# LANGUAGE QuasiQuotes #-}
module Main where

import qualified Data.Text as T
import qualified Data.Text.IO as TIO
import Data.Void (Void)
import Dhall (EvaluateSettings (..), InputSettings (..),
inputExpr, inputExprWithSettings)
import Dhall.Core (pretty)
import Dhall.Import (load)
import Dhall.Core (Expr, pretty)
import Dhall.Secret
import Dhall.Secret.IO (parseExpr, prettyExpr, version)
import Dhall.Secret.Type (secretTypes)
import Dhall.Src
import Dhall.Src (Src)
import Options.Applicative
data EncryptOpts = EncryptOpts
{ eo'file :: Maybe String
Expand All @@ -32,14 +28,17 @@ data GenTypesOpts = GenTypesOpts { gt'output :: Maybe String }

data Command = Encrypt EncryptOpts | Decrypt DecryptOpts | GenTypes GenTypesOpts

versionOpt :: Parser (a -> a)
versionOpt = infoOption version (long "version" <> short 'v' <> help "print version")

genTypesOpt :: Parser GenTypesOpts
genTypesOpt = GenTypesOpts <$> optional (strOption
(long "output"
<> short 'o'
<> metavar "FILE"
<> help "Output types into FILE"))

encryptOpt :: Parser EncryptOpts
encryptOpt = EncryptOpts
<$> optional (strOption
(long "file"
Expand All @@ -54,6 +53,7 @@ encryptOpt = EncryptOpts
<> help "Write result to a file instead of stdout"))


decryptOpt :: Parser DecryptOpts
decryptOpt = DecryptOpts
<$> optional (strOption
(long "file"
Expand All @@ -68,12 +68,17 @@ decryptOpt = DecryptOpts
<> help "Write result to a file instead of stdout"))
<*> switch (long "plain-text" <> short 'p' <> help "decrypt into plain text without types")

encryptCmdParser :: Parser EncryptOpts
encryptCmdParser = hsubparser $ command "encrypt" (info encryptOpt (progDesc "Encrypt a Dhall expression")) <> metavar "encrypt"

decryptCmdParser :: Parser DecryptOpts
decryptCmdParser = hsubparser $ command "decrypt" (info decryptOpt (progDesc "Decrypt a Dhall expression")) <> metavar "decrypt"

genTypesCmdParser :: Parser GenTypesOpts
genTypesCmdParser = hsubparser $ command "gen-types" (info genTypesOpt (progDesc "generate types")) <> metavar "gen-types"


commands :: Parser Command
commands = Encrypt <$> encryptCmdParser
<|> Decrypt <$>decryptCmdParser
<|> GenTypes <$> genTypesCmdParser
Expand All @@ -90,6 +95,7 @@ exec (GenTypes GenTypesOpts {gt'output}) = do
let a = pretty secretTypes
maybe (TIO.putStrLn a) (`TIO.writeFile` a) gt'output

ioDhallExpr :: Maybe FilePath -> Maybe FilePath -> Bool -> (Expr Src Void -> IO (Expr Src Void)) -> IO ()
ioDhallExpr input output inplace op = do
text <- maybe TIO.getContents TIO.readFile input
expr <- parseExpr text
Expand Down
132 changes: 72 additions & 60 deletions src/Dhall/Secret/Age.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ module Dhall.Secret.Age
generateX25519Identity,
parseRecipient,
parseIdentity,
toRecipient
) where
toRecipient,
)
where

import qualified Codec.Binary.Bech32 as Bech32
import qualified Crypto.Cipher.ChaChaPoly1305 as CC
import Crypto.Error (CryptoError (..),
Expand All @@ -30,15 +32,19 @@ import Data.Text (Text)
import qualified Data.Text as T

data Stanza = Stanza
{ stzType:: ByteString
, stzArgs :: [ByteString]
, stzBody :: ByteString
} deriving Show
{ stzType :: ByteString,
stzArgs :: [ByteString],
stzBody :: ByteString
}
deriving (Show)

data X25519Recipient = X25519Recipient X25519.PublicKey

instance Show X25519Recipient where
show (X25519Recipient pub) = T.unpack $ b32 "age" pub

data X25519Identity = X25519Identity X25519.PublicKey X25519.SecretKey

instance Show X25519Identity where
show (X25519Identity _ sec) = T.unpack $ T.toUpper $ b32 "AGE-SECRET-KEY-" sec

Expand All @@ -51,8 +57,8 @@ encrypt recipients msg = do
fileKey <- getRandomBytes 16 :: IO ByteString
nonce <- getRandomBytes 16 :: IO ByteString
stanzas <- traverse (mkStanza fileKey) recipients
body <- encryptChunks (payloadKey nonce fileKey) (zeroNonceOf 11) msg
pure $ pemWriteBS $ PEM { pemName ="AGE ENCRYPTED FILE", pemHeader = [], pemContent = mkHeader fileKey stanzas <> nonce <> body}
body <- encryptChunks BS.empty (payloadKey nonce fileKey) (zeroNonceOf 11) msg
pure $ pemWriteBS $ PEM {pemName = "AGE ENCRYPTED FILE", pemHeader = [], pemContent = mkHeader fileKey stanzas <> nonce <> body}

decrypt :: ByteString -> [X25519Identity] -> IO ByteString
decrypt ciphertext identities = do
Expand All @@ -62,10 +68,10 @@ decrypt ciphertext identities = do
case find isRight $ possibleKeys of
Just (Right key) -> do
let (headerNoMac, macGot) = mkHeaderMac key stz
if macGot == mac then
decryptChunks (payloadKey nonce key) (zeroNonceOf 11) body
else error $ show $ "Header MAC not match" <> headerNoMac <> "\n" <> macGot
_ -> error "No file key found"
if macGot == mac
then decryptChunks BS.empty (payloadKey nonce key) (zeroNonceOf 11) body
else error $ show $ "Header MAC not match" <> headerNoMac <> "\n" <> macGot
_ -> error "No file key found"

generateX25519Identity :: IO X25519Identity
generateX25519Identity = do
Expand All @@ -80,20 +86,38 @@ parseIdentity i = throwCryptoErrorIO $ do
key <- X25519.secretKey (b32dec i)
pure $ X25519Identity (X25519.toPublic key) key

decryptChunks :: ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunks key nonce body = case BS.splitAt (64 * 1024) body of
(head', tail') | tail' == BS.empty -> decryptChunk key nonce head' (BS.pack [1])
(head', tail') -> decryptChunk key nonce head' (BS.pack [0]) <> decryptChunks key (incNonce nonce) tail'
decryptChunks :: ByteString -> ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunks acc key nonce body = case BS.splitAt (64 * 1024 + 16) body of
(head', tail') | tail' == BS.empty -> (acc <>) <$> decryptChunk key nonce head' (BS.pack [1])
(head', tail') -> do
decrypted <- decryptChunk key nonce head' (BS.pack [0])
decryptChunks (acc <> decrypted) key (incNonce nonce) tail'

encryptChunks :: ByteString -> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunks acc key nonce msg = case BS.splitAt (64 * 1024) msg of
(head', tail') | tail' == BS.empty -> (acc <>) <$> encryptChunk key nonce head' (BS.pack [1])
(head', tail') -> do
encrypted <- encryptChunk key nonce head' (BS.pack [0])
encryptChunks (acc <> encrypted) key (incNonce nonce) tail'

encryptChunk :: ByteString -> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunk key nonce msg isFinal = do
st <- throwCryptoErrorIO $ do
payloadNonce <- CC.nonce12 $ (nonce <> isFinal)
CC.finalizeAAD <$> CC.initialize key payloadNonce
let (e, st1) = CC.encrypt msg st
let tag = CC.finalize st1
return $ e <> (convert tag)

decryptChunk :: ByteString -> ByteString -> ByteString -> ByteString -> IO ByteString
decryptChunk key nonce cipherblob isFinal = do
st1 <- throwCryptoErrorIO $ do
payloadNonce <- CC.nonce12 $ (nonce <> isFinal)
CC.finalizeAAD <$> CC.initialize key payloadNonce
let (msg, tag) = BS.splitAt (BS.length cipherblob - 16) cipherblob
let (d, st2) = CC.decrypt msg st1
let authtag = CC.finalize st2
if (convert authtag) == tag then pure d else error "Invalid auth tag"
st1 <- throwCryptoErrorIO $ do
payloadNonce <- CC.nonce12 $ (nonce <> isFinal)
CC.finalizeAAD <$> CC.initialize key payloadNonce
let (msg, tag) = BS.splitAt (BS.length cipherblob - 16) cipherblob
let (d, st2) = CC.decrypt msg st1
let authtag = CC.finalize st2
if (convert authtag) == tag then pure d else error "Invalid auth tag"

parseCipher :: ByteString -> Either String CipherBlock
parseCipher ct = do
Expand All @@ -103,19 +127,19 @@ parseCipher ct = do
let (nonce, body) = BS.splitAt 16 rest2
pure $ Cipher header nonce body

parseHeader :: Header -> ByteString -> Either String (Header, ByteString)
parseHeader :: Header -> ByteString -> Either String (Header, ByteString)
parseHeader (Header stz mac) content = do
case BS.take 3 content of
"---" ->
let (mac', body) = BS.break isLF $ content in
Right $ (Header (reverse stz) (BS.decodeBase64Lenient $ BS.drop 4 mac'), BS.drop 1 body)
let (mac', body) = BS.break isLF $ content
in Right $ (Header (reverse stz) (BS.decodeBase64Lenient $ BS.drop 4 mac'), BS.drop 1 body)
"-> " ->
let (recipients, rest1) = BS.break isLF $ BS.drop 3 content
(fileKey, rest2) = BS.break isLF $ BS.drop 1 rest1
(stztype, rest11) = BS.break isSpace recipients
stzarg = BS.drop 1 rest11
st = Stanza {stzType = stztype, stzArgs = [stzarg], stzBody = BS.decodeBase64Lenient fileKey} in
parseHeader (Header (st:stz) mac) (BS.drop 1 rest2)
st = Stanza {stzType = stztype, stzArgs = [stzarg], stzBody = BS.decodeBase64Lenient fileKey}
in parseHeader (Header (st : stz) mac) (BS.drop 1 rest2)
_ -> Left "invalid headers"
where
isLF = (== 0x0a)
Expand All @@ -139,41 +163,27 @@ findFileKey identities (Header stanza _mac) = hasKey <$> identities <*> stanza
let dtag = CC.finalize st1
if (convert dtag) == tag then pure d else CryptoFailed CryptoError_AuthenticationTagSizeInvalid

encryptChunks :: ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunks key nonce msg = case BS.splitAt (64 * 1024) msg of
(head', tail') | tail' == BS.empty -> encryptChunk key nonce head' (BS.pack [1])
(head', tail') -> encryptChunk key nonce head' (BS.pack [0]) <> encryptChunks key (incNonce nonce) tail'

encryptChunk :: ByteString -> ByteString -> ByteString -> ByteString -> IO ByteString
encryptChunk key nonce msg isFinal = do
st <- throwCryptoErrorIO $ do
payloadNonce <- CC.nonce12 $ (nonce <> isFinal)
CC.finalizeAAD <$> CC.initialize key payloadNonce
let (e, st1) = CC.encrypt msg st
let tag = CC.finalize st1
return $ e <> (convert tag)

toRecipient :: X25519Identity -> X25519Recipient
toRecipient (X25519Identity pub _) = X25519Recipient pub

b32 :: (ByteArrayAccess b) => Text -> b -> Text
b32 header b = case Bech32.humanReadablePartFromText header of
Left e -> T.pack $ show e
Right header' -> case Bech32.encode header' (Bech32.dataPartFromBytes (convert b)) of
Left e -> T.pack $ show e
Right t -> t
Left e -> T.pack $ show e
Right header' -> case Bech32.encode header' (Bech32.dataPartFromBytes (convert b)) of
Left e -> T.pack $ show e
Right t -> t

b32dec :: Text -> ByteString
b32dec r = case Bech32.decode r of
Left _ -> error "Cannot decode bech32"
Left _ -> error "Cannot decode bech32"
Right (_, d) -> fromMaybe (error "Cannot extract bech32 data") $ Bech32.dataPartToBytes d

mkStanza :: ByteString -> X25519Recipient -> IO Stanza
mkStanza :: ByteString -> X25519Recipient -> IO Stanza
mkStanza fileKey (X25519Recipient theirPK) = do
ourKey <- X25519.generateSecretKey
let ourPK = X25519.toPublic ourKey
let shareKey = X25519.dh theirPK ourKey
let salt = (convert ourPK) <> (convert theirPK) :: ByteString
let salt = (convert ourPK) <> (convert theirPK) :: ByteString
let wrappingKey = hkdf "age-encryption.org/v1/X25519" (convert shareKey) salt
body <- throwCryptoErrorIO $ do
nonce <- CC.nonce12 (BS.pack $ take 12 $ repeat 0)
Expand All @@ -187,41 +197,43 @@ marshalStanza stanza =
let prefix = "-> " :: ByteString
body = BS.encodeBase64Unpadded' $ stzBody stanza
argLine = prefix <> stzType stanza <> " " <> intercalate " " (stzArgs stanza) <> "\n"
in argLine <>
wrap64b body <> "\n"
in argLine
<> wrap64b body
<> "\n"

mkHeader :: ByteString -> [Stanza] -> ByteString
mkHeader fileKey recipients =
let (headerNoMac, mac) = mkHeaderMac fileKey recipients
in headerNoMac <> " " <> (BS.encodeBase64Unpadded' mac) <> "\n"
in headerNoMac <> " " <> (BS.encodeBase64Unpadded' mac) <> "\n"

mkHeaderMac :: ByteString -> [Stanza] -> (ByteString, ByteString)
mkHeaderMac fileKey recipients =
let intro = "age-encryption.org/v1\n" :: ByteString
macKey = hkdf "header" fileKey ""
footer = "---" :: ByteString
stanza = BS.concat (marshalStanza <$> recipients)
headerNoMac = intro <> stanza <> footer
headerNoMac = intro <> stanza <> footer
mac = convert (hmac macKey headerNoMac :: HMAC SHA256) :: ByteString
in (headerNoMac, mac)
in (headerNoMac, mac)

hkdf :: ByteString -> ByteString -> ByteString -> ByteString
hkdf info key salt = HKDF.expand (HKDF.extract salt key ::PRK SHA256) info 32
hkdf info key salt = HKDF.expand (HKDF.extract salt key :: PRK SHA256) info 32

incNonce :: ByteString -> ByteString
incNonce n = BS.pack . snd $ foldr inc1 (True, []) (BS.unpack n)
where
inc1 cur (True, acc) = (cur + 1 == 0, (cur + 1) : acc)
inc1 cur (True, acc) = (cur + 1 == 0, (cur + 1) : acc)
inc1 cur (False, acc) = (False, cur : acc)

zeroNonceOf :: Int -> ByteString
zeroNonceOf n = BS.pack (take n $ repeat 0)

wrap64b :: ByteString -> ByteString
wrap64b :: ByteString -> ByteString
wrap64b bs =
let (head', tail') = BS.splitAt 64 bs
in if (BS.length tail' == 0) then head'
else head' <> "\n" <> wrap64b tail'
in if (BS.length tail' == 0)
then head'
else head' <> "\n" <> wrap64b tail'

payloadKey :: ByteString -> ByteString -> ByteString
payloadKey nonce filekey = HKDF.expand (HKDF.extract nonce filekey ::PRK SHA256) ("payload" :: ByteString) 32
payloadKey nonce filekey = HKDF.expand (HKDF.extract nonce filekey :: PRK SHA256) ("payload" :: ByteString) 32
13 changes: 2 additions & 11 deletions test/Age.hs
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
{-# LANGUAGE OverloadedStrings #-}
module Age where
import qualified Crypto.Cipher.ChaChaPoly1305 as CC
import Crypto.Error (throwCryptoErrorIO)
import Data.ByteArray (ByteArray, ByteArrayAccess,
Bytes, convert, pack)
import Data.ByteString (ByteString, empty)
import qualified Data.ByteString as BS
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Text.IO as TIO
import qualified Data.ByteString as BS
import Dhall.Secret.Age
import Test.HUnit

Expand All @@ -17,8 +9,7 @@ testAgeEncryption = TestCase $ do
i2 <- generateX25519Identity
let r = toRecipient i
let r2 = toRecipient i2
plaintext <- BS.readFile "./README.md"
plaintext <- BS.readFile "./test/age.md"
encrypted <- encrypt [r, r2] plaintext
print encrypted
decrypted <- decrypt encrypted [i]
assertEqual "age encryption" plaintext decrypted
3 changes: 1 addition & 2 deletions test/Spec.hs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
module Main where
import Age
import qualified Data.Text.IO as TIO
import Dhall
import Dhall.Core (pretty)
import qualified Dhall.Secret as Lib
import Dhall.Secret.IO
import System.Environment (setEnv)
import System.Environment.Blank (getEnv)
import Test.HUnit

testKms = "encrypt decrypt with KMS" ~: snapshot "./test/example01.dhall" "./test/example01.encrypted.dhall"
testAge = "encrypt decrypt with Age Algo" ~: snapshot "./test/example02.dhall" "./test/example02.encrypted.dhall"

Expand Down
Loading

0 comments on commit c32a293

Please sign in to comment.