src/pkg/crypto/rsa/rsa.go - The Go Programming Language

Golang

Source file src/pkg/crypto/rsa/rsa.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	// Package rsa implements RSA encryption as specified in PKCS#1.
     6	package rsa
     7	
     8	// TODO(agl): Add support for PSS padding.
     9	
    10	import (
    11		"crypto/rand"
    12		"crypto/subtle"
    13		"errors"
    14		"hash"
    15		"io"
    16		"math/big"
    17	)
    18	
    19	var bigZero = big.NewInt(0)
    20	var bigOne = big.NewInt(1)
    21	
    22	// A PublicKey represents the public part of an RSA key.
    23	type PublicKey struct {
    24		N *big.Int // modulus
    25		E int      // public exponent
    26	}
    27	
    28	// A PrivateKey represents an RSA key
    29	type PrivateKey struct {
    30		PublicKey            // public part.
    31		D         *big.Int   // private exponent
    32		Primes    []*big.Int // prime factors of N, has >= 2 elements.
    33	
    34		// Precomputed contains precomputed values that speed up private
    35		// operations, if available.
    36		Precomputed PrecomputedValues
    37	}
    38	
    39	type PrecomputedValues struct {
    40		Dp, Dq *big.Int // D mod (P-1) (or mod Q-1) 
    41		Qinv   *big.Int // Q^-1 mod Q
    42	
    43		// CRTValues is used for the 3rd and subsequent primes. Due to a
    44		// historical accident, the CRT for the first two primes is handled
    45		// differently in PKCS#1 and interoperability is sufficiently
    46		// important that we mirror this.
    47		CRTValues []CRTValue
    48	}
    49	
    50	// CRTValue contains the precomputed chinese remainder theorem values.
    51	type CRTValue struct {
    52		Exp   *big.Int // D mod (prime-1).
    53		Coeff *big.Int // R·Coeff ≡ 1 mod Prime.
    54		R     *big.Int // product of primes prior to this (inc p and q).
    55	}
    56	
    57	// Validate performs basic sanity checks on the key.
    58	// It returns nil if the key is valid, or else an error describing a problem.
    59	func (priv *PrivateKey) Validate() error {
    60		// Check that the prime factors are actually prime. Note that this is
    61		// just a sanity check. Since the random witnesses chosen by
    62		// ProbablyPrime are deterministic, given the candidate number, it's
    63		// easy for an attack to generate composites that pass this test.
    64		for _, prime := range priv.Primes {
    65			if !prime.ProbablyPrime(20) {
    66				return errors.New("prime factor is composite")
    67			}
    68		}
    69	
    70		// Check that Πprimes == n.
    71		modulus := new(big.Int).Set(bigOne)
    72		for _, prime := range priv.Primes {
    73			modulus.Mul(modulus, prime)
    74		}
    75		if modulus.Cmp(priv.N) != 0 {
    76			return errors.New("invalid modulus")
    77		}
    78		// Check that e and totient(Πprimes) are coprime.
    79		totient := new(big.Int).Set(bigOne)
    80		for _, prime := range priv.Primes {
    81			pminus1 := new(big.Int).Sub(prime, bigOne)
    82			totient.Mul(totient, pminus1)
    83		}
    84		e := big.NewInt(int64(priv.E))
    85		gcd := new(big.Int)
    86		x := new(big.Int)
    87		y := new(big.Int)
    88		gcd.GCD(x, y, totient, e)
    89		if gcd.Cmp(bigOne) != 0 {
    90			return errors.New("invalid public exponent E")
    91		}
    92		// Check that de ≡ 1 (mod totient(Πprimes))
    93		de := new(big.Int).Mul(priv.D, e)
    94		de.Mod(de, totient)
    95		if de.Cmp(bigOne) != 0 {
    96			return errors.New("invalid private exponent D")
    97		}
    98		return nil
    99	}
   100	
   101	// GenerateKey generates an RSA keypair of the given bit size.
   102	func GenerateKey(random io.Reader, bits int) (priv *PrivateKey, err error) {
   103		return GenerateMultiPrimeKey(random, 2, bits)
   104	}
   105	
   106	// GenerateMultiPrimeKey generates a multi-prime RSA keypair of the given bit
   107	// size, as suggested in [1]. Although the public keys are compatible
   108	// (actually, indistinguishable) from the 2-prime case, the private keys are
   109	// not. Thus it may not be possible to export multi-prime private keys in
   110	// certain formats or to subsequently import them into other code.
   111	//
   112	// Table 1 in [2] suggests maximum numbers of primes for a given size.
   113	//
   114	// [1] US patent 4405829 (1972, expired)
   115	// [2] http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf
   116	func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (priv *PrivateKey, err error) {
   117		priv = new(PrivateKey)
   118		priv.E = 65537
   119	
   120		if nprimes < 2 {
   121			return nil, errors.New("rsa.GenerateMultiPrimeKey: nprimes must be >= 2")
   122		}
   123	
   124		primes := make([]*big.Int, nprimes)
   125	
   126	NextSetOfPrimes:
   127		for {
   128			todo := bits
   129			for i := 0; i < nprimes; i++ {
   130				primes[i], err = rand.Prime(random, todo/(nprimes-i))
   131				if err != nil {
   132					return nil, err
   133				}
   134				todo -= primes[i].BitLen()
   135			}
   136	
   137			// Make sure that primes is pairwise unequal.
   138			for i, prime := range primes {
   139				for j := 0; j < i; j++ {
   140					if prime.Cmp(primes[j]) == 0 {
   141						continue NextSetOfPrimes
   142					}
   143				}
   144			}
   145	
   146			n := new(big.Int).Set(bigOne)
   147			totient := new(big.Int).Set(bigOne)
   148			pminus1 := new(big.Int)
   149			for _, prime := range primes {
   150				n.Mul(n, prime)
   151				pminus1.Sub(prime, bigOne)
   152				totient.Mul(totient, pminus1)
   153			}
   154	
   155			g := new(big.Int)
   156			priv.D = new(big.Int)
   157			y := new(big.Int)
   158			e := big.NewInt(int64(priv.E))
   159			g.GCD(priv.D, y, e, totient)
   160	
   161			if g.Cmp(bigOne) == 0 {
   162				priv.D.Add(priv.D, totient)
   163				priv.Primes = primes
   164				priv.N = n
   165	
   166				break
   167			}
   168		}
   169	
   170		priv.Precompute()
   171		return
   172	}
   173	
   174	// incCounter increments a four byte, big-endian counter.
   175	func incCounter(c *[4]byte) {
   176		if c[3]++; c[3] != 0 {
   177			return
   178		}
   179		if c[2]++; c[2] != 0 {
   180			return
   181		}
   182		if c[1]++; c[1] != 0 {
   183			return
   184		}
   185		c[0]++
   186	}
   187	
   188	// mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function
   189	// specified in PKCS#1 v2.1.
   190	func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
   191		var counter [4]byte
   192		var digest []byte
   193	
   194		done := 0
   195		for done < len(out) {
   196			hash.Write(seed)
   197			hash.Write(counter[0:4])
   198			digest = hash.Sum(digest[:0])
   199			hash.Reset()
   200	
   201			for i := 0; i < len(digest) && done < len(out); i++ {
   202				out[done] ^= digest[i]
   203				done++
   204			}
   205			incCounter(&counter)
   206		}
   207	}
   208	
   209	// ErrMessageTooLong is returned when attempting to encrypt a message which is
   210	// too large for the size of the public key.
   211	var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA public key size")
   212	
   213	func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int {
   214		e := big.NewInt(int64(pub.E))
   215		c.Exp(m, e, pub.N)
   216		return c
   217	}
   218	
   219	// EncryptOAEP encrypts the given message with RSA-OAEP.
   220	// The message must be no longer than the length of the public modulus less
   221	// twice the hash length plus 2.
   222	func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, label []byte) (out []byte, err error) {
   223		hash.Reset()
   224		k := (pub.N.BitLen() + 7) / 8
   225		if len(msg) > k-2*hash.Size()-2 {
   226			err = ErrMessageTooLong
   227			return
   228		}
   229	
   230		hash.Write(label)
   231		lHash := hash.Sum(nil)
   232		hash.Reset()
   233	
   234		em := make([]byte, k)
   235		seed := em[1 : 1+hash.Size()]
   236		db := em[1+hash.Size():]
   237	
   238		copy(db[0:hash.Size()], lHash)
   239		db[len(db)-len(msg)-1] = 1
   240		copy(db[len(db)-len(msg):], msg)
   241	
   242		_, err = io.ReadFull(random, seed)
   243		if err != nil {
   244			return
   245		}
   246	
   247		mgf1XOR(db, hash, seed)
   248		mgf1XOR(seed, hash, db)
   249	
   250		m := new(big.Int)
   251		m.SetBytes(em)
   252		c := encrypt(new(big.Int), pub, m)
   253		out = c.Bytes()
   254	
   255		if len(out) < k {
   256			// If the output is too small, we need to left-pad with zeros.
   257			t := make([]byte, k)
   258			copy(t[k-len(out):], out)
   259			out = t
   260		}
   261	
   262		return
   263	}
   264	
   265	// ErrDecryption represents a failure to decrypt a message.
   266	// It is deliberately vague to avoid adaptive attacks.
   267	var ErrDecryption = errors.New("crypto/rsa: decryption error")
   268	
   269	// ErrVerification represents a failure to verify a signature.
   270	// It is deliberately vague to avoid adaptive attacks.
   271	var ErrVerification = errors.New("crypto/rsa: verification error")
   272	
   273	// modInverse returns ia, the inverse of a in the multiplicative group of prime
   274	// order n. It requires that a be a member of the group (i.e. less than n).
   275	func modInverse(a, n *big.Int) (ia *big.Int, ok bool) {
   276		g := new(big.Int)
   277		x := new(big.Int)
   278		y := new(big.Int)
   279		g.GCD(x, y, a, n)
   280		if g.Cmp(bigOne) != 0 {
   281			// In this case, a and n aren't coprime and we cannot calculate
   282			// the inverse. This happens because the values of n are nearly
   283			// prime (being the product of two primes) rather than truly
   284			// prime.
   285			return
   286		}
   287	
   288		if x.Cmp(bigOne) < 0 {
   289			// 0 is not the multiplicative inverse of any element so, if x
   290			// < 1, then x is negative.
   291			x.Add(x, n)
   292		}
   293	
   294		return x, true
   295	}
   296	
   297	// Precompute performs some calculations that speed up private key operations
   298	// in the future.
   299	func (priv *PrivateKey) Precompute() {
   300		if priv.Precomputed.Dp != nil {
   301			return
   302		}
   303	
   304		priv.Precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne)
   305		priv.Precomputed.Dp.Mod(priv.D, priv.Precomputed.Dp)
   306	
   307		priv.Precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne)
   308		priv.Precomputed.Dq.Mod(priv.D, priv.Precomputed.Dq)
   309	
   310		priv.Precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0])
   311	
   312		r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1])
   313		priv.Precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2)
   314		for i := 2; i < len(priv.Primes); i++ {
   315			prime := priv.Primes[i]
   316			values := &priv.Precomputed.CRTValues[i-2]
   317	
   318			values.Exp = new(big.Int).Sub(prime, bigOne)
   319			values.Exp.Mod(priv.D, values.Exp)
   320	
   321			values.R = new(big.Int).Set(r)
   322			values.Coeff = new(big.Int).ModInverse(r, prime)
   323	
   324			r.Mul(r, prime)
   325		}
   326	}
   327	
   328	// decrypt performs an RSA decryption, resulting in a plaintext integer. If a
   329	// random source is given, RSA blinding is used.
   330	func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) {
   331		// TODO(agl): can we get away with reusing blinds?
   332		if c.Cmp(priv.N) > 0 {
   333			err = ErrDecryption
   334			return
   335		}
   336	
   337		var ir *big.Int
   338		if random != nil {
   339			// Blinding enabled. Blinding involves multiplying c by r^e.
   340			// Then the decryption operation performs (m^e * r^e)^d mod n
   341			// which equals mr mod n. The factor of r can then be removed
   342			// by multiplying by the multiplicative inverse of r.
   343	
   344			var r *big.Int
   345	
   346			for {
   347				r, err = rand.Int(random, priv.N)
   348				if err != nil {
   349					return
   350				}
   351				if r.Cmp(bigZero) == 0 {
   352					r = bigOne
   353				}
   354				var ok bool
   355				ir, ok = modInverse(r, priv.N)
   356				if ok {
   357					break
   358				}
   359			}
   360			bigE := big.NewInt(int64(priv.E))
   361			rpowe := new(big.Int).Exp(r, bigE, priv.N)
   362			cCopy := new(big.Int).Set(c)
   363			cCopy.Mul(cCopy, rpowe)
   364			cCopy.Mod(cCopy, priv.N)
   365			c = cCopy
   366		}
   367	
   368		if priv.Precomputed.Dp == nil {
   369			m = new(big.Int).Exp(c, priv.D, priv.N)
   370		} else {
   371			// We have the precalculated values needed for the CRT.
   372			m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0])
   373			m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1])
   374			m.Sub(m, m2)
   375			if m.Sign() < 0 {
   376				m.Add(m, priv.Primes[0])
   377			}
   378			m.Mul(m, priv.Precomputed.Qinv)
   379			m.Mod(m, priv.Primes[0])
   380			m.Mul(m, priv.Primes[1])
   381			m.Add(m, m2)
   382	
   383			for i, values := range priv.Precomputed.CRTValues {
   384				prime := priv.Primes[2+i]
   385				m2.Exp(c, values.Exp, prime)
   386				m2.Sub(m2, m)
   387				m2.Mul(m2, values.Coeff)
   388				m2.Mod(m2, prime)
   389				if m2.Sign() < 0 {
   390					m2.Add(m2, prime)
   391				}
   392				m2.Mul(m2, values.R)
   393				m.Add(m, m2)
   394			}
   395		}
   396	
   397		if ir != nil {
   398			// Unblind.
   399			m.Mul(m, ir)
   400			m.Mod(m, priv.N)
   401		}
   402	
   403		return
   404	}
   405	
   406	// DecryptOAEP decrypts ciphertext using RSA-OAEP.
   407	// If random != nil, DecryptOAEP uses RSA blinding to avoid timing side-channel attacks.
   408	func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) (msg []byte, err error) {
   409		k := (priv.N.BitLen() + 7) / 8
   410		if len(ciphertext) > k ||
   411			k < hash.Size()*2+2 {
   412			err = ErrDecryption
   413			return
   414		}
   415	
   416		c := new(big.Int).SetBytes(ciphertext)
   417	
   418		m, err := decrypt(random, priv, c)
   419		if err != nil {
   420			return
   421		}
   422	
   423		hash.Write(label)
   424		lHash := hash.Sum(nil)
   425		hash.Reset()
   426	
   427		// Converting the plaintext number to bytes will strip any
   428		// leading zeros so we may have to left pad. We do this unconditionally
   429		// to avoid leaking timing information. (Although we still probably
   430		// leak the number of leading zeros. It's not clear that we can do
   431		// anything about this.)
   432		em := leftPad(m.Bytes(), k)
   433	
   434		firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
   435	
   436		seed := em[1 : hash.Size()+1]
   437		db := em[hash.Size()+1:]
   438	
   439		mgf1XOR(seed, hash, db)
   440		mgf1XOR(db, hash, seed)
   441	
   442		lHash2 := db[0:hash.Size()]
   443	
   444		// We have to validate the plaintext in constant time in order to avoid
   445		// attacks like: J. Manger. A Chosen Ciphertext Attack on RSA Optimal
   446		// Asymmetric Encryption Padding (OAEP) as Standardized in PKCS #1
   447		// v2.0. In J. Kilian, editor, Advances in Cryptology.
   448		lHash2Good := subtle.ConstantTimeCompare(lHash, lHash2)
   449	
   450		// The remainder of the plaintext must be zero or more 0x00, followed
   451		// by 0x01, followed by the message.
   452		//   lookingForIndex: 1 iff we are still looking for the 0x01
   453		//   index: the offset of the first 0x01 byte
   454		//   invalid: 1 iff we saw a non-zero byte before the 0x01.
   455		var lookingForIndex, index, invalid int
   456		lookingForIndex = 1
   457		rest := db[hash.Size():]
   458	
   459		for i := 0; i < len(rest); i++ {
   460			equals0 := subtle.ConstantTimeByteEq(rest[i], 0)
   461			equals1 := subtle.ConstantTimeByteEq(rest[i], 1)
   462			index = subtle.ConstantTimeSelect(lookingForIndex&equals1, i, index)
   463			lookingForIndex = subtle.ConstantTimeSelect(equals1, 0, lookingForIndex)
   464			invalid = subtle.ConstantTimeSelect(lookingForIndex&^equals0, 1, invalid)
   465		}
   466	
   467		if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 {
   468			err = ErrDecryption
   469			return
   470		}
   471	
   472		msg = rest[index+1:]
   473		return
   474	}
   475	
   476	// leftPad returns a new slice of length size. The contents of input are right
   477	// aligned in the new slice.
   478	func leftPad(input []byte, size int) (out []byte) {
   479		n := len(input)
   480		if n > size {
   481			n = size
   482		}
   483		out = make([]byte, size)
   484		copy(out[len(out)-n:], input)
   485		return
   486	}