mirror of https://github.com/ginuerzh/gost
172 changed files with 13597 additions and 3895 deletions
@ -0,0 +1,21 @@ |
|||
The MIT License (MIT) |
|||
|
|||
Copyright (c) 2016 Andreas Auernhammer |
|||
|
|||
Permission is hereby granted, free of charge, to any person obtaining a copy |
|||
of this software and associated documentation files (the "Software"), to deal |
|||
in the Software without restriction, including without limitation the rights |
|||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|||
copies of the Software, and to permit persons to whom the Software is |
|||
furnished to do so, subject to the following conditions: |
|||
|
|||
The above copyright notice and this permission notice shall be included in all |
|||
copies or substantial portions of the Software. |
|||
|
|||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|||
SOFTWARE. |
|||
@ -0,0 +1,79 @@ |
|||
[](https://godoc.org/github.com/aead/chacha20) |
|||
|
|||
## The ChaCha20 stream cipher |
|||
|
|||
ChaCha is a stream cipher family created by Daniel J. Bernstein. |
|||
The most common ChaCha cipher is ChaCha20 (20 rounds). ChaCha20 is standardized in [RFC 7539](https://tools.ietf.org/html/rfc7539 "RFC 7539"). |
|||
|
|||
This package provides implementations of three ChaCha versions: |
|||
- ChaCha20 with a 64 bit nonce (can en/decrypt up to 2^64 * 64 bytes for one key-nonce combination) |
|||
- ChaCha20 with a 96 bit nonce (can en/decrypt up to 2^32 * 64 bytes ~ 256 GB for one key-nonce combination) |
|||
- XChaCha20 with a 192 bit nonce (can en/decrypt up to 2^64 * 64 bytes for one key-nonce combination) |
|||
|
|||
Furthermore the chacha subpackage implements ChaCha20/12 and ChaCha20/8. |
|||
These versions use 12 or 8 rounds instead of 20. |
|||
But it's recommended to use ChaCha20 (with 20 rounds) - it will be fast enough for almost all purposes. |
|||
|
|||
### Installation |
|||
Install in your GOPATH: `go get -u github.com/aead/chacha20` |
|||
|
|||
### Requirements |
|||
All go versions >= 1.5.3 are supported. |
|||
Please notice, that the amd64 AVX2 asm implementation requires go1.7 or newer. |
|||
|
|||
### Performance |
|||
|
|||
#### AMD64 |
|||
Hardware: Intel i7-6500U 2.50GHz x 2 |
|||
System: Linux Ubuntu 16.04 - kernel: 4.4.0-62-generic |
|||
Go version: 1.8.0 |
|||
``` |
|||
AVX2 |
|||
name speed cpb |
|||
ChaCha20_64-4 573MB/s ± 0% 4.16 |
|||
ChaCha20_1K-4 2.19GB/s ± 0% 1.06 |
|||
XChaCha20_64-4 261MB/s ± 0% 9.13 |
|||
XChaCha20_1K-4 1.69GB/s ± 4% 1.37 |
|||
XORKeyStream64-4 474MB/s ± 2% 5.02 |
|||
XORKeyStream1K-4 2.09GB/s ± 1% 1.11 |
|||
XChaCha20_XORKeyStream64-4 262MB/s ± 0% 9.09 |
|||
XChaCha20_XORKeyStream1K-4 1.71GB/s ± 1% 1.36 |
|||
|
|||
SSSE3 |
|||
name speed cpb |
|||
ChaCha20_64-4 583MB/s ± 0% 4.08 |
|||
ChaCha20_1K-4 1.15GB/s ± 1% 2.02 |
|||
XChaCha20_64-4 267MB/s ± 0% 8.92 |
|||
XChaCha20_1K-4 984MB/s ± 5% 2.42 |
|||
XORKeyStream64-4 492MB/s ± 1% 4.84 |
|||
XORKeyStream1K-4 1.10GB/s ± 5% 2.11 |
|||
XChaCha20_XORKeyStream64-4 266MB/s ± 0% 8.96 |
|||
XChaCha20_XORKeyStream1K-4 1.00GB/s ± 2% 2.32 |
|||
``` |
|||
#### 386 |
|||
Hardware: Intel i7-6500U 2.50GHz x 2 |
|||
System: Linux Ubuntu 16.04 - kernel: 4.4.0-62-generic |
|||
Go version: 1.8.0 |
|||
``` |
|||
SSSE3 |
|||
name speed cpb |
|||
ChaCha20_64-4 570MB/s ± 0% 4.18 |
|||
ChaCha20_1K-4 650MB/s ± 0% 3.66 |
|||
XChaCha20_64-4 223MB/s ± 0% 10.69 |
|||
XChaCha20_1K-4 584MB/s ± 1% 4.08 |
|||
XORKeyStream64-4 392MB/s ± 1% 6.08 |
|||
XORKeyStream1K-4 629MB/s ± 1% 3.79 |
|||
XChaCha20_XORKeyStream64-4 222MB/s ± 0% 10.73 |
|||
XChaCha20_XORKeyStream1K-4 585MB/s ± 0% 4.07 |
|||
|
|||
SSE2 |
|||
name speed cpb |
|||
ChaCha20_64-4 509MB/s ± 0% 4.68 |
|||
ChaCha20_1K-4 553MB/s ± 2% 4.31 |
|||
XChaCha20_64-4 201MB/s ± 0% 11.86 |
|||
XChaCha20_1K-4 498MB/s ± 4% 4.78 |
|||
XORKeyStream64-4 359MB/s ± 1% 6.64 |
|||
XORKeyStream1K-4 545MB/s ± 0% 4.37 |
|||
XChaCha20_XORKeyStream64-4 201MB/s ± 1% 11.86 |
|||
XChaCha20_XORKeyStream1K-4 507MB/s ± 0% 4.70 |
|||
``` |
|||
@ -0,0 +1,176 @@ |
|||
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
|
|||
// Use of this source code is governed by a license that can be
|
|||
// found in the LICENSE file.
|
|||
|
|||
// Package chacha implements some low-level functions of the
|
|||
// ChaCha cipher family.
|
|||
package chacha // import "github.com/aead/chacha20/chacha"
|
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"errors" |
|||
) |
|||
|
|||
const ( |
|||
// NonceSize is the size of the ChaCha20 nonce in bytes.
|
|||
NonceSize = 8 |
|||
|
|||
// INonceSize is the size of the IETF-ChaCha20 nonce in bytes.
|
|||
INonceSize = 12 |
|||
|
|||
// XNonceSize is the size of the XChaCha20 nonce in bytes.
|
|||
XNonceSize = 24 |
|||
|
|||
// KeySize is the size of the key in bytes.
|
|||
KeySize = 32 |
|||
) |
|||
|
|||
var ( |
|||
useSSE2 bool |
|||
useSSSE3 bool |
|||
useAVX2 bool |
|||
) |
|||
|
|||
var ( |
|||
errKeySize = errors.New("chacha20/chacha: bad key length") |
|||
errInvalidNonce = errors.New("chacha20/chacha: bad nonce length") |
|||
) |
|||
|
|||
func setup(state *[64]byte, nonce, key []byte) (err error) { |
|||
if len(key) != KeySize { |
|||
err = errKeySize |
|||
return |
|||
} |
|||
var Nonce [16]byte |
|||
switch len(nonce) { |
|||
case NonceSize: |
|||
copy(Nonce[8:], nonce) |
|||
initialize(state, key, &Nonce) |
|||
case INonceSize: |
|||
copy(Nonce[4:], nonce) |
|||
initialize(state, key, &Nonce) |
|||
case XNonceSize: |
|||
var tmpKey [32]byte |
|||
var hNonce [16]byte |
|||
|
|||
copy(hNonce[:], nonce[:16]) |
|||
copy(tmpKey[:], key) |
|||
hChaCha20(&tmpKey, &hNonce, &tmpKey) |
|||
copy(Nonce[8:], nonce[16:]) |
|||
initialize(state, tmpKey[:], &Nonce) |
|||
|
|||
// BUG(aead): A "good" compiler will remove this (optimizations)
|
|||
// But using the provided key instead of tmpKey,
|
|||
// will change the key (-> probably confuses users)
|
|||
for i := range tmpKey { |
|||
tmpKey[i] = 0 |
|||
} |
|||
default: |
|||
err = errInvalidNonce |
|||
} |
|||
return |
|||
} |
|||
|
|||
// XORKeyStream crypts bytes from src to dst using the given nonce and key.
|
|||
// The length of the nonce determinds the version of ChaCha20:
|
|||
// - NonceSize: ChaCha20/r with a 64 bit nonce and a 2^64 * 64 byte period.
|
|||
// - INonceSize: ChaCha20/r as defined in RFC 7539 and a 2^32 * 64 byte period.
|
|||
// - XNonceSize: XChaCha20/r with a 192 bit nonce and a 2^64 * 64 byte period.
|
|||
// The rounds argument specifies the number of rounds performed for keystream
|
|||
// generation - valid values are 8, 12 or 20. The src and dst may be the same slice
|
|||
// but otherwise should not overlap. If len(dst) < len(src) this function panics.
|
|||
// If the nonce is neither 64, 96 nor 192 bits long, this function panics.
|
|||
func XORKeyStream(dst, src, nonce, key []byte, rounds int) { |
|||
if rounds != 20 && rounds != 12 && rounds != 8 { |
|||
panic("chacha20/chacha: bad number of rounds") |
|||
} |
|||
if len(dst) < len(src) { |
|||
panic("chacha20/chacha: dst buffer is to small") |
|||
} |
|||
if len(nonce) == INonceSize && uint64(len(src)) > (1<<38) { |
|||
panic("chacha20/chacha: src is too large") |
|||
} |
|||
|
|||
var block, state [64]byte |
|||
if err := setup(&state, nonce, key); err != nil { |
|||
panic(err) |
|||
} |
|||
xorKeyStream(dst, src, &block, &state, rounds) |
|||
} |
|||
|
|||
// Cipher implements ChaCha20/r (XChaCha20/r) for a given number of rounds r.
|
|||
type Cipher struct { |
|||
state, block [64]byte |
|||
off int |
|||
rounds int // 20 for ChaCha20
|
|||
noncesize int |
|||
} |
|||
|
|||
// NewCipher returns a new *chacha.Cipher implementing the ChaCha20/r or XChaCha20/r
|
|||
// (r = 8, 12 or 20) stream cipher. The nonce must be unique for one key for all time.
|
|||
// The length of the nonce determinds the version of ChaCha20:
|
|||
// - NonceSize: ChaCha20/r with a 64 bit nonce and a 2^64 * 64 byte period.
|
|||
// - INonceSize: ChaCha20/r as defined in RFC 7539 and a 2^32 * 64 byte period.
|
|||
// - XNonceSize: XChaCha20/r with a 192 bit nonce and a 2^64 * 64 byte period.
|
|||
// If the nonce is neither 64, 96 nor 192 bits long, a non-nil error is returned.
|
|||
func NewCipher(nonce, key []byte, rounds int) (*Cipher, error) { |
|||
if rounds != 20 && rounds != 12 && rounds != 8 { |
|||
panic("chacha20/chacha: bad number of rounds") |
|||
} |
|||
|
|||
c := new(Cipher) |
|||
if err := setup(&(c.state), nonce, key); err != nil { |
|||
return nil, err |
|||
} |
|||
c.rounds = rounds |
|||
|
|||
if len(nonce) == INonceSize { |
|||
c.noncesize = INonceSize |
|||
} else { |
|||
c.noncesize = NonceSize |
|||
} |
|||
|
|||
return c, nil |
|||
} |
|||
|
|||
// XORKeyStream crypts bytes from src to dst. Src and dst may be the same slice
|
|||
// but otherwise should not overlap. If len(dst) < len(src) the function panics.
|
|||
func (c *Cipher) XORKeyStream(dst, src []byte) { |
|||
if len(dst) < len(src) { |
|||
panic("chacha20/chacha: dst buffer is to small") |
|||
} |
|||
|
|||
if c.off > 0 { |
|||
n := len(c.block[c.off:]) |
|||
if len(src) <= n { |
|||
for i, v := range src { |
|||
dst[i] = v ^ c.block[c.off] |
|||
c.off++ |
|||
} |
|||
if c.off == 64 { |
|||
c.off = 0 |
|||
} |
|||
return |
|||
} |
|||
|
|||
for i, v := range c.block[c.off:] { |
|||
dst[i] = src[i] ^ v |
|||
} |
|||
src = src[n:] |
|||
dst = dst[n:] |
|||
c.off = 0 |
|||
} |
|||
|
|||
c.off += xorKeyStream(dst, src, &(c.block), &(c.state), c.rounds) |
|||
} |
|||
|
|||
// SetCounter skips ctr * 64 byte blocks. SetCounter(0) resets the cipher.
|
|||
// This function always skips the unused keystream of the current 64 byte block.
|
|||
func (c *Cipher) SetCounter(ctr uint64) { |
|||
if c.noncesize == INonceSize { |
|||
binary.LittleEndian.PutUint32(c.state[48:], uint32(ctr)) |
|||
} else { |
|||
binary.LittleEndian.PutUint64(c.state[48:], ctr) |
|||
} |
|||
c.off = 0 |
|||
} |
|||
@ -0,0 +1,542 @@ |
|||
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. |
|||
// Use of this source code is governed by a license that can be |
|||
// found in the LICENSE file. |
|||
|
|||
// +build go1.7,amd64,!gccgo,!appengine,!nacl |
|||
|
|||
#include "textflag.h" |
|||
|
|||
DATA ·sigma_AVX<>+0x00(SB)/4, $0x61707865 |
|||
DATA ·sigma_AVX<>+0x04(SB)/4, $0x3320646e |
|||
DATA ·sigma_AVX<>+0x08(SB)/4, $0x79622d32 |
|||
DATA ·sigma_AVX<>+0x0C(SB)/4, $0x6b206574 |
|||
GLOBL ·sigma_AVX<>(SB), (NOPTR+RODATA), $16 |
|||
|
|||
DATA ·one_AVX<>+0x00(SB)/8, $1 |
|||
DATA ·one_AVX<>+0x08(SB)/8, $0 |
|||
GLOBL ·one_AVX<>(SB), (NOPTR+RODATA), $16 |
|||
|
|||
DATA ·one_AVX2<>+0x00(SB)/8, $0 |
|||
DATA ·one_AVX2<>+0x08(SB)/8, $0 |
|||
DATA ·one_AVX2<>+0x10(SB)/8, $1 |
|||
DATA ·one_AVX2<>+0x18(SB)/8, $0 |
|||
GLOBL ·one_AVX2<>(SB), (NOPTR+RODATA), $32 |
|||
|
|||
DATA ·two_AVX2<>+0x00(SB)/8, $2 |
|||
DATA ·two_AVX2<>+0x08(SB)/8, $0 |
|||
DATA ·two_AVX2<>+0x10(SB)/8, $2 |
|||
DATA ·two_AVX2<>+0x18(SB)/8, $0 |
|||
GLOBL ·two_AVX2<>(SB), (NOPTR+RODATA), $32 |
|||
|
|||
DATA ·rol16_AVX2<>+0x00(SB)/8, $0x0504070601000302 |
|||
DATA ·rol16_AVX2<>+0x08(SB)/8, $0x0D0C0F0E09080B0A |
|||
DATA ·rol16_AVX2<>+0x10(SB)/8, $0x0504070601000302 |
|||
DATA ·rol16_AVX2<>+0x18(SB)/8, $0x0D0C0F0E09080B0A |
|||
GLOBL ·rol16_AVX2<>(SB), (NOPTR+RODATA), $32 |
|||
|
|||
DATA ·rol8_AVX2<>+0x00(SB)/8, $0x0605040702010003 |
|||
DATA ·rol8_AVX2<>+0x08(SB)/8, $0x0E0D0C0F0A09080B |
|||
DATA ·rol8_AVX2<>+0x10(SB)/8, $0x0605040702010003 |
|||
DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B |
|||
GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32 |
|||
|
|||
#define ROTL(n, t, v) \ |
|||
VPSLLD $n, v, t; \ |
|||
VPSRLD $(32-n), v, v; \ |
|||
VPXOR v, t, v |
|||
|
|||
#define CHACHA_QROUND(v0, v1, v2, v3, t, c16, c8) \ |
|||
VPADDD v0, v1, v0; \ |
|||
VPXOR v3, v0, v3; \ |
|||
VPSHUFB c16, v3, v3; \ |
|||
VPADDD v2, v3, v2; \ |
|||
VPXOR v1, v2, v1; \ |
|||
ROTL(12, t, v1); \ |
|||
VPADDD v0, v1, v0; \ |
|||
VPXOR v3, v0, v3; \ |
|||
VPSHUFB c8, v3, v3; \ |
|||
VPADDD v2, v3, v2; \ |
|||
VPXOR v1, v2, v1; \ |
|||
ROTL(7, t, v1) |
|||
|
|||
#define CHACHA_SHUFFLE(v1, v2, v3) \ |
|||
VPSHUFD $0x39, v1, v1; \ |
|||
VPSHUFD $0x4E, v2, v2; \ |
|||
VPSHUFD $-109, v3, v3 |
|||
|
|||
#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \ |
|||
VMOVDQU (0+off)(src), t0; \ |
|||
VPERM2I128 $32, v1, v0, t1; \ |
|||
VPXOR t0, t1, t0; \ |
|||
VMOVDQU t0, (0+off)(dst); \ |
|||
VMOVDQU (32+off)(src), t0; \ |
|||
VPERM2I128 $32, v3, v2, t1; \ |
|||
VPXOR t0, t1, t0; \ |
|||
VMOVDQU t0, (32+off)(dst); \ |
|||
VMOVDQU (64+off)(src), t0; \ |
|||
VPERM2I128 $49, v1, v0, t1; \ |
|||
VPXOR t0, t1, t0; \ |
|||
VMOVDQU t0, (64+off)(dst); \ |
|||
VMOVDQU (96+off)(src), t0; \ |
|||
VPERM2I128 $49, v3, v2, t1; \ |
|||
VPXOR t0, t1, t0; \ |
|||
VMOVDQU t0, (96+off)(dst) |
|||
|
|||
#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \ |
|||
VMOVDQU (0+off)(src), t0; \ |
|||
VPERM2I128 $32, v1, v0, t1; \ |
|||
VPXOR t0, t1, t0; \ |
|||
VMOVDQU t0, (0+off)(dst); \ |
|||
VMOVDQU (32+off)(src), t0; \ |
|||
VPERM2I128 $32, v3, v2, t1; \ |
|||
VPXOR t0, t1, t0; \ |
|||
VMOVDQU t0, (32+off)(dst); \ |
|||
|
|||
#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \ |
|||
VPERM2I128 $49, v1, v0, t0; \ |
|||
VMOVDQU t0, 0(dst); \ |
|||
VPERM2I128 $49, v3, v2, t0; \ |
|||
VMOVDQU t0, 32(dst) |
|||
|
|||
#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t0) \ |
|||
VPXOR 0+off(src), v0, t0; \ |
|||
VMOVDQU t0, 0+off(dst); \ |
|||
VPXOR 16+off(src), v1, t0; \ |
|||
VMOVDQU t0, 16+off(dst); \ |
|||
VPXOR 32+off(src), v2, t0; \ |
|||
VMOVDQU t0, 32+off(dst); \ |
|||
VPXOR 48+off(src), v3, t0; \ |
|||
VMOVDQU t0, 48+off(dst) |
|||
|
|||
#define TWO 0(SP) |
|||
#define C16 32(SP) |
|||
#define C8 64(SP) |
|||
#define STATE_0 96(SP) |
|||
#define STATE_1 128(SP) |
|||
#define STATE_2 160(SP) |
|||
#define STATE_3 192(SP) |
|||
#define TMP_0 224(SP) |
|||
#define TMP_1 256(SP) |
|||
|
|||
// func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
TEXT ·xorKeyStreamAVX2(SB), 4, $320-80 |
|||
MOVQ dst_base+0(FP), DI |
|||
MOVQ src_base+24(FP), SI |
|||
MOVQ src_len+32(FP), CX |
|||
MOVQ block+48(FP), BX |
|||
MOVQ state+56(FP), AX |
|||
MOVQ rounds+64(FP), DX |
|||
|
|||
MOVQ SP, R8 |
|||
ADDQ $32, SP |
|||
ANDQ $-32, SP |
|||
|
|||
VMOVDQU 0(AX), Y2 |
|||
VMOVDQU 32(AX), Y3 |
|||
VPERM2I128 $0x22, Y2, Y0, Y0 |
|||
VPERM2I128 $0x33, Y2, Y1, Y1 |
|||
VPERM2I128 $0x22, Y3, Y2, Y2 |
|||
VPERM2I128 $0x33, Y3, Y3, Y3 |
|||
|
|||
TESTQ CX, CX |
|||
JZ done |
|||
|
|||
VMOVDQU ·one_AVX2<>(SB), Y4 |
|||
VPADDD Y4, Y3, Y3 |
|||
|
|||
VMOVDQA Y0, STATE_0 |
|||
VMOVDQA Y1, STATE_1 |
|||
VMOVDQA Y2, STATE_2 |
|||
VMOVDQA Y3, STATE_3 |
|||
|
|||
VMOVDQU ·rol16_AVX2<>(SB), Y4 |
|||
VMOVDQU ·rol8_AVX2<>(SB), Y5 |
|||
VMOVDQU ·two_AVX2<>(SB), Y6 |
|||
VMOVDQA Y4, Y14 |
|||
VMOVDQA Y5, Y15 |
|||
VMOVDQA Y4, C16 |
|||
VMOVDQA Y5, C8 |
|||
VMOVDQA Y6, TWO |
|||
|
|||
CMPQ CX, $64 |
|||
JBE between_0_and_64 |
|||
CMPQ CX, $192 |
|||
JBE between_64_and_192 |
|||
CMPQ CX, $320 |
|||
JBE between_192_and_320 |
|||
CMPQ CX, $448 |
|||
JBE between_320_and_448 |
|||
|
|||
at_least_512: |
|||
VMOVDQA Y0, Y4 |
|||
VMOVDQA Y1, Y5 |
|||
VMOVDQA Y2, Y6 |
|||
VPADDQ TWO, Y3, Y7 |
|||
VMOVDQA Y0, Y8 |
|||
VMOVDQA Y1, Y9 |
|||
VMOVDQA Y2, Y10 |
|||
VPADDQ TWO, Y7, Y11 |
|||
VMOVDQA Y0, Y12 |
|||
VMOVDQA Y1, Y13 |
|||
VMOVDQA Y2, Y14 |
|||
VPADDQ TWO, Y11, Y15 |
|||
|
|||
MOVQ DX, R9 |
|||
|
|||
chacha_loop_512: |
|||
VMOVDQA Y8, TMP_0 |
|||
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8) |
|||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8) |
|||
VMOVDQA TMP_0, Y8 |
|||
VMOVDQA Y0, TMP_0 |
|||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8) |
|||
CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8) |
|||
CHACHA_SHUFFLE(Y1, Y2, Y3) |
|||
CHACHA_SHUFFLE(Y5, Y6, Y7) |
|||
CHACHA_SHUFFLE(Y9, Y10, Y11) |
|||
CHACHA_SHUFFLE(Y13, Y14, Y15) |
|||
|
|||
CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8) |
|||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8) |
|||
VMOVDQA TMP_0, Y0 |
|||
VMOVDQA Y8, TMP_0 |
|||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8) |
|||
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8) |
|||
VMOVDQA TMP_0, Y8 |
|||
CHACHA_SHUFFLE(Y3, Y2, Y1) |
|||
CHACHA_SHUFFLE(Y7, Y6, Y5) |
|||
CHACHA_SHUFFLE(Y11, Y10, Y9) |
|||
CHACHA_SHUFFLE(Y15, Y14, Y13) |
|||
SUBQ $2, R9 |
|||
JA chacha_loop_512 |
|||
|
|||
VMOVDQA Y12, TMP_0 |
|||
VMOVDQA Y13, TMP_1 |
|||
VPADDD STATE_0, Y0, Y0 |
|||
VPADDD STATE_1, Y1, Y1 |
|||
VPADDD STATE_2, Y2, Y2 |
|||
VPADDD STATE_3, Y3, Y3 |
|||
XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13) |
|||
VMOVDQA STATE_0, Y0 |
|||
VMOVDQA STATE_1, Y1 |
|||
VMOVDQA STATE_2, Y2 |
|||
VMOVDQA STATE_3, Y3 |
|||
VPADDQ TWO, Y3, Y3 |
|||
|
|||
VPADDD Y0, Y4, Y4 |
|||
VPADDD Y1, Y5, Y5 |
|||
VPADDD Y2, Y6, Y6 |
|||
VPADDD Y3, Y7, Y7 |
|||
XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13) |
|||
VPADDQ TWO, Y3, Y3 |
|||
|
|||
VPADDD Y0, Y8, Y8 |
|||
VPADDD Y1, Y9, Y9 |
|||
VPADDD Y2, Y10, Y10 |
|||
VPADDD Y3, Y11, Y11 |
|||
XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13) |
|||
VPADDQ TWO, Y3, Y3 |
|||
|
|||
VPADDD TMP_0, Y0, Y12 |
|||
VPADDD TMP_1, Y1, Y13 |
|||
VPADDD Y2, Y14, Y14 |
|||
VPADDD Y3, Y15, Y15 |
|||
VPADDQ TWO, Y3, Y3 |
|||
|
|||
CMPQ CX, $512 |
|||
JB less_than_512 |
|||
|
|||
XOR_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5) |
|||
VMOVDQA Y3, STATE_3 |
|||
ADDQ $512, SI |
|||
ADDQ $512, DI |
|||
SUBQ $512, CX |
|||
CMPQ CX, $448 |
|||
JA at_least_512 |
|||
|
|||
TESTQ CX, CX |
|||
JZ done |
|||
|
|||
VMOVDQA C16, Y14 |
|||
VMOVDQA C8, Y15 |
|||
|
|||
CMPQ CX, $64 |
|||
JBE between_0_and_64 |
|||
CMPQ CX, $192 |
|||
JBE between_64_and_192 |
|||
CMPQ CX, $320 |
|||
JBE between_192_and_320 |
|||
JMP between_320_and_448 |
|||
|
|||
less_than_512: |
|||
XOR_UPPER_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5) |
|||
EXTRACT_LOWER(BX, Y12, Y13, Y14, Y15, Y4) |
|||
ADDQ $448, SI |
|||
ADDQ $448, DI |
|||
SUBQ $448, CX |
|||
JMP finalize |
|||
|
|||
between_320_and_448: |
|||
VMOVDQA Y0, Y4 |
|||
VMOVDQA Y1, Y5 |
|||
VMOVDQA Y2, Y6 |
|||
VPADDQ TWO, Y3, Y7 |
|||
VMOVDQA Y0, Y8 |
|||
VMOVDQA Y1, Y9 |
|||
VMOVDQA Y2, Y10 |
|||
VPADDQ TWO, Y7, Y11 |
|||
|
|||
MOVQ DX, R9 |
|||
|
|||
chacha_loop_384: |
|||
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15) |
|||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) |
|||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) |
|||
CHACHA_SHUFFLE(Y1, Y2, Y3) |
|||
CHACHA_SHUFFLE(Y5, Y6, Y7) |
|||
CHACHA_SHUFFLE(Y9, Y10, Y11) |
|||
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15) |
|||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) |
|||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) |
|||
CHACHA_SHUFFLE(Y3, Y2, Y1) |
|||
CHACHA_SHUFFLE(Y7, Y6, Y5) |
|||
CHACHA_SHUFFLE(Y11, Y10, Y9) |
|||
SUBQ $2, R9 |
|||
JA chacha_loop_384 |
|||
|
|||
VPADDD STATE_0, Y0, Y0 |
|||
VPADDD STATE_1, Y1, Y1 |
|||
VPADDD STATE_2, Y2, Y2 |
|||
VPADDD STATE_3, Y3, Y3 |
|||
XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13) |
|||
VMOVDQA STATE_0, Y0 |
|||
VMOVDQA STATE_1, Y1 |
|||
VMOVDQA STATE_2, Y2 |
|||
VMOVDQA STATE_3, Y3 |
|||
VPADDQ TWO, Y3, Y3 |
|||
|
|||
VPADDD Y0, Y4, Y4 |
|||
VPADDD Y1, Y5, Y5 |
|||
VPADDD Y2, Y6, Y6 |
|||
VPADDD Y3, Y7, Y7 |
|||
XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13) |
|||
VPADDQ TWO, Y3, Y3 |
|||
|
|||
VPADDD Y0, Y8, Y8 |
|||
VPADDD Y1, Y9, Y9 |
|||
VPADDD Y2, Y10, Y10 |
|||
VPADDD Y3, Y11, Y11 |
|||
VPADDQ TWO, Y3, Y3 |
|||
|
|||
CMPQ CX, $384 |
|||
JB less_than_384 |
|||
|
|||
XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13) |
|||
SUBQ $384, CX |
|||
TESTQ CX, CX |
|||
JE done |
|||
|
|||
ADDQ $384, SI |
|||
ADDQ $384, DI |
|||
JMP between_0_and_64 |
|||
|
|||
less_than_384: |
|||
XOR_UPPER_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13) |
|||
EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12) |
|||
ADDQ $320, SI |
|||
ADDQ $320, DI |
|||
SUBQ $320, CX |
|||
JMP finalize |
|||
|
|||
between_192_and_320: |
|||
VMOVDQA Y0, Y4 |
|||
VMOVDQA Y1, Y5 |
|||
VMOVDQA Y2, Y6 |
|||
VMOVDQA Y3, Y7 |
|||
VMOVDQA Y0, Y8 |
|||
VMOVDQA Y1, Y9 |
|||
VMOVDQA Y2, Y10 |
|||
VPADDQ TWO, Y3, Y11 |
|||
|
|||
MOVQ DX, R9 |
|||
|
|||
chacha_loop_256: |
|||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) |
|||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) |
|||
CHACHA_SHUFFLE(Y5, Y6, Y7) |
|||
CHACHA_SHUFFLE(Y9, Y10, Y11) |
|||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) |
|||
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) |
|||
CHACHA_SHUFFLE(Y7, Y6, Y5) |
|||
CHACHA_SHUFFLE(Y11, Y10, Y9) |
|||
SUBQ $2, R9 |
|||
JA chacha_loop_256 |
|||
|
|||
VPADDD Y0, Y4, Y4 |
|||
VPADDD Y1, Y5, Y5 |
|||
VPADDD Y2, Y6, Y6 |
|||
VPADDD Y3, Y7, Y7 |
|||
VPADDQ TWO, Y3, Y3 |
|||
XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13) |
|||
VPADDD Y0, Y8, Y8 |
|||
VPADDD Y1, Y9, Y9 |
|||
VPADDD Y2, Y10, Y10 |
|||
VPADDD Y3, Y11, Y11 |
|||
VPADDQ TWO, Y3, Y3 |
|||
|
|||
CMPQ CX, $256 |
|||
JB less_than_256 |
|||
|
|||
XOR_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13) |
|||
SUBQ $256, CX |
|||
TESTQ CX, CX |
|||
JE done |
|||
|
|||
ADDQ $256, SI |
|||
ADDQ $256, DI |
|||
JMP between_0_and_64 |
|||
|
|||
less_than_256: |
|||
XOR_UPPER_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13) |
|||
EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12) |
|||
ADDQ $192, SI |
|||
ADDQ $192, DI |
|||
SUBQ $192, CX |
|||
JMP finalize |
|||
|
|||
between_64_and_192: |
|||
VMOVDQA Y0, Y4 |
|||
VMOVDQA Y1, Y5 |
|||
VMOVDQA Y2, Y6 |
|||
VMOVDQA Y3, Y7 |
|||
|
|||
MOVQ DX, R9 |
|||
|
|||
chacha_loop_128: |
|||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) |
|||
CHACHA_SHUFFLE(Y5, Y6, Y7) |
|||
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) |
|||
CHACHA_SHUFFLE(Y7, Y6, Y5) |
|||
SUBQ $2, R9 |
|||
JA chacha_loop_128 |
|||
|
|||
VPADDD Y0, Y4, Y4 |
|||
VPADDD Y1, Y5, Y5 |
|||
VPADDD Y2, Y6, Y6 |
|||
VPADDD Y3, Y7, Y7 |
|||
VPADDQ TWO, Y3, Y3 |
|||
|
|||
CMPQ CX, $128 |
|||
JB less_than_128 |
|||
|
|||
XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13) |
|||
SUBQ $128, CX |
|||
TESTQ CX, CX |
|||
JE done |
|||
|
|||
ADDQ $128, SI |
|||
ADDQ $128, DI |
|||
JMP between_0_and_64 |
|||
|
|||
less_than_128: |
|||
XOR_UPPER_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13) |
|||
EXTRACT_LOWER(BX, Y4, Y5, Y6, Y7, Y13) |
|||
ADDQ $64, SI |
|||
ADDQ $64, DI |
|||
SUBQ $64, CX |
|||
JMP finalize |
|||
|
|||
between_0_and_64: |
|||
VMOVDQA X0, X4 |
|||
VMOVDQA X1, X5 |
|||
VMOVDQA X2, X6 |
|||
VMOVDQA X3, X7 |
|||
|
|||
MOVQ DX, R9 |
|||
|
|||
chacha_loop_64: |
|||
CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15) |
|||
CHACHA_SHUFFLE(X5, X6, X7) |
|||
CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15) |
|||
CHACHA_SHUFFLE(X7, X6, X5) |
|||
SUBQ $2, R9 |
|||
JA chacha_loop_64 |
|||
|
|||
VPADDD X0, X4, X4 |
|||
VPADDD X1, X5, X5 |
|||
VPADDD X2, X6, X6 |
|||
VPADDD X3, X7, X7 |
|||
VMOVDQU ·one_AVX<>(SB), X0 |
|||
VPADDQ X0, X3, X3 |
|||
|
|||
CMPQ CX, $64 |
|||
JB less_than_64 |
|||
|
|||
XOR_AVX(DI, SI, 0, X4, X5, X6, X7, X13) |
|||
SUBQ $64, CX |
|||
JMP done |
|||
|
|||
less_than_64: |
|||
VMOVDQU X4, 0(BX) |
|||
VMOVDQU X5, 16(BX) |
|||
VMOVDQU X6, 32(BX) |
|||
VMOVDQU X7, 48(BX) |
|||
|
|||
finalize: |
|||
XORQ R11, R11 |
|||
XORQ R12, R12 |
|||
MOVQ CX, BP |
|||
|
|||
xor_loop: |
|||
MOVB 0(SI), R11 |
|||
MOVB 0(BX), R12 |
|||
XORQ R11, R12 |
|||
MOVB R12, 0(DI) |
|||
INCQ SI |
|||
INCQ BX |
|||
INCQ DI |
|||
DECQ BP |
|||
JA xor_loop |
|||
|
|||
done: |
|||
VMOVDQU X3, 48(AX) |
|||
VZEROUPPER |
|||
MOVQ R8, SP |
|||
MOVQ CX, ret+72(FP) |
|||
RET |
|||
|
|||
// func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
TEXT ·hChaCha20AVX(SB), 4, $0-24 |
|||
MOVQ out+0(FP), DI |
|||
MOVQ nonce+8(FP), AX |
|||
MOVQ key+16(FP), BX |
|||
|
|||
VMOVDQU ·sigma_AVX<>(SB), X0 |
|||
VMOVDQU 0(BX), X1 |
|||
VMOVDQU 16(BX), X2 |
|||
VMOVDQU 0(AX), X3 |
|||
VMOVDQU ·rol16_AVX2<>(SB), X5 |
|||
VMOVDQU ·rol8_AVX2<>(SB), X6 |
|||
|
|||
MOVQ $20, CX |
|||
|
|||
chacha_loop: |
|||
CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6) |
|||
CHACHA_SHUFFLE(X1, X2, X3) |
|||
CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6) |
|||
CHACHA_SHUFFLE(X3, X2, X1) |
|||
SUBQ $2, CX |
|||
JNZ chacha_loop |
|||
|
|||
VMOVDQU X0, 0(DI) |
|||
VMOVDQU X3, 16(DI) |
|||
VZEROUPPER |
|||
RET |
|||
|
|||
// func supportsAVX2() bool |
|||
TEXT ·supportsAVX2(SB), 4, $0-1 |
|||
MOVQ runtime·support_avx(SB), AX |
|||
MOVQ runtime·support_avx2(SB), BX |
|||
ANDQ AX, BX |
|||
MOVB BX, ret+0(FP) |
|||
RET |
|||
@ -0,0 +1,67 @@ |
|||
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
|
|||
// Use of this source code is governed by a license that can be
|
|||
// found in the LICENSE file.
|
|||
|
|||
// +build 386,!gccgo,!appengine,!nacl
|
|||
|
|||
package chacha |
|||
|
|||
import "encoding/binary" |
|||
|
|||
func init() { |
|||
useSSE2 = supportsSSE2() |
|||
useSSSE3 = supportsSSSE3() |
|||
useAVX2 = false |
|||
} |
|||
|
|||
func initialize(state *[64]byte, key []byte, nonce *[16]byte) { |
|||
binary.LittleEndian.PutUint32(state[0:], sigma[0]) |
|||
binary.LittleEndian.PutUint32(state[4:], sigma[1]) |
|||
binary.LittleEndian.PutUint32(state[8:], sigma[2]) |
|||
binary.LittleEndian.PutUint32(state[12:], sigma[3]) |
|||
copy(state[16:], key[:]) |
|||
copy(state[48:], nonce[:]) |
|||
} |
|||
|
|||
// This function is implemented in chacha_386.s
|
|||
//go:noescape
|
|||
func supportsSSE2() bool |
|||
|
|||
// This function is implemented in chacha_386.s
|
|||
//go:noescape
|
|||
func supportsSSSE3() bool |
|||
|
|||
// This function is implemented in chacha_386.s
|
|||
//go:noescape
|
|||
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
|
|||
// This function is implemented in chacha_386.s
|
|||
//go:noescape
|
|||
func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
|
|||
// This function is implemented in chacha_386.s
|
|||
//go:noescape
|
|||
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
|
|||
// This function is implemented in chacha_386.s
|
|||
//go:noescape
|
|||
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
|
|||
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { |
|||
if useSSSE3 { |
|||
hChaCha20SSSE3(out, nonce, key) |
|||
} else if useSSE2 { |
|||
hChaCha20SSE2(out, nonce, key) |
|||
} else { |
|||
hChaCha20Generic(out, nonce, key) |
|||
} |
|||
} |
|||
|
|||
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { |
|||
if useSSSE3 { |
|||
return xorKeyStreamSSSE3(dst, src, block, state, rounds) |
|||
} else if useSSE2 { |
|||
return xorKeyStreamSSE2(dst, src, block, state, rounds) |
|||
} |
|||
return xorKeyStreamGeneric(dst, src, block, state, rounds) |
|||
} |
|||
@ -0,0 +1,311 @@ |
|||
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. |
|||
// Use of this source code is governed by a license that can be |
|||
// found in the LICENSE file. |
|||
|
|||
// +build 386,!gccgo,!appengine,!nacl |
|||
|
|||
#include "textflag.h" |
|||
|
|||
DATA ·sigma<>+0x00(SB)/4, $0x61707865 |
|||
DATA ·sigma<>+0x04(SB)/4, $0x3320646e |
|||
DATA ·sigma<>+0x08(SB)/4, $0x79622d32 |
|||
DATA ·sigma<>+0x0C(SB)/4, $0x6b206574 |
|||
GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 |
|||
|
|||
DATA ·one<>+0x00(SB)/8, $1 |
|||
DATA ·one<>+0x08(SB)/8, $0 |
|||
GLOBL ·one<>(SB), (NOPTR+RODATA), $16 |
|||
|
|||
DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302 |
|||
DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A |
|||
GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 |
|||
|
|||
DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003 |
|||
DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B |
|||
GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 |
|||
|
|||
#define ROTL_SSE2(n, t, v) \ |
|||
MOVO v, t; \ |
|||
PSLLL $n, t; \ |
|||
PSRLL $(32-n), v; \ |
|||
PXOR t, v |
|||
|
|||
#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \ |
|||
PADDL v1, v0; \ |
|||
PXOR v0, v3; \ |
|||
ROTL_SSE2(16, t0, v3); \ |
|||
PADDL v3, v2; \ |
|||
PXOR v2, v1; \ |
|||
ROTL_SSE2(12, t0, v1); \ |
|||
PADDL v1, v0; \ |
|||
PXOR v0, v3; \ |
|||
ROTL_SSE2(8, t0, v3); \ |
|||
PADDL v3, v2; \ |
|||
PXOR v2, v1; \ |
|||
ROTL_SSE2(7, t0, v1) |
|||
|
|||
#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \ |
|||
PADDL v1, v0; \ |
|||
PXOR v0, v3; \ |
|||
PSHUFB r16, v3; \ |
|||
PADDL v3, v2; \ |
|||
PXOR v2, v1; \ |
|||
ROTL_SSE2(12, t0, v1); \ |
|||
PADDL v1, v0; \ |
|||
PXOR v0, v3; \ |
|||
PSHUFB r8, v3; \ |
|||
PADDL v3, v2; \ |
|||
PXOR v2, v1; \ |
|||
ROTL_SSE2(7, t0, v1) |
|||
|
|||
#define CHACHA_SHUFFLE(v1, v2, v3) \ |
|||
PSHUFL $0x39, v1, v1; \ |
|||
PSHUFL $0x4E, v2, v2; \ |
|||
PSHUFL $0x93, v3, v3 |
|||
|
|||
#define XOR(dst, src, off, v0, v1, v2, v3, t0) \ |
|||
MOVOU 0+off(src), t0; \ |
|||
PXOR v0, t0; \ |
|||
MOVOU t0, 0+off(dst); \ |
|||
MOVOU 16+off(src), t0; \ |
|||
PXOR v1, t0; \ |
|||
MOVOU t0, 16+off(dst); \ |
|||
MOVOU 32+off(src), t0; \ |
|||
PXOR v2, t0; \ |
|||
MOVOU t0, 32+off(dst); \ |
|||
MOVOU 48+off(src), t0; \ |
|||
PXOR v3, t0; \ |
|||
MOVOU t0, 48+off(dst) |
|||
|
|||
#define FINALIZE(dst, src, block, len, t0, t1) \ |
|||
XORL t0, t0; \ |
|||
XORL t1, t1; \ |
|||
finalize: \ |
|||
MOVB 0(src), t0; \ |
|||
MOVB 0(block), t1; \ |
|||
XORL t0, t1; \ |
|||
MOVB t1, 0(dst); \ |
|||
INCL src; \ |
|||
INCL block; \ |
|||
INCL dst; \ |
|||
DECL len; \ |
|||
JA finalize \ |
|||
|
|||
// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
TEXT ·xorKeyStreamSSE2(SB), 4, $0-40 |
|||
MOVL dst_base+0(FP), DI |
|||
MOVL src_base+12(FP), SI |
|||
MOVL src_len+16(FP), CX |
|||
MOVL state+28(FP), AX |
|||
MOVL rounds+32(FP), DX |
|||
|
|||
MOVOU 0(AX), X0 |
|||
MOVOU 16(AX), X1 |
|||
MOVOU 32(AX), X2 |
|||
MOVOU 48(AX), X3 |
|||
|
|||
TESTL CX, CX |
|||
JZ done |
|||
|
|||
at_least_64: |
|||
MOVO X0, X4 |
|||
MOVO X1, X5 |
|||
MOVO X2, X6 |
|||
MOVO X3, X7 |
|||
|
|||
MOVL DX, BX |
|||
|
|||
chacha_loop: |
|||
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) |
|||
CHACHA_SHUFFLE(X5, X6, X7) |
|||
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) |
|||
CHACHA_SHUFFLE(X7, X6, X5) |
|||
SUBL $2, BX |
|||
JA chacha_loop |
|||
|
|||
MOVOU 0(AX), X0 |
|||
PADDL X0, X4 |
|||
PADDL X1, X5 |
|||
PADDL X2, X6 |
|||
PADDL X3, X7 |
|||
MOVOU ·one<>(SB), X0 |
|||
PADDQ X0, X3 |
|||
|
|||
CMPL CX, $64 |
|||
JB less_than_64 |
|||
|
|||
XOR(DI, SI, 0, X4, X5, X6, X7, X0) |
|||
MOVOU 0(AX), X0 |
|||
ADDL $64, SI |
|||
ADDL $64, DI |
|||
SUBL $64, CX |
|||
JNZ at_least_64 |
|||
|
|||
less_than_64: |
|||
MOVL CX, BP |
|||
TESTL BP, BP |
|||
JZ done |
|||
|
|||
MOVL block+24(FP), BX |
|||
MOVOU X4, 0(BX) |
|||
MOVOU X5, 16(BX) |
|||
MOVOU X6, 32(BX) |
|||
MOVOU X7, 48(BX) |
|||
FINALIZE(DI, SI, BX, BP, AX, DX) |
|||
|
|||
done: |
|||
MOVL state+28(FP), AX |
|||
MOVOU X3, 48(AX) |
|||
MOVL CX, ret+36(FP) |
|||
RET |
|||
|
|||
// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
TEXT ·xorKeyStreamSSSE3(SB), 4, $64-40 |
|||
MOVL dst_base+0(FP), DI |
|||
MOVL src_base+12(FP), SI |
|||
MOVL src_len+16(FP), CX |
|||
MOVL state+28(FP), AX |
|||
MOVL rounds+32(FP), DX |
|||
|
|||
MOVOU 48(AX), X3 |
|||
TESTL CX, CX |
|||
JZ done |
|||
|
|||
MOVL SP, BP |
|||
ADDL $16, SP |
|||
ANDL $-16, SP |
|||
|
|||
MOVOU ·one<>(SB), X0 |
|||
MOVOU 16(AX), X1 |
|||
MOVOU 32(AX), X2 |
|||
MOVO X0, 0(SP) |
|||
MOVO X1, 16(SP) |
|||
MOVO X2, 32(SP) |
|||
|
|||
MOVOU 0(AX), X0 |
|||
MOVOU ·rol16<>(SB), X1 |
|||
MOVOU ·rol8<>(SB), X2 |
|||
|
|||
at_least_64: |
|||
MOVO X0, X4 |
|||
MOVO 16(SP), X5 |
|||
MOVO 32(SP), X6 |
|||
MOVO X3, X7 |
|||
|
|||
MOVL DX, BX |
|||
|
|||
chacha_loop: |
|||
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) |
|||
CHACHA_SHUFFLE(X5, X6, X7) |
|||
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) |
|||
CHACHA_SHUFFLE(X7, X6, X5) |
|||
SUBL $2, BX |
|||
JA chacha_loop |
|||
|
|||
MOVOU 0(AX), X0 |
|||
PADDL X0, X4 |
|||
PADDL 16(SP), X5 |
|||
PADDL 32(SP), X6 |
|||
PADDL X3, X7 |
|||
PADDQ 0(SP), X3 |
|||
|
|||
CMPL CX, $64 |
|||
JB less_than_64 |
|||
|
|||
XOR(DI, SI, 0, X4, X5, X6, X7, X0) |
|||
MOVOU 0(AX), X0 |
|||
ADDL $64, SI |
|||
ADDL $64, DI |
|||
SUBL $64, CX |
|||
JNZ at_least_64 |
|||
|
|||
less_than_64: |
|||
MOVL BP, SP |
|||
MOVL CX, BP |
|||
TESTL BP, BP |
|||
JE done |
|||
|
|||
MOVL block+24(FP), BX |
|||
MOVOU X4, 0(BX) |
|||
MOVOU X5, 16(BX) |
|||
MOVOU X6, 32(BX) |
|||
MOVOU X7, 48(BX) |
|||
FINALIZE(DI, SI, BX, BP, AX, DX) |
|||
|
|||
done: |
|||
MOVL state+28(FP), AX |
|||
MOVOU X3, 48(AX) |
|||
MOVL CX, ret+36(FP) |
|||
RET |
|||
|
|||
// func supportsSSE2() bool |
|||
TEXT ·supportsSSE2(SB), NOSPLIT, $0-1 |
|||
XORL AX, AX |
|||
INCL AX |
|||
CPUID |
|||
SHRL $26, DX |
|||
ANDL $1, DX |
|||
MOVB DX, ret+0(FP) |
|||
RET |
|||
|
|||
// func supportsSSSE3() bool |
|||
TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1 |
|||
XORL AX, AX |
|||
INCL AX |
|||
CPUID |
|||
SHRL $9, CX |
|||
ANDL $1, CX |
|||
MOVB CX, ret+0(FP) |
|||
RET |
|||
|
|||
// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
TEXT ·hChaCha20SSE2(SB), 4, $0-12 |
|||
MOVL out+0(FP), DI |
|||
MOVL nonce+4(FP), AX |
|||
MOVL key+8(FP), BX |
|||
|
|||
MOVOU ·sigma<>(SB), X0 |
|||
MOVOU 0(BX), X1 |
|||
MOVOU 16(BX), X2 |
|||
MOVOU 0(AX), X3 |
|||
|
|||
MOVL $20, CX |
|||
|
|||
chacha_loop: |
|||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) |
|||
CHACHA_SHUFFLE(X1, X2, X3) |
|||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) |
|||
CHACHA_SHUFFLE(X3, X2, X1) |
|||
SUBL $2, CX |
|||
JNZ chacha_loop |
|||
|
|||
MOVOU X0, 0(DI) |
|||
MOVOU X3, 16(DI) |
|||
RET |
|||
|
|||
// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
TEXT ·hChaCha20SSSE3(SB), 4, $0-12 |
|||
MOVL out+0(FP), DI |
|||
MOVL nonce+4(FP), AX |
|||
MOVL key+8(FP), BX |
|||
|
|||
MOVOU ·sigma<>(SB), X0 |
|||
MOVOU 0(BX), X1 |
|||
MOVOU 16(BX), X2 |
|||
MOVOU 0(AX), X3 |
|||
MOVOU ·rol16<>(SB), X5 |
|||
MOVOU ·rol8<>(SB), X6 |
|||
|
|||
MOVL $20, CX |
|||
|
|||
chacha_loop: |
|||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) |
|||
CHACHA_SHUFFLE(X1, X2, X3) |
|||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) |
|||
CHACHA_SHUFFLE(X3, X2, X1) |
|||
SUBL $2, CX |
|||
JNZ chacha_loop |
|||
|
|||
MOVOU X0, 0(DI) |
|||
MOVOU X3, 16(DI) |
|||
RET |
|||
@ -0,0 +1,788 @@ |
|||
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. |
|||
// Use of this source code is governed by a license that can be |
|||
// found in the LICENSE file. |
|||
|
|||
// +build amd64,!gccgo,!appengine,!nacl |
|||
|
|||
#include "textflag.h" |
|||
|
|||
DATA ·sigma<>+0x00(SB)/4, $0x61707865 |
|||
DATA ·sigma<>+0x04(SB)/4, $0x3320646e |
|||
DATA ·sigma<>+0x08(SB)/4, $0x79622d32 |
|||
DATA ·sigma<>+0x0C(SB)/4, $0x6b206574 |
|||
GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 |
|||
|
|||
DATA ·one<>+0x00(SB)/8, $1 |
|||
DATA ·one<>+0x08(SB)/8, $0 |
|||
GLOBL ·one<>(SB), (NOPTR+RODATA), $16 |
|||
|
|||
DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302 |
|||
DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A |
|||
GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 |
|||
|
|||
DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003 |
|||
DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B |
|||
GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 |
|||
|
|||
#define ROTL_SSE2(n, t, v) \ |
|||
MOVO v, t; \ |
|||
PSLLL $n, t; \ |
|||
PSRLL $(32-n), v; \ |
|||
PXOR t, v |
|||
|
|||
#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \ |
|||
PADDL v1, v0; \ |
|||
PXOR v0, v3; \ |
|||
ROTL_SSE2(16, t0, v3); \ |
|||
PADDL v3, v2; \ |
|||
PXOR v2, v1; \ |
|||
ROTL_SSE2(12, t0, v1); \ |
|||
PADDL v1, v0; \ |
|||
PXOR v0, v3; \ |
|||
ROTL_SSE2(8, t0, v3); \ |
|||
PADDL v3, v2; \ |
|||
PXOR v2, v1; \ |
|||
ROTL_SSE2(7, t0, v1) |
|||
|
|||
#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \ |
|||
PADDL v1, v0; \ |
|||
PXOR v0, v3; \ |
|||
PSHUFB r16, v3; \ |
|||
PADDL v3, v2; \ |
|||
PXOR v2, v1; \ |
|||
ROTL_SSE2(12, t0, v1); \ |
|||
PADDL v1, v0; \ |
|||
PXOR v0, v3; \ |
|||
PSHUFB r8, v3; \ |
|||
PADDL v3, v2; \ |
|||
PXOR v2, v1; \ |
|||
ROTL_SSE2(7, t0, v1) |
|||
|
|||
#define CHACHA_SHUFFLE(v1, v2, v3) \ |
|||
PSHUFL $0x39, v1, v1; \ |
|||
PSHUFL $0x4E, v2, v2; \ |
|||
PSHUFL $0x93, v3, v3 |
|||
|
|||
#define XOR(dst, src, off, v0, v1, v2, v3, t0) \ |
|||
MOVOU 0+off(src), t0; \ |
|||
PXOR v0, t0; \ |
|||
MOVOU t0, 0+off(dst); \ |
|||
MOVOU 16+off(src), t0; \ |
|||
PXOR v1, t0; \ |
|||
MOVOU t0, 16+off(dst); \ |
|||
MOVOU 32+off(src), t0; \ |
|||
PXOR v2, t0; \ |
|||
MOVOU t0, 32+off(dst); \ |
|||
MOVOU 48+off(src), t0; \ |
|||
PXOR v3, t0; \ |
|||
MOVOU t0, 48+off(dst) |
|||
|
|||
// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
TEXT ·xorKeyStreamSSE2(SB), 4, $112-80 |
|||
MOVQ dst_base+0(FP), DI |
|||
MOVQ src_base+24(FP), SI |
|||
MOVQ src_len+32(FP), CX |
|||
MOVQ block+48(FP), BX |
|||
MOVQ state+56(FP), AX |
|||
MOVQ rounds+64(FP), DX |
|||
|
|||
MOVQ SP, R9 |
|||
ADDQ $16, SP |
|||
ANDQ $-16, SP |
|||
|
|||
MOVOU 0(AX), X0 |
|||
MOVOU 16(AX), X1 |
|||
MOVOU 32(AX), X2 |
|||
MOVOU 48(AX), X3 |
|||
MOVOU ·one<>(SB), X15 |
|||
|
|||
TESTQ CX, CX |
|||
JZ done |
|||
|
|||
CMPQ CX, $64 |
|||
JBE between_0_and_64 |
|||
|
|||
CMPQ CX, $128 |
|||
JBE between_64_and_128 |
|||
|
|||
MOVO X0, 0(SP) |
|||
MOVO X1, 16(SP) |
|||
MOVO X2, 32(SP) |
|||
MOVO X3, 48(SP) |
|||
MOVO X15, 64(SP) |
|||
|
|||
CMPQ CX, $192 |
|||
JBE between_128_and_192 |
|||
|
|||
MOVQ $192, R14 |
|||
|
|||
at_least_256: |
|||
MOVO X0, X4 |
|||
MOVO X1, X5 |
|||
MOVO X2, X6 |
|||
MOVO X3, X7 |
|||
PADDQ 64(SP), X7 |
|||
MOVO X0, X12 |
|||
MOVO X1, X13 |
|||
MOVO X2, X14 |
|||
MOVO X7, X15 |
|||
PADDQ 64(SP), X15 |
|||
MOVO X0, X8 |
|||
MOVO X1, X9 |
|||
MOVO X2, X10 |
|||
MOVO X15, X11 |
|||
PADDQ 64(SP), X11 |
|||
|
|||
MOVQ DX, R8 |
|||
|
|||
chacha_loop_256: |
|||
MOVO X8, 80(SP) |
|||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X8) |
|||
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X8) |
|||
MOVO 80(SP), X8 |
|||
|
|||
MOVO X0, 80(SP) |
|||
CHACHA_QROUND_SSE2(X12, X13, X14, X15, X0) |
|||
CHACHA_QROUND_SSE2(X8, X9, X10, X11, X0) |
|||
MOVO 80(SP), X0 |
|||
|
|||
CHACHA_SHUFFLE(X1, X2, X3) |
|||
CHACHA_SHUFFLE(X5, X6, X7) |
|||
CHACHA_SHUFFLE(X13, X14, X15) |
|||
CHACHA_SHUFFLE(X9, X10, X11) |
|||
|
|||
MOVO X8, 80(SP) |
|||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X8) |
|||
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X8) |
|||
MOVO 80(SP), X8 |
|||
|
|||
MOVO X0, 80(SP) |
|||
CHACHA_QROUND_SSE2(X12, X13, X14, X15, X0) |
|||
CHACHA_QROUND_SSE2(X8, X9, X10, X11, X0) |
|||
MOVO 80(SP), X0 |
|||
|
|||
CHACHA_SHUFFLE(X3, X2, X1) |
|||
CHACHA_SHUFFLE(X7, X6, X5) |
|||
CHACHA_SHUFFLE(X15, X14, X13) |
|||
CHACHA_SHUFFLE(X11, X10, X9) |
|||
SUBQ $2, R8 |
|||
JA chacha_loop_256 |
|||
|
|||
MOVO X8, 80(SP) |
|||
|
|||
PADDL 0(SP), X0 |
|||
PADDL 16(SP), X1 |
|||
PADDL 32(SP), X2 |
|||
PADDL 48(SP), X3 |
|||
XOR(DI, SI, 0, X0, X1, X2, X3, X8) |
|||
|
|||
MOVO 0(SP), X0 |
|||
MOVO 16(SP), X1 |
|||
MOVO 32(SP), X2 |
|||
MOVO 48(SP), X3 |
|||
PADDQ 64(SP), X3 |
|||
|
|||
PADDL X0, X4 |
|||
PADDL X1, X5 |
|||
PADDL X2, X6 |
|||
PADDL X3, X7 |
|||
PADDQ 64(SP), X3 |
|||
XOR(DI, SI, 64, X4, X5, X6, X7, X8) |
|||
|
|||
MOVO 64(SP), X5 |
|||
MOVO 80(SP), X8 |
|||
|
|||
PADDL X0, X12 |
|||
PADDL X1, X13 |
|||
PADDL X2, X14 |
|||
PADDL X3, X15 |
|||
PADDQ X5, X3 |
|||
XOR(DI, SI, 128, X12, X13, X14, X15, X4) |
|||
|
|||
PADDL X0, X8 |
|||
PADDL X1, X9 |
|||
PADDL X2, X10 |
|||
PADDL X3, X11 |
|||
PADDQ X5, X3 |
|||
|
|||
CMPQ CX, $256 |
|||
JB less_than_64 |
|||
|
|||
XOR(DI, SI, 192, X8, X9, X10, X11, X4) |
|||
MOVO X3, 48(SP) |
|||
ADDQ $256, SI |
|||
ADDQ $256, DI |
|||
SUBQ $256, CX |
|||
CMPQ CX, $192 |
|||
JA at_least_256 |
|||
|
|||
TESTQ CX, CX |
|||
JZ done |
|||
MOVO 64(SP), X15 |
|||
CMPQ CX, $64 |
|||
JBE between_0_and_64 |
|||
CMPQ CX, $128 |
|||
JBE between_64_and_128 |
|||
|
|||
between_128_and_192: |
|||
MOVQ $128, R14 |
|||
MOVO X0, X4 |
|||
MOVO X1, X5 |
|||
MOVO X2, X6 |
|||
MOVO X3, X7 |
|||
PADDQ X15, X7 |
|||
MOVO X0, X8 |
|||
MOVO X1, X9 |
|||
MOVO X2, X10 |
|||
MOVO X7, X11 |
|||
PADDQ X15, X11 |
|||
|
|||
MOVQ DX, R8 |
|||
|
|||
chacha_loop_192: |
|||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X12) |
|||
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) |
|||
CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) |
|||
CHACHA_SHUFFLE(X1, X2, X3) |
|||
CHACHA_SHUFFLE(X5, X6, X7) |
|||
CHACHA_SHUFFLE(X9, X10, X11) |
|||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X12) |
|||
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) |
|||
CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) |
|||
CHACHA_SHUFFLE(X3, X2, X1) |
|||
CHACHA_SHUFFLE(X7, X6, X5) |
|||
CHACHA_SHUFFLE(X11, X10, X9) |
|||
SUBQ $2, R8 |
|||
JA chacha_loop_192 |
|||
|
|||
PADDL 0(SP), X0 |
|||
PADDL 16(SP), X1 |
|||
PADDL 32(SP), X2 |
|||
PADDL 48(SP), X3 |
|||
XOR(DI, SI, 0, X0, X1, X2, X3, X12) |
|||
|
|||
MOVO 0(SP), X0 |
|||
MOVO 16(SP), X1 |
|||
MOVO 32(SP), X2 |
|||
MOVO 48(SP), X3 |
|||
PADDQ X15, X3 |
|||
|
|||
PADDL X0, X4 |
|||
PADDL X1, X5 |
|||
PADDL X2, X6 |
|||
PADDL X3, X7 |
|||
PADDQ X15, X3 |
|||
XOR(DI, SI, 64, X4, X5, X6, X7, X12) |
|||
|
|||
PADDL X0, X8 |
|||
PADDL X1, X9 |
|||
PADDL X2, X10 |
|||
PADDL X3, X11 |
|||
PADDQ X15, X3 |
|||
|
|||
CMPQ CX, $192 |
|||
JB less_than_64 |
|||
|
|||
XOR(DI, SI, 128, X8, X9, X10, X11, X12) |
|||
SUBQ $192, CX |
|||
JMP done |
|||
|
|||
between_64_and_128: |
|||
MOVQ $64, R14 |
|||
MOVO X0, X4 |
|||
MOVO X1, X5 |
|||
MOVO X2, X6 |
|||
MOVO X3, X7 |
|||
MOVO X0, X8 |
|||
MOVO X1, X9 |
|||
MOVO X2, X10 |
|||
MOVO X3, X11 |
|||
PADDQ X15, X11 |
|||
|
|||
MOVQ DX, R8 |
|||
|
|||
chacha_loop_128: |
|||
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) |
|||
CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) |
|||
CHACHA_SHUFFLE(X5, X6, X7) |
|||
CHACHA_SHUFFLE(X9, X10, X11) |
|||
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) |
|||
CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) |
|||
CHACHA_SHUFFLE(X7, X6, X5) |
|||
CHACHA_SHUFFLE(X11, X10, X9) |
|||
SUBQ $2, R8 |
|||
JA chacha_loop_128 |
|||
|
|||
PADDL X0, X4 |
|||
PADDL X1, X5 |
|||
PADDL X2, X6 |
|||
PADDL X3, X7 |
|||
PADDQ X15, X3 |
|||
PADDL X0, X8 |
|||
PADDL X1, X9 |
|||
PADDL X2, X10 |
|||
PADDL X3, X11 |
|||
PADDQ X15, X3 |
|||
XOR(DI, SI, 0, X4, X5, X6, X7, X12) |
|||
|
|||
CMPQ CX, $128 |
|||
JB less_than_64 |
|||
|
|||
XOR(DI, SI, 64, X8, X9, X10, X11, X12) |
|||
SUBQ $128, CX |
|||
JMP done |
|||
|
|||
between_0_and_64: |
|||
MOVQ $0, R14 |
|||
MOVO X0, X8 |
|||
MOVO X1, X9 |
|||
MOVO X2, X10 |
|||
MOVO X3, X11 |
|||
MOVQ DX, R8 |
|||
|
|||
chacha_loop_64: |
|||
CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) |
|||
CHACHA_SHUFFLE(X9, X10, X11) |
|||
CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) |
|||
CHACHA_SHUFFLE(X11, X10, X9) |
|||
SUBQ $2, R8 |
|||
JA chacha_loop_64 |
|||
|
|||
PADDL X0, X8 |
|||
PADDL X1, X9 |
|||
PADDL X2, X10 |
|||
PADDL X3, X11 |
|||
PADDQ X15, X3 |
|||
CMPQ CX, $64 |
|||
JB less_than_64 |
|||
|
|||
XOR(DI, SI, 0, X8, X9, X10, X11, X12) |
|||
SUBQ $64, CX |
|||
JMP done |
|||
|
|||
less_than_64: |
|||
// R14 contains the num of bytes already xor'd |
|||
ADDQ R14, SI |
|||
ADDQ R14, DI |
|||
SUBQ R14, CX |
|||
MOVOU X8, 0(BX) |
|||
MOVOU X9, 16(BX) |
|||
MOVOU X10, 32(BX) |
|||
MOVOU X11, 48(BX) |
|||
XORQ R11, R11 |
|||
XORQ R12, R12 |
|||
MOVQ CX, BP |
|||
|
|||
xor_loop: |
|||
MOVB 0(SI), R11 |
|||
MOVB 0(BX), R12 |
|||
XORQ R11, R12 |
|||
MOVB R12, 0(DI) |
|||
INCQ SI |
|||
INCQ BX |
|||
INCQ DI |
|||
DECQ BP |
|||
JA xor_loop |
|||
|
|||
done: |
|||
MOVOU X3, 48(AX) |
|||
MOVQ R9, SP |
|||
MOVQ CX, ret+72(FP) |
|||
RET |
|||
|
|||
// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
TEXT ·xorKeyStreamSSSE3(SB), 4, $144-80 |
|||
MOVQ dst_base+0(FP), DI |
|||
MOVQ src_base+24(FP), SI |
|||
MOVQ src_len+32(FP), CX |
|||
MOVQ block+48(FP), BX |
|||
MOVQ state+56(FP), AX |
|||
MOVQ rounds+64(FP), DX |
|||
|
|||
MOVQ SP, R9 |
|||
ADDQ $16, SP |
|||
ANDQ $-16, SP |
|||
|
|||
MOVOU 0(AX), X0 |
|||
MOVOU 16(AX), X1 |
|||
MOVOU 32(AX), X2 |
|||
MOVOU 48(AX), X3 |
|||
MOVOU ·rol16<>(SB), X13 |
|||
MOVOU ·rol8<>(SB), X14 |
|||
MOVOU ·one<>(SB), X15 |
|||
|
|||
TESTQ CX, CX |
|||
JZ done |
|||
|
|||
CMPQ CX, $64 |
|||
JBE between_0_and_64 |
|||
|
|||
CMPQ CX, $128 |
|||
JBE between_64_and_128 |
|||
|
|||
MOVO X0, 0(SP) |
|||
MOVO X1, 16(SP) |
|||
MOVO X2, 32(SP) |
|||
MOVO X3, 48(SP) |
|||
MOVO X15, 64(SP) |
|||
|
|||
CMPQ CX, $192 |
|||
JBE between_128_and_192 |
|||
|
|||
MOVO X13, 96(SP) |
|||
MOVO X14, 112(SP) |
|||
MOVQ $192, R14 |
|||
|
|||
at_least_256: |
|||
MOVO X0, X4 |
|||
MOVO X1, X5 |
|||
MOVO X2, X6 |
|||
MOVO X3, X7 |
|||
PADDQ 64(SP), X7 |
|||
MOVO X0, X12 |
|||
MOVO X1, X13 |
|||
MOVO X2, X14 |
|||
MOVO X7, X15 |
|||
PADDQ 64(SP), X15 |
|||
MOVO X0, X8 |
|||
MOVO X1, X9 |
|||
MOVO X2, X10 |
|||
MOVO X15, X11 |
|||
PADDQ 64(SP), X11 |
|||
|
|||
MOVQ DX, R8 |
|||
|
|||
chacha_loop_256: |
|||
MOVO X8, 80(SP) |
|||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X8, 96(SP), 112(SP)) |
|||
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X8, 96(SP), 112(SP)) |
|||
MOVO 80(SP), X8 |
|||
|
|||
MOVO X0, 80(SP) |
|||
CHACHA_QROUND_SSSE3(X12, X13, X14, X15, X0, 96(SP), 112(SP)) |
|||
CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X0, 96(SP), 112(SP)) |
|||
MOVO 80(SP), X0 |
|||
|
|||
CHACHA_SHUFFLE(X1, X2, X3) |
|||
CHACHA_SHUFFLE(X5, X6, X7) |
|||
CHACHA_SHUFFLE(X13, X14, X15) |
|||
CHACHA_SHUFFLE(X9, X10, X11) |
|||
|
|||
MOVO X8, 80(SP) |
|||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X8, 96(SP), 112(SP)) |
|||
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X8, 96(SP), 112(SP)) |
|||
MOVO 80(SP), X8 |
|||
|
|||
MOVO X0, 80(SP) |
|||
CHACHA_QROUND_SSSE3(X12, X13, X14, X15, X0, 96(SP), 112(SP)) |
|||
CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X0, 96(SP), 112(SP)) |
|||
MOVO 80(SP), X0 |
|||
|
|||
CHACHA_SHUFFLE(X3, X2, X1) |
|||
CHACHA_SHUFFLE(X7, X6, X5) |
|||
CHACHA_SHUFFLE(X15, X14, X13) |
|||
CHACHA_SHUFFLE(X11, X10, X9) |
|||
SUBQ $2, R8 |
|||
JA chacha_loop_256 |
|||
|
|||
MOVO X8, 80(SP) |
|||
|
|||
PADDL 0(SP), X0 |
|||
PADDL 16(SP), X1 |
|||
PADDL 32(SP), X2 |
|||
PADDL 48(SP), X3 |
|||
XOR(DI, SI, 0, X0, X1, X2, X3, X8) |
|||
MOVO 0(SP), X0 |
|||
MOVO 16(SP), X1 |
|||
MOVO 32(SP), X2 |
|||
MOVO 48(SP), X3 |
|||
PADDQ 64(SP), X3 |
|||
|
|||
PADDL X0, X4 |
|||
PADDL X1, X5 |
|||
PADDL X2, X6 |
|||
PADDL X3, X7 |
|||
PADDQ 64(SP), X3 |
|||
XOR(DI, SI, 64, X4, X5, X6, X7, X8) |
|||
|
|||
MOVO 64(SP), X5 |
|||
MOVO 80(SP), X8 |
|||
|
|||
PADDL X0, X12 |
|||
PADDL X1, X13 |
|||
PADDL X2, X14 |
|||
PADDL X3, X15 |
|||
PADDQ X5, X3 |
|||
XOR(DI, SI, 128, X12, X13, X14, X15, X4) |
|||
|
|||
PADDL X0, X8 |
|||
PADDL X1, X9 |
|||
PADDL X2, X10 |
|||
PADDL X3, X11 |
|||
PADDQ X5, X3 |
|||
|
|||
CMPQ CX, $256 |
|||
JB less_than_64 |
|||
|
|||
XOR(DI, SI, 192, X8, X9, X10, X11, X4) |
|||
MOVO X3, 48(SP) |
|||
ADDQ $256, SI |
|||
ADDQ $256, DI |
|||
SUBQ $256, CX |
|||
CMPQ CX, $192 |
|||
JA at_least_256 |
|||
|
|||
TESTQ CX, CX |
|||
JZ done |
|||
MOVOU ·rol16<>(SB), X13 |
|||
MOVOU ·rol8<>(SB), X14 |
|||
MOVO 64(SP), X15 |
|||
CMPQ CX, $64 |
|||
JBE between_0_and_64 |
|||
CMPQ CX, $128 |
|||
JBE between_64_and_128 |
|||
|
|||
between_128_and_192: |
|||
MOVQ $128, R14 |
|||
MOVO X0, X4 |
|||
MOVO X1, X5 |
|||
MOVO X2, X6 |
|||
MOVO X3, X7 |
|||
PADDQ X15, X7 |
|||
MOVO X0, X8 |
|||
MOVO X1, X9 |
|||
MOVO X2, X10 |
|||
MOVO X7, X11 |
|||
PADDQ X15, X11 |
|||
|
|||
MOVQ DX, R8 |
|||
|
|||
chacha_loop_192: |
|||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X12, X13, X14) |
|||
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) |
|||
CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) |
|||
CHACHA_SHUFFLE(X1, X2, X3) |
|||
CHACHA_SHUFFLE(X5, X6, X7) |
|||
CHACHA_SHUFFLE(X9, X10, X11) |
|||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X12, X13, X14) |
|||
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) |
|||
CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) |
|||
CHACHA_SHUFFLE(X3, X2, X1) |
|||
CHACHA_SHUFFLE(X7, X6, X5) |
|||
CHACHA_SHUFFLE(X11, X10, X9) |
|||
SUBQ $2, R8 |
|||
JA chacha_loop_192 |
|||
|
|||
PADDL 0(SP), X0 |
|||
PADDL 16(SP), X1 |
|||
PADDL 32(SP), X2 |
|||
PADDL 48(SP), X3 |
|||
XOR(DI, SI, 0, X0, X1, X2, X3, X12) |
|||
|
|||
MOVO 0(SP), X0 |
|||
MOVO 16(SP), X1 |
|||
MOVO 32(SP), X2 |
|||
MOVO 48(SP), X3 |
|||
PADDQ X15, X3 |
|||
|
|||
PADDL X0, X4 |
|||
PADDL X1, X5 |
|||
PADDL X2, X6 |
|||
PADDL X3, X7 |
|||
PADDQ X15, X3 |
|||
XOR(DI, SI, 64, X4, X5, X6, X7, X12) |
|||
|
|||
PADDL X0, X8 |
|||
PADDL X1, X9 |
|||
PADDL X2, X10 |
|||
PADDL X3, X11 |
|||
PADDQ X15, X3 |
|||
|
|||
CMPQ CX, $192 |
|||
JB less_than_64 |
|||
|
|||
XOR(DI, SI, 128, X8, X9, X10, X11, X12) |
|||
SUBQ $192, CX |
|||
JMP done |
|||
|
|||
between_64_and_128: |
|||
MOVQ $64, R14 |
|||
MOVO X0, X4 |
|||
MOVO X1, X5 |
|||
MOVO X2, X6 |
|||
MOVO X3, X7 |
|||
MOVO X0, X8 |
|||
MOVO X1, X9 |
|||
MOVO X2, X10 |
|||
MOVO X3, X11 |
|||
PADDQ X15, X11 |
|||
|
|||
MOVQ DX, R8 |
|||
|
|||
chacha_loop_128: |
|||
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) |
|||
CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) |
|||
CHACHA_SHUFFLE(X5, X6, X7) |
|||
CHACHA_SHUFFLE(X9, X10, X11) |
|||
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) |
|||
CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) |
|||
CHACHA_SHUFFLE(X7, X6, X5) |
|||
CHACHA_SHUFFLE(X11, X10, X9) |
|||
SUBQ $2, R8 |
|||
JA chacha_loop_128 |
|||
|
|||
PADDL X0, X4 |
|||
PADDL X1, X5 |
|||
PADDL X2, X6 |
|||
PADDL X3, X7 |
|||
PADDQ X15, X3 |
|||
PADDL X0, X8 |
|||
PADDL X1, X9 |
|||
PADDL X2, X10 |
|||
PADDL X3, X11 |
|||
PADDQ X15, X3 |
|||
XOR(DI, SI, 0, X4, X5, X6, X7, X12) |
|||
|
|||
CMPQ CX, $128 |
|||
JB less_than_64 |
|||
|
|||
XOR(DI, SI, 64, X8, X9, X10, X11, X12) |
|||
SUBQ $128, CX |
|||
JMP done |
|||
|
|||
between_0_and_64: |
|||
MOVQ $0, R14 |
|||
MOVO X0, X8 |
|||
MOVO X1, X9 |
|||
MOVO X2, X10 |
|||
MOVO X3, X11 |
|||
MOVQ DX, R8 |
|||
|
|||
chacha_loop_64: |
|||
CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) |
|||
CHACHA_SHUFFLE(X9, X10, X11) |
|||
CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) |
|||
CHACHA_SHUFFLE(X11, X10, X9) |
|||
SUBQ $2, R8 |
|||
JA chacha_loop_64 |
|||
|
|||
PADDL X0, X8 |
|||
PADDL X1, X9 |
|||
PADDL X2, X10 |
|||
PADDL X3, X11 |
|||
PADDQ X15, X3 |
|||
CMPQ CX, $64 |
|||
JB less_than_64 |
|||
|
|||
XOR(DI, SI, 0, X8, X9, X10, X11, X12) |
|||
SUBQ $64, CX |
|||
JMP done |
|||
|
|||
less_than_64: |
|||
// R14 contains the num of bytes already xor'd |
|||
ADDQ R14, SI |
|||
ADDQ R14, DI |
|||
SUBQ R14, CX |
|||
MOVOU X8, 0(BX) |
|||
MOVOU X9, 16(BX) |
|||
MOVOU X10, 32(BX) |
|||
MOVOU X11, 48(BX) |
|||
XORQ R11, R11 |
|||
XORQ R12, R12 |
|||
MOVQ CX, BP |
|||
|
|||
xor_loop: |
|||
MOVB 0(SI), R11 |
|||
MOVB 0(BX), R12 |
|||
XORQ R11, R12 |
|||
MOVB R12, 0(DI) |
|||
INCQ SI |
|||
INCQ BX |
|||
INCQ DI |
|||
DECQ BP |
|||
JA xor_loop |
|||
|
|||
done: |
|||
MOVQ R9, SP |
|||
MOVOU X3, 48(AX) |
|||
MOVQ CX, ret+72(FP) |
|||
RET |
|||
|
|||
// func supportsSSSE3() bool |
|||
TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1 |
|||
XORQ AX, AX |
|||
INCQ AX |
|||
CPUID |
|||
SHRQ $9, CX |
|||
ANDQ $1, CX |
|||
MOVB CX, ret+0(FP) |
|||
RET |
|||
|
|||
// func initialize(state *[64]byte, key []byte, nonce *[16]byte) |
|||
TEXT ·initialize(SB), 4, $0-40 |
|||
MOVQ state+0(FP), DI |
|||
MOVQ key+8(FP), AX |
|||
MOVQ nonce+32(FP), BX |
|||
|
|||
MOVOU ·sigma<>(SB), X0 |
|||
MOVOU 0(AX), X1 |
|||
MOVOU 16(AX), X2 |
|||
MOVOU 0(BX), X3 |
|||
|
|||
MOVOU X0, 0(DI) |
|||
MOVOU X1, 16(DI) |
|||
MOVOU X2, 32(DI) |
|||
MOVOU X3, 48(DI) |
|||
RET |
|||
|
|||
// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
TEXT ·hChaCha20SSE2(SB), 4, $0-24 |
|||
MOVQ out+0(FP), DI |
|||
MOVQ nonce+8(FP), AX |
|||
MOVQ key+16(FP), BX |
|||
|
|||
MOVOU ·sigma<>(SB), X0 |
|||
MOVOU 0(BX), X1 |
|||
MOVOU 16(BX), X2 |
|||
MOVOU 0(AX), X3 |
|||
|
|||
MOVQ $20, CX |
|||
|
|||
chacha_loop: |
|||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) |
|||
CHACHA_SHUFFLE(X1, X2, X3) |
|||
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) |
|||
CHACHA_SHUFFLE(X3, X2, X1) |
|||
SUBQ $2, CX |
|||
JNZ chacha_loop |
|||
|
|||
MOVOU X0, 0(DI) |
|||
MOVOU X3, 16(DI) |
|||
RET |
|||
|
|||
// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
TEXT ·hChaCha20SSSE3(SB), 4, $0-24 |
|||
MOVQ out+0(FP), DI |
|||
MOVQ nonce+8(FP), AX |
|||
MOVQ key+16(FP), BX |
|||
|
|||
MOVOU ·sigma<>(SB), X0 |
|||
MOVOU 0(BX), X1 |
|||
MOVOU 16(BX), X2 |
|||
MOVOU 0(AX), X3 |
|||
MOVOU ·rol16<>(SB), X5 |
|||
MOVOU ·rol8<>(SB), X6 |
|||
|
|||
MOVQ $20, CX |
|||
|
|||
chacha_loop: |
|||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) |
|||
CHACHA_SHUFFLE(X1, X2, X3) |
|||
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) |
|||
CHACHA_SHUFFLE(X3, X2, X1) |
|||
SUBQ $2, CX |
|||
JNZ chacha_loop |
|||
|
|||
MOVOU X0, 0(DI) |
|||
MOVOU X3, 16(DI) |
|||
RET |
|||
@ -0,0 +1,319 @@ |
|||
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
|
|||
// Use of this source code is governed by a license that can be
|
|||
// found in the LICENSE file.
|
|||
|
|||
package chacha |
|||
|
|||
import "encoding/binary" |
|||
|
|||
var sigma = [4]uint32{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574} |
|||
|
|||
func xorKeyStreamGeneric(dst, src []byte, block, state *[64]byte, rounds int) int { |
|||
for len(src) >= 64 { |
|||
chachaGeneric(block, state, rounds) |
|||
|
|||
for i, v := range block { |
|||
dst[i] = src[i] ^ v |
|||
} |
|||
src = src[64:] |
|||
dst = dst[64:] |
|||
} |
|||
|
|||
n := len(src) |
|||
if n > 0 { |
|||
chachaGeneric(block, state, rounds) |
|||
for i, v := range src { |
|||
dst[i] = v ^ block[i] |
|||
} |
|||
} |
|||
return n |
|||
} |
|||
|
|||
func chachaGeneric(dst *[64]byte, state *[64]byte, rounds int) { |
|||
v00 := binary.LittleEndian.Uint32(state[0:]) |
|||
v01 := binary.LittleEndian.Uint32(state[4:]) |
|||
v02 := binary.LittleEndian.Uint32(state[8:]) |
|||
v03 := binary.LittleEndian.Uint32(state[12:]) |
|||
v04 := binary.LittleEndian.Uint32(state[16:]) |
|||
v05 := binary.LittleEndian.Uint32(state[20:]) |
|||
v06 := binary.LittleEndian.Uint32(state[24:]) |
|||
v07 := binary.LittleEndian.Uint32(state[28:]) |
|||
v08 := binary.LittleEndian.Uint32(state[32:]) |
|||
v09 := binary.LittleEndian.Uint32(state[36:]) |
|||
v10 := binary.LittleEndian.Uint32(state[40:]) |
|||
v11 := binary.LittleEndian.Uint32(state[44:]) |
|||
v12 := binary.LittleEndian.Uint32(state[48:]) |
|||
v13 := binary.LittleEndian.Uint32(state[52:]) |
|||
v14 := binary.LittleEndian.Uint32(state[56:]) |
|||
v15 := binary.LittleEndian.Uint32(state[60:]) |
|||
|
|||
s00, s01, s02, s03, s04, s05, s06, s07 := v00, v01, v02, v03, v04, v05, v06, v07 |
|||
s08, s09, s10, s11, s12, s13, s14, s15 := v08, v09, v10, v11, v12, v13, v14, v15 |
|||
|
|||
for i := 0; i < rounds; i += 2 { |
|||
v00 += v04 |
|||
v12 ^= v00 |
|||
v12 = (v12 << 16) | (v12 >> 16) |
|||
v08 += v12 |
|||
v04 ^= v08 |
|||
v04 = (v04 << 12) | (v04 >> 20) |
|||
v00 += v04 |
|||
v12 ^= v00 |
|||
v12 = (v12 << 8) | (v12 >> 24) |
|||
v08 += v12 |
|||
v04 ^= v08 |
|||
v04 = (v04 << 7) | (v04 >> 25) |
|||
v01 += v05 |
|||
v13 ^= v01 |
|||
v13 = (v13 << 16) | (v13 >> 16) |
|||
v09 += v13 |
|||
v05 ^= v09 |
|||
v05 = (v05 << 12) | (v05 >> 20) |
|||
v01 += v05 |
|||
v13 ^= v01 |
|||
v13 = (v13 << 8) | (v13 >> 24) |
|||
v09 += v13 |
|||
v05 ^= v09 |
|||
v05 = (v05 << 7) | (v05 >> 25) |
|||
v02 += v06 |
|||
v14 ^= v02 |
|||
v14 = (v14 << 16) | (v14 >> 16) |
|||
v10 += v14 |
|||
v06 ^= v10 |
|||
v06 = (v06 << 12) | (v06 >> 20) |
|||
v02 += v06 |
|||
v14 ^= v02 |
|||
v14 = (v14 << 8) | (v14 >> 24) |
|||
v10 += v14 |
|||
v06 ^= v10 |
|||
v06 = (v06 << 7) | (v06 >> 25) |
|||
v03 += v07 |
|||
v15 ^= v03 |
|||
v15 = (v15 << 16) | (v15 >> 16) |
|||
v11 += v15 |
|||
v07 ^= v11 |
|||
v07 = (v07 << 12) | (v07 >> 20) |
|||
v03 += v07 |
|||
v15 ^= v03 |
|||
v15 = (v15 << 8) | (v15 >> 24) |
|||
v11 += v15 |
|||
v07 ^= v11 |
|||
v07 = (v07 << 7) | (v07 >> 25) |
|||
v00 += v05 |
|||
v15 ^= v00 |
|||
v15 = (v15 << 16) | (v15 >> 16) |
|||
v10 += v15 |
|||
v05 ^= v10 |
|||
v05 = (v05 << 12) | (v05 >> 20) |
|||
v00 += v05 |
|||
v15 ^= v00 |
|||
v15 = (v15 << 8) | (v15 >> 24) |
|||
v10 += v15 |
|||
v05 ^= v10 |
|||
v05 = (v05 << 7) | (v05 >> 25) |
|||
v01 += v06 |
|||
v12 ^= v01 |
|||
v12 = (v12 << 16) | (v12 >> 16) |
|||
v11 += v12 |
|||
v06 ^= v11 |
|||
v06 = (v06 << 12) | (v06 >> 20) |
|||
v01 += v06 |
|||
v12 ^= v01 |
|||
v12 = (v12 << 8) | (v12 >> 24) |
|||
v11 += v12 |
|||
v06 ^= v11 |
|||
v06 = (v06 << 7) | (v06 >> 25) |
|||
v02 += v07 |
|||
v13 ^= v02 |
|||
v13 = (v13 << 16) | (v13 >> 16) |
|||
v08 += v13 |
|||
v07 ^= v08 |
|||
v07 = (v07 << 12) | (v07 >> 20) |
|||
v02 += v07 |
|||
v13 ^= v02 |
|||
v13 = (v13 << 8) | (v13 >> 24) |
|||
v08 += v13 |
|||
v07 ^= v08 |
|||
v07 = (v07 << 7) | (v07 >> 25) |
|||
v03 += v04 |
|||
v14 ^= v03 |
|||
v14 = (v14 << 16) | (v14 >> 16) |
|||
v09 += v14 |
|||
v04 ^= v09 |
|||
v04 = (v04 << 12) | (v04 >> 20) |
|||
v03 += v04 |
|||
v14 ^= v03 |
|||
v14 = (v14 << 8) | (v14 >> 24) |
|||
v09 += v14 |
|||
v04 ^= v09 |
|||
v04 = (v04 << 7) | (v04 >> 25) |
|||
} |
|||
|
|||
v00 += s00 |
|||
v01 += s01 |
|||
v02 += s02 |
|||
v03 += s03 |
|||
v04 += s04 |
|||
v05 += s05 |
|||
v06 += s06 |
|||
v07 += s07 |
|||
v08 += s08 |
|||
v09 += s09 |
|||
v10 += s10 |
|||
v11 += s11 |
|||
v12 += s12 |
|||
v13 += s13 |
|||
v14 += s14 |
|||
v15 += s15 |
|||
|
|||
s12++ |
|||
binary.LittleEndian.PutUint32(state[48:], s12) |
|||
if s12 == 0 { // indicates overflow
|
|||
s13++ |
|||
binary.LittleEndian.PutUint32(state[52:], s13) |
|||
} |
|||
|
|||
binary.LittleEndian.PutUint32(dst[0:], v00) |
|||
binary.LittleEndian.PutUint32(dst[4:], v01) |
|||
binary.LittleEndian.PutUint32(dst[8:], v02) |
|||
binary.LittleEndian.PutUint32(dst[12:], v03) |
|||
binary.LittleEndian.PutUint32(dst[16:], v04) |
|||
binary.LittleEndian.PutUint32(dst[20:], v05) |
|||
binary.LittleEndian.PutUint32(dst[24:], v06) |
|||
binary.LittleEndian.PutUint32(dst[28:], v07) |
|||
binary.LittleEndian.PutUint32(dst[32:], v08) |
|||
binary.LittleEndian.PutUint32(dst[36:], v09) |
|||
binary.LittleEndian.PutUint32(dst[40:], v10) |
|||
binary.LittleEndian.PutUint32(dst[44:], v11) |
|||
binary.LittleEndian.PutUint32(dst[48:], v12) |
|||
binary.LittleEndian.PutUint32(dst[52:], v13) |
|||
binary.LittleEndian.PutUint32(dst[56:], v14) |
|||
binary.LittleEndian.PutUint32(dst[60:], v15) |
|||
} |
|||
|
|||
func hChaCha20Generic(out *[32]byte, nonce *[16]byte, key *[32]byte) { |
|||
v00 := sigma[0] |
|||
v01 := sigma[1] |
|||
v02 := sigma[2] |
|||
v03 := sigma[3] |
|||
v04 := binary.LittleEndian.Uint32(key[0:]) |
|||
v05 := binary.LittleEndian.Uint32(key[4:]) |
|||
v06 := binary.LittleEndian.Uint32(key[8:]) |
|||
v07 := binary.LittleEndian.Uint32(key[12:]) |
|||
v08 := binary.LittleEndian.Uint32(key[16:]) |
|||
v09 := binary.LittleEndian.Uint32(key[20:]) |
|||
v10 := binary.LittleEndian.Uint32(key[24:]) |
|||
v11 := binary.LittleEndian.Uint32(key[28:]) |
|||
v12 := binary.LittleEndian.Uint32(nonce[0:]) |
|||
v13 := binary.LittleEndian.Uint32(nonce[4:]) |
|||
v14 := binary.LittleEndian.Uint32(nonce[8:]) |
|||
v15 := binary.LittleEndian.Uint32(nonce[12:]) |
|||
|
|||
for i := 0; i < 20; i += 2 { |
|||
v00 += v04 |
|||
v12 ^= v00 |
|||
v12 = (v12 << 16) | (v12 >> 16) |
|||
v08 += v12 |
|||
v04 ^= v08 |
|||
v04 = (v04 << 12) | (v04 >> 20) |
|||
v00 += v04 |
|||
v12 ^= v00 |
|||
v12 = (v12 << 8) | (v12 >> 24) |
|||
v08 += v12 |
|||
v04 ^= v08 |
|||
v04 = (v04 << 7) | (v04 >> 25) |
|||
v01 += v05 |
|||
v13 ^= v01 |
|||
v13 = (v13 << 16) | (v13 >> 16) |
|||
v09 += v13 |
|||
v05 ^= v09 |
|||
v05 = (v05 << 12) | (v05 >> 20) |
|||
v01 += v05 |
|||
v13 ^= v01 |
|||
v13 = (v13 << 8) | (v13 >> 24) |
|||
v09 += v13 |
|||
v05 ^= v09 |
|||
v05 = (v05 << 7) | (v05 >> 25) |
|||
v02 += v06 |
|||
v14 ^= v02 |
|||
v14 = (v14 << 16) | (v14 >> 16) |
|||
v10 += v14 |
|||
v06 ^= v10 |
|||
v06 = (v06 << 12) | (v06 >> 20) |
|||
v02 += v06 |
|||
v14 ^= v02 |
|||
v14 = (v14 << 8) | (v14 >> 24) |
|||
v10 += v14 |
|||
v06 ^= v10 |
|||
v06 = (v06 << 7) | (v06 >> 25) |
|||
v03 += v07 |
|||
v15 ^= v03 |
|||
v15 = (v15 << 16) | (v15 >> 16) |
|||
v11 += v15 |
|||
v07 ^= v11 |
|||
v07 = (v07 << 12) | (v07 >> 20) |
|||
v03 += v07 |
|||
v15 ^= v03 |
|||
v15 = (v15 << 8) | (v15 >> 24) |
|||
v11 += v15 |
|||
v07 ^= v11 |
|||
v07 = (v07 << 7) | (v07 >> 25) |
|||
v00 += v05 |
|||
v15 ^= v00 |
|||
v15 = (v15 << 16) | (v15 >> 16) |
|||
v10 += v15 |
|||
v05 ^= v10 |
|||
v05 = (v05 << 12) | (v05 >> 20) |
|||
v00 += v05 |
|||
v15 ^= v00 |
|||
v15 = (v15 << 8) | (v15 >> 24) |
|||
v10 += v15 |
|||
v05 ^= v10 |
|||
v05 = (v05 << 7) | (v05 >> 25) |
|||
v01 += v06 |
|||
v12 ^= v01 |
|||
v12 = (v12 << 16) | (v12 >> 16) |
|||
v11 += v12 |
|||
v06 ^= v11 |
|||
v06 = (v06 << 12) | (v06 >> 20) |
|||
v01 += v06 |
|||
v12 ^= v01 |
|||
v12 = (v12 << 8) | (v12 >> 24) |
|||
v11 += v12 |
|||
v06 ^= v11 |
|||
v06 = (v06 << 7) | (v06 >> 25) |
|||
v02 += v07 |
|||
v13 ^= v02 |
|||
v13 = (v13 << 16) | (v13 >> 16) |
|||
v08 += v13 |
|||
v07 ^= v08 |
|||
v07 = (v07 << 12) | (v07 >> 20) |
|||
v02 += v07 |
|||
v13 ^= v02 |
|||
v13 = (v13 << 8) | (v13 >> 24) |
|||
v08 += v13 |
|||
v07 ^= v08 |
|||
v07 = (v07 << 7) | (v07 >> 25) |
|||
v03 += v04 |
|||
v14 ^= v03 |
|||
v14 = (v14 << 16) | (v14 >> 16) |
|||
v09 += v14 |
|||
v04 ^= v09 |
|||
v04 = (v04 << 12) | (v04 >> 20) |
|||
v03 += v04 |
|||
v14 ^= v03 |
|||
v14 = (v14 << 8) | (v14 >> 24) |
|||
v09 += v14 |
|||
v04 ^= v09 |
|||
v04 = (v04 << 7) | (v04 >> 25) |
|||
} |
|||
|
|||
binary.LittleEndian.PutUint32(out[0:], v00) |
|||
binary.LittleEndian.PutUint32(out[4:], v01) |
|||
binary.LittleEndian.PutUint32(out[8:], v02) |
|||
binary.LittleEndian.PutUint32(out[12:], v03) |
|||
binary.LittleEndian.PutUint32(out[16:], v12) |
|||
binary.LittleEndian.PutUint32(out[20:], v13) |
|||
binary.LittleEndian.PutUint32(out[24:], v14) |
|||
binary.LittleEndian.PutUint32(out[28:], v15) |
|||
} |
|||
@ -0,0 +1,56 @@ |
|||
// Copyright (c) 2017 Andreas Auernhammer. All rights reserved.
|
|||
// Use of this source code is governed by a license that can be
|
|||
// found in the LICENSE file.
|
|||
|
|||
// +build amd64,!gccgo,!appengine,!nacl,!go1.7
|
|||
|
|||
package chacha |
|||
|
|||
func init() { |
|||
useSSE2 = true |
|||
useSSSE3 = supportsSSSE3() |
|||
useAVX2 = false |
|||
} |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func initialize(state *[64]byte, key []byte, nonce *[16]byte) |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func supportsSSSE3() bool |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
|
|||
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { |
|||
if useSSSE3 { |
|||
hChaCha20SSSE3(out, nonce, key) |
|||
} else if useSSE2 { // on amd64 this is always true - used to test generic on amd64
|
|||
hChaCha20SSE2(out, nonce, key) |
|||
} else { |
|||
hChaCha20Generic(out, nonce, key) |
|||
} |
|||
} |
|||
|
|||
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { |
|||
if useSSSE3 { |
|||
return xorKeyStreamSSSE3(dst, src, block, state, rounds) |
|||
} else if useSSE2 { // on amd64 this is always true - used to test generic on amd64
|
|||
return xorKeyStreamSSE2(dst, src, block, state, rounds) |
|||
} |
|||
return xorKeyStreamGeneric(dst, src, block, state, rounds) |
|||
} |
|||
@ -0,0 +1,72 @@ |
|||
// Copyright (c) 2017 Andreas Auernhammer. All rights reserved.
|
|||
// Use of this source code is governed by a license that can be
|
|||
// found in the LICENSE file.
|
|||
|
|||
// +build go1.7,amd64,!gccgo,!appengine,!nacl
|
|||
|
|||
package chacha |
|||
|
|||
func init() { |
|||
useSSE2 = true |
|||
useSSSE3 = supportsSSSE3() |
|||
useAVX2 = supportsAVX2() |
|||
} |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func initialize(state *[64]byte, key []byte, nonce *[16]byte) |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func supportsSSSE3() bool |
|||
|
|||
// This function is implemented in chachaAVX2_amd64.s
|
|||
//go:noescape
|
|||
func supportsAVX2() bool |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
|
|||
// This function is implemented in chachaAVX2_amd64.s
|
|||
//go:noescape
|
|||
func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte) |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
|
|||
// This function is implemented in chacha_amd64.s
|
|||
//go:noescape
|
|||
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
|
|||
// This function is implemented in chachaAVX2_amd64.s
|
|||
//go:noescape
|
|||
func xorKeyStreamAVX2(dst, src []byte, block, state *[64]byte, rounds int) int |
|||
|
|||
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { |
|||
if useAVX2 { |
|||
hChaCha20AVX(out, nonce, key) |
|||
} else if useSSSE3 { |
|||
hChaCha20SSSE3(out, nonce, key) |
|||
} else if useSSE2 { // on amd64 this is always true - neccessary for testing generic on amd64
|
|||
hChaCha20SSE2(out, nonce, key) |
|||
} else { |
|||
hChaCha20Generic(out, nonce, key) |
|||
} |
|||
} |
|||
|
|||
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { |
|||
if useAVX2 { |
|||
return xorKeyStreamAVX2(dst, src, block, state, rounds) |
|||
} else if useSSSE3 { |
|||
return xorKeyStreamSSSE3(dst, src, block, state, rounds) |
|||
} else if useSSE2 { // on amd64 this is always true - neccessary for testing generic on amd64
|
|||
return xorKeyStreamSSE2(dst, src, block, state, rounds) |
|||
} |
|||
return xorKeyStreamGeneric(dst, src, block, state, rounds) |
|||
} |
|||
@ -0,0 +1,26 @@ |
|||
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
|
|||
// Use of this source code is governed by a license that can be
|
|||
// found in the LICENSE file.
|
|||
|
|||
// +build !amd64,!386 gccgo appengine nacl
|
|||
|
|||
package chacha |
|||
|
|||
import "encoding/binary" |
|||
|
|||
func initialize(state *[64]byte, key []byte, nonce *[16]byte) { |
|||
binary.LittleEndian.PutUint32(state[0:], sigma[0]) |
|||
binary.LittleEndian.PutUint32(state[4:], sigma[1]) |
|||
binary.LittleEndian.PutUint32(state[8:], sigma[2]) |
|||
binary.LittleEndian.PutUint32(state[12:], sigma[3]) |
|||
copy(state[16:], key[:]) |
|||
copy(state[48:], nonce[:]) |
|||
} |
|||
|
|||
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { |
|||
return xorKeyStreamGeneric(dst, src, block, state, rounds) |
|||
} |
|||
|
|||
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { |
|||
hChaCha20Generic(out, nonce, key) |
|||
} |
|||
@ -0,0 +1,41 @@ |
|||
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
|
|||
// Use of this source code is governed by a license that can be
|
|||
// found in the LICENSE file.
|
|||
|
|||
// Package chacha20 implements the ChaCha20 / XChaCha20 stream chipher.
|
|||
// Notice that one specific key-nonce combination must be unique for all time.
|
|||
//
|
|||
// There are three versions of ChaCha20:
|
|||
// - ChaCha20 with a 64 bit nonce (en/decrypt up to 2^64 * 64 bytes for one key-nonce combination)
|
|||
// - ChaCha20 with a 96 bit nonce (en/decrypt up to 2^32 * 64 bytes (~256 GB) for one key-nonce combination)
|
|||
// - XChaCha20 with a 192 bit nonce (en/decrypt up to 2^64 * 64 bytes for one key-nonce combination)
|
|||
package chacha20 // import "github.com/aead/chacha20"
|
|||
|
|||
import ( |
|||
"crypto/cipher" |
|||
|
|||
"github.com/aead/chacha20/chacha" |
|||
) |
|||
|
|||
// XORKeyStream crypts bytes from src to dst using the given nonce and key.
|
|||
// The length of the nonce determinds the version of ChaCha20:
|
|||
// - 8 bytes: ChaCha20 with a 64 bit nonce and a 2^64 * 64 byte period.
|
|||
// - 12 bytes: ChaCha20 as defined in RFC 7539 and a 2^32 * 64 byte period.
|
|||
// - 24 bytes: XChaCha20 with a 192 bit nonce and a 2^64 * 64 byte period.
|
|||
// Src and dst may be the same slice but otherwise should not overlap.
|
|||
// If len(dst) < len(src) this function panics.
|
|||
// If the nonce is neither 64, 96 nor 192 bits long, this function panics.
|
|||
func XORKeyStream(dst, src, nonce, key []byte) { |
|||
chacha.XORKeyStream(dst, src, nonce, key, 20) |
|||
} |
|||
|
|||
// NewCipher returns a new cipher.Stream implementing a ChaCha20 version.
|
|||
// The nonce must be unique for one key for all time.
|
|||
// The length of the nonce determinds the version of ChaCha20:
|
|||
// - 8 bytes: ChaCha20 with a 64 bit nonce and a 2^64 * 64 byte period.
|
|||
// - 12 bytes: ChaCha20 as defined in RFC 7539 and a 2^32 * 64 byte period.
|
|||
// - 24 bytes: XChaCha20 with a 192 bit nonce and a 2^64 * 64 byte period.
|
|||
// If the nonce is neither 64, 96 nor 192 bits long, a non-nil error is returned.
|
|||
func NewCipher(nonce, key []byte) (cipher.Stream, error) { |
|||
return chacha.NewCipher(nonce, key, 20) |
|||
} |
|||
@ -0,0 +1,21 @@ |
|||
The MIT License (MIT) |
|||
|
|||
Copyright (c) 2016 Richard Barnes |
|||
|
|||
Permission is hereby granted, free of charge, to any person obtaining a copy |
|||
of this software and associated documentation files (the "Software"), to deal |
|||
in the Software without restriction, including without limitation the rights |
|||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|||
copies of the Software, and to permit persons to whom the Software is |
|||
furnished to do so, subject to the following conditions: |
|||
|
|||
The above copyright notice and this permission notice shall be included in |
|||
all copies or substantial portions of the Software. |
|||
|
|||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
|||
THE SOFTWARE. |
|||
@ -0,0 +1,88 @@ |
|||
 |
|||
|
|||
mint - A Minimal TLS 1.3 stack |
|||
============================== |
|||
|
|||
[](https://circleci.com/gh/bifurcation/mint) |
|||
|
|||
This project is primarily a learning effort for me to understand the [TLS |
|||
1.3](http://tlswg.github.io/tls13-spec/) protocol. The goal is to arrive at a |
|||
pretty complete implementation of TLS 1.3, with minimal, elegant code that |
|||
demonstrates how things work. Testing is a priority to ensure correctness, but |
|||
otherwise, the quality of the software engineering might not be at a level where |
|||
it makes sense to integrate this with other libraries. Backward compatibility |
|||
is not an objective. |
|||
|
|||
We borrow liberally from the [Go TLS |
|||
library](https://golang.org/pkg/crypto/tls/), especially where TLS 1.3 aligns |
|||
with earlier TLS versions. However, unnecessary parts will be ruthlessly cut |
|||
off. |
|||
|
|||
## Quickstart |
|||
|
|||
Installation is the same as for any other Go package: |
|||
|
|||
``` |
|||
go get github.com/bifurcation/mint |
|||
``` |
|||
|
|||
The API is pretty much the same as for the TLS module, with `Dial` and `Listen` |
|||
methods wrapping the underlying socket APIs. |
|||
|
|||
``` |
|||
conn, err := mint.Dial("tcp", "localhost:4430", &mint.Config{...}) |
|||
... |
|||
listener, err := mint.Listen("tcp", "localhost:4430", &mint.Config{...}) |
|||
``` |
|||
|
|||
Documentation is available on |
|||
[godoc.org](https://godoc.org/github.com/bifurcation/mint) |
|||
|
|||
|
|||
## Interoperability testing |
|||
|
|||
The `mint-client` and `mint-server` executables are included to make it easy to |
|||
do basic interoperability tests with other TLS 1.3 implementations. The steps |
|||
for testing against NSS are as follows. |
|||
|
|||
``` |
|||
# Install mint |
|||
go get github.com/bifurcation/mint |
|||
|
|||
# Environment for NSS (you'll probably want a new directory) |
|||
NSS_ROOT=<whereever you want to put NSS> |
|||
mkdir $NSS_ROOT |
|||
cd $NSS_ROOT |
|||
export USE_64=1 |
|||
export ENABLE_TLS_1_3=1 |
|||
export HOST=localhost |
|||
export DOMSUF=localhost |
|||
|
|||
# Build NSS |
|||
hg clone https://hg.mozilla.org/projects/nss |
|||
hg clone https://hg.mozilla.org/projects/nspr |
|||
cd nss |
|||
make nss_build_all |
|||
|
|||
export PLATFORM=`cat $NSS_ROOT/dist/latest` |
|||
export DYLD_LIBRARY_PATH=$NSS_ROOT/dist/$PLATFORM/lib |
|||
export LD_LIBRARY_PATH=$NSS_ROOT/dist/$PLATFORM/lib |
|||
|
|||
# Run NSS tests (this creates data for the server to use) |
|||
cd tests/ssl_gtests |
|||
./ssl_gtests.sh |
|||
|
|||
# Test with client=mint server=NSS |
|||
cd $NSS_ROOT |
|||
./dist/$PLATFORM/bin/selfserv -d tests_results/security/$HOST.1/ssl_gtests/ -n rsa -p 4430 |
|||
# if you get `NSS_Init failed.`, check the path above, particularly around $HOST |
|||
# ... |
|||
go run $GOPATH/src/github.com/bifurcation/mint/bin/mint-client/main.go |
|||
|
|||
# Test with client=NSS server=mint |
|||
go run $GOPATH/src/github.com/bifurcation/mint/bin/mint-server/main.go |
|||
# ... |
|||
cd $NSS_ROOT |
|||
dist/$PLATFORM/bin/tstclnt -d tests_results/security/$HOST/ssl_gtests/ -V tls1.3:tls1.3 -h 127.0.0.1 -p 4430 -o |
|||
``` |
|||
|
|||
@ -0,0 +1,99 @@ |
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package mint |
|||
|
|||
import "strconv" |
|||
|
|||
type Alert uint8 |
|||
|
|||
const ( |
|||
// alert level
|
|||
AlertLevelWarning = 1 |
|||
AlertLevelError = 2 |
|||
) |
|||
|
|||
const ( |
|||
AlertCloseNotify Alert = 0 |
|||
AlertUnexpectedMessage Alert = 10 |
|||
AlertBadRecordMAC Alert = 20 |
|||
AlertDecryptionFailed Alert = 21 |
|||
AlertRecordOverflow Alert = 22 |
|||
AlertDecompressionFailure Alert = 30 |
|||
AlertHandshakeFailure Alert = 40 |
|||
AlertBadCertificate Alert = 42 |
|||
AlertUnsupportedCertificate Alert = 43 |
|||
AlertCertificateRevoked Alert = 44 |
|||
AlertCertificateExpired Alert = 45 |
|||
AlertCertificateUnknown Alert = 46 |
|||
AlertIllegalParameter Alert = 47 |
|||
AlertUnknownCA Alert = 48 |
|||
AlertAccessDenied Alert = 49 |
|||
AlertDecodeError Alert = 50 |
|||
AlertDecryptError Alert = 51 |
|||
AlertProtocolVersion Alert = 70 |
|||
AlertInsufficientSecurity Alert = 71 |
|||
AlertInternalError Alert = 80 |
|||
AlertInappropriateFallback Alert = 86 |
|||
AlertUserCanceled Alert = 90 |
|||
AlertNoRenegotiation Alert = 100 |
|||
AlertMissingExtension Alert = 109 |
|||
AlertUnsupportedExtension Alert = 110 |
|||
AlertCertificateUnobtainable Alert = 111 |
|||
AlertUnrecognizedName Alert = 112 |
|||
AlertBadCertificateStatsResponse Alert = 113 |
|||
AlertBadCertificateHashValue Alert = 114 |
|||
AlertUnknownPSKIdentity Alert = 115 |
|||
AlertNoApplicationProtocol Alert = 120 |
|||
AlertWouldBlock Alert = 254 |
|||
AlertNoAlert Alert = 255 |
|||
) |
|||
|
|||
var alertText = map[Alert]string{ |
|||
AlertCloseNotify: "close notify", |
|||
AlertUnexpectedMessage: "unexpected message", |
|||
AlertBadRecordMAC: "bad record MAC", |
|||
AlertDecryptionFailed: "decryption failed", |
|||
AlertRecordOverflow: "record overflow", |
|||
AlertDecompressionFailure: "decompression failure", |
|||
AlertHandshakeFailure: "handshake failure", |
|||
AlertBadCertificate: "bad certificate", |
|||
AlertUnsupportedCertificate: "unsupported certificate", |
|||
AlertCertificateRevoked: "revoked certificate", |
|||
AlertCertificateExpired: "expired certificate", |
|||
AlertCertificateUnknown: "unknown certificate", |
|||
AlertIllegalParameter: "illegal parameter", |
|||
AlertUnknownCA: "unknown certificate authority", |
|||
AlertAccessDenied: "access denied", |
|||
AlertDecodeError: "error decoding message", |
|||
AlertDecryptError: "error decrypting message", |
|||
AlertProtocolVersion: "protocol version not supported", |
|||
AlertInsufficientSecurity: "insufficient security level", |
|||
AlertInternalError: "internal error", |
|||
AlertInappropriateFallback: "inappropriate fallback", |
|||
AlertUserCanceled: "user canceled", |
|||
AlertMissingExtension: "missing extension", |
|||
AlertUnsupportedExtension: "unsupported extension", |
|||
AlertCertificateUnobtainable: "certificate unobtainable", |
|||
AlertUnrecognizedName: "unrecognized name", |
|||
AlertBadCertificateStatsResponse: "bad certificate status response", |
|||
AlertBadCertificateHashValue: "bad certificate hash value", |
|||
AlertUnknownPSKIdentity: "unknown PSK identity", |
|||
AlertNoApplicationProtocol: "no application protocol", |
|||
AlertNoRenegotiation: "no renegotiation", |
|||
AlertWouldBlock: "would have blocked", |
|||
AlertNoAlert: "no alert", |
|||
} |
|||
|
|||
func (e Alert) String() string { |
|||
s, ok := alertText[e] |
|||
if ok { |
|||
return s |
|||
} |
|||
return "alert(" + strconv.Itoa(int(e)) + ")" |
|||
} |
|||
|
|||
func (e Alert) Error() string { |
|||
return e.String() |
|||
} |
|||
@ -0,0 +1,942 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto" |
|||
"hash" |
|||
"time" |
|||
) |
|||
|
|||
// Client State Machine
|
|||
//
|
|||
// START <----+
|
|||
// Send ClientHello | | Recv HelloRetryRequest
|
|||
// / v |
|
|||
// | WAIT_SH ---+
|
|||
// Can | | Recv ServerHello
|
|||
// send | V
|
|||
// early | WAIT_EE
|
|||
// data | | Recv EncryptedExtensions
|
|||
// | +--------+--------+
|
|||
// | Using | | Using certificate
|
|||
// | PSK | v
|
|||
// | | WAIT_CERT_CR
|
|||
// | | Recv | | Recv CertificateRequest
|
|||
// | | Certificate | v
|
|||
// | | | WAIT_CERT
|
|||
// | | | | Recv Certificate
|
|||
// | | v v
|
|||
// | | WAIT_CV
|
|||
// | | | Recv CertificateVerify
|
|||
// | +> WAIT_FINISHED <+
|
|||
// | | Recv Finished
|
|||
// \ |
|
|||
// | [Send EndOfEarlyData]
|
|||
// | [Send Certificate [+ CertificateVerify]]
|
|||
// | Send Finished
|
|||
// Can send v
|
|||
// app data --> CONNECTED
|
|||
// after
|
|||
// here
|
|||
//
|
|||
// State Instructions
|
|||
// START Send(CH); [RekeyOut; SendEarlyData]
|
|||
// WAIT_SH Send(CH) || RekeyIn
|
|||
// WAIT_EE {}
|
|||
// WAIT_CERT_CR {}
|
|||
// WAIT_CERT {}
|
|||
// WAIT_CV {}
|
|||
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
|
|||
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
|||
|
|||
type ClientStateStart struct { |
|||
Caps Capabilities |
|||
Opts ConnectionOptions |
|||
Params ConnectionParameters |
|||
|
|||
cookie []byte |
|||
firstClientHello *HandshakeMessage |
|||
helloRetryRequest *HandshakeMessage |
|||
} |
|||
|
|||
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm != nil { |
|||
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
// key_shares
|
|||
offeredDH := map[NamedGroup][]byte{} |
|||
ks := KeyShareExtension{ |
|||
HandshakeType: HandshakeTypeClientHello, |
|||
Shares: make([]KeyShareEntry, len(state.Caps.Groups)), |
|||
} |
|||
for i, group := range state.Caps.Groups { |
|||
pub, priv, err := newKeyShare(group) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
ks.Shares[i].Group = group |
|||
ks.Shares[i].KeyExchange = pub |
|||
offeredDH[group] = priv |
|||
} |
|||
|
|||
logf(logTypeHandshake, "opts: %+v", state.Opts) |
|||
|
|||
// supported_versions, supported_groups, signature_algorithms, server_name
|
|||
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}} |
|||
sni := ServerNameExtension(state.Opts.ServerName) |
|||
sg := SupportedGroupsExtension{Groups: state.Caps.Groups} |
|||
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} |
|||
|
|||
state.Params.ServerName = state.Opts.ServerName |
|||
|
|||
// Application Layer Protocol Negotiation
|
|||
var alpn *ALPNExtension |
|||
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) { |
|||
alpn = &ALPNExtension{Protocols: state.Opts.NextProtos} |
|||
} |
|||
|
|||
// Construct base ClientHello
|
|||
ch := &ClientHelloBody{ |
|||
CipherSuites: state.Caps.CipherSuites, |
|||
} |
|||
_, err := prng.Read(ch.Random[:]) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} { |
|||
err := ch.Extensions.Add(ext) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
// XXX: These optional extensions can't be folded into the above because Go
|
|||
// interface-typed values are never reported as nil
|
|||
if alpn != nil { |
|||
err := ch.Extensions.Add(alpn) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
if state.cookie != nil { |
|||
err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie}) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
// Run the external extension handler.
|
|||
if state.Caps.ExtensionHandler != nil { |
|||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
// Handle PSK and EarlyData just before transmitting, so that we can
|
|||
// calculate the PSK binder value
|
|||
var psk *PreSharedKeyExtension |
|||
var ed *EarlyDataExtension |
|||
var offeredPSK PreSharedKey |
|||
var earlyHash crypto.Hash |
|||
var earlySecret []byte |
|||
var clientEarlyTrafficKeys keySet |
|||
var clientHello *HandshakeMessage |
|||
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok { |
|||
offeredPSK = key |
|||
|
|||
// Narrow ciphersuites to ones that match PSK hash
|
|||
params, ok := cipherSuiteMap[key.CipherSuite] |
|||
if !ok { |
|||
logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite") |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
compatibleSuites := []CipherSuite{} |
|||
for _, suite := range ch.CipherSuites { |
|||
if cipherSuiteMap[suite].Hash == params.Hash { |
|||
compatibleSuites = append(compatibleSuites, suite) |
|||
} |
|||
} |
|||
ch.CipherSuites = compatibleSuites |
|||
|
|||
// Signal early data if we're going to do it
|
|||
if len(state.Opts.EarlyData) > 0 { |
|||
state.Params.ClientSendingEarlyData = true |
|||
ed = &EarlyDataExtension{} |
|||
err = ch.Extensions.Add(ed) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "Error adding early data extension: %v", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
// Signal supported PSK key exchange modes
|
|||
if len(state.Caps.PSKModes) == 0 { |
|||
logf(logTypeHandshake, "PSK selected, but no PSKModes") |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes} |
|||
err = ch.Extensions.Add(kem) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
// Add the shim PSK extension to the ClientHello
|
|||
logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity) |
|||
psk = &PreSharedKeyExtension{ |
|||
HandshakeType: HandshakeTypeClientHello, |
|||
Identities: []PSKIdentity{ |
|||
{ |
|||
Identity: key.Identity, |
|||
ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd, |
|||
}, |
|||
}, |
|||
Binders: []PSKBinderEntry{ |
|||
// Note: Stub to get the length fields right
|
|||
{Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())}, |
|||
}, |
|||
} |
|||
ch.Extensions.Add(psk) |
|||
|
|||
// Compute the binder key
|
|||
h0 := params.Hash.New().Sum(nil) |
|||
zero := bytes.Repeat([]byte{0}, params.Hash.Size()) |
|||
|
|||
earlyHash = params.Hash |
|||
earlySecret = HkdfExtract(params.Hash, zero, key.Key) |
|||
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) |
|||
|
|||
binderLabel := labelExternalBinder |
|||
if key.IsResumption { |
|||
binderLabel = labelResumptionBinder |
|||
} |
|||
binderKey := deriveSecret(params, earlySecret, binderLabel, h0) |
|||
logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey) |
|||
|
|||
// Compute the binder value
|
|||
trunc, err := ch.Truncated() |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
truncHash := params.Hash.New() |
|||
truncHash.Write(trunc) |
|||
|
|||
binder := computeFinishedData(params, binderKey, truncHash.Sum(nil)) |
|||
|
|||
// Replace the PSK extension
|
|||
psk.Binders[0].Binder = binder |
|||
ch.Extensions.Add(psk) |
|||
|
|||
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
|
|||
// this one should too.
|
|||
clientHello, _ = HandshakeMessageFromBody(ch) |
|||
|
|||
// Compute early traffic keys
|
|||
h := params.Hash.New() |
|||
h.Write(clientHello.Marshal()) |
|||
chHash := h.Sum(nil) |
|||
|
|||
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) |
|||
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) |
|||
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) |
|||
} else if len(state.Opts.EarlyData) > 0 { |
|||
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK") |
|||
return nil, nil, AlertInternalError |
|||
} else { |
|||
clientHello, err = HandshakeMessageFromBody(ch) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]") |
|||
nextState := ClientStateWaitSH{ |
|||
Caps: state.Caps, |
|||
Opts: state.Opts, |
|||
Params: state.Params, |
|||
OfferedDH: offeredDH, |
|||
OfferedPSK: offeredPSK, |
|||
|
|||
earlySecret: earlySecret, |
|||
earlyHash: earlyHash, |
|||
|
|||
firstClientHello: state.firstClientHello, |
|||
helloRetryRequest: state.helloRetryRequest, |
|||
clientHello: clientHello, |
|||
} |
|||
|
|||
toSend := []HandshakeAction{ |
|||
SendHandshakeMessage{clientHello}, |
|||
} |
|||
if state.Params.ClientSendingEarlyData { |
|||
toSend = append(toSend, []HandshakeAction{ |
|||
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys}, |
|||
SendEarlyData{}, |
|||
}...) |
|||
} |
|||
|
|||
return nextState, toSend, AlertNoAlert |
|||
} |
|||
|
|||
type ClientStateWaitSH struct { |
|||
Caps Capabilities |
|||
Opts ConnectionOptions |
|||
Params ConnectionParameters |
|||
OfferedDH map[NamedGroup][]byte |
|||
OfferedPSK PreSharedKey |
|||
PSK []byte |
|||
|
|||
earlySecret []byte |
|||
earlyHash crypto.Hash |
|||
|
|||
firstClientHello *HandshakeMessage |
|||
helloRetryRequest *HandshakeMessage |
|||
clientHello *HandshakeMessage |
|||
} |
|||
|
|||
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
bodyGeneric, err := hm.ToBody() |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
switch body := bodyGeneric.(type) { |
|||
case *HelloRetryRequestBody: |
|||
hrr := body |
|||
|
|||
if state.helloRetryRequest != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
// Check that the version sent by the server is the one we support
|
|||
if hrr.Version != supportedVersion { |
|||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version) |
|||
return nil, nil, AlertProtocolVersion |
|||
} |
|||
|
|||
// Check that the server provided a supported ciphersuite
|
|||
supportedCipherSuite := false |
|||
for _, suite := range state.Caps.CipherSuites { |
|||
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite) |
|||
} |
|||
if !supportedCipherSuite { |
|||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite) |
|||
return nil, nil, AlertHandshakeFailure |
|||
} |
|||
|
|||
// Narrow the supported ciphersuites to the server-provided one
|
|||
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite} |
|||
|
|||
// Handle external extensions.
|
|||
if state.Caps.ExtensionHandler != nil { |
|||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
// The only thing we know how to respond to in an HRR is the Cookie
|
|||
// extension, so if there is either no Cookie extension or anything other
|
|||
// than a Cookie extension, we have to fail.
|
|||
serverCookie := new(CookieExtension) |
|||
foundCookie := hrr.Extensions.Find(serverCookie) |
|||
if !foundCookie || len(hrr.Extensions) != 1 { |
|||
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions)) |
|||
return nil, nil, AlertIllegalParameter |
|||
} |
|||
|
|||
// Hash the body into a pseudo-message
|
|||
// XXX: Ignoring some errors here
|
|||
params := cipherSuiteMap[hrr.CipherSuite] |
|||
h := params.Hash.New() |
|||
h.Write(state.clientHello.Marshal()) |
|||
firstClientHello := &HandshakeMessage{ |
|||
msgType: HandshakeTypeMessageHash, |
|||
body: h.Sum(nil), |
|||
} |
|||
|
|||
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") |
|||
return ClientStateStart{ |
|||
Caps: state.Caps, |
|||
Opts: state.Opts, |
|||
cookie: serverCookie.Cookie, |
|||
firstClientHello: firstClientHello, |
|||
helloRetryRequest: hm, |
|||
}.Next(nil) |
|||
|
|||
case *ServerHelloBody: |
|||
sh := body |
|||
|
|||
// Check that the version sent by the server is the one we support
|
|||
if sh.Version != supportedVersion { |
|||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version) |
|||
return nil, nil, AlertProtocolVersion |
|||
} |
|||
|
|||
// Check that the server provided a supported ciphersuite
|
|||
supportedCipherSuite := false |
|||
for _, suite := range state.Caps.CipherSuites { |
|||
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite) |
|||
} |
|||
if !supportedCipherSuite { |
|||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite) |
|||
return nil, nil, AlertHandshakeFailure |
|||
} |
|||
|
|||
// Handle external extensions.
|
|||
if state.Caps.ExtensionHandler != nil { |
|||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
// Do PSK or key agreement depending on extensions
|
|||
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} |
|||
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} |
|||
|
|||
foundPSK := sh.Extensions.Find(&serverPSK) |
|||
foundKeyShare := sh.Extensions.Find(&serverKeyShare) |
|||
|
|||
if foundPSK && (serverPSK.SelectedIdentity == 0) { |
|||
state.Params.UsingPSK = true |
|||
} |
|||
|
|||
var dhSecret []byte |
|||
if foundKeyShare { |
|||
sks := serverKeyShare.Shares[0] |
|||
priv, ok := state.OfferedDH[sks.Group] |
|||
if !ok { |
|||
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group") |
|||
return nil, nil, AlertIllegalParameter |
|||
} |
|||
|
|||
state.Params.UsingDH = true |
|||
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv) |
|||
} |
|||
|
|||
suite := sh.CipherSuite |
|||
state.Params.CipherSuite = suite |
|||
|
|||
params, ok := cipherSuiteMap[suite] |
|||
if !ok { |
|||
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite) |
|||
return nil, nil, AlertHandshakeFailure |
|||
} |
|||
|
|||
// Start up the handshake hash
|
|||
handshakeHash := params.Hash.New() |
|||
handshakeHash.Write(state.firstClientHello.Marshal()) |
|||
handshakeHash.Write(state.helloRetryRequest.Marshal()) |
|||
handshakeHash.Write(state.clientHello.Marshal()) |
|||
handshakeHash.Write(hm.Marshal()) |
|||
|
|||
// Compute handshake secrets
|
|||
zero := bytes.Repeat([]byte{0}, params.Hash.Size()) |
|||
|
|||
var earlySecret []byte |
|||
if state.Params.UsingPSK { |
|||
if params.Hash != state.earlyHash { |
|||
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]", |
|||
state.earlyHash, suite, params.Hash) |
|||
} |
|||
|
|||
earlySecret = state.earlySecret |
|||
} else { |
|||
earlySecret = HkdfExtract(params.Hash, zero, zero) |
|||
} |
|||
|
|||
if dhSecret == nil { |
|||
dhSecret = zero |
|||
} |
|||
|
|||
h0 := params.Hash.New().Sum(nil) |
|||
h2 := handshakeHash.Sum(nil) |
|||
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) |
|||
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret) |
|||
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) |
|||
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) |
|||
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) |
|||
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) |
|||
|
|||
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) |
|||
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) |
|||
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) |
|||
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) |
|||
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) |
|||
|
|||
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) |
|||
|
|||
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") |
|||
nextState := ClientStateWaitEE{ |
|||
Caps: state.Caps, |
|||
Params: state.Params, |
|||
cryptoParams: params, |
|||
handshakeHash: handshakeHash, |
|||
certificates: state.Caps.Certificates, |
|||
masterSecret: masterSecret, |
|||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, |
|||
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, |
|||
} |
|||
toSend := []HandshakeAction{ |
|||
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys}, |
|||
} |
|||
return nextState, toSend, AlertNoAlert |
|||
} |
|||
|
|||
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType) |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
type ClientStateWaitEE struct { |
|||
Caps Capabilities |
|||
AuthCertificate func(chain []CertificateEntry) error |
|||
Params ConnectionParameters |
|||
cryptoParams CipherSuiteParams |
|||
handshakeHash hash.Hash |
|||
certificates []*Certificate |
|||
masterSecret []byte |
|||
clientHandshakeTrafficSecret []byte |
|||
serverHandshakeTrafficSecret []byte |
|||
} |
|||
|
|||
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions { |
|||
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
ee := EncryptedExtensionsBody{} |
|||
_, err := ee.Unmarshal(hm.body) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
// Handle external extensions.
|
|||
if state.Caps.ExtensionHandler != nil { |
|||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
serverALPN := ALPNExtension{} |
|||
serverEarlyData := EarlyDataExtension{} |
|||
|
|||
gotALPN := ee.Extensions.Find(&serverALPN) |
|||
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData) |
|||
|
|||
if gotALPN && len(serverALPN.Protocols) > 0 { |
|||
state.Params.NextProto = serverALPN.Protocols[0] |
|||
} |
|||
|
|||
state.handshakeHash.Write(hm.Marshal()) |
|||
|
|||
if state.Params.UsingPSK { |
|||
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") |
|||
nextState := ClientStateWaitFinished{ |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
handshakeHash: state.handshakeHash, |
|||
certificates: state.certificates, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
} |
|||
|
|||
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") |
|||
nextState := ClientStateWaitCertCR{ |
|||
AuthCertificate: state.AuthCertificate, |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
handshakeHash: state.handshakeHash, |
|||
certificates: state.certificates, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
} |
|||
|
|||
type ClientStateWaitCertCR struct { |
|||
AuthCertificate func(chain []CertificateEntry) error |
|||
Params ConnectionParameters |
|||
cryptoParams CipherSuiteParams |
|||
handshakeHash hash.Hash |
|||
certificates []*Certificate |
|||
masterSecret []byte |
|||
clientHandshakeTrafficSecret []byte |
|||
serverHandshakeTrafficSecret []byte |
|||
} |
|||
|
|||
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
bodyGeneric, err := hm.ToBody() |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
state.handshakeHash.Write(hm.Marshal()) |
|||
|
|||
switch body := bodyGeneric.(type) { |
|||
case *CertificateBody: |
|||
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]") |
|||
nextState := ClientStateWaitCV{ |
|||
AuthCertificate: state.AuthCertificate, |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
handshakeHash: state.handshakeHash, |
|||
certificates: state.certificates, |
|||
serverCertificate: body, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
|
|||
case *CertificateRequestBody: |
|||
// A certificate request in the handshake should have a zero-length context
|
|||
if len(body.CertificateRequestContext) > 0 { |
|||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err) |
|||
return nil, nil, AlertIllegalParameter |
|||
} |
|||
|
|||
state.Params.UsingClientAuth = true |
|||
|
|||
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]") |
|||
nextState := ClientStateWaitCert{ |
|||
AuthCertificate: state.AuthCertificate, |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
handshakeHash: state.handshakeHash, |
|||
certificates: state.certificates, |
|||
serverCertificateRequest: body, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
} |
|||
|
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
type ClientStateWaitCert struct { |
|||
AuthCertificate func(chain []CertificateEntry) error |
|||
Params ConnectionParameters |
|||
cryptoParams CipherSuiteParams |
|||
handshakeHash hash.Hash |
|||
|
|||
certificates []*Certificate |
|||
serverCertificateRequest *CertificateRequestBody |
|||
|
|||
masterSecret []byte |
|||
clientHandshakeTrafficSecret []byte |
|||
serverHandshakeTrafficSecret []byte |
|||
} |
|||
|
|||
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil || hm.msgType != HandshakeTypeCertificate { |
|||
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
cert := &CertificateBody{} |
|||
_, err := cert.Unmarshal(hm.body) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
state.handshakeHash.Write(hm.Marshal()) |
|||
|
|||
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]") |
|||
nextState := ClientStateWaitCV{ |
|||
AuthCertificate: state.AuthCertificate, |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
handshakeHash: state.handshakeHash, |
|||
certificates: state.certificates, |
|||
serverCertificate: cert, |
|||
serverCertificateRequest: state.serverCertificateRequest, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
} |
|||
|
|||
type ClientStateWaitCV struct { |
|||
AuthCertificate func(chain []CertificateEntry) error |
|||
Params ConnectionParameters |
|||
cryptoParams CipherSuiteParams |
|||
handshakeHash hash.Hash |
|||
|
|||
certificates []*Certificate |
|||
serverCertificate *CertificateBody |
|||
serverCertificateRequest *CertificateRequestBody |
|||
|
|||
masterSecret []byte |
|||
clientHandshakeTrafficSecret []byte |
|||
serverHandshakeTrafficSecret []byte |
|||
} |
|||
|
|||
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { |
|||
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
certVerify := CertificateVerifyBody{} |
|||
_, err := certVerify.Unmarshal(hm.body) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
hcv := state.handshakeHash.Sum(nil) |
|||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) |
|||
|
|||
serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey |
|||
if err := certVerify.Verify(serverPublicKey, hcv); err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify") |
|||
return nil, nil, AlertHandshakeFailure |
|||
} |
|||
|
|||
if state.AuthCertificate != nil { |
|||
err := state.AuthCertificate(state.serverCertificate.CertificateList) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate") |
|||
return nil, nil, AlertBadCertificate |
|||
} |
|||
} else { |
|||
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate") |
|||
} |
|||
|
|||
state.handshakeHash.Write(hm.Marshal()) |
|||
|
|||
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]") |
|||
nextState := ClientStateWaitFinished{ |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
handshakeHash: state.handshakeHash, |
|||
certificates: state.certificates, |
|||
serverCertificateRequest: state.serverCertificateRequest, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
} |
|||
|
|||
type ClientStateWaitFinished struct { |
|||
Params ConnectionParameters |
|||
cryptoParams CipherSuiteParams |
|||
handshakeHash hash.Hash |
|||
|
|||
certificates []*Certificate |
|||
serverCertificateRequest *CertificateRequestBody |
|||
|
|||
masterSecret []byte |
|||
clientHandshakeTrafficSecret []byte |
|||
serverHandshakeTrafficSecret []byte |
|||
} |
|||
|
|||
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil || hm.msgType != HandshakeTypeFinished { |
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
// Verify server's Finished
|
|||
h3 := state.handshakeHash.Sum(nil) |
|||
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) |
|||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) |
|||
|
|||
serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3) |
|||
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) |
|||
|
|||
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)} |
|||
_, err := fin.Unmarshal(hm.body) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
if !bytes.Equal(fin.VerifyData, serverFinishedData) { |
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]", |
|||
fin.VerifyData, serverFinishedData) |
|||
return nil, nil, AlertHandshakeFailure |
|||
} |
|||
|
|||
// Update the handshake hash with the Finished
|
|||
state.handshakeHash.Write(hm.Marshal()) |
|||
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal()) |
|||
h4 := state.handshakeHash.Sum(nil) |
|||
logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4) |
|||
|
|||
// Compute traffic secrets and keys
|
|||
clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4) |
|||
serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4) |
|||
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) |
|||
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) |
|||
|
|||
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret) |
|||
serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret) |
|||
|
|||
exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4) |
|||
logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret) |
|||
|
|||
// Assemble client's second flight
|
|||
toSend := []HandshakeAction{} |
|||
|
|||
if state.Params.UsingEarlyData { |
|||
// Note: We only send EOED if the server is actually going to use the early
|
|||
// data. Otherwise, it will never see it, and the transcripts will
|
|||
// mismatch.
|
|||
// EOED marshal is infallible
|
|||
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{}) |
|||
toSend = append(toSend, SendHandshakeMessage{eoedm}) |
|||
state.handshakeHash.Write(eoedm.Marshal()) |
|||
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) |
|||
} |
|||
|
|||
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) |
|||
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys}) |
|||
|
|||
if state.Params.UsingClientAuth { |
|||
// Extract constraints from certicateRequest
|
|||
schemes := SignatureAlgorithmsExtension{} |
|||
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes) |
|||
if !gotSchemes { |
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) |
|||
return nil, nil, AlertIllegalParameter |
|||
} |
|||
|
|||
// Select a certificate
|
|||
cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates) |
|||
if err != nil { |
|||
// XXX: Signal this to the application layer?
|
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) |
|||
|
|||
certificate := &CertificateBody{} |
|||
certm, err := HandshakeMessageFromBody(certificate) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
toSend = append(toSend, SendHandshakeMessage{certm}) |
|||
state.handshakeHash.Write(certm.Marshal()) |
|||
} else { |
|||
// Create and send Certificate, CertificateVerify
|
|||
certificate := &CertificateBody{ |
|||
CertificateList: make([]CertificateEntry, len(cert.Chain)), |
|||
} |
|||
for i, entry := range cert.Chain { |
|||
certificate.CertificateList[i] = CertificateEntry{CertData: entry} |
|||
} |
|||
certm, err := HandshakeMessageFromBody(certificate) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
toSend = append(toSend, SendHandshakeMessage{certm}) |
|||
state.handshakeHash.Write(certm.Marshal()) |
|||
|
|||
hcv := state.handshakeHash.Sum(nil) |
|||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) |
|||
|
|||
certificateVerify := &CertificateVerifyBody{Algorithm: certScheme} |
|||
logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash) |
|||
|
|||
err = certificateVerify.Sign(cert.PrivateKey, hcv) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
certvm, err := HandshakeMessageFromBody(certificateVerify) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
toSend = append(toSend, SendHandshakeMessage{certvm}) |
|||
state.handshakeHash.Write(certvm.Marshal()) |
|||
} |
|||
} |
|||
|
|||
// Compute the client's Finished message
|
|||
h5 := state.handshakeHash.Sum(nil) |
|||
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) |
|||
|
|||
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) |
|||
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) |
|||
|
|||
fin = &FinishedBody{ |
|||
VerifyDataLen: len(clientFinishedData), |
|||
VerifyData: clientFinishedData, |
|||
} |
|||
finm, err := HandshakeMessageFromBody(fin) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
// Compute the resumption secret
|
|||
state.handshakeHash.Write(finm.Marshal()) |
|||
h6 := state.handshakeHash.Sum(nil) |
|||
|
|||
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) |
|||
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) |
|||
|
|||
toSend = append(toSend, []HandshakeAction{ |
|||
SendHandshakeMessage{finm}, |
|||
RekeyIn{Label: "application", KeySet: serverTrafficKeys}, |
|||
RekeyOut{Label: "application", KeySet: clientTrafficKeys}, |
|||
}...) |
|||
|
|||
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") |
|||
nextState := StateConnected{ |
|||
Params: state.Params, |
|||
isClient: true, |
|||
cryptoParams: state.cryptoParams, |
|||
resumptionSecret: resumptionSecret, |
|||
clientTrafficSecret: clientTrafficSecret, |
|||
serverTrafficSecret: serverTrafficSecret, |
|||
exporterSecret: exporterSecret, |
|||
} |
|||
return nextState, toSend, AlertNoAlert |
|||
} |
|||
@ -0,0 +1,152 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"fmt" |
|||
"strconv" |
|||
) |
|||
|
|||
var ( |
|||
supportedVersion uint16 = 0x7f15 // draft-21
|
|||
|
|||
// Flags for some minor compat issues
|
|||
allowWrongVersionNumber = true |
|||
allowPKCS1 = true |
|||
) |
|||
|
|||
// enum {...} ContentType;
|
|||
type RecordType byte |
|||
|
|||
const ( |
|||
RecordTypeAlert RecordType = 21 |
|||
RecordTypeHandshake RecordType = 22 |
|||
RecordTypeApplicationData RecordType = 23 |
|||
) |
|||
|
|||
// enum {...} HandshakeType;
|
|||
type HandshakeType byte |
|||
|
|||
const ( |
|||
// Omitted: *_RESERVED
|
|||
HandshakeTypeClientHello HandshakeType = 1 |
|||
HandshakeTypeServerHello HandshakeType = 2 |
|||
HandshakeTypeNewSessionTicket HandshakeType = 4 |
|||
HandshakeTypeEndOfEarlyData HandshakeType = 5 |
|||
HandshakeTypeHelloRetryRequest HandshakeType = 6 |
|||
HandshakeTypeEncryptedExtensions HandshakeType = 8 |
|||
HandshakeTypeCertificate HandshakeType = 11 |
|||
HandshakeTypeCertificateRequest HandshakeType = 13 |
|||
HandshakeTypeCertificateVerify HandshakeType = 15 |
|||
HandshakeTypeServerConfiguration HandshakeType = 17 |
|||
HandshakeTypeFinished HandshakeType = 20 |
|||
HandshakeTypeKeyUpdate HandshakeType = 24 |
|||
HandshakeTypeMessageHash HandshakeType = 254 |
|||
) |
|||
|
|||
// uint8 CipherSuite[2];
|
|||
type CipherSuite uint16 |
|||
|
|||
const ( |
|||
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
|
|||
// value for this type so that we can detect when a field is set.
|
|||
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000 |
|||
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301 |
|||
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302 |
|||
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303 |
|||
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304 |
|||
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305 |
|||
) |
|||
|
|||
func (c CipherSuite) String() string { |
|||
switch c { |
|||
case CIPHER_SUITE_UNKNOWN: |
|||
return "unknown" |
|||
case TLS_AES_128_GCM_SHA256: |
|||
return "TLS_AES_128_GCM_SHA256" |
|||
case TLS_AES_256_GCM_SHA384: |
|||
return "TLS_AES_256_GCM_SHA384" |
|||
case TLS_CHACHA20_POLY1305_SHA256: |
|||
return "TLS_CHACHA20_POLY1305_SHA256" |
|||
case TLS_AES_128_CCM_SHA256: |
|||
return "TLS_AES_128_CCM_SHA256" |
|||
case TLS_AES_256_CCM_8_SHA256: |
|||
return "TLS_AES_256_CCM_8_SHA256" |
|||
} |
|||
// cannot use %x here, since it calls String(), leading to infinite recursion
|
|||
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16)) |
|||
} |
|||
|
|||
// enum {...} SignatureScheme
|
|||
type SignatureScheme uint16 |
|||
|
|||
const ( |
|||
// RSASSA-PKCS1-v1_5 algorithms
|
|||
RSA_PKCS1_SHA1 SignatureScheme = 0x0201 |
|||
RSA_PKCS1_SHA256 SignatureScheme = 0x0401 |
|||
RSA_PKCS1_SHA384 SignatureScheme = 0x0501 |
|||
RSA_PKCS1_SHA512 SignatureScheme = 0x0601 |
|||
// ECDSA algorithms
|
|||
ECDSA_P256_SHA256 SignatureScheme = 0x0403 |
|||
ECDSA_P384_SHA384 SignatureScheme = 0x0503 |
|||
ECDSA_P521_SHA512 SignatureScheme = 0x0603 |
|||
// RSASSA-PSS algorithms
|
|||
RSA_PSS_SHA256 SignatureScheme = 0x0804 |
|||
RSA_PSS_SHA384 SignatureScheme = 0x0805 |
|||
RSA_PSS_SHA512 SignatureScheme = 0x0806 |
|||
// EdDSA algorithms
|
|||
Ed25519 SignatureScheme = 0x0807 |
|||
Ed448 SignatureScheme = 0x0808 |
|||
) |
|||
|
|||
// enum {...} ExtensionType
|
|||
type ExtensionType uint16 |
|||
|
|||
const ( |
|||
ExtensionTypeServerName ExtensionType = 0 |
|||
ExtensionTypeSupportedGroups ExtensionType = 10 |
|||
ExtensionTypeSignatureAlgorithms ExtensionType = 13 |
|||
ExtensionTypeALPN ExtensionType = 16 |
|||
ExtensionTypeKeyShare ExtensionType = 40 |
|||
ExtensionTypePreSharedKey ExtensionType = 41 |
|||
ExtensionTypeEarlyData ExtensionType = 42 |
|||
ExtensionTypeSupportedVersions ExtensionType = 43 |
|||
ExtensionTypeCookie ExtensionType = 44 |
|||
ExtensionTypePSKKeyExchangeModes ExtensionType = 45 |
|||
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46 |
|||
) |
|||
|
|||
// enum {...} NamedGroup
|
|||
type NamedGroup uint16 |
|||
|
|||
const ( |
|||
// Elliptic Curve Groups.
|
|||
P256 NamedGroup = 23 |
|||
P384 NamedGroup = 24 |
|||
P521 NamedGroup = 25 |
|||
// ECDH functions.
|
|||
X25519 NamedGroup = 29 |
|||
X448 NamedGroup = 30 |
|||
// Finite field groups.
|
|||
FFDHE2048 NamedGroup = 256 |
|||
FFDHE3072 NamedGroup = 257 |
|||
FFDHE4096 NamedGroup = 258 |
|||
FFDHE6144 NamedGroup = 259 |
|||
FFDHE8192 NamedGroup = 260 |
|||
) |
|||
|
|||
// enum {...} PskKeyExchangeMode;
|
|||
type PSKKeyExchangeMode uint8 |
|||
|
|||
const ( |
|||
PSKModeKE PSKKeyExchangeMode = 0 |
|||
PSKModeDHEKE PSKKeyExchangeMode = 1 |
|||
) |
|||
|
|||
// enum {
|
|||
// update_not_requested(0), update_requested(1), (255)
|
|||
// } KeyUpdateRequest;
|
|||
type KeyUpdateRequest uint8 |
|||
|
|||
const ( |
|||
KeyUpdateNotRequested KeyUpdateRequest = 0 |
|||
KeyUpdateRequested KeyUpdateRequest = 1 |
|||
) |
|||
@ -0,0 +1,819 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"crypto" |
|||
"crypto/x509" |
|||
"encoding/hex" |
|||
"fmt" |
|||
"io" |
|||
"net" |
|||
"reflect" |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
var WouldBlock = fmt.Errorf("Would have blocked") |
|||
|
|||
type Certificate struct { |
|||
Chain []*x509.Certificate |
|||
PrivateKey crypto.Signer |
|||
} |
|||
|
|||
type PreSharedKey struct { |
|||
CipherSuite CipherSuite |
|||
IsResumption bool |
|||
Identity []byte |
|||
Key []byte |
|||
NextProto string |
|||
ReceivedAt time.Time |
|||
ExpiresAt time.Time |
|||
TicketAgeAdd uint32 |
|||
} |
|||
|
|||
type PreSharedKeyCache interface { |
|||
Get(string) (PreSharedKey, bool) |
|||
Put(string, PreSharedKey) |
|||
Size() int |
|||
} |
|||
|
|||
type PSKMapCache map[string]PreSharedKey |
|||
|
|||
// A CookieHandler does two things:
|
|||
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
|||
// - validates this byte string echoed by the client in the ClientHello
|
|||
type CookieHandler interface { |
|||
Generate(*Conn) ([]byte, error) |
|||
Validate(*Conn, []byte) bool |
|||
} |
|||
|
|||
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) { |
|||
psk, ok = cache[key] |
|||
return |
|||
} |
|||
|
|||
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) { |
|||
(*cache)[key] = psk |
|||
} |
|||
|
|||
func (cache PSKMapCache) Size() int { |
|||
return len(cache) |
|||
} |
|||
|
|||
// Config is the struct used to pass configuration settings to a TLS client or
|
|||
// server instance. The settings for client and server are pretty different,
|
|||
// but we just throw them all in here.
|
|||
type Config struct { |
|||
// Client fields
|
|||
ServerName string |
|||
|
|||
// Server fields
|
|||
SendSessionTickets bool |
|||
TicketLifetime uint32 |
|||
TicketLen int |
|||
EarlyDataLifetime uint32 |
|||
AllowEarlyData bool |
|||
// Require the client to echo a cookie.
|
|||
RequireCookie bool |
|||
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
|
|||
// The default cookie handler uses 32 random bytes as a cookie.
|
|||
CookieHandler CookieHandler |
|||
RequireClientAuth bool |
|||
|
|||
// Shared fields
|
|||
Certificates []*Certificate |
|||
AuthCertificate func(chain []CertificateEntry) error |
|||
CipherSuites []CipherSuite |
|||
Groups []NamedGroup |
|||
SignatureSchemes []SignatureScheme |
|||
NextProtos []string |
|||
PSKs PreSharedKeyCache |
|||
PSKModes []PSKKeyExchangeMode |
|||
NonBlocking bool |
|||
|
|||
// The same config object can be shared among different connections, so it
|
|||
// needs its own mutex
|
|||
mutex sync.RWMutex |
|||
} |
|||
|
|||
// Clone returns a shallow clone of c. It is safe to clone a Config that is
|
|||
// being used concurrently by a TLS client or server.
|
|||
func (c *Config) Clone() *Config { |
|||
c.mutex.Lock() |
|||
defer c.mutex.Unlock() |
|||
|
|||
return &Config{ |
|||
ServerName: c.ServerName, |
|||
|
|||
SendSessionTickets: c.SendSessionTickets, |
|||
TicketLifetime: c.TicketLifetime, |
|||
TicketLen: c.TicketLen, |
|||
EarlyDataLifetime: c.EarlyDataLifetime, |
|||
AllowEarlyData: c.AllowEarlyData, |
|||
RequireCookie: c.RequireCookie, |
|||
RequireClientAuth: c.RequireClientAuth, |
|||
|
|||
Certificates: c.Certificates, |
|||
AuthCertificate: c.AuthCertificate, |
|||
CipherSuites: c.CipherSuites, |
|||
Groups: c.Groups, |
|||
SignatureSchemes: c.SignatureSchemes, |
|||
NextProtos: c.NextProtos, |
|||
PSKs: c.PSKs, |
|||
PSKModes: c.PSKModes, |
|||
NonBlocking: c.NonBlocking, |
|||
} |
|||
} |
|||
|
|||
func (c *Config) Init(isClient bool) error { |
|||
c.mutex.Lock() |
|||
defer c.mutex.Unlock() |
|||
|
|||
// Set defaults
|
|||
if len(c.CipherSuites) == 0 { |
|||
c.CipherSuites = defaultSupportedCipherSuites |
|||
} |
|||
if len(c.Groups) == 0 { |
|||
c.Groups = defaultSupportedGroups |
|||
} |
|||
if len(c.SignatureSchemes) == 0 { |
|||
c.SignatureSchemes = defaultSignatureSchemes |
|||
} |
|||
if c.TicketLen == 0 { |
|||
c.TicketLen = defaultTicketLen |
|||
} |
|||
if !reflect.ValueOf(c.PSKs).IsValid() { |
|||
c.PSKs = &PSKMapCache{} |
|||
} |
|||
if len(c.PSKModes) == 0 { |
|||
c.PSKModes = defaultPSKModes |
|||
} |
|||
|
|||
// If there is no certificate, generate one
|
|||
if !isClient && len(c.Certificates) == 0 { |
|||
logf(logTypeHandshake, "Generating key name=%v", c.ServerName) |
|||
priv, err := newSigningKey(RSA_PSS_SHA256) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
c.Certificates = []*Certificate{ |
|||
{ |
|||
Chain: []*x509.Certificate{cert}, |
|||
PrivateKey: priv, |
|||
}, |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (c *Config) ValidForServer() bool { |
|||
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) || |
|||
(len(c.Certificates) > 0 && |
|||
len(c.Certificates[0].Chain) > 0 && |
|||
c.Certificates[0].PrivateKey != nil) |
|||
} |
|||
|
|||
func (c *Config) ValidForClient() bool { |
|||
return len(c.ServerName) > 0 |
|||
} |
|||
|
|||
var ( |
|||
defaultSupportedCipherSuites = []CipherSuite{ |
|||
TLS_AES_128_GCM_SHA256, |
|||
TLS_AES_256_GCM_SHA384, |
|||
} |
|||
|
|||
defaultSupportedGroups = []NamedGroup{ |
|||
P256, |
|||
P384, |
|||
FFDHE2048, |
|||
X25519, |
|||
} |
|||
|
|||
defaultSignatureSchemes = []SignatureScheme{ |
|||
RSA_PSS_SHA256, |
|||
RSA_PSS_SHA384, |
|||
RSA_PSS_SHA512, |
|||
ECDSA_P256_SHA256, |
|||
ECDSA_P384_SHA384, |
|||
ECDSA_P521_SHA512, |
|||
} |
|||
|
|||
defaultTicketLen = 16 |
|||
|
|||
defaultPSKModes = []PSKKeyExchangeMode{ |
|||
PSKModeKE, |
|||
PSKModeDHEKE, |
|||
} |
|||
) |
|||
|
|||
type ConnectionState struct { |
|||
HandshakeState string // string representation of the handshake state.
|
|||
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
|||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO([email protected]): implement
|
|||
NextProto string // Selected ALPN proto
|
|||
} |
|||
|
|||
// Conn implements the net.Conn interface, as with "crypto/tls"
|
|||
// * Read, Write, and Close are provided locally
|
|||
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
|
|||
type Conn struct { |
|||
config *Config |
|||
conn net.Conn |
|||
isClient bool |
|||
|
|||
EarlyData []byte |
|||
|
|||
state StateConnected |
|||
hState HandshakeState |
|||
handshakeMutex sync.Mutex |
|||
handshakeAlert Alert |
|||
handshakeComplete bool |
|||
|
|||
readBuffer []byte |
|||
in, out *RecordLayer |
|||
hIn, hOut *HandshakeLayer |
|||
|
|||
extHandler AppExtensionHandler |
|||
} |
|||
|
|||
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { |
|||
c := &Conn{conn: conn, config: config, isClient: isClient} |
|||
c.in = NewRecordLayer(c.conn) |
|||
c.out = NewRecordLayer(c.conn) |
|||
c.hIn = NewHandshakeLayer(c.in) |
|||
c.hIn.nonblocking = c.config.NonBlocking |
|||
c.hOut = NewHandshakeLayer(c.out) |
|||
return c |
|||
} |
|||
|
|||
// Read up
|
|||
func (c *Conn) consumeRecord() error { |
|||
pt, err := c.in.ReadRecord() |
|||
if pt == nil { |
|||
logf(logTypeIO, "extendBuffer returns error %v", err) |
|||
return err |
|||
} |
|||
|
|||
switch pt.contentType { |
|||
case RecordTypeHandshake: |
|||
logf(logTypeHandshake, "Received post-handshake message") |
|||
// We do not support fragmentation of post-handshake handshake messages.
|
|||
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
|
|||
start := 0 |
|||
for start < len(pt.fragment) { |
|||
if len(pt.fragment[start:]) < handshakeHeaderLen { |
|||
return fmt.Errorf("Post-handshake handshake message too short for header") |
|||
} |
|||
|
|||
hm := &HandshakeMessage{} |
|||
hm.msgType = HandshakeType(pt.fragment[start]) |
|||
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3]) |
|||
|
|||
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen { |
|||
return fmt.Errorf("Post-handshake handshake message too short for body") |
|||
} |
|||
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen] |
|||
|
|||
// Advance state machine
|
|||
state, actions, alert := c.state.Next(hm) |
|||
|
|||
if alert != AlertNoAlert { |
|||
logf(logTypeHandshake, "Error in state transition: %v", alert) |
|||
c.sendAlert(alert) |
|||
return io.EOF |
|||
} |
|||
|
|||
for _, action := range actions { |
|||
alert = c.takeAction(action) |
|||
if alert != AlertNoAlert { |
|||
logf(logTypeHandshake, "Error during handshake actions: %v", alert) |
|||
c.sendAlert(alert) |
|||
return io.EOF |
|||
} |
|||
} |
|||
|
|||
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
|||
// authentication, we'll need to allow transitions other than
|
|||
// Connected -> Connected
|
|||
var connected bool |
|||
c.state, connected = state.(StateConnected) |
|||
if !connected { |
|||
logf(logTypeHandshake, "Disconnected after state transition: %v", alert) |
|||
c.sendAlert(alert) |
|||
return io.EOF |
|||
} |
|||
|
|||
start += handshakeHeaderLen + hmLen |
|||
} |
|||
case RecordTypeAlert: |
|||
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer) |
|||
if len(pt.fragment) != 2 { |
|||
c.sendAlert(AlertUnexpectedMessage) |
|||
return io.EOF |
|||
} |
|||
if Alert(pt.fragment[1]) == AlertCloseNotify { |
|||
return io.EOF |
|||
} |
|||
|
|||
switch pt.fragment[0] { |
|||
case AlertLevelWarning: |
|||
// drop on the floor
|
|||
case AlertLevelError: |
|||
return Alert(pt.fragment[1]) |
|||
default: |
|||
c.sendAlert(AlertUnexpectedMessage) |
|||
return io.EOF |
|||
} |
|||
|
|||
case RecordTypeApplicationData: |
|||
c.readBuffer = append(c.readBuffer, pt.fragment...) |
|||
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) |
|||
} |
|||
|
|||
return err |
|||
} |
|||
|
|||
// Read application data up to the size of buffer. Handshake and alert records
|
|||
// are consumed by the Conn object directly.
|
|||
func (c *Conn) Read(buffer []byte) (int, error) { |
|||
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) |
|||
if alert := c.Handshake(); alert != AlertNoAlert { |
|||
return 0, alert |
|||
} |
|||
|
|||
if len(buffer) == 0 { |
|||
return 0, nil |
|||
} |
|||
|
|||
// Lock the input channel
|
|||
c.in.Lock() |
|||
defer c.in.Unlock() |
|||
for len(c.readBuffer) == 0 { |
|||
err := c.consumeRecord() |
|||
|
|||
// err can be nil if consumeRecord processed a non app-data
|
|||
// record.
|
|||
if err != nil { |
|||
if c.config.NonBlocking || err != WouldBlock { |
|||
logf(logTypeIO, "conn.Read returns err=%v", err) |
|||
return 0, err |
|||
} |
|||
} |
|||
} |
|||
|
|||
var read int |
|||
n := len(buffer) |
|||
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer)) |
|||
if len(c.readBuffer) <= n { |
|||
buffer = buffer[:len(c.readBuffer)] |
|||
copy(buffer, c.readBuffer) |
|||
read = len(c.readBuffer) |
|||
c.readBuffer = c.readBuffer[:0] |
|||
} else { |
|||
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n) |
|||
copy(buffer[:n], c.readBuffer[:n]) |
|||
c.readBuffer = c.readBuffer[n:] |
|||
read = n |
|||
} |
|||
|
|||
logf(logTypeVerbose, "Returning %v", string(buffer)) |
|||
return read, nil |
|||
} |
|||
|
|||
// Write application data
|
|||
func (c *Conn) Write(buffer []byte) (int, error) { |
|||
// Lock the output channel
|
|||
c.out.Lock() |
|||
defer c.out.Unlock() |
|||
|
|||
// Send full-size fragments
|
|||
var start int |
|||
sent := 0 |
|||
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { |
|||
err := c.out.WriteRecord(&TLSPlaintext{ |
|||
contentType: RecordTypeApplicationData, |
|||
fragment: buffer[start : start+maxFragmentLen], |
|||
}) |
|||
|
|||
if err != nil { |
|||
return sent, err |
|||
} |
|||
sent += maxFragmentLen |
|||
} |
|||
|
|||
// Send a final partial fragment if necessary
|
|||
if start < len(buffer) { |
|||
err := c.out.WriteRecord(&TLSPlaintext{ |
|||
contentType: RecordTypeApplicationData, |
|||
fragment: buffer[start:], |
|||
}) |
|||
|
|||
if err != nil { |
|||
return sent, err |
|||
} |
|||
sent += len(buffer[start:]) |
|||
} |
|||
return sent, nil |
|||
} |
|||
|
|||
// sendAlert sends a TLS alert message.
|
|||
// c.out.Mutex <= L.
|
|||
func (c *Conn) sendAlert(err Alert) error { |
|||
c.handshakeMutex.Lock() |
|||
defer c.handshakeMutex.Unlock() |
|||
|
|||
var level int |
|||
switch err { |
|||
case AlertNoRenegotiation, AlertCloseNotify: |
|||
level = AlertLevelWarning |
|||
default: |
|||
level = AlertLevelError |
|||
} |
|||
|
|||
buf := []byte{byte(err), byte(level)} |
|||
c.out.WriteRecord(&TLSPlaintext{ |
|||
contentType: RecordTypeAlert, |
|||
fragment: buf, |
|||
}) |
|||
|
|||
// close_notify and end_of_early_data are not actually errors
|
|||
if level == AlertLevelWarning { |
|||
return &net.OpError{Op: "local error", Err: err} |
|||
} |
|||
|
|||
return c.Close() |
|||
} |
|||
|
|||
// Close closes the connection.
|
|||
func (c *Conn) Close() error { |
|||
// XXX crypto/tls has an interlock with Write here. Do we need that?
|
|||
|
|||
return c.conn.Close() |
|||
} |
|||
|
|||
// LocalAddr returns the local network address.
|
|||
func (c *Conn) LocalAddr() net.Addr { |
|||
return c.conn.LocalAddr() |
|||
} |
|||
|
|||
// RemoteAddr returns the remote network address.
|
|||
func (c *Conn) RemoteAddr() net.Addr { |
|||
return c.conn.RemoteAddr() |
|||
} |
|||
|
|||
// SetDeadline sets the read and write deadlines associated with the connection.
|
|||
// A zero value for t means Read and Write will not time out.
|
|||
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
|||
func (c *Conn) SetDeadline(t time.Time) error { |
|||
return c.conn.SetDeadline(t) |
|||
} |
|||
|
|||
// SetReadDeadline sets the read deadline on the underlying connection.
|
|||
// A zero value for t means Read will not time out.
|
|||
func (c *Conn) SetReadDeadline(t time.Time) error { |
|||
return c.conn.SetReadDeadline(t) |
|||
} |
|||
|
|||
// SetWriteDeadline sets the write deadline on the underlying connection.
|
|||
// A zero value for t means Write will not time out.
|
|||
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
|||
func (c *Conn) SetWriteDeadline(t time.Time) error { |
|||
return c.conn.SetWriteDeadline(t) |
|||
} |
|||
|
|||
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { |
|||
label := "[server]" |
|||
if c.isClient { |
|||
label = "[client]" |
|||
} |
|||
|
|||
switch action := actionGeneric.(type) { |
|||
case SendHandshakeMessage: |
|||
err := c.hOut.WriteMessage(action.Message) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) |
|||
return AlertInternalError |
|||
} |
|||
|
|||
case RekeyIn: |
|||
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet) |
|||
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err) |
|||
return AlertInternalError |
|||
} |
|||
|
|||
case RekeyOut: |
|||
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet) |
|||
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err) |
|||
return AlertInternalError |
|||
} |
|||
|
|||
case SendEarlyData: |
|||
logf(logTypeHandshake, "%s Sending early data...", label) |
|||
_, err := c.Write(c.EarlyData) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "%s Error writing early data: %v", label, err) |
|||
return AlertInternalError |
|||
} |
|||
|
|||
case ReadPastEarlyData: |
|||
logf(logTypeHandshake, "%s Reading past early data...", label) |
|||
// Scan past all records that fail to decrypt
|
|||
_, err := c.in.PeekRecordType(!c.config.NonBlocking) |
|||
if err == nil { |
|||
break |
|||
} |
|||
_, ok := err.(DecryptError) |
|||
|
|||
for ok { |
|||
_, err = c.in.PeekRecordType(!c.config.NonBlocking) |
|||
if err == nil { |
|||
break |
|||
} |
|||
_, ok = err.(DecryptError) |
|||
} |
|||
|
|||
case ReadEarlyData: |
|||
logf(logTypeHandshake, "%s Reading early data...", label) |
|||
t, err := c.in.PeekRecordType(!c.config.NonBlocking) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err) |
|||
return AlertInternalError |
|||
} |
|||
logf(logTypeHandshake, "%s Got record type(1): %v", label, t) |
|||
|
|||
for t == RecordTypeApplicationData { |
|||
// Read a record into the buffer. Note that this is safe
|
|||
// in blocking mode because we read the record in in
|
|||
// PeekRecordType.
|
|||
pt, err := c.in.ReadRecord() |
|||
if err != nil { |
|||
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err) |
|||
return AlertInternalError |
|||
} |
|||
|
|||
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment) |
|||
c.EarlyData = append(c.EarlyData, pt.fragment...) |
|||
|
|||
t, err = c.in.PeekRecordType(!c.config.NonBlocking) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err) |
|||
return AlertInternalError |
|||
} |
|||
logf(logTypeHandshake, "%s Got record type (2): %v", label, t) |
|||
} |
|||
logf(logTypeHandshake, "%s Done reading early data", label) |
|||
|
|||
case StorePSK: |
|||
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) |
|||
if c.isClient { |
|||
// Clients look up PSKs based on server name
|
|||
c.config.PSKs.Put(c.config.ServerName, action.PSK) |
|||
} else { |
|||
// Servers look them up based on the identity in the extension
|
|||
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK) |
|||
} |
|||
|
|||
default: |
|||
logf(logTypeHandshake, "%s Unknown actionuction type", label) |
|||
return AlertInternalError |
|||
} |
|||
|
|||
return AlertNoAlert |
|||
} |
|||
|
|||
func (c *Conn) HandshakeSetup() Alert { |
|||
var state HandshakeState |
|||
var actions []HandshakeAction |
|||
var alert Alert |
|||
|
|||
if err := c.config.Init(c.isClient); err != nil { |
|||
logf(logTypeHandshake, "Error initializing config: %v", err) |
|||
return AlertInternalError |
|||
} |
|||
|
|||
// Set things up
|
|||
caps := Capabilities{ |
|||
CipherSuites: c.config.CipherSuites, |
|||
Groups: c.config.Groups, |
|||
SignatureSchemes: c.config.SignatureSchemes, |
|||
PSKs: c.config.PSKs, |
|||
PSKModes: c.config.PSKModes, |
|||
AllowEarlyData: c.config.AllowEarlyData, |
|||
RequireCookie: c.config.RequireCookie, |
|||
CookieHandler: c.config.CookieHandler, |
|||
RequireClientAuth: c.config.RequireClientAuth, |
|||
NextProtos: c.config.NextProtos, |
|||
Certificates: c.config.Certificates, |
|||
ExtensionHandler: c.extHandler, |
|||
} |
|||
opts := ConnectionOptions{ |
|||
ServerName: c.config.ServerName, |
|||
NextProtos: c.config.NextProtos, |
|||
EarlyData: c.EarlyData, |
|||
} |
|||
|
|||
if caps.RequireCookie && caps.CookieHandler == nil { |
|||
caps.CookieHandler = &defaultCookieHandler{} |
|||
} |
|||
|
|||
if c.isClient { |
|||
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil) |
|||
if alert != AlertNoAlert { |
|||
logf(logTypeHandshake, "Error initializing client state: %v", alert) |
|||
return alert |
|||
} |
|||
|
|||
for _, action := range actions { |
|||
alert = c.takeAction(action) |
|||
if alert != AlertNoAlert { |
|||
logf(logTypeHandshake, "Error during handshake actions: %v", alert) |
|||
return alert |
|||
} |
|||
} |
|||
} else { |
|||
state = ServerStateStart{Caps: caps, conn: c} |
|||
} |
|||
|
|||
c.hState = state |
|||
|
|||
return AlertNoAlert |
|||
} |
|||
|
|||
// Handshake causes a TLS handshake on the connection. The `isClient` member
|
|||
// determines whether a client or server handshake is performed. If a
|
|||
// handshake has already been performed, then its result will be returned.
|
|||
func (c *Conn) Handshake() Alert { |
|||
label := "[server]" |
|||
if c.isClient { |
|||
label = "[client]" |
|||
} |
|||
|
|||
// TODO Lock handshakeMutex
|
|||
// TODO Remove CloseNotify hack
|
|||
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify { |
|||
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert) |
|||
return c.handshakeAlert |
|||
} |
|||
if c.handshakeComplete { |
|||
return AlertNoAlert |
|||
} |
|||
|
|||
var alert Alert |
|||
if c.hState == nil { |
|||
logf(logTypeHandshake, "%s First time through handshake, setting up", label) |
|||
alert = c.HandshakeSetup() |
|||
if alert != AlertNoAlert { |
|||
return alert |
|||
} |
|||
} else { |
|||
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState) |
|||
} |
|||
|
|||
state := c.hState |
|||
_, connected := state.(StateConnected) |
|||
|
|||
var actions []HandshakeAction |
|||
|
|||
for !connected { |
|||
// Read a handshake message
|
|||
hm, err := c.hIn.ReadMessage() |
|||
if err == WouldBlock { |
|||
logf(logTypeHandshake, "%s Would block reading message: %v", label, err) |
|||
return AlertWouldBlock |
|||
} |
|||
if err != nil { |
|||
logf(logTypeHandshake, "%s Error reading message: %v", label, err) |
|||
c.sendAlert(AlertCloseNotify) |
|||
return AlertCloseNotify |
|||
} |
|||
logf(logTypeHandshake, "Read message with type: %v", hm.msgType) |
|||
|
|||
// Advance the state machine
|
|||
state, actions, alert = state.Next(hm) |
|||
|
|||
if alert != AlertNoAlert { |
|||
logf(logTypeHandshake, "Error in state transition: %v", alert) |
|||
return alert |
|||
} |
|||
|
|||
for index, action := range actions { |
|||
logf(logTypeHandshake, "%s taking next action (%d)", label, index) |
|||
alert = c.takeAction(action) |
|||
if alert != AlertNoAlert { |
|||
logf(logTypeHandshake, "Error during handshake actions: %v", alert) |
|||
c.sendAlert(alert) |
|||
return alert |
|||
} |
|||
} |
|||
|
|||
c.hState = state |
|||
logf(logTypeHandshake, "state is now %s", c.GetHsState()) |
|||
|
|||
_, connected = state.(StateConnected) |
|||
} |
|||
|
|||
c.state = state.(StateConnected) |
|||
|
|||
// Send NewSessionTicket if acting as server
|
|||
if !c.isClient && c.config.SendSessionTickets { |
|||
actions, alert := c.state.NewSessionTicket( |
|||
c.config.TicketLen, |
|||
c.config.TicketLifetime, |
|||
c.config.EarlyDataLifetime) |
|||
|
|||
for _, action := range actions { |
|||
alert = c.takeAction(action) |
|||
if alert != AlertNoAlert { |
|||
logf(logTypeHandshake, "Error during handshake actions: %v", alert) |
|||
c.sendAlert(alert) |
|||
return alert |
|||
} |
|||
} |
|||
} |
|||
|
|||
c.handshakeComplete = true |
|||
return AlertNoAlert |
|||
} |
|||
|
|||
func (c *Conn) SendKeyUpdate(requestUpdate bool) error { |
|||
if !c.handshakeComplete { |
|||
return fmt.Errorf("Cannot update keys until after handshake") |
|||
} |
|||
|
|||
request := KeyUpdateNotRequested |
|||
if requestUpdate { |
|||
request = KeyUpdateRequested |
|||
} |
|||
|
|||
// Create the key update and update state
|
|||
actions, alert := c.state.KeyUpdate(request) |
|||
if alert != AlertNoAlert { |
|||
c.sendAlert(alert) |
|||
return fmt.Errorf("Alert while generating key update: %v", alert) |
|||
} |
|||
|
|||
// Take actions (send key update and rekey)
|
|||
for _, action := range actions { |
|||
alert = c.takeAction(action) |
|||
if alert != AlertNoAlert { |
|||
c.sendAlert(alert) |
|||
return fmt.Errorf("Alert during key update actions: %v", alert) |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (c *Conn) GetHsState() string { |
|||
return reflect.TypeOf(c.hState).Name() |
|||
} |
|||
|
|||
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { |
|||
_, connected := c.hState.(StateConnected) |
|||
if !connected { |
|||
return nil, fmt.Errorf("Cannot compute exporter when state is not connected") |
|||
} |
|||
|
|||
if c.state.exporterSecret == nil { |
|||
return nil, fmt.Errorf("Internal error: no exporter secret") |
|||
} |
|||
|
|||
h0 := c.state.cryptoParams.Hash.New().Sum(nil) |
|||
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0) |
|||
|
|||
hc := c.state.cryptoParams.Hash.New().Sum(context) |
|||
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil |
|||
} |
|||
|
|||
func (c *Conn) State() ConnectionState { |
|||
state := ConnectionState{ |
|||
HandshakeState: c.GetHsState(), |
|||
} |
|||
|
|||
if c.handshakeComplete { |
|||
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite] |
|||
state.NextProto = c.state.Params.NextProto |
|||
} |
|||
|
|||
return state |
|||
} |
|||
|
|||
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error { |
|||
if c.hState != nil { |
|||
return fmt.Errorf("Can't set extension handler after setup") |
|||
} |
|||
|
|||
c.extHandler = h |
|||
return nil |
|||
} |
|||
@ -0,0 +1,654 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto" |
|||
"crypto/aes" |
|||
"crypto/cipher" |
|||
"crypto/ecdsa" |
|||
"crypto/elliptic" |
|||
"crypto/hmac" |
|||
"crypto/rand" |
|||
"crypto/rsa" |
|||
"crypto/x509" |
|||
"crypto/x509/pkix" |
|||
"encoding/asn1" |
|||
"fmt" |
|||
"math/big" |
|||
"time" |
|||
|
|||
"golang.org/x/crypto/curve25519" |
|||
|
|||
// Blank includes to ensure hash support
|
|||
_ "crypto/sha1" |
|||
_ "crypto/sha256" |
|||
_ "crypto/sha512" |
|||
) |
|||
|
|||
var prng = rand.Reader |
|||
|
|||
type aeadFactory func(key []byte) (cipher.AEAD, error) |
|||
|
|||
type CipherSuiteParams struct { |
|||
Suite CipherSuite |
|||
Cipher aeadFactory // Cipher factory
|
|||
Hash crypto.Hash // Hash function
|
|||
KeyLen int // Key length in octets
|
|||
IvLen int // IV length in octets
|
|||
} |
|||
|
|||
type signatureAlgorithm uint8 |
|||
|
|||
const ( |
|||
signatureAlgorithmUnknown = iota |
|||
signatureAlgorithmRSA_PKCS1 |
|||
signatureAlgorithmRSA_PSS |
|||
signatureAlgorithmECDSA |
|||
) |
|||
|
|||
var ( |
|||
hashMap = map[SignatureScheme]crypto.Hash{ |
|||
RSA_PKCS1_SHA1: crypto.SHA1, |
|||
RSA_PKCS1_SHA256: crypto.SHA256, |
|||
RSA_PKCS1_SHA384: crypto.SHA384, |
|||
RSA_PKCS1_SHA512: crypto.SHA512, |
|||
ECDSA_P256_SHA256: crypto.SHA256, |
|||
ECDSA_P384_SHA384: crypto.SHA384, |
|||
ECDSA_P521_SHA512: crypto.SHA512, |
|||
RSA_PSS_SHA256: crypto.SHA256, |
|||
RSA_PSS_SHA384: crypto.SHA384, |
|||
RSA_PSS_SHA512: crypto.SHA512, |
|||
} |
|||
|
|||
sigMap = map[SignatureScheme]signatureAlgorithm{ |
|||
RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1, |
|||
RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1, |
|||
RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1, |
|||
RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1, |
|||
ECDSA_P256_SHA256: signatureAlgorithmECDSA, |
|||
ECDSA_P384_SHA384: signatureAlgorithmECDSA, |
|||
ECDSA_P521_SHA512: signatureAlgorithmECDSA, |
|||
RSA_PSS_SHA256: signatureAlgorithmRSA_PSS, |
|||
RSA_PSS_SHA384: signatureAlgorithmRSA_PSS, |
|||
RSA_PSS_SHA512: signatureAlgorithmRSA_PSS, |
|||
} |
|||
|
|||
curveMap = map[SignatureScheme]NamedGroup{ |
|||
ECDSA_P256_SHA256: P256, |
|||
ECDSA_P384_SHA384: P384, |
|||
ECDSA_P521_SHA512: P521, |
|||
} |
|||
|
|||
newAESGCM = func(key []byte) (cipher.AEAD, error) { |
|||
block, err := aes.NewCipher(key) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
// TLS always uses 12-byte nonces
|
|||
return cipher.NewGCMWithNonceSize(block, 12) |
|||
} |
|||
|
|||
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{ |
|||
TLS_AES_128_GCM_SHA256: { |
|||
Suite: TLS_AES_128_GCM_SHA256, |
|||
Cipher: newAESGCM, |
|||
Hash: crypto.SHA256, |
|||
KeyLen: 16, |
|||
IvLen: 12, |
|||
}, |
|||
TLS_AES_256_GCM_SHA384: { |
|||
Suite: TLS_AES_256_GCM_SHA384, |
|||
Cipher: newAESGCM, |
|||
Hash: crypto.SHA384, |
|||
KeyLen: 32, |
|||
IvLen: 12, |
|||
}, |
|||
} |
|||
|
|||
x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{ |
|||
RSA_PKCS1_SHA1: x509.SHA1WithRSA, |
|||
RSA_PKCS1_SHA256: x509.SHA256WithRSA, |
|||
RSA_PKCS1_SHA384: x509.SHA384WithRSA, |
|||
RSA_PKCS1_SHA512: x509.SHA512WithRSA, |
|||
ECDSA_P256_SHA256: x509.ECDSAWithSHA256, |
|||
ECDSA_P384_SHA384: x509.ECDSAWithSHA384, |
|||
ECDSA_P521_SHA512: x509.ECDSAWithSHA512, |
|||
} |
|||
|
|||
defaultRSAKeySize = 2048 |
|||
) |
|||
|
|||
func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) { |
|||
switch group { |
|||
case P256: |
|||
crv = elliptic.P256() |
|||
case P384: |
|||
crv = elliptic.P384() |
|||
case P521: |
|||
crv = elliptic.P521() |
|||
} |
|||
return |
|||
} |
|||
|
|||
func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) { |
|||
switch key.Curve.Params().Name { |
|||
case elliptic.P256().Params().Name: |
|||
g = P256 |
|||
case elliptic.P384().Params().Name: |
|||
g = P384 |
|||
case elliptic.P521().Params().Name: |
|||
g = P521 |
|||
} |
|||
return |
|||
} |
|||
|
|||
func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) { |
|||
size = 0 |
|||
switch group { |
|||
case X25519: |
|||
size = 32 |
|||
case P256: |
|||
size = 65 |
|||
case P384: |
|||
size = 97 |
|||
case P521: |
|||
size = 133 |
|||
case FFDHE2048: |
|||
size = 256 |
|||
case FFDHE3072: |
|||
size = 384 |
|||
case FFDHE4096: |
|||
size = 512 |
|||
case FFDHE6144: |
|||
size = 768 |
|||
case FFDHE8192: |
|||
size = 1024 |
|||
} |
|||
return |
|||
} |
|||
|
|||
func primeFromNamedGroup(group NamedGroup) (p *big.Int) { |
|||
switch group { |
|||
case FFDHE2048: |
|||
p = finiteFieldPrime2048 |
|||
case FFDHE3072: |
|||
p = finiteFieldPrime3072 |
|||
case FFDHE4096: |
|||
p = finiteFieldPrime4096 |
|||
case FFDHE6144: |
|||
p = finiteFieldPrime6144 |
|||
case FFDHE8192: |
|||
p = finiteFieldPrime8192 |
|||
} |
|||
return |
|||
} |
|||
|
|||
func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool { |
|||
sigType := sigMap[alg] |
|||
switch key.(type) { |
|||
case *rsa.PrivateKey: |
|||
return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS |
|||
case *ecdsa.PrivateKey: |
|||
return sigType == signatureAlgorithmECDSA |
|||
default: |
|||
return false |
|||
} |
|||
} |
|||
|
|||
func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) { |
|||
primeLen := len(p.Bytes()) |
|||
for { |
|||
// g = 2 for all ffdhe groups
|
|||
priv, err = rand.Int(prng, p) |
|||
if err != nil { |
|||
return |
|||
} |
|||
|
|||
pub = big.NewInt(0) |
|||
pub.Exp(big.NewInt(2), priv, p) |
|||
|
|||
if len(pub.Bytes()) == primeLen { |
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) { |
|||
switch group { |
|||
case P256, P384, P521: |
|||
var x, y *big.Int |
|||
crv := curveFromNamedGroup(group) |
|||
priv, x, y, err = elliptic.GenerateKey(crv, prng) |
|||
if err != nil { |
|||
return |
|||
} |
|||
|
|||
pub = elliptic.Marshal(crv, x, y) |
|||
return |
|||
|
|||
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: |
|||
p := primeFromNamedGroup(group) |
|||
x, X, err2 := ffdheKeyShareFromPrime(p) |
|||
if err2 != nil { |
|||
err = err2 |
|||
return |
|||
} |
|||
|
|||
priv = x.Bytes() |
|||
pubBytes := X.Bytes() |
|||
|
|||
numBytes := keyExchangeSizeFromNamedGroup(group) |
|||
|
|||
pub = make([]byte, numBytes) |
|||
copy(pub[numBytes-len(pubBytes):], pubBytes) |
|||
|
|||
return |
|||
|
|||
case X25519: |
|||
var private, public [32]byte |
|||
_, err = prng.Read(private[:]) |
|||
if err != nil { |
|||
return |
|||
} |
|||
|
|||
curve25519.ScalarBaseMult(&public, &private) |
|||
priv = private[:] |
|||
pub = public[:] |
|||
return |
|||
|
|||
default: |
|||
return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group) |
|||
} |
|||
} |
|||
|
|||
func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) { |
|||
switch group { |
|||
case P256, P384, P521: |
|||
if len(pub) != keyExchangeSizeFromNamedGroup(group) { |
|||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") |
|||
} |
|||
|
|||
crv := curveFromNamedGroup(group) |
|||
pubX, pubY := elliptic.Unmarshal(crv, pub) |
|||
x, _ := crv.Params().ScalarMult(pubX, pubY, priv) |
|||
xBytes := x.Bytes() |
|||
|
|||
numBytes := len(crv.Params().P.Bytes()) |
|||
|
|||
ret := make([]byte, numBytes) |
|||
copy(ret[numBytes-len(xBytes):], xBytes) |
|||
|
|||
return ret, nil |
|||
|
|||
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: |
|||
numBytes := keyExchangeSizeFromNamedGroup(group) |
|||
if len(pub) != numBytes { |
|||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") |
|||
} |
|||
p := primeFromNamedGroup(group) |
|||
x := big.NewInt(0).SetBytes(priv) |
|||
Y := big.NewInt(0).SetBytes(pub) |
|||
ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes() |
|||
|
|||
ret := make([]byte, numBytes) |
|||
copy(ret[numBytes-len(ZBytes):], ZBytes) |
|||
|
|||
return ret, nil |
|||
|
|||
case X25519: |
|||
if len(pub) != keyExchangeSizeFromNamedGroup(group) { |
|||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") |
|||
} |
|||
|
|||
var private, public, ret [32]byte |
|||
copy(private[:], priv) |
|||
copy(public[:], pub) |
|||
curve25519.ScalarMult(&ret, &private, &public) |
|||
|
|||
return ret[:], nil |
|||
|
|||
default: |
|||
return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group) |
|||
} |
|||
} |
|||
|
|||
func newSigningKey(sig SignatureScheme) (crypto.Signer, error) { |
|||
switch sig { |
|||
case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256, |
|||
RSA_PKCS1_SHA384, RSA_PKCS1_SHA512, |
|||
RSA_PSS_SHA256, RSA_PSS_SHA384, |
|||
RSA_PSS_SHA512: |
|||
return rsa.GenerateKey(prng, defaultRSAKeySize) |
|||
case ECDSA_P256_SHA256: |
|||
return ecdsa.GenerateKey(elliptic.P256(), prng) |
|||
case ECDSA_P384_SHA384: |
|||
return ecdsa.GenerateKey(elliptic.P384(), prng) |
|||
case ECDSA_P521_SHA512: |
|||
return ecdsa.GenerateKey(elliptic.P521(), prng) |
|||
default: |
|||
return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig) |
|||
} |
|||
} |
|||
|
|||
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) { |
|||
sigAlg, ok := x509AlgMap[alg] |
|||
if !ok { |
|||
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg) |
|||
} |
|||
if len(name) == 0 { |
|||
return nil, fmt.Errorf("tls.selfsigned: No name provided") |
|||
} |
|||
|
|||
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0)) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
template := &x509.Certificate{ |
|||
SerialNumber: serial, |
|||
NotBefore: time.Now(), |
|||
NotAfter: time.Now().AddDate(0, 0, 1), |
|||
SignatureAlgorithm: sigAlg, |
|||
Subject: pkix.Name{CommonName: name}, |
|||
DNSNames: []string{name}, |
|||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment, |
|||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, |
|||
} |
|||
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
// It is safe to ignore the error here because we're parsing known-good data
|
|||
cert, _ := x509.ParseCertificate(der) |
|||
return cert, nil |
|||
} |
|||
|
|||
// XXX(rlb): Copied from crypto/x509
|
|||
type ecdsaSignature struct { |
|||
R, S *big.Int |
|||
} |
|||
|
|||
func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) { |
|||
var opts crypto.SignerOpts |
|||
|
|||
hash := hashMap[alg] |
|||
if hash == crypto.SHA1 { |
|||
return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") |
|||
} |
|||
|
|||
sigType := sigMap[alg] |
|||
var realInput []byte |
|||
switch key := privateKey.(type) { |
|||
case *rsa.PrivateKey: |
|||
switch { |
|||
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: |
|||
logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size()) |
|||
opts = hash |
|||
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: |
|||
fallthrough |
|||
case sigType == signatureAlgorithmRSA_PSS: |
|||
logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size()) |
|||
opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} |
|||
default: |
|||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key") |
|||
} |
|||
|
|||
h := hash.New() |
|||
h.Write(sigInput) |
|||
realInput = h.Sum(nil) |
|||
case *ecdsa.PrivateKey: |
|||
if sigType != signatureAlgorithmECDSA { |
|||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key") |
|||
} |
|||
|
|||
algGroup := curveMap[alg] |
|||
keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey)) |
|||
if algGroup != keyGroup { |
|||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination") |
|||
} |
|||
|
|||
h := hash.New() |
|||
h.Write(sigInput) |
|||
realInput = h.Sum(nil) |
|||
default: |
|||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type") |
|||
} |
|||
|
|||
sig, err := privateKey.Sign(prng, realInput, opts) |
|||
logf(logTypeCrypto, "signature: %x", sig) |
|||
return sig, err |
|||
} |
|||
|
|||
func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error { |
|||
hash := hashMap[alg] |
|||
|
|||
if hash == crypto.SHA1 { |
|||
return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") |
|||
} |
|||
|
|||
sigType := sigMap[alg] |
|||
switch pub := publicKey.(type) { |
|||
case *rsa.PublicKey: |
|||
switch { |
|||
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: |
|||
logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size()) |
|||
|
|||
h := hash.New() |
|||
h.Write(sigInput) |
|||
realInput := h.Sum(nil) |
|||
return rsa.VerifyPKCS1v15(pub, hash, realInput, sig) |
|||
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: |
|||
fallthrough |
|||
case sigType == signatureAlgorithmRSA_PSS: |
|||
logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size()) |
|||
opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} |
|||
|
|||
h := hash.New() |
|||
h.Write(sigInput) |
|||
realInput := h.Sum(nil) |
|||
return rsa.VerifyPSS(pub, hash, realInput, sig, opts) |
|||
default: |
|||
return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key") |
|||
} |
|||
|
|||
case *ecdsa.PublicKey: |
|||
if sigType != signatureAlgorithmECDSA { |
|||
return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key") |
|||
} |
|||
|
|||
if curveMap[alg] != namedGroupFromECDSAKey(pub) { |
|||
return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key") |
|||
} |
|||
|
|||
ecdsaSig := new(ecdsaSignature) |
|||
if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil { |
|||
return err |
|||
} else if len(rest) != 0 { |
|||
return fmt.Errorf("tls.verify: trailing data after ECDSA signature") |
|||
} |
|||
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { |
|||
return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values") |
|||
} |
|||
|
|||
h := hash.New() |
|||
h.Write(sigInput) |
|||
realInput := h.Sum(nil) |
|||
if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) { |
|||
return fmt.Errorf("tls.verify: ECDSA verification failure") |
|||
} |
|||
return nil |
|||
default: |
|||
return fmt.Errorf("tls.verify: Unsupported key type") |
|||
} |
|||
} |
|||
|
|||
// 0
|
|||
// |
|
|||
// v
|
|||
// PSK -> HKDF-Extract = Early Secret
|
|||
// |
|
|||
// +-----> Derive-Secret(.,
|
|||
// | "ext binder" |
|
|||
// | "res binder",
|
|||
// | "")
|
|||
// | = binder_key
|
|||
// |
|
|||
// +-----> Derive-Secret(., "c e traffic",
|
|||
// | ClientHello)
|
|||
// | = client_early_traffic_secret
|
|||
// |
|
|||
// +-----> Derive-Secret(., "e exp master",
|
|||
// | ClientHello)
|
|||
// | = early_exporter_master_secret
|
|||
// v
|
|||
// Derive-Secret(., "derived", "")
|
|||
// |
|
|||
// v
|
|||
// (EC)DHE -> HKDF-Extract = Handshake Secret
|
|||
// |
|
|||
// +-----> Derive-Secret(., "c hs traffic",
|
|||
// | ClientHello...ServerHello)
|
|||
// | = client_handshake_traffic_secret
|
|||
// |
|
|||
// +-----> Derive-Secret(., "s hs traffic",
|
|||
// | ClientHello...ServerHello)
|
|||
// | = server_handshake_traffic_secret
|
|||
// v
|
|||
// Derive-Secret(., "derived", "")
|
|||
// |
|
|||
// v
|
|||
// 0 -> HKDF-Extract = Master Secret
|
|||
// |
|
|||
// +-----> Derive-Secret(., "c ap traffic",
|
|||
// | ClientHello...server Finished)
|
|||
// | = client_application_traffic_secret_0
|
|||
// |
|
|||
// +-----> Derive-Secret(., "s ap traffic",
|
|||
// | ClientHello...server Finished)
|
|||
// | = server_application_traffic_secret_0
|
|||
// |
|
|||
// +-----> Derive-Secret(., "exp master",
|
|||
// | ClientHello...server Finished)
|
|||
// | = exporter_master_secret
|
|||
// |
|
|||
// +-----> Derive-Secret(., "res master",
|
|||
// ClientHello...client Finished)
|
|||
// = resumption_master_secret
|
|||
|
|||
// From RFC 5869
|
|||
// PRK = HMAC-Hash(salt, IKM)
|
|||
func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte { |
|||
salt := saltIn |
|||
|
|||
// if [salt is] not provided, it is set to a string of HashLen zeros
|
|||
if salt == nil { |
|||
salt = bytes.Repeat([]byte{0}, hash.Size()) |
|||
} |
|||
|
|||
h := hmac.New(hash.New, salt) |
|||
h.Write(input) |
|||
out := h.Sum(nil) |
|||
|
|||
logf(logTypeCrypto, "HKDF Extract:\n") |
|||
logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt) |
|||
logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input) |
|||
logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out) |
|||
|
|||
return out |
|||
} |
|||
|
|||
const ( |
|||
labelExternalBinder = "ext binder" |
|||
labelResumptionBinder = "res binder" |
|||
labelEarlyTrafficSecret = "c e traffic" |
|||
labelEarlyExporterSecret = "e exp master" |
|||
labelClientHandshakeTrafficSecret = "c hs traffic" |
|||
labelServerHandshakeTrafficSecret = "s hs traffic" |
|||
labelClientApplicationTrafficSecret = "c ap traffic" |
|||
labelServerApplicationTrafficSecret = "s ap traffic" |
|||
labelExporterSecret = "exp master" |
|||
labelResumptionSecret = "res master" |
|||
labelDerived = "derived" |
|||
labelFinished = "finished" |
|||
labelResumption = "resumption" |
|||
) |
|||
|
|||
// struct HkdfLabel {
|
|||
// uint16 length;
|
|||
// opaque label<9..255>;
|
|||
// opaque hash_value<0..255>;
|
|||
// };
|
|||
func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte { |
|||
label := "tls13 " + labelIn |
|||
|
|||
labelLen := len(label) |
|||
hashLen := len(hashValue) |
|||
hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen) |
|||
hkdfLabel[0] = byte(outLen >> 8) |
|||
hkdfLabel[1] = byte(outLen) |
|||
hkdfLabel[2] = byte(labelLen) |
|||
copy(hkdfLabel[3:3+labelLen], []byte(label)) |
|||
hkdfLabel[3+labelLen] = byte(hashLen) |
|||
copy(hkdfLabel[3+labelLen+1:], hashValue) |
|||
|
|||
return hkdfLabel |
|||
} |
|||
|
|||
func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte { |
|||
out := []byte{} |
|||
T := []byte{} |
|||
i := byte(1) |
|||
for len(out) < outLen { |
|||
block := append(T, info...) |
|||
block = append(block, i) |
|||
|
|||
h := hmac.New(hash.New, prk) |
|||
h.Write(block) |
|||
|
|||
T = h.Sum(nil) |
|||
out = append(out, T...) |
|||
i++ |
|||
} |
|||
return out[:outLen] |
|||
} |
|||
|
|||
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte { |
|||
info := hkdfEncodeLabel(label, hashValue, outLen) |
|||
derived := HkdfExpand(hash, secret, info, outLen) |
|||
|
|||
logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen) |
|||
logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret) |
|||
logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue) |
|||
logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info) |
|||
logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived) |
|||
|
|||
return derived |
|||
} |
|||
|
|||
func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte { |
|||
return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size()) |
|||
} |
|||
|
|||
func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte { |
|||
macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size()) |
|||
mac := hmac.New(params.Hash.New, macKey) |
|||
mac.Write(input) |
|||
return mac.Sum(nil) |
|||
} |
|||
|
|||
type keySet struct { |
|||
cipher aeadFactory |
|||
key []byte |
|||
iv []byte |
|||
} |
|||
|
|||
func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet { |
|||
logf(logTypeCrypto, "making traffic keys: secret=%x", secret) |
|||
return keySet{ |
|||
cipher: params.Cipher, |
|||
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen), |
|||
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen), |
|||
} |
|||
} |
|||
@ -0,0 +1,586 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
|
|||
"github.com/bifurcation/mint/syntax" |
|||
) |
|||
|
|||
type ExtensionBody interface { |
|||
Type() ExtensionType |
|||
Marshal() ([]byte, error) |
|||
Unmarshal(data []byte) (int, error) |
|||
} |
|||
|
|||
// struct {
|
|||
// ExtensionType extension_type;
|
|||
// opaque extension_data<0..2^16-1>;
|
|||
// } Extension;
|
|||
type Extension struct { |
|||
ExtensionType ExtensionType |
|||
ExtensionData []byte `tls:"head=2"` |
|||
} |
|||
|
|||
func (ext Extension) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(ext) |
|||
} |
|||
|
|||
func (ext *Extension) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, ext) |
|||
} |
|||
|
|||
type ExtensionList []Extension |
|||
|
|||
type extensionListInner struct { |
|||
List []Extension `tls:"head=2"` |
|||
} |
|||
|
|||
func (el ExtensionList) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(extensionListInner{el}) |
|||
} |
|||
|
|||
func (el *ExtensionList) Unmarshal(data []byte) (int, error) { |
|||
var list extensionListInner |
|||
read, err := syntax.Unmarshal(data, &list) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
*el = list.List |
|||
return read, nil |
|||
} |
|||
|
|||
func (el *ExtensionList) Add(src ExtensionBody) error { |
|||
data, err := src.Marshal() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
if el == nil { |
|||
el = new(ExtensionList) |
|||
} |
|||
|
|||
// If one already exists with this type, replace it
|
|||
for i := range *el { |
|||
if (*el)[i].ExtensionType == src.Type() { |
|||
(*el)[i].ExtensionData = data |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
// Otherwise append
|
|||
*el = append(*el, Extension{ |
|||
ExtensionType: src.Type(), |
|||
ExtensionData: data, |
|||
}) |
|||
return nil |
|||
} |
|||
|
|||
func (el ExtensionList) Find(dst ExtensionBody) bool { |
|||
for _, ext := range el { |
|||
if ext.ExtensionType == dst.Type() { |
|||
_, err := dst.Unmarshal(ext.ExtensionData) |
|||
return err == nil |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// struct {
|
|||
// NameType name_type;
|
|||
// select (name_type) {
|
|||
// case host_name: HostName;
|
|||
// } name;
|
|||
// } ServerName;
|
|||
//
|
|||
// enum {
|
|||
// host_name(0), (255)
|
|||
// } NameType;
|
|||
//
|
|||
// opaque HostName<1..2^16-1>;
|
|||
//
|
|||
// struct {
|
|||
// ServerName server_name_list<1..2^16-1>
|
|||
// } ServerNameList;
|
|||
//
|
|||
// But we only care about the case where there's a single DNS hostname. We
|
|||
// will never create anything else, and throw if we receive something else
|
|||
//
|
|||
// 2 1 2
|
|||
// | listLen | NameType | nameLen | name |
|
|||
type ServerNameExtension string |
|||
|
|||
type serverNameInner struct { |
|||
NameType uint8 |
|||
HostName []byte `tls:"head=2,min=1"` |
|||
} |
|||
|
|||
type serverNameListInner struct { |
|||
ServerNameList []serverNameInner `tls:"head=2,min=1"` |
|||
} |
|||
|
|||
func (sni ServerNameExtension) Type() ExtensionType { |
|||
return ExtensionTypeServerName |
|||
} |
|||
|
|||
func (sni ServerNameExtension) Marshal() ([]byte, error) { |
|||
list := serverNameListInner{ |
|||
ServerNameList: []serverNameInner{{ |
|||
NameType: 0x00, // host_name
|
|||
HostName: []byte(sni), |
|||
}}, |
|||
} |
|||
|
|||
return syntax.Marshal(list) |
|||
} |
|||
|
|||
func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) { |
|||
var list serverNameListInner |
|||
read, err := syntax.Unmarshal(data, &list) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
// Syntax requires at least one entry
|
|||
// Entries beyond the first are ignored
|
|||
if nameType := list.ServerNameList[0].NameType; nameType != 0x00 { |
|||
return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType) |
|||
} |
|||
|
|||
*sni = ServerNameExtension(list.ServerNameList[0].HostName) |
|||
return read, nil |
|||
} |
|||
|
|||
// struct {
|
|||
// NamedGroup group;
|
|||
// opaque key_exchange<1..2^16-1>;
|
|||
// } KeyShareEntry;
|
|||
//
|
|||
// struct {
|
|||
// select (Handshake.msg_type) {
|
|||
// case client_hello:
|
|||
// KeyShareEntry client_shares<0..2^16-1>;
|
|||
//
|
|||
// case hello_retry_request:
|
|||
// NamedGroup selected_group;
|
|||
//
|
|||
// case server_hello:
|
|||
// KeyShareEntry server_share;
|
|||
// };
|
|||
// } KeyShare;
|
|||
type KeyShareEntry struct { |
|||
Group NamedGroup |
|||
KeyExchange []byte `tls:"head=2,min=1"` |
|||
} |
|||
|
|||
func (kse KeyShareEntry) SizeValid() bool { |
|||
return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group) |
|||
} |
|||
|
|||
type KeyShareExtension struct { |
|||
HandshakeType HandshakeType |
|||
SelectedGroup NamedGroup |
|||
Shares []KeyShareEntry |
|||
} |
|||
|
|||
type KeyShareClientHelloInner struct { |
|||
ClientShares []KeyShareEntry `tls:"head=2,min=0"` |
|||
} |
|||
type KeyShareHelloRetryInner struct { |
|||
SelectedGroup NamedGroup |
|||
} |
|||
type KeyShareServerHelloInner struct { |
|||
ServerShare KeyShareEntry |
|||
} |
|||
|
|||
func (ks KeyShareExtension) Type() ExtensionType { |
|||
return ExtensionTypeKeyShare |
|||
} |
|||
|
|||
func (ks KeyShareExtension) Marshal() ([]byte, error) { |
|||
switch ks.HandshakeType { |
|||
case HandshakeTypeClientHello: |
|||
for _, share := range ks.Shares { |
|||
if !share.SizeValid() { |
|||
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") |
|||
} |
|||
} |
|||
return syntax.Marshal(KeyShareClientHelloInner{ks.Shares}) |
|||
|
|||
case HandshakeTypeHelloRetryRequest: |
|||
if len(ks.Shares) > 0 { |
|||
return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest") |
|||
} |
|||
|
|||
return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup}) |
|||
|
|||
case HandshakeTypeServerHello: |
|||
if len(ks.Shares) != 1 { |
|||
return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share") |
|||
} |
|||
|
|||
if !ks.Shares[0].SizeValid() { |
|||
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") |
|||
} |
|||
|
|||
return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]}) |
|||
|
|||
default: |
|||
return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed") |
|||
} |
|||
} |
|||
|
|||
func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) { |
|||
switch ks.HandshakeType { |
|||
case HandshakeTypeClientHello: |
|||
var inner KeyShareClientHelloInner |
|||
read, err := syntax.Unmarshal(data, &inner) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
for _, share := range inner.ClientShares { |
|||
if !share.SizeValid() { |
|||
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") |
|||
} |
|||
} |
|||
|
|||
ks.Shares = inner.ClientShares |
|||
return read, nil |
|||
|
|||
case HandshakeTypeHelloRetryRequest: |
|||
var inner KeyShareHelloRetryInner |
|||
read, err := syntax.Unmarshal(data, &inner) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
ks.SelectedGroup = inner.SelectedGroup |
|||
return read, nil |
|||
|
|||
case HandshakeTypeServerHello: |
|||
var inner KeyShareServerHelloInner |
|||
read, err := syntax.Unmarshal(data, &inner) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
if !inner.ServerShare.SizeValid() { |
|||
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") |
|||
} |
|||
|
|||
ks.Shares = []KeyShareEntry{inner.ServerShare} |
|||
return read, nil |
|||
|
|||
default: |
|||
return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed") |
|||
} |
|||
} |
|||
|
|||
// struct {
|
|||
// NamedGroup named_group_list<2..2^16-1>;
|
|||
// } NamedGroupList;
|
|||
type SupportedGroupsExtension struct { |
|||
Groups []NamedGroup `tls:"head=2,min=2"` |
|||
} |
|||
|
|||
func (sg SupportedGroupsExtension) Type() ExtensionType { |
|||
return ExtensionTypeSupportedGroups |
|||
} |
|||
|
|||
func (sg SupportedGroupsExtension) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(sg) |
|||
} |
|||
|
|||
func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, sg) |
|||
} |
|||
|
|||
// struct {
|
|||
// SignatureScheme supported_signature_algorithms<2..2^16-2>;
|
|||
// } SignatureSchemeList
|
|||
type SignatureAlgorithmsExtension struct { |
|||
Algorithms []SignatureScheme `tls:"head=2,min=2"` |
|||
} |
|||
|
|||
func (sa SignatureAlgorithmsExtension) Type() ExtensionType { |
|||
return ExtensionTypeSignatureAlgorithms |
|||
} |
|||
|
|||
func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(sa) |
|||
} |
|||
|
|||
func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, sa) |
|||
} |
|||
|
|||
// struct {
|
|||
// opaque identity<1..2^16-1>;
|
|||
// uint32 obfuscated_ticket_age;
|
|||
// } PskIdentity;
|
|||
//
|
|||
// opaque PskBinderEntry<32..255>;
|
|||
//
|
|||
// struct {
|
|||
// select (Handshake.msg_type) {
|
|||
// case client_hello:
|
|||
// PskIdentity identities<7..2^16-1>;
|
|||
// PskBinderEntry binders<33..2^16-1>;
|
|||
//
|
|||
// case server_hello:
|
|||
// uint16 selected_identity;
|
|||
// };
|
|||
//
|
|||
// } PreSharedKeyExtension;
|
|||
type PSKIdentity struct { |
|||
Identity []byte `tls:"head=2,min=1"` |
|||
ObfuscatedTicketAge uint32 |
|||
} |
|||
|
|||
type PSKBinderEntry struct { |
|||
Binder []byte `tls:"head=1,min=32"` |
|||
} |
|||
|
|||
type PreSharedKeyExtension struct { |
|||
HandshakeType HandshakeType |
|||
Identities []PSKIdentity |
|||
Binders []PSKBinderEntry |
|||
SelectedIdentity uint16 |
|||
} |
|||
|
|||
type preSharedKeyClientInner struct { |
|||
Identities []PSKIdentity `tls:"head=2,min=7"` |
|||
Binders []PSKBinderEntry `tls:"head=2,min=33"` |
|||
} |
|||
|
|||
type preSharedKeyServerInner struct { |
|||
SelectedIdentity uint16 |
|||
} |
|||
|
|||
func (psk PreSharedKeyExtension) Type() ExtensionType { |
|||
return ExtensionTypePreSharedKey |
|||
} |
|||
|
|||
func (psk PreSharedKeyExtension) Marshal() ([]byte, error) { |
|||
switch psk.HandshakeType { |
|||
case HandshakeTypeClientHello: |
|||
return syntax.Marshal(preSharedKeyClientInner{ |
|||
Identities: psk.Identities, |
|||
Binders: psk.Binders, |
|||
}) |
|||
|
|||
case HandshakeTypeServerHello: |
|||
if len(psk.Identities) > 0 || len(psk.Binders) > 0 { |
|||
return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index") |
|||
} |
|||
return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity}) |
|||
|
|||
default: |
|||
return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported") |
|||
} |
|||
} |
|||
|
|||
func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) { |
|||
switch psk.HandshakeType { |
|||
case HandshakeTypeClientHello: |
|||
var inner preSharedKeyClientInner |
|||
read, err := syntax.Unmarshal(data, &inner) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
if len(inner.Identities) != len(inner.Binders) { |
|||
return 0, fmt.Errorf("Lengths of identities and binders not equal") |
|||
} |
|||
|
|||
psk.Identities = inner.Identities |
|||
psk.Binders = inner.Binders |
|||
return read, nil |
|||
|
|||
case HandshakeTypeServerHello: |
|||
var inner preSharedKeyServerInner |
|||
read, err := syntax.Unmarshal(data, &inner) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
psk.SelectedIdentity = inner.SelectedIdentity |
|||
return read, nil |
|||
|
|||
default: |
|||
return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported") |
|||
} |
|||
} |
|||
|
|||
func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) { |
|||
for i, localID := range psk.Identities { |
|||
if bytes.Equal(localID.Identity, id) { |
|||
return psk.Binders[i].Binder, true |
|||
} |
|||
} |
|||
return nil, false |
|||
} |
|||
|
|||
// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
|
|||
//
|
|||
// struct {
|
|||
// PskKeyExchangeMode ke_modes<1..255>;
|
|||
// } PskKeyExchangeModes;
|
|||
type PSKKeyExchangeModesExtension struct { |
|||
KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"` |
|||
} |
|||
|
|||
func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType { |
|||
return ExtensionTypePSKKeyExchangeModes |
|||
} |
|||
|
|||
func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(pkem) |
|||
} |
|||
|
|||
func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, pkem) |
|||
} |
|||
|
|||
// struct {
|
|||
// } EarlyDataIndication;
|
|||
|
|||
type EarlyDataExtension struct{} |
|||
|
|||
func (ed EarlyDataExtension) Type() ExtensionType { |
|||
return ExtensionTypeEarlyData |
|||
} |
|||
|
|||
func (ed EarlyDataExtension) Marshal() ([]byte, error) { |
|||
return []byte{}, nil |
|||
} |
|||
|
|||
func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) { |
|||
return 0, nil |
|||
} |
|||
|
|||
// struct {
|
|||
// uint32 max_early_data_size;
|
|||
// } TicketEarlyDataInfo;
|
|||
|
|||
type TicketEarlyDataInfoExtension struct { |
|||
MaxEarlyDataSize uint32 |
|||
} |
|||
|
|||
func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType { |
|||
return ExtensionTypeTicketEarlyDataInfo |
|||
} |
|||
|
|||
func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(tedi) |
|||
} |
|||
|
|||
func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, tedi) |
|||
} |
|||
|
|||
// opaque ProtocolName<1..2^8-1>;
|
|||
//
|
|||
// struct {
|
|||
// ProtocolName protocol_name_list<2..2^16-1>
|
|||
// } ProtocolNameList;
|
|||
type ALPNExtension struct { |
|||
Protocols []string |
|||
} |
|||
|
|||
type protocolNameInner struct { |
|||
Name []byte `tls:"head=1,min=1"` |
|||
} |
|||
|
|||
type alpnExtensionInner struct { |
|||
Protocols []protocolNameInner `tls:"head=2,min=2"` |
|||
} |
|||
|
|||
func (alpn ALPNExtension) Type() ExtensionType { |
|||
return ExtensionTypeALPN |
|||
} |
|||
|
|||
func (alpn ALPNExtension) Marshal() ([]byte, error) { |
|||
protocols := make([]protocolNameInner, len(alpn.Protocols)) |
|||
for i, protocol := range alpn.Protocols { |
|||
protocols[i] = protocolNameInner{[]byte(protocol)} |
|||
} |
|||
return syntax.Marshal(alpnExtensionInner{protocols}) |
|||
} |
|||
|
|||
func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) { |
|||
var inner alpnExtensionInner |
|||
read, err := syntax.Unmarshal(data, &inner) |
|||
|
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
alpn.Protocols = make([]string, len(inner.Protocols)) |
|||
for i, protocol := range inner.Protocols { |
|||
alpn.Protocols[i] = string(protocol.Name) |
|||
} |
|||
return read, nil |
|||
} |
|||
|
|||
// struct {
|
|||
// ProtocolVersion versions<2..254>;
|
|||
// } SupportedVersions;
|
|||
type SupportedVersionsExtension struct { |
|||
Versions []uint16 `tls:"head=1,min=2,max=254"` |
|||
} |
|||
|
|||
func (sv SupportedVersionsExtension) Type() ExtensionType { |
|||
return ExtensionTypeSupportedVersions |
|||
} |
|||
|
|||
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(sv) |
|||
} |
|||
|
|||
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, sv) |
|||
} |
|||
|
|||
// struct {
|
|||
// opaque cookie<1..2^16-1>;
|
|||
// } Cookie;
|
|||
type CookieExtension struct { |
|||
Cookie []byte `tls:"head=2,min=1"` |
|||
} |
|||
|
|||
func (c CookieExtension) Type() ExtensionType { |
|||
return ExtensionTypeCookie |
|||
} |
|||
|
|||
func (c CookieExtension) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(c) |
|||
} |
|||
|
|||
func (c *CookieExtension) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, c) |
|||
} |
|||
|
|||
// defaultCookieLength is the default length of a cookie
|
|||
const defaultCookieLength = 32 |
|||
|
|||
type defaultCookieHandler struct { |
|||
data []byte |
|||
} |
|||
|
|||
var _ CookieHandler = &defaultCookieHandler{} |
|||
|
|||
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
|
|||
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) { |
|||
h.data = make([]byte, defaultCookieLength) |
|||
if _, err := prng.Read(h.data); err != nil { |
|||
return nil, err |
|||
} |
|||
return h.data, nil |
|||
} |
|||
|
|||
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool { |
|||
return bytes.Equal(h.data, data) |
|||
} |
|||
@ -0,0 +1,147 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"encoding/hex" |
|||
"math/big" |
|||
) |
|||
|
|||
var ( |
|||
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + |
|||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + |
|||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + |
|||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + |
|||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + |
|||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + |
|||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + |
|||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + |
|||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + |
|||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + |
|||
"886B423861285C97FFFFFFFFFFFFFFFF" |
|||
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex) |
|||
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes) |
|||
|
|||
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + |
|||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + |
|||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + |
|||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + |
|||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + |
|||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + |
|||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + |
|||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + |
|||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + |
|||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + |
|||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + |
|||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + |
|||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + |
|||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + |
|||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + |
|||
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF" |
|||
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex) |
|||
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes) |
|||
|
|||
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + |
|||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + |
|||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + |
|||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + |
|||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + |
|||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + |
|||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + |
|||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + |
|||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + |
|||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + |
|||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + |
|||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + |
|||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + |
|||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + |
|||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + |
|||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + |
|||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + |
|||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + |
|||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" + |
|||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + |
|||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" + |
|||
"FFFFFFFFFFFFFFFF" |
|||
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex) |
|||
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes) |
|||
|
|||
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + |
|||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + |
|||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + |
|||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + |
|||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + |
|||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + |
|||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + |
|||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + |
|||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + |
|||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + |
|||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + |
|||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + |
|||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + |
|||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + |
|||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + |
|||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + |
|||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + |
|||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + |
|||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" + |
|||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + |
|||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + |
|||
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + |
|||
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + |
|||
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + |
|||
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + |
|||
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + |
|||
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + |
|||
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + |
|||
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" + |
|||
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + |
|||
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + |
|||
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF" |
|||
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex) |
|||
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes) |
|||
|
|||
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + |
|||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + |
|||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + |
|||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + |
|||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + |
|||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + |
|||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + |
|||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + |
|||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + |
|||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + |
|||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + |
|||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + |
|||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + |
|||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + |
|||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + |
|||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + |
|||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + |
|||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + |
|||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" + |
|||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + |
|||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + |
|||
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + |
|||
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + |
|||
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + |
|||
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + |
|||
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + |
|||
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + |
|||
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + |
|||
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" + |
|||
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + |
|||
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + |
|||
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" + |
|||
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" + |
|||
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" + |
|||
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" + |
|||
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" + |
|||
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" + |
|||
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" + |
|||
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" + |
|||
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" + |
|||
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" + |
|||
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" + |
|||
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF" |
|||
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex) |
|||
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes) |
|||
) |
|||
@ -0,0 +1,98 @@ |
|||
// Read a generic "framed" packet consisting of a header and a
|
|||
// This is used for both TLS Records and TLS Handshake Messages
|
|||
package mint |
|||
|
|||
type framing interface { |
|||
headerLen() int |
|||
defaultReadLen() int |
|||
frameLen(hdr []byte) (int, error) |
|||
} |
|||
|
|||
const ( |
|||
kFrameReaderHdr = 0 |
|||
kFrameReaderBody = 1 |
|||
) |
|||
|
|||
type frameNextAction func(f *frameReader) error |
|||
|
|||
type frameReader struct { |
|||
details framing |
|||
state uint8 |
|||
header []byte |
|||
body []byte |
|||
working []byte |
|||
writeOffset int |
|||
remainder []byte |
|||
} |
|||
|
|||
func newFrameReader(d framing) *frameReader { |
|||
hdr := make([]byte, d.headerLen()) |
|||
return &frameReader{ |
|||
d, |
|||
kFrameReaderHdr, |
|||
hdr, |
|||
nil, |
|||
hdr, |
|||
0, |
|||
nil, |
|||
} |
|||
} |
|||
|
|||
func dup(a []byte) []byte { |
|||
r := make([]byte, len(a)) |
|||
copy(r, a) |
|||
return r |
|||
} |
|||
|
|||
func (f *frameReader) needed() int { |
|||
tmp := (len(f.working) - f.writeOffset) - len(f.remainder) |
|||
if tmp < 0 { |
|||
return 0 |
|||
} |
|||
return tmp |
|||
} |
|||
|
|||
func (f *frameReader) addChunk(in []byte) { |
|||
// Append to the buffer.
|
|||
logf(logTypeFrameReader, "Appending %v", len(in)) |
|||
f.remainder = append(f.remainder, in...) |
|||
} |
|||
|
|||
func (f *frameReader) process() (hdr []byte, body []byte, err error) { |
|||
for f.needed() == 0 { |
|||
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset) |
|||
// Fill out our working block
|
|||
copied := copy(f.working[f.writeOffset:], f.remainder) |
|||
f.remainder = f.remainder[copied:] |
|||
f.writeOffset += copied |
|||
if f.writeOffset < len(f.working) { |
|||
logf(logTypeFrameReader, "Read would have blocked 1") |
|||
return nil, nil, WouldBlock |
|||
} |
|||
// Reset the write offset, because we are now full.
|
|||
f.writeOffset = 0 |
|||
|
|||
// We have read a full frame
|
|||
if f.state == kFrameReaderBody { |
|||
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder)) |
|||
f.state = kFrameReaderHdr |
|||
f.working = f.header |
|||
return dup(f.header), dup(f.body), nil |
|||
} |
|||
|
|||
// We have read the header
|
|||
bodyLen, err := f.details.frameLen(f.header) |
|||
if err != nil { |
|||
return nil, nil, err |
|||
} |
|||
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen) |
|||
|
|||
f.body = make([]byte, bodyLen) |
|||
f.working = f.body |
|||
f.writeOffset = 0 |
|||
f.state = kFrameReaderBody |
|||
} |
|||
|
|||
logf(logTypeFrameReader, "Read would have blocked 2") |
|||
return nil, nil, WouldBlock |
|||
} |
|||
@ -0,0 +1,253 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"fmt" |
|||
"io" |
|||
"net" |
|||
) |
|||
|
|||
const ( |
|||
handshakeHeaderLen = 4 // handshake message header length
|
|||
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
|||
) |
|||
|
|||
// struct {
|
|||
// HandshakeType msg_type; /* handshake type */
|
|||
// uint24 length; /* bytes in message */
|
|||
// select (HandshakeType) {
|
|||
// ...
|
|||
// } body;
|
|||
// } Handshake;
|
|||
//
|
|||
// We do the select{...} part in a different layer, so we treat the
|
|||
// actual message body as opaque:
|
|||
//
|
|||
// struct {
|
|||
// HandshakeType msg_type;
|
|||
// opaque msg<0..2^24-1>
|
|||
// } Handshake;
|
|||
//
|
|||
// TODO: File a spec bug
|
|||
type HandshakeMessage struct { |
|||
// Omitted: length
|
|||
msgType HandshakeType |
|||
body []byte |
|||
} |
|||
|
|||
// Note: This could be done with the `syntax` module, using the simplified
|
|||
// syntax as discussed above. However, since this is so simple, there's not
|
|||
// much benefit to doing so.
|
|||
func (hm *HandshakeMessage) Marshal() []byte { |
|||
if hm == nil { |
|||
return []byte{} |
|||
} |
|||
|
|||
msgLen := len(hm.body) |
|||
data := make([]byte, 4+len(hm.body)) |
|||
data[0] = byte(hm.msgType) |
|||
data[1] = byte(msgLen >> 16) |
|||
data[2] = byte(msgLen >> 8) |
|||
data[3] = byte(msgLen) |
|||
copy(data[4:], hm.body) |
|||
return data |
|||
} |
|||
|
|||
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) { |
|||
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body) |
|||
|
|||
var body HandshakeMessageBody |
|||
switch hm.msgType { |
|||
case HandshakeTypeClientHello: |
|||
body = new(ClientHelloBody) |
|||
case HandshakeTypeServerHello: |
|||
body = new(ServerHelloBody) |
|||
case HandshakeTypeHelloRetryRequest: |
|||
body = new(HelloRetryRequestBody) |
|||
case HandshakeTypeEncryptedExtensions: |
|||
body = new(EncryptedExtensionsBody) |
|||
case HandshakeTypeCertificate: |
|||
body = new(CertificateBody) |
|||
case HandshakeTypeCertificateRequest: |
|||
body = new(CertificateRequestBody) |
|||
case HandshakeTypeCertificateVerify: |
|||
body = new(CertificateVerifyBody) |
|||
case HandshakeTypeFinished: |
|||
body = &FinishedBody{VerifyDataLen: len(hm.body)} |
|||
case HandshakeTypeNewSessionTicket: |
|||
body = new(NewSessionTicketBody) |
|||
case HandshakeTypeKeyUpdate: |
|||
body = new(KeyUpdateBody) |
|||
case HandshakeTypeEndOfEarlyData: |
|||
body = new(EndOfEarlyDataBody) |
|||
default: |
|||
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") |
|||
} |
|||
|
|||
_, err := body.Unmarshal(hm.body) |
|||
return body, err |
|||
} |
|||
|
|||
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) { |
|||
data, err := body.Marshal() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return &HandshakeMessage{ |
|||
msgType: body.Type(), |
|||
body: data, |
|||
}, nil |
|||
} |
|||
|
|||
type HandshakeLayer struct { |
|||
nonblocking bool // Should we operate in nonblocking mode
|
|||
conn *RecordLayer // Used for reading/writing records
|
|||
frame *frameReader // The buffered frame reader
|
|||
} |
|||
|
|||
type handshakeLayerFrameDetails struct{} |
|||
|
|||
func (d handshakeLayerFrameDetails) headerLen() int { |
|||
return handshakeHeaderLen |
|||
} |
|||
|
|||
func (d handshakeLayerFrameDetails) defaultReadLen() int { |
|||
return handshakeHeaderLen + maxFragmentLen |
|||
} |
|||
|
|||
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { |
|||
logf(logTypeIO, "Header=%x", hdr) |
|||
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil |
|||
} |
|||
|
|||
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer { |
|||
h := HandshakeLayer{} |
|||
h.conn = r |
|||
h.frame = newFrameReader(&handshakeLayerFrameDetails{}) |
|||
return &h |
|||
} |
|||
|
|||
func (h *HandshakeLayer) readRecord() error { |
|||
logf(logTypeIO, "Trying to read record") |
|||
pt, err := h.conn.ReadRecord() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
if pt.contentType != RecordTypeHandshake && |
|||
pt.contentType != RecordTypeAlert { |
|||
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) |
|||
} |
|||
|
|||
if pt.contentType == RecordTypeAlert { |
|||
logf(logTypeIO, "read alert %v", pt.fragment[1]) |
|||
if len(pt.fragment) < 2 { |
|||
h.sendAlert(AlertUnexpectedMessage) |
|||
return io.EOF |
|||
} |
|||
return Alert(pt.fragment[1]) |
|||
} |
|||
|
|||
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment)) |
|||
h.frame.addChunk(pt.fragment) |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// sendAlert sends a TLS alert message.
|
|||
func (h *HandshakeLayer) sendAlert(err Alert) error { |
|||
tmp := make([]byte, 2) |
|||
tmp[0] = AlertLevelError |
|||
tmp[1] = byte(err) |
|||
h.conn.WriteRecord(&TLSPlaintext{ |
|||
contentType: RecordTypeAlert, |
|||
fragment: tmp}, |
|||
) |
|||
|
|||
// closeNotify is a special case in that it isn't an error:
|
|||
if err != AlertCloseNotify { |
|||
return &net.OpError{Op: "local error", Err: err} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { |
|||
var hdr, body []byte |
|||
var err error |
|||
|
|||
for { |
|||
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder)) |
|||
if h.frame.needed() > 0 { |
|||
logf(logTypeHandshake, "Trying to read a new record") |
|||
err = h.readRecord() |
|||
} |
|||
if err != nil && (h.nonblocking || err != WouldBlock) { |
|||
return nil, err |
|||
} |
|||
|
|||
hdr, body, err = h.frame.process() |
|||
if err == nil { |
|||
break |
|||
} |
|||
if err != nil && (h.nonblocking || err != WouldBlock) { |
|||
return nil, err |
|||
} |
|||
} |
|||
|
|||
logf(logTypeHandshake, "read handshake message") |
|||
|
|||
hm := &HandshakeMessage{} |
|||
hm.msgType = HandshakeType(hdr[0]) |
|||
|
|||
hm.body = make([]byte, len(body)) |
|||
copy(hm.body, body) |
|||
|
|||
return hm, nil |
|||
} |
|||
|
|||
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { |
|||
return h.WriteMessages([]*HandshakeMessage{hm}) |
|||
} |
|||
|
|||
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { |
|||
for _, hm := range hms { |
|||
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) |
|||
} |
|||
|
|||
// Write out headers and bodies
|
|||
buffer := []byte{} |
|||
for _, msg := range hms { |
|||
msgLen := len(msg.body) |
|||
if msgLen > maxHandshakeMessageLen { |
|||
return fmt.Errorf("tls.handshakelayer: Message too large to send") |
|||
} |
|||
|
|||
buffer = append(buffer, msg.Marshal()...) |
|||
} |
|||
|
|||
// Send full-size fragments
|
|||
var start int |
|||
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { |
|||
err := h.conn.WriteRecord(&TLSPlaintext{ |
|||
contentType: RecordTypeHandshake, |
|||
fragment: buffer[start : start+maxFragmentLen], |
|||
}) |
|||
|
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
// Send a final partial fragment if necessary
|
|||
if start < len(buffer) { |
|||
err := h.conn.WriteRecord(&TLSPlaintext{ |
|||
contentType: RecordTypeHandshake, |
|||
fragment: buffer[start:], |
|||
}) |
|||
|
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
@ -0,0 +1,450 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto" |
|||
"crypto/x509" |
|||
"encoding/binary" |
|||
"fmt" |
|||
|
|||
"github.com/bifurcation/mint/syntax" |
|||
) |
|||
|
|||
type HandshakeMessageBody interface { |
|||
Type() HandshakeType |
|||
Marshal() ([]byte, error) |
|||
Unmarshal(data []byte) (int, error) |
|||
} |
|||
|
|||
// struct {
|
|||
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
|
|||
// Random random;
|
|||
// opaque legacy_session_id<0..32>;
|
|||
// CipherSuite cipher_suites<2..2^16-2>;
|
|||
// opaque legacy_compression_methods<1..2^8-1>;
|
|||
// Extension extensions<0..2^16-1>;
|
|||
// } ClientHello;
|
|||
type ClientHelloBody struct { |
|||
// Omitted: clientVersion
|
|||
// Omitted: legacySessionID
|
|||
// Omitted: legacyCompressionMethods
|
|||
Random [32]byte |
|||
CipherSuites []CipherSuite |
|||
Extensions ExtensionList |
|||
} |
|||
|
|||
type clientHelloBodyInner struct { |
|||
LegacyVersion uint16 |
|||
Random [32]byte |
|||
LegacySessionID []byte `tls:"head=1,max=32"` |
|||
CipherSuites []CipherSuite `tls:"head=2,min=2"` |
|||
LegacyCompressionMethods []byte `tls:"head=1,min=1"` |
|||
Extensions []Extension `tls:"head=2"` |
|||
} |
|||
|
|||
func (ch ClientHelloBody) Type() HandshakeType { |
|||
return HandshakeTypeClientHello |
|||
} |
|||
|
|||
func (ch ClientHelloBody) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(clientHelloBodyInner{ |
|||
LegacyVersion: 0x0303, |
|||
Random: ch.Random, |
|||
LegacySessionID: []byte{}, |
|||
CipherSuites: ch.CipherSuites, |
|||
LegacyCompressionMethods: []byte{0}, |
|||
Extensions: ch.Extensions, |
|||
}) |
|||
} |
|||
|
|||
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) { |
|||
var inner clientHelloBodyInner |
|||
read, err := syntax.Unmarshal(data, &inner) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
// We are strict about these things because we only support 1.3
|
|||
if inner.LegacyVersion != 0x0303 { |
|||
return 0, fmt.Errorf("tls.clienthello: Incorrect version number") |
|||
} |
|||
|
|||
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { |
|||
return 0, fmt.Errorf("tls.clienthello: Invalid compression method") |
|||
} |
|||
|
|||
ch.Random = inner.Random |
|||
ch.CipherSuites = inner.CipherSuites |
|||
ch.Extensions = inner.Extensions |
|||
return read, nil |
|||
} |
|||
|
|||
// TODO: File a spec bug to clarify this
|
|||
func (ch ClientHelloBody) Truncated() ([]byte, error) { |
|||
if len(ch.Extensions) == 0 { |
|||
return nil, fmt.Errorf("tls.clienthello.truncate: No extensions") |
|||
} |
|||
|
|||
pskExt := ch.Extensions[len(ch.Extensions)-1] |
|||
if pskExt.ExtensionType != ExtensionTypePreSharedKey { |
|||
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK") |
|||
} |
|||
|
|||
chm, err := HandshakeMessageFromBody(&ch) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
chData := chm.Marshal() |
|||
|
|||
psk := PreSharedKeyExtension{ |
|||
HandshakeType: HandshakeTypeClientHello, |
|||
} |
|||
_, err = psk.Unmarshal(pskExt.ExtensionData) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
// Marshal just the binders so that we know how much to truncate
|
|||
binders := struct { |
|||
Binders []PSKBinderEntry `tls:"head=2,min=33"` |
|||
}{Binders: psk.Binders} |
|||
binderData, _ := syntax.Marshal(binders) |
|||
binderLen := len(binderData) |
|||
|
|||
chLen := len(chData) |
|||
return chData[:chLen-binderLen], nil |
|||
} |
|||
|
|||
// struct {
|
|||
// ProtocolVersion server_version;
|
|||
// CipherSuite cipher_suite;
|
|||
// Extension extensions<2..2^16-1>;
|
|||
// } HelloRetryRequest;
|
|||
type HelloRetryRequestBody struct { |
|||
Version uint16 |
|||
CipherSuite CipherSuite |
|||
Extensions ExtensionList `tls:"head=2,min=2"` |
|||
} |
|||
|
|||
func (hrr HelloRetryRequestBody) Type() HandshakeType { |
|||
return HandshakeTypeHelloRetryRequest |
|||
} |
|||
|
|||
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(hrr) |
|||
} |
|||
|
|||
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, hrr) |
|||
} |
|||
|
|||
// struct {
|
|||
// ProtocolVersion version;
|
|||
// Random random;
|
|||
// CipherSuite cipher_suite;
|
|||
// Extension extensions<0..2^16-1>;
|
|||
// } ServerHello;
|
|||
type ServerHelloBody struct { |
|||
Version uint16 |
|||
Random [32]byte |
|||
CipherSuite CipherSuite |
|||
Extensions ExtensionList `tls:"head=2"` |
|||
} |
|||
|
|||
func (sh ServerHelloBody) Type() HandshakeType { |
|||
return HandshakeTypeServerHello |
|||
} |
|||
|
|||
func (sh ServerHelloBody) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(sh) |
|||
} |
|||
|
|||
func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, sh) |
|||
} |
|||
|
|||
// struct {
|
|||
// opaque verify_data[verify_data_length];
|
|||
// } Finished;
|
|||
//
|
|||
// verifyDataLen is not a field in the TLS struct, but we add it here so
|
|||
// that calling code can tell us how much data to expect when we marshal /
|
|||
// unmarshal. (We could add this to the marshal/unmarshal methods, but let's
|
|||
// try to keep the signature consistent for now.)
|
|||
//
|
|||
// For similar reasons, we don't use the `syntax` module here, because this
|
|||
// struct doesn't map well to standard TLS presentation language concepts.
|
|||
//
|
|||
// TODO: File a spec bug
|
|||
type FinishedBody struct { |
|||
VerifyDataLen int |
|||
VerifyData []byte |
|||
} |
|||
|
|||
func (fin FinishedBody) Type() HandshakeType { |
|||
return HandshakeTypeFinished |
|||
} |
|||
|
|||
func (fin FinishedBody) Marshal() ([]byte, error) { |
|||
if len(fin.VerifyData) != fin.VerifyDataLen { |
|||
return nil, fmt.Errorf("tls.finished: data length mismatch") |
|||
} |
|||
|
|||
body := make([]byte, len(fin.VerifyData)) |
|||
copy(body, fin.VerifyData) |
|||
return body, nil |
|||
} |
|||
|
|||
func (fin *FinishedBody) Unmarshal(data []byte) (int, error) { |
|||
if len(data) < fin.VerifyDataLen { |
|||
return 0, fmt.Errorf("tls.finished: Malformed finished; too short") |
|||
} |
|||
|
|||
fin.VerifyData = make([]byte, fin.VerifyDataLen) |
|||
copy(fin.VerifyData, data[:fin.VerifyDataLen]) |
|||
return fin.VerifyDataLen, nil |
|||
} |
|||
|
|||
// struct {
|
|||
// Extension extensions<0..2^16-1>;
|
|||
// } EncryptedExtensions;
|
|||
//
|
|||
// Marshal() and Unmarshal() are handled by ExtensionList
|
|||
type EncryptedExtensionsBody struct { |
|||
Extensions ExtensionList `tls:"head=2"` |
|||
} |
|||
|
|||
func (ee EncryptedExtensionsBody) Type() HandshakeType { |
|||
return HandshakeTypeEncryptedExtensions |
|||
} |
|||
|
|||
func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(ee) |
|||
} |
|||
|
|||
func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, ee) |
|||
} |
|||
|
|||
// opaque ASN1Cert<1..2^24-1>;
|
|||
//
|
|||
// struct {
|
|||
// ASN1Cert cert_data;
|
|||
// Extension extensions<0..2^16-1>
|
|||
// } CertificateEntry;
|
|||
//
|
|||
// struct {
|
|||
// opaque certificate_request_context<0..2^8-1>;
|
|||
// CertificateEntry certificate_list<0..2^24-1>;
|
|||
// } Certificate;
|
|||
type CertificateEntry struct { |
|||
CertData *x509.Certificate |
|||
Extensions ExtensionList |
|||
} |
|||
|
|||
type CertificateBody struct { |
|||
CertificateRequestContext []byte |
|||
CertificateList []CertificateEntry |
|||
} |
|||
|
|||
type certificateEntryInner struct { |
|||
CertData []byte `tls:"head=3,min=1"` |
|||
Extensions ExtensionList `tls:"head=2"` |
|||
} |
|||
|
|||
type certificateBodyInner struct { |
|||
CertificateRequestContext []byte `tls:"head=1"` |
|||
CertificateList []certificateEntryInner `tls:"head=3"` |
|||
} |
|||
|
|||
func (c CertificateBody) Type() HandshakeType { |
|||
return HandshakeTypeCertificate |
|||
} |
|||
|
|||
func (c CertificateBody) Marshal() ([]byte, error) { |
|||
inner := certificateBodyInner{ |
|||
CertificateRequestContext: c.CertificateRequestContext, |
|||
CertificateList: make([]certificateEntryInner, len(c.CertificateList)), |
|||
} |
|||
|
|||
for i, entry := range c.CertificateList { |
|||
inner.CertificateList[i] = certificateEntryInner{ |
|||
CertData: entry.CertData.Raw, |
|||
Extensions: entry.Extensions, |
|||
} |
|||
} |
|||
|
|||
return syntax.Marshal(inner) |
|||
} |
|||
|
|||
func (c *CertificateBody) Unmarshal(data []byte) (int, error) { |
|||
inner := certificateBodyInner{} |
|||
read, err := syntax.Unmarshal(data, &inner) |
|||
if err != nil { |
|||
return read, err |
|||
} |
|||
|
|||
c.CertificateRequestContext = inner.CertificateRequestContext |
|||
c.CertificateList = make([]CertificateEntry, len(inner.CertificateList)) |
|||
|
|||
for i, entry := range inner.CertificateList { |
|||
c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData) |
|||
if err != nil { |
|||
return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err) |
|||
} |
|||
|
|||
c.CertificateList[i].Extensions = entry.Extensions |
|||
} |
|||
|
|||
return read, nil |
|||
} |
|||
|
|||
// struct {
|
|||
// SignatureScheme algorithm;
|
|||
// opaque signature<0..2^16-1>;
|
|||
// } CertificateVerify;
|
|||
type CertificateVerifyBody struct { |
|||
Algorithm SignatureScheme |
|||
Signature []byte `tls:"head=2"` |
|||
} |
|||
|
|||
func (cv CertificateVerifyBody) Type() HandshakeType { |
|||
return HandshakeTypeCertificateVerify |
|||
} |
|||
|
|||
func (cv CertificateVerifyBody) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(cv) |
|||
} |
|||
|
|||
func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, cv) |
|||
} |
|||
|
|||
func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte { |
|||
// TODO: Change context for client auth
|
|||
// TODO: Put this in a const
|
|||
const context = "TLS 1.3, server CertificateVerify" |
|||
sigInput := bytes.Repeat([]byte{0x20}, 64) |
|||
sigInput = append(sigInput, []byte(context)...) |
|||
sigInput = append(sigInput, []byte{0}...) |
|||
sigInput = append(sigInput, data...) |
|||
return sigInput |
|||
} |
|||
|
|||
func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) { |
|||
sigInput := cv.EncodeSignatureInput(handshakeHash) |
|||
cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput) |
|||
logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) |
|||
return |
|||
} |
|||
|
|||
func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error { |
|||
sigInput := cv.EncodeSignatureInput(handshakeHash) |
|||
logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) |
|||
return verify(cv.Algorithm, publicKey, sigInput, cv.Signature) |
|||
} |
|||
|
|||
// struct {
|
|||
// opaque certificate_request_context<0..2^8-1>;
|
|||
// Extension extensions<2..2^16-1>;
|
|||
// } CertificateRequest;
|
|||
type CertificateRequestBody struct { |
|||
CertificateRequestContext []byte `tls:"head=1"` |
|||
Extensions ExtensionList `tls:"head=2"` |
|||
} |
|||
|
|||
func (cr CertificateRequestBody) Type() HandshakeType { |
|||
return HandshakeTypeCertificateRequest |
|||
} |
|||
|
|||
func (cr CertificateRequestBody) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(cr) |
|||
} |
|||
|
|||
func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, cr) |
|||
} |
|||
|
|||
// struct {
|
|||
// uint32 ticket_lifetime;
|
|||
// uint32 ticket_age_add;
|
|||
// opaque ticket_nonce<1..255>;
|
|||
// opaque ticket<1..2^16-1>;
|
|||
// Extension extensions<0..2^16-2>;
|
|||
// } NewSessionTicket;
|
|||
type NewSessionTicketBody struct { |
|||
TicketLifetime uint32 |
|||
TicketAgeAdd uint32 |
|||
TicketNonce []byte `tls:"head=1,min=1"` |
|||
Ticket []byte `tls:"head=2,min=1"` |
|||
Extensions ExtensionList `tls:"head=2"` |
|||
} |
|||
|
|||
const ticketNonceLen = 16 |
|||
|
|||
func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) { |
|||
buf := make([]byte, 4+ticketNonceLen+ticketLen) |
|||
_, err := prng.Read(buf) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
tkt := &NewSessionTicketBody{ |
|||
TicketLifetime: ticketLifetime, |
|||
TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]), |
|||
TicketNonce: buf[4 : 4+ticketNonceLen], |
|||
Ticket: buf[4+ticketNonceLen:], |
|||
} |
|||
|
|||
return tkt, err |
|||
} |
|||
|
|||
func (tkt NewSessionTicketBody) Type() HandshakeType { |
|||
return HandshakeTypeNewSessionTicket |
|||
} |
|||
|
|||
func (tkt NewSessionTicketBody) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(tkt) |
|||
} |
|||
|
|||
func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, tkt) |
|||
} |
|||
|
|||
// enum {
|
|||
// update_not_requested(0), update_requested(1), (255)
|
|||
// } KeyUpdateRequest;
|
|||
//
|
|||
// struct {
|
|||
// KeyUpdateRequest request_update;
|
|||
// } KeyUpdate;
|
|||
type KeyUpdateBody struct { |
|||
KeyUpdateRequest KeyUpdateRequest |
|||
} |
|||
|
|||
func (ku KeyUpdateBody) Type() HandshakeType { |
|||
return HandshakeTypeKeyUpdate |
|||
} |
|||
|
|||
func (ku KeyUpdateBody) Marshal() ([]byte, error) { |
|||
return syntax.Marshal(ku) |
|||
} |
|||
|
|||
func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) { |
|||
return syntax.Unmarshal(data, ku) |
|||
} |
|||
|
|||
// struct {} EndOfEarlyData;
|
|||
type EndOfEarlyDataBody struct{} |
|||
|
|||
func (eoed EndOfEarlyDataBody) Type() HandshakeType { |
|||
return HandshakeTypeEndOfEarlyData |
|||
} |
|||
|
|||
func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) { |
|||
return []byte{}, nil |
|||
} |
|||
|
|||
func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) { |
|||
return 0, nil |
|||
} |
|||
@ -0,0 +1,55 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"fmt" |
|||
"log" |
|||
"os" |
|||
"strings" |
|||
) |
|||
|
|||
// We use this environment variable to control logging. It should be a
|
|||
// comma-separated list of log tags (see below) or "*" to enable all logging.
|
|||
const logConfigVar = "MINT_LOG" |
|||
|
|||
// Pre-defined log types
|
|||
const ( |
|||
logTypeCrypto = "crypto" |
|||
logTypeHandshake = "handshake" |
|||
logTypeNegotiation = "negotiation" |
|||
logTypeIO = "io" |
|||
logTypeFrameReader = "frame" |
|||
logTypeVerbose = "verbose" |
|||
) |
|||
|
|||
var ( |
|||
logFunction = log.Printf |
|||
logAll = false |
|||
logSettings = map[string]bool{} |
|||
) |
|||
|
|||
func init() { |
|||
parseLogEnv(os.Environ()) |
|||
} |
|||
|
|||
func parseLogEnv(env []string) { |
|||
for _, stmt := range env { |
|||
if strings.HasPrefix(stmt, logConfigVar+"=") { |
|||
val := stmt[len(logConfigVar)+1:] |
|||
|
|||
if val == "*" { |
|||
logAll = true |
|||
} else { |
|||
for _, t := range strings.Split(val, ",") { |
|||
logSettings[t] = true |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
func logf(tag string, format string, args ...interface{}) { |
|||
if logAll || logSettings[tag] { |
|||
fullFormat := fmt.Sprintf("[%s] %s", tag, format) |
|||
logFunction(fullFormat, args...) |
|||
} |
|||
} |
|||
|
After Width: | Height: | Size: 16 KiB |
@ -0,0 +1,217 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/hex" |
|||
"fmt" |
|||
"time" |
|||
) |
|||
|
|||
func VersionNegotiation(offered, supported []uint16) (bool, uint16) { |
|||
for _, offeredVersion := range offered { |
|||
for _, supportedVersion := range supported { |
|||
logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion) |
|||
if offeredVersion == supportedVersion { |
|||
// XXX: Should probably be highest supported version, but for now, we
|
|||
// only support one version, so it doesn't really matter.
|
|||
return true, offeredVersion |
|||
} |
|||
} |
|||
} |
|||
|
|||
return false, 0 |
|||
} |
|||
|
|||
func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) { |
|||
for _, share := range keyShares { |
|||
for _, group := range groups { |
|||
if group != share.Group { |
|||
continue |
|||
} |
|||
|
|||
pub, priv, err := newKeyShare(share.Group) |
|||
if err != nil { |
|||
// If we encounter an error, just keep looking
|
|||
continue |
|||
} |
|||
|
|||
dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv) |
|||
if err != nil { |
|||
// If we encounter an error, just keep looking
|
|||
continue |
|||
} |
|||
|
|||
return true, group, pub, dhSecret |
|||
} |
|||
} |
|||
|
|||
return false, 0, nil, nil |
|||
} |
|||
|
|||
const ( |
|||
ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds
|
|||
) |
|||
|
|||
func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) { |
|||
logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size()) |
|||
for i, id := range identities { |
|||
identityHex := hex.EncodeToString(id.Identity) |
|||
|
|||
psk, ok := psks.Get(identityHex) |
|||
if !ok { |
|||
logf(logTypeNegotiation, "No PSK for identity %x", identityHex) |
|||
continue |
|||
} |
|||
|
|||
// For resumption, make sure the ticket age is correct
|
|||
if psk.IsResumption { |
|||
extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd |
|||
knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond) |
|||
ticketAgeDelta := knownTicketAge - extTicketAge |
|||
if knownTicketAge < extTicketAge { |
|||
ticketAgeDelta = extTicketAge - knownTicketAge |
|||
} |
|||
if ticketAgeDelta > ticketAgeTolerance { |
|||
logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity) |
|||
logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]", |
|||
extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance) |
|||
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity) |
|||
} |
|||
} |
|||
|
|||
params, ok := cipherSuiteMap[psk.CipherSuite] |
|||
if !ok { |
|||
err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite) |
|||
return false, 0, nil, CipherSuiteParams{}, err |
|||
} |
|||
|
|||
// Compute binder
|
|||
binderLabel := labelExternalBinder |
|||
if psk.IsResumption { |
|||
binderLabel = labelResumptionBinder |
|||
} |
|||
|
|||
h0 := params.Hash.New().Sum(nil) |
|||
zero := bytes.Repeat([]byte{0}, params.Hash.Size()) |
|||
earlySecret := HkdfExtract(params.Hash, zero, psk.Key) |
|||
binderKey := deriveSecret(params, earlySecret, binderLabel, h0) |
|||
|
|||
// context = ClientHello[truncated]
|
|||
// context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated]
|
|||
ctxHash := params.Hash.New() |
|||
ctxHash.Write(context) |
|||
|
|||
binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil)) |
|||
if !bytes.Equal(binder, binders[i].Binder) { |
|||
logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder) |
|||
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity) |
|||
} |
|||
|
|||
logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity) |
|||
return true, i, &psk, params, nil |
|||
} |
|||
|
|||
logf(logTypeNegotiation, "Failed to find a usable PSK") |
|||
return false, 0, nil, CipherSuiteParams{}, nil |
|||
} |
|||
|
|||
func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) { |
|||
logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes) |
|||
dhAllowed := false |
|||
dhRequired := true |
|||
for _, mode := range modes { |
|||
dhAllowed = dhAllowed || (mode == PSKModeDHEKE) |
|||
dhRequired = dhRequired && (mode == PSKModeDHEKE) |
|||
} |
|||
|
|||
// Use PSK if we can meet DH requirement and modes were provided
|
|||
usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0) |
|||
|
|||
// Use DH if allowed
|
|||
usingDH := canDoDH && (dhAllowed || !usingPSK) |
|||
|
|||
logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK) |
|||
return usingDH, usingPSK |
|||
} |
|||
|
|||
func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) { |
|||
// Select for server name if provided
|
|||
candidates := certs |
|||
if serverName != nil { |
|||
candidatesByName := []*Certificate{} |
|||
for _, cert := range certs { |
|||
for _, name := range cert.Chain[0].DNSNames { |
|||
if len(*serverName) > 0 && name == *serverName { |
|||
candidatesByName = append(candidatesByName, cert) |
|||
} |
|||
} |
|||
} |
|||
|
|||
if len(candidatesByName) == 0 { |
|||
return nil, 0, fmt.Errorf("No certificates available for server name") |
|||
} |
|||
|
|||
candidates = candidatesByName |
|||
} |
|||
|
|||
// Select for signature scheme
|
|||
for _, cert := range candidates { |
|||
for _, scheme := range signatureSchemes { |
|||
if !schemeValidForKey(scheme, cert.PrivateKey) { |
|||
continue |
|||
} |
|||
|
|||
return cert, scheme, nil |
|||
} |
|||
} |
|||
|
|||
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") |
|||
} |
|||
|
|||
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool { |
|||
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData |
|||
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData) |
|||
return usingEarlyData |
|||
} |
|||
|
|||
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { |
|||
for _, s1 := range offered { |
|||
if psk != nil { |
|||
if s1 == psk.CipherSuite { |
|||
return s1, nil |
|||
} |
|||
continue |
|||
} |
|||
|
|||
for _, s2 := range supported { |
|||
if s1 == s2 { |
|||
return s1, nil |
|||
} |
|||
} |
|||
} |
|||
|
|||
return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil) |
|||
} |
|||
|
|||
func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) { |
|||
for _, p1 := range offered { |
|||
if psk != nil { |
|||
if p1 != psk.NextProto { |
|||
continue |
|||
} |
|||
} |
|||
|
|||
for _, p2 := range supported { |
|||
if p1 == p2 { |
|||
return p1, nil |
|||
} |
|||
} |
|||
} |
|||
|
|||
// If the client offers ALPN on resumption, it must match the earlier one
|
|||
var err error |
|||
if psk != nil && psk.IsResumption && (len(offered) > 0) { |
|||
err = fmt.Errorf("ALPN for PSK not provided") |
|||
} |
|||
return "", err |
|||
} |
|||
@ -0,0 +1,296 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto/cipher" |
|||
"fmt" |
|||
"io" |
|||
"sync" |
|||
) |
|||
|
|||
const ( |
|||
sequenceNumberLen = 8 // sequence number length
|
|||
recordHeaderLen = 5 // record header length
|
|||
maxFragmentLen = 1 << 14 // max number of bytes in a record
|
|||
) |
|||
|
|||
type DecryptError string |
|||
|
|||
func (err DecryptError) Error() string { |
|||
return string(err) |
|||
} |
|||
|
|||
// struct {
|
|||
// ContentType type;
|
|||
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
|
|||
// uint16 length;
|
|||
// opaque fragment[TLSPlaintext.length];
|
|||
// } TLSPlaintext;
|
|||
type TLSPlaintext struct { |
|||
// Omitted: record_version (static)
|
|||
// Omitted: length (computed from fragment)
|
|||
contentType RecordType |
|||
fragment []byte |
|||
} |
|||
|
|||
type RecordLayer struct { |
|||
sync.Mutex |
|||
|
|||
conn io.ReadWriter // The underlying connection
|
|||
frame *frameReader // The buffered frame reader
|
|||
nextData []byte // The next record to send
|
|||
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
|
|||
cachedError error // Error on the last record read
|
|||
|
|||
ivLength int // Length of the seq and nonce fields
|
|||
seq []byte // Zero-padded sequence number
|
|||
nonce []byte // Buffer for per-record nonces
|
|||
cipher cipher.AEAD // AEAD cipher
|
|||
} |
|||
|
|||
type recordLayerFrameDetails struct{} |
|||
|
|||
func (d recordLayerFrameDetails) headerLen() int { |
|||
return recordHeaderLen |
|||
} |
|||
|
|||
func (d recordLayerFrameDetails) defaultReadLen() int { |
|||
return recordHeaderLen + maxFragmentLen |
|||
} |
|||
|
|||
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { |
|||
return (int(hdr[3]) << 8) | int(hdr[4]), nil |
|||
} |
|||
|
|||
func NewRecordLayer(conn io.ReadWriter) *RecordLayer { |
|||
r := RecordLayer{} |
|||
r.conn = conn |
|||
r.frame = newFrameReader(recordLayerFrameDetails{}) |
|||
r.ivLength = 0 |
|||
return &r |
|||
} |
|||
|
|||
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error { |
|||
var err error |
|||
r.cipher, err = cipher(key) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
r.ivLength = len(iv) |
|||
r.seq = bytes.Repeat([]byte{0}, r.ivLength) |
|||
r.nonce = make([]byte, r.ivLength) |
|||
copy(r.nonce, iv) |
|||
return nil |
|||
} |
|||
|
|||
func (r *RecordLayer) incrementSequenceNumber() { |
|||
if r.ivLength == 0 { |
|||
return |
|||
} |
|||
|
|||
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- { |
|||
r.seq[i]++ |
|||
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i] |
|||
if r.seq[i] != 0 { |
|||
return |
|||
} |
|||
} |
|||
|
|||
// Not allowed to let sequence number wrap.
|
|||
// Instead, must renegotiate before it does.
|
|||
// Not likely enough to bother.
|
|||
panic("TLS: sequence number wraparound") |
|||
} |
|||
|
|||
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext { |
|||
// Expand the fragment to hold contentType, padding, and overhead
|
|||
originalLen := len(pt.fragment) |
|||
plaintextLen := originalLen + 1 + padLen |
|||
ciphertextLen := plaintextLen + r.cipher.Overhead() |
|||
|
|||
// Assemble the revised plaintext
|
|||
out := &TLSPlaintext{ |
|||
contentType: RecordTypeApplicationData, |
|||
fragment: make([]byte, ciphertextLen), |
|||
} |
|||
copy(out.fragment, pt.fragment) |
|||
out.fragment[originalLen] = byte(pt.contentType) |
|||
for i := 1; i <= padLen; i++ { |
|||
out.fragment[originalLen+i] = 0 |
|||
} |
|||
|
|||
// Encrypt the fragment
|
|||
payload := out.fragment[:plaintextLen] |
|||
r.cipher.Seal(payload[:0], r.nonce, payload, nil) |
|||
return out |
|||
} |
|||
|
|||
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) { |
|||
if len(pt.fragment) < r.cipher.Overhead() { |
|||
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead()) |
|||
return nil, 0, DecryptError(msg) |
|||
} |
|||
|
|||
decryptLen := len(pt.fragment) - r.cipher.Overhead() |
|||
out := &TLSPlaintext{ |
|||
contentType: pt.contentType, |
|||
fragment: make([]byte, decryptLen), |
|||
} |
|||
|
|||
// Decrypt
|
|||
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil) |
|||
if err != nil { |
|||
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") |
|||
} |
|||
|
|||
// Find the padding boundary
|
|||
padLen := 0 |
|||
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ { |
|||
} |
|||
|
|||
// Transfer the content type
|
|||
newLen := decryptLen - padLen - 1 |
|||
out.contentType = RecordType(out.fragment[newLen]) |
|||
|
|||
// Truncate the message to remove contentType, padding, overhead
|
|||
out.fragment = out.fragment[:newLen] |
|||
return out, padLen, nil |
|||
} |
|||
|
|||
func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { |
|||
var pt *TLSPlaintext |
|||
var err error |
|||
|
|||
for { |
|||
pt, err = r.nextRecord() |
|||
if err == nil { |
|||
break |
|||
} |
|||
if !block || err != WouldBlock { |
|||
return 0, err |
|||
} |
|||
} |
|||
return pt.contentType, nil |
|||
} |
|||
|
|||
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { |
|||
pt, err := r.nextRecord() |
|||
|
|||
// Consume the cached record if there was one
|
|||
r.cachedRecord = nil |
|||
r.cachedError = nil |
|||
|
|||
return pt, err |
|||
} |
|||
|
|||
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { |
|||
if r.cachedRecord != nil { |
|||
logf(logTypeIO, "Returning cached record") |
|||
return r.cachedRecord, r.cachedError |
|||
} |
|||
|
|||
// Loop until one of three things happens:
|
|||
//
|
|||
// 1. We get a frame
|
|||
// 2. We try to read off the socket and get nothing, in which case
|
|||
// return WouldBlock
|
|||
// 3. We get an error.
|
|||
err := WouldBlock |
|||
var header, body []byte |
|||
|
|||
for err != nil { |
|||
if r.frame.needed() > 0 { |
|||
buf := make([]byte, recordHeaderLen+maxFragmentLen) |
|||
n, err := r.conn.Read(buf) |
|||
if err != nil { |
|||
logf(logTypeIO, "Error reading, %v", err) |
|||
return nil, err |
|||
} |
|||
|
|||
if n == 0 { |
|||
return nil, WouldBlock |
|||
} |
|||
|
|||
logf(logTypeIO, "Read %v bytes", n) |
|||
|
|||
buf = buf[:n] |
|||
r.frame.addChunk(buf) |
|||
} |
|||
|
|||
header, body, err = r.frame.process() |
|||
// Loop around on WouldBlock to see if some
|
|||
// data is now available.
|
|||
if err != nil && err != WouldBlock { |
|||
return nil, err |
|||
} |
|||
} |
|||
|
|||
pt := &TLSPlaintext{} |
|||
// Validate content type
|
|||
switch RecordType(header[0]) { |
|||
default: |
|||
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) |
|||
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData: |
|||
pt.contentType = RecordType(header[0]) |
|||
} |
|||
|
|||
// Validate version
|
|||
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) { |
|||
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2]) |
|||
} |
|||
|
|||
// Validate size < max
|
|||
size := (int(header[3]) << 8) + int(header[4]) |
|||
if size > maxFragmentLen+256 { |
|||
return nil, fmt.Errorf("tls.record: Ciphertext size too big") |
|||
} |
|||
|
|||
pt.fragment = make([]byte, size) |
|||
copy(pt.fragment, body) |
|||
|
|||
// Attempt to decrypt fragment
|
|||
if r.cipher != nil { |
|||
pt, _, err = r.decrypt(pt) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
|
|||
// Check that plaintext length is not too long
|
|||
if len(pt.fragment) > maxFragmentLen { |
|||
return nil, fmt.Errorf("tls.record: Plaintext size too big") |
|||
} |
|||
|
|||
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment) |
|||
|
|||
r.cachedRecord = pt |
|||
r.incrementSequenceNumber() |
|||
return pt, nil |
|||
} |
|||
|
|||
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error { |
|||
return r.WriteRecordWithPadding(pt, 0) |
|||
} |
|||
|
|||
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error { |
|||
if r.cipher != nil { |
|||
pt = r.encrypt(pt, padLen) |
|||
} else if padLen > 0 { |
|||
return fmt.Errorf("tls.record: Padding can only be done on encrypted records") |
|||
} |
|||
|
|||
if len(pt.fragment) > maxFragmentLen { |
|||
return fmt.Errorf("tls.record: Record size too big") |
|||
} |
|||
|
|||
length := len(pt.fragment) |
|||
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)} |
|||
record := append(header, pt.fragment...) |
|||
|
|||
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment) |
|||
|
|||
r.incrementSequenceNumber() |
|||
_, err := r.conn.Write(record) |
|||
return err |
|||
} |
|||
@ -0,0 +1,898 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"bytes" |
|||
"hash" |
|||
"reflect" |
|||
) |
|||
|
|||
// Server State Machine
|
|||
//
|
|||
// START <-----+
|
|||
// Recv ClientHello | | Send HelloRetryRequest
|
|||
// v |
|
|||
// RECVD_CH ----+
|
|||
// | Select parameters
|
|||
// | Send ServerHello
|
|||
// v
|
|||
// NEGOTIATED
|
|||
// | Send EncryptedExtensions
|
|||
// | [Send CertificateRequest]
|
|||
// Can send | [Send Certificate + CertificateVerify]
|
|||
// app data --> | Send Finished
|
|||
// after +--------+--------+
|
|||
// here No 0-RTT | | 0-RTT
|
|||
// | v
|
|||
// | WAIT_EOED <---+
|
|||
// | Recv | | | Recv
|
|||
// | EndOfEarlyData | | | early data
|
|||
// | | +-----+
|
|||
// +> WAIT_FLIGHT2 <-+
|
|||
// |
|
|||
// +--------+--------+
|
|||
// No auth | | Client auth
|
|||
// | |
|
|||
// | v
|
|||
// | WAIT_CERT
|
|||
// | Recv | | Recv Certificate
|
|||
// | empty | v
|
|||
// | Certificate | WAIT_CV
|
|||
// | | | Recv
|
|||
// | v | CertificateVerify
|
|||
// +-> WAIT_FINISHED <---+
|
|||
// | Recv Finished
|
|||
// v
|
|||
// CONNECTED
|
|||
//
|
|||
// NB: Not using state RECVD_CH
|
|||
//
|
|||
// State Instructions
|
|||
// START {}
|
|||
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
|
|||
// WAIT_EOED RekeyIn;
|
|||
// WAIT_FLIGHT2 {}
|
|||
// WAIT_CERT_CR {}
|
|||
// WAIT_CERT {}
|
|||
// WAIT_CV {}
|
|||
// WAIT_FINISHED RekeyIn; RekeyOut;
|
|||
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
|||
|
|||
type ServerStateStart struct { |
|||
Caps Capabilities |
|||
conn *Conn |
|||
|
|||
cookieSent bool |
|||
firstClientHello *HandshakeMessage |
|||
helloRetryRequest *HandshakeMessage |
|||
} |
|||
|
|||
func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil || hm.msgType != HandshakeTypeClientHello { |
|||
logf(logTypeHandshake, "[ServerStateStart] unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
ch := &ClientHelloBody{} |
|||
_, err := ch.Unmarshal(hm.body) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
clientHello := hm |
|||
connParams := ConnectionParameters{} |
|||
|
|||
supportedVersions := new(SupportedVersionsExtension) |
|||
serverName := new(ServerNameExtension) |
|||
supportedGroups := new(SupportedGroupsExtension) |
|||
signatureAlgorithms := new(SignatureAlgorithmsExtension) |
|||
clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello} |
|||
clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello} |
|||
clientEarlyData := &EarlyDataExtension{} |
|||
clientALPN := new(ALPNExtension) |
|||
clientPSKModes := new(PSKKeyExchangeModesExtension) |
|||
clientCookie := new(CookieExtension) |
|||
|
|||
// Handle external extensions.
|
|||
if state.Caps.ExtensionHandler != nil { |
|||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
gotSupportedVersions := ch.Extensions.Find(supportedVersions) |
|||
gotServerName := ch.Extensions.Find(serverName) |
|||
gotSupportedGroups := ch.Extensions.Find(supportedGroups) |
|||
gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms) |
|||
gotEarlyData := ch.Extensions.Find(clientEarlyData) |
|||
ch.Extensions.Find(clientKeyShares) |
|||
ch.Extensions.Find(clientPSK) |
|||
ch.Extensions.Find(clientALPN) |
|||
ch.Extensions.Find(clientPSKModes) |
|||
ch.Extensions.Find(clientCookie) |
|||
|
|||
if gotServerName { |
|||
connParams.ServerName = string(*serverName) |
|||
} |
|||
|
|||
// If the client didn't send supportedVersions or doesn't support 1.3,
|
|||
// then we're done here.
|
|||
if !gotSupportedVersions { |
|||
logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions") |
|||
return nil, nil, AlertProtocolVersion |
|||
} |
|||
versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion}) |
|||
if !versionOK { |
|||
logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version") |
|||
return nil, nil, AlertProtocolVersion |
|||
} |
|||
|
|||
if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) { |
|||
logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch") |
|||
return nil, nil, AlertAccessDenied |
|||
} |
|||
|
|||
// Figure out if we can do DH
|
|||
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups) |
|||
|
|||
// Figure out if we can do PSK
|
|||
canDoPSK := false |
|||
var selectedPSK int |
|||
var psk *PreSharedKey |
|||
var params CipherSuiteParams |
|||
if len(clientPSK.Identities) > 0 { |
|||
contextBase := []byte{} |
|||
if state.helloRetryRequest != nil { |
|||
chBytes := state.firstClientHello.Marshal() |
|||
hrrBytes := state.helloRetryRequest.Marshal() |
|||
contextBase = append(chBytes, hrrBytes...) |
|||
} |
|||
|
|||
chTrunc, err := ch.Truncated() |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
context := append(contextBase, chTrunc...) |
|||
|
|||
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
// Figure out if we actually should do DH / PSK
|
|||
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes) |
|||
|
|||
// Select a ciphersuite
|
|||
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err) |
|||
return nil, nil, AlertHandshakeFailure |
|||
} |
|||
|
|||
// Send a cookie if required
|
|||
// NB: Need to do this here because it's after ciphersuite selection, which
|
|||
// has to be after PSK selection.
|
|||
// XXX: Doing this statefully for now, could be stateless
|
|||
var cookieData []byte |
|||
if state.Caps.RequireCookie && !state.cookieSent { |
|||
var err error |
|||
cookieData, err = state.Caps.CookieHandler.Generate(state.conn) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
if cookieData != nil { |
|||
// Ignoring errors because everything here is newly constructed, so there
|
|||
// shouldn't be marshal errors
|
|||
hrr := &HelloRetryRequestBody{ |
|||
Version: supportedVersion, |
|||
CipherSuite: connParams.CipherSuite, |
|||
} |
|||
hrr.Extensions.Add(&CookieExtension{Cookie: cookieData}) |
|||
|
|||
// Run the external extension handler.
|
|||
if state.Caps.ExtensionHandler != nil { |
|||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
helloRetryRequest, err := HandshakeMessageFromBody(hrr) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
params := cipherSuiteMap[connParams.CipherSuite] |
|||
h := params.Hash.New() |
|||
h.Write(clientHello.Marshal()) |
|||
firstClientHello := &HandshakeMessage{ |
|||
msgType: HandshakeTypeMessageHash, |
|||
body: h.Sum(nil), |
|||
} |
|||
|
|||
nextState := ServerStateStart{ |
|||
Caps: state.Caps, |
|||
conn: state.conn, |
|||
cookieSent: true, |
|||
firstClientHello: firstClientHello, |
|||
helloRetryRequest: helloRetryRequest, |
|||
} |
|||
toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}} |
|||
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]") |
|||
return nextState, toSend, AlertNoAlert |
|||
} |
|||
|
|||
// If we've got no entropy to make keys from, fail
|
|||
if !connParams.UsingDH && !connParams.UsingPSK { |
|||
logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated") |
|||
return nil, nil, AlertHandshakeFailure |
|||
} |
|||
|
|||
var pskSecret []byte |
|||
var cert *Certificate |
|||
var certScheme SignatureScheme |
|||
if connParams.UsingPSK { |
|||
pskSecret = psk.Key |
|||
} else { |
|||
psk = nil |
|||
|
|||
// If we're not using a PSK mode, then we need to have certain extensions
|
|||
if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms { |
|||
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)", |
|||
gotServerName, gotSupportedGroups, gotSignatureAlgorithms) |
|||
return nil, nil, AlertMissingExtension |
|||
} |
|||
|
|||
// Select a certificate
|
|||
name := string(*serverName) |
|||
var err error |
|||
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err) |
|||
return nil, nil, AlertAccessDenied |
|||
} |
|||
} |
|||
|
|||
if !connParams.UsingDH { |
|||
dhSecret = nil |
|||
} |
|||
|
|||
// Figure out if we're going to do early data
|
|||
var clientEarlyTrafficSecret []byte |
|||
connParams.ClientSendingEarlyData = gotEarlyData |
|||
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData) |
|||
if connParams.UsingEarlyData { |
|||
|
|||
h := params.Hash.New() |
|||
h.Write(clientHello.Marshal()) |
|||
chHash := h.Sum(nil) |
|||
|
|||
zero := bytes.Repeat([]byte{0}, params.Hash.Size()) |
|||
earlySecret := HkdfExtract(params.Hash, zero, pskSecret) |
|||
clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) |
|||
} |
|||
|
|||
// Select a next protocol
|
|||
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err) |
|||
return nil, nil, AlertNoApplicationProtocol |
|||
} |
|||
|
|||
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") |
|||
return ServerStateNegotiated{ |
|||
Caps: state.Caps, |
|||
Params: connParams, |
|||
|
|||
dhGroup: dhGroup, |
|||
dhPublic: dhPublic, |
|||
dhSecret: dhSecret, |
|||
pskSecret: pskSecret, |
|||
selectedPSK: selectedPSK, |
|||
cert: cert, |
|||
certScheme: certScheme, |
|||
clientEarlyTrafficSecret: clientEarlyTrafficSecret, |
|||
|
|||
firstClientHello: state.firstClientHello, |
|||
helloRetryRequest: state.helloRetryRequest, |
|||
clientHello: clientHello, |
|||
}.Next(nil) |
|||
} |
|||
|
|||
type ServerStateNegotiated struct { |
|||
Caps Capabilities |
|||
Params ConnectionParameters |
|||
|
|||
dhGroup NamedGroup |
|||
dhPublic []byte |
|||
dhSecret []byte |
|||
pskSecret []byte |
|||
clientEarlyTrafficSecret []byte |
|||
selectedPSK int |
|||
cert *Certificate |
|||
certScheme SignatureScheme |
|||
|
|||
firstClientHello *HandshakeMessage |
|||
helloRetryRequest *HandshakeMessage |
|||
clientHello *HandshakeMessage |
|||
} |
|||
|
|||
func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
// Create the ServerHello
|
|||
sh := &ServerHelloBody{ |
|||
Version: supportedVersion, |
|||
CipherSuite: state.Params.CipherSuite, |
|||
} |
|||
_, err := prng.Read(sh.Random[:]) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
if state.Params.UsingDH { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension") |
|||
err = sh.Extensions.Add(&KeyShareExtension{ |
|||
HandshakeType: HandshakeTypeServerHello, |
|||
Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}}, |
|||
}) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
if state.Params.UsingPSK { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension") |
|||
err = sh.Extensions.Add(&PreSharedKeyExtension{ |
|||
HandshakeType: HandshakeTypeServerHello, |
|||
SelectedIdentity: uint16(state.selectedPSK), |
|||
}) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
// Run the external extension handler.
|
|||
if state.Caps.ExtensionHandler != nil { |
|||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
serverHello, err := HandshakeMessageFromBody(sh) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
// Look up crypto params
|
|||
params, ok := cipherSuiteMap[sh.CipherSuite] |
|||
if !ok { |
|||
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite) |
|||
return nil, nil, AlertHandshakeFailure |
|||
} |
|||
|
|||
// Start up the handshake hash
|
|||
handshakeHash := params.Hash.New() |
|||
handshakeHash.Write(state.firstClientHello.Marshal()) |
|||
handshakeHash.Write(state.helloRetryRequest.Marshal()) |
|||
handshakeHash.Write(state.clientHello.Marshal()) |
|||
handshakeHash.Write(serverHello.Marshal()) |
|||
|
|||
// Compute handshake secrets
|
|||
zero := bytes.Repeat([]byte{0}, params.Hash.Size()) |
|||
|
|||
var earlySecret []byte |
|||
if state.Params.UsingPSK { |
|||
earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret) |
|||
} else { |
|||
earlySecret = HkdfExtract(params.Hash, zero, zero) |
|||
} |
|||
|
|||
if state.dhSecret == nil { |
|||
state.dhSecret = zero |
|||
} |
|||
|
|||
h0 := params.Hash.New().Sum(nil) |
|||
h2 := handshakeHash.Sum(nil) |
|||
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) |
|||
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret) |
|||
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) |
|||
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) |
|||
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) |
|||
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) |
|||
|
|||
logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret) |
|||
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) |
|||
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) |
|||
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) |
|||
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) |
|||
|
|||
clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret) |
|||
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) |
|||
|
|||
// Send an EncryptedExtensions message (even if it's empty)
|
|||
eeList := ExtensionList{} |
|||
if state.Params.NextProto != "" { |
|||
logf(logTypeHandshake, "[server] sending ALPN extension") |
|||
err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}}) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
if state.Params.UsingEarlyData { |
|||
logf(logTypeHandshake, "[server] sending EDI extension") |
|||
err = eeList.Add(&EarlyDataExtension{}) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
ee := &EncryptedExtensionsBody{eeList} |
|||
|
|||
// Run the external extension handler.
|
|||
if state.Caps.ExtensionHandler != nil { |
|||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
} |
|||
|
|||
eem, err := HandshakeMessageFromBody(ee) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
handshakeHash.Write(eem.Marshal()) |
|||
|
|||
toSend := []HandshakeAction{ |
|||
SendHandshakeMessage{serverHello}, |
|||
RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys}, |
|||
SendHandshakeMessage{eem}, |
|||
} |
|||
|
|||
// Authenticate with a certificate if required
|
|||
if !state.Params.UsingPSK { |
|||
// Send a CertificateRequest message if we want client auth
|
|||
if state.Caps.RequireClientAuth { |
|||
state.Params.UsingClientAuth = true |
|||
|
|||
// XXX: We don't support sending any constraints besides a list of
|
|||
// supported signature algorithms
|
|||
cr := &CertificateRequestBody{} |
|||
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} |
|||
err := cr.Extensions.Add(schemes) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
crm, err := HandshakeMessageFromBody(cr) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
//TODO state.state.serverCertificateRequest = cr
|
|||
|
|||
toSend = append(toSend, SendHandshakeMessage{crm}) |
|||
handshakeHash.Write(crm.Marshal()) |
|||
} |
|||
|
|||
// Create and send Certificate, CertificateVerify
|
|||
certificate := &CertificateBody{ |
|||
CertificateList: make([]CertificateEntry, len(state.cert.Chain)), |
|||
} |
|||
for i, entry := range state.cert.Chain { |
|||
certificate.CertificateList[i] = CertificateEntry{CertData: entry} |
|||
} |
|||
certm, err := HandshakeMessageFromBody(certificate) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
toSend = append(toSend, SendHandshakeMessage{certm}) |
|||
handshakeHash.Write(certm.Marshal()) |
|||
|
|||
certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme} |
|||
logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash) |
|||
|
|||
hcv := handshakeHash.Sum(nil) |
|||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) |
|||
|
|||
err = certificateVerify.Sign(state.cert.PrivateKey, hcv) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
certvm, err := HandshakeMessageFromBody(certificateVerify) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err) |
|||
return nil, nil, AlertInternalError |
|||
} |
|||
|
|||
toSend = append(toSend, SendHandshakeMessage{certvm}) |
|||
handshakeHash.Write(certvm.Marshal()) |
|||
} |
|||
|
|||
// Compute secrets resulting from the server's first flight
|
|||
h3 := handshakeHash.Sum(nil) |
|||
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) |
|||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) |
|||
|
|||
serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3) |
|||
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) |
|||
|
|||
// Assemble the Finished message
|
|||
fin := &FinishedBody{ |
|||
VerifyDataLen: len(serverFinishedData), |
|||
VerifyData: serverFinishedData, |
|||
} |
|||
finm, _ := HandshakeMessageFromBody(fin) |
|||
|
|||
toSend = append(toSend, SendHandshakeMessage{finm}) |
|||
handshakeHash.Write(finm.Marshal()) |
|||
|
|||
// Compute traffic secrets
|
|||
h4 := handshakeHash.Sum(nil) |
|||
logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4) |
|||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4) |
|||
|
|||
clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4) |
|||
serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4) |
|||
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) |
|||
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) |
|||
|
|||
serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret) |
|||
toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys}) |
|||
|
|||
exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4) |
|||
logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret) |
|||
|
|||
if state.Params.UsingEarlyData { |
|||
clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret) |
|||
|
|||
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]") |
|||
nextState := ServerStateWaitEOED{ |
|||
AuthCertificate: state.Caps.AuthCertificate, |
|||
Params: state.Params, |
|||
cryptoParams: params, |
|||
handshakeHash: handshakeHash, |
|||
masterSecret: masterSecret, |
|||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, |
|||
clientTrafficSecret: clientTrafficSecret, |
|||
serverTrafficSecret: serverTrafficSecret, |
|||
exporterSecret: exporterSecret, |
|||
} |
|||
toSend = append(toSend, []HandshakeAction{ |
|||
RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys}, |
|||
ReadEarlyData{}, |
|||
}...) |
|||
return nextState, toSend, AlertNoAlert |
|||
} |
|||
|
|||
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]") |
|||
toSend = append(toSend, []HandshakeAction{ |
|||
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys}, |
|||
ReadPastEarlyData{}, |
|||
}...) |
|||
waitFlight2 := ServerStateWaitFlight2{ |
|||
AuthCertificate: state.Caps.AuthCertificate, |
|||
Params: state.Params, |
|||
cryptoParams: params, |
|||
handshakeHash: handshakeHash, |
|||
masterSecret: masterSecret, |
|||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, |
|||
clientTrafficSecret: clientTrafficSecret, |
|||
serverTrafficSecret: serverTrafficSecret, |
|||
exporterSecret: exporterSecret, |
|||
} |
|||
nextState, moreToSend, alert := waitFlight2.Next(nil) |
|||
toSend = append(toSend, moreToSend...) |
|||
return nextState, toSend, alert |
|||
} |
|||
|
|||
type ServerStateWaitEOED struct { |
|||
AuthCertificate func(chain []CertificateEntry) error |
|||
Params ConnectionParameters |
|||
cryptoParams CipherSuiteParams |
|||
masterSecret []byte |
|||
clientHandshakeTrafficSecret []byte |
|||
handshakeHash hash.Hash |
|||
clientTrafficSecret []byte |
|||
serverTrafficSecret []byte |
|||
exporterSecret []byte |
|||
} |
|||
|
|||
func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData { |
|||
logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
if len(hm.body) > 0 { |
|||
logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]") |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
state.handshakeHash.Write(hm.Marshal()) |
|||
|
|||
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) |
|||
|
|||
logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]") |
|||
toSend := []HandshakeAction{ |
|||
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys}, |
|||
} |
|||
waitFlight2 := ServerStateWaitFlight2{ |
|||
AuthCertificate: state.AuthCertificate, |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
handshakeHash: state.handshakeHash, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
clientTrafficSecret: state.clientTrafficSecret, |
|||
serverTrafficSecret: state.serverTrafficSecret, |
|||
exporterSecret: state.exporterSecret, |
|||
} |
|||
nextState, moreToSend, alert := waitFlight2.Next(nil) |
|||
toSend = append(toSend, moreToSend...) |
|||
return nextState, toSend, alert |
|||
} |
|||
|
|||
type ServerStateWaitFlight2 struct { |
|||
AuthCertificate func(chain []CertificateEntry) error |
|||
Params ConnectionParameters |
|||
cryptoParams CipherSuiteParams |
|||
masterSecret []byte |
|||
clientHandshakeTrafficSecret []byte |
|||
handshakeHash hash.Hash |
|||
clientTrafficSecret []byte |
|||
serverTrafficSecret []byte |
|||
exporterSecret []byte |
|||
} |
|||
|
|||
func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm != nil { |
|||
logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
if state.Params.UsingClientAuth { |
|||
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]") |
|||
nextState := ServerStateWaitCert{ |
|||
AuthCertificate: state.AuthCertificate, |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
handshakeHash: state.handshakeHash, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
clientTrafficSecret: state.clientTrafficSecret, |
|||
serverTrafficSecret: state.serverTrafficSecret, |
|||
exporterSecret: state.exporterSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
} |
|||
|
|||
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]") |
|||
nextState := ServerStateWaitFinished{ |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
handshakeHash: state.handshakeHash, |
|||
clientTrafficSecret: state.clientTrafficSecret, |
|||
serverTrafficSecret: state.serverTrafficSecret, |
|||
exporterSecret: state.exporterSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
} |
|||
|
|||
type ServerStateWaitCert struct { |
|||
AuthCertificate func(chain []CertificateEntry) error |
|||
Params ConnectionParameters |
|||
cryptoParams CipherSuiteParams |
|||
masterSecret []byte |
|||
clientHandshakeTrafficSecret []byte |
|||
handshakeHash hash.Hash |
|||
clientTrafficSecret []byte |
|||
serverTrafficSecret []byte |
|||
exporterSecret []byte |
|||
} |
|||
|
|||
func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil || hm.msgType != HandshakeTypeCertificate { |
|||
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
cert := &CertificateBody{} |
|||
_, err := cert.Unmarshal(hm.body) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
state.handshakeHash.Write(hm.Marshal()) |
|||
|
|||
if len(cert.CertificateList) == 0 { |
|||
logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate") |
|||
|
|||
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]") |
|||
nextState := ServerStateWaitFinished{ |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
handshakeHash: state.handshakeHash, |
|||
clientTrafficSecret: state.clientTrafficSecret, |
|||
serverTrafficSecret: state.serverTrafficSecret, |
|||
exporterSecret: state.exporterSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
} |
|||
|
|||
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]") |
|||
nextState := ServerStateWaitCV{ |
|||
AuthCertificate: state.AuthCertificate, |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
handshakeHash: state.handshakeHash, |
|||
clientTrafficSecret: state.clientTrafficSecret, |
|||
serverTrafficSecret: state.serverTrafficSecret, |
|||
clientCertificate: cert, |
|||
exporterSecret: state.exporterSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
} |
|||
|
|||
type ServerStateWaitCV struct { |
|||
AuthCertificate func(chain []CertificateEntry) error |
|||
Params ConnectionParameters |
|||
cryptoParams CipherSuiteParams |
|||
|
|||
masterSecret []byte |
|||
clientHandshakeTrafficSecret []byte |
|||
|
|||
handshakeHash hash.Hash |
|||
clientTrafficSecret []byte |
|||
serverTrafficSecret []byte |
|||
exporterSecret []byte |
|||
|
|||
clientCertificate *CertificateBody |
|||
} |
|||
|
|||
func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { |
|||
logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm)) |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
certVerify := &CertificateVerifyBody{} |
|||
_, err := certVerify.Unmarshal(hm.body) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
// Verify client signature over handshake hash
|
|||
hcv := state.handshakeHash.Sum(nil) |
|||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) |
|||
|
|||
clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey |
|||
if err := certVerify.Verify(clientPublicKey, hcv); err != nil { |
|||
logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err) |
|||
return nil, nil, AlertHandshakeFailure |
|||
} |
|||
|
|||
if state.AuthCertificate != nil { |
|||
err := state.AuthCertificate(state.clientCertificate.CertificateList) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate") |
|||
return nil, nil, AlertBadCertificate |
|||
} |
|||
} else { |
|||
logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate") |
|||
} |
|||
|
|||
// If it passes, record the certificateVerify in the transcript hash
|
|||
state.handshakeHash.Write(hm.Marshal()) |
|||
|
|||
logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]") |
|||
nextState := ServerStateWaitFinished{ |
|||
Params: state.Params, |
|||
cryptoParams: state.cryptoParams, |
|||
masterSecret: state.masterSecret, |
|||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, |
|||
handshakeHash: state.handshakeHash, |
|||
clientTrafficSecret: state.clientTrafficSecret, |
|||
serverTrafficSecret: state.serverTrafficSecret, |
|||
exporterSecret: state.exporterSecret, |
|||
} |
|||
return nextState, nil, AlertNoAlert |
|||
} |
|||
|
|||
type ServerStateWaitFinished struct { |
|||
Params ConnectionParameters |
|||
cryptoParams CipherSuiteParams |
|||
|
|||
masterSecret []byte |
|||
clientHandshakeTrafficSecret []byte |
|||
|
|||
handshakeHash hash.Hash |
|||
clientTrafficSecret []byte |
|||
serverTrafficSecret []byte |
|||
exporterSecret []byte |
|||
} |
|||
|
|||
func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil || hm.msgType != HandshakeTypeFinished { |
|||
logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()} |
|||
_, err := fin.Unmarshal(hm.body) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
// Verify client Finished data
|
|||
h5 := state.handshakeHash.Sum(nil) |
|||
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) |
|||
|
|||
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) |
|||
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) |
|||
|
|||
if !bytes.Equal(fin.VerifyData, clientFinishedData) { |
|||
logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify") |
|||
return nil, nil, AlertHandshakeFailure |
|||
} |
|||
|
|||
// Compute the resumption secret
|
|||
state.handshakeHash.Write(hm.Marshal()) |
|||
h6 := state.handshakeHash.Sum(nil) |
|||
logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6) |
|||
|
|||
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) |
|||
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) |
|||
|
|||
// Compute client traffic keys
|
|||
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) |
|||
|
|||
logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]") |
|||
nextState := StateConnected{ |
|||
Params: state.Params, |
|||
isClient: false, |
|||
cryptoParams: state.cryptoParams, |
|||
resumptionSecret: resumptionSecret, |
|||
clientTrafficSecret: state.clientTrafficSecret, |
|||
serverTrafficSecret: state.serverTrafficSecret, |
|||
exporterSecret: state.exporterSecret, |
|||
} |
|||
toSend := []HandshakeAction{ |
|||
RekeyIn{Label: "application", KeySet: clientTrafficKeys}, |
|||
} |
|||
return nextState, toSend, AlertNoAlert |
|||
} |
|||
@ -0,0 +1,230 @@ |
|||
package mint |
|||
|
|||
import ( |
|||
"time" |
|||
) |
|||
|
|||
// Marker interface for actions that an implementation should take based on
|
|||
// state transitions.
|
|||
type HandshakeAction interface{} |
|||
|
|||
type SendHandshakeMessage struct { |
|||
Message *HandshakeMessage |
|||
} |
|||
|
|||
type SendEarlyData struct{} |
|||
|
|||
type ReadEarlyData struct{} |
|||
|
|||
type ReadPastEarlyData struct{} |
|||
|
|||
type RekeyIn struct { |
|||
Label string |
|||
KeySet keySet |
|||
} |
|||
|
|||
type RekeyOut struct { |
|||
Label string |
|||
KeySet keySet |
|||
} |
|||
|
|||
type StorePSK struct { |
|||
PSK PreSharedKey |
|||
} |
|||
|
|||
type HandshakeState interface { |
|||
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) |
|||
} |
|||
|
|||
type AppExtensionHandler interface { |
|||
Send(hs HandshakeType, el *ExtensionList) error |
|||
Receive(hs HandshakeType, el *ExtensionList) error |
|||
} |
|||
|
|||
// Capabilities objects represent the capabilities of a TLS client or server,
|
|||
// as an input to TLS negotiation
|
|||
type Capabilities struct { |
|||
// For both client and server
|
|||
CipherSuites []CipherSuite |
|||
Groups []NamedGroup |
|||
SignatureSchemes []SignatureScheme |
|||
PSKs PreSharedKeyCache |
|||
Certificates []*Certificate |
|||
AuthCertificate func(chain []CertificateEntry) error |
|||
ExtensionHandler AppExtensionHandler |
|||
|
|||
// For client
|
|||
PSKModes []PSKKeyExchangeMode |
|||
|
|||
// For server
|
|||
NextProtos []string |
|||
AllowEarlyData bool |
|||
RequireCookie bool |
|||
CookieHandler CookieHandler |
|||
RequireClientAuth bool |
|||
} |
|||
|
|||
// ConnectionOptions objects represent per-connection settings for a client
|
|||
// initiating a connection
|
|||
type ConnectionOptions struct { |
|||
ServerName string |
|||
NextProtos []string |
|||
EarlyData []byte |
|||
} |
|||
|
|||
// ConnectionParameters objects represent the parameters negotiated for a
|
|||
// connection.
|
|||
type ConnectionParameters struct { |
|||
UsingPSK bool |
|||
UsingDH bool |
|||
ClientSendingEarlyData bool |
|||
UsingEarlyData bool |
|||
UsingClientAuth bool |
|||
|
|||
CipherSuite CipherSuite |
|||
ServerName string |
|||
NextProto string |
|||
} |
|||
|
|||
// StateConnected is symmetric between client and server
|
|||
type StateConnected struct { |
|||
Params ConnectionParameters |
|||
isClient bool |
|||
cryptoParams CipherSuiteParams |
|||
resumptionSecret []byte |
|||
clientTrafficSecret []byte |
|||
serverTrafficSecret []byte |
|||
exporterSecret []byte |
|||
} |
|||
|
|||
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) { |
|||
var trafficKeys keySet |
|||
if state.isClient { |
|||
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, |
|||
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) |
|||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) |
|||
} else { |
|||
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, |
|||
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) |
|||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) |
|||
} |
|||
|
|||
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request}) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err) |
|||
return nil, AlertInternalError |
|||
} |
|||
|
|||
toSend := []HandshakeAction{ |
|||
SendHandshakeMessage{kum}, |
|||
RekeyOut{Label: "update", KeySet: trafficKeys}, |
|||
} |
|||
return toSend, AlertNoAlert |
|||
} |
|||
|
|||
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) { |
|||
tkt, err := NewSessionTicket(length, lifetime) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err) |
|||
return nil, AlertInternalError |
|||
} |
|||
|
|||
err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime}) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err) |
|||
return nil, AlertInternalError |
|||
} |
|||
|
|||
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, |
|||
labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size()) |
|||
|
|||
newPSK := PreSharedKey{ |
|||
CipherSuite: state.cryptoParams.Suite, |
|||
IsResumption: true, |
|||
Identity: tkt.Ticket, |
|||
Key: resumptionKey, |
|||
NextProto: state.Params.NextProto, |
|||
ReceivedAt: time.Now(), |
|||
ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second), |
|||
TicketAgeAdd: tkt.TicketAgeAdd, |
|||
} |
|||
|
|||
tktm, err := HandshakeMessageFromBody(tkt) |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err) |
|||
return nil, AlertInternalError |
|||
} |
|||
|
|||
toSend := []HandshakeAction{ |
|||
StorePSK{newPSK}, |
|||
SendHandshakeMessage{tktm}, |
|||
} |
|||
return toSend, AlertNoAlert |
|||
} |
|||
|
|||
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { |
|||
if hm == nil { |
|||
logf(logTypeHandshake, "[StateConnected] Unexpected message") |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
bodyGeneric, err := hm.ToBody() |
|||
if err != nil { |
|||
logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err) |
|||
return nil, nil, AlertDecodeError |
|||
} |
|||
|
|||
switch body := bodyGeneric.(type) { |
|||
case *KeyUpdateBody: |
|||
var trafficKeys keySet |
|||
if !state.isClient { |
|||
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, |
|||
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) |
|||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) |
|||
} else { |
|||
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, |
|||
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) |
|||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) |
|||
} |
|||
|
|||
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}} |
|||
|
|||
// If requested, roll outbound keys and send a KeyUpdate
|
|||
if body.KeyUpdateRequest == KeyUpdateRequested { |
|||
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested) |
|||
if alert != AlertNoAlert { |
|||
return nil, nil, alert |
|||
} |
|||
|
|||
toSend = append(toSend, moreToSend...) |
|||
} |
|||
|
|||
return state, toSend, AlertNoAlert |
|||
|
|||
case *NewSessionTicketBody: |
|||
// XXX: Allow NewSessionTicket in both directions?
|
|||
if !state.isClient { |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
|
|||
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, |
|||
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size()) |
|||
|
|||
psk := PreSharedKey{ |
|||
CipherSuite: state.cryptoParams.Suite, |
|||
IsResumption: true, |
|||
Identity: body.Ticket, |
|||
Key: resumptionKey, |
|||
NextProto: state.Params.NextProto, |
|||
ReceivedAt: time.Now(), |
|||
ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second), |
|||
TicketAgeAdd: body.TicketAgeAdd, |
|||
} |
|||
|
|||
toSend := []HandshakeAction{StorePSK{psk}} |
|||
return state, toSend, AlertNoAlert |
|||
} |
|||
|
|||
logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType) |
|||
return nil, nil, AlertUnexpectedMessage |
|||
} |
|||
@ -0,0 +1,74 @@ |
|||
TLS Syntax |
|||
========== |
|||
|
|||
TLS defines [its own syntax](https://tlswg.github.io/tls13-spec/#rfc.section.3) |
|||
for describing structures used in that protocol. To facilitate experimentation |
|||
with TLS in Go, this module maps that syntax to the Go structure syntax, taking |
|||
advantage of Go's type annotations to encode non-type information carried in the |
|||
TLS presentation format. |
|||
|
|||
For example, in the TLS specification, a ClientHello message has the following |
|||
structure: |
|||
|
|||
~~~~~ |
|||
uint16 ProtocolVersion; |
|||
opaque Random[32]; |
|||
uint8 CipherSuite[2]; |
|||
enum { ... (65535)} ExtensionType; |
|||
|
|||
struct { |
|||
ExtensionType extension_type; |
|||
opaque extension_data<0..2^16-1>; |
|||
} Extension; |
|||
|
|||
struct { |
|||
ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ |
|||
Random random; |
|||
opaque legacy_session_id<0..32>; |
|||
CipherSuite cipher_suites<2..2^16-2>; |
|||
opaque legacy_compression_methods<1..2^8-1>; |
|||
Extension extensions<0..2^16-1>; |
|||
} ClientHello; |
|||
~~~~~ |
|||
|
|||
This maps to the following Go type definitions: |
|||
|
|||
~~~~~ |
|||
type protocolVersion uint16 |
|||
type random [32]byte |
|||
type cipherSuite uint16 // or [2]byte |
|||
type extensionType uint16 |
|||
|
|||
type extension struct { |
|||
ExtensionType extensionType |
|||
ExtensionData []byte `tls:"head=2"` |
|||
} |
|||
|
|||
type clientHello struct { |
|||
LegacyVersion protocolVersion |
|||
Random random |
|||
LegacySessionID []byte `tls:"head=1,max=32"` |
|||
CipherSuites []cipherSuite `tls:"head=2,min=2"` |
|||
LegacyCompressionMethods []byte `tls:"head=1,min=1"` |
|||
Extensions []extension `tls:"head=2"` |
|||
} |
|||
~~~~~ |
|||
|
|||
Then you can just declare, marshal, and unmarshal structs just like you would |
|||
with, say JSON. |
|||
|
|||
The available annotations right now are all related to vectors: |
|||
|
|||
* `head`: The number of bytes of length to use as a "header" |
|||
* `min`: The minimum length of the vector, in bytes |
|||
* `max`: The maximum length of the vector, in bytes |
|||
|
|||
## Not supported |
|||
|
|||
* The `select()` syntax for creating alternate version of the same struct (see, |
|||
e.g., the KeyShare extension) |
|||
|
|||
* The backreference syntax for array lengths or select parameters, as in `opaque |
|||
fragment[TLSPlaintext.length]`. Note, however, that in cases where the length |
|||
immediately preceds the array, these can be reframed as vectors with |
|||
appropriate sizes. |
|||
@ -0,0 +1,243 @@ |
|||
package syntax |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"reflect" |
|||
"runtime" |
|||
) |
|||
|
|||
func Unmarshal(data []byte, v interface{}) (int, error) { |
|||
// Check for well-formedness.
|
|||
// Avoids filling out half a data structure
|
|||
// before discovering a JSON syntax error.
|
|||
d := decodeState{} |
|||
d.Write(data) |
|||
return d.unmarshal(v) |
|||
} |
|||
|
|||
// These are the options that can be specified in the struct tag. Right now,
|
|||
// all of them apply to variable-length vectors and nothing else
|
|||
type decOpts struct { |
|||
head uint // length of length in bytes
|
|||
min uint // minimum size in bytes
|
|||
max uint // maximum size in bytes
|
|||
} |
|||
|
|||
type decodeState struct { |
|||
bytes.Buffer |
|||
} |
|||
|
|||
func (d *decodeState) unmarshal(v interface{}) (read int, err error) { |
|||
defer func() { |
|||
if r := recover(); r != nil { |
|||
if _, ok := r.(runtime.Error); ok { |
|||
panic(r) |
|||
} |
|||
if s, ok := r.(string); ok { |
|||
panic(s) |
|||
} |
|||
err = r.(error) |
|||
} |
|||
}() |
|||
|
|||
rv := reflect.ValueOf(v) |
|||
if rv.Kind() != reflect.Ptr || rv.IsNil() { |
|||
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)") |
|||
} |
|||
|
|||
read = d.value(rv) |
|||
return read, nil |
|||
} |
|||
|
|||
func (e *decodeState) value(v reflect.Value) int { |
|||
return valueDecoder(v)(e, v, decOpts{}) |
|||
} |
|||
|
|||
type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int |
|||
|
|||
func valueDecoder(v reflect.Value) decoderFunc { |
|||
return typeDecoder(v.Type().Elem()) |
|||
} |
|||
|
|||
func typeDecoder(t reflect.Type) decoderFunc { |
|||
// Note: Omits the caching / wait-group things that encoding/json uses
|
|||
return newTypeDecoder(t) |
|||
} |
|||
|
|||
func newTypeDecoder(t reflect.Type) decoderFunc { |
|||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
|||
|
|||
switch t.Kind() { |
|||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
|||
return uintDecoder |
|||
case reflect.Array: |
|||
return newArrayDecoder(t) |
|||
case reflect.Slice: |
|||
return newSliceDecoder(t) |
|||
case reflect.Struct: |
|||
return newStructDecoder(t) |
|||
default: |
|||
panic(fmt.Errorf("Unsupported type (%s)", t)) |
|||
} |
|||
} |
|||
|
|||
///// Specific decoders below
|
|||
|
|||
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int { |
|||
var uintLen int |
|||
switch v.Elem().Kind() { |
|||
case reflect.Uint8: |
|||
uintLen = 1 |
|||
case reflect.Uint16: |
|||
uintLen = 2 |
|||
case reflect.Uint32: |
|||
uintLen = 4 |
|||
case reflect.Uint64: |
|||
uintLen = 8 |
|||
} |
|||
|
|||
buf := make([]byte, uintLen) |
|||
n, err := d.Read(buf) |
|||
if err != nil { |
|||
panic(err) |
|||
} |
|||
if n != uintLen { |
|||
panic(fmt.Errorf("Insufficient data to read uint")) |
|||
} |
|||
|
|||
val := uint64(0) |
|||
for _, b := range buf { |
|||
val = (val << 8) + uint64(b) |
|||
} |
|||
|
|||
v.Elem().SetUint(val) |
|||
return uintLen |
|||
} |
|||
|
|||
//////////
|
|||
|
|||
type arrayDecoder struct { |
|||
elemDec decoderFunc |
|||
} |
|||
|
|||
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { |
|||
n := v.Elem().Type().Len() |
|||
read := 0 |
|||
for i := 0; i < n; i += 1 { |
|||
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts) |
|||
} |
|||
return read |
|||
} |
|||
|
|||
func newArrayDecoder(t reflect.Type) decoderFunc { |
|||
dec := &arrayDecoder{typeDecoder(t.Elem())} |
|||
return dec.decode |
|||
} |
|||
|
|||
//////////
|
|||
|
|||
type sliceDecoder struct { |
|||
elementType reflect.Type |
|||
elementDec decoderFunc |
|||
} |
|||
|
|||
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { |
|||
if opts.head == 0 { |
|||
panic(fmt.Errorf("Cannot decode a slice without a header length")) |
|||
} |
|||
|
|||
lengthBytes := make([]byte, opts.head) |
|||
n, err := d.Read(lengthBytes) |
|||
if err != nil { |
|||
panic(err) |
|||
} |
|||
if uint(n) != opts.head { |
|||
panic(fmt.Errorf("Not enough data to read header")) |
|||
} |
|||
|
|||
length := uint(0) |
|||
for _, b := range lengthBytes { |
|||
length = (length << 8) + uint(b) |
|||
} |
|||
|
|||
if opts.max > 0 && length > opts.max { |
|||
panic(fmt.Errorf("Length of vector exceeds declared max")) |
|||
} |
|||
if length < opts.min { |
|||
panic(fmt.Errorf("Length of vector below declared min")) |
|||
} |
|||
|
|||
data := make([]byte, length) |
|||
n, err = d.Read(data) |
|||
if err != nil { |
|||
panic(err) |
|||
} |
|||
if uint(n) != length { |
|||
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length)) |
|||
} |
|||
|
|||
elemBuf := &decodeState{} |
|||
elemBuf.Write(data) |
|||
elems := []reflect.Value{} |
|||
read := int(opts.head) |
|||
for elemBuf.Len() > 0 { |
|||
elem := reflect.New(sd.elementType) |
|||
read += sd.elementDec(elemBuf, elem, opts) |
|||
elems = append(elems, elem) |
|||
} |
|||
|
|||
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems))) |
|||
for i := 0; i < len(elems); i += 1 { |
|||
v.Elem().Index(i).Set(elems[i].Elem()) |
|||
} |
|||
return read |
|||
} |
|||
|
|||
func newSliceDecoder(t reflect.Type) decoderFunc { |
|||
dec := &sliceDecoder{ |
|||
elementType: t.Elem(), |
|||
elementDec: typeDecoder(t.Elem()), |
|||
} |
|||
return dec.decode |
|||
} |
|||
|
|||
//////////
|
|||
|
|||
type structDecoder struct { |
|||
fieldOpts []decOpts |
|||
fieldDecs []decoderFunc |
|||
} |
|||
|
|||
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { |
|||
read := 0 |
|||
for i := range sd.fieldDecs { |
|||
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i]) |
|||
} |
|||
return read |
|||
} |
|||
|
|||
func newStructDecoder(t reflect.Type) decoderFunc { |
|||
n := t.NumField() |
|||
sd := structDecoder{ |
|||
fieldOpts: make([]decOpts, n), |
|||
fieldDecs: make([]decoderFunc, n), |
|||
} |
|||
|
|||
for i := 0; i < n; i += 1 { |
|||
f := t.Field(i) |
|||
|
|||
tag := f.Tag.Get("tls") |
|||
tagOpts := parseTag(tag) |
|||
|
|||
sd.fieldOpts[i] = decOpts{ |
|||
head: tagOpts["head"], |
|||
max: tagOpts["max"], |
|||
min: tagOpts["min"], |
|||
} |
|||
|
|||
sd.fieldDecs[i] = typeDecoder(f.Type) |
|||
} |
|||
|
|||
return sd.decode |
|||
} |
|||
@ -0,0 +1,187 @@ |
|||
package syntax |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"reflect" |
|||
"runtime" |
|||
) |
|||
|
|||
func Marshal(v interface{}) ([]byte, error) { |
|||
e := &encodeState{} |
|||
err := e.marshal(v, encOpts{}) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return e.Bytes(), nil |
|||
} |
|||
|
|||
// These are the options that can be specified in the struct tag. Right now,
|
|||
// all of them apply to variable-length vectors and nothing else
|
|||
type encOpts struct { |
|||
head uint // length of length in bytes
|
|||
min uint // minimum size in bytes
|
|||
max uint // maximum size in bytes
|
|||
} |
|||
|
|||
type encodeState struct { |
|||
bytes.Buffer |
|||
} |
|||
|
|||
func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) { |
|||
defer func() { |
|||
if r := recover(); r != nil { |
|||
if _, ok := r.(runtime.Error); ok { |
|||
panic(r) |
|||
} |
|||
if s, ok := r.(string); ok { |
|||
panic(s) |
|||
} |
|||
err = r.(error) |
|||
} |
|||
}() |
|||
e.reflectValue(reflect.ValueOf(v), opts) |
|||
return nil |
|||
} |
|||
|
|||
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) { |
|||
valueEncoder(v)(e, v, opts) |
|||
} |
|||
|
|||
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts) |
|||
|
|||
func valueEncoder(v reflect.Value) encoderFunc { |
|||
if !v.IsValid() { |
|||
panic(fmt.Errorf("Cannot encode an invalid value")) |
|||
} |
|||
return typeEncoder(v.Type()) |
|||
} |
|||
|
|||
func typeEncoder(t reflect.Type) encoderFunc { |
|||
// Note: Omits the caching / wait-group things that encoding/json uses
|
|||
return newTypeEncoder(t) |
|||
} |
|||
|
|||
func newTypeEncoder(t reflect.Type) encoderFunc { |
|||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
|||
|
|||
switch t.Kind() { |
|||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
|||
return uintEncoder |
|||
case reflect.Array: |
|||
return newArrayEncoder(t) |
|||
case reflect.Slice: |
|||
return newSliceEncoder(t) |
|||
case reflect.Struct: |
|||
return newStructEncoder(t) |
|||
default: |
|||
panic(fmt.Errorf("Unsupported type (%s)", t)) |
|||
} |
|||
} |
|||
|
|||
///// Specific encoders below
|
|||
|
|||
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) { |
|||
u := v.Uint() |
|||
switch v.Type().Kind() { |
|||
case reflect.Uint8: |
|||
e.WriteByte(byte(u)) |
|||
case reflect.Uint16: |
|||
e.Write([]byte{byte(u >> 8), byte(u)}) |
|||
case reflect.Uint32: |
|||
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) |
|||
case reflect.Uint64: |
|||
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32), |
|||
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) |
|||
} |
|||
} |
|||
|
|||
//////////
|
|||
|
|||
type arrayEncoder struct { |
|||
elemEnc encoderFunc |
|||
} |
|||
|
|||
func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { |
|||
n := v.Len() |
|||
for i := 0; i < n; i += 1 { |
|||
ae.elemEnc(e, v.Index(i), opts) |
|||
} |
|||
} |
|||
|
|||
func newArrayEncoder(t reflect.Type) encoderFunc { |
|||
enc := &arrayEncoder{typeEncoder(t.Elem())} |
|||
return enc.encode |
|||
} |
|||
|
|||
//////////
|
|||
|
|||
type sliceEncoder struct { |
|||
ae *arrayEncoder |
|||
} |
|||
|
|||
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { |
|||
if opts.head == 0 { |
|||
panic(fmt.Errorf("Cannot encode a slice without a header length")) |
|||
} |
|||
|
|||
arrayState := &encodeState{} |
|||
se.ae.encode(arrayState, v, opts) |
|||
|
|||
n := uint(arrayState.Len()) |
|||
if opts.max > 0 && n > opts.max { |
|||
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max)) |
|||
} |
|||
if n>>(8*opts.head) > 0 { |
|||
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head)) |
|||
} |
|||
if n < opts.min { |
|||
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min)) |
|||
} |
|||
|
|||
for i := int(opts.head - 1); i >= 0; i -= 1 { |
|||
e.WriteByte(byte(n >> (8 * uint(i)))) |
|||
} |
|||
e.Write(arrayState.Bytes()) |
|||
} |
|||
|
|||
func newSliceEncoder(t reflect.Type) encoderFunc { |
|||
enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}} |
|||
return enc.encode |
|||
} |
|||
|
|||
//////////
|
|||
|
|||
type structEncoder struct { |
|||
fieldOpts []encOpts |
|||
fieldEncs []encoderFunc |
|||
} |
|||
|
|||
func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { |
|||
for i := range se.fieldEncs { |
|||
se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i]) |
|||
} |
|||
} |
|||
|
|||
func newStructEncoder(t reflect.Type) encoderFunc { |
|||
n := t.NumField() |
|||
se := structEncoder{ |
|||
fieldOpts: make([]encOpts, n), |
|||
fieldEncs: make([]encoderFunc, n), |
|||
} |
|||
|
|||
for i := 0; i < n; i += 1 { |
|||
f := t.Field(i) |
|||
tag := f.Tag.Get("tls") |
|||
tagOpts := parseTag(tag) |
|||
|
|||
se.fieldOpts[i] = encOpts{ |
|||
head: tagOpts["head"], |
|||
max: tagOpts["max"], |
|||
min: tagOpts["min"], |
|||
} |
|||
se.fieldEncs[i] = typeEncoder(f.Type) |
|||
} |
|||
|
|||
return se.encode |
|||
} |
|||
@ -0,0 +1,30 @@ |
|||
package syntax |
|||
|
|||
import ( |
|||
"strconv" |
|||
"strings" |
|||
) |
|||
|
|||
// `tls:"head=2,min=2,max=255"`
|
|||
|
|||
type tagOptions map[string]uint |
|||
|
|||
// parseTag parses a struct field's "tls" tag as a comma-separated list of
|
|||
// name=value pairs, where the values MUST be unsigned integers
|
|||
func parseTag(tag string) tagOptions { |
|||
opts := tagOptions{} |
|||
for _, token := range strings.Split(tag, ",") { |
|||
if strings.Index(token, "=") == -1 { |
|||
continue |
|||
} |
|||
|
|||
parts := strings.Split(token, "=") |
|||
if len(parts[0]) == 0 { |
|||
continue |
|||
} |
|||
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 { |
|||
opts[parts[0]] = uint(val) |
|||
} |
|||
} |
|||
return opts |
|||
} |
|||
@ -0,0 +1,168 @@ |
|||
package mint |
|||
|
|||
// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls
|
|||
|
|||
import ( |
|||
"errors" |
|||
"net" |
|||
"strings" |
|||
"time" |
|||
) |
|||
|
|||
// Server returns a new TLS server side connection
|
|||
// using conn as the underlying transport.
|
|||
// The configuration config must be non-nil and must include
|
|||
// at least one certificate or else set GetCertificate.
|
|||
func Server(conn net.Conn, config *Config) *Conn { |
|||
return NewConn(conn, config, false) |
|||
} |
|||
|
|||
// Client returns a new TLS client side connection
|
|||
// using conn as the underlying transport.
|
|||
// The config cannot be nil: users must set either ServerName or
|
|||
// InsecureSkipVerify in the config.
|
|||
func Client(conn net.Conn, config *Config) *Conn { |
|||
return NewConn(conn, config, true) |
|||
} |
|||
|
|||
// A listener implements a network listener (net.Listener) for TLS connections.
|
|||
type Listener struct { |
|||
net.Listener |
|||
config *Config |
|||
} |
|||
|
|||
// Accept waits for and returns the next incoming TLS connection.
|
|||
// The returned connection c is a *tls.Conn.
|
|||
func (l *Listener) Accept() (c net.Conn, err error) { |
|||
c, err = l.Listener.Accept() |
|||
if err != nil { |
|||
return |
|||
} |
|||
server := Server(c, l.config) |
|||
err = server.Handshake() |
|||
if err == AlertNoAlert { |
|||
err = nil |
|||
} |
|||
c = server |
|||
return |
|||
} |
|||
|
|||
// NewListener creates a Listener which accepts connections from an inner
|
|||
// Listener and wraps each connection with Server.
|
|||
// The configuration config must be non-nil and must include
|
|||
// at least one certificate or else set GetCertificate.
|
|||
func NewListener(inner net.Listener, config *Config) net.Listener { |
|||
l := new(Listener) |
|||
l.Listener = inner |
|||
l.config = config |
|||
return l |
|||
} |
|||
|
|||
// Listen creates a TLS listener accepting connections on the
|
|||
// given network address using net.Listen.
|
|||
// The configuration config must be non-nil and must include
|
|||
// at least one certificate or else set GetCertificate.
|
|||
func Listen(network, laddr string, config *Config) (net.Listener, error) { |
|||
if config == nil || !config.ValidForServer() { |
|||
return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config") |
|||
} |
|||
l, err := net.Listen(network, laddr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return NewListener(l, config), nil |
|||
} |
|||
|
|||
type TimeoutError struct{} |
|||
|
|||
func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" } |
|||
func (TimeoutError) Timeout() bool { return true } |
|||
func (TimeoutError) Temporary() bool { return true } |
|||
|
|||
// DialWithDialer connects to the given network address using dialer.Dial and
|
|||
// then initiates a TLS handshake, returning the resulting TLS connection. Any
|
|||
// timeout or deadline given in the dialer apply to connection and TLS
|
|||
// handshake as a whole.
|
|||
//
|
|||
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
|||
// configuration; see the documentation of Config for the defaults.
|
|||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { |
|||
// We want the Timeout and Deadline values from dialer to cover the
|
|||
// whole process: TCP connection and TLS handshake. This means that we
|
|||
// also need to start our own timers now.
|
|||
timeout := dialer.Timeout |
|||
|
|||
if !dialer.Deadline.IsZero() { |
|||
deadlineTimeout := dialer.Deadline.Sub(time.Now()) |
|||
if timeout == 0 || deadlineTimeout < timeout { |
|||
timeout = deadlineTimeout |
|||
} |
|||
} |
|||
|
|||
var errChannel chan error |
|||
|
|||
if timeout != 0 { |
|||
errChannel = make(chan error, 2) |
|||
time.AfterFunc(timeout, func() { |
|||
errChannel <- TimeoutError{} |
|||
}) |
|||
} |
|||
|
|||
rawConn, err := dialer.Dial(network, addr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
colonPos := strings.LastIndex(addr, ":") |
|||
if colonPos == -1 { |
|||
colonPos = len(addr) |
|||
} |
|||
hostname := addr[:colonPos] |
|||
|
|||
if config == nil { |
|||
config = &Config{} |
|||
} |
|||
// If no ServerName is set, infer the ServerName
|
|||
// from the hostname we're connecting to.
|
|||
if config.ServerName == "" { |
|||
// Make a copy to avoid polluting argument or default.
|
|||
c := config.Clone() |
|||
c.ServerName = hostname |
|||
config = c |
|||
} |
|||
|
|||
conn := Client(rawConn, config) |
|||
|
|||
if timeout == 0 { |
|||
err = conn.Handshake() |
|||
if err == AlertNoAlert { |
|||
err = nil |
|||
} |
|||
} else { |
|||
go func() { |
|||
errChannel <- conn.Handshake() |
|||
}() |
|||
|
|||
err = <-errChannel |
|||
if err == AlertNoAlert { |
|||
err = nil |
|||
} |
|||
} |
|||
|
|||
if err != nil { |
|||
rawConn.Close() |
|||
return nil, err |
|||
} |
|||
|
|||
return conn, nil |
|||
} |
|||
|
|||
// Dial connects to the given network address using net.Dial
|
|||
// and then initiates a TLS handshake, returning the resulting
|
|||
// TLS connection.
|
|||
// Dial interprets a nil configuration as equivalent to
|
|||
// the zero configuration; see the documentation of Config
|
|||
// for the defaults.
|
|||
func Dial(network, addr string, config *Config) (*Conn, error) { |
|||
return DialWithDialer(new(net.Dialer), network, addr, config) |
|||
} |
|||
@ -1,58 +0,0 @@ |
|||
package crypto |
|||
|
|||
import ( |
|||
"crypto/cipher" |
|||
"errors" |
|||
|
|||
"github.com/lucas-clemente/aes12" |
|||
|
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
) |
|||
|
|||
type aeadAESGCM struct { |
|||
otherIV []byte |
|||
myIV []byte |
|||
encrypter cipher.AEAD |
|||
decrypter cipher.AEAD |
|||
} |
|||
|
|||
// NewAEADAESGCM creates a AEAD using AES-GCM with 12 bytes tag size
|
|||
//
|
|||
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
|||
// tag size, and couples the cipher and aes packages closely.
|
|||
// See https://github.com/lucas-clemente/aes12.
|
|||
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) { |
|||
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 { |
|||
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs") |
|||
} |
|||
encrypterCipher, err := aes12.NewCipher(myKey) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
encrypter, err := aes12.NewGCM(encrypterCipher) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
decrypterCipher, err := aes12.NewCipher(otherKey) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
decrypter, err := aes12.NewGCM(decrypterCipher) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return &aeadAESGCM{ |
|||
otherIV: otherIV, |
|||
myIV: myIV, |
|||
encrypter: encrypter, |
|||
decrypter: decrypter, |
|||
}, nil |
|||
} |
|||
|
|||
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { |
|||
return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData) |
|||
} |
|||
|
|||
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { |
|||
return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData) |
|||
} |
|||
@ -1,14 +0,0 @@ |
|||
package crypto |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
|
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
) |
|||
|
|||
func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { |
|||
res := make([]byte, 12) |
|||
copy(res[0:4], iv) |
|||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber)) |
|||
return res |
|||
} |
|||
@ -1,240 +0,0 @@ |
|||
package flowcontrol |
|||
|
|||
import ( |
|||
"errors" |
|||
"fmt" |
|||
"sync" |
|||
|
|||
"github.com/lucas-clemente/quic-go/congestion" |
|||
"github.com/lucas-clemente/quic-go/handshake" |
|||
"github.com/lucas-clemente/quic-go/internal/utils" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
"github.com/lucas-clemente/quic-go/qerr" |
|||
) |
|||
|
|||
type flowControlManager struct { |
|||
connectionParameters handshake.ConnectionParametersManager |
|||
rttStats *congestion.RTTStats |
|||
|
|||
streamFlowController map[protocol.StreamID]*flowController |
|||
connFlowController *flowController |
|||
mutex sync.RWMutex |
|||
} |
|||
|
|||
var _ FlowControlManager = &flowControlManager{} |
|||
|
|||
var errMapAccess = errors.New("Error accessing the flowController map.") |
|||
|
|||
// NewFlowControlManager creates a new flow control manager
|
|||
func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager { |
|||
return &flowControlManager{ |
|||
connectionParameters: connectionParameters, |
|||
rttStats: rttStats, |
|||
streamFlowController: make(map[protocol.StreamID]*flowController), |
|||
connFlowController: newFlowController(0, false, connectionParameters, rttStats), |
|||
} |
|||
} |
|||
|
|||
// NewStream creates new flow controllers for a stream
|
|||
// it does nothing if the stream already exists
|
|||
func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) { |
|||
f.mutex.Lock() |
|||
defer f.mutex.Unlock() |
|||
|
|||
if _, ok := f.streamFlowController[streamID]; ok { |
|||
return |
|||
} |
|||
|
|||
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats) |
|||
} |
|||
|
|||
// RemoveStream removes a closed stream from flow control
|
|||
func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) { |
|||
f.mutex.Lock() |
|||
delete(f.streamFlowController, streamID) |
|||
f.mutex.Unlock() |
|||
} |
|||
|
|||
// ResetStream should be called when receiving a RstStreamFrame
|
|||
// it updates the byte offset to the value in the RstStreamFrame
|
|||
// streamID must not be 0 here
|
|||
func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { |
|||
f.mutex.Lock() |
|||
defer f.mutex.Unlock() |
|||
|
|||
streamFlowController, err := f.getFlowController(streamID) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
increment, err := streamFlowController.UpdateHighestReceived(byteOffset) |
|||
if err != nil { |
|||
return qerr.StreamDataAfterTermination |
|||
} |
|||
|
|||
if streamFlowController.CheckFlowControlViolation() { |
|||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow)) |
|||
} |
|||
|
|||
if streamFlowController.ContributesToConnection() { |
|||
f.connFlowController.IncrementHighestReceived(increment) |
|||
if f.connFlowController.CheckFlowControlViolation() { |
|||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow)) |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// UpdateHighestReceived updates the highest received byte offset for a stream
|
|||
// it adds the number of additional bytes to connection level flow control
|
|||
// streamID must not be 0 here
|
|||
func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { |
|||
f.mutex.Lock() |
|||
defer f.mutex.Unlock() |
|||
|
|||
streamFlowController, err := f.getFlowController(streamID) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
// UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered
|
|||
// this error can be ignored here
|
|||
increment, _ := streamFlowController.UpdateHighestReceived(byteOffset) |
|||
|
|||
if streamFlowController.CheckFlowControlViolation() { |
|||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow)) |
|||
} |
|||
|
|||
if streamFlowController.ContributesToConnection() { |
|||
f.connFlowController.IncrementHighestReceived(increment) |
|||
if f.connFlowController.CheckFlowControlViolation() { |
|||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow)) |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// streamID must not be 0 here
|
|||
func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error { |
|||
f.mutex.Lock() |
|||
defer f.mutex.Unlock() |
|||
|
|||
fc, err := f.getFlowController(streamID) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
fc.AddBytesRead(n) |
|||
if fc.ContributesToConnection() { |
|||
f.connFlowController.AddBytesRead(n) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) { |
|||
f.mutex.Lock() |
|||
defer f.mutex.Unlock() |
|||
|
|||
// get WindowUpdates for streams
|
|||
for id, fc := range f.streamFlowController { |
|||
if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary { |
|||
res = append(res, WindowUpdate{StreamID: id, Offset: offset}) |
|||
if fc.ContributesToConnection() && newIncrement != 0 { |
|||
f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier)) |
|||
} |
|||
} |
|||
} |
|||
// get a WindowUpdate for the connection
|
|||
if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary { |
|||
res = append(res, WindowUpdate{StreamID: 0, Offset: offset}) |
|||
} |
|||
|
|||
return |
|||
} |
|||
|
|||
func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) { |
|||
f.mutex.RLock() |
|||
defer f.mutex.RUnlock() |
|||
|
|||
// StreamID can be 0 when retransmitting
|
|||
if streamID == 0 { |
|||
return f.connFlowController.receiveWindow, nil |
|||
} |
|||
|
|||
flowController, err := f.getFlowController(streamID) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
return flowController.receiveWindow, nil |
|||
} |
|||
|
|||
// streamID must not be 0 here
|
|||
func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error { |
|||
f.mutex.Lock() |
|||
defer f.mutex.Unlock() |
|||
|
|||
fc, err := f.getFlowController(streamID) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
fc.AddBytesSent(n) |
|||
if fc.ContributesToConnection() { |
|||
f.connFlowController.AddBytesSent(n) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// must not be called with StreamID 0
|
|||
func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) { |
|||
f.mutex.RLock() |
|||
defer f.mutex.RUnlock() |
|||
|
|||
fc, err := f.getFlowController(streamID) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
res := fc.SendWindowSize() |
|||
|
|||
if fc.ContributesToConnection() { |
|||
res = utils.MinByteCount(res, f.connFlowController.SendWindowSize()) |
|||
} |
|||
|
|||
return res, nil |
|||
} |
|||
|
|||
func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount { |
|||
f.mutex.RLock() |
|||
defer f.mutex.RUnlock() |
|||
|
|||
return f.connFlowController.SendWindowSize() |
|||
} |
|||
|
|||
// streamID may be 0 here
|
|||
func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) { |
|||
f.mutex.Lock() |
|||
defer f.mutex.Unlock() |
|||
|
|||
var fc *flowController |
|||
if streamID == 0 { |
|||
fc = f.connFlowController |
|||
} else { |
|||
var err error |
|||
fc, err = f.getFlowController(streamID) |
|||
if err != nil { |
|||
return false, err |
|||
} |
|||
} |
|||
|
|||
return fc.UpdateSendWindow(offset), nil |
|||
} |
|||
|
|||
func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) { |
|||
streamFlowController, ok := f.streamFlowController[streamID] |
|||
if !ok { |
|||
return nil, errMapAccess |
|||
} |
|||
return streamFlowController, nil |
|||
} |
|||
@ -1,198 +0,0 @@ |
|||
package flowcontrol |
|||
|
|||
import ( |
|||
"errors" |
|||
"time" |
|||
|
|||
"github.com/lucas-clemente/quic-go/congestion" |
|||
"github.com/lucas-clemente/quic-go/handshake" |
|||
"github.com/lucas-clemente/quic-go/internal/utils" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
) |
|||
|
|||
type flowController struct { |
|||
streamID protocol.StreamID |
|||
contributesToConnection bool // does the stream contribute to connection level flow control
|
|||
|
|||
connectionParameters handshake.ConnectionParametersManager |
|||
rttStats *congestion.RTTStats |
|||
|
|||
bytesSent protocol.ByteCount |
|||
sendWindow protocol.ByteCount |
|||
|
|||
lastWindowUpdateTime time.Time |
|||
|
|||
bytesRead protocol.ByteCount |
|||
highestReceived protocol.ByteCount |
|||
receiveWindow protocol.ByteCount |
|||
receiveWindowIncrement protocol.ByteCount |
|||
maxReceiveWindowIncrement protocol.ByteCount |
|||
} |
|||
|
|||
// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously
|
|||
var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset") |
|||
|
|||
// newFlowController gets a new flow controller
|
|||
func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController { |
|||
fc := flowController{ |
|||
streamID: streamID, |
|||
contributesToConnection: contributesToConnection, |
|||
connectionParameters: connectionParameters, |
|||
rttStats: rttStats, |
|||
} |
|||
|
|||
if streamID == 0 { |
|||
fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow() |
|||
fc.receiveWindowIncrement = fc.receiveWindow |
|||
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow() |
|||
} else { |
|||
fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow() |
|||
fc.receiveWindowIncrement = fc.receiveWindow |
|||
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow() |
|||
} |
|||
|
|||
return &fc |
|||
} |
|||
|
|||
func (c *flowController) ContributesToConnection() bool { |
|||
return c.contributesToConnection |
|||
} |
|||
|
|||
func (c *flowController) getSendWindow() protocol.ByteCount { |
|||
if c.sendWindow == 0 { |
|||
if c.streamID == 0 { |
|||
return c.connectionParameters.GetSendConnectionFlowControlWindow() |
|||
} |
|||
return c.connectionParameters.GetSendStreamFlowControlWindow() |
|||
} |
|||
return c.sendWindow |
|||
} |
|||
|
|||
func (c *flowController) AddBytesSent(n protocol.ByteCount) { |
|||
c.bytesSent += n |
|||
} |
|||
|
|||
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
|||
// it returns true if the window was actually updated
|
|||
func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool { |
|||
if newOffset > c.sendWindow { |
|||
c.sendWindow = newOffset |
|||
return true |
|||
} |
|||
return false |
|||
} |
|||
|
|||
func (c *flowController) SendWindowSize() protocol.ByteCount { |
|||
sendWindow := c.getSendWindow() |
|||
|
|||
if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here
|
|||
return 0 |
|||
} |
|||
return sendWindow - c.bytesSent |
|||
} |
|||
|
|||
func (c *flowController) SendWindowOffset() protocol.ByteCount { |
|||
return c.getSendWindow() |
|||
} |
|||
|
|||
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
|
|||
// Should **only** be used for the stream-level FlowController
|
|||
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
|
|||
// This error occurs every time StreamFrames get reordered and has to be ignored in that case
|
|||
// It should only be treated as an error when resetting a stream
|
|||
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) { |
|||
if byteOffset == c.highestReceived { |
|||
return 0, nil |
|||
} |
|||
if byteOffset > c.highestReceived { |
|||
increment := byteOffset - c.highestReceived |
|||
c.highestReceived = byteOffset |
|||
return increment, nil |
|||
} |
|||
return 0, ErrReceivedSmallerByteOffset |
|||
} |
|||
|
|||
// IncrementHighestReceived adds an increment to the highestReceived value
|
|||
// Should **only** be used for the connection-level FlowController
|
|||
func (c *flowController) IncrementHighestReceived(increment protocol.ByteCount) { |
|||
c.highestReceived += increment |
|||
} |
|||
|
|||
func (c *flowController) AddBytesRead(n protocol.ByteCount) { |
|||
// pretend we sent a WindowUpdate when reading the first byte
|
|||
// this way auto-tuning of the window increment already works for the first WindowUpdate
|
|||
if c.bytesRead == 0 { |
|||
c.lastWindowUpdateTime = time.Now() |
|||
} |
|||
c.bytesRead += n |
|||
} |
|||
|
|||
// MaybeUpdateWindow updates the receive window, if necessary
|
|||
// if the receive window increment is changed, the new value is returned, otherwise a 0
|
|||
// the last return value is the new offset of the receive window
|
|||
func (c *flowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) { |
|||
diff := c.receiveWindow - c.bytesRead |
|||
|
|||
// Chromium implements the same threshold
|
|||
if diff < (c.receiveWindowIncrement / 2) { |
|||
var newWindowIncrement protocol.ByteCount |
|||
oldWindowIncrement := c.receiveWindowIncrement |
|||
|
|||
c.maybeAdjustWindowIncrement() |
|||
if c.receiveWindowIncrement != oldWindowIncrement { |
|||
newWindowIncrement = c.receiveWindowIncrement |
|||
} |
|||
|
|||
c.lastWindowUpdateTime = time.Now() |
|||
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement |
|||
return true, newWindowIncrement, c.receiveWindow |
|||
} |
|||
|
|||
return false, 0, 0 |
|||
} |
|||
|
|||
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
|
|||
func (c *flowController) maybeAdjustWindowIncrement() { |
|||
if c.lastWindowUpdateTime.IsZero() { |
|||
return |
|||
} |
|||
|
|||
rtt := c.rttStats.SmoothedRTT() |
|||
if rtt == 0 { |
|||
return |
|||
} |
|||
|
|||
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime) |
|||
|
|||
// interval between the window updates is sufficiently large, no need to increase the increment
|
|||
if timeSinceLastWindowUpdate >= 2*rtt { |
|||
return |
|||
} |
|||
|
|||
oldWindowSize := c.receiveWindowIncrement |
|||
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement) |
|||
|
|||
// debug log, if the window size was actually increased
|
|||
if oldWindowSize < c.receiveWindowIncrement { |
|||
newWindowSize := c.receiveWindowIncrement / (1 << 10) |
|||
if c.streamID == 0 { |
|||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize) |
|||
} else { |
|||
utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// EnsureMinimumWindowIncrement sets a minimum window increment
|
|||
// it is intended be used for the connection-level flow controller
|
|||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
|||
func (c *flowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) { |
|||
if inc > c.receiveWindowIncrement { |
|||
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement) |
|||
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
|
|||
} |
|||
} |
|||
|
|||
func (c *flowController) CheckFlowControlViolation() bool { |
|||
return c.highestReceived > c.receiveWindow |
|||
} |
|||
@ -1,26 +0,0 @@ |
|||
package flowcontrol |
|||
|
|||
import "github.com/lucas-clemente/quic-go/protocol" |
|||
|
|||
// WindowUpdate provides the data for WindowUpdateFrames.
|
|||
type WindowUpdate struct { |
|||
StreamID protocol.StreamID |
|||
Offset protocol.ByteCount |
|||
} |
|||
|
|||
// A FlowControlManager manages the flow control
|
|||
type FlowControlManager interface { |
|||
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) |
|||
RemoveStream(streamID protocol.StreamID) |
|||
// methods needed for receiving data
|
|||
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error |
|||
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error |
|||
AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error |
|||
GetWindowUpdates() []WindowUpdate |
|||
GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) |
|||
// methods needed for sending data
|
|||
AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error |
|||
SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) |
|||
RemainingConnectionWindowSize() protocol.ByteCount |
|||
UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) |
|||
} |
|||
@ -1,9 +0,0 @@ |
|||
package frames |
|||
|
|||
import "github.com/lucas-clemente/quic-go/protocol" |
|||
|
|||
// AckRange is an ACK range
|
|||
type AckRange struct { |
|||
FirstPacketNumber protocol.PacketNumber |
|||
LastPacketNumber protocol.PacketNumber |
|||
} |
|||
@ -1,44 +0,0 @@ |
|||
package frames |
|||
|
|||
import ( |
|||
"bytes" |
|||
|
|||
"github.com/lucas-clemente/quic-go/internal/utils" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
) |
|||
|
|||
// A BlockedFrame in QUIC
|
|||
type BlockedFrame struct { |
|||
StreamID protocol.StreamID |
|||
} |
|||
|
|||
//Write writes a BlockedFrame frame
|
|||
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { |
|||
b.WriteByte(0x05) |
|||
utils.WriteUint32(b, uint32(f.StreamID)) |
|||
return nil |
|||
} |
|||
|
|||
// MinLength of a written frame
|
|||
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { |
|||
return 1 + 4, nil |
|||
} |
|||
|
|||
// ParseBlockedFrame parses a BLOCKED frame
|
|||
func ParseBlockedFrame(r *bytes.Reader) (*BlockedFrame, error) { |
|||
frame := &BlockedFrame{} |
|||
|
|||
// read the TypeByte
|
|||
_, err := r.ReadByte() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
sid, err := utils.ReadUint32(r) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
frame.StreamID = protocol.StreamID(sid) |
|||
|
|||
return frame, nil |
|||
} |
|||
@ -1,28 +0,0 @@ |
|||
package frames |
|||
|
|||
import "github.com/lucas-clemente/quic-go/internal/utils" |
|||
|
|||
// LogFrame logs a frame, either sent or received
|
|||
func LogFrame(frame Frame, sent bool) { |
|||
if !utils.Debug() { |
|||
return |
|||
} |
|||
dir := "<-" |
|||
if sent { |
|||
dir = "->" |
|||
} |
|||
switch f := frame.(type) { |
|||
case *StreamFrame: |
|||
utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) |
|||
case *StopWaitingFrame: |
|||
if sent { |
|||
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) |
|||
} else { |
|||
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) |
|||
} |
|||
case *AckFrame: |
|||
utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) |
|||
default: |
|||
utils.Debugf("\t%s %#v", dir, frame) |
|||
} |
|||
} |
|||
@ -1,54 +0,0 @@ |
|||
package frames |
|||
|
|||
import ( |
|||
"bytes" |
|||
|
|||
"github.com/lucas-clemente/quic-go/internal/utils" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
) |
|||
|
|||
// A WindowUpdateFrame in QUIC
|
|||
type WindowUpdateFrame struct { |
|||
StreamID protocol.StreamID |
|||
ByteOffset protocol.ByteCount |
|||
} |
|||
|
|||
//Write writes a RST_STREAM frame
|
|||
func (f *WindowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { |
|||
typeByte := uint8(0x04) |
|||
b.WriteByte(typeByte) |
|||
|
|||
utils.WriteUint32(b, uint32(f.StreamID)) |
|||
utils.WriteUint64(b, uint64(f.ByteOffset)) |
|||
return nil |
|||
} |
|||
|
|||
// MinLength of a written frame
|
|||
func (f *WindowUpdateFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { |
|||
return 1 + 4 + 8, nil |
|||
} |
|||
|
|||
// ParseWindowUpdateFrame parses a RST_STREAM frame
|
|||
func ParseWindowUpdateFrame(r *bytes.Reader) (*WindowUpdateFrame, error) { |
|||
frame := &WindowUpdateFrame{} |
|||
|
|||
// read the TypeByte
|
|||
_, err := r.ReadByte() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
sid, err := utils.ReadUint32(r) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
frame.StreamID = protocol.StreamID(sid) |
|||
|
|||
byteOffset, err := utils.ReadUint64(r) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
frame.ByteOffset = protocol.ByteCount(byteOffset) |
|||
|
|||
return frame, nil |
|||
} |
|||
@ -1,296 +0,0 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"crypto/tls" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"net" |
|||
"net/http" |
|||
"strings" |
|||
"sync" |
|||
|
|||
"golang.org/x/net/http2" |
|||
"golang.org/x/net/http2/hpack" |
|||
"golang.org/x/net/idna" |
|||
|
|||
quic "github.com/lucas-clemente/quic-go" |
|||
"github.com/lucas-clemente/quic-go/internal/utils" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
"github.com/lucas-clemente/quic-go/qerr" |
|||
) |
|||
|
|||
type roundTripperOpts struct { |
|||
DisableCompression bool |
|||
} |
|||
|
|||
var dialAddr = quic.DialAddr |
|||
|
|||
// client is a HTTP2 client doing QUIC requests
|
|||
type client struct { |
|||
mutex sync.RWMutex |
|||
|
|||
tlsConf *tls.Config |
|||
config *quic.Config |
|||
opts *roundTripperOpts |
|||
|
|||
hostname string |
|||
encryptionLevel protocol.EncryptionLevel |
|||
handshakeErr error |
|||
dialOnce sync.Once |
|||
|
|||
session quic.Session |
|||
headerStream quic.Stream |
|||
headerErr *qerr.QuicError |
|||
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
|
|||
requestWriter *requestWriter |
|||
|
|||
responses map[protocol.StreamID]chan *http.Response |
|||
} |
|||
|
|||
var _ http.RoundTripper = &client{} |
|||
|
|||
var defaultQuicConfig = &quic.Config{ |
|||
RequestConnectionIDTruncation: true, |
|||
KeepAlive: true, |
|||
} |
|||
|
|||
// newClient creates a new client
|
|||
func newClient( |
|||
hostname string, |
|||
tlsConfig *tls.Config, |
|||
opts *roundTripperOpts, |
|||
quicConfig *quic.Config, |
|||
) *client { |
|||
config := defaultQuicConfig |
|||
if quicConfig != nil { |
|||
config = quicConfig |
|||
} |
|||
return &client{ |
|||
hostname: authorityAddr("https", hostname), |
|||
responses: make(map[protocol.StreamID]chan *http.Response), |
|||
encryptionLevel: protocol.EncryptionUnencrypted, |
|||
tlsConf: tlsConfig, |
|||
config: config, |
|||
opts: opts, |
|||
headerErrored: make(chan struct{}), |
|||
} |
|||
} |
|||
|
|||
// dial dials the connection
|
|||
func (c *client) dial() error { |
|||
var err error |
|||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
// once the version has been negotiated, open the header stream
|
|||
c.headerStream, err = c.session.OpenStream() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
if c.headerStream.StreamID() != 3 { |
|||
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3") |
|||
} |
|||
c.requestWriter = newRequestWriter(c.headerStream) |
|||
go c.handleHeaderStream() |
|||
return nil |
|||
} |
|||
|
|||
func (c *client) handleHeaderStream() { |
|||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) |
|||
h2framer := http2.NewFramer(nil, c.headerStream) |
|||
|
|||
var lastStream protocol.StreamID |
|||
|
|||
for { |
|||
frame, err := h2framer.ReadFrame() |
|||
if err != nil { |
|||
c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") |
|||
break |
|||
} |
|||
lastStream = protocol.StreamID(frame.Header().StreamID) |
|||
hframe, ok := frame.(*http2.HeadersFrame) |
|||
if !ok { |
|||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") |
|||
break |
|||
} |
|||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} |
|||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) |
|||
if err != nil { |
|||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") |
|||
break |
|||
} |
|||
|
|||
c.mutex.RLock() |
|||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] |
|||
c.mutex.RUnlock() |
|||
if !ok { |
|||
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) |
|||
break |
|||
} |
|||
|
|||
rsp, err := responseFromHeaders(mhframe) |
|||
if err != nil { |
|||
c.headerErr = qerr.Error(qerr.InternalError, err.Error()) |
|||
} |
|||
responseChan <- rsp |
|||
} |
|||
|
|||
// stop all running request
|
|||
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) |
|||
close(c.headerErrored) |
|||
} |
|||
|
|||
// Roundtrip executes a request and returns a response
|
|||
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { |
|||
// TODO: add port to address, if it doesn't have one
|
|||
if req.URL.Scheme != "https" { |
|||
return nil, errors.New("quic http2: unsupported scheme") |
|||
} |
|||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { |
|||
return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host) |
|||
} |
|||
|
|||
c.dialOnce.Do(func() { |
|||
c.handshakeErr = c.dial() |
|||
}) |
|||
|
|||
if c.handshakeErr != nil { |
|||
return nil, c.handshakeErr |
|||
} |
|||
|
|||
hasBody := (req.Body != nil) |
|||
|
|||
responseChan := make(chan *http.Response) |
|||
dataStream, err := c.session.OpenStreamSync() |
|||
if err != nil { |
|||
_ = c.CloseWithError(err) |
|||
return nil, err |
|||
} |
|||
c.mutex.Lock() |
|||
c.responses[dataStream.StreamID()] = responseChan |
|||
c.mutex.Unlock() |
|||
|
|||
var requestedGzip bool |
|||
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { |
|||
requestedGzip = true |
|||
} |
|||
// TODO: add support for trailers
|
|||
endStream := !hasBody |
|||
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) |
|||
if err != nil { |
|||
_ = c.CloseWithError(err) |
|||
return nil, err |
|||
} |
|||
|
|||
resc := make(chan error, 1) |
|||
if hasBody { |
|||
go func() { |
|||
resc <- c.writeRequestBody(dataStream, req.Body) |
|||
}() |
|||
} |
|||
|
|||
var res *http.Response |
|||
|
|||
var receivedResponse bool |
|||
var bodySent bool |
|||
|
|||
if !hasBody { |
|||
bodySent = true |
|||
} |
|||
|
|||
for !(bodySent && receivedResponse) { |
|||
select { |
|||
case res = <-responseChan: |
|||
receivedResponse = true |
|||
c.mutex.Lock() |
|||
delete(c.responses, dataStream.StreamID()) |
|||
c.mutex.Unlock() |
|||
case err := <-resc: |
|||
bodySent = true |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
case <-c.headerErrored: |
|||
// an error occured on the header stream
|
|||
_ = c.CloseWithError(c.headerErr) |
|||
return nil, c.headerErr |
|||
} |
|||
} |
|||
|
|||
// TODO: correctly set this variable
|
|||
var streamEnded bool |
|||
isHead := (req.Method == "HEAD") |
|||
|
|||
res = setLength(res, isHead, streamEnded) |
|||
|
|||
if streamEnded || isHead { |
|||
res.Body = noBody |
|||
} else { |
|||
res.Body = dataStream |
|||
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { |
|||
res.Header.Del("Content-Encoding") |
|||
res.Header.Del("Content-Length") |
|||
res.ContentLength = -1 |
|||
res.Body = &gzipReader{body: res.Body} |
|||
res.Uncompressed = true |
|||
} |
|||
} |
|||
|
|||
res.Request = req |
|||
return res, nil |
|||
} |
|||
|
|||
func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) { |
|||
defer func() { |
|||
cerr := body.Close() |
|||
if err == nil { |
|||
// TODO: what to do with dataStream here? Maybe reset it?
|
|||
err = cerr |
|||
} |
|||
}() |
|||
|
|||
_, err = io.Copy(dataStream, body) |
|||
if err != nil { |
|||
// TODO: what to do with dataStream here? Maybe reset it?
|
|||
return err |
|||
} |
|||
return dataStream.Close() |
|||
} |
|||
|
|||
// Close closes the client
|
|||
func (c *client) CloseWithError(e error) error { |
|||
if c.session == nil { |
|||
return nil |
|||
} |
|||
return c.session.Close(e) |
|||
} |
|||
|
|||
func (c *client) Close() error { |
|||
return c.CloseWithError(nil) |
|||
} |
|||
|
|||
// copied from net/transport.go
|
|||
|
|||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
|||
// and returns a host:port. The port 443 is added if needed.
|
|||
func authorityAddr(scheme string, authority string) (addr string) { |
|||
host, port, err := net.SplitHostPort(authority) |
|||
if err != nil { // authority didn't have a port
|
|||
port = "443" |
|||
if scheme == "http" { |
|||
port = "80" |
|||
} |
|||
host = authority |
|||
} |
|||
if a, err := idna.ToASCII(host); err == nil { |
|||
host = a |
|||
} |
|||
// IPv6 address literal, without a port:
|
|||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { |
|||
return host + ":" + port |
|||
} |
|||
return net.JoinHostPort(host, port) |
|||
} |
|||
@ -1,35 +0,0 @@ |
|||
package h2quic |
|||
|
|||
// copied from net/transport.go
|
|||
|
|||
// gzipReader wraps a response body so it can lazily
|
|||
// call gzip.NewReader on the first call to Read
|
|||
import ( |
|||
"compress/gzip" |
|||
"io" |
|||
) |
|||
|
|||
// call gzip.NewReader on the first call to Read
|
|||
type gzipReader struct { |
|||
body io.ReadCloser // underlying Response.Body
|
|||
zr *gzip.Reader // lazily-initialized gzip reader
|
|||
zerr error // sticky error
|
|||
} |
|||
|
|||
func (gz *gzipReader) Read(p []byte) (n int, err error) { |
|||
if gz.zerr != nil { |
|||
return 0, gz.zerr |
|||
} |
|||
if gz.zr == nil { |
|||
gz.zr, err = gzip.NewReader(gz.body) |
|||
if err != nil { |
|||
gz.zerr = err |
|||
return 0, err |
|||
} |
|||
} |
|||
return gz.zr.Read(p) |
|||
} |
|||
|
|||
func (gz *gzipReader) Close() error { |
|||
return gz.body.Close() |
|||
} |
|||
@ -1,80 +0,0 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"crypto/tls" |
|||
"errors" |
|||
"net/http" |
|||
"net/url" |
|||
"strconv" |
|||
"strings" |
|||
|
|||
"golang.org/x/net/http2/hpack" |
|||
) |
|||
|
|||
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { |
|||
var path, authority, method, contentLengthStr string |
|||
httpHeaders := http.Header{} |
|||
|
|||
for _, h := range headers { |
|||
switch h.Name { |
|||
case ":path": |
|||
path = h.Value |
|||
case ":method": |
|||
method = h.Value |
|||
case ":authority": |
|||
authority = h.Value |
|||
case "content-length": |
|||
contentLengthStr = h.Value |
|||
default: |
|||
if !h.IsPseudo() { |
|||
httpHeaders.Add(h.Name, h.Value) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
|
|||
if len(httpHeaders["Cookie"]) > 0 { |
|||
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; ")) |
|||
} |
|||
|
|||
if len(path) == 0 || len(authority) == 0 || len(method) == 0 { |
|||
return nil, errors.New(":path, :authority and :method must not be empty") |
|||
} |
|||
|
|||
u, err := url.Parse(path) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
var contentLength int64 |
|||
if len(contentLengthStr) > 0 { |
|||
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
|
|||
return &http.Request{ |
|||
Method: method, |
|||
URL: u, |
|||
Proto: "HTTP/2.0", |
|||
ProtoMajor: 2, |
|||
ProtoMinor: 0, |
|||
Header: httpHeaders, |
|||
Body: nil, |
|||
ContentLength: contentLength, |
|||
Host: authority, |
|||
RequestURI: path, |
|||
TLS: &tls.ConnectionState{}, |
|||
}, nil |
|||
} |
|||
|
|||
func hostnameFromRequest(req *http.Request) string { |
|||
if len(req.Host) > 0 { |
|||
return req.Host |
|||
} |
|||
if req.URL != nil { |
|||
return req.URL.Host |
|||
} |
|||
return "" |
|||
} |
|||
@ -1,29 +0,0 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"io" |
|||
|
|||
quic "github.com/lucas-clemente/quic-go" |
|||
) |
|||
|
|||
type requestBody struct { |
|||
requestRead bool |
|||
dataStream quic.Stream |
|||
} |
|||
|
|||
// make sure the requestBody can be used as a http.Request.Body
|
|||
var _ io.ReadCloser = &requestBody{} |
|||
|
|||
func newRequestBody(stream quic.Stream) *requestBody { |
|||
return &requestBody{dataStream: stream} |
|||
} |
|||
|
|||
func (b *requestBody) Read(p []byte) (int, error) { |
|||
b.requestRead = true |
|||
return b.dataStream.Read(p) |
|||
} |
|||
|
|||
func (b *requestBody) Close() error { |
|||
// stream's Close() closes the write side, not the read side
|
|||
return nil |
|||
} |
|||
@ -1,201 +0,0 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"net/http" |
|||
"strconv" |
|||
"strings" |
|||
"sync" |
|||
|
|||
"golang.org/x/net/http2" |
|||
"golang.org/x/net/http2/hpack" |
|||
"golang.org/x/net/lex/httplex" |
|||
|
|||
quic "github.com/lucas-clemente/quic-go" |
|||
"github.com/lucas-clemente/quic-go/internal/utils" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
) |
|||
|
|||
type requestWriter struct { |
|||
mutex sync.Mutex |
|||
headerStream quic.Stream |
|||
|
|||
henc *hpack.Encoder |
|||
hbuf bytes.Buffer // HPACK encoder writes into this
|
|||
} |
|||
|
|||
const defaultUserAgent = "quic-go" |
|||
|
|||
func newRequestWriter(headerStream quic.Stream) *requestWriter { |
|||
rw := &requestWriter{ |
|||
headerStream: headerStream, |
|||
} |
|||
rw.henc = hpack.NewEncoder(&rw.hbuf) |
|||
return rw |
|||
} |
|||
|
|||
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error { |
|||
// TODO: add support for trailers
|
|||
// TODO: add support for gzip compression
|
|||
// TODO: write continuation frames, if the header frame is too long
|
|||
|
|||
w.mutex.Lock() |
|||
defer w.mutex.Unlock() |
|||
|
|||
w.encodeHeaders(req, requestGzip, "", actualContentLength(req)) |
|||
h2framer := http2.NewFramer(w.headerStream, nil) |
|||
return h2framer.WriteHeaders(http2.HeadersFrameParam{ |
|||
StreamID: uint32(dataStreamID), |
|||
EndHeaders: true, |
|||
EndStream: endStream, |
|||
BlockFragment: w.hbuf.Bytes(), |
|||
Priority: http2.PriorityParam{Weight: 0xff}, |
|||
}) |
|||
} |
|||
|
|||
// the rest of this files is copied from http2.Transport
|
|||
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { |
|||
w.hbuf.Reset() |
|||
|
|||
host := req.Host |
|||
if host == "" { |
|||
host = req.URL.Host |
|||
} |
|||
host, err := httplex.PunycodeHostPort(host) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
var path string |
|||
if req.Method != "CONNECT" { |
|||
path = req.URL.RequestURI() |
|||
if !validPseudoPath(path) { |
|||
orig := path |
|||
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) |
|||
if !validPseudoPath(path) { |
|||
if req.URL.Opaque != "" { |
|||
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) |
|||
} else { |
|||
return nil, fmt.Errorf("invalid request :path %q", orig) |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Check for any invalid headers and return an error before we
|
|||
// potentially pollute our hpack state. (We want to be able to
|
|||
// continue to reuse the hpack encoder for future requests)
|
|||
for k, vv := range req.Header { |
|||
if !httplex.ValidHeaderFieldName(k) { |
|||
return nil, fmt.Errorf("invalid HTTP header name %q", k) |
|||
} |
|||
for _, v := range vv { |
|||
if !httplex.ValidHeaderFieldValue(v) { |
|||
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// 8.1.2.3 Request Pseudo-Header Fields
|
|||
// The :path pseudo-header field includes the path and query parts of the
|
|||
// target URI (the path-absolute production and optionally a '?' character
|
|||
// followed by the query production (see Sections 3.3 and 3.4 of
|
|||
// [RFC3986]).
|
|||
w.writeHeader(":authority", host) |
|||
w.writeHeader(":method", req.Method) |
|||
if req.Method != "CONNECT" { |
|||
w.writeHeader(":path", path) |
|||
w.writeHeader(":scheme", req.URL.Scheme) |
|||
} |
|||
if trailers != "" { |
|||
w.writeHeader("trailer", trailers) |
|||
} |
|||
|
|||
var didUA bool |
|||
for k, vv := range req.Header { |
|||
lowKey := strings.ToLower(k) |
|||
switch lowKey { |
|||
case "host", "content-length": |
|||
// Host is :authority, already sent.
|
|||
// Content-Length is automatic, set below.
|
|||
continue |
|||
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": |
|||
// Per 8.1.2.2 Connection-Specific Header
|
|||
// Fields, don't send connection-specific
|
|||
// fields. We have already checked if any
|
|||
// are error-worthy so just ignore the rest.
|
|||
continue |
|||
case "user-agent": |
|||
// Match Go's http1 behavior: at most one
|
|||
// User-Agent. If set to nil or empty string,
|
|||
// then omit it. Otherwise if not mentioned,
|
|||
// include the default (below).
|
|||
didUA = true |
|||
if len(vv) < 1 { |
|||
continue |
|||
} |
|||
vv = vv[:1] |
|||
if vv[0] == "" { |
|||
continue |
|||
} |
|||
} |
|||
for _, v := range vv { |
|||
w.writeHeader(lowKey, v) |
|||
} |
|||
} |
|||
if shouldSendReqContentLength(req.Method, contentLength) { |
|||
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) |
|||
} |
|||
if addGzipHeader { |
|||
w.writeHeader("accept-encoding", "gzip") |
|||
} |
|||
if !didUA { |
|||
w.writeHeader("user-agent", defaultUserAgent) |
|||
} |
|||
return w.hbuf.Bytes(), nil |
|||
} |
|||
|
|||
func (w *requestWriter) writeHeader(name, value string) { |
|||
utils.Debugf("http2: Transport encoding header %q = %q", name, value) |
|||
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) |
|||
} |
|||
|
|||
// shouldSendReqContentLength reports whether the http2.Transport should send
|
|||
// a "content-length" request header. This logic is basically a copy of the net/http
|
|||
// transferWriter.shouldSendContentLength.
|
|||
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
|||
// -1 means unknown.
|
|||
func shouldSendReqContentLength(method string, contentLength int64) bool { |
|||
if contentLength > 0 { |
|||
return true |
|||
} |
|||
if contentLength < 0 { |
|||
return false |
|||
} |
|||
// For zero bodies, whether we send a content-length depends on the method.
|
|||
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
|||
switch method { |
|||
case "POST", "PUT", "PATCH": |
|||
return true |
|||
default: |
|||
return false |
|||
} |
|||
} |
|||
|
|||
func validPseudoPath(v string) bool { |
|||
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" |
|||
} |
|||
|
|||
// actualContentLength returns a sanitized version of
|
|||
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
|||
// means unknown.
|
|||
func actualContentLength(req *http.Request) int64 { |
|||
if req.Body == nil { |
|||
return 0 |
|||
} |
|||
if req.ContentLength != 0 { |
|||
return req.ContentLength |
|||
} |
|||
return -1 |
|||
} |
|||
@ -1,111 +0,0 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"bytes" |
|||
"errors" |
|||
"io" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/textproto" |
|||
"strconv" |
|||
"strings" |
|||
|
|||
"golang.org/x/net/http2" |
|||
) |
|||
|
|||
// copied from net/http2/transport.go
|
|||
|
|||
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") |
|||
var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) |
|||
|
|||
// from the handleResponse function
|
|||
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { |
|||
if f.Truncated { |
|||
return nil, errResponseHeaderListSize |
|||
} |
|||
|
|||
status := f.PseudoValue("status") |
|||
if status == "" { |
|||
return nil, errors.New("missing status pseudo header") |
|||
} |
|||
statusCode, err := strconv.Atoi(status) |
|||
if err != nil { |
|||
return nil, errors.New("malformed non-numeric status pseudo header") |
|||
} |
|||
|
|||
if statusCode == 100 { |
|||
// TODO: handle this
|
|||
|
|||
// traceGot100Continue(cs.trace)
|
|||
// if cs.on100 != nil {
|
|||
// cs.on100() // forces any write delay timer to fire
|
|||
// }
|
|||
// cs.pastHeaders = false // do it all again
|
|||
// return nil, nil
|
|||
} |
|||
|
|||
header := make(http.Header) |
|||
res := &http.Response{ |
|||
Proto: "HTTP/2.0", |
|||
ProtoMajor: 2, |
|||
Header: header, |
|||
StatusCode: statusCode, |
|||
Status: status + " " + http.StatusText(statusCode), |
|||
} |
|||
for _, hf := range f.RegularFields() { |
|||
key := http.CanonicalHeaderKey(hf.Name) |
|||
if key == "Trailer" { |
|||
t := res.Trailer |
|||
if t == nil { |
|||
t = make(http.Header) |
|||
res.Trailer = t |
|||
} |
|||
foreachHeaderElement(hf.Value, func(v string) { |
|||
t[http.CanonicalHeaderKey(v)] = nil |
|||
}) |
|||
} else { |
|||
header[key] = append(header[key], hf.Value) |
|||
} |
|||
} |
|||
|
|||
return res, nil |
|||
} |
|||
|
|||
// continuation of the handleResponse function
|
|||
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response { |
|||
if !streamEnded || isHead { |
|||
res.ContentLength = -1 |
|||
if clens := res.Header["Content-Length"]; len(clens) == 1 { |
|||
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { |
|||
res.ContentLength = clen64 |
|||
} else { |
|||
// TODO: care? unlike http/1, it won't mess up our framing, so it's
|
|||
// more safe smuggling-wise to ignore.
|
|||
} |
|||
} else if len(clens) > 1 { |
|||
// TODO: care? unlike http/1, it won't mess up our framing, so it's
|
|||
// more safe smuggling-wise to ignore.
|
|||
} |
|||
} |
|||
return res |
|||
} |
|||
|
|||
// copied from net/http/server.go
|
|||
|
|||
// foreachHeaderElement splits v according to the "#rule" construction
|
|||
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
|
|||
func foreachHeaderElement(v string, fn func(string)) { |
|||
v = textproto.TrimString(v) |
|||
if v == "" { |
|||
return |
|||
} |
|||
if !strings.Contains(v, ",") { |
|||
fn(v) |
|||
return |
|||
} |
|||
for _, f := range strings.Split(v, ",") { |
|||
if f = textproto.TrimString(f); f != "" { |
|||
fn(f) |
|||
} |
|||
} |
|||
} |
|||
@ -1,108 +0,0 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"bytes" |
|||
"net/http" |
|||
"strconv" |
|||
"strings" |
|||
"sync" |
|||
|
|||
quic "github.com/lucas-clemente/quic-go" |
|||
"github.com/lucas-clemente/quic-go/internal/utils" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
"golang.org/x/net/http2" |
|||
"golang.org/x/net/http2/hpack" |
|||
) |
|||
|
|||
type responseWriter struct { |
|||
dataStreamID protocol.StreamID |
|||
dataStream quic.Stream |
|||
|
|||
headerStream quic.Stream |
|||
headerStreamMutex *sync.Mutex |
|||
|
|||
header http.Header |
|||
status int // status code passed to WriteHeader
|
|||
headerWritten bool |
|||
} |
|||
|
|||
func newResponseWriter(headerStream quic.Stream, headerStreamMutex *sync.Mutex, dataStream quic.Stream, dataStreamID protocol.StreamID) *responseWriter { |
|||
return &responseWriter{ |
|||
header: http.Header{}, |
|||
headerStream: headerStream, |
|||
headerStreamMutex: headerStreamMutex, |
|||
dataStream: dataStream, |
|||
dataStreamID: dataStreamID, |
|||
} |
|||
} |
|||
|
|||
func (w *responseWriter) Header() http.Header { |
|||
return w.header |
|||
} |
|||
|
|||
func (w *responseWriter) WriteHeader(status int) { |
|||
if w.headerWritten { |
|||
return |
|||
} |
|||
w.headerWritten = true |
|||
w.status = status |
|||
|
|||
var headers bytes.Buffer |
|||
enc := hpack.NewEncoder(&headers) |
|||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) |
|||
|
|||
for k, v := range w.header { |
|||
for index := range v { |
|||
enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) |
|||
} |
|||
} |
|||
|
|||
utils.Infof("Responding with %d", status) |
|||
w.headerStreamMutex.Lock() |
|||
defer w.headerStreamMutex.Unlock() |
|||
h2framer := http2.NewFramer(w.headerStream, nil) |
|||
err := h2framer.WriteHeaders(http2.HeadersFrameParam{ |
|||
StreamID: uint32(w.dataStreamID), |
|||
EndHeaders: true, |
|||
BlockFragment: headers.Bytes(), |
|||
}) |
|||
if err != nil { |
|||
utils.Errorf("could not write h2 header: %s", err.Error()) |
|||
} |
|||
} |
|||
|
|||
func (w *responseWriter) Write(p []byte) (int, error) { |
|||
if !w.headerWritten { |
|||
w.WriteHeader(200) |
|||
} |
|||
if !bodyAllowedForStatus(w.status) { |
|||
return 0, http.ErrBodyNotAllowed |
|||
} |
|||
return w.dataStream.Write(p) |
|||
} |
|||
|
|||
func (w *responseWriter) Flush() {} |
|||
|
|||
// TODO: Implement a functional CloseNotify method.
|
|||
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) } |
|||
|
|||
// test that we implement http.Flusher
|
|||
var _ http.Flusher = &responseWriter{} |
|||
|
|||
// test that we implement http.CloseNotifier
|
|||
var _ http.CloseNotifier = &responseWriter{} |
|||
|
|||
// copied from http2/http2.go
|
|||
// bodyAllowedForStatus reports whether a given response status code
|
|||
// permits a body. See RFC 2616, section 4.4.
|
|||
func bodyAllowedForStatus(status int) bool { |
|||
switch { |
|||
case status >= 100 && status <= 199: |
|||
return false |
|||
case status == 204: |
|||
return false |
|||
case status == 304: |
|||
return false |
|||
} |
|||
return true |
|||
} |
|||
@ -1,168 +0,0 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"crypto/tls" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"net/http" |
|||
"strings" |
|||
"sync" |
|||
|
|||
quic "github.com/lucas-clemente/quic-go" |
|||
|
|||
"golang.org/x/net/lex/httplex" |
|||
) |
|||
|
|||
type roundTripCloser interface { |
|||
http.RoundTripper |
|||
io.Closer |
|||
} |
|||
|
|||
// RoundTripper implements the http.RoundTripper interface
|
|||
type RoundTripper struct { |
|||
mutex sync.Mutex |
|||
|
|||
// DisableCompression, if true, prevents the Transport from
|
|||
// requesting compression with an "Accept-Encoding: gzip"
|
|||
// request header when the Request contains no existing
|
|||
// Accept-Encoding value. If the Transport requests gzip on
|
|||
// its own and gets a gzipped response, it's transparently
|
|||
// decoded in the Response.Body. However, if the user
|
|||
// explicitly requested gzip it is not automatically
|
|||
// uncompressed.
|
|||
DisableCompression bool |
|||
|
|||
// TLSClientConfig specifies the TLS configuration to use with
|
|||
// tls.Client. If nil, the default configuration is used.
|
|||
TLSClientConfig *tls.Config |
|||
|
|||
// QuicConfig is the quic.Config used for dialing new connections.
|
|||
// If nil, reasonable default values will be used.
|
|||
QuicConfig *quic.Config |
|||
|
|||
clients map[string]roundTripCloser |
|||
} |
|||
|
|||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
|||
type RoundTripOpt struct { |
|||
// OnlyCachedConn controls whether the RoundTripper may
|
|||
// create a new QUIC connection. If set true and
|
|||
// no cached connection is available, RoundTrip
|
|||
// will return ErrNoCachedConn.
|
|||
OnlyCachedConn bool |
|||
} |
|||
|
|||
var _ roundTripCloser = &RoundTripper{} |
|||
|
|||
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
|||
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available") |
|||
|
|||
// RoundTripOpt is like RoundTrip, but takes options.
|
|||
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { |
|||
if req.URL == nil { |
|||
closeRequestBody(req) |
|||
return nil, errors.New("quic: nil Request.URL") |
|||
} |
|||
if req.URL.Host == "" { |
|||
closeRequestBody(req) |
|||
return nil, errors.New("quic: no Host in request URL") |
|||
} |
|||
if req.Header == nil { |
|||
closeRequestBody(req) |
|||
return nil, errors.New("quic: nil Request.Header") |
|||
} |
|||
|
|||
if req.URL.Scheme == "https" { |
|||
for k, vv := range req.Header { |
|||
if !httplex.ValidHeaderFieldName(k) { |
|||
return nil, fmt.Errorf("quic: invalid http header field name %q", k) |
|||
} |
|||
for _, v := range vv { |
|||
if !httplex.ValidHeaderFieldValue(v) { |
|||
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k) |
|||
} |
|||
} |
|||
} |
|||
} else { |
|||
closeRequestBody(req) |
|||
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme) |
|||
} |
|||
|
|||
if req.Method != "" && !validMethod(req.Method) { |
|||
closeRequestBody(req) |
|||
return nil, fmt.Errorf("quic: invalid method %q", req.Method) |
|||
} |
|||
|
|||
hostname := authorityAddr("https", hostnameFromRequest(req)) |
|||
cl, err := r.getClient(hostname, opt.OnlyCachedConn) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return cl.RoundTrip(req) |
|||
} |
|||
|
|||
// RoundTrip does a round trip.
|
|||
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { |
|||
return r.RoundTripOpt(req, RoundTripOpt{}) |
|||
} |
|||
|
|||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) { |
|||
r.mutex.Lock() |
|||
defer r.mutex.Unlock() |
|||
|
|||
if r.clients == nil { |
|||
r.clients = make(map[string]roundTripCloser) |
|||
} |
|||
|
|||
client, ok := r.clients[hostname] |
|||
if !ok { |
|||
if onlyCached { |
|||
return nil, ErrNoCachedConn |
|||
} |
|||
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig) |
|||
r.clients[hostname] = client |
|||
} |
|||
return client, nil |
|||
} |
|||
|
|||
// Close closes the QUIC connections that this RoundTripper has used
|
|||
func (r *RoundTripper) Close() error { |
|||
r.mutex.Lock() |
|||
defer r.mutex.Unlock() |
|||
for _, client := range r.clients { |
|||
if err := client.Close(); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
r.clients = nil |
|||
return nil |
|||
} |
|||
|
|||
func closeRequestBody(req *http.Request) { |
|||
if req.Body != nil { |
|||
req.Body.Close() |
|||
} |
|||
} |
|||
|
|||
func validMethod(method string) bool { |
|||
/* |
|||
Method = "OPTIONS" ; Section 9.2 |
|||
| "GET" ; Section 9.3 |
|||
| "HEAD" ; Section 9.4 |
|||
| "POST" ; Section 9.5 |
|||
| "PUT" ; Section 9.6 |
|||
| "DELETE" ; Section 9.7 |
|||
| "TRACE" ; Section 9.8 |
|||
| "CONNECT" ; Section 9.9 |
|||
| extension-method |
|||
extension-method = token |
|||
token = 1*<any CHAR except CTLs or separators> |
|||
*/ |
|||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 |
|||
} |
|||
|
|||
// copied from net/http/http.go
|
|||
func isNotToken(r rune) bool { |
|||
return !httplex.IsTokenRune(r) |
|||
} |
|||
@ -1,382 +0,0 @@ |
|||
package h2quic |
|||
|
|||
import ( |
|||
"crypto/tls" |
|||
"errors" |
|||
"fmt" |
|||
"net" |
|||
"net/http" |
|||
"runtime" |
|||
"strconv" |
|||
"sync" |
|||
"sync/atomic" |
|||
"time" |
|||
|
|||
quic "github.com/lucas-clemente/quic-go" |
|||
"github.com/lucas-clemente/quic-go/internal/utils" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
"github.com/lucas-clemente/quic-go/qerr" |
|||
"golang.org/x/net/http2" |
|||
"golang.org/x/net/http2/hpack" |
|||
) |
|||
|
|||
type streamCreator interface { |
|||
quic.Session |
|||
GetOrOpenStream(protocol.StreamID) (quic.Stream, error) |
|||
} |
|||
|
|||
type remoteCloser interface { |
|||
CloseRemote(protocol.ByteCount) |
|||
} |
|||
|
|||
// allows mocking of quic.Listen and quic.ListenAddr
|
|||
var ( |
|||
quicListen = quic.Listen |
|||
quicListenAddr = quic.ListenAddr |
|||
) |
|||
|
|||
// Server is a HTTP2 server listening for QUIC connections.
|
|||
type Server struct { |
|||
*http.Server |
|||
|
|||
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
|
|||
// If nil, it uses reasonable default values.
|
|||
QuicConfig *quic.Config |
|||
|
|||
// Private flag for demo, do not use
|
|||
CloseAfterFirstRequest bool |
|||
|
|||
port uint32 // used atomically
|
|||
|
|||
listenerMutex sync.Mutex |
|||
listener quic.Listener |
|||
|
|||
supportedVersionsAsString string |
|||
} |
|||
|
|||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
|||
func (s *Server) ListenAndServe() error { |
|||
if s.Server == nil { |
|||
return errors.New("use of h2quic.Server without http.Server") |
|||
} |
|||
return s.serveImpl(s.TLSConfig, nil) |
|||
} |
|||
|
|||
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
|||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { |
|||
var err error |
|||
certs := make([]tls.Certificate, 1) |
|||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
// We currently only use the cert-related stuff from tls.Config,
|
|||
// so we don't need to make a full copy.
|
|||
config := &tls.Config{ |
|||
Certificates: certs, |
|||
} |
|||
return s.serveImpl(config, nil) |
|||
} |
|||
|
|||
// Serve an existing UDP connection.
|
|||
func (s *Server) Serve(conn net.PacketConn) error { |
|||
return s.serveImpl(s.TLSConfig, conn) |
|||
} |
|||
|
|||
func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { |
|||
if s.Server == nil { |
|||
return errors.New("use of h2quic.Server without http.Server") |
|||
} |
|||
s.listenerMutex.Lock() |
|||
if s.listener != nil { |
|||
s.listenerMutex.Unlock() |
|||
return errors.New("ListenAndServe may only be called once") |
|||
} |
|||
|
|||
var ln quic.Listener |
|||
var err error |
|||
if conn == nil { |
|||
ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig) |
|||
} else { |
|||
ln, err = quicListen(conn, tlsConfig, s.QuicConfig) |
|||
} |
|||
if err != nil { |
|||
s.listenerMutex.Unlock() |
|||
return err |
|||
} |
|||
s.listener = ln |
|||
s.listenerMutex.Unlock() |
|||
|
|||
for { |
|||
sess, err := ln.Accept() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
go s.handleHeaderStream(sess.(streamCreator)) |
|||
} |
|||
} |
|||
|
|||
func (s *Server) handleHeaderStream(session streamCreator) { |
|||
stream, err := session.AcceptStream() |
|||
if err != nil { |
|||
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) |
|||
return |
|||
} |
|||
if stream.StreamID() != 3 { |
|||
session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3")) |
|||
return |
|||
} |
|||
|
|||
hpackDecoder := hpack.NewDecoder(4096, nil) |
|||
h2framer := http2.NewFramer(nil, stream) |
|||
|
|||
go func() { |
|||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
|||
for { |
|||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil { |
|||
// QuicErrors must originate from stream.Read() returning an error.
|
|||
// In this case, the session has already logged the error, so we don't
|
|||
// need to log it again.
|
|||
if _, ok := err.(*qerr.QuicError); !ok { |
|||
utils.Errorf("error handling h2 request: %s", err.Error()) |
|||
} |
|||
session.Close(err) |
|||
return |
|||
} |
|||
} |
|||
}() |
|||
} |
|||
|
|||
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { |
|||
h2frame, err := h2framer.ReadFrame() |
|||
if err != nil { |
|||
return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") |
|||
} |
|||
h2headersFrame, ok := h2frame.(*http2.HeadersFrame) |
|||
if !ok { |
|||
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame") |
|||
} |
|||
if !h2headersFrame.HeadersEnded() { |
|||
return errors.New("http2 header continuation not implemented") |
|||
} |
|||
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment()) |
|||
if err != nil { |
|||
utils.Errorf("invalid http2 headers encoding: %s", err.Error()) |
|||
return err |
|||
} |
|||
|
|||
req, err := requestFromHeaders(headers) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
req.RemoteAddr = session.RemoteAddr().String() |
|||
|
|||
if utils.Debug() { |
|||
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID) |
|||
} else { |
|||
utils.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) |
|||
} |
|||
|
|||
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID)) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request
|
|||
if dataStream == nil { |
|||
return nil |
|||
} |
|||
|
|||
var streamEnded bool |
|||
if h2headersFrame.StreamEnded() { |
|||
dataStream.(remoteCloser).CloseRemote(0) |
|||
streamEnded = true |
|||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
|||
} |
|||
|
|||
reqBody := newRequestBody(dataStream) |
|||
req.Body = reqBody |
|||
|
|||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) |
|||
|
|||
go func() { |
|||
handler := s.Handler |
|||
if handler == nil { |
|||
handler = http.DefaultServeMux |
|||
} |
|||
panicked := false |
|||
func() { |
|||
defer func() { |
|||
if p := recover(); p != nil { |
|||
// Copied from net/http/server.go
|
|||
const size = 64 << 10 |
|||
buf := make([]byte, size) |
|||
buf = buf[:runtime.Stack(buf, false)] |
|||
utils.Errorf("http: panic serving: %v\n%s", p, buf) |
|||
panicked = true |
|||
} |
|||
}() |
|||
handler.ServeHTTP(responseWriter, req) |
|||
}() |
|||
if panicked { |
|||
responseWriter.WriteHeader(500) |
|||
} else { |
|||
responseWriter.WriteHeader(200) |
|||
} |
|||
if responseWriter.dataStream != nil { |
|||
if !streamEnded && !reqBody.requestRead { |
|||
responseWriter.dataStream.Reset(nil) |
|||
} |
|||
responseWriter.dataStream.Close() |
|||
} |
|||
if s.CloseAfterFirstRequest { |
|||
time.Sleep(100 * time.Millisecond) |
|||
session.Close(nil) |
|||
} |
|||
}() |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
|
|||
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
|||
func (s *Server) Close() error { |
|||
s.listenerMutex.Lock() |
|||
defer s.listenerMutex.Unlock() |
|||
if s.listener != nil { |
|||
err := s.listener.Close() |
|||
s.listener = nil |
|||
return err |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
|
|||
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
|||
func (s *Server) CloseGracefully(timeout time.Duration) error { |
|||
// TODO: implement
|
|||
return nil |
|||
} |
|||
|
|||
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
|
|||
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
|
|||
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
|
|||
func (s *Server) SetQuicHeaders(hdr http.Header) error { |
|||
port := atomic.LoadUint32(&s.port) |
|||
|
|||
if port == 0 { |
|||
// Extract port from s.Server.Addr
|
|||
_, portStr, err := net.SplitHostPort(s.Server.Addr) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
portInt, err := net.LookupPort("tcp", portStr) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
port = uint32(portInt) |
|||
atomic.StoreUint32(&s.port, port) |
|||
} |
|||
|
|||
if s.supportedVersionsAsString == "" { |
|||
for i, v := range protocol.SupportedVersions { |
|||
s.supportedVersionsAsString += strconv.Itoa(int(v)) |
|||
if i != len(protocol.SupportedVersions)-1 { |
|||
s.supportedVersionsAsString += "," |
|||
} |
|||
} |
|||
} |
|||
|
|||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString)) |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// ListenAndServeQUIC listens on the UDP network address addr and calls the
|
|||
// handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is
|
|||
// used when handler is nil.
|
|||
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error { |
|||
server := &Server{ |
|||
Server: &http.Server{ |
|||
Addr: addr, |
|||
Handler: handler, |
|||
}, |
|||
} |
|||
return server.ListenAndServeTLS(certFile, keyFile) |
|||
} |
|||
|
|||
// ListenAndServe listens on the given network address for both, TLS and QUIC
|
|||
// connetions in parallel. It returns if one of the two returns an error.
|
|||
// http.DefaultServeMux is used when handler is nil.
|
|||
// The correct Alt-Svc headers for QUIC are set.
|
|||
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error { |
|||
// Load certs
|
|||
var err error |
|||
certs := make([]tls.Certificate, 1) |
|||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
// We currently only use the cert-related stuff from tls.Config,
|
|||
// so we don't need to make a full copy.
|
|||
config := &tls.Config{ |
|||
Certificates: certs, |
|||
} |
|||
|
|||
// Open the listeners
|
|||
udpAddr, err := net.ResolveUDPAddr("udp", addr) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
udpConn, err := net.ListenUDP("udp", udpAddr) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
defer udpConn.Close() |
|||
|
|||
tcpAddr, err := net.ResolveTCPAddr("tcp", addr) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
tcpConn, err := net.ListenTCP("tcp", tcpAddr) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
defer tcpConn.Close() |
|||
|
|||
// Start the servers
|
|||
httpServer := &http.Server{ |
|||
Addr: addr, |
|||
TLSConfig: config, |
|||
} |
|||
|
|||
quicServer := &Server{ |
|||
Server: httpServer, |
|||
} |
|||
|
|||
if handler == nil { |
|||
handler = http.DefaultServeMux |
|||
} |
|||
httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
quicServer.SetQuicHeaders(w.Header()) |
|||
handler.ServeHTTP(w, r) |
|||
}) |
|||
|
|||
hErr := make(chan error) |
|||
qErr := make(chan error) |
|||
go func() { |
|||
hErr <- httpServer.Serve(tcpConn) |
|||
}() |
|||
go func() { |
|||
qErr <- quicServer.Serve(udpConn) |
|||
}() |
|||
|
|||
select { |
|||
case err := <-hErr: |
|||
quicServer.Close() |
|||
return err |
|||
case err := <-qErr: |
|||
// Cannot close the HTTP server or wait for requests to complete properly :/
|
|||
return err |
|||
} |
|||
} |
|||
@ -1,265 +0,0 @@ |
|||
package handshake |
|||
|
|||
import ( |
|||
"bytes" |
|||
"sync" |
|||
"time" |
|||
|
|||
"github.com/lucas-clemente/quic-go/internal/utils" |
|||
"github.com/lucas-clemente/quic-go/protocol" |
|||
"github.com/lucas-clemente/quic-go/qerr" |
|||
) |
|||
|
|||
// ConnectionParametersManager negotiates and stores the connection parameters
|
|||
// A ConnectionParametersManager can be used for a server as well as a client
|
|||
// For the server:
|
|||
// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation
|
|||
// 2. call GetHelloMap to get the values to send in the SHLO
|
|||
// For the client:
|
|||
// 1. call GetHelloMap to get the values to send in a CHLO
|
|||
// 2. call SetFromMap with the values received in the SHLO
|
|||
type ConnectionParametersManager interface { |
|||
SetFromMap(map[Tag][]byte) error |
|||
GetHelloMap() (map[Tag][]byte, error) |
|||
|
|||
GetSendStreamFlowControlWindow() protocol.ByteCount |
|||
GetSendConnectionFlowControlWindow() protocol.ByteCount |
|||
GetReceiveStreamFlowControlWindow() protocol.ByteCount |
|||
GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount |
|||
GetReceiveConnectionFlowControlWindow() protocol.ByteCount |
|||
GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount |
|||
GetMaxOutgoingStreams() uint32 |
|||
GetMaxIncomingStreams() uint32 |
|||
GetIdleConnectionStateLifetime() time.Duration |
|||
TruncateConnectionID() bool |
|||
} |
|||
|
|||
type connectionParametersManager struct { |
|||
mutex sync.RWMutex |
|||
|
|||
version protocol.VersionNumber |
|||
perspective protocol.Perspective |
|||
|
|||
flowControlNegotiated bool |
|||
|
|||
truncateConnectionID bool |
|||
maxStreamsPerConnection uint32 |
|||
maxIncomingDynamicStreamsPerConnection uint32 |
|||
idleConnectionStateLifetime time.Duration |
|||
sendStreamFlowControlWindow protocol.ByteCount |
|||
sendConnectionFlowControlWindow protocol.ByteCount |
|||
receiveStreamFlowControlWindow protocol.ByteCount |
|||
receiveConnectionFlowControlWindow protocol.ByteCount |
|||
maxReceiveStreamFlowControlWindow protocol.ByteCount |
|||
maxReceiveConnectionFlowControlWindow protocol.ByteCount |
|||
} |
|||
|
|||
var _ ConnectionParametersManager = &connectionParametersManager{} |
|||
|
|||
// ErrMalformedTag is returned when the tag value cannot be read
|
|||
var ( |
|||
ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value") |
|||
ErrFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported") |
|||
) |
|||
|
|||
// NewConnectionParamatersManager creates a new connection parameters manager
|
|||
func NewConnectionParamatersManager( |
|||
pers protocol.Perspective, v protocol.VersionNumber, |
|||
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount, |
|||
) ConnectionParametersManager { |
|||
h := &connectionParametersManager{ |
|||
perspective: pers, |
|||
version: v, |
|||
sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client
|
|||
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
|
|||
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, |
|||
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, |
|||
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, |
|||
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, |
|||
} |
|||
|
|||
if h.perspective == protocol.PerspectiveServer { |
|||
h.idleConnectionStateLifetime = protocol.DefaultIdleTimeout |
|||
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
|
|||
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective
|
|||
} else { |
|||
h.idleConnectionStateLifetime = protocol.MaxIdleTimeoutClient |
|||
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
|
|||
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective
|
|||
} |
|||
|
|||
return h |
|||
} |
|||
|
|||
// SetFromMap reads all params
|
|||
func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error { |
|||
h.mutex.Lock() |
|||
defer h.mutex.Unlock() |
|||
|
|||
if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer { |
|||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) |
|||
if err != nil { |
|||
return ErrMalformedTag |
|||
} |
|||
h.truncateConnectionID = (clientValue == 0) |
|||
} |
|||
if value, ok := params[TagMSPC]; ok { |
|||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) |
|||
if err != nil { |
|||
return ErrMalformedTag |
|||
} |
|||
h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue) |
|||
} |
|||
if value, ok := params[TagMIDS]; ok { |
|||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) |
|||
if err != nil { |
|||
return ErrMalformedTag |
|||
} |
|||
h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue) |
|||
} |
|||
if value, ok := params[TagICSL]; ok { |
|||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) |
|||
if err != nil { |
|||
return ErrMalformedTag |
|||
} |
|||
h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second) |
|||
} |
|||
if value, ok := params[TagSFCW]; ok { |
|||
if h.flowControlNegotiated { |
|||
return ErrFlowControlRenegotiationNotSupported |
|||
} |
|||
sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value)) |
|||
if err != nil { |
|||
return ErrMalformedTag |
|||
} |
|||
h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow) |
|||
} |
|||
if value, ok := params[TagCFCW]; ok { |
|||
if h.flowControlNegotiated { |
|||
return ErrFlowControlRenegotiationNotSupported |
|||
} |
|||
sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value)) |
|||
if err != nil { |
|||
return ErrMalformedTag |
|||
} |
|||
h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow) |
|||
} |
|||
|
|||
_, containsSFCW := params[TagSFCW] |
|||
_, containsCFCW := params[TagCFCW] |
|||
if containsCFCW || containsSFCW { |
|||
h.flowControlNegotiated = true |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (h *connectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 { |
|||
return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection) |
|||
} |
|||
|
|||
func (h *connectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 { |
|||
return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection) |
|||
} |
|||
|
|||
func (h *connectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration { |
|||
if h.perspective == protocol.PerspectiveServer { |
|||
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutServer) |
|||
} |
|||
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutClient) |
|||
} |
|||
|
|||
// GetHelloMap gets all parameters needed for the Hello message
|
|||
func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) { |
|||
sfcw := bytes.NewBuffer([]byte{}) |
|||
utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow())) |
|||
cfcw := bytes.NewBuffer([]byte{}) |
|||
utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow())) |
|||
mspc := bytes.NewBuffer([]byte{}) |
|||
utils.WriteUint32(mspc, h.maxStreamsPerConnection) |
|||
mids := bytes.NewBuffer([]byte{}) |
|||
utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection) |
|||
icsl := bytes.NewBuffer([]byte{}) |
|||
utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second)) |
|||
|
|||
return map[Tag][]byte{ |
|||
TagICSL: icsl.Bytes(), |
|||
TagMSPC: mspc.Bytes(), |
|||
TagMIDS: mids.Bytes(), |
|||
TagCFCW: cfcw.Bytes(), |
|||
TagSFCW: sfcw.Bytes(), |
|||
}, nil |
|||
} |
|||
|
|||
// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data
|
|||
func (h *connectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { |
|||
h.mutex.RLock() |
|||
defer h.mutex.RUnlock() |
|||
return h.sendStreamFlowControlWindow |
|||
} |
|||
|
|||
// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data
|
|||
func (h *connectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { |
|||
h.mutex.RLock() |
|||
defer h.mutex.RUnlock() |
|||
return h.sendConnectionFlowControlWindow |
|||
} |
|||
|
|||
// GetReceiveStreamFlowControlWindow gets the size of the stream-level flow control window for receiving data
|
|||
func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { |
|||
h.mutex.RLock() |
|||
defer h.mutex.RUnlock() |
|||
return h.receiveStreamFlowControlWindow |
|||
} |
|||
|
|||
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
|
|||
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { |
|||
return h.maxReceiveStreamFlowControlWindow |
|||
} |
|||
|
|||
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
|
|||
func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { |
|||
h.mutex.RLock() |
|||
defer h.mutex.RUnlock() |
|||
return h.receiveConnectionFlowControlWindow |
|||
} |
|||
|
|||
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
|
|||
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { |
|||
return h.maxReceiveConnectionFlowControlWindow |
|||
} |
|||
|
|||
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection
|
|||
func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 { |
|||
h.mutex.RLock() |
|||
defer h.mutex.RUnlock() |
|||
|
|||
return h.maxIncomingDynamicStreamsPerConnection |
|||
} |
|||
|
|||
// GetMaxIncomingStreams get the maximum number of incoming streams per connection
|
|||
func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 { |
|||
h.mutex.RLock() |
|||
defer h.mutex.RUnlock() |
|||
|
|||
maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection |
|||
return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier)) |
|||
} |
|||
|
|||
// GetIdleConnectionStateLifetime gets the idle timeout
|
|||
func (h *connectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { |
|||
h.mutex.RLock() |
|||
defer h.mutex.RUnlock() |
|||
return h.idleConnectionStateLifetime |
|||
} |
|||
|
|||
// TruncateConnectionID determines if the client requests truncated ConnectionIDs
|
|||
func (h *connectionParametersManager) TruncateConnectionID() bool { |
|||
if h.perspective == protocol.PerspectiveClient { |
|||
return false |
|||
} |
|||
|
|||
h.mutex.RLock() |
|||
defer h.mutex.RUnlock() |
|||
return h.truncateConnectionID |
|||
} |
|||
@ -1,100 +0,0 @@ |
|||
package handshake |
|||
|
|||
import ( |
|||
"encoding/asn1" |
|||
"fmt" |
|||
"net" |
|||
"time" |
|||
|
|||
"github.com/lucas-clemente/quic-go/crypto" |
|||
) |
|||
|
|||
const ( |
|||
stkPrefixIP byte = iota |
|||
stkPrefixString |
|||
) |
|||
|
|||
// An STK is a source address token
|
|||
type STK struct { |
|||
RemoteAddr string |
|||
SentTime time.Time |
|||
} |
|||
|
|||
// token is the struct that is used for ASN1 serialization and deserialization
|
|||
type token struct { |
|||
Data []byte |
|||
Timestamp int64 |
|||
} |
|||
|
|||
// An STKGenerator generates STKs
|
|||
type STKGenerator struct { |
|||
stkSource crypto.StkSource |
|||
} |
|||
|
|||
// NewSTKGenerator initializes a new STKGenerator
|
|||
func NewSTKGenerator() (*STKGenerator, error) { |
|||
stkSource, err := crypto.NewStkSource() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return &STKGenerator{ |
|||
stkSource: stkSource, |
|||
}, nil |
|||
} |
|||
|
|||
// NewToken generates a new STK token for a given source address
|
|||
func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) { |
|||
data, err := asn1.Marshal(token{ |
|||
Data: encodeRemoteAddr(raddr), |
|||
Timestamp: time.Now().Unix(), |
|||
}) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return g.stkSource.NewToken(data) |
|||
} |
|||
|
|||
// DecodeToken decodes an STK token
|
|||
func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) { |
|||
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
|
|||
if len(encrypted) == 0 { |
|||
return nil, nil |
|||
} |
|||
|
|||
data, err := g.stkSource.DecodeToken(encrypted) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
t := &token{} |
|||
rest, err := asn1.Unmarshal(data, t) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if len(rest) != 0 { |
|||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) |
|||
} |
|||
return &STK{ |
|||
RemoteAddr: decodeRemoteAddr(t.Data), |
|||
SentTime: time.Unix(t.Timestamp, 0), |
|||
}, nil |
|||
} |
|||
|
|||
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
|
|||
func encodeRemoteAddr(remoteAddr net.Addr) []byte { |
|||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { |
|||
return append([]byte{stkPrefixIP}, udpAddr.IP...) |
|||
} |
|||
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...) |
|||
} |
|||
|
|||
// decodeRemoteAddr decodes the remote address saved in the STK
|
|||
func decodeRemoteAddr(data []byte) string { |
|||
// data will never be empty for an STK that we generated. Check it to be on the safe side
|
|||
if len(data) == 0 { |
|||
return "" |
|||
} |
|||
if data[0] == stkPrefixIP { |
|||
return net.IP(data[1:]).String() |
|||
} |
|||
return string(data[1:]) |
|||
} |
|||
@ -1,9 +1,10 @@ |
|||
package crypto |
|||
|
|||
import "github.com/lucas-clemente/quic-go/protocol" |
|||
import "github.com/lucas-clemente/quic-go/internal/protocol" |
|||
|
|||
// An AEAD implements QUIC's authenticated encryption and associated data
|
|||
type AEAD interface { |
|||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) |
|||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte |
|||
Overhead() int |
|||
} |
|||
@ -0,0 +1,72 @@ |
|||
package crypto |
|||
|
|||
import ( |
|||
"crypto/cipher" |
|||
"encoding/binary" |
|||
"errors" |
|||
|
|||
"github.com/lucas-clemente/aes12" |
|||
|
|||
"github.com/lucas-clemente/quic-go/internal/protocol" |
|||
) |
|||
|
|||
type aeadAESGCM12 struct { |
|||
otherIV []byte |
|||
myIV []byte |
|||
encrypter cipher.AEAD |
|||
decrypter cipher.AEAD |
|||
} |
|||
|
|||
var _ AEAD = &aeadAESGCM12{} |
|||
|
|||
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
|
|||
//
|
|||
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
|||
// tag size, and couples the cipher and aes packages closely.
|
|||
// See https://github.com/lucas-clemente/aes12.
|
|||
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) { |
|||
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 { |
|||
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs") |
|||
} |
|||
encrypterCipher, err := aes12.NewCipher(myKey) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
encrypter, err := aes12.NewGCM(encrypterCipher) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
decrypterCipher, err := aes12.NewCipher(otherKey) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
decrypter, err := aes12.NewGCM(decrypterCipher) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return &aeadAESGCM12{ |
|||
otherIV: otherIV, |
|||
myIV: myIV, |
|||
encrypter: encrypter, |
|||
decrypter: decrypter, |
|||
}, nil |
|||
} |
|||
|
|||
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { |
|||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData) |
|||
} |
|||
|
|||
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { |
|||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData) |
|||
} |
|||
|
|||
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { |
|||
res := make([]byte, 12) |
|||
copy(res[0:4], iv) |
|||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber)) |
|||
return res |
|||
} |
|||
|
|||
func (aead *aeadAESGCM12) Overhead() int { |
|||
return aead.encrypter.Overhead() |
|||
} |
|||
@ -0,0 +1,74 @@ |
|||
package crypto |
|||
|
|||
import ( |
|||
"crypto/aes" |
|||
"crypto/cipher" |
|||
"encoding/binary" |
|||
"errors" |
|||
|
|||
"github.com/lucas-clemente/quic-go/internal/protocol" |
|||
) |
|||
|
|||
type aeadAESGCM struct { |
|||
otherIV []byte |
|||
myIV []byte |
|||
encrypter cipher.AEAD |
|||
decrypter cipher.AEAD |
|||
} |
|||
|
|||
var _ AEAD = &aeadAESGCM{} |
|||
|
|||
const ivLen = 12 |
|||
|
|||
// NewAEADAESGCM creates a AEAD using AES-GCM
|
|||
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) { |
|||
// the IVs need to be at least 8 bytes long, otherwise we can't compute the nonce
|
|||
if len(otherIV) != ivLen || len(myIV) != ivLen { |
|||
return nil, errors.New("AES-GCM: expected 12 byte IVs") |
|||
} |
|||
|
|||
encrypterCipher, err := aes.NewCipher(myKey) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
encrypter, err := cipher.NewGCM(encrypterCipher) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
decrypterCipher, err := aes.NewCipher(otherKey) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
decrypter, err := cipher.NewGCM(decrypterCipher) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return &aeadAESGCM{ |
|||
otherIV: otherIV, |
|||
myIV: myIV, |
|||
encrypter: encrypter, |
|||
decrypter: decrypter, |
|||
}, nil |
|||
} |
|||
|
|||
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { |
|||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData) |
|||
} |
|||
|
|||
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { |
|||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData) |
|||
} |
|||
|
|||
func (aead *aeadAESGCM) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { |
|||
nonce := make([]byte, ivLen) |
|||
binary.BigEndian.PutUint64(nonce[ivLen-8:], uint64(packetNumber)) |
|||
for i := 0; i < ivLen; i++ { |
|||
nonce[i] ^= iv[i] |
|||
} |
|||
return nonce |
|||
} |
|||
|
|||
func (aead *aeadAESGCM) Overhead() int { |
|||
return aead.encrypter.Overhead() |
|||
} |
|||
@ -0,0 +1,49 @@ |
|||
package crypto |
|||
|
|||
import ( |
|||
"github.com/bifurcation/mint" |
|||
"github.com/lucas-clemente/quic-go/internal/protocol" |
|||
) |
|||
|
|||
const ( |
|||
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret" |
|||
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret" |
|||
) |
|||
|
|||
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
|||
type TLSExporter interface { |
|||
GetCipherSuite() mint.CipherSuiteParams |
|||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) |
|||
} |
|||
|
|||
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
|
|||
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) { |
|||
var myLabel, otherLabel string |
|||
if pers == protocol.PerspectiveClient { |
|||
myLabel = clientExporterLabel |
|||
otherLabel = serverExporterLabel |
|||
} else { |
|||
myLabel = serverExporterLabel |
|||
otherLabel = clientExporterLabel |
|||
} |
|||
myKey, myIV, err := computeKeyAndIV(tls, myLabel) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) |
|||
} |
|||
|
|||
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) { |
|||
cs := tls.GetCipherSuite() |
|||
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size()) |
|||
if err != nil { |
|||
return nil, nil, err |
|||
} |
|||
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen) |
|||
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen) |
|||
return key, iv, nil |
|||
} |
|||
@ -0,0 +1,11 @@ |
|||
package crypto |
|||
|
|||
import "github.com/lucas-clemente/quic-go/internal/protocol" |
|||
|
|||
// NewNullAEAD creates a NullAEAD
|
|||
func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) { |
|||
if v.UsesTLS() { |
|||
return newNullAEADAESGCM(connID, p) |
|||
} |
|||
return &nullAEADFNV128a{perspective: p}, nil |
|||
} |
|||
@ -0,0 +1,44 @@ |
|||
package crypto |
|||
|
|||
import ( |
|||
"crypto" |
|||
"encoding/binary" |
|||
|
|||
"github.com/bifurcation/mint" |
|||
"github.com/lucas-clemente/quic-go/internal/protocol" |
|||
) |
|||
|
|||
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39} |
|||
|
|||
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) { |
|||
clientSecret, serverSecret := computeSecrets(connectionID) |
|||
|
|||
var mySecret, otherSecret []byte |
|||
if pers == protocol.PerspectiveClient { |
|||
mySecret = clientSecret |
|||
otherSecret = serverSecret |
|||
} else { |
|||
mySecret = serverSecret |
|||
otherSecret = clientSecret |
|||
} |
|||
|
|||
myKey, myIV := computeNullAEADKeyAndIV(mySecret) |
|||
otherKey, otherIV := computeNullAEADKeyAndIV(otherSecret) |
|||
|
|||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) |
|||
} |
|||
|
|||
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) { |
|||
connID := make([]byte, 8) |
|||
binary.BigEndian.PutUint64(connID, uint64(connectionID)) |
|||
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID) |
|||
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size()) |
|||
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size()) |
|||
return |
|||
} |
|||
|
|||
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) { |
|||
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16) |
|||
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12) |
|||
return |
|||
} |
|||
Some files were not shown because too many files changed in this diff
Loading…
Reference in new issue