You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

186 lines
4.9 KiB

package clientcore
import (
"bytes"
"encoding/binary"
"strings"
"testing"
)
func TestWrapConnRoundTrip(t *testing.T) {
key := bytes.Repeat([]byte{0x42}, wrapKeyLen)
payload := []byte("dtls record bytes")
client, err := newWrapConn(key, false)
if err != nil {
t.Fatalf("newWrapConn(client): %v", err)
}
server, err := newWrapConn(key, true)
if err != nil {
t.Fatalf("newWrapConn(server): %v", err)
}
wire := make([]byte, wrapMaxWire(len(payload)))
n, err := client.wrapInto(wire, payload)
if err != nil {
t.Fatalf("wrapInto returned error: %v", err)
}
wire = wire[:n]
if wire[0] != wrapRTPVersion {
t.Fatalf("RTP byte0 = 0x%02X, want 0x%02X", wire[0], wrapRTPVersion)
}
if wire[1] != wrapRTPPT {
t.Fatalf("RTP byte1 (PT) = 0x%02X, want 0x%02X", wire[1], wrapRTPPT)
}
if bytes.Contains(wire, payload) {
t.Fatalf("wrapped packet contains plaintext payload")
}
dst := make([]byte, len(payload))
n, err = server.unwrapPacket(wire, dst)
if err != nil {
t.Fatalf("unwrapPacket returned error: %v", err)
}
if n != len(payload) {
t.Fatalf("unwrapped len = %d, want %d", n, len(payload))
}
if !bytes.Equal(dst[:n], payload) {
t.Fatalf("round trip mismatch: got %q want %q", dst[:n], payload)
}
}
func TestWrapRTPHeaderProgression(t *testing.T) {
key := bytes.Repeat([]byte{0x42}, wrapKeyLen)
wc, err := newWrapConn(key, false)
if err != nil {
t.Fatalf("newWrapConn: %v", err)
}
payload := []byte("x")
wire1 := make([]byte, wrapMaxWire(len(payload)))
n1, err := wc.wrapInto(wire1, payload)
if err != nil {
t.Fatalf("wrapInto 1: %v", err)
}
wire2 := make([]byte, wrapMaxWire(len(payload)))
n2, err := wc.wrapInto(wire2, payload)
if err != nil {
t.Fatalf("wrapInto 2: %v", err)
}
if n1 != n2 {
t.Fatalf("wire size variance: %d vs %d", n1, n2)
}
seq1 := binary.BigEndian.Uint16(wire1[2:4])
seq2 := binary.BigEndian.Uint16(wire2[2:4])
if seq2 != seq1+1 {
t.Fatalf("seq did not increment: %d -> %d", seq1, seq2)
}
ts1 := binary.BigEndian.Uint32(wire1[4:8])
ts2 := binary.BigEndian.Uint32(wire2[4:8])
if ts2-ts1 != wrapTSStep {
t.Fatalf("timestamp step = %d, want %d", ts2-ts1, wrapTSStep)
}
if !bytes.Equal(wire1[8:12], wire2[8:12]) {
t.Fatalf("SSRC changed between packets")
}
}
func TestWrapDirectionBit(t *testing.T) {
key := bytes.Repeat([]byte{0x42}, wrapKeyLen)
client, err := newWrapConn(key, false)
if err != nil {
t.Fatalf("newWrapConn(client): %v", err)
}
server, err := newWrapConn(key, true)
if err != nil {
t.Fatalf("newWrapConn(server): %v", err)
}
if client.sessionID[0]&0x80 != 0 {
t.Fatalf("client sessionID MSB should be 0, got 0x%02X", client.sessionID[0])
}
if server.sessionID[0]&0x80 == 0 {
t.Fatalf("server sessionID MSB should be 1, got 0x%02X", server.sessionID[0])
}
if client.ssrc[0]&0x80 != 0 {
t.Fatalf("client SSRC MSB should be 0, got 0x%02X", client.ssrc[0])
}
if server.ssrc[0]&0x80 == 0 {
t.Fatalf("server SSRC MSB should be 1, got 0x%02X", server.ssrc[0])
}
}
func TestDecodeWrapKeyRequiresValidKeyWhenEnabled(t *testing.T) {
if key, err := decodeWrapKey(false, ""); err != nil || key != nil {
t.Fatalf("disabled decodeWrapKey = (%v, %v), want (nil, nil)", key, err)
}
if _, err := decodeWrapKey(true, ""); err == nil {
t.Fatalf("decodeWrapKey accepted empty key")
}
shortHex := strings.Repeat("ab", wrapKeyLen-1)
if _, err := decodeWrapKey(true, shortHex); err == nil {
t.Fatalf("decodeWrapKey accepted short key")
}
fullHex := strings.Repeat("ab", wrapKeyLen)
key, err := decodeWrapKey(true, fullHex)
if err != nil {
t.Fatalf("decodeWrapKey returned error: %v", err)
}
if len(key) != wrapKeyLen {
t.Fatalf("decoded key len = %d, want %d", len(key), wrapKeyLen)
}
}
func TestUnwrapRejectsShortPacket(t *testing.T) {
key := bytes.Repeat([]byte{0x42}, wrapKeyLen)
wc, err := newWrapConn(key, false)
if err != nil {
t.Fatalf("newWrapConn: %v", err)
}
if _, err := wc.unwrapPacket([]byte("short"), make([]byte, 16)); err == nil {
t.Fatalf("unwrapPacket accepted short packet")
}
}
func TestUnwrapRejectsTamperedPacket(t *testing.T) {
key := bytes.Repeat([]byte{0x42}, wrapKeyLen)
client, err := newWrapConn(key, false)
if err != nil {
t.Fatalf("newWrapConn(client): %v", err)
}
server, err := newWrapConn(key, true)
if err != nil {
t.Fatalf("newWrapConn(server): %v", err)
}
payload := []byte("integrity test")
wire := make([]byte, wrapMaxWire(len(payload)))
n, err := client.wrapInto(wire, payload)
if err != nil {
t.Fatalf("wrapInto: %v", err)
}
wire = wire[:n]
wire[wrapHeaderLen+1] ^= 0xFF
dst := make([]byte, 1600)
if _, unwrapErr := server.unwrapPacket(wire, dst); unwrapErr == nil {
t.Fatalf("unwrapPacket accepted tampered ciphertext")
}
n, err = client.wrapInto(wire, payload)
if err != nil {
t.Fatalf("wrapInto: %v", err)
}
wire = wire[:n]
wire[8] ^= 0x01
if _, unwrapErr := server.unwrapPacket(wire, dst); unwrapErr == nil {
t.Fatalf("unwrapPacket accepted tampered AAD")
}
}