mirror of https://github.com/ginuerzh/gost
66 changed files with 3778 additions and 853 deletions
@ -0,0 +1,219 @@ |
|||||
|
package quic |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"crypto/tls" |
||||
|
"errors" |
||||
|
"net" |
||||
|
"strings" |
||||
|
"sync/atomic" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/lucas-clemente/quic-go/protocol" |
||||
|
"github.com/lucas-clemente/quic-go/qerr" |
||||
|
"github.com/lucas-clemente/quic-go/utils" |
||||
|
) |
||||
|
|
||||
|
// A Client of QUIC
|
||||
|
type Client struct { |
||||
|
addr *net.UDPAddr |
||||
|
conn *net.UDPConn |
||||
|
hostname string |
||||
|
|
||||
|
connectionID protocol.ConnectionID |
||||
|
version protocol.VersionNumber |
||||
|
versionNegotiated bool |
||||
|
closed uint32 // atomic bool
|
||||
|
|
||||
|
tlsConfig *tls.Config |
||||
|
cryptoChangeCallback CryptoChangeCallback |
||||
|
versionNegotiateCallback VersionNegotiateCallback |
||||
|
|
||||
|
session packetHandler |
||||
|
} |
||||
|
|
||||
|
// VersionNegotiateCallback is called once the client has a negotiated version
|
||||
|
type VersionNegotiateCallback func() error |
||||
|
|
||||
|
var errHostname = errors.New("Invalid hostname") |
||||
|
|
||||
|
var ( |
||||
|
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") |
||||
|
) |
||||
|
|
||||
|
// NewClient makes a new client
|
||||
|
func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) { |
||||
|
udpAddr, err := net.ResolveUDPAddr("udp", host) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
connectionID, err := utils.GenerateConnectionID() |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
hostname, _, err := net.SplitHostPort(host) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
client := &Client{ |
||||
|
addr: udpAddr, |
||||
|
conn: conn, |
||||
|
hostname: hostname, |
||||
|
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
|
||||
|
connectionID: connectionID, |
||||
|
tlsConfig: tlsConfig, |
||||
|
cryptoChangeCallback: cryptoChangeCallback, |
||||
|
versionNegotiateCallback: versionNegotiateCallback, |
||||
|
} |
||||
|
|
||||
|
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version) |
||||
|
|
||||
|
err = client.createNewSession(nil) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
return client, nil |
||||
|
} |
||||
|
|
||||
|
// Listen listens
|
||||
|
func (c *Client) Listen() error { |
||||
|
for { |
||||
|
data := getPacketBuffer() |
||||
|
data = data[:protocol.MaxPacketSize] |
||||
|
|
||||
|
n, _, err := c.conn.ReadFromUDP(data) |
||||
|
if err != nil { |
||||
|
if strings.HasSuffix(err.Error(), "use of closed network connection") { |
||||
|
return nil |
||||
|
} |
||||
|
return err |
||||
|
} |
||||
|
data = data[:n] |
||||
|
|
||||
|
err = c.handlePacket(data) |
||||
|
if err != nil { |
||||
|
utils.Errorf("error handling packet: %s", err.Error()) |
||||
|
c.session.Close(err) |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs)
|
||||
|
func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) { |
||||
|
return c.session.OpenStream(id) |
||||
|
} |
||||
|
|
||||
|
// Close closes the connection
|
||||
|
func (c *Client) Close(e error) error { |
||||
|
// Only close once
|
||||
|
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
_ = c.session.Close(e) |
||||
|
return c.conn.Close() |
||||
|
} |
||||
|
|
||||
|
func (c *Client) handlePacket(packet []byte) error { |
||||
|
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize { |
||||
|
return qerr.PacketTooLarge |
||||
|
} |
||||
|
|
||||
|
rcvTime := time.Now() |
||||
|
|
||||
|
r := bytes.NewReader(packet) |
||||
|
|
||||
|
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) |
||||
|
if err != nil { |
||||
|
return qerr.Error(qerr.InvalidPacketHeader, err.Error()) |
||||
|
} |
||||
|
hdr.Raw = packet[:len(packet)-r.Len()] |
||||
|
|
||||
|
// ignore delayed / duplicated version negotiation packets
|
||||
|
if c.versionNegotiated && hdr.VersionFlag { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// this is the first packet after the client sent a packet with the VersionFlag set
|
||||
|
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
||||
|
if !hdr.VersionFlag && !c.versionNegotiated { |
||||
|
c.versionNegotiated = true |
||||
|
err = c.versionNegotiateCallback() |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if hdr.VersionFlag { |
||||
|
var hasCommonVersion bool // check if we're supporting any of the offered versions
|
||||
|
for _, v := range hdr.SupportedVersions { |
||||
|
// check if the server sent the offered version in supported versions
|
||||
|
if v == c.version { |
||||
|
return qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection.") |
||||
|
} |
||||
|
if v != protocol.VersionUnsupported { |
||||
|
hasCommonVersion = true |
||||
|
} |
||||
|
} |
||||
|
if !hasCommonVersion { |
||||
|
utils.Infof("No common version found.") |
||||
|
return qerr.InvalidVersion |
||||
|
} |
||||
|
|
||||
|
ok, highestSupportedVersion := protocol.HighestSupportedVersion(hdr.SupportedVersions) |
||||
|
if !ok { |
||||
|
return qerr.VersionNegotiationMismatch |
||||
|
} |
||||
|
|
||||
|
utils.Infof("Switching to QUIC version %d", highestSupportedVersion) |
||||
|
c.version = highestSupportedVersion |
||||
|
c.versionNegotiated = true |
||||
|
|
||||
|
c.session.Close(errCloseSessionForNewVersion) |
||||
|
err = c.createNewSession(hdr.SupportedVersions) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
err = c.versionNegotiateCallback() |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
return nil // version negotiation packets have no payload
|
||||
|
} |
||||
|
|
||||
|
c.session.handlePacket(&receivedPacket{ |
||||
|
remoteAddr: c.addr, |
||||
|
publicHeader: hdr, |
||||
|
data: packet[len(packet)-r.Len():], |
||||
|
rcvTime: rcvTime, |
||||
|
}) |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { |
||||
|
var err error |
||||
|
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.tlsConfig, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
go c.session.run() |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (c *Client) streamCallback(session *Session, stream utils.Stream) {} |
||||
|
|
||||
|
func (c *Client) closeCallback(id protocol.ConnectionID) { |
||||
|
utils.Infof("Connection %x closed.", id) |
||||
|
} |
||||
@ -0,0 +1,84 @@ |
|||||
|
package crypto |
||||
|
|
||||
|
import ( |
||||
|
"crypto/tls" |
||||
|
"errors" |
||||
|
"strings" |
||||
|
) |
||||
|
|
||||
|
// A CertChain holds a certificate and a private key
|
||||
|
type CertChain interface { |
||||
|
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) |
||||
|
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error) |
||||
|
GetLeafCert(sni string) ([]byte, error) |
||||
|
} |
||||
|
|
||||
|
// proofSource stores a key and a certificate for the server proof
|
||||
|
type certChain struct { |
||||
|
config *tls.Config |
||||
|
} |
||||
|
|
||||
|
var _ CertChain = &certChain{} |
||||
|
|
||||
|
var errNoMatchingCertificate = errors.New("no matching certificate found") |
||||
|
|
||||
|
// NewCertChain loads the key and cert from files
|
||||
|
func NewCertChain(tlsConfig *tls.Config) CertChain { |
||||
|
return &certChain{config: tlsConfig} |
||||
|
} |
||||
|
|
||||
|
// SignServerProof signs CHLO and server config for use in the server proof
|
||||
|
func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) { |
||||
|
cert, err := c.getCertForSNI(sni) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
return signServerProof(cert, chlo, serverConfigData) |
||||
|
} |
||||
|
|
||||
|
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
|
||||
|
func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { |
||||
|
cert, err := c.getCertForSNI(sni) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes) |
||||
|
} |
||||
|
|
||||
|
// GetLeafCert gets the leaf certificate
|
||||
|
func (c *certChain) GetLeafCert(sni string) ([]byte, error) { |
||||
|
cert, err := c.getCertForSNI(sni) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return cert.Certificate[0], nil |
||||
|
} |
||||
|
|
||||
|
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { |
||||
|
if c.config.GetCertificate != nil { |
||||
|
cert, err := c.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
if cert != nil { |
||||
|
return cert, nil |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if len(c.config.NameToCertificate) != 0 { |
||||
|
if cert, ok := c.config.NameToCertificate[sni]; ok { |
||||
|
return cert, nil |
||||
|
} |
||||
|
wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' }) |
||||
|
if cert, ok := c.config.NameToCertificate[wildcardSNI]; ok { |
||||
|
return cert, nil |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if len(c.config.Certificates) != 0 { |
||||
|
return &c.config.Certificates[0], nil |
||||
|
} |
||||
|
|
||||
|
return nil, errNoMatchingCertificate |
||||
|
} |
||||
@ -0,0 +1,131 @@ |
|||||
|
package crypto |
||||
|
|
||||
|
import ( |
||||
|
"crypto/tls" |
||||
|
"crypto/x509" |
||||
|
"errors" |
||||
|
"hash/fnv" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/lucas-clemente/quic-go/qerr" |
||||
|
) |
||||
|
|
||||
|
// CertManager manages the certificates sent by the server
|
||||
|
type CertManager interface { |
||||
|
SetData([]byte) error |
||||
|
GetCommonCertificateHashes() []byte |
||||
|
GetLeafCert() []byte |
||||
|
GetLeafCertHash() (uint64, error) |
||||
|
VerifyServerProof(proof, chlo, serverConfigData []byte) bool |
||||
|
Verify(hostname string) error |
||||
|
} |
||||
|
|
||||
|
type certManager struct { |
||||
|
chain []*x509.Certificate |
||||
|
config *tls.Config |
||||
|
} |
||||
|
|
||||
|
var _ CertManager = &certManager{} |
||||
|
|
||||
|
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded") |
||||
|
|
||||
|
// NewCertManager creates a new CertManager
|
||||
|
func NewCertManager(tlsConfig *tls.Config) CertManager { |
||||
|
return &certManager{config: tlsConfig} |
||||
|
} |
||||
|
|
||||
|
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
|
||||
|
func (c *certManager) SetData(data []byte) error { |
||||
|
byteChain, err := decompressChain(data) |
||||
|
if err != nil { |
||||
|
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid") |
||||
|
} |
||||
|
|
||||
|
chain := make([]*x509.Certificate, len(byteChain), len(byteChain)) |
||||
|
for i, data := range byteChain { |
||||
|
cert, err := x509.ParseCertificate(data) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
chain[i] = cert |
||||
|
} |
||||
|
|
||||
|
c.chain = chain |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (c *certManager) GetCommonCertificateHashes() []byte { |
||||
|
return getCommonCertificateHashes() |
||||
|
} |
||||
|
|
||||
|
// GetLeafCert returns the leaf certificate of the certificate chain
|
||||
|
// it returns nil if the certificate chain has not yet been set
|
||||
|
func (c *certManager) GetLeafCert() []byte { |
||||
|
if len(c.chain) == 0 { |
||||
|
return nil |
||||
|
} |
||||
|
return c.chain[0].Raw |
||||
|
} |
||||
|
|
||||
|
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
|
||||
|
func (c *certManager) GetLeafCertHash() (uint64, error) { |
||||
|
leafCert := c.GetLeafCert() |
||||
|
if leafCert == nil { |
||||
|
return 0, errNoCertificateChain |
||||
|
} |
||||
|
|
||||
|
h := fnv.New64a() |
||||
|
_, err := h.Write(leafCert) |
||||
|
if err != nil { |
||||
|
return 0, err |
||||
|
} |
||||
|
return h.Sum64(), nil |
||||
|
} |
||||
|
|
||||
|
// VerifyServerProof verifies the signature of the server config
|
||||
|
// it should only be called after the certificate chain has been set, otherwise it returns false
|
||||
|
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool { |
||||
|
if len(c.chain) == 0 { |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData) |
||||
|
} |
||||
|
|
||||
|
// Verify verifies the certificate chain
|
||||
|
func (c *certManager) Verify(hostname string) error { |
||||
|
if len(c.chain) == 0 { |
||||
|
return errNoCertificateChain |
||||
|
} |
||||
|
|
||||
|
if c.config != nil && c.config.InsecureSkipVerify { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
leafCert := c.chain[0] |
||||
|
|
||||
|
var opts x509.VerifyOptions |
||||
|
if c.config != nil { |
||||
|
opts.Roots = c.config.RootCAs |
||||
|
opts.DNSName = c.config.ServerName |
||||
|
if c.config.Time == nil { |
||||
|
opts.CurrentTime = time.Now() |
||||
|
} else { |
||||
|
opts.CurrentTime = c.config.Time() |
||||
|
} |
||||
|
} else { |
||||
|
opts.DNSName = hostname |
||||
|
} |
||||
|
|
||||
|
// the first certificate is the leaf certificate, all others are intermediates
|
||||
|
if len(c.chain) > 1 { |
||||
|
intermediates := x509.NewCertPool() |
||||
|
for i := 1; i < len(c.chain); i++ { |
||||
|
intermediates.AddCert(c.chain[i]) |
||||
|
} |
||||
|
opts.Intermediates = intermediates |
||||
|
} |
||||
|
|
||||
|
_, err := leafCert.Verify(opts) |
||||
|
return err |
||||
|
} |
||||
@ -1,71 +0,0 @@ |
|||||
// +build ignore
|
|
||||
|
|
||||
package crypto |
|
||||
|
|
||||
import ( |
|
||||
"crypto/rand" |
|
||||
|
|
||||
. "github.com/onsi/ginkgo" |
|
||||
. "github.com/onsi/gomega" |
|
||||
) |
|
||||
|
|
||||
var _ = Describe("Chacha20poly1305", func() { |
|
||||
var ( |
|
||||
alice, bob AEAD |
|
||||
keyAlice, keyBob, ivAlice, ivBob []byte |
|
||||
) |
|
||||
|
|
||||
BeforeEach(func() { |
|
||||
keyAlice = make([]byte, 32) |
|
||||
keyBob = make([]byte, 32) |
|
||||
ivAlice = make([]byte, 4) |
|
||||
ivBob = make([]byte, 4) |
|
||||
rand.Reader.Read(keyAlice) |
|
||||
rand.Reader.Read(keyBob) |
|
||||
rand.Reader.Read(ivAlice) |
|
||||
rand.Reader.Read(ivBob) |
|
||||
var err error |
|
||||
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice) |
|
||||
Expect(err).ToNot(HaveOccurred()) |
|
||||
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob) |
|
||||
Expect(err).ToNot(HaveOccurred()) |
|
||||
}) |
|
||||
|
|
||||
It("seals and opens", func() { |
|
||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad")) |
|
||||
text, err := bob.Open(nil, b, 42, []byte("aad")) |
|
||||
Expect(err).ToNot(HaveOccurred()) |
|
||||
Expect(text).To(Equal([]byte("foobar"))) |
|
||||
}) |
|
||||
|
|
||||
It("seals and opens reverse", func() { |
|
||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad")) |
|
||||
text, err := alice.Open(nil, b, 42, []byte("aad")) |
|
||||
Expect(err).ToNot(HaveOccurred()) |
|
||||
Expect(text).To(Equal([]byte("foobar"))) |
|
||||
}) |
|
||||
|
|
||||
It("has the proper length", func() { |
|
||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad")) |
|
||||
Expect(b).To(HaveLen(6 + 12)) |
|
||||
}) |
|
||||
|
|
||||
It("fails with wrong aad", func() { |
|
||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad")) |
|
||||
_, err := bob.Open(nil, b, 42, []byte("aad2")) |
|
||||
Expect(err).To(HaveOccurred()) |
|
||||
}) |
|
||||
|
|
||||
It("rejects wrong key and iv sizes", func() { |
|
||||
var err error |
|
||||
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs" |
|
||||
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice) |
|
||||
Expect(err).To(MatchError(e)) |
|
||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice) |
|
||||
Expect(err).To(MatchError(e)) |
|
||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice) |
|
||||
Expect(err).To(MatchError(e)) |
|
||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:]) |
|
||||
Expect(err).To(MatchError(e)) |
|
||||
}) |
|
||||
}) |
|
||||
@ -1,92 +0,0 @@ |
|||||
package crypto |
|
||||
|
|
||||
import ( |
|
||||
"crypto" |
|
||||
"crypto/rand" |
|
||||
"crypto/rsa" |
|
||||
"crypto/sha256" |
|
||||
"crypto/tls" |
|
||||
"errors" |
|
||||
"strings" |
|
||||
) |
|
||||
|
|
||||
// proofSource stores a key and a certificate for the server proof
|
|
||||
type proofSource struct { |
|
||||
config *tls.Config |
|
||||
} |
|
||||
|
|
||||
// NewProofSource loads the key and cert from files
|
|
||||
func NewProofSource(tlsConfig *tls.Config) (Signer, error) { |
|
||||
return &proofSource{config: tlsConfig}, nil |
|
||||
} |
|
||||
|
|
||||
// SignServerProof signs CHLO and server config for use in the server proof
|
|
||||
func (ps *proofSource) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) { |
|
||||
cert, err := ps.getCertForSNI(sni) |
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
|
|
||||
hash := sha256.New() |
|
||||
hash.Write([]byte("QUIC CHLO and server config signature\x00")) |
|
||||
chloHash := sha256.Sum256(chlo) |
|
||||
hash.Write([]byte{32, 0, 0, 0}) |
|
||||
hash.Write(chloHash[:]) |
|
||||
hash.Write(serverConfigData) |
|
||||
|
|
||||
key, ok := cert.PrivateKey.(crypto.Signer) |
|
||||
if !ok { |
|
||||
return nil, errors.New("expected PrivateKey to implement crypto.Signer") |
|
||||
} |
|
||||
|
|
||||
opts := crypto.SignerOpts(crypto.SHA256) |
|
||||
|
|
||||
if _, ok = key.(*rsa.PrivateKey); ok { |
|
||||
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256} |
|
||||
} |
|
||||
|
|
||||
return key.Sign(rand.Reader, hash.Sum(nil), opts) |
|
||||
} |
|
||||
|
|
||||
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
|
|
||||
func (ps *proofSource) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { |
|
||||
cert, err := ps.getCertForSNI(sni) |
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes) |
|
||||
} |
|
||||
|
|
||||
// GetLeafCert gets the leaf certificate
|
|
||||
func (ps *proofSource) GetLeafCert(sni string) ([]byte, error) { |
|
||||
cert, err := ps.getCertForSNI(sni) |
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
return cert.Certificate[0], nil |
|
||||
} |
|
||||
|
|
||||
func (ps *proofSource) getCertForSNI(sni string) (*tls.Certificate, error) { |
|
||||
if ps.config.GetCertificate != nil { |
|
||||
cert, err := ps.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) |
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
if cert != nil { |
|
||||
return cert, nil |
|
||||
} |
|
||||
} |
|
||||
if len(ps.config.NameToCertificate) != 0 { |
|
||||
if cert, ok := ps.config.NameToCertificate[sni]; ok { |
|
||||
return cert, nil |
|
||||
} |
|
||||
wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' }) |
|
||||
if cert, ok := ps.config.NameToCertificate[wildcardSNI]; ok { |
|
||||
return cert, nil |
|
||||
} |
|
||||
} |
|
||||
if len(ps.config.Certificates) != 0 { |
|
||||
return &ps.config.Certificates[0], nil |
|
||||
} |
|
||||
return nil, errors.New("no matching certificate found") |
|
||||
} |
|
||||
@ -0,0 +1,66 @@ |
|||||
|
package crypto |
||||
|
|
||||
|
import ( |
||||
|
"crypto" |
||||
|
"crypto/ecdsa" |
||||
|
"crypto/rand" |
||||
|
"crypto/rsa" |
||||
|
"crypto/sha256" |
||||
|
"crypto/tls" |
||||
|
"crypto/x509" |
||||
|
"encoding/asn1" |
||||
|
"errors" |
||||
|
"math/big" |
||||
|
) |
||||
|
|
||||
|
type ecdsaSignature struct { |
||||
|
R, S *big.Int |
||||
|
} |
||||
|
|
||||
|
// signServerProof signs CHLO and server config for use in the server proof
|
||||
|
func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) { |
||||
|
hash := sha256.New() |
||||
|
hash.Write([]byte("QUIC CHLO and server config signature\x00")) |
||||
|
chloHash := sha256.Sum256(chlo) |
||||
|
hash.Write([]byte{32, 0, 0, 0}) |
||||
|
hash.Write(chloHash[:]) |
||||
|
hash.Write(serverConfigData) |
||||
|
|
||||
|
key, ok := cert.PrivateKey.(crypto.Signer) |
||||
|
if !ok { |
||||
|
return nil, errors.New("expected PrivateKey to implement crypto.Signer") |
||||
|
} |
||||
|
|
||||
|
opts := crypto.SignerOpts(crypto.SHA256) |
||||
|
|
||||
|
if _, ok = key.(*rsa.PrivateKey); ok { |
||||
|
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256} |
||||
|
} |
||||
|
|
||||
|
return key.Sign(rand.Reader, hash.Sum(nil), opts) |
||||
|
} |
||||
|
|
||||
|
// verifyServerProof verifies the server proof signature
|
||||
|
func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool { |
||||
|
hash := sha256.New() |
||||
|
hash.Write([]byte("QUIC CHLO and server config signature\x00")) |
||||
|
chloHash := sha256.Sum256(chlo) |
||||
|
hash.Write([]byte{32, 0, 0, 0}) |
||||
|
hash.Write(chloHash[:]) |
||||
|
hash.Write(serverConfigData) |
||||
|
|
||||
|
// RSA
|
||||
|
if cert.PublicKeyAlgorithm == x509.RSA { |
||||
|
opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256} |
||||
|
err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts) |
||||
|
return err == nil |
||||
|
} |
||||
|
|
||||
|
// ECDSA
|
||||
|
signature := &ecdsaSignature{} |
||||
|
rest, err := asn1.Unmarshal(proof, signature) |
||||
|
if err != nil || len(rest) != 0 { |
||||
|
return false |
||||
|
} |
||||
|
return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S) |
||||
|
} |
||||
@ -1,8 +0,0 @@ |
|||||
package crypto |
|
||||
|
|
||||
// A Signer holds a certificate and a private key
|
|
||||
type Signer interface { |
|
||||
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) |
|
||||
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error) |
|
||||
GetLeafCert(sni string) ([]byte, error) |
|
||||
} |
|
||||
@ -0,0 +1,293 @@ |
|||||
|
package h2quic |
||||
|
|
||||
|
import ( |
||||
|
"crypto/tls" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
"net" |
||||
|
"net/http" |
||||
|
"strings" |
||||
|
"sync" |
||||
|
|
||||
|
"golang.org/x/net/http2" |
||||
|
"golang.org/x/net/http2/hpack" |
||||
|
"golang.org/x/net/idna" |
||||
|
|
||||
|
quic "github.com/lucas-clemente/quic-go" |
||||
|
"github.com/lucas-clemente/quic-go/protocol" |
||||
|
"github.com/lucas-clemente/quic-go/qerr" |
||||
|
"github.com/lucas-clemente/quic-go/utils" |
||||
|
) |
||||
|
|
||||
|
type quicClient interface { |
||||
|
OpenStream(protocol.StreamID) (utils.Stream, error) |
||||
|
Close(error) error |
||||
|
Listen() error |
||||
|
} |
||||
|
|
||||
|
// Client is a HTTP2 client doing QUIC requests
|
||||
|
type Client struct { |
||||
|
mutex sync.RWMutex |
||||
|
cryptoChangedCond sync.Cond |
||||
|
|
||||
|
t *QuicRoundTripper |
||||
|
|
||||
|
hostname string |
||||
|
encryptionLevel protocol.EncryptionLevel |
||||
|
|
||||
|
client quicClient |
||||
|
headerStream utils.Stream |
||||
|
headerErr *qerr.QuicError |
||||
|
highestOpenedStream protocol.StreamID |
||||
|
requestWriter *requestWriter |
||||
|
|
||||
|
responses map[protocol.StreamID]chan *http.Response |
||||
|
} |
||||
|
|
||||
|
var _ h2quicClient = &Client{} |
||||
|
|
||||
|
// NewClient creates a new client
|
||||
|
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) { |
||||
|
c := &Client{ |
||||
|
t: t, |
||||
|
hostname: authorityAddr("https", hostname), |
||||
|
highestOpenedStream: 3, |
||||
|
responses: make(map[protocol.StreamID]chan *http.Response), |
||||
|
} |
||||
|
c.cryptoChangedCond = sync.Cond{L: &c.mutex} |
||||
|
|
||||
|
var err error |
||||
|
c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
go c.client.Listen() |
||||
|
return c, nil |
||||
|
} |
||||
|
|
||||
|
func (c *Client) handleStreamCb(session *quic.Session, stream utils.Stream) { |
||||
|
utils.Debugf("Handling stream %d", stream.StreamID()) |
||||
|
} |
||||
|
|
||||
|
func (c *Client) cryptoChangeCallback(isForwardSecure bool) { |
||||
|
c.cryptoChangedCond.L.Lock() |
||||
|
defer c.cryptoChangedCond.L.Unlock() |
||||
|
|
||||
|
if isForwardSecure { |
||||
|
c.encryptionLevel = protocol.EncryptionForwardSecure |
||||
|
utils.Debugf("is forward secure") |
||||
|
} else { |
||||
|
c.encryptionLevel = protocol.EncryptionSecure |
||||
|
utils.Debugf("is secure") |
||||
|
} |
||||
|
c.cryptoChangedCond.Broadcast() |
||||
|
} |
||||
|
|
||||
|
func (c *Client) versionNegotiateCallback() error { |
||||
|
var err error |
||||
|
// once the version has been negotiated, open the header stream
|
||||
|
c.headerStream, err = c.client.OpenStream(3) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
c.requestWriter = newRequestWriter(c.headerStream) |
||||
|
go c.handleHeaderStream() |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (c *Client) handleHeaderStream() { |
||||
|
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) |
||||
|
h2framer := http2.NewFramer(nil, c.headerStream) |
||||
|
|
||||
|
var lastStream protocol.StreamID |
||||
|
|
||||
|
for { |
||||
|
frame, err := h2framer.ReadFrame() |
||||
|
if err != nil { |
||||
|
c.headerErr = qerr.Error(qerr.InvalidStreamData, "cannot read frame") |
||||
|
break |
||||
|
} |
||||
|
lastStream = protocol.StreamID(frame.Header().StreamID) |
||||
|
hframe, ok := frame.(*http2.HeadersFrame) |
||||
|
if !ok { |
||||
|
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") |
||||
|
break |
||||
|
} |
||||
|
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} |
||||
|
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) |
||||
|
if err != nil { |
||||
|
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") |
||||
|
break |
||||
|
} |
||||
|
|
||||
|
c.mutex.RLock() |
||||
|
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] |
||||
|
c.mutex.RUnlock() |
||||
|
if !ok { |
||||
|
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) |
||||
|
break |
||||
|
} |
||||
|
|
||||
|
rsp, err := responseFromHeaders(mhframe) |
||||
|
if err != nil { |
||||
|
c.headerErr = qerr.Error(qerr.InternalError, err.Error()) |
||||
|
} |
||||
|
headerChan <- rsp |
||||
|
} |
||||
|
|
||||
|
// stop all running request
|
||||
|
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) |
||||
|
c.mutex.Lock() |
||||
|
for _, responseChan := range c.responses { |
||||
|
responseChan <- nil |
||||
|
} |
||||
|
c.mutex.Unlock() |
||||
|
} |
||||
|
|
||||
|
// Do executes a request and returns a response
|
||||
|
func (c *Client) Do(req *http.Request) (*http.Response, error) { |
||||
|
// TODO: add port to address, if it doesn't have one
|
||||
|
if req.URL.Scheme != "https" { |
||||
|
return nil, errors.New("quic http2: unsupported scheme") |
||||
|
} |
||||
|
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { |
||||
|
utils.Debugf("%s vs %s", req.Host, c.hostname) |
||||
|
return nil, errors.New("h2quic Client BUG: Do called for the wrong client") |
||||
|
} |
||||
|
|
||||
|
hasBody := (req.Body != nil) |
||||
|
|
||||
|
c.mutex.Lock() |
||||
|
c.highestOpenedStream += 2 |
||||
|
dataStreamID := c.highestOpenedStream |
||||
|
for c.encryptionLevel != protocol.EncryptionForwardSecure { |
||||
|
c.cryptoChangedCond.Wait() |
||||
|
} |
||||
|
hdrChan := make(chan *http.Response) |
||||
|
c.responses[dataStreamID] = hdrChan |
||||
|
c.mutex.Unlock() |
||||
|
|
||||
|
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
|
||||
|
dataStream, err := c.client.OpenStream(dataStreamID) |
||||
|
if err != nil { |
||||
|
c.Close(err) |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
var requestedGzip bool |
||||
|
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { |
||||
|
requestedGzip = true |
||||
|
} |
||||
|
// TODO: add support for trailers
|
||||
|
endStream := !hasBody |
||||
|
err = c.requestWriter.WriteRequest(req, dataStreamID, endStream, requestedGzip) |
||||
|
if err != nil { |
||||
|
c.Close(err) |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
resc := make(chan error, 1) |
||||
|
if hasBody { |
||||
|
go func() { |
||||
|
resc <- c.writeRequestBody(dataStream, req.Body) |
||||
|
}() |
||||
|
} |
||||
|
|
||||
|
var res *http.Response |
||||
|
|
||||
|
var receivedResponse bool |
||||
|
var bodySent bool |
||||
|
|
||||
|
if !hasBody { |
||||
|
bodySent = true |
||||
|
} |
||||
|
|
||||
|
for !(bodySent && receivedResponse) { |
||||
|
select { |
||||
|
case res = <-hdrChan: |
||||
|
receivedResponse = true |
||||
|
c.mutex.Lock() |
||||
|
delete(c.responses, dataStreamID) |
||||
|
c.mutex.Unlock() |
||||
|
if res == nil { // an error occured on the header stream
|
||||
|
c.Close(c.headerErr) |
||||
|
return nil, c.headerErr |
||||
|
} |
||||
|
case err := <-resc: |
||||
|
bodySent = true |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// TODO: correctly set this variable
|
||||
|
var streamEnded bool |
||||
|
isHead := (req.Method == "HEAD") |
||||
|
|
||||
|
res = setLength(res, isHead, streamEnded) |
||||
|
|
||||
|
if streamEnded || isHead { |
||||
|
res.Body = noBody |
||||
|
} else { |
||||
|
res.Body = dataStream |
||||
|
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { |
||||
|
res.Header.Del("Content-Encoding") |
||||
|
res.Header.Del("Content-Length") |
||||
|
res.ContentLength = -1 |
||||
|
res.Body = &gzipReader{body: res.Body} |
||||
|
setUncompressed(res) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
res.Request = req |
||||
|
|
||||
|
return res, nil |
||||
|
} |
||||
|
|
||||
|
func (c *Client) writeRequestBody(dataStream utils.Stream, body io.ReadCloser) (err error) { |
||||
|
defer func() { |
||||
|
cerr := body.Close() |
||||
|
if err == nil { |
||||
|
// TODO: what to do with dataStream here? Maybe reset it?
|
||||
|
err = cerr |
||||
|
} |
||||
|
}() |
||||
|
|
||||
|
_, err = io.Copy(dataStream, body) |
||||
|
if err != nil { |
||||
|
// TODO: what to do with dataStream here? Maybe reset it?
|
||||
|
return err |
||||
|
} |
||||
|
return dataStream.Close() |
||||
|
} |
||||
|
|
||||
|
// Close closes the client
|
||||
|
func (c *Client) Close(e error) { |
||||
|
_ = c.client.Close(e) |
||||
|
} |
||||
|
|
||||
|
// copied from net/transport.go
|
||||
|
|
||||
|
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
||||
|
// and returns a host:port. The port 443 is added if needed.
|
||||
|
func authorityAddr(scheme string, authority string) (addr string) { |
||||
|
host, port, err := net.SplitHostPort(authority) |
||||
|
if err != nil { // authority didn't have a port
|
||||
|
port = "443" |
||||
|
if scheme == "http" { |
||||
|
port = "80" |
||||
|
} |
||||
|
host = authority |
||||
|
} |
||||
|
if a, err := idna.ToASCII(host); err == nil { |
||||
|
host = a |
||||
|
} |
||||
|
// IPv6 address literal, without a port:
|
||||
|
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { |
||||
|
return host + ":" + port |
||||
|
} |
||||
|
return net.JoinHostPort(host, port) |
||||
|
} |
||||
@ -0,0 +1,35 @@ |
|||||
|
package h2quic |
||||
|
|
||||
|
// copied from net/transport.go
|
||||
|
|
||||
|
// gzipReader wraps a response body so it can lazily
|
||||
|
// call gzip.NewReader on the first call to Read
|
||||
|
import ( |
||||
|
"compress/gzip" |
||||
|
"io" |
||||
|
) |
||||
|
|
||||
|
// call gzip.NewReader on the first call to Read
|
||||
|
type gzipReader struct { |
||||
|
body io.ReadCloser // underlying Response.Body
|
||||
|
zr *gzip.Reader // lazily-initialized gzip reader
|
||||
|
zerr error // sticky error
|
||||
|
} |
||||
|
|
||||
|
func (gz *gzipReader) Read(p []byte) (n int, err error) { |
||||
|
if gz.zerr != nil { |
||||
|
return 0, gz.zerr |
||||
|
} |
||||
|
if gz.zr == nil { |
||||
|
gz.zr, err = gzip.NewReader(gz.body) |
||||
|
if err != nil { |
||||
|
gz.zerr = err |
||||
|
return 0, err |
||||
|
} |
||||
|
} |
||||
|
return gz.zr.Read(p) |
||||
|
} |
||||
|
|
||||
|
func (gz *gzipReader) Close() error { |
||||
|
return gz.body.Close() |
||||
|
} |
||||
@ -0,0 +1,29 @@ |
|||||
|
package h2quic |
||||
|
|
||||
|
import ( |
||||
|
"io" |
||||
|
|
||||
|
"github.com/lucas-clemente/quic-go/utils" |
||||
|
) |
||||
|
|
||||
|
type requestBody struct { |
||||
|
requestRead bool |
||||
|
dataStream utils.Stream |
||||
|
} |
||||
|
|
||||
|
// make sure the requestBody can be used as a http.Request.Body
|
||||
|
var _ io.ReadCloser = &requestBody{} |
||||
|
|
||||
|
func newRequestBody(stream utils.Stream) *requestBody { |
||||
|
return &requestBody{dataStream: stream} |
||||
|
} |
||||
|
|
||||
|
func (b *requestBody) Read(p []byte) (int, error) { |
||||
|
b.requestRead = true |
||||
|
return b.dataStream.Read(p) |
||||
|
} |
||||
|
|
||||
|
func (b *requestBody) Close() error { |
||||
|
// stream's Close() closes the write side, not the read side
|
||||
|
return nil |
||||
|
} |
||||
@ -0,0 +1,200 @@ |
|||||
|
package h2quic |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"fmt" |
||||
|
"net/http" |
||||
|
"strconv" |
||||
|
"strings" |
||||
|
"sync" |
||||
|
|
||||
|
"golang.org/x/net/http2" |
||||
|
"golang.org/x/net/http2/hpack" |
||||
|
"golang.org/x/net/lex/httplex" |
||||
|
|
||||
|
"github.com/lucas-clemente/quic-go/protocol" |
||||
|
"github.com/lucas-clemente/quic-go/utils" |
||||
|
) |
||||
|
|
||||
|
type requestWriter struct { |
||||
|
mutex sync.Mutex |
||||
|
headerStream utils.Stream |
||||
|
|
||||
|
henc *hpack.Encoder |
||||
|
hbuf bytes.Buffer // HPACK encoder writes into this
|
||||
|
} |
||||
|
|
||||
|
const defaultUserAgent = "quic-go" |
||||
|
|
||||
|
func newRequestWriter(headerStream utils.Stream) *requestWriter { |
||||
|
rw := &requestWriter{ |
||||
|
headerStream: headerStream, |
||||
|
} |
||||
|
rw.henc = hpack.NewEncoder(&rw.hbuf) |
||||
|
return rw |
||||
|
} |
||||
|
|
||||
|
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error { |
||||
|
// TODO: add support for trailers
|
||||
|
// TODO: add support for gzip compression
|
||||
|
// TODO: write continuation frames, if the header frame is too long
|
||||
|
|
||||
|
w.mutex.Lock() |
||||
|
defer w.mutex.Unlock() |
||||
|
|
||||
|
w.encodeHeaders(req, requestGzip, "", actualContentLength(req)) |
||||
|
h2framer := http2.NewFramer(w.headerStream, nil) |
||||
|
return h2framer.WriteHeaders(http2.HeadersFrameParam{ |
||||
|
StreamID: uint32(dataStreamID), |
||||
|
EndHeaders: true, |
||||
|
EndStream: endStream, |
||||
|
BlockFragment: w.hbuf.Bytes(), |
||||
|
Priority: http2.PriorityParam{Weight: 0xff}, |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// the rest of this files is copied from http2.Transport
|
||||
|
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { |
||||
|
w.hbuf.Reset() |
||||
|
|
||||
|
host := req.Host |
||||
|
if host == "" { |
||||
|
host = req.URL.Host |
||||
|
} |
||||
|
host, err := httplex.PunycodeHostPort(host) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
var path string |
||||
|
if req.Method != "CONNECT" { |
||||
|
path = req.URL.RequestURI() |
||||
|
if !validPseudoPath(path) { |
||||
|
orig := path |
||||
|
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) |
||||
|
if !validPseudoPath(path) { |
||||
|
if req.URL.Opaque != "" { |
||||
|
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) |
||||
|
} else { |
||||
|
return nil, fmt.Errorf("invalid request :path %q", orig) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Check for any invalid headers and return an error before we
|
||||
|
// potentially pollute our hpack state. (We want to be able to
|
||||
|
// continue to reuse the hpack encoder for future requests)
|
||||
|
for k, vv := range req.Header { |
||||
|
if !httplex.ValidHeaderFieldName(k) { |
||||
|
return nil, fmt.Errorf("invalid HTTP header name %q", k) |
||||
|
} |
||||
|
for _, v := range vv { |
||||
|
if !httplex.ValidHeaderFieldValue(v) { |
||||
|
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// 8.1.2.3 Request Pseudo-Header Fields
|
||||
|
// The :path pseudo-header field includes the path and query parts of the
|
||||
|
// target URI (the path-absolute production and optionally a '?' character
|
||||
|
// followed by the query production (see Sections 3.3 and 3.4 of
|
||||
|
// [RFC3986]).
|
||||
|
w.writeHeader(":authority", host) |
||||
|
w.writeHeader(":method", req.Method) |
||||
|
if req.Method != "CONNECT" { |
||||
|
w.writeHeader(":path", path) |
||||
|
w.writeHeader(":scheme", req.URL.Scheme) |
||||
|
} |
||||
|
if trailers != "" { |
||||
|
w.writeHeader("trailer", trailers) |
||||
|
} |
||||
|
|
||||
|
var didUA bool |
||||
|
for k, vv := range req.Header { |
||||
|
lowKey := strings.ToLower(k) |
||||
|
switch lowKey { |
||||
|
case "host", "content-length": |
||||
|
// Host is :authority, already sent.
|
||||
|
// Content-Length is automatic, set below.
|
||||
|
continue |
||||
|
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": |
||||
|
// Per 8.1.2.2 Connection-Specific Header
|
||||
|
// Fields, don't send connection-specific
|
||||
|
// fields. We have already checked if any
|
||||
|
// are error-worthy so just ignore the rest.
|
||||
|
continue |
||||
|
case "user-agent": |
||||
|
// Match Go's http1 behavior: at most one
|
||||
|
// User-Agent. If set to nil or empty string,
|
||||
|
// then omit it. Otherwise if not mentioned,
|
||||
|
// include the default (below).
|
||||
|
didUA = true |
||||
|
if len(vv) < 1 { |
||||
|
continue |
||||
|
} |
||||
|
vv = vv[:1] |
||||
|
if vv[0] == "" { |
||||
|
continue |
||||
|
} |
||||
|
} |
||||
|
for _, v := range vv { |
||||
|
w.writeHeader(lowKey, v) |
||||
|
} |
||||
|
} |
||||
|
if shouldSendReqContentLength(req.Method, contentLength) { |
||||
|
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) |
||||
|
} |
||||
|
if addGzipHeader { |
||||
|
w.writeHeader("accept-encoding", "gzip") |
||||
|
} |
||||
|
if !didUA { |
||||
|
w.writeHeader("user-agent", defaultUserAgent) |
||||
|
} |
||||
|
return w.hbuf.Bytes(), nil |
||||
|
} |
||||
|
|
||||
|
func (w *requestWriter) writeHeader(name, value string) { |
||||
|
utils.Debugf("http2: Transport encoding header %q = %q", name, value) |
||||
|
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) |
||||
|
} |
||||
|
|
||||
|
// shouldSendReqContentLength reports whether the http2.Transport should send
|
||||
|
// a "content-length" request header. This logic is basically a copy of the net/http
|
||||
|
// transferWriter.shouldSendContentLength.
|
||||
|
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
||||
|
// -1 means unknown.
|
||||
|
func shouldSendReqContentLength(method string, contentLength int64) bool { |
||||
|
if contentLength > 0 { |
||||
|
return true |
||||
|
} |
||||
|
if contentLength < 0 { |
||||
|
return false |
||||
|
} |
||||
|
// For zero bodies, whether we send a content-length depends on the method.
|
||||
|
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
||||
|
switch method { |
||||
|
case "POST", "PUT", "PATCH": |
||||
|
return true |
||||
|
default: |
||||
|
return false |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func validPseudoPath(v string) bool { |
||||
|
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" |
||||
|
} |
||||
|
|
||||
|
// actualContentLength returns a sanitized version of
|
||||
|
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
||||
|
// means unknown.
|
||||
|
func actualContentLength(req *http.Request) int64 { |
||||
|
if req.Body == nil { |
||||
|
return 0 |
||||
|
} |
||||
|
if req.ContentLength != 0 { |
||||
|
return req.ContentLength |
||||
|
} |
||||
|
return -1 |
||||
|
} |
||||
@ -0,0 +1,111 @@ |
|||||
|
package h2quic |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"errors" |
||||
|
"io" |
||||
|
"io/ioutil" |
||||
|
"net/http" |
||||
|
"net/textproto" |
||||
|
"strconv" |
||||
|
"strings" |
||||
|
|
||||
|
"golang.org/x/net/http2" |
||||
|
) |
||||
|
|
||||
|
// copied from net/http2/transport.go
|
||||
|
|
||||
|
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") |
||||
|
var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) |
||||
|
|
||||
|
// from the handleResponse function
|
||||
|
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { |
||||
|
if f.Truncated { |
||||
|
return nil, errResponseHeaderListSize |
||||
|
} |
||||
|
|
||||
|
status := f.PseudoValue("status") |
||||
|
if status == "" { |
||||
|
return nil, errors.New("missing status pseudo header") |
||||
|
} |
||||
|
statusCode, err := strconv.Atoi(status) |
||||
|
if err != nil { |
||||
|
return nil, errors.New("malformed non-numeric status pseudo header") |
||||
|
} |
||||
|
|
||||
|
if statusCode == 100 { |
||||
|
// TODO: handle this
|
||||
|
|
||||
|
// traceGot100Continue(cs.trace)
|
||||
|
// if cs.on100 != nil {
|
||||
|
// cs.on100() // forces any write delay timer to fire
|
||||
|
// }
|
||||
|
// cs.pastHeaders = false // do it all again
|
||||
|
// return nil, nil
|
||||
|
} |
||||
|
|
||||
|
header := make(http.Header) |
||||
|
res := &http.Response{ |
||||
|
Proto: "HTTP/2.0", |
||||
|
ProtoMajor: 2, |
||||
|
Header: header, |
||||
|
StatusCode: statusCode, |
||||
|
Status: status + " " + http.StatusText(statusCode), |
||||
|
} |
||||
|
for _, hf := range f.RegularFields() { |
||||
|
key := http.CanonicalHeaderKey(hf.Name) |
||||
|
if key == "Trailer" { |
||||
|
t := res.Trailer |
||||
|
if t == nil { |
||||
|
t = make(http.Header) |
||||
|
res.Trailer = t |
||||
|
} |
||||
|
foreachHeaderElement(hf.Value, func(v string) { |
||||
|
t[http.CanonicalHeaderKey(v)] = nil |
||||
|
}) |
||||
|
} else { |
||||
|
header[key] = append(header[key], hf.Value) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return res, nil |
||||
|
} |
||||
|
|
||||
|
// continuation of the handleResponse function
|
||||
|
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response { |
||||
|
if !streamEnded || isHead { |
||||
|
res.ContentLength = -1 |
||||
|
if clens := res.Header["Content-Length"]; len(clens) == 1 { |
||||
|
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { |
||||
|
res.ContentLength = clen64 |
||||
|
} else { |
||||
|
// TODO: care? unlike http/1, it won't mess up our framing, so it's
|
||||
|
// more safe smuggling-wise to ignore.
|
||||
|
} |
||||
|
} else if len(clens) > 1 { |
||||
|
// TODO: care? unlike http/1, it won't mess up our framing, so it's
|
||||
|
// more safe smuggling-wise to ignore.
|
||||
|
} |
||||
|
} |
||||
|
return res |
||||
|
} |
||||
|
|
||||
|
// copied from net/http/server.go
|
||||
|
|
||||
|
// foreachHeaderElement splits v according to the "#rule" construction
|
||||
|
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
|
||||
|
func foreachHeaderElement(v string, fn func(string)) { |
||||
|
v = textproto.TrimString(v) |
||||
|
if v == "" { |
||||
|
return |
||||
|
} |
||||
|
if !strings.Contains(v, ",") { |
||||
|
fn(v) |
||||
|
return |
||||
|
} |
||||
|
for _, f := range strings.Split(v, ",") { |
||||
|
if f = textproto.TrimString(f); f != "" { |
||||
|
fn(f) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,9 @@ |
|||||
|
// +build go1.7
|
||||
|
|
||||
|
package h2quic |
||||
|
|
||||
|
import "net/http" |
||||
|
|
||||
|
func setUncompressed(res *http.Response) { |
||||
|
res.Uncompressed = true |
||||
|
} |
||||
@ -0,0 +1,9 @@ |
|||||
|
// +build !go1.7
|
||||
|
|
||||
|
package h2quic |
||||
|
|
||||
|
import "net/http" |
||||
|
|
||||
|
func setUncompressed(res *http.Response) { |
||||
|
// http.Response.Uncompressed was introduced in go 1.7
|
||||
|
} |
||||
@ -0,0 +1,135 @@ |
|||||
|
package h2quic |
||||
|
|
||||
|
import ( |
||||
|
"crypto/tls" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"net/http" |
||||
|
"strings" |
||||
|
"sync" |
||||
|
|
||||
|
"golang.org/x/net/lex/httplex" |
||||
|
) |
||||
|
|
||||
|
type h2quicClient interface { |
||||
|
Do(*http.Request) (*http.Response, error) |
||||
|
} |
||||
|
|
||||
|
// QuicRoundTripper implements the http.RoundTripper interface
|
||||
|
type QuicRoundTripper struct { |
||||
|
mutex sync.Mutex |
||||
|
|
||||
|
// DisableCompression, if true, prevents the Transport from
|
||||
|
// requesting compression with an "Accept-Encoding: gzip"
|
||||
|
// request header when the Request contains no existing
|
||||
|
// Accept-Encoding value. If the Transport requests gzip on
|
||||
|
// its own and gets a gzipped response, it's transparently
|
||||
|
// decoded in the Response.Body. However, if the user
|
||||
|
// explicitly requested gzip it is not automatically
|
||||
|
// uncompressed.
|
||||
|
DisableCompression bool |
||||
|
|
||||
|
// TLSClientConfig specifies the TLS configuration to use with
|
||||
|
// tls.Client. If nil, the default configuration is used.
|
||||
|
TLSClientConfig *tls.Config |
||||
|
|
||||
|
clients map[string]h2quicClient |
||||
|
} |
||||
|
|
||||
|
var _ http.RoundTripper = &QuicRoundTripper{} |
||||
|
|
||||
|
// RoundTrip does a round trip
|
||||
|
func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { |
||||
|
if req.URL == nil { |
||||
|
closeRequestBody(req) |
||||
|
return nil, errors.New("quic: nil Request.URL") |
||||
|
} |
||||
|
if req.URL.Host == "" { |
||||
|
closeRequestBody(req) |
||||
|
return nil, errors.New("quic: no Host in request URL") |
||||
|
} |
||||
|
if req.Header == nil { |
||||
|
closeRequestBody(req) |
||||
|
return nil, errors.New("quic: nil Request.Header") |
||||
|
} |
||||
|
|
||||
|
if req.URL.Scheme == "https" { |
||||
|
for k, vv := range req.Header { |
||||
|
if !httplex.ValidHeaderFieldName(k) { |
||||
|
return nil, fmt.Errorf("quic: invalid http header field name %q", k) |
||||
|
} |
||||
|
for _, v := range vv { |
||||
|
if !httplex.ValidHeaderFieldValue(v) { |
||||
|
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
} else { |
||||
|
closeRequestBody(req) |
||||
|
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme) |
||||
|
} |
||||
|
|
||||
|
if req.Method != "" && !validMethod(req.Method) { |
||||
|
closeRequestBody(req) |
||||
|
return nil, fmt.Errorf("quic: invalid method %q", req.Method) |
||||
|
} |
||||
|
|
||||
|
hostname := authorityAddr("https", hostnameFromRequest(req)) |
||||
|
client, err := r.getClient(hostname) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return client.Do(req) |
||||
|
} |
||||
|
|
||||
|
func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { |
||||
|
r.mutex.Lock() |
||||
|
defer r.mutex.Unlock() |
||||
|
|
||||
|
if r.clients == nil { |
||||
|
r.clients = make(map[string]h2quicClient) |
||||
|
} |
||||
|
|
||||
|
client, ok := r.clients[hostname] |
||||
|
if !ok { |
||||
|
var err error |
||||
|
client, err = NewClient(r, r.TLSClientConfig, hostname) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
r.clients[hostname] = client |
||||
|
} |
||||
|
return client, nil |
||||
|
} |
||||
|
|
||||
|
func (r *QuicRoundTripper) disableCompression() bool { |
||||
|
return r.DisableCompression |
||||
|
} |
||||
|
|
||||
|
func closeRequestBody(req *http.Request) { |
||||
|
if req.Body != nil { |
||||
|
req.Body.Close() |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func validMethod(method string) bool { |
||||
|
/* |
||||
|
Method = "OPTIONS" ; Section 9.2 |
||||
|
| "GET" ; Section 9.3 |
||||
|
| "HEAD" ; Section 9.4 |
||||
|
| "POST" ; Section 9.5 |
||||
|
| "PUT" ; Section 9.6 |
||||
|
| "DELETE" ; Section 9.7 |
||||
|
| "TRACE" ; Section 9.8 |
||||
|
| "CONNECT" ; Section 9.9 |
||||
|
| extension-method |
||||
|
extension-method = token |
||||
|
token = 1*<any CHAR except CTLs or separators> |
||||
|
*/ |
||||
|
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 |
||||
|
} |
||||
|
|
||||
|
// copied from net/http/http.go
|
||||
|
func isNotToken(r rune) bool { |
||||
|
return !httplex.IsTokenRune(r) |
||||
|
} |
||||
@ -0,0 +1,485 @@ |
|||||
|
package handshake |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"crypto/rand" |
||||
|
"crypto/tls" |
||||
|
"encoding/binary" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
"sync" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/lucas-clemente/quic-go/crypto" |
||||
|
"github.com/lucas-clemente/quic-go/protocol" |
||||
|
"github.com/lucas-clemente/quic-go/qerr" |
||||
|
"github.com/lucas-clemente/quic-go/utils" |
||||
|
) |
||||
|
|
||||
|
type cryptoSetupClient struct { |
||||
|
mutex sync.RWMutex |
||||
|
|
||||
|
hostname string |
||||
|
connID protocol.ConnectionID |
||||
|
version protocol.VersionNumber |
||||
|
negotiatedVersions []protocol.VersionNumber |
||||
|
|
||||
|
cryptoStream utils.Stream |
||||
|
|
||||
|
serverConfig *serverConfigClient |
||||
|
|
||||
|
stk []byte |
||||
|
sno []byte |
||||
|
nonc []byte |
||||
|
proof []byte |
||||
|
diversificationNonce []byte |
||||
|
chloForSignature []byte |
||||
|
lastSentCHLO []byte |
||||
|
certManager crypto.CertManager |
||||
|
|
||||
|
clientHelloCounter int |
||||
|
serverVerified bool // has the certificate chain and the proof already been verified
|
||||
|
keyDerivation KeyDerivationFunction |
||||
|
|
||||
|
receivedSecurePacket bool |
||||
|
secureAEAD crypto.AEAD |
||||
|
forwardSecureAEAD crypto.AEAD |
||||
|
aeadChanged chan struct{} |
||||
|
|
||||
|
connectionParameters ConnectionParametersManager |
||||
|
} |
||||
|
|
||||
|
var _ crypto.AEAD = &cryptoSetupClient{} |
||||
|
var _ CryptoSetup = &cryptoSetupClient{} |
||||
|
|
||||
|
var ( |
||||
|
errNoObitForClientNonce = errors.New("CryptoSetup BUG: No OBIT for client nonce available") |
||||
|
errClientNonceAlreadyExists = errors.New("CryptoSetup BUG: A client nonce was already generated") |
||||
|
errConflictingDiversificationNonces = errors.New("Received two different diversification nonces") |
||||
|
) |
||||
|
|
||||
|
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
|
||||
|
func NewCryptoSetupClient( |
||||
|
hostname string, |
||||
|
connID protocol.ConnectionID, |
||||
|
version protocol.VersionNumber, |
||||
|
cryptoStream utils.Stream, |
||||
|
tlsConfig *tls.Config, |
||||
|
connectionParameters ConnectionParametersManager, |
||||
|
aeadChanged chan struct{}, |
||||
|
negotiatedVersions []protocol.VersionNumber, |
||||
|
) (CryptoSetup, error) { |
||||
|
return &cryptoSetupClient{ |
||||
|
hostname: hostname, |
||||
|
connID: connID, |
||||
|
version: version, |
||||
|
cryptoStream: cryptoStream, |
||||
|
certManager: crypto.NewCertManager(tlsConfig), |
||||
|
connectionParameters: connectionParameters, |
||||
|
keyDerivation: crypto.DeriveKeysAESGCM, |
||||
|
aeadChanged: aeadChanged, |
||||
|
negotiatedVersions: negotiatedVersions, |
||||
|
}, nil |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) HandleCryptoStream() error { |
||||
|
for { |
||||
|
err := h.maybeUpgradeCrypto() |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
// send CHLOs until the forward secure encryption is established
|
||||
|
if h.forwardSecureAEAD == nil { |
||||
|
err = h.sendCHLO() |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
var shloData bytes.Buffer |
||||
|
|
||||
|
messageTag, cryptoData, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &shloData)) |
||||
|
if err != nil { |
||||
|
return qerr.HandshakeFailed |
||||
|
} |
||||
|
|
||||
|
if messageTag != TagSHLO && messageTag != TagREJ { |
||||
|
return qerr.InvalidCryptoMessageType |
||||
|
} |
||||
|
|
||||
|
if messageTag == TagSHLO { |
||||
|
utils.Debugf("Got SHLO:\n%s", printHandshakeMessage(cryptoData)) |
||||
|
err = h.handleSHLOMessage(cryptoData) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if messageTag == TagREJ { |
||||
|
err = h.handleREJMessage(cryptoData) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { |
||||
|
utils.Debugf("Got REJ:\n%s", printHandshakeMessage(cryptoData)) |
||||
|
|
||||
|
var err error |
||||
|
|
||||
|
if stk, ok := cryptoData[TagSTK]; ok { |
||||
|
h.stk = stk |
||||
|
} |
||||
|
|
||||
|
if sno, ok := cryptoData[TagSNO]; ok { |
||||
|
h.sno = sno |
||||
|
} |
||||
|
|
||||
|
// TODO: what happens if the server sends a different server config in two packets?
|
||||
|
if scfg, ok := cryptoData[TagSCFG]; ok { |
||||
|
h.serverConfig, err = parseServerConfig(scfg) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
if h.serverConfig.IsExpired() { |
||||
|
return qerr.CryptoServerConfigExpired |
||||
|
} |
||||
|
|
||||
|
// now that we have a server config, we can use its OBIT value to generate a client nonce
|
||||
|
if len(h.nonc) == 0 { |
||||
|
err = h.generateClientNonce() |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if proof, ok := cryptoData[TagPROF]; ok { |
||||
|
h.proof = proof |
||||
|
h.chloForSignature = h.lastSentCHLO |
||||
|
} |
||||
|
|
||||
|
if crt, ok := cryptoData[TagCERT]; ok { |
||||
|
err := h.certManager.SetData(crt) |
||||
|
if err != nil { |
||||
|
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid") |
||||
|
} |
||||
|
|
||||
|
err = h.certManager.Verify(h.hostname) |
||||
|
if err != nil { |
||||
|
utils.Infof("Certificate validation failed: %s", err.Error()) |
||||
|
return qerr.ProofInvalid |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil { |
||||
|
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get()) |
||||
|
if !validProof { |
||||
|
utils.Infof("Server proof verification failed") |
||||
|
return qerr.ProofInvalid |
||||
|
} |
||||
|
|
||||
|
h.serverVerified = true |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { |
||||
|
h.mutex.Lock() |
||||
|
defer h.mutex.Unlock() |
||||
|
|
||||
|
if !h.receivedSecurePacket { |
||||
|
return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message") |
||||
|
} |
||||
|
|
||||
|
if sno, ok := cryptoData[TagSNO]; ok { |
||||
|
h.sno = sno |
||||
|
} |
||||
|
|
||||
|
serverPubs, ok := cryptoData[TagPUBS] |
||||
|
if !ok { |
||||
|
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") |
||||
|
} |
||||
|
|
||||
|
verTag, ok := cryptoData[TagVER] |
||||
|
if !ok { |
||||
|
return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list") |
||||
|
} |
||||
|
if !h.validateVersionList(verTag) { |
||||
|
return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") |
||||
|
} |
||||
|
|
||||
|
nonce := append(h.nonc, h.sno...) |
||||
|
|
||||
|
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
leafCert := h.certManager.GetLeafCert() |
||||
|
|
||||
|
h.forwardSecureAEAD, err = h.keyDerivation( |
||||
|
true, |
||||
|
ephermalSharedSecret, |
||||
|
nonce, |
||||
|
h.connID, |
||||
|
h.lastSentCHLO, |
||||
|
h.serverConfig.Get(), |
||||
|
leafCert, |
||||
|
nil, |
||||
|
protocol.PerspectiveClient, |
||||
|
) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
err = h.connectionParameters.SetFromMap(cryptoData) |
||||
|
if err != nil { |
||||
|
return qerr.InvalidCryptoMessageParameter |
||||
|
} |
||||
|
|
||||
|
h.aeadChanged <- struct{}{} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool { |
||||
|
if len(h.negotiatedVersions) == 0 { |
||||
|
return true |
||||
|
} |
||||
|
if len(verTags)%4 != 0 || len(verTags)/4 != len(h.negotiatedVersions) { |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
b := bytes.NewReader(verTags) |
||||
|
for _, negotiatedVersion := range h.negotiatedVersions { |
||||
|
verTag, err := utils.ReadUint32(b) |
||||
|
if err != nil { // should never occur, since the length was already checked
|
||||
|
return false |
||||
|
} |
||||
|
ver := protocol.VersionTagToNumber(verTag) |
||||
|
if !protocol.IsSupportedVersion(ver) { |
||||
|
ver = protocol.VersionUnsupported |
||||
|
} |
||||
|
if ver != negotiatedVersion { |
||||
|
return false |
||||
|
} |
||||
|
} |
||||
|
return true |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { |
||||
|
if h.forwardSecureAEAD != nil { |
||||
|
data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) |
||||
|
if err == nil { |
||||
|
return data, nil |
||||
|
} |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
if h.secureAEAD != nil { |
||||
|
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) |
||||
|
if err == nil { |
||||
|
h.receivedSecurePacket = true |
||||
|
return data, nil |
||||
|
} |
||||
|
if h.receivedSecurePacket { |
||||
|
return nil, err |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return (&crypto.NullAEAD{}).Open(dst, src, packetNumber, associatedData) |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { |
||||
|
if h.forwardSecureAEAD != nil { |
||||
|
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) |
||||
|
} |
||||
|
if h.secureAEAD != nil { |
||||
|
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) |
||||
|
} |
||||
|
return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData) |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) DiversificationNonce() []byte { |
||||
|
panic("not needed for cryptoSetupClient") |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) error { |
||||
|
if len(h.diversificationNonce) == 0 { |
||||
|
h.diversificationNonce = data |
||||
|
return h.maybeUpgradeCrypto() |
||||
|
} |
||||
|
if !bytes.Equal(h.diversificationNonce, data) { |
||||
|
return errConflictingDiversificationNonces |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) LockForSealing() { |
||||
|
|
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) UnlockForSealing() { |
||||
|
|
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) HandshakeComplete() bool { |
||||
|
h.mutex.RLock() |
||||
|
complete := h.forwardSecureAEAD != nil |
||||
|
h.mutex.RUnlock() |
||||
|
return complete |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) sendCHLO() error { |
||||
|
h.clientHelloCounter++ |
||||
|
if h.clientHelloCounter > protocol.MaxClientHellos { |
||||
|
return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos)) |
||||
|
} |
||||
|
|
||||
|
b := &bytes.Buffer{} |
||||
|
|
||||
|
tags, err := h.getTags() |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
h.addPadding(tags) |
||||
|
|
||||
|
utils.Debugf("Sending CHLO:\n%s", printHandshakeMessage(tags)) |
||||
|
WriteHandshakeMessage(b, TagCHLO, tags) |
||||
|
|
||||
|
_, err = h.cryptoStream.Write(b.Bytes()) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
h.lastSentCHLO = b.Bytes() |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { |
||||
|
tags, err := h.connectionParameters.GetHelloMap() |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
tags[TagSNI] = []byte(h.hostname) |
||||
|
tags[TagPDMD] = []byte("X509") |
||||
|
|
||||
|
ccs := h.certManager.GetCommonCertificateHashes() |
||||
|
if len(ccs) > 0 { |
||||
|
tags[TagCCS] = ccs |
||||
|
} |
||||
|
|
||||
|
versionTag := make([]byte, 4, 4) |
||||
|
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version)) |
||||
|
tags[TagVER] = versionTag |
||||
|
|
||||
|
if len(h.stk) > 0 { |
||||
|
tags[TagSTK] = h.stk |
||||
|
} |
||||
|
|
||||
|
if len(h.sno) > 0 { |
||||
|
tags[TagSNO] = h.sno |
||||
|
} |
||||
|
|
||||
|
if h.serverConfig != nil { |
||||
|
tags[TagSCID] = h.serverConfig.ID |
||||
|
|
||||
|
leafCert := h.certManager.GetLeafCert() |
||||
|
if leafCert != nil { |
||||
|
certHash, _ := h.certManager.GetLeafCertHash() |
||||
|
xlct := make([]byte, 8, 8) |
||||
|
binary.LittleEndian.PutUint64(xlct, certHash) |
||||
|
|
||||
|
tags[TagNONC] = h.nonc |
||||
|
tags[TagXLCT] = xlct |
||||
|
tags[TagKEXS] = []byte("C255") |
||||
|
tags[TagAEAD] = []byte("AESG") |
||||
|
tags[TagPUBS] = h.serverConfig.kex.PublicKey() // TODO: check if 3 bytes need to be prepended
|
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return tags, nil |
||||
|
} |
||||
|
|
||||
|
// add a TagPAD to a tagMap, such that the total size will be bigger than the ClientHelloMinimumSize
|
||||
|
func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) { |
||||
|
var size int |
||||
|
for _, tag := range tags { |
||||
|
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
||||
|
} |
||||
|
paddingSize := protocol.ClientHelloMinimumSize - size |
||||
|
if paddingSize > 0 { |
||||
|
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) maybeUpgradeCrypto() error { |
||||
|
if !h.serverVerified { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
h.mutex.Lock() |
||||
|
defer h.mutex.Unlock() |
||||
|
|
||||
|
leafCert := h.certManager.GetLeafCert() |
||||
|
|
||||
|
if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) { |
||||
|
var err error |
||||
|
var nonce []byte |
||||
|
if h.sno == nil { |
||||
|
nonce = h.nonc |
||||
|
} else { |
||||
|
nonce = append(h.nonc, h.sno...) |
||||
|
} |
||||
|
|
||||
|
h.secureAEAD, err = h.keyDerivation( |
||||
|
false, |
||||
|
h.serverConfig.sharedSecret, |
||||
|
nonce, |
||||
|
h.connID, |
||||
|
h.lastSentCHLO, |
||||
|
h.serverConfig.Get(), |
||||
|
leafCert, |
||||
|
h.diversificationNonce, |
||||
|
protocol.PerspectiveClient, |
||||
|
) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
h.aeadChanged <- struct{}{} |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (h *cryptoSetupClient) generateClientNonce() error { |
||||
|
if len(h.nonc) > 0 { |
||||
|
return errClientNonceAlreadyExists |
||||
|
} |
||||
|
|
||||
|
nonc := make([]byte, 32) |
||||
|
binary.BigEndian.PutUint32(nonc, uint32(time.Now().Unix())) |
||||
|
|
||||
|
if len(h.serverConfig.obit) != 8 { |
||||
|
return errNoObitForClientNonce |
||||
|
} |
||||
|
|
||||
|
copy(nonc[4:12], h.serverConfig.obit) |
||||
|
|
||||
|
_, err := rand.Read(nonc[12:]) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
h.nonc = nonc |
||||
|
return nil |
||||
|
} |
||||
@ -0,0 +1,16 @@ |
|||||
|
package handshake |
||||
|
|
||||
|
import "github.com/lucas-clemente/quic-go/protocol" |
||||
|
|
||||
|
// CryptoSetup is a crypto setup
|
||||
|
type CryptoSetup interface { |
||||
|
HandleCryptoStream() error |
||||
|
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) |
||||
|
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte |
||||
|
LockForSealing() |
||||
|
UnlockForSealing() |
||||
|
HandshakeComplete() bool |
||||
|
// TODO: clean up this interface
|
||||
|
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||
|
SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient
|
||||
|
} |
||||
@ -0,0 +1,148 @@ |
|||||
|
package handshake |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"encoding/binary" |
||||
|
"errors" |
||||
|
"math" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/lucas-clemente/quic-go/crypto" |
||||
|
"github.com/lucas-clemente/quic-go/qerr" |
||||
|
"github.com/lucas-clemente/quic-go/utils" |
||||
|
) |
||||
|
|
||||
|
type serverConfigClient struct { |
||||
|
raw []byte |
||||
|
ID []byte |
||||
|
obit []byte |
||||
|
expiry time.Time |
||||
|
|
||||
|
kex crypto.KeyExchange |
||||
|
sharedSecret []byte |
||||
|
} |
||||
|
|
||||
|
var ( |
||||
|
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG") |
||||
|
) |
||||
|
|
||||
|
// parseServerConfig parses a server config
|
||||
|
func parseServerConfig(data []byte) (*serverConfigClient, error) { |
||||
|
tag, tagMap, err := ParseHandshakeMessage(bytes.NewReader(data)) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
if tag != TagSCFG { |
||||
|
return nil, errMessageNotServerConfig |
||||
|
} |
||||
|
|
||||
|
scfg := &serverConfigClient{raw: data} |
||||
|
err = scfg.parseValues(tagMap) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
return scfg, nil |
||||
|
} |
||||
|
|
||||
|
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error { |
||||
|
// SCID
|
||||
|
scfgID, ok := tagMap[TagSCID] |
||||
|
if !ok { |
||||
|
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID") |
||||
|
} |
||||
|
if len(scfgID) != 16 { |
||||
|
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID") |
||||
|
} |
||||
|
s.ID = scfgID |
||||
|
|
||||
|
// KEXS
|
||||
|
// TODO: allow for P256 in the list
|
||||
|
// TODO: setup Key Exchange
|
||||
|
kexs, ok := tagMap[TagKEXS] |
||||
|
if !ok { |
||||
|
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS") |
||||
|
} |
||||
|
if len(kexs)%4 != 0 { |
||||
|
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS") |
||||
|
} |
||||
|
if !bytes.Equal(kexs, []byte("C255")) { |
||||
|
return qerr.Error(qerr.CryptoNoSupport, "KEXS") |
||||
|
} |
||||
|
|
||||
|
// AEAD
|
||||
|
aead, ok := tagMap[TagAEAD] |
||||
|
if !ok { |
||||
|
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD") |
||||
|
} |
||||
|
if len(aead)%4 != 0 { |
||||
|
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD") |
||||
|
} |
||||
|
var aesgFound bool |
||||
|
for i := 0; i < len(aead)/4; i++ { |
||||
|
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) { |
||||
|
aesgFound = true |
||||
|
break |
||||
|
} |
||||
|
} |
||||
|
if !aesgFound { |
||||
|
return qerr.Error(qerr.CryptoNoSupport, "AEAD") |
||||
|
} |
||||
|
|
||||
|
// PUBS
|
||||
|
// TODO: save this value
|
||||
|
pubs, ok := tagMap[TagPUBS] |
||||
|
if !ok { |
||||
|
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") |
||||
|
} |
||||
|
if len(pubs) != 35 { |
||||
|
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") |
||||
|
} |
||||
|
|
||||
|
var err error |
||||
|
s.kex, err = crypto.NewCurve25519KEX() |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
// the PUBS value is always prepended by []byte{0x20, 0x00, 0x00}
|
||||
|
s.sharedSecret, err = s.kex.CalculateSharedKey(pubs[3:]) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
// OBIT
|
||||
|
obit, ok := tagMap[TagOBIT] |
||||
|
if !ok { |
||||
|
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT") |
||||
|
} |
||||
|
if len(obit) != 8 { |
||||
|
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT") |
||||
|
} |
||||
|
s.obit = obit |
||||
|
|
||||
|
// EXPY
|
||||
|
expy, ok := tagMap[TagEXPY] |
||||
|
if !ok { |
||||
|
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY") |
||||
|
} |
||||
|
if len(expy) != 8 { |
||||
|
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY") |
||||
|
} |
||||
|
// make sure that the value doesn't overflow an int64
|
||||
|
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
|
||||
|
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2) |
||||
|
s.expiry = time.Unix(int64(expyTimestamp), 0) |
||||
|
|
||||
|
// TODO: implement VER
|
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (s *serverConfigClient) IsExpired() bool { |
||||
|
return s.expiry.Before(time.Now()) |
||||
|
} |
||||
|
|
||||
|
func (s *serverConfigClient) Get() []byte { |
||||
|
return s.raw |
||||
|
} |
||||
@ -0,0 +1,14 @@ |
|||||
|
package protocol |
||||
|
|
||||
|
// EncryptionLevel is the encryption level
|
||||
|
// Default value is Unencrypted
|
||||
|
type EncryptionLevel int |
||||
|
|
||||
|
const ( |
||||
|
// Unencrypted is not encrypted
|
||||
|
Unencrypted EncryptionLevel = iota |
||||
|
// EncryptionSecure is encrypted, but not forward secure
|
||||
|
EncryptionSecure |
||||
|
// EncryptionForwardSecure is forward secure
|
||||
|
EncryptionForwardSecure |
||||
|
) |
||||
@ -0,0 +1,10 @@ |
|||||
|
package protocol |
||||
|
|
||||
|
// Perspective determines if we're acting as a server or a client
|
||||
|
type Perspective int |
||||
|
|
||||
|
// the perspectives
|
||||
|
const ( |
||||
|
PerspectiveServer Perspective = 1 |
||||
|
PerspectiveClient Perspective = 2 |
||||
|
) |
||||
@ -0,0 +1,27 @@ |
|||||
|
package quic |
||||
|
|
||||
|
import "github.com/lucas-clemente/quic-go/frames" |
||||
|
|
||||
|
type unpackedPacket struct { |
||||
|
frames []frames.Frame |
||||
|
} |
||||
|
|
||||
|
func (u *unpackedPacket) IsRetransmittable() bool { |
||||
|
for _, f := range u.frames { |
||||
|
switch f.(type) { |
||||
|
case *frames.StreamFrame: |
||||
|
return true |
||||
|
case *frames.RstStreamFrame: |
||||
|
return true |
||||
|
case *frames.WindowUpdateFrame: |
||||
|
return true |
||||
|
case *frames.BlockedFrame: |
||||
|
return true |
||||
|
case *frames.PingFrame: |
||||
|
return true |
||||
|
case *frames.GoawayFrame: |
||||
|
return true |
||||
|
} |
||||
|
} |
||||
|
return false |
||||
|
} |
||||
@ -0,0 +1,22 @@ |
|||||
|
package utils |
||||
|
|
||||
|
import "sync/atomic" |
||||
|
|
||||
|
// An AtomicBool is an atomic bool
|
||||
|
type AtomicBool struct { |
||||
|
v int32 |
||||
|
} |
||||
|
|
||||
|
// Set sets the value
|
||||
|
func (a *AtomicBool) Set(value bool) { |
||||
|
var n int32 |
||||
|
if value { |
||||
|
n = 1 |
||||
|
} |
||||
|
atomic.StoreInt32(&a.v, n) |
||||
|
} |
||||
|
|
||||
|
// Get gets the value
|
||||
|
func (a *AtomicBool) Get() bool { |
||||
|
return atomic.LoadInt32(&a.v) != 0 |
||||
|
} |
||||
@ -0,0 +1,18 @@ |
|||||
|
package utils |
||||
|
|
||||
|
import ( |
||||
|
"crypto/rand" |
||||
|
"encoding/binary" |
||||
|
|
||||
|
"github.com/lucas-clemente/quic-go/protocol" |
||||
|
) |
||||
|
|
||||
|
// GenerateConnectionID generates a connection ID using cryptographic random
|
||||
|
func GenerateConnectionID() (protocol.ConnectionID, error) { |
||||
|
b := make([]byte, 8, 8) |
||||
|
_, err := rand.Read(b) |
||||
|
if err != nil { |
||||
|
return 0, err |
||||
|
} |
||||
|
return protocol.ConnectionID(binary.LittleEndian.Uint64(b)), nil |
||||
|
} |
||||
@ -0,0 +1,27 @@ |
|||||
|
package utils |
||||
|
|
||||
|
import ( |
||||
|
"net/url" |
||||
|
"strings" |
||||
|
) |
||||
|
|
||||
|
// HostnameFromAddr determines the hostname in an address string
|
||||
|
func HostnameFromAddr(addr string) (string, error) { |
||||
|
p, err := url.Parse(addr) |
||||
|
if err != nil { |
||||
|
return "", err |
||||
|
} |
||||
|
h := p.Host |
||||
|
|
||||
|
// copied from https://golang.org/src/net/http/transport.go
|
||||
|
if hasPort(h) { |
||||
|
h = h[:strings.LastIndex(h, ":")] |
||||
|
} |
||||
|
|
||||
|
return h, nil |
||||
|
} |
||||
|
|
||||
|
// copied from https://golang.org/src/net/http/http.go
|
||||
|
func hasPort(s string) bool { |
||||
|
return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") |
||||
|
} |
||||
@ -0,0 +1,17 @@ |
|||||
|
package utils |
||||
|
|
||||
|
import ( |
||||
|
"io" |
||||
|
|
||||
|
"github.com/lucas-clemente/quic-go/protocol" |
||||
|
) |
||||
|
|
||||
|
// Stream is the interface for QUIC streams
|
||||
|
type Stream interface { |
||||
|
io.Reader |
||||
|
io.Writer |
||||
|
io.Closer |
||||
|
StreamID() protocol.StreamID |
||||
|
CloseRemote(offset protocol.ByteCount) |
||||
|
Reset(error) |
||||
|
} |
||||
Loading…
Reference in new issue