forked from toolshed/abra
		
	
		
			
				
	
	
		
			364 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			364 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package libtrust
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto"
 | |
| 	"crypto/elliptic"
 | |
| 	"crypto/tls"
 | |
| 	"crypto/x509"
 | |
| 	"encoding/base32"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/binary"
 | |
| 	"encoding/pem"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"math/big"
 | |
| 	"net/url"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| // LoadOrCreateTrustKey will load a PrivateKey from the specified path
 | |
| func LoadOrCreateTrustKey(trustKeyPath string) (PrivateKey, error) {
 | |
| 	if err := os.MkdirAll(filepath.Dir(trustKeyPath), 0700); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	trustKey, err := LoadKeyFile(trustKeyPath)
 | |
| 	if err == ErrKeyFileDoesNotExist {
 | |
| 		trustKey, err = GenerateECP256PrivateKey()
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("error generating key: %s", err)
 | |
| 		}
 | |
| 
 | |
| 		if err := SaveKey(trustKeyPath, trustKey); err != nil {
 | |
| 			return nil, fmt.Errorf("error saving key file: %s", err)
 | |
| 		}
 | |
| 
 | |
| 		dir, file := filepath.Split(trustKeyPath)
 | |
| 		if err := SavePublicKey(filepath.Join(dir, "public-"+file), trustKey.PublicKey()); err != nil {
 | |
| 			return nil, fmt.Errorf("error saving public key file: %s", err)
 | |
| 		}
 | |
| 	} else if err != nil {
 | |
| 		return nil, fmt.Errorf("error loading key file: %s", err)
 | |
| 	}
 | |
| 	return trustKey, nil
 | |
| }
 | |
| 
 | |
| // NewIdentityAuthTLSClientConfig returns a tls.Config configured to use identity
 | |
| // based authentication from the specified dockerUrl, the rootConfigPath and
 | |
| // the server name to which it is connecting.
 | |
| // If trustUnknownHosts is true it will automatically add the host to the
 | |
| // known-hosts.json in rootConfigPath.
 | |
| func NewIdentityAuthTLSClientConfig(dockerUrl string, trustUnknownHosts bool, rootConfigPath string, serverName string) (*tls.Config, error) {
 | |
| 	tlsConfig := newTLSConfig()
 | |
| 
 | |
| 	trustKeyPath := filepath.Join(rootConfigPath, "key.json")
 | |
| 	knownHostsPath := filepath.Join(rootConfigPath, "known-hosts.json")
 | |
| 
 | |
| 	u, err := url.Parse(dockerUrl)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("unable to parse machine url")
 | |
| 	}
 | |
| 
 | |
| 	if u.Scheme == "unix" {
 | |
| 		return nil, nil
 | |
| 	}
 | |
| 
 | |
| 	addr := u.Host
 | |
| 	proto := "tcp"
 | |
| 
 | |
| 	trustKey, err := LoadOrCreateTrustKey(trustKeyPath)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("unable to load trust key: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	knownHosts, err := LoadKeySetFile(knownHostsPath)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("could not load trusted hosts file: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	allowedHosts, err := FilterByHosts(knownHosts, addr, false)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("error filtering hosts: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	certPool, err := GenerateCACertPool(trustKey, allowedHosts)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("Could not create CA pool: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	tlsConfig.ServerName = serverName
 | |
| 	tlsConfig.RootCAs = certPool
 | |
| 
 | |
| 	x509Cert, err := GenerateSelfSignedClientCert(trustKey)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("certificate generation error: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	tlsConfig.Certificates = []tls.Certificate{{
 | |
| 		Certificate: [][]byte{x509Cert.Raw},
 | |
| 		PrivateKey:  trustKey.CryptoPrivateKey(),
 | |
| 		Leaf:        x509Cert,
 | |
| 	}}
 | |
| 
 | |
| 	tlsConfig.InsecureSkipVerify = true
 | |
| 
 | |
| 	testConn, err := tls.Dial(proto, addr, tlsConfig)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("tls Handshake error: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	opts := x509.VerifyOptions{
 | |
| 		Roots:         tlsConfig.RootCAs,
 | |
| 		CurrentTime:   time.Now(),
 | |
| 		DNSName:       tlsConfig.ServerName,
 | |
| 		Intermediates: x509.NewCertPool(),
 | |
| 	}
 | |
| 
 | |
| 	certs := testConn.ConnectionState().PeerCertificates
 | |
| 	for i, cert := range certs {
 | |
| 		if i == 0 {
 | |
| 			continue
 | |
| 		}
 | |
| 		opts.Intermediates.AddCert(cert)
 | |
| 	}
 | |
| 
 | |
| 	if _, err := certs[0].Verify(opts); err != nil {
 | |
| 		if _, ok := err.(x509.UnknownAuthorityError); ok {
 | |
| 			if trustUnknownHosts {
 | |
| 				pubKey, err := FromCryptoPublicKey(certs[0].PublicKey)
 | |
| 				if err != nil {
 | |
| 					return nil, fmt.Errorf("error extracting public key from cert: %s", err)
 | |
| 				}
 | |
| 
 | |
| 				pubKey.AddExtendedField("hosts", []string{addr})
 | |
| 
 | |
| 				if err := AddKeySetFile(knownHostsPath, pubKey); err != nil {
 | |
| 					return nil, fmt.Errorf("error adding machine to known hosts: %s", err)
 | |
| 				}
 | |
| 			} else {
 | |
| 				return nil, fmt.Errorf("unable to connect.  unknown host: %s", addr)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	testConn.Close()
 | |
| 	tlsConfig.InsecureSkipVerify = false
 | |
| 
 | |
| 	return tlsConfig, nil
 | |
| }
 | |
| 
 | |
| // joseBase64UrlEncode encodes the given data using the standard base64 url
 | |
| // encoding format but with all trailing '=' characters omitted in accordance
 | |
| // with the jose specification.
 | |
| // http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
 | |
| func joseBase64UrlEncode(b []byte) string {
 | |
| 	return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
 | |
| }
 | |
| 
 | |
| // joseBase64UrlDecode decodes the given string using the standard base64 url
 | |
| // decoder but first adds the appropriate number of trailing '=' characters in
 | |
| // accordance with the jose specification.
 | |
| // http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
 | |
| func joseBase64UrlDecode(s string) ([]byte, error) {
 | |
| 	s = strings.Replace(s, "\n", "", -1)
 | |
| 	s = strings.Replace(s, " ", "", -1)
 | |
| 	switch len(s) % 4 {
 | |
| 	case 0:
 | |
| 	case 2:
 | |
| 		s += "=="
 | |
| 	case 3:
 | |
| 		s += "="
 | |
| 	default:
 | |
| 		return nil, errors.New("illegal base64url string")
 | |
| 	}
 | |
| 	return base64.URLEncoding.DecodeString(s)
 | |
| }
 | |
| 
 | |
| func keyIDEncode(b []byte) string {
 | |
| 	s := strings.TrimRight(base32.StdEncoding.EncodeToString(b), "=")
 | |
| 	var buf bytes.Buffer
 | |
| 	var i int
 | |
| 	for i = 0; i < len(s)/4-1; i++ {
 | |
| 		start := i * 4
 | |
| 		end := start + 4
 | |
| 		buf.WriteString(s[start:end] + ":")
 | |
| 	}
 | |
| 	buf.WriteString(s[i*4:])
 | |
| 	return buf.String()
 | |
| }
 | |
| 
 | |
| func keyIDFromCryptoKey(pubKey PublicKey) string {
 | |
| 	// Generate and return a 'libtrust' fingerprint of the public key.
 | |
| 	// For an RSA key this should be:
 | |
| 	//   SHA256(DER encoded ASN1)
 | |
| 	// Then truncated to 240 bits and encoded into 12 base32 groups like so:
 | |
| 	//   ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP
 | |
| 	derBytes, err := x509.MarshalPKIXPublicKey(pubKey.CryptoPublicKey())
 | |
| 	if err != nil {
 | |
| 		return ""
 | |
| 	}
 | |
| 	hasher := crypto.SHA256.New()
 | |
| 	hasher.Write(derBytes)
 | |
| 	return keyIDEncode(hasher.Sum(nil)[:30])
 | |
| }
 | |
| 
 | |
| func stringFromMap(m map[string]interface{}, key string) (string, error) {
 | |
| 	val, ok := m[key]
 | |
| 	if !ok {
 | |
| 		return "", fmt.Errorf("%q value not specified", key)
 | |
| 	}
 | |
| 
 | |
| 	str, ok := val.(string)
 | |
| 	if !ok {
 | |
| 		return "", fmt.Errorf("%q value must be a string", key)
 | |
| 	}
 | |
| 	delete(m, key)
 | |
| 
 | |
| 	return str, nil
 | |
| }
 | |
| 
 | |
| func parseECCoordinate(cB64Url string, curve elliptic.Curve) (*big.Int, error) {
 | |
| 	curveByteLen := (curve.Params().BitSize + 7) >> 3
 | |
| 
 | |
| 	cBytes, err := joseBase64UrlDecode(cB64Url)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
 | |
| 	}
 | |
| 	cByteLength := len(cBytes)
 | |
| 	if cByteLength != curveByteLen {
 | |
| 		return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", cByteLength, curveByteLen)
 | |
| 	}
 | |
| 	return new(big.Int).SetBytes(cBytes), nil
 | |
| }
 | |
| 
 | |
| func parseECPrivateParam(dB64Url string, curve elliptic.Curve) (*big.Int, error) {
 | |
| 	dBytes, err := joseBase64UrlDecode(dB64Url)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	// The length of this octet string MUST be ceiling(log-base-2(n)/8)
 | |
| 	// octets (where n is the order of the curve). This is because the private
 | |
| 	// key d must be in the interval [1, n-1] so the bitlength of d should be
 | |
| 	// no larger than the bitlength of n-1. The easiest way to find the octet
 | |
| 	// length is to take bitlength(n-1), add 7 to force a carry, and shift this
 | |
| 	// bit sequence right by 3, which is essentially dividing by 8 and adding
 | |
| 	// 1 if there is any remainder. Thus, the private key value d should be
 | |
| 	// output to (bitlength(n-1)+7)>>3 octets.
 | |
| 	n := curve.Params().N
 | |
| 	octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
 | |
| 	dByteLength := len(dBytes)
 | |
| 
 | |
| 	if dByteLength != octetLength {
 | |
| 		return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", dByteLength, octetLength)
 | |
| 	}
 | |
| 
 | |
| 	return new(big.Int).SetBytes(dBytes), nil
 | |
| }
 | |
| 
 | |
| func parseRSAModulusParam(nB64Url string) (*big.Int, error) {
 | |
| 	nBytes, err := joseBase64UrlDecode(nB64Url)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	return new(big.Int).SetBytes(nBytes), nil
 | |
| }
 | |
| 
 | |
| func serializeRSAPublicExponentParam(e int) []byte {
 | |
| 	// We MUST use the minimum number of octets to represent E.
 | |
| 	// E is supposed to be 65537 for performance and security reasons
 | |
| 	// and is what golang's rsa package generates, but it might be
 | |
| 	// different if imported from some other generator.
 | |
| 	buf := make([]byte, 4)
 | |
| 	binary.BigEndian.PutUint32(buf, uint32(e))
 | |
| 	var i int
 | |
| 	for i = 0; i < 8; i++ {
 | |
| 		if buf[i] != 0 {
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 	return buf[i:]
 | |
| }
 | |
| 
 | |
| func parseRSAPublicExponentParam(eB64Url string) (int, error) {
 | |
| 	eBytes, err := joseBase64UrlDecode(eB64Url)
 | |
| 	if err != nil {
 | |
| 		return 0, fmt.Errorf("invalid base64 URL encoding: %s", err)
 | |
| 	}
 | |
| 	// Only the minimum number of bytes were used to represent E, but
 | |
| 	// binary.BigEndian.Uint32 expects at least 4 bytes, so we need
 | |
| 	// to add zero padding if necassary.
 | |
| 	byteLen := len(eBytes)
 | |
| 	buf := make([]byte, 4-byteLen, 4)
 | |
| 	eBytes = append(buf, eBytes...)
 | |
| 
 | |
| 	return int(binary.BigEndian.Uint32(eBytes)), nil
 | |
| }
 | |
| 
 | |
| func parseRSAPrivateKeyParamFromMap(m map[string]interface{}, key string) (*big.Int, error) {
 | |
| 	b64Url, err := stringFromMap(m, key)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	paramBytes, err := joseBase64UrlDecode(b64Url)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invaled base64 URL encoding: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	return new(big.Int).SetBytes(paramBytes), nil
 | |
| }
 | |
| 
 | |
| func createPemBlock(name string, derBytes []byte, headers map[string]interface{}) (*pem.Block, error) {
 | |
| 	pemBlock := &pem.Block{Type: name, Bytes: derBytes, Headers: map[string]string{}}
 | |
| 	for k, v := range headers {
 | |
| 		switch val := v.(type) {
 | |
| 		case string:
 | |
| 			pemBlock.Headers[k] = val
 | |
| 		case []string:
 | |
| 			if k == "hosts" {
 | |
| 				pemBlock.Headers[k] = strings.Join(val, ",")
 | |
| 			} else {
 | |
| 				// Return error, non-encodable type
 | |
| 			}
 | |
| 		default:
 | |
| 			// Return error, non-encodable type
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return pemBlock, nil
 | |
| }
 | |
| 
 | |
| func pubKeyFromPEMBlock(pemBlock *pem.Block) (PublicKey, error) {
 | |
| 	cryptoPublicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("unable to decode Public Key PEM data: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	addPEMHeadersToKey(pemBlock, pubKey)
 | |
| 
 | |
| 	return pubKey, nil
 | |
| }
 | |
| 
 | |
| func addPEMHeadersToKey(pemBlock *pem.Block, pubKey PublicKey) {
 | |
| 	for key, value := range pemBlock.Headers {
 | |
| 		var safeVal interface{}
 | |
| 		if key == "hosts" {
 | |
| 			safeVal = strings.Split(value, ",")
 | |
| 		} else {
 | |
| 			safeVal = value
 | |
| 		}
 | |
| 		pubKey.AddExtendedField(key, safeVal)
 | |
| 	}
 | |
| }
 |