You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
196 lines
5.0 KiB
196 lines
5.0 KiB
// SPDX-License-Identifier: MIT
|
|
|
|
package main
|
|
|
|
import (
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
dtlsnet "github.com/pion/dtls/v3/pkg/net"
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
)
|
|
|
|
// Wire format is identical to client. Server sets the MSB of sessionID/SSRC;
|
|
// client clears it. RTP header fields are per-conn.
|
|
|
|
const (
|
|
wrapKeyLen = 32
|
|
wrapRTPHdrLen = 12
|
|
wrapNonceLen = 12
|
|
wrapTagLen = 16
|
|
wrapHeaderLen = wrapRTPHdrLen + wrapNonceLen
|
|
wrapOverhead = wrapHeaderLen + wrapTagLen
|
|
wrapRTPVersion = 0x80
|
|
wrapRTPPT = 0x6F
|
|
wrapTSStep = 960
|
|
)
|
|
|
|
var bufPool = sync.Pool{
|
|
New: func() any {
|
|
b := make([]byte, 1600+wrapOverhead)
|
|
return &b
|
|
},
|
|
}
|
|
|
|
type wrapState struct {
|
|
aead cipher.AEAD
|
|
}
|
|
|
|
func newWrapState(key []byte) (*wrapState, error) {
|
|
if len(key) != wrapKeyLen {
|
|
return nil, fmt.Errorf("wrap: key must be %d bytes (got %d)", wrapKeyLen, len(key))
|
|
}
|
|
aead, err := chacha20poly1305.New(key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("wrap: aead init: %w", err)
|
|
}
|
|
return &wrapState{aead: aead}, nil
|
|
}
|
|
|
|
func listenWrapped(addr *net.UDPAddr, key []byte) (dtlsnet.PacketListener, error) {
|
|
ws, err := newWrapState(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
innerConn, err := listenPacketInfoUDP("udp", addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("wrap: udp listen: %w", err)
|
|
}
|
|
return &wrapPacketListener{
|
|
inner: newUDPPacketListener(innerConn, nil),
|
|
ws: ws,
|
|
}, nil
|
|
}
|
|
|
|
type wrapPacketListener struct {
|
|
inner dtlsnet.PacketListener
|
|
ws *wrapState
|
|
}
|
|
|
|
func (l *wrapPacketListener) Accept() (net.PacketConn, net.Addr, error) {
|
|
pc, addr, err := l.inner.Accept()
|
|
if err != nil {
|
|
return pc, addr, err
|
|
}
|
|
c := &wrapPacketConn{inner: pc, ws: l.ws}
|
|
|
|
var rnd [16]byte
|
|
if _, err := rand.Read(rnd[:]); err != nil {
|
|
return nil, addr, fmt.Errorf("wrap: rand init: %w", err)
|
|
}
|
|
copy(c.sessionID[:], rnd[0:4])
|
|
copy(c.ssrc[:], rnd[4:8])
|
|
c.sessionID[0] |= 0x80
|
|
c.ssrc[0] |= 0x80
|
|
c.seq.Store(uint32(binary.BigEndian.Uint16(rnd[8:10])))
|
|
c.timestamp.Store(binary.BigEndian.Uint32(rnd[10:14]))
|
|
|
|
var cb [8]byte
|
|
if _, err := rand.Read(cb[:]); err != nil {
|
|
return nil, addr, fmt.Errorf("wrap: counter rand: %w", err)
|
|
}
|
|
c.counter.Store(binary.BigEndian.Uint64(cb[:]))
|
|
return c, addr, nil
|
|
}
|
|
|
|
func (l *wrapPacketListener) Close() error { return l.inner.Close() }
|
|
func (l *wrapPacketListener) Addr() net.Addr { return l.inner.Addr() }
|
|
|
|
type wrapPacketConn struct {
|
|
inner net.PacketConn
|
|
ws *wrapState
|
|
sessionID [4]byte
|
|
ssrc [4]byte
|
|
counter atomic.Uint64
|
|
seq atomic.Uint32
|
|
timestamp atomic.Uint32
|
|
}
|
|
|
|
func (c *wrapPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
|
bp, ok := bufPool.Get().(*[]byte)
|
|
if !ok {
|
|
return 0, nil, errors.New("wrap: buffer pool returned invalid type")
|
|
}
|
|
buf := *bp
|
|
need := len(p) + wrapOverhead
|
|
if cap(buf) < need {
|
|
buf = make([]byte, need)
|
|
*bp = buf
|
|
}
|
|
defer bufPool.Put(bp)
|
|
|
|
n, addr, err := c.inner.ReadFrom(buf[:cap(buf)])
|
|
if err != nil {
|
|
return 0, addr, err
|
|
}
|
|
wire := buf[:n]
|
|
if len(wire) < wrapOverhead {
|
|
return 0, addr, errors.New("wrap: packet too short")
|
|
}
|
|
nonce := wire[wrapRTPHdrLen : wrapRTPHdrLen+wrapNonceLen]
|
|
aad := wire[:wrapHeaderLen]
|
|
ct := wire[wrapHeaderLen:]
|
|
|
|
plain, err := c.ws.aead.Open(ct[:0], nonce, ct, aad)
|
|
if err != nil {
|
|
return 0, addr, fmt.Errorf("wrap: AEAD open: %w", err)
|
|
}
|
|
if len(plain) > len(p) {
|
|
return 0, addr, errors.New("wrap: dst buffer too small")
|
|
}
|
|
copy(p[:len(plain)], plain)
|
|
return len(plain), addr, nil
|
|
}
|
|
|
|
func (c *wrapPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
|
wireLen := wrapOverhead + len(p)
|
|
|
|
bp, ok := bufPool.Get().(*[]byte)
|
|
if !ok {
|
|
return 0, errors.New("wrap: buffer pool returned invalid type")
|
|
}
|
|
out := *bp
|
|
if cap(out) < wireLen {
|
|
out = make([]byte, wireLen)
|
|
*bp = out
|
|
}
|
|
out = out[:wireLen]
|
|
defer bufPool.Put(bp)
|
|
|
|
out[0] = wrapRTPVersion
|
|
out[1] = wrapRTPPT
|
|
seq := uint16(c.seq.Add(1) - 1)
|
|
binary.BigEndian.PutUint16(out[2:4], seq)
|
|
ts := c.timestamp.Add(wrapTSStep) - wrapTSStep
|
|
binary.BigEndian.PutUint32(out[4:8], ts)
|
|
copy(out[8:12], c.ssrc[:])
|
|
|
|
noncePos := wrapRTPHdrLen
|
|
copy(out[noncePos:noncePos+4], c.sessionID[:])
|
|
ctr := c.counter.Add(1) - 1
|
|
binary.BigEndian.PutUint64(out[noncePos+4:noncePos+wrapNonceLen], ctr)
|
|
|
|
nonce := out[noncePos : noncePos+wrapNonceLen]
|
|
aad := out[:wrapHeaderLen]
|
|
ctPos := wrapHeaderLen
|
|
copy(out[ctPos:], p)
|
|
c.ws.aead.Seal(out[ctPos:ctPos], nonce, out[ctPos:ctPos+len(p)], aad)
|
|
|
|
if _, err := c.inner.WriteTo(out, addr); err != nil {
|
|
return 0, err
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
func (c *wrapPacketConn) Close() error { return c.inner.Close() }
|
|
func (c *wrapPacketConn) LocalAddr() net.Addr { return c.inner.LocalAddr() }
|
|
func (c *wrapPacketConn) SetDeadline(t time.Time) error { return c.inner.SetDeadline(t) }
|
|
func (c *wrapPacketConn) SetReadDeadline(t time.Time) error { return c.inner.SetReadDeadline(t) }
|
|
func (c *wrapPacketConn) SetWriteDeadline(t time.Time) error { return c.inner.SetWriteDeadline(t) }
|
|
|