You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
4.9 KiB
186 lines
4.9 KiB
package clientcore
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestWrapConnRoundTrip(t *testing.T) {
|
|
key := bytes.Repeat([]byte{0x42}, wrapKeyLen)
|
|
payload := []byte("dtls record bytes")
|
|
|
|
client, err := newWrapConn(key, false)
|
|
if err != nil {
|
|
t.Fatalf("newWrapConn(client): %v", err)
|
|
}
|
|
server, err := newWrapConn(key, true)
|
|
if err != nil {
|
|
t.Fatalf("newWrapConn(server): %v", err)
|
|
}
|
|
|
|
wire := make([]byte, wrapMaxWire(len(payload)))
|
|
n, err := client.wrapInto(wire, payload)
|
|
if err != nil {
|
|
t.Fatalf("wrapInto returned error: %v", err)
|
|
}
|
|
wire = wire[:n]
|
|
|
|
if wire[0] != wrapRTPVersion {
|
|
t.Fatalf("RTP byte0 = 0x%02X, want 0x%02X", wire[0], wrapRTPVersion)
|
|
}
|
|
if wire[1] != wrapRTPPT {
|
|
t.Fatalf("RTP byte1 (PT) = 0x%02X, want 0x%02X", wire[1], wrapRTPPT)
|
|
}
|
|
if bytes.Contains(wire, payload) {
|
|
t.Fatalf("wrapped packet contains plaintext payload")
|
|
}
|
|
|
|
dst := make([]byte, len(payload))
|
|
n, err = server.unwrapPacket(wire, dst)
|
|
if err != nil {
|
|
t.Fatalf("unwrapPacket returned error: %v", err)
|
|
}
|
|
if n != len(payload) {
|
|
t.Fatalf("unwrapped len = %d, want %d", n, len(payload))
|
|
}
|
|
if !bytes.Equal(dst[:n], payload) {
|
|
t.Fatalf("round trip mismatch: got %q want %q", dst[:n], payload)
|
|
}
|
|
}
|
|
|
|
func TestWrapRTPHeaderProgression(t *testing.T) {
|
|
key := bytes.Repeat([]byte{0x42}, wrapKeyLen)
|
|
wc, err := newWrapConn(key, false)
|
|
if err != nil {
|
|
t.Fatalf("newWrapConn: %v", err)
|
|
}
|
|
payload := []byte("x")
|
|
|
|
wire1 := make([]byte, wrapMaxWire(len(payload)))
|
|
n1, err := wc.wrapInto(wire1, payload)
|
|
if err != nil {
|
|
t.Fatalf("wrapInto 1: %v", err)
|
|
}
|
|
wire2 := make([]byte, wrapMaxWire(len(payload)))
|
|
n2, err := wc.wrapInto(wire2, payload)
|
|
if err != nil {
|
|
t.Fatalf("wrapInto 2: %v", err)
|
|
}
|
|
if n1 != n2 {
|
|
t.Fatalf("wire size variance: %d vs %d", n1, n2)
|
|
}
|
|
|
|
seq1 := binary.BigEndian.Uint16(wire1[2:4])
|
|
seq2 := binary.BigEndian.Uint16(wire2[2:4])
|
|
if seq2 != seq1+1 {
|
|
t.Fatalf("seq did not increment: %d -> %d", seq1, seq2)
|
|
}
|
|
|
|
ts1 := binary.BigEndian.Uint32(wire1[4:8])
|
|
ts2 := binary.BigEndian.Uint32(wire2[4:8])
|
|
if ts2-ts1 != wrapTSStep {
|
|
t.Fatalf("timestamp step = %d, want %d", ts2-ts1, wrapTSStep)
|
|
}
|
|
|
|
if !bytes.Equal(wire1[8:12], wire2[8:12]) {
|
|
t.Fatalf("SSRC changed between packets")
|
|
}
|
|
}
|
|
|
|
func TestWrapDirectionBit(t *testing.T) {
|
|
key := bytes.Repeat([]byte{0x42}, wrapKeyLen)
|
|
client, err := newWrapConn(key, false)
|
|
if err != nil {
|
|
t.Fatalf("newWrapConn(client): %v", err)
|
|
}
|
|
server, err := newWrapConn(key, true)
|
|
if err != nil {
|
|
t.Fatalf("newWrapConn(server): %v", err)
|
|
}
|
|
|
|
if client.sessionID[0]&0x80 != 0 {
|
|
t.Fatalf("client sessionID MSB should be 0, got 0x%02X", client.sessionID[0])
|
|
}
|
|
if server.sessionID[0]&0x80 == 0 {
|
|
t.Fatalf("server sessionID MSB should be 1, got 0x%02X", server.sessionID[0])
|
|
}
|
|
if client.ssrc[0]&0x80 != 0 {
|
|
t.Fatalf("client SSRC MSB should be 0, got 0x%02X", client.ssrc[0])
|
|
}
|
|
if server.ssrc[0]&0x80 == 0 {
|
|
t.Fatalf("server SSRC MSB should be 1, got 0x%02X", server.ssrc[0])
|
|
}
|
|
}
|
|
|
|
func TestDecodeWrapKeyRequiresValidKeyWhenEnabled(t *testing.T) {
|
|
if key, err := decodeWrapKey(false, ""); err != nil || key != nil {
|
|
t.Fatalf("disabled decodeWrapKey = (%v, %v), want (nil, nil)", key, err)
|
|
}
|
|
|
|
if _, err := decodeWrapKey(true, ""); err == nil {
|
|
t.Fatalf("decodeWrapKey accepted empty key")
|
|
}
|
|
|
|
shortHex := strings.Repeat("ab", wrapKeyLen-1)
|
|
if _, err := decodeWrapKey(true, shortHex); err == nil {
|
|
t.Fatalf("decodeWrapKey accepted short key")
|
|
}
|
|
|
|
fullHex := strings.Repeat("ab", wrapKeyLen)
|
|
key, err := decodeWrapKey(true, fullHex)
|
|
if err != nil {
|
|
t.Fatalf("decodeWrapKey returned error: %v", err)
|
|
}
|
|
if len(key) != wrapKeyLen {
|
|
t.Fatalf("decoded key len = %d, want %d", len(key), wrapKeyLen)
|
|
}
|
|
}
|
|
|
|
func TestUnwrapRejectsShortPacket(t *testing.T) {
|
|
key := bytes.Repeat([]byte{0x42}, wrapKeyLen)
|
|
wc, err := newWrapConn(key, false)
|
|
if err != nil {
|
|
t.Fatalf("newWrapConn: %v", err)
|
|
}
|
|
if _, err := wc.unwrapPacket([]byte("short"), make([]byte, 16)); err == nil {
|
|
t.Fatalf("unwrapPacket accepted short packet")
|
|
}
|
|
}
|
|
|
|
func TestUnwrapRejectsTamperedPacket(t *testing.T) {
|
|
key := bytes.Repeat([]byte{0x42}, wrapKeyLen)
|
|
client, err := newWrapConn(key, false)
|
|
if err != nil {
|
|
t.Fatalf("newWrapConn(client): %v", err)
|
|
}
|
|
server, err := newWrapConn(key, true)
|
|
if err != nil {
|
|
t.Fatalf("newWrapConn(server): %v", err)
|
|
}
|
|
|
|
payload := []byte("integrity test")
|
|
wire := make([]byte, wrapMaxWire(len(payload)))
|
|
n, err := client.wrapInto(wire, payload)
|
|
if err != nil {
|
|
t.Fatalf("wrapInto: %v", err)
|
|
}
|
|
wire = wire[:n]
|
|
wire[wrapHeaderLen+1] ^= 0xFF
|
|
|
|
dst := make([]byte, 1600)
|
|
if _, unwrapErr := server.unwrapPacket(wire, dst); unwrapErr == nil {
|
|
t.Fatalf("unwrapPacket accepted tampered ciphertext")
|
|
}
|
|
|
|
n, err = client.wrapInto(wire, payload)
|
|
if err != nil {
|
|
t.Fatalf("wrapInto: %v", err)
|
|
}
|
|
wire = wire[:n]
|
|
wire[8] ^= 0x01
|
|
if _, unwrapErr := server.unwrapPacket(wire, dst); unwrapErr == nil {
|
|
t.Fatalf("unwrapPacket accepted tampered AAD")
|
|
}
|
|
}
|
|
|