mirror of https://github.com/ginuerzh/gost
66 changed files with 3778 additions and 853 deletions
@ -0,0 +1,219 @@ |
|||
package quic |
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto/tls" |
|||
"errors" |
|||
"net" |
|||
"strings" |
|||
"sync/atomic" |
|||
"time" |
|||
|
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
"github.com/lucas-clemente/quic-go/qerr" |
|||
"github.com/lucas-clemente/quic-go/utils" |
|||
) |
|||
|
|||
// A Client of QUIC
|
|||
type Client struct { |
|||
addr *net.UDPAddr |
|||
conn *net.UDPConn |
|||
hostname string |
|||
|
|||
connectionID protocol.ConnectionID |
|||
version protocol.VersionNumber |
|||
versionNegotiated bool |
|||
closed uint32 // atomic bool
|
|||
|
|||
tlsConfig *tls.Config |
|||
cryptoChangeCallback CryptoChangeCallback |
|||
versionNegotiateCallback VersionNegotiateCallback |
|||
|
|||
session packetHandler |
|||
} |
|||
|
|||
// VersionNegotiateCallback is called once the client has a negotiated version
|
|||
type VersionNegotiateCallback func() error |
|||
|
|||
var errHostname = errors.New("Invalid hostname") |
|||
|
|||
var ( |
|||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") |
|||
) |
|||
|
|||
// NewClient makes a new client
|
|||
func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) { |
|||
udpAddr, err := net.ResolveUDPAddr("udp", host) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
connectionID, err := utils.GenerateConnectionID() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
hostname, _, err := net.SplitHostPort(host) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
client := &Client{ |
|||
addr: udpAddr, |
|||
conn: conn, |
|||
hostname: hostname, |
|||
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
|
|||
connectionID: connectionID, |
|||
tlsConfig: tlsConfig, |
|||
cryptoChangeCallback: cryptoChangeCallback, |
|||
versionNegotiateCallback: versionNegotiateCallback, |
|||
} |
|||
|
|||
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version) |
|||
|
|||
err = client.createNewSession(nil) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return client, nil |
|||
} |
|||
|
|||
// Listen listens
|
|||
func (c *Client) Listen() error { |
|||
for { |
|||
data := getPacketBuffer() |
|||
data = data[:protocol.MaxPacketSize] |
|||
|
|||
n, _, err := c.conn.ReadFromUDP(data) |
|||
if err != nil { |
|||
if strings.HasSuffix(err.Error(), "use of closed network connection") { |
|||
return nil |
|||
} |
|||
return err |
|||
} |
|||
data = data[:n] |
|||
|
|||
err = c.handlePacket(data) |
|||
if err != nil { |
|||
utils.Errorf("error handling packet: %s", err.Error()) |
|||
c.session.Close(err) |
|||
return err |
|||
} |
|||
} |
|||
} |
|||
|
|||
// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs)
|
|||
func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) { |
|||
return c.session.OpenStream(id) |
|||
} |
|||
|
|||
// Close closes the connection
|
|||
func (c *Client) Close(e error) error { |
|||
// Only close once
|
|||
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { |
|||
return nil |
|||
} |
|||
|
|||
_ = c.session.Close(e) |
|||
return c.conn.Close() |
|||
} |
|||
|
|||
func (c *Client) handlePacket(packet []byte) error { |
|||
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize { |
|||
return qerr.PacketTooLarge |
|||
} |
|||
|
|||
rcvTime := time.Now() |
|||
|
|||
r := bytes.NewReader(packet) |
|||
|
|||
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) |
|||
if err != nil { |
|||
return qerr.Error(qerr.InvalidPacketHeader, err.Error()) |
|||
} |
|||
hdr.Raw = packet[:len(packet)-r.Len()] |
|||
|
|||
// ignore delayed / duplicated version negotiation packets
|
|||
if c.versionNegotiated && hdr.VersionFlag { |
|||
return nil |
|||
} |
|||
|
|||
// this is the first packet after the client sent a packet with the VersionFlag set
|
|||
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
|||
if !hdr.VersionFlag && !c.versionNegotiated { |
|||
c.versionNegotiated = true |
|||
err = c.versionNegotiateCallback() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
if hdr.VersionFlag { |
|||
var hasCommonVersion bool // check if we're supporting any of the offered versions
|
|||
for _, v := range hdr.SupportedVersions { |
|||
// check if the server sent the offered version in supported versions
|
|||
if v == c.version { |
|||
return qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection.") |
|||
} |
|||
if v != protocol.VersionUnsupported { |
|||
hasCommonVersion = true |
|||
} |
|||
} |
|||
if !hasCommonVersion { |
|||
utils.Infof("No common version found.") |
|||
return qerr.InvalidVersion |
|||
} |
|||
|
|||
ok, highestSupportedVersion := protocol.HighestSupportedVersion(hdr.SupportedVersions) |
|||
if !ok { |
|||
return qerr.VersionNegotiationMismatch |
|||
} |
|||
|
|||
utils.Infof("Switching to QUIC version %d", highestSupportedVersion) |
|||
c.version = highestSupportedVersion |
|||
c.versionNegotiated = true |
|||
|
|||
c.session.Close(errCloseSessionForNewVersion) |
|||
err = c.createNewSession(hdr.SupportedVersions) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
err = c.versionNegotiateCallback() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
return nil // version negotiation packets have no payload
|
|||
} |
|||
|
|||
c.session.handlePacket(&receivedPacket{ |
|||
remoteAddr: c.addr, |
|||
publicHeader: hdr, |
|||
data: packet[len(packet)-r.Len():], |
|||
rcvTime: rcvTime, |
|||
}) |
|||
return nil |
|||
} |
|||
|
|||
func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { |
|||
var err error |
|||
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.tlsConfig, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
go c.session.run() |
|||
return nil |
|||
} |
|||
|
|||
func (c *Client) streamCallback(session *Session, stream utils.Stream) {} |
|||
|
|||
func (c *Client) closeCallback(id protocol.ConnectionID) { |
|||
utils.Infof("Connection %x closed.", id) |
|||
} |
|||
@ -0,0 +1,84 @@ |
|||
package crypto |
|||
|
|||
import ( |
|||
"crypto/tls" |
|||
"errors" |
|||
"strings" |
|||
) |
|||
|
|||
// A CertChain holds a certificate and a private key
|
|||
type CertChain interface { |
|||
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) |
|||
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error) |
|||
GetLeafCert(sni string) ([]byte, error) |
|||
} |
|||
|
|||
// proofSource stores a key and a certificate for the server proof
|
|||
type certChain struct { |
|||
config *tls.Config |
|||
} |
|||
|
|||
var _ CertChain = &certChain{} |
|||
|
|||
var errNoMatchingCertificate = errors.New("no matching certificate found") |
|||
|
|||
// NewCertChain loads the key and cert from files
|
|||
func NewCertChain(tlsConfig *tls.Config) CertChain { |
|||
return &certChain{config: tlsConfig} |
|||
} |
|||
|
|||
// SignServerProof signs CHLO and server config for use in the server proof
|
|||
func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) { |
|||
cert, err := c.getCertForSNI(sni) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return signServerProof(cert, chlo, serverConfigData) |
|||
} |
|||
|
|||
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
|
|||
func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { |
|||
cert, err := c.getCertForSNI(sni) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes) |
|||
} |
|||
|
|||
// GetLeafCert gets the leaf certificate
|
|||
func (c *certChain) GetLeafCert(sni string) ([]byte, error) { |
|||
cert, err := c.getCertForSNI(sni) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return cert.Certificate[0], nil |
|||
} |
|||
|
|||
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { |
|||
if c.config.GetCertificate != nil { |
|||
cert, err := c.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if cert != nil { |
|||
return cert, nil |
|||
} |
|||
} |
|||
|
|||
if len(c.config.NameToCertificate) != 0 { |
|||
if cert, ok := c.config.NameToCertificate[sni]; ok { |
|||
return cert, nil |
|||
} |
|||
wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' }) |
|||
if cert, ok := c.config.NameToCertificate[wildcardSNI]; ok { |
|||
return cert, nil |
|||
} |
|||
} |
|||
|
|||
if len(c.config.Certificates) != 0 { |
|||
return &c.config.Certificates[0], nil |
|||
} |
|||
|
|||
return nil, errNoMatchingCertificate |
|||
} |
|||
@ -0,0 +1,131 @@ |
|||
package crypto |
|||
|
|||
import ( |
|||
"crypto/tls" |
|||
"crypto/x509" |
|||
"errors" |
|||
"hash/fnv" |
|||
"time" |
|||
|
|||
"github.com/lucas-clemente/quic-go/qerr" |
|||
) |
|||
|
|||
// CertManager manages the certificates sent by the server
|
|||
type CertManager interface { |
|||
SetData([]byte) error |
|||
GetCommonCertificateHashes() []byte |
|||
GetLeafCert() []byte |
|||
GetLeafCertHash() (uint64, error) |
|||
VerifyServerProof(proof, chlo, serverConfigData []byte) bool |
|||
Verify(hostname string) error |
|||
} |
|||
|
|||
type certManager struct { |
|||
chain []*x509.Certificate |
|||
config *tls.Config |
|||
} |
|||
|
|||
var _ CertManager = &certManager{} |
|||
|
|||
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded") |
|||
|
|||
// NewCertManager creates a new CertManager
|
|||
func NewCertManager(tlsConfig *tls.Config) CertManager { |
|||
return &certManager{config: tlsConfig} |
|||
} |
|||
|
|||
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
|
|||
func (c *certManager) SetData(data []byte) error { |
|||
byteChain, err := decompressChain(data) |
|||
if err != nil { |
|||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid") |
|||
} |
|||
|
|||
chain := make([]*x509.Certificate, len(byteChain), len(byteChain)) |
|||
for i, data := range byteChain { |
|||
cert, err := x509.ParseCertificate(data) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
chain[i] = cert |
|||
} |
|||
|
|||
c.chain = chain |
|||
return nil |
|||
} |
|||
|
|||
func (c *certManager) GetCommonCertificateHashes() []byte { |
|||
return getCommonCertificateHashes() |
|||
} |
|||
|
|||
// GetLeafCert returns the leaf certificate of the certificate chain
|
|||
// it returns nil if the certificate chain has not yet been set
|
|||
func (c *certManager) GetLeafCert() []byte { |
|||
if len(c.chain) == 0 { |
|||
return nil |
|||
} |
|||
return c.chain[0].Raw |
|||
} |
|||
|
|||
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
|
|||
func (c *certManager) GetLeafCertHash() (uint64, error) { |
|||
leafCert := c.GetLeafCert() |
|||
if leafCert == nil { |
|||
return 0, errNoCertificateChain |
|||
} |
|||
|
|||
h := fnv.New64a() |
|||
_, err := h.Write(leafCert) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
return h.Sum64(), nil |
|||
} |
|||
|
|||
// VerifyServerProof verifies the signature of the server config
|
|||
// it should only be called after the certificate chain has been set, otherwise it returns false
|
|||
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool { |
|||
if len(c.chain) == 0 { |
|||
return false |
|||
} |
|||
|
|||
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData) |
|||
} |
|||
|
|||
// Verify verifies the certificate chain
|
|||
func (c *certManager) Verify(hostname string) error { |
|||
if len(c.chain) == 0 { |
|||
return errNoCertificateChain |
|||
} |
|||
|
|||
if c.config != nil && c.config.InsecureSkipVerify { |
|||
return nil |
|||
} |
|||
|
|||
leafCert := c.chain[0] |
|||
|
|||
var opts x509.VerifyOptions |
|||
if c.config != nil { |
|||
opts.Roots = c.config.RootCAs |
|||
opts.DNSName = c.config.ServerName |
|||
if c.config.Time == nil { |
|||
opts.CurrentTime = time.Now() |
|||
} else { |
|||
opts.CurrentTime = c.config.Time() |
|||
} |
|||
} else { |
|||
opts.DNSName = hostname |
|||
} |
|||
|
|||
// the first certificate is the leaf certificate, all others are intermediates
|
|||
if len(c.chain) > 1 { |
|||
intermediates := x509.NewCertPool() |
|||
for i := 1; i < len(c.chain); i++ { |
|||
intermediates.AddCert(c.chain[i]) |
|||
} |
|||
opts.Intermediates = intermediates |
|||
} |
|||
|
|||
_, err := leafCert.Verify(opts) |
|||
return err |
|||
} |
|||
@ -1,71 +0,0 @@ |
|||
// +build ignore
|
|||
|
|||
package crypto |
|||
|
|||
import ( |
|||
"crypto/rand" |
|||
|
|||
. "github.com/onsi/ginkgo" |
|||
. "github.com/onsi/gomega" |
|||
) |
|||
|
|||
var _ = Describe("Chacha20poly1305", func() { |
|||
var ( |
|||
alice, bob AEAD |
|||
keyAlice, keyBob, ivAlice, ivBob []byte |
|||
) |
|||
|
|||
BeforeEach(func() { |
|||
keyAlice = make([]byte, 32) |
|||
keyBob = make([]byte, 32) |
|||
ivAlice = make([]byte, 4) |
|||
ivBob = make([]byte, 4) |
|||
rand.Reader.Read(keyAlice) |
|||
rand.Reader.Read(keyBob) |
|||
rand.Reader.Read(ivAlice) |
|||
rand.Reader.Read(ivBob) |
|||
var err error |
|||
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice) |
|||
Expect(err).ToNot(HaveOccurred()) |
|||
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob) |
|||
Expect(err).ToNot(HaveOccurred()) |
|||
}) |
|||
|
|||
It("seals and opens", func() { |
|||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad")) |
|||
text, err := bob.Open(nil, b, 42, []byte("aad")) |
|||
Expect(err).ToNot(HaveOccurred()) |
|||
Expect(text).To(Equal([]byte("foobar"))) |
|||
}) |
|||
|
|||
It("seals and opens reverse", func() { |
|||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad")) |
|||
text, err := alice.Open(nil, b, 42, []byte("aad")) |
|||
Expect(err).ToNot(HaveOccurred()) |
|||
Expect(text).To(Equal([]byte("foobar"))) |
|||
}) |
|||
|
|||
It("has the proper length", func() { |
|||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad")) |
|||
Expect(b).To(HaveLen(6 + 12)) |
|||
}) |
|||
|
|||
It("fails with wrong aad", func() { |
|||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad")) |
|||
_, err := bob.Open(nil, b, 42, []byte("aad2")) |
|||
Expect(err).To(HaveOccurred()) |
|||
}) |
|||
|
|||
It("rejects wrong key and iv sizes", func() { |
|||
var err error |
|||
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs" |
|||
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice) |
|||
Expect(err).To(MatchError(e)) |
|||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice) |
|||
Expect(err).To(MatchError(e)) |
|||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice) |
|||
Expect(err).To(MatchError(e)) |
|||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:]) |
|||
Expect(err).To(MatchError(e)) |
|||
}) |
|||
}) |
|||
@ -1,92 +0,0 @@ |
|||
package crypto |
|||
|
|||
import ( |
|||
"crypto" |
|||
"crypto/rand" |
|||
"crypto/rsa" |
|||
"crypto/sha256" |
|||
"crypto/tls" |
|||
"errors" |
|||
"strings" |
|||
) |
|||
|
|||
// proofSource stores a key and a certificate for the server proof
|
|||
type proofSource struct { |
|||
config *tls.Config |
|||
} |
|||
|
|||
// NewProofSource loads the key and cert from files
|
|||
func NewProofSource(tlsConfig *tls.Config) (Signer, error) { |
|||
return &proofSource{config: tlsConfig}, nil |
|||
} |
|||
|
|||
// SignServerProof signs CHLO and server config for use in the server proof
|
|||
func (ps *proofSource) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) { |
|||
cert, err := ps.getCertForSNI(sni) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
hash := sha256.New() |
|||
hash.Write([]byte("QUIC CHLO and server config signature\x00")) |
|||
chloHash := sha256.Sum256(chlo) |
|||
hash.Write([]byte{32, 0, 0, 0}) |
|||
hash.Write(chloHash[:]) |
|||
hash.Write(serverConfigData) |
|||
|
|||
key, ok := cert.PrivateKey.(crypto.Signer) |
|||
if !ok { |
|||
return nil, errors.New("expected PrivateKey to implement crypto.Signer") |
|||
} |
|||
|
|||
opts := crypto.SignerOpts(crypto.SHA256) |
|||
|
|||
if _, ok = key.(*rsa.PrivateKey); ok { |
|||
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256} |
|||
} |
|||
|
|||
return key.Sign(rand.Reader, hash.Sum(nil), opts) |
|||
} |
|||
|
|||
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
|
|||
func (ps *proofSource) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { |
|||
cert, err := ps.getCertForSNI(sni) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes) |
|||
} |
|||
|
|||
// GetLeafCert gets the leaf certificate
|
|||
func (ps *proofSource) GetLeafCert(sni string) ([]byte, error) { |
|||
cert, err := ps.getCertForSNI(sni) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return cert.Certificate[0], nil |
|||
} |
|||
|
|||
func (ps *proofSource) getCertForSNI(sni string) (*tls.Certificate, error) { |
|||
if ps.config.GetCertificate != nil { |
|||
cert, err := ps.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if cert != nil { |
|||
return cert, nil |
|||
} |
|||
} |
|||
if len(ps.config.NameToCertificate) != 0 { |
|||
if cert, ok := ps.config.NameToCertificate[sni]; ok { |
|||
return cert, nil |
|||
} |
|||
wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' }) |
|||
if cert, ok := ps.config.NameToCertificate[wildcardSNI]; ok { |
|||
return cert, nil |
|||
} |
|||
} |
|||
if len(ps.config.Certificates) != 0 { |
|||
return &ps.config.Certificates[0], nil |
|||
} |
|||
return nil, errors.New("no matching certificate found") |
|||
} |
|||
@ -0,0 +1,66 @@ |
|||
package crypto |
|||
|
|||
import ( |
|||
"crypto" |
|||
"crypto/ecdsa" |
|||
"crypto/rand" |
|||
"crypto/rsa" |
|||
"crypto/sha256" |
|||
"crypto/tls" |
|||
"crypto/x509" |
|||
"encoding/asn1" |
|||
"errors" |
|||
"math/big" |
|||
) |
|||
|
|||
type ecdsaSignature struct { |
|||
R, S *big.Int |
|||
} |
|||
|
|||
// signServerProof signs CHLO and server config for use in the server proof
|
|||
func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) { |
|||
hash := sha256.New() |
|||
hash.Write([]byte("QUIC CHLO and server config signature\x00")) |
|||
chloHash := sha256.Sum256(chlo) |
|||
hash.Write([]byte{32, 0, 0, 0}) |
|||
hash.Write(chloHash[:]) |
|||
hash.Write(serverConfigData) |
|||
|
|||
key, ok := cert.PrivateKey.(crypto.Signer) |
|||
if !ok { |
|||
return nil, errors.New("expected PrivateKey to implement crypto.Signer") |
|||
} |
|||
|
|||
opts := crypto.SignerOpts(crypto.SHA256) |
|||
|
|||
if _, ok = key.(*rsa.PrivateKey); ok { |
|||
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256} |
|||
} |
|||
|
|||
return key.Sign(rand.Reader, hash.Sum(nil), opts) |
|||
} |
|||
|
|||
// verifyServerProof verifies the server proof signature
|
|||
func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool { |
|||
hash := sha256.New() |
|||
hash.Write([]byte("QUIC CHLO and server config signature\x00")) |
|||
chloHash := sha256.Sum256(chlo) |
|||
hash.Write([]byte{32, 0, 0, 0}) |
|||
hash.Write(chloHash[:]) |
|||
hash.Write(serverConfigData) |
|||
|
|||
// RSA
|
|||
if cert.PublicKeyAlgorithm == x509.RSA { |
|||
opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256} |
|||
err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts) |
|||
return err == nil |
|||
} |
|||
|
|||
// ECDSA
|
|||
signature := &ecdsaSignature{} |
|||
rest, err := asn1.Unmarshal(proof, signature) |
|||
if err != nil || len(rest) != 0 { |
|||
return false |
|||
} |
|||
return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S) |
|||
} |
|||
@ -1,8 +0,0 @@ |
|||
package crypto |
|||
|
|||
// A Signer holds a certificate and a private key
|
|||
type Signer interface { |
|||
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) |
|||
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error) |
|||
GetLeafCert(sni string) ([]byte, error) |
|||
} |
|||
@ -0,0 +1,293 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"crypto/tls" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"net" |
|||
"net/http" |
|||
"strings" |
|||
"sync" |
|||
|
|||
"golang.org/x/net/http2" |
|||
"golang.org/x/net/http2/hpack" |
|||
"golang.org/x/net/idna" |
|||
|
|||
quic "github.com/lucas-clemente/quic-go" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
"github.com/lucas-clemente/quic-go/qerr" |
|||
"github.com/lucas-clemente/quic-go/utils" |
|||
) |
|||
|
|||
type quicClient interface { |
|||
OpenStream(protocol.StreamID) (utils.Stream, error) |
|||
Close(error) error |
|||
Listen() error |
|||
} |
|||
|
|||
// Client is a HTTP2 client doing QUIC requests
|
|||
type Client struct { |
|||
mutex sync.RWMutex |
|||
cryptoChangedCond sync.Cond |
|||
|
|||
t *QuicRoundTripper |
|||
|
|||
hostname string |
|||
encryptionLevel protocol.EncryptionLevel |
|||
|
|||
client quicClient |
|||
headerStream utils.Stream |
|||
headerErr *qerr.QuicError |
|||
highestOpenedStream protocol.StreamID |
|||
requestWriter *requestWriter |
|||
|
|||
responses map[protocol.StreamID]chan *http.Response |
|||
} |
|||
|
|||
var _ h2quicClient = &Client{} |
|||
|
|||
// NewClient creates a new client
|
|||
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) { |
|||
c := &Client{ |
|||
t: t, |
|||
hostname: authorityAddr("https", hostname), |
|||
highestOpenedStream: 3, |
|||
responses: make(map[protocol.StreamID]chan *http.Response), |
|||
} |
|||
c.cryptoChangedCond = sync.Cond{L: &c.mutex} |
|||
|
|||
var err error |
|||
c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
go c.client.Listen() |
|||
return c, nil |
|||
} |
|||
|
|||
func (c *Client) handleStreamCb(session *quic.Session, stream utils.Stream) { |
|||
utils.Debugf("Handling stream %d", stream.StreamID()) |
|||
} |
|||
|
|||
func (c *Client) cryptoChangeCallback(isForwardSecure bool) { |
|||
c.cryptoChangedCond.L.Lock() |
|||
defer c.cryptoChangedCond.L.Unlock() |
|||
|
|||
if isForwardSecure { |
|||
c.encryptionLevel = protocol.EncryptionForwardSecure |
|||
utils.Debugf("is forward secure") |
|||
} else { |
|||
c.encryptionLevel = protocol.EncryptionSecure |
|||
utils.Debugf("is secure") |
|||
} |
|||
c.cryptoChangedCond.Broadcast() |
|||
} |
|||
|
|||
func (c *Client) versionNegotiateCallback() error { |
|||
var err error |
|||
// once the version has been negotiated, open the header stream
|
|||
c.headerStream, err = c.client.OpenStream(3) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
c.requestWriter = newRequestWriter(c.headerStream) |
|||
go c.handleHeaderStream() |
|||
return nil |
|||
} |
|||
|
|||
func (c *Client) handleHeaderStream() { |
|||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) |
|||
h2framer := http2.NewFramer(nil, c.headerStream) |
|||
|
|||
var lastStream protocol.StreamID |
|||
|
|||
for { |
|||
frame, err := h2framer.ReadFrame() |
|||
if err != nil { |
|||
c.headerErr = qerr.Error(qerr.InvalidStreamData, "cannot read frame") |
|||
break |
|||
} |
|||
lastStream = protocol.StreamID(frame.Header().StreamID) |
|||
hframe, ok := frame.(*http2.HeadersFrame) |
|||
if !ok { |
|||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") |
|||
break |
|||
} |
|||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} |
|||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) |
|||
if err != nil { |
|||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") |
|||
break |
|||
} |
|||
|
|||
c.mutex.RLock() |
|||
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] |
|||
c.mutex.RUnlock() |
|||
if !ok { |
|||
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) |
|||
break |
|||
} |
|||
|
|||
rsp, err := responseFromHeaders(mhframe) |
|||
if err != nil { |
|||
c.headerErr = qerr.Error(qerr.InternalError, err.Error()) |
|||
} |
|||
headerChan <- rsp |
|||
} |
|||
|
|||
// stop all running request
|
|||
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) |
|||
c.mutex.Lock() |
|||
for _, responseChan := range c.responses { |
|||
responseChan <- nil |
|||
} |
|||
c.mutex.Unlock() |
|||
} |
|||
|
|||
// Do executes a request and returns a response
|
|||
func (c *Client) Do(req *http.Request) (*http.Response, error) { |
|||
// TODO: add port to address, if it doesn't have one
|
|||
if req.URL.Scheme != "https" { |
|||
return nil, errors.New("quic http2: unsupported scheme") |
|||
} |
|||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { |
|||
utils.Debugf("%s vs %s", req.Host, c.hostname) |
|||
return nil, errors.New("h2quic Client BUG: Do called for the wrong client") |
|||
} |
|||
|
|||
hasBody := (req.Body != nil) |
|||
|
|||
c.mutex.Lock() |
|||
c.highestOpenedStream += 2 |
|||
dataStreamID := c.highestOpenedStream |
|||
for c.encryptionLevel != protocol.EncryptionForwardSecure { |
|||
c.cryptoChangedCond.Wait() |
|||
} |
|||
hdrChan := make(chan *http.Response) |
|||
c.responses[dataStreamID] = hdrChan |
|||
c.mutex.Unlock() |
|||
|
|||
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
|
|||
dataStream, err := c.client.OpenStream(dataStreamID) |
|||
if err != nil { |
|||
c.Close(err) |
|||
return nil, err |
|||
} |
|||
|
|||
var requestedGzip bool |
|||
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { |
|||
requestedGzip = true |
|||
} |
|||
// TODO: add support for trailers
|
|||
endStream := !hasBody |
|||
err = c.requestWriter.WriteRequest(req, dataStreamID, endStream, requestedGzip) |
|||
if err != nil { |
|||
c.Close(err) |
|||
return nil, err |
|||
} |
|||
|
|||
resc := make(chan error, 1) |
|||
if hasBody { |
|||
go func() { |
|||
resc <- c.writeRequestBody(dataStream, req.Body) |
|||
}() |
|||
} |
|||
|
|||
var res *http.Response |
|||
|
|||
var receivedResponse bool |
|||
var bodySent bool |
|||
|
|||
if !hasBody { |
|||
bodySent = true |
|||
} |
|||
|
|||
for !(bodySent && receivedResponse) { |
|||
select { |
|||
case res = <-hdrChan: |
|||
receivedResponse = true |
|||
c.mutex.Lock() |
|||
delete(c.responses, dataStreamID) |
|||
c.mutex.Unlock() |
|||
if res == nil { // an error occured on the header stream
|
|||
c.Close(c.headerErr) |
|||
return nil, c.headerErr |
|||
} |
|||
case err := <-resc: |
|||
bodySent = true |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
} |
|||
|
|||
// TODO: correctly set this variable
|
|||
var streamEnded bool |
|||
isHead := (req.Method == "HEAD") |
|||
|
|||
res = setLength(res, isHead, streamEnded) |
|||
|
|||
if streamEnded || isHead { |
|||
res.Body = noBody |
|||
} else { |
|||
res.Body = dataStream |
|||
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { |
|||
res.Header.Del("Content-Encoding") |
|||
res.Header.Del("Content-Length") |
|||
res.ContentLength = -1 |
|||
res.Body = &gzipReader{body: res.Body} |
|||
setUncompressed(res) |
|||
} |
|||
} |
|||
|
|||
res.Request = req |
|||
|
|||
return res, nil |
|||
} |
|||
|
|||
func (c *Client) writeRequestBody(dataStream utils.Stream, body io.ReadCloser) (err error) { |
|||
defer func() { |
|||
cerr := body.Close() |
|||
if err == nil { |
|||
// TODO: what to do with dataStream here? Maybe reset it?
|
|||
err = cerr |
|||
} |
|||
}() |
|||
|
|||
_, err = io.Copy(dataStream, body) |
|||
if err != nil { |
|||
// TODO: what to do with dataStream here? Maybe reset it?
|
|||
return err |
|||
} |
|||
return dataStream.Close() |
|||
} |
|||
|
|||
// Close closes the client
|
|||
func (c *Client) Close(e error) { |
|||
_ = c.client.Close(e) |
|||
} |
|||
|
|||
// copied from net/transport.go
|
|||
|
|||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
|||
// and returns a host:port. The port 443 is added if needed.
|
|||
func authorityAddr(scheme string, authority string) (addr string) { |
|||
host, port, err := net.SplitHostPort(authority) |
|||
if err != nil { // authority didn't have a port
|
|||
port = "443" |
|||
if scheme == "http" { |
|||
port = "80" |
|||
} |
|||
host = authority |
|||
} |
|||
if a, err := idna.ToASCII(host); err == nil { |
|||
host = a |
|||
} |
|||
// IPv6 address literal, without a port:
|
|||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { |
|||
return host + ":" + port |
|||
} |
|||
return net.JoinHostPort(host, port) |
|||
} |
|||
@ -0,0 +1,35 @@ |
|||
package h2quic |
|||
|
|||
// copied from net/transport.go
|
|||
|
|||
// gzipReader wraps a response body so it can lazily
|
|||
// call gzip.NewReader on the first call to Read
|
|||
import ( |
|||
"compress/gzip" |
|||
"io" |
|||
) |
|||
|
|||
// call gzip.NewReader on the first call to Read
|
|||
type gzipReader struct { |
|||
body io.ReadCloser // underlying Response.Body
|
|||
zr *gzip.Reader // lazily-initialized gzip reader
|
|||
zerr error // sticky error
|
|||
} |
|||
|
|||
func (gz *gzipReader) Read(p []byte) (n int, err error) { |
|||
if gz.zerr != nil { |
|||
return 0, gz.zerr |
|||
} |
|||
if gz.zr == nil { |
|||
gz.zr, err = gzip.NewReader(gz.body) |
|||
if err != nil { |
|||
gz.zerr = err |
|||
return 0, err |
|||
} |
|||
} |
|||
return gz.zr.Read(p) |
|||
} |
|||
|
|||
func (gz *gzipReader) Close() error { |
|||
return gz.body.Close() |
|||
} |
|||
@ -0,0 +1,29 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"io" |
|||
|
|||
"github.com/lucas-clemente/quic-go/utils" |
|||
) |
|||
|
|||
type requestBody struct { |
|||
requestRead bool |
|||
dataStream utils.Stream |
|||
} |
|||
|
|||
// make sure the requestBody can be used as a http.Request.Body
|
|||
var _ io.ReadCloser = &requestBody{} |
|||
|
|||
func newRequestBody(stream utils.Stream) *requestBody { |
|||
return &requestBody{dataStream: stream} |
|||
} |
|||
|
|||
func (b *requestBody) Read(p []byte) (int, error) { |
|||
b.requestRead = true |
|||
return b.dataStream.Read(p) |
|||
} |
|||
|
|||
func (b *requestBody) Close() error { |
|||
// stream's Close() closes the write side, not the read side
|
|||
return nil |
|||
} |
|||
@ -0,0 +1,200 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"net/http" |
|||
"strconv" |
|||
"strings" |
|||
"sync" |
|||
|
|||
"golang.org/x/net/http2" |
|||
"golang.org/x/net/http2/hpack" |
|||
"golang.org/x/net/lex/httplex" |
|||
|
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
"github.com/lucas-clemente/quic-go/utils" |
|||
) |
|||
|
|||
type requestWriter struct { |
|||
mutex sync.Mutex |
|||
headerStream utils.Stream |
|||
|
|||
henc *hpack.Encoder |
|||
hbuf bytes.Buffer // HPACK encoder writes into this
|
|||
} |
|||
|
|||
const defaultUserAgent = "quic-go" |
|||
|
|||
func newRequestWriter(headerStream utils.Stream) *requestWriter { |
|||
rw := &requestWriter{ |
|||
headerStream: headerStream, |
|||
} |
|||
rw.henc = hpack.NewEncoder(&rw.hbuf) |
|||
return rw |
|||
} |
|||
|
|||
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error { |
|||
// TODO: add support for trailers
|
|||
// TODO: add support for gzip compression
|
|||
// TODO: write continuation frames, if the header frame is too long
|
|||
|
|||
w.mutex.Lock() |
|||
defer w.mutex.Unlock() |
|||
|
|||
w.encodeHeaders(req, requestGzip, "", actualContentLength(req)) |
|||
h2framer := http2.NewFramer(w.headerStream, nil) |
|||
return h2framer.WriteHeaders(http2.HeadersFrameParam{ |
|||
StreamID: uint32(dataStreamID), |
|||
EndHeaders: true, |
|||
EndStream: endStream, |
|||
BlockFragment: w.hbuf.Bytes(), |
|||
Priority: http2.PriorityParam{Weight: 0xff}, |
|||
}) |
|||
} |
|||
|
|||
// the rest of this files is copied from http2.Transport
|
|||
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { |
|||
w.hbuf.Reset() |
|||
|
|||
host := req.Host |
|||
if host == "" { |
|||
host = req.URL.Host |
|||
} |
|||
host, err := httplex.PunycodeHostPort(host) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
var path string |
|||
if req.Method != "CONNECT" { |
|||
path = req.URL.RequestURI() |
|||
if !validPseudoPath(path) { |
|||
orig := path |
|||
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) |
|||
if !validPseudoPath(path) { |
|||
if req.URL.Opaque != "" { |
|||
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) |
|||
} else { |
|||
return nil, fmt.Errorf("invalid request :path %q", orig) |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Check for any invalid headers and return an error before we
|
|||
// potentially pollute our hpack state. (We want to be able to
|
|||
// continue to reuse the hpack encoder for future requests)
|
|||
for k, vv := range req.Header { |
|||
if !httplex.ValidHeaderFieldName(k) { |
|||
return nil, fmt.Errorf("invalid HTTP header name %q", k) |
|||
} |
|||
for _, v := range vv { |
|||
if !httplex.ValidHeaderFieldValue(v) { |
|||
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// 8.1.2.3 Request Pseudo-Header Fields
|
|||
// The :path pseudo-header field includes the path and query parts of the
|
|||
// target URI (the path-absolute production and optionally a '?' character
|
|||
// followed by the query production (see Sections 3.3 and 3.4 of
|
|||
// [RFC3986]).
|
|||
w.writeHeader(":authority", host) |
|||
w.writeHeader(":method", req.Method) |
|||
if req.Method != "CONNECT" { |
|||
w.writeHeader(":path", path) |
|||
w.writeHeader(":scheme", req.URL.Scheme) |
|||
} |
|||
if trailers != "" { |
|||
w.writeHeader("trailer", trailers) |
|||
} |
|||
|
|||
var didUA bool |
|||
for k, vv := range req.Header { |
|||
lowKey := strings.ToLower(k) |
|||
switch lowKey { |
|||
case "host", "content-length": |
|||
// Host is :authority, already sent.
|
|||
// Content-Length is automatic, set below.
|
|||
continue |
|||
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": |
|||
// Per 8.1.2.2 Connection-Specific Header
|
|||
// Fields, don't send connection-specific
|
|||
// fields. We have already checked if any
|
|||
// are error-worthy so just ignore the rest.
|
|||
continue |
|||
case "user-agent": |
|||
// Match Go's http1 behavior: at most one
|
|||
// User-Agent. If set to nil or empty string,
|
|||
// then omit it. Otherwise if not mentioned,
|
|||
// include the default (below).
|
|||
didUA = true |
|||
if len(vv) < 1 { |
|||
continue |
|||
} |
|||
vv = vv[:1] |
|||
if vv[0] == "" { |
|||
continue |
|||
} |
|||
} |
|||
for _, v := range vv { |
|||
w.writeHeader(lowKey, v) |
|||
} |
|||
} |
|||
if shouldSendReqContentLength(req.Method, contentLength) { |
|||
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) |
|||
} |
|||
if addGzipHeader { |
|||
w.writeHeader("accept-encoding", "gzip") |
|||
} |
|||
if !didUA { |
|||
w.writeHeader("user-agent", defaultUserAgent) |
|||
} |
|||
return w.hbuf.Bytes(), nil |
|||
} |
|||
|
|||
func (w *requestWriter) writeHeader(name, value string) { |
|||
utils.Debugf("http2: Transport encoding header %q = %q", name, value) |
|||
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) |
|||
} |
|||
|
|||
// shouldSendReqContentLength reports whether the http2.Transport should send
|
|||
// a "content-length" request header. This logic is basically a copy of the net/http
|
|||
// transferWriter.shouldSendContentLength.
|
|||
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
|||
// -1 means unknown.
|
|||
func shouldSendReqContentLength(method string, contentLength int64) bool { |
|||
if contentLength > 0 { |
|||
return true |
|||
} |
|||
if contentLength < 0 { |
|||
return false |
|||
} |
|||
// For zero bodies, whether we send a content-length depends on the method.
|
|||
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
|||
switch method { |
|||
case "POST", "PUT", "PATCH": |
|||
return true |
|||
default: |
|||
return false |
|||
} |
|||
} |
|||
|
|||
func validPseudoPath(v string) bool { |
|||
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" |
|||
} |
|||
|
|||
// actualContentLength returns a sanitized version of
|
|||
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
|||
// means unknown.
|
|||
func actualContentLength(req *http.Request) int64 { |
|||
if req.Body == nil { |
|||
return 0 |
|||
} |
|||
if req.ContentLength != 0 { |
|||
return req.ContentLength |
|||
} |
|||
return -1 |
|||
} |
|||
@ -0,0 +1,111 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"bytes" |
|||
"errors" |
|||
"io" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/textproto" |
|||
"strconv" |
|||
"strings" |
|||
|
|||
"golang.org/x/net/http2" |
|||
) |
|||
|
|||
// copied from net/http2/transport.go
|
|||
|
|||
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") |
|||
var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) |
|||
|
|||
// from the handleResponse function
|
|||
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { |
|||
if f.Truncated { |
|||
return nil, errResponseHeaderListSize |
|||
} |
|||
|
|||
status := f.PseudoValue("status") |
|||
if status == "" { |
|||
return nil, errors.New("missing status pseudo header") |
|||
} |
|||
statusCode, err := strconv.Atoi(status) |
|||
if err != nil { |
|||
return nil, errors.New("malformed non-numeric status pseudo header") |
|||
} |
|||
|
|||
if statusCode == 100 { |
|||
// TODO: handle this
|
|||
|
|||
// traceGot100Continue(cs.trace)
|
|||
// if cs.on100 != nil {
|
|||
// cs.on100() // forces any write delay timer to fire
|
|||
// }
|
|||
// cs.pastHeaders = false // do it all again
|
|||
// return nil, nil
|
|||
} |
|||
|
|||
header := make(http.Header) |
|||
res := &http.Response{ |
|||
Proto: "HTTP/2.0", |
|||
ProtoMajor: 2, |
|||
Header: header, |
|||
StatusCode: statusCode, |
|||
Status: status + " " + http.StatusText(statusCode), |
|||
} |
|||
for _, hf := range f.RegularFields() { |
|||
key := http.CanonicalHeaderKey(hf.Name) |
|||
if key == "Trailer" { |
|||
t := res.Trailer |
|||
if t == nil { |
|||
t = make(http.Header) |
|||
res.Trailer = t |
|||
} |
|||
foreachHeaderElement(hf.Value, func(v string) { |
|||
t[http.CanonicalHeaderKey(v)] = nil |
|||
}) |
|||
} else { |
|||
header[key] = append(header[key], hf.Value) |
|||
} |
|||
} |
|||
|
|||
return res, nil |
|||
} |
|||
|
|||
// continuation of the handleResponse function
|
|||
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response { |
|||
if !streamEnded || isHead { |
|||
res.ContentLength = -1 |
|||
if clens := res.Header["Content-Length"]; len(clens) == 1 { |
|||
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { |
|||
res.ContentLength = clen64 |
|||
} else { |
|||
// TODO: care? unlike http/1, it won't mess up our framing, so it's
|
|||
// more safe smuggling-wise to ignore.
|
|||
} |
|||
} else if len(clens) > 1 { |
|||
// TODO: care? unlike http/1, it won't mess up our framing, so it's
|
|||
// more safe smuggling-wise to ignore.
|
|||
} |
|||
} |
|||
return res |
|||
} |
|||
|
|||
// copied from net/http/server.go
|
|||
|
|||
// foreachHeaderElement splits v according to the "#rule" construction
|
|||
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
|
|||
func foreachHeaderElement(v string, fn func(string)) { |
|||
v = textproto.TrimString(v) |
|||
if v == "" { |
|||
return |
|||
} |
|||
if !strings.Contains(v, ",") { |
|||
fn(v) |
|||
return |
|||
} |
|||
for _, f := range strings.Split(v, ",") { |
|||
if f = textproto.TrimString(f); f != "" { |
|||
fn(f) |
|||
} |
|||
} |
|||
} |
|||
@ -0,0 +1,9 @@ |
|||
// +build go1.7
|
|||
|
|||
package h2quic |
|||
|
|||
import "net/http" |
|||
|
|||
func setUncompressed(res *http.Response) { |
|||
res.Uncompressed = true |
|||
} |
|||
@ -0,0 +1,9 @@ |
|||
// +build !go1.7
|
|||
|
|||
package h2quic |
|||
|
|||
import "net/http" |
|||
|
|||
func setUncompressed(res *http.Response) { |
|||
// http.Response.Uncompressed was introduced in go 1.7
|
|||
} |
|||
@ -0,0 +1,135 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"crypto/tls" |
|||
"errors" |
|||
"fmt" |
|||
"net/http" |
|||
"strings" |
|||
"sync" |
|||
|
|||
"golang.org/x/net/lex/httplex" |
|||
) |
|||
|
|||
type h2quicClient interface { |
|||
Do(*http.Request) (*http.Response, error) |
|||
} |
|||
|
|||
// QuicRoundTripper implements the http.RoundTripper interface
|
|||
type QuicRoundTripper struct { |
|||
mutex sync.Mutex |
|||
|
|||
// DisableCompression, if true, prevents the Transport from
|
|||
// requesting compression with an "Accept-Encoding: gzip"
|
|||
// request header when the Request contains no existing
|
|||
// Accept-Encoding value. If the Transport requests gzip on
|
|||
// its own and gets a gzipped response, it's transparently
|
|||
// decoded in the Response.Body. However, if the user
|
|||
// explicitly requested gzip it is not automatically
|
|||
// uncompressed.
|
|||
DisableCompression bool |
|||
|
|||
// TLSClientConfig specifies the TLS configuration to use with
|
|||
// tls.Client. If nil, the default configuration is used.
|
|||
TLSClientConfig *tls.Config |
|||
|
|||
clients map[string]h2quicClient |
|||
} |
|||
|
|||
var _ http.RoundTripper = &QuicRoundTripper{} |
|||
|
|||
// RoundTrip does a round trip
|
|||
func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { |
|||
if req.URL == nil { |
|||
closeRequestBody(req) |
|||
return nil, errors.New("quic: nil Request.URL") |
|||
} |
|||
if req.URL.Host == "" { |
|||
closeRequestBody(req) |
|||
return nil, errors.New("quic: no Host in request URL") |
|||
} |
|||
if req.Header == nil { |
|||
closeRequestBody(req) |
|||
return nil, errors.New("quic: nil Request.Header") |
|||
} |
|||
|
|||
if req.URL.Scheme == "https" { |
|||
for k, vv := range req.Header { |
|||
if !httplex.ValidHeaderFieldName(k) { |
|||
return nil, fmt.Errorf("quic: invalid http header field name %q", k) |
|||
} |
|||
for _, v := range vv { |
|||
if !httplex.ValidHeaderFieldValue(v) { |
|||
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k) |
|||
} |
|||
} |
|||
} |
|||
} else { |
|||
closeRequestBody(req) |
|||
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme) |
|||
} |
|||
|
|||
if req.Method != "" && !validMethod(req.Method) { |
|||
closeRequestBody(req) |
|||
return nil, fmt.Errorf("quic: invalid method %q", req.Method) |
|||
} |
|||
|
|||
hostname := authorityAddr("https", hostnameFromRequest(req)) |
|||
client, err := r.getClient(hostname) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return client.Do(req) |
|||
} |
|||
|
|||
func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { |
|||
r.mutex.Lock() |
|||
defer r.mutex.Unlock() |
|||
|
|||
if r.clients == nil { |
|||
r.clients = make(map[string]h2quicClient) |
|||
} |
|||
|
|||
client, ok := r.clients[hostname] |
|||
if !ok { |
|||
var err error |
|||
client, err = NewClient(r, r.TLSClientConfig, hostname) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
r.clients[hostname] = client |
|||
} |
|||
return client, nil |
|||
} |
|||
|
|||
func (r *QuicRoundTripper) disableCompression() bool { |
|||
return r.DisableCompression |
|||
} |
|||
|
|||
func closeRequestBody(req *http.Request) { |
|||
if req.Body != nil { |
|||
req.Body.Close() |
|||
} |
|||
} |
|||
|
|||
func validMethod(method string) bool { |
|||
/* |
|||
Method = "OPTIONS" ; Section 9.2 |
|||
| "GET" ; Section 9.3 |
|||
| "HEAD" ; Section 9.4 |
|||
| "POST" ; Section 9.5 |
|||
| "PUT" ; Section 9.6 |
|||
| "DELETE" ; Section 9.7 |
|||
| "TRACE" ; Section 9.8 |
|||
| "CONNECT" ; Section 9.9 |
|||
| extension-method |
|||
extension-method = token |
|||
token = 1*<any CHAR except CTLs or separators> |
|||
*/ |
|||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 |
|||
} |
|||
|
|||
// copied from net/http/http.go
|
|||
func isNotToken(r rune) bool { |
|||
return !httplex.IsTokenRune(r) |
|||
} |
|||
@ -0,0 +1,485 @@ |
|||
package handshake |
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto/rand" |
|||
"crypto/tls" |
|||
"encoding/binary" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"sync" |
|||
"time" |
|||
|
|||
"github.com/lucas-clemente/quic-go/crypto" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
"github.com/lucas-clemente/quic-go/qerr" |
|||
"github.com/lucas-clemente/quic-go/utils" |
|||
) |
|||
|
|||
type cryptoSetupClient struct { |
|||
mutex sync.RWMutex |
|||
|
|||
hostname string |
|||
connID protocol.ConnectionID |
|||
version protocol.VersionNumber |
|||
negotiatedVersions []protocol.VersionNumber |
|||
|
|||
cryptoStream utils.Stream |
|||
|
|||
serverConfig *serverConfigClient |
|||
|
|||
stk []byte |
|||
sno []byte |
|||
nonc []byte |
|||
proof []byte |
|||
diversificationNonce []byte |
|||
chloForSignature []byte |
|||
lastSentCHLO []byte |
|||
certManager crypto.CertManager |
|||
|
|||
clientHelloCounter int |
|||
serverVerified bool // has the certificate chain and the proof already been verified
|
|||
keyDerivation KeyDerivationFunction |
|||
|
|||
receivedSecurePacket bool |
|||
secureAEAD crypto.AEAD |
|||
forwardSecureAEAD crypto.AEAD |
|||
aeadChanged chan struct{} |
|||
|
|||
connectionParameters ConnectionParametersManager |
|||
} |
|||
|
|||
var _ crypto.AEAD = &cryptoSetupClient{} |
|||
var _ CryptoSetup = &cryptoSetupClient{} |
|||
|
|||
var ( |
|||
errNoObitForClientNonce = errors.New("CryptoSetup BUG: No OBIT for client nonce available") |
|||
errClientNonceAlreadyExists = errors.New("CryptoSetup BUG: A client nonce was already generated") |
|||
errConflictingDiversificationNonces = errors.New("Received two different diversification nonces") |
|||
) |
|||
|
|||
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
|
|||
func NewCryptoSetupClient( |
|||
hostname string, |
|||
connID protocol.ConnectionID, |
|||
version protocol.VersionNumber, |
|||
cryptoStream utils.Stream, |
|||
tlsConfig *tls.Config, |
|||
connectionParameters ConnectionParametersManager, |
|||
aeadChanged chan struct{}, |
|||
negotiatedVersions []protocol.VersionNumber, |
|||
) (CryptoSetup, error) { |
|||
return &cryptoSetupClient{ |
|||
hostname: hostname, |
|||
connID: connID, |
|||
version: version, |
|||
cryptoStream: cryptoStream, |
|||
certManager: crypto.NewCertManager(tlsConfig), |
|||
connectionParameters: connectionParameters, |
|||
keyDerivation: crypto.DeriveKeysAESGCM, |
|||
aeadChanged: aeadChanged, |
|||
negotiatedVersions: negotiatedVersions, |
|||
}, nil |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) HandleCryptoStream() error { |
|||
for { |
|||
err := h.maybeUpgradeCrypto() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
// send CHLOs until the forward secure encryption is established
|
|||
if h.forwardSecureAEAD == nil { |
|||
err = h.sendCHLO() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
var shloData bytes.Buffer |
|||
|
|||
messageTag, cryptoData, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &shloData)) |
|||
if err != nil { |
|||
return qerr.HandshakeFailed |
|||
} |
|||
|
|||
if messageTag != TagSHLO && messageTag != TagREJ { |
|||
return qerr.InvalidCryptoMessageType |
|||
} |
|||
|
|||
if messageTag == TagSHLO { |
|||
utils.Debugf("Got SHLO:\n%s", printHandshakeMessage(cryptoData)) |
|||
err = h.handleSHLOMessage(cryptoData) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
if messageTag == TagREJ { |
|||
err = h.handleREJMessage(cryptoData) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { |
|||
utils.Debugf("Got REJ:\n%s", printHandshakeMessage(cryptoData)) |
|||
|
|||
var err error |
|||
|
|||
if stk, ok := cryptoData[TagSTK]; ok { |
|||
h.stk = stk |
|||
} |
|||
|
|||
if sno, ok := cryptoData[TagSNO]; ok { |
|||
h.sno = sno |
|||
} |
|||
|
|||
// TODO: what happens if the server sends a different server config in two packets?
|
|||
if scfg, ok := cryptoData[TagSCFG]; ok { |
|||
h.serverConfig, err = parseServerConfig(scfg) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
if h.serverConfig.IsExpired() { |
|||
return qerr.CryptoServerConfigExpired |
|||
} |
|||
|
|||
// now that we have a server config, we can use its OBIT value to generate a client nonce
|
|||
if len(h.nonc) == 0 { |
|||
err = h.generateClientNonce() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
} |
|||
|
|||
if proof, ok := cryptoData[TagPROF]; ok { |
|||
h.proof = proof |
|||
h.chloForSignature = h.lastSentCHLO |
|||
} |
|||
|
|||
if crt, ok := cryptoData[TagCERT]; ok { |
|||
err := h.certManager.SetData(crt) |
|||
if err != nil { |
|||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid") |
|||
} |
|||
|
|||
err = h.certManager.Verify(h.hostname) |
|||
if err != nil { |
|||
utils.Infof("Certificate validation failed: %s", err.Error()) |
|||
return qerr.ProofInvalid |
|||
} |
|||
} |
|||
|
|||
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil { |
|||
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get()) |
|||
if !validProof { |
|||
utils.Infof("Server proof verification failed") |
|||
return qerr.ProofInvalid |
|||
} |
|||
|
|||
h.serverVerified = true |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { |
|||
h.mutex.Lock() |
|||
defer h.mutex.Unlock() |
|||
|
|||
if !h.receivedSecurePacket { |
|||
return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message") |
|||
} |
|||
|
|||
if sno, ok := cryptoData[TagSNO]; ok { |
|||
h.sno = sno |
|||
} |
|||
|
|||
serverPubs, ok := cryptoData[TagPUBS] |
|||
if !ok { |
|||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") |
|||
} |
|||
|
|||
verTag, ok := cryptoData[TagVER] |
|||
if !ok { |
|||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list") |
|||
} |
|||
if !h.validateVersionList(verTag) { |
|||
return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") |
|||
} |
|||
|
|||
nonce := append(h.nonc, h.sno...) |
|||
|
|||
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
leafCert := h.certManager.GetLeafCert() |
|||
|
|||
h.forwardSecureAEAD, err = h.keyDerivation( |
|||
true, |
|||
ephermalSharedSecret, |
|||
nonce, |
|||
h.connID, |
|||
h.lastSentCHLO, |
|||
h.serverConfig.Get(), |
|||
leafCert, |
|||
nil, |
|||
protocol.PerspectiveClient, |
|||
) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
err = h.connectionParameters.SetFromMap(cryptoData) |
|||
if err != nil { |
|||
return qerr.InvalidCryptoMessageParameter |
|||
} |
|||
|
|||
h.aeadChanged <- struct{}{} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool { |
|||
if len(h.negotiatedVersions) == 0 { |
|||
return true |
|||
} |
|||
if len(verTags)%4 != 0 || len(verTags)/4 != len(h.negotiatedVersions) { |
|||
return false |
|||
} |
|||
|
|||
b := bytes.NewReader(verTags) |
|||
for _, negotiatedVersion := range h.negotiatedVersions { |
|||
verTag, err := utils.ReadUint32(b) |
|||
if err != nil { // should never occur, since the length was already checked
|
|||
return false |
|||
} |
|||
ver := protocol.VersionTagToNumber(verTag) |
|||
if !protocol.IsSupportedVersion(ver) { |
|||
ver = protocol.VersionUnsupported |
|||
} |
|||
if ver != negotiatedVersion { |
|||
return false |
|||
} |
|||
} |
|||
return true |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { |
|||
if h.forwardSecureAEAD != nil { |
|||
data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) |
|||
if err == nil { |
|||
return data, nil |
|||
} |
|||
return nil, err |
|||
} |
|||
|
|||
if h.secureAEAD != nil { |
|||
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) |
|||
if err == nil { |
|||
h.receivedSecurePacket = true |
|||
return data, nil |
|||
} |
|||
if h.receivedSecurePacket { |
|||
return nil, err |
|||
} |
|||
} |
|||
|
|||
return (&crypto.NullAEAD{}).Open(dst, src, packetNumber, associatedData) |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { |
|||
if h.forwardSecureAEAD != nil { |
|||
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) |
|||
} |
|||
if h.secureAEAD != nil { |
|||
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) |
|||
} |
|||
return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData) |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) DiversificationNonce() []byte { |
|||
panic("not needed for cryptoSetupClient") |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) error { |
|||
if len(h.diversificationNonce) == 0 { |
|||
h.diversificationNonce = data |
|||
return h.maybeUpgradeCrypto() |
|||
} |
|||
if !bytes.Equal(h.diversificationNonce, data) { |
|||
return errConflictingDiversificationNonces |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) LockForSealing() { |
|||
|
|||
} |
|||
|
|||
func (h *cryptoSetupClient) UnlockForSealing() { |
|||
|
|||
} |
|||
|
|||
func (h *cryptoSetupClient) HandshakeComplete() bool { |
|||
h.mutex.RLock() |
|||
complete := h.forwardSecureAEAD != nil |
|||
h.mutex.RUnlock() |
|||
return complete |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) sendCHLO() error { |
|||
h.clientHelloCounter++ |
|||
if h.clientHelloCounter > protocol.MaxClientHellos { |
|||
return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos)) |
|||
} |
|||
|
|||
b := &bytes.Buffer{} |
|||
|
|||
tags, err := h.getTags() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
h.addPadding(tags) |
|||
|
|||
utils.Debugf("Sending CHLO:\n%s", printHandshakeMessage(tags)) |
|||
WriteHandshakeMessage(b, TagCHLO, tags) |
|||
|
|||
_, err = h.cryptoStream.Write(b.Bytes()) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
h.lastSentCHLO = b.Bytes() |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { |
|||
tags, err := h.connectionParameters.GetHelloMap() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
tags[TagSNI] = []byte(h.hostname) |
|||
tags[TagPDMD] = []byte("X509") |
|||
|
|||
ccs := h.certManager.GetCommonCertificateHashes() |
|||
if len(ccs) > 0 { |
|||
tags[TagCCS] = ccs |
|||
} |
|||
|
|||
versionTag := make([]byte, 4, 4) |
|||
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version)) |
|||
tags[TagVER] = versionTag |
|||
|
|||
if len(h.stk) > 0 { |
|||
tags[TagSTK] = h.stk |
|||
} |
|||
|
|||
if len(h.sno) > 0 { |
|||
tags[TagSNO] = h.sno |
|||
} |
|||
|
|||
if h.serverConfig != nil { |
|||
tags[TagSCID] = h.serverConfig.ID |
|||
|
|||
leafCert := h.certManager.GetLeafCert() |
|||
if leafCert != nil { |
|||
certHash, _ := h.certManager.GetLeafCertHash() |
|||
xlct := make([]byte, 8, 8) |
|||
binary.LittleEndian.PutUint64(xlct, certHash) |
|||
|
|||
tags[TagNONC] = h.nonc |
|||
tags[TagXLCT] = xlct |
|||
tags[TagKEXS] = []byte("C255") |
|||
tags[TagAEAD] = []byte("AESG") |
|||
tags[TagPUBS] = h.serverConfig.kex.PublicKey() // TODO: check if 3 bytes need to be prepended
|
|||
} |
|||
} |
|||
|
|||
return tags, nil |
|||
} |
|||
|
|||
// add a TagPAD to a tagMap, such that the total size will be bigger than the ClientHelloMinimumSize
|
|||
func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) { |
|||
var size int |
|||
for _, tag := range tags { |
|||
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
|||
} |
|||
paddingSize := protocol.ClientHelloMinimumSize - size |
|||
if paddingSize > 0 { |
|||
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize) |
|||
} |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) maybeUpgradeCrypto() error { |
|||
if !h.serverVerified { |
|||
return nil |
|||
} |
|||
|
|||
h.mutex.Lock() |
|||
defer h.mutex.Unlock() |
|||
|
|||
leafCert := h.certManager.GetLeafCert() |
|||
|
|||
if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) { |
|||
var err error |
|||
var nonce []byte |
|||
if h.sno == nil { |
|||
nonce = h.nonc |
|||
} else { |
|||
nonce = append(h.nonc, h.sno...) |
|||
} |
|||
|
|||
h.secureAEAD, err = h.keyDerivation( |
|||
false, |
|||
h.serverConfig.sharedSecret, |
|||
nonce, |
|||
h.connID, |
|||
h.lastSentCHLO, |
|||
h.serverConfig.Get(), |
|||
leafCert, |
|||
h.diversificationNonce, |
|||
protocol.PerspectiveClient, |
|||
) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
h.aeadChanged <- struct{}{} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (h *cryptoSetupClient) generateClientNonce() error { |
|||
if len(h.nonc) > 0 { |
|||
return errClientNonceAlreadyExists |
|||
} |
|||
|
|||
nonc := make([]byte, 32) |
|||
binary.BigEndian.PutUint32(nonc, uint32(time.Now().Unix())) |
|||
|
|||
if len(h.serverConfig.obit) != 8 { |
|||
return errNoObitForClientNonce |
|||
} |
|||
|
|||
copy(nonc[4:12], h.serverConfig.obit) |
|||
|
|||
_, err := rand.Read(nonc[12:]) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
h.nonc = nonc |
|||
return nil |
|||
} |
|||
@ -0,0 +1,16 @@ |
|||
package handshake |
|||
|
|||
import "github.com/lucas-clemente/quic-go/protocol" |
|||
|
|||
// CryptoSetup is a crypto setup
|
|||
type CryptoSetup interface { |
|||
HandleCryptoStream() error |
|||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) |
|||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte |
|||
LockForSealing() |
|||
UnlockForSealing() |
|||
HandshakeComplete() bool |
|||
// TODO: clean up this interface
|
|||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
|||
SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient
|
|||
} |
|||
@ -0,0 +1,148 @@ |
|||
package handshake |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/binary" |
|||
"errors" |
|||
"math" |
|||
"time" |
|||
|
|||
"github.com/lucas-clemente/quic-go/crypto" |
|||
"github.com/lucas-clemente/quic-go/qerr" |
|||
"github.com/lucas-clemente/quic-go/utils" |
|||
) |
|||
|
|||
type serverConfigClient struct { |
|||
raw []byte |
|||
ID []byte |
|||
obit []byte |
|||
expiry time.Time |
|||
|
|||
kex crypto.KeyExchange |
|||
sharedSecret []byte |
|||
} |
|||
|
|||
var ( |
|||
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG") |
|||
) |
|||
|
|||
// parseServerConfig parses a server config
|
|||
func parseServerConfig(data []byte) (*serverConfigClient, error) { |
|||
tag, tagMap, err := ParseHandshakeMessage(bytes.NewReader(data)) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if tag != TagSCFG { |
|||
return nil, errMessageNotServerConfig |
|||
} |
|||
|
|||
scfg := &serverConfigClient{raw: data} |
|||
err = scfg.parseValues(tagMap) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return scfg, nil |
|||
} |
|||
|
|||
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error { |
|||
// SCID
|
|||
scfgID, ok := tagMap[TagSCID] |
|||
if !ok { |
|||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID") |
|||
} |
|||
if len(scfgID) != 16 { |
|||
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID") |
|||
} |
|||
s.ID = scfgID |
|||
|
|||
// KEXS
|
|||
// TODO: allow for P256 in the list
|
|||
// TODO: setup Key Exchange
|
|||
kexs, ok := tagMap[TagKEXS] |
|||
if !ok { |
|||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS") |
|||
} |
|||
if len(kexs)%4 != 0 { |
|||
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS") |
|||
} |
|||
if !bytes.Equal(kexs, []byte("C255")) { |
|||
return qerr.Error(qerr.CryptoNoSupport, "KEXS") |
|||
} |
|||
|
|||
// AEAD
|
|||
aead, ok := tagMap[TagAEAD] |
|||
if !ok { |
|||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD") |
|||
} |
|||
if len(aead)%4 != 0 { |
|||
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD") |
|||
} |
|||
var aesgFound bool |
|||
for i := 0; i < len(aead)/4; i++ { |
|||
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) { |
|||
aesgFound = true |
|||
break |
|||
} |
|||
} |
|||
if !aesgFound { |
|||
return qerr.Error(qerr.CryptoNoSupport, "AEAD") |
|||
} |
|||
|
|||
// PUBS
|
|||
// TODO: save this value
|
|||
pubs, ok := tagMap[TagPUBS] |
|||
if !ok { |
|||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") |
|||
} |
|||
if len(pubs) != 35 { |
|||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") |
|||
} |
|||
|
|||
var err error |
|||
s.kex, err = crypto.NewCurve25519KEX() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
// the PUBS value is always prepended by []byte{0x20, 0x00, 0x00}
|
|||
s.sharedSecret, err = s.kex.CalculateSharedKey(pubs[3:]) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
// OBIT
|
|||
obit, ok := tagMap[TagOBIT] |
|||
if !ok { |
|||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT") |
|||
} |
|||
if len(obit) != 8 { |
|||
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT") |
|||
} |
|||
s.obit = obit |
|||
|
|||
// EXPY
|
|||
expy, ok := tagMap[TagEXPY] |
|||
if !ok { |
|||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY") |
|||
} |
|||
if len(expy) != 8 { |
|||
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY") |
|||
} |
|||
// make sure that the value doesn't overflow an int64
|
|||
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
|
|||
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2) |
|||
s.expiry = time.Unix(int64(expyTimestamp), 0) |
|||
|
|||
// TODO: implement VER
|
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (s *serverConfigClient) IsExpired() bool { |
|||
return s.expiry.Before(time.Now()) |
|||
} |
|||
|
|||
func (s *serverConfigClient) Get() []byte { |
|||
return s.raw |
|||
} |
|||
@ -0,0 +1,14 @@ |
|||
package protocol |
|||
|
|||
// EncryptionLevel is the encryption level
|
|||
// Default value is Unencrypted
|
|||
type EncryptionLevel int |
|||
|
|||
const ( |
|||
// Unencrypted is not encrypted
|
|||
Unencrypted EncryptionLevel = iota |
|||
// EncryptionSecure is encrypted, but not forward secure
|
|||
EncryptionSecure |
|||
// EncryptionForwardSecure is forward secure
|
|||
EncryptionForwardSecure |
|||
) |
|||
@ -0,0 +1,10 @@ |
|||
package protocol |
|||
|
|||
// Perspective determines if we're acting as a server or a client
|
|||
type Perspective int |
|||
|
|||
// the perspectives
|
|||
const ( |
|||
PerspectiveServer Perspective = 1 |
|||
PerspectiveClient Perspective = 2 |
|||
) |
|||
@ -0,0 +1,27 @@ |
|||
package quic |
|||
|
|||
import "github.com/lucas-clemente/quic-go/frames" |
|||
|
|||
type unpackedPacket struct { |
|||
frames []frames.Frame |
|||
} |
|||
|
|||
func (u *unpackedPacket) IsRetransmittable() bool { |
|||
for _, f := range u.frames { |
|||
switch f.(type) { |
|||
case *frames.StreamFrame: |
|||
return true |
|||
case *frames.RstStreamFrame: |
|||
return true |
|||
case *frames.WindowUpdateFrame: |
|||
return true |
|||
case *frames.BlockedFrame: |
|||
return true |
|||
case *frames.PingFrame: |
|||
return true |
|||
case *frames.GoawayFrame: |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
@ -0,0 +1,22 @@ |
|||
package utils |
|||
|
|||
import "sync/atomic" |
|||
|
|||
// An AtomicBool is an atomic bool
|
|||
type AtomicBool struct { |
|||
v int32 |
|||
} |
|||
|
|||
// Set sets the value
|
|||
func (a *AtomicBool) Set(value bool) { |
|||
var n int32 |
|||
if value { |
|||
n = 1 |
|||
} |
|||
atomic.StoreInt32(&a.v, n) |
|||
} |
|||
|
|||
// Get gets the value
|
|||
func (a *AtomicBool) Get() bool { |
|||
return atomic.LoadInt32(&a.v) != 0 |
|||
} |
|||
@ -0,0 +1,18 @@ |
|||
package utils |
|||
|
|||
import ( |
|||
"crypto/rand" |
|||
"encoding/binary" |
|||
|
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
) |
|||
|
|||
// GenerateConnectionID generates a connection ID using cryptographic random
|
|||
func GenerateConnectionID() (protocol.ConnectionID, error) { |
|||
b := make([]byte, 8, 8) |
|||
_, err := rand.Read(b) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
return protocol.ConnectionID(binary.LittleEndian.Uint64(b)), nil |
|||
} |
|||
@ -0,0 +1,27 @@ |
|||
package utils |
|||
|
|||
import ( |
|||
"net/url" |
|||
"strings" |
|||
) |
|||
|
|||
// HostnameFromAddr determines the hostname in an address string
|
|||
func HostnameFromAddr(addr string) (string, error) { |
|||
p, err := url.Parse(addr) |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
h := p.Host |
|||
|
|||
// copied from https://golang.org/src/net/http/transport.go
|
|||
if hasPort(h) { |
|||
h = h[:strings.LastIndex(h, ":")] |
|||
} |
|||
|
|||
return h, nil |
|||
} |
|||
|
|||
// copied from https://golang.org/src/net/http/http.go
|
|||
func hasPort(s string) bool { |
|||
return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") |
|||
} |
|||
@ -0,0 +1,17 @@ |
|||
package utils |
|||
|
|||
import ( |
|||
"io" |
|||
|
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
) |
|||
|
|||
// Stream is the interface for QUIC streams
|
|||
type Stream interface { |
|||
io.Reader |
|||
io.Writer |
|||
io.Closer |
|||
StreamID() protocol.StreamID |
|||
CloseRemote(offset protocol.ByteCount) |
|||
Reset(error) |
|||
} |
|||
Loading…
Reference in new issue