diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go index 4d78d1eaaa6be..0f882b08e9e66 100644 --- a/src/crypto/rsa/rsa.go +++ b/src/crypto/rsa/rsa.go @@ -33,6 +33,7 @@ import ( "crypto/rand" "crypto/subtle" "errors" + "fmt" "hash" "io" "math" @@ -272,7 +273,162 @@ func (priv *PrivateKey) Validate() error { // returned key does not depend deterministically on the bytes read from rand, // and may change between calls and/or between versions. func GenerateKey(random io.Reader, bits int) (*PrivateKey, error) { - return GenerateMultiPrimeKey(random, 2, bits) + if bits < 2048 { + // Fall back to old implementation for smaller keys. + return GenerateMultiPrimeKey(random, 2, bits) + } + + priv := new(PrivateKey) + priv.E = 65537 + // p and q + primes := make([]*big.Int, 2) + priv.Primes = primes + + if priv.rsaFipsGeneratePrimeFactors(bits) != nil { + return nil, errors.New("crypto/rsa: could not generate prime factors p,q") + } + n := new(big.Int).Set(bigOne) + totient := new(big.Int).Set(bigOne) + pminus1 := new(big.Int) + for _, prime := range primes { + n.Mul(n, prime) + pminus1.Sub(prime, bigOne) + totient.Mul(totient, pminus1) + } + priv.D = new(big.Int) + e := big.NewInt(int64(priv.E)) + ok := priv.D.ModInverse(e, totient) + + if ok != nil { + priv.N = n + } else { + return nil, errors.New("crypto/rsa: modulus error with public key exponent") + } + priv.N = n + priv.Precomputed = PrecomputedValues{Dp: nil, Dq: nil, Qinv: nil, CRTValues: make([]CRTValue, 0), n: nil, p: nil, q: nil} + priv.Precompute() + return priv, nil +} + +func diffCheck(p *big.Int, q *big.Int, bits int) bool { + // 2^(nlen/2)-100 + limit := new(big.Int).Lsh(bigOne, (uint)(bits>>1)-99) + z := new(big.Int).Sub(p, q) + z = z.Abs(z) + + return z.Cmp(limit) <= 0 +} + +func rsaFipsAuxPrimeMRRounds(bits int) int { + switch { + case bits >= 4096: + return 44 + case bits >= 3072: + return 41 + case bits >= 2048: + return 38 + default: + return 0 + } +} + +func (priv *PrivateKey) rsaFipsGeneratePrimeFactors(bits int) error { + rounds := rsaFipsAuxPrimeMRRounds(bits) + bytes := ((bits >> 1) + 7) >> 3 + + E := new(big.Int).SetInt64(int64(priv.E)) + + // 1/sqrt(2) * 2^256 + base, ok := new(big.Int).SetString("0xB504F333F9DE6484597D89B3754ABE9F1D6F60BA893BA84CED17AC8583339916", 0) + if !ok { + panic("crypto/rsa: overflow of static constant sqrt2inv") + } + if (bits >> 1) < 257 { + return errors.New("crypto/rsa: Number of bits too small") + } + sqrtinv := new(big.Int).Lsh(base, (uint)((bits>>1)-257)) + + i := 0 + pbuf := make([]byte, bytes) + var p, q *big.Int + for { + // Generate p + if _, err := rand.Read(pbuf); err != nil { + panic("crypto/rsa: RNG failure") + } + pbuf[bytes-1] |= 1 + pbuf[0] |= 0xe0 + pbuf[1] |= 0xa0 + p = new(big.Int).SetBytes(pbuf) + + // check if p < 1/sqrt(2)*(2^(bits/2)-1) + for p.Cmp(sqrtinv) < 0 { + if _, err := rand.Read(pbuf); err != nil { + return fmt.Errorf("crypto/rsa: error reading from random number generator: %s", err) + } + pbuf[bytes-1] |= 1 + pbuf[0] |= 0xe0 + pbuf[1] |= 0xa0 + + p = new(big.Int).SetBytes(pbuf) + } + diff := new(big.Int).Sub(p, bigOne) + ret := new(big.Int).GCD(nil, nil, diff, E) + if ret.Cmp(bigOne) == 0 { + isPrime := p.ProbablyPrime(rounds) + if isPrime { + goto genq + } + } + i++ + if i >= 5*bits { + priv.Primes[0] = nil + priv.Primes[1] = nil + return errors.New("crypto/rsa: number of tries to find prime factor exceeded limit") + } + } +genq: + // Generate q + i = 0 + for { + if _, err := rand.Read(pbuf); err != nil { + return fmt.Errorf("crypto/rsa: error reading from random number generator: %s", err) + } + pbuf[bytes-1] |= 1 + pbuf[0] |= 0xe0 + pbuf[1] |= 0x80 + + q = new(big.Int).SetBytes(pbuf) + + // check if q < 1/sqrt(2)*(2^(bits/2)-1) + for q.Cmp(sqrtinv) < 0 || diffCheck(p, q, bits) { + if _, err := rand.Read(pbuf); err != nil { + return fmt.Errorf("crypto/rsa: error reading from random number generator: %s", err) + } + pbuf[bytes-1] |= 1 + pbuf[0] |= 0xe0 + pbuf[1] |= 0x80 + + q = new(big.Int).SetBytes(pbuf) + } + diff := new(big.Int).Sub(q, bigOne) + ret := new(big.Int).GCD(nil, nil, diff, E) + if ret.Cmp(bigOne) == 0 { + isPrime := q.ProbablyPrime(rounds) + if isPrime { + break + } + } + i++ + if i >= 10*bits { + priv.Primes[0] = nil + priv.Primes[1] = nil + return errors.New("crypto/rsa: number of tries to find prime factor exceeded limit") + } + } + priv.Primes[0] = p + priv.Primes[1] = q + return nil } // GenerateMultiPrimeKey generates a multi-prime RSA keypair of the given bit diff --git a/src/crypto/rsa/rsa_test.go b/src/crypto/rsa/rsa_test.go index 2afa045a3a0bd..83ffb2cac45f4 100644 --- a/src/crypto/rsa/rsa_test.go +++ b/src/crypto/rsa/rsa_test.go @@ -24,7 +24,7 @@ import ( ) func TestKeyGeneration(t *testing.T) { - for _, size := range []int{128, 1024, 2048, 3072} { + for _, size := range []int{128, 1024, 2048, 3072, 4096, 8192} { priv, err := GenerateKey(rand.Reader, size) if err != nil { t.Errorf("GenerateKey(%d): %v", size, err)