16 changed files with 2383 additions and 13 deletions
@ -0,0 +1,208 @@ |
|||
package api |
|||
|
|||
import ( |
|||
"bytes" |
|||
"context" |
|||
"encoding/base64" |
|||
"encoding/json" |
|||
"fmt" |
|||
"io" |
|||
"net" |
|||
"time" |
|||
|
|||
fhttp "github.com/bogdanfinn/fhttp" |
|||
"github.com/cacggghp/vk-turn-proxy/client/warp/internal" |
|||
"github.com/cacggghp/vk-turn-proxy/client/warp/models" |
|||
) |
|||
|
|||
// Register creates a new user account by registering a WireGuard public key and generating a random Android-like device identifier.
|
|||
// The WireGuard private key isn't stored anywhere, therefore it won't be usable. It's sole purpose is to mimic the Android app's registration process.
|
|||
//
|
|||
// This function sends a POST request to the API to register a new user and returns the created account data.
|
|||
//
|
|||
// Parameters:
|
|||
// - model: string - The device model string to register. (e.g., "PC")
|
|||
// - locale: string - The user's locale. (e.g., "en-US")
|
|||
// - jwt: string - Team token to register.
|
|||
// - acceptTos: bool - Whether the user accepts the Terms of Service (TOS). If false, the user will be prompted to accept.
|
|||
//
|
|||
// Returns:
|
|||
// - models.AccountData: The account data returned from the registration process.
|
|||
// - error: An error if registration fails at any step.
|
|||
//
|
|||
// Example:
|
|||
//
|
|||
// account, err := Register("PC", "en-US", "", false)
|
|||
// if err != nil {
|
|||
// log.Fatalf("Registration failed: %v", err)
|
|||
// }
|
|||
func Register(model, locale, jwt string, acceptTos bool) (models.AccountData, error) { |
|||
wgKey, err := internal.GenerateRandomWgPubkey() |
|||
if err != nil { |
|||
return models.AccountData{}, fmt.Errorf("failed to generate wg key: %v", err) |
|||
} |
|||
serial, err := internal.GenerateRandomAndroidSerial() |
|||
if err != nil { |
|||
return models.AccountData{}, fmt.Errorf("failed to generate serial: %v", err) |
|||
} |
|||
|
|||
if !acceptTos { |
|||
fmt.Print("You must accept the Terms of Service (https://www.cloudflare.com/application/terms/) to register. Do you agree? (y/n): ") |
|||
var response string |
|||
if _, err := fmt.Scanln(&response); err != nil { |
|||
return models.AccountData{}, fmt.Errorf("failed to read user input: %v", err) |
|||
} |
|||
if response != "y" { |
|||
return models.AccountData{}, fmt.Errorf("user did not accept TOS") |
|||
} |
|||
} |
|||
|
|||
data := models.Registration{ |
|||
Key: wgKey, |
|||
InstallID: "", |
|||
FcmToken: "", |
|||
Tos: internal.TimeAsCfString(time.Now()), |
|||
Model: model, |
|||
Serial: serial, |
|||
OsVersion: "", |
|||
KeyType: internal.KeyTypeWg, |
|||
TunType: internal.TunTypeWg, |
|||
Locale: locale, |
|||
} |
|||
|
|||
jsonData, err := json.Marshal(data) |
|||
if err != nil { |
|||
return models.AccountData{}, fmt.Errorf("failed to marshal json: %v", err) |
|||
} |
|||
|
|||
tr := &fhttp.Transport{ |
|||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { |
|||
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", "api.cloudflareclient.com") |
|||
if err != nil || len(ips) == 0 { |
|||
return nil, fmt.Errorf("DNS resolution failed for api.cloudflareclient.com") |
|||
} |
|||
return net.DialTimeout("tcp", net.JoinHostPort(ips[0].String(), "443"), 10*time.Second) |
|||
}, |
|||
} |
|||
httpClient := &fhttp.Client{ |
|||
Transport: tr, |
|||
Timeout: 60 * time.Second, |
|||
} |
|||
|
|||
req, err := fhttp.NewRequest("POST", "https://consumer-masque.cloudflareclient.com/"+internal.ApiVersion+"/reg", bytes.NewBuffer(jsonData)) |
|||
if err != nil { |
|||
return models.AccountData{}, fmt.Errorf("failed to create request: %v", err) |
|||
} |
|||
req.Host = "api.cloudflareclient.com" |
|||
|
|||
for k, v := range internal.Headers { |
|||
req.Header.Set(k, v) |
|||
} |
|||
|
|||
if jwt != "" { |
|||
req.Header.Set("CF-Access-Jwt-Assertion", jwt) |
|||
} |
|||
|
|||
resp, err := httpClient.Do(req) |
|||
if err != nil { |
|||
return models.AccountData{}, fmt.Errorf("failed to send request: %v", err) |
|||
} |
|||
defer resp.Body.Close() |
|||
|
|||
if resp.StatusCode != fhttp.StatusOK { |
|||
return models.AccountData{}, fmt.Errorf("failed to register: %v", resp.Status) |
|||
} |
|||
|
|||
var accountData models.AccountData |
|||
if err := json.NewDecoder(resp.Body).Decode(&accountData); err != nil { |
|||
return models.AccountData{}, fmt.Errorf("failed to decode response: %v", err) |
|||
} |
|||
|
|||
return accountData, nil |
|||
} |
|||
|
|||
// EnrollKey updates an existing user account with a new MASQUE public key.
|
|||
//
|
|||
// This function sends a PATCH request to update the user's account with a new key.
|
|||
//
|
|||
// Parameters:
|
|||
// - accountData: models.AccountData - The account data of the user being updated.
|
|||
// - pubKey: []byte - The new MASQUE public key in binary format.
|
|||
// - deviceName: string - The name of the device to enroll. (optional)
|
|||
//
|
|||
// Returns:
|
|||
// - models.AccountData: The updated account data.
|
|||
// - error: An error if the update process fails.
|
|||
//
|
|||
// Example:
|
|||
//
|
|||
// updatedAccount, apiErr, err := EnrollKey(account, pubKey, "PC")
|
|||
// if err != nil {
|
|||
// log.Fatalf("Key enrollment failed: %v", err)
|
|||
// }
|
|||
func EnrollKey(accountData models.AccountData, pubKey []byte, deviceName string) (models.AccountData, *models.APIError, error) { |
|||
deviceUpdate := models.DeviceUpdate{ |
|||
Key: base64.StdEncoding.EncodeToString(pubKey), |
|||
KeyType: internal.KeyTypeMasque, |
|||
TunType: internal.TunTypeMasque, |
|||
} |
|||
|
|||
if deviceName != "" { |
|||
deviceUpdate.Name = deviceName |
|||
} |
|||
|
|||
jsonData, err := json.Marshal(deviceUpdate) |
|||
if err != nil { |
|||
return models.AccountData{}, nil, fmt.Errorf("failed to marshal json: %v", err) |
|||
} |
|||
|
|||
tr := &fhttp.Transport{ |
|||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { |
|||
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", "api.cloudflareclient.com") |
|||
if err != nil || len(ips) == 0 { |
|||
return nil, fmt.Errorf("DNS resolution failed for api.cloudflareclient.com") |
|||
} |
|||
return net.DialTimeout("tcp", net.JoinHostPort(ips[0].String(), "443"), 10*time.Second) |
|||
}, |
|||
} |
|||
httpClient := &fhttp.Client{ |
|||
Transport: tr, |
|||
Timeout: 60 * time.Second, |
|||
} |
|||
|
|||
req, err := fhttp.NewRequest("PATCH", "https://consumer-masque.cloudflareclient.com/"+internal.ApiVersion+"/reg/"+accountData.ID, bytes.NewBuffer(jsonData)) |
|||
if err != nil { |
|||
return models.AccountData{}, nil, fmt.Errorf("failed to create request: %v", err) |
|||
} |
|||
req.Host = "api.cloudflareclient.com" |
|||
|
|||
for k, v := range internal.Headers { |
|||
req.Header.Set(k, v) |
|||
} |
|||
req.Header.Set("Authorization", "Bearer "+accountData.Token) |
|||
|
|||
resp, err := httpClient.Do(req) |
|||
if err != nil { |
|||
return models.AccountData{}, nil, fmt.Errorf("failed to send request: %v", err) |
|||
} |
|||
defer resp.Body.Close() |
|||
|
|||
body, err := io.ReadAll(resp.Body) |
|||
if err != nil { |
|||
return models.AccountData{}, nil, fmt.Errorf("failed to read response body: %v", err) |
|||
} |
|||
|
|||
if resp.StatusCode != fhttp.StatusOK { |
|||
var apiErr models.APIError |
|||
if err := json.Unmarshal(body, &apiErr); err != nil { |
|||
return models.AccountData{}, nil, fmt.Errorf("failed to parse error response: %v", err) |
|||
} |
|||
return models.AccountData{}, &apiErr, fmt.Errorf("failed to update: %s", resp.Status) |
|||
} |
|||
|
|||
if err := json.Unmarshal(body, &accountData); err != nil { |
|||
return models.AccountData{}, nil, fmt.Errorf("failed to decode response: %v", err) |
|||
} |
|||
|
|||
return accountData, nil, nil |
|||
} |
|||
@ -0,0 +1,200 @@ |
|||
package api |
|||
|
|||
import ( |
|||
"context" |
|||
"crypto/ecdsa" |
|||
"crypto/tls" |
|||
"crypto/x509" |
|||
"errors" |
|||
"fmt" |
|||
"net" |
|||
"net/http" |
|||
|
|||
connectip "github.com/Diniboy1123/connect-ip-go" |
|||
"github.com/quic-go/quic-go" |
|||
"github.com/quic-go/quic-go/http3" |
|||
"github.com/yosida95/uritemplate/v3" |
|||
) |
|||
|
|||
// fixedPeerConn wraps a net.PacketConn and makes it behave like a point-to-point
|
|||
// connection to a fixed peer (e.g. the Cloudflare MASQUE endpoint).
|
|||
// This is critical when using a TURN relay as the QUIC transport: the relay
|
|||
// conn knows how to send/receive via TURN indications, but quic-go needs
|
|||
// the connection to look like a direct pipe to the remote.
|
|||
// Matches the fixedPeerConn from the working vk-turn-usque-old implementation.
|
|||
type fixedPeerConn struct { |
|||
net.PacketConn |
|||
peer net.Addr |
|||
} |
|||
|
|||
func (c *fixedPeerConn) Write(p []byte) (n int, err error) { |
|||
return c.PacketConn.WriteTo(p, c.peer) |
|||
} |
|||
|
|||
func (c *fixedPeerConn) Read(p []byte) (n int, err error) { |
|||
n, _, err = c.PacketConn.ReadFrom(p) |
|||
return n, err |
|||
} |
|||
|
|||
func (c *fixedPeerConn) RemoteAddr() net.Addr { |
|||
return c.peer |
|||
} |
|||
|
|||
// PrepareTlsConfig creates a TLS configuration using the provided certificate and SNI (Server Name Indication).
|
|||
// It also verifies the peer's public key against the provided public key.
|
|||
//
|
|||
// Parameters:
|
|||
// - privKey: *ecdsa.PrivateKey - The private key to use for TLS authentication.
|
|||
// - peerPubKey: *ecdsa.PublicKey - The endpoint's public key to pin to.
|
|||
// - cert: [][]byte - The certificate chain to use for TLS authentication.
|
|||
// - sni: string - The Server Name Indication (SNI) to use.
|
|||
//
|
|||
// Returns:
|
|||
// - *tls.Config: A TLS configuration for secure communication.
|
|||
// - error: An error if TLS setup fails.
|
|||
func PrepareTlsConfig(privKey *ecdsa.PrivateKey, peerPubKey *ecdsa.PublicKey, cert [][]byte, sni string) (*tls.Config, error) { |
|||
tlsConfig := &tls.Config{ |
|||
Certificates: []tls.Certificate{ |
|||
{ |
|||
Certificate: cert, |
|||
PrivateKey: privKey, |
|||
}, |
|||
}, |
|||
ServerName: sni, |
|||
NextProtos: []string{http3.NextProtoH3}, |
|||
MinVersion: tls.VersionTLS12, |
|||
MaxVersion: tls.VersionTLS13, |
|||
CurvePreferences: []tls.CurveID{ |
|||
tls.X25519, |
|||
tls.CurveP256, |
|||
}, |
|||
CipherSuites: []uint16{ |
|||
tls.TLS_AES_128_GCM_SHA256, |
|||
tls.TLS_AES_256_GCM_SHA384, |
|||
tls.TLS_CHACHA20_POLY1305_SHA256, |
|||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, |
|||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, |
|||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, |
|||
}, |
|||
// WARN: SNI is usually not for the endpoint, so we must skip verification
|
|||
InsecureSkipVerify: true, |
|||
// we pin to the endpoint public key
|
|||
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { |
|||
if len(rawCerts) == 0 { |
|||
return nil |
|||
} |
|||
|
|||
cert, err := x509.ParseCertificate(rawCerts[0]) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
if _, ok := cert.PublicKey.(*ecdsa.PublicKey); !ok { |
|||
// we only support ECDSA
|
|||
// TODO: don't hardcode cert type in the future
|
|||
// as backend can start using different cert types
|
|||
return x509.ErrUnsupportedAlgorithm |
|||
} |
|||
|
|||
if !cert.PublicKey.(*ecdsa.PublicKey).Equal(peerPubKey) { |
|||
// reason is incorrect, but the best I could figure
|
|||
// detail explains the actual reason
|
|||
|
|||
//10 is NoValidChains, but we support go1.22 where it's not defined
|
|||
return x509.CertificateInvalidError{Cert: cert, Reason: 10, Detail: "remote endpoint has a different public key than what we trust in config.json"} |
|||
} |
|||
|
|||
return nil |
|||
}, |
|||
} |
|||
|
|||
return tlsConfig, nil |
|||
} |
|||
|
|||
// ConnectTunnel establishes a QUIC connection and sets up a Connect-IP tunnel with the provided endpoint.
|
|||
// Endpoint address is used to check whether the authentication/connection is successful or not.
|
|||
// Requires modified connect-ip-go for now to support Cloudflare's non RFC compliant implementation.
|
|||
//
|
|||
// Parameters:
|
|||
// - ctx: context.Context - The QUIC TLS context.
|
|||
// - tlsConfig: *tls.Config - The TLS configuration for secure communication.
|
|||
// - quicConfig: *quic.Config - The QUIC configuration settings.
|
|||
// - connectUri: string - The URI template for the Connect-IP request.
|
|||
// - endpoint: *net.UDPAddr - The UDP address of the QUIC server.
|
|||
// - baseConn: net.PacketConn - Optional pre-allocated connection (e.g. from VK TURN relay). If nil, a new UDP socket is created.
|
|||
//
|
|||
// Returns:
|
|||
// - net.PacketConn: The packet connection used for the QUIC session.
|
|||
// - *http3.Transport: The HTTP/3 transport used for initial request.
|
|||
// - *connectip.Conn: The Connect-IP connection instance.
|
|||
// - *http.Response: The response from the Connect-IP handshake.
|
|||
// - error: An error if the connection setup fails.
|
|||
func ConnectTunnel(ctx context.Context, tlsConfig *tls.Config, quicConfig *quic.Config, connectUri string, endpoint *net.UDPAddr, baseConn net.PacketConn) (net.PacketConn, *http3.Transport, *connectip.Conn, *http.Response, error) { |
|||
var conn net.PacketConn |
|||
var err error |
|||
|
|||
if baseConn != nil { |
|||
// Wrap the TURN relay conn in fixedPeerConn so quic-go sees it as a
|
|||
// point-to-point connection to the Cloudflare endpoint.
|
|||
// Without this wrapping, some QUIC packet flows don't survive the
|
|||
// TURN relay hop (e.g. keepalives and connect-ip IP packets time out).
|
|||
conn = &fixedPeerConn{PacketConn: baseConn, peer: endpoint} |
|||
} else { |
|||
// Create a new UDP socket for direct connection to the Cloudflare MASQUE endpoint
|
|||
var udpConn *net.UDPConn |
|||
if endpoint.IP.To4() == nil { |
|||
udpConn, err = net.ListenUDP("udp", &net.UDPAddr{ |
|||
IP: net.IPv6zero, |
|||
Port: 0, |
|||
}) |
|||
} else { |
|||
udpConn, err = net.ListenUDP("udp", &net.UDPAddr{ |
|||
IP: net.IPv4zero, |
|||
Port: 0, |
|||
}) |
|||
} |
|||
if err != nil { |
|||
return nil, nil, nil, nil, err |
|||
} |
|||
conn = udpConn |
|||
} |
|||
|
|||
qconn, err := quic.Dial( |
|||
ctx, |
|||
conn, |
|||
endpoint, |
|||
tlsConfig, |
|||
quicConfig, |
|||
) |
|||
if err != nil { |
|||
return conn, nil, nil, nil, err |
|||
} |
|||
|
|||
tr := &http3.Transport{ |
|||
EnableDatagrams: true, |
|||
AdditionalSettings: map[uint64]uint64{ |
|||
// SETTINGS_H3_DATAGRAM (current IETF RFC 9297) - required by Cloudflare
|
|||
0x33: 1, |
|||
// SETTINGS_H3_DATAGRAM_00 (deprecated draft, but official client still sends it)
|
|||
0x276: 1, |
|||
}, |
|||
DisableCompression: true, |
|||
} |
|||
|
|||
hconn := tr.NewClientConn(qconn) |
|||
|
|||
additionalHeaders := http.Header{ |
|||
"User-Agent": []string{""}, |
|||
} |
|||
|
|||
template := uritemplate.MustNew(connectUri) |
|||
ipConn, rsp, err := connectip.Dial(ctx, hconn, template, "cf-connect-ip", additionalHeaders, true) |
|||
if err != nil { |
|||
if err.Error() == "CRYPTO_ERROR 0x131 (remote): tls: access denied" { |
|||
return conn, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") |
|||
} |
|||
return conn, nil, nil, nil, fmt.Errorf("failed to dial connect-ip: %v", err) |
|||
} |
|||
|
|||
return conn, tr, ipConn, rsp, nil |
|||
} |
|||
@ -0,0 +1,348 @@ |
|||
package api |
|||
|
|||
import ( |
|||
"context" |
|||
"crypto/rand" |
|||
"crypto/tls" |
|||
"errors" |
|||
"fmt" |
|||
"log" |
|||
"math/big" |
|||
"net" |
|||
"sync" |
|||
"sync/atomic" |
|||
"time" |
|||
|
|||
connectip "github.com/Diniboy1123/connect-ip-go" |
|||
"github.com/cacggghp/vk-turn-proxy/client/warp/internal" |
|||
"github.com/songgao/water" |
|||
"golang.zx2c4.com/wireguard/tun" |
|||
) |
|||
|
|||
// Verbose controls whether diagnostic logs like tunnel stats are printed.
|
|||
var Verbose bool |
|||
|
|||
// NetBuffer is a pool of byte slices with a fixed capacity.
|
|||
// Helps to reduce memory allocations and improve performance.
|
|||
// It uses a sync.Pool to manage the byte slices.
|
|||
// The capacity of the byte slices is set when the pool is created.
|
|||
type NetBuffer struct { |
|||
capacity int |
|||
buf sync.Pool |
|||
} |
|||
|
|||
// Get returns a byte slice from the pool.
|
|||
func (n *NetBuffer) Get() []byte { |
|||
return *(n.buf.Get().(*[]byte)) |
|||
} |
|||
|
|||
// Put places a byte slice back into the pool.
|
|||
// It checks if the capacity of the byte slice matches the pool's capacity.
|
|||
// If it doesn't match, the byte slice is not returned to the pool.
|
|||
func (n *NetBuffer) Put(buf []byte) { |
|||
if cap(buf) != n.capacity { |
|||
return |
|||
} |
|||
n.buf.Put(&buf) |
|||
} |
|||
|
|||
// NewNetBuffer creates a new NetBuffer with the specified capacity.
|
|||
// The capacity must be greater than 0.
|
|||
func NewNetBuffer(capacity int) *NetBuffer { |
|||
if capacity <= 0 { |
|||
panic("capacity must be greater than 0") |
|||
} |
|||
return &NetBuffer{ |
|||
capacity: capacity, |
|||
buf: sync.Pool{ |
|||
New: func() interface{} { |
|||
b := make([]byte, capacity) |
|||
return &b |
|||
}, |
|||
}, |
|||
} |
|||
} |
|||
|
|||
// TunnelDevice abstracts a TUN device so that we can use the same tunnel-maintenance code
|
|||
// regardless of the underlying implementation.
|
|||
type TunnelDevice interface { |
|||
// ReadPacket reads a packet from the device (using the given mtu) and returns its contents.
|
|||
ReadPacket(buf []byte) (int, error) |
|||
// WritePacket writes a packet to the device.
|
|||
WritePacket(pkt []byte) error |
|||
} |
|||
|
|||
// NetstackAdapter wraps a tun.Device (e.g. from netstack) to satisfy TunnelDevice.
|
|||
type NetstackAdapter struct { |
|||
dev tun.Device |
|||
tunnelBufPool sync.Pool |
|||
tunnelSizesPool sync.Pool |
|||
} |
|||
|
|||
func (n *NetstackAdapter) ReadPacket(buf []byte) (int, error) { |
|||
packetBufsPtr := n.tunnelBufPool.Get().(*[][]byte) |
|||
sizesPtr := n.tunnelSizesPool.Get().(*[]int) |
|||
|
|||
defer func() { |
|||
(*packetBufsPtr)[0] = nil |
|||
n.tunnelBufPool.Put(packetBufsPtr) |
|||
n.tunnelSizesPool.Put(sizesPtr) |
|||
}() |
|||
|
|||
(*packetBufsPtr)[0] = buf |
|||
(*sizesPtr)[0] = 0 |
|||
|
|||
_, err := n.dev.Read(*packetBufsPtr, *sizesPtr, 0) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
return (*sizesPtr)[0], nil |
|||
} |
|||
|
|||
func (n *NetstackAdapter) WritePacket(pkt []byte) error { |
|||
// Write expects a slice of packet buffers.
|
|||
_, err := n.dev.Write([][]byte{pkt}, 0) |
|||
return err |
|||
} |
|||
|
|||
// NewNetstackAdapter creates a new NetstackAdapter.
|
|||
func NewNetstackAdapter(dev tun.Device) TunnelDevice { |
|||
return &NetstackAdapter{ |
|||
dev: dev, |
|||
tunnelBufPool: sync.Pool{ |
|||
New: func() interface{} { |
|||
buf := make([][]byte, 1) |
|||
return &buf |
|||
}, |
|||
}, |
|||
tunnelSizesPool: sync.Pool{ |
|||
New: func() interface{} { |
|||
sizes := make([]int, 1) |
|||
return &sizes |
|||
}, |
|||
}, |
|||
} |
|||
} |
|||
|
|||
// WaterAdapter wraps a *water.Interface so it satisfies TunnelDevice.
|
|||
type WaterAdapter struct { |
|||
iface *water.Interface |
|||
} |
|||
|
|||
func (w *WaterAdapter) ReadPacket(buf []byte) (int, error) { |
|||
n, err := w.iface.Read(buf) |
|||
if err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
return n, nil |
|||
} |
|||
|
|||
func (w *WaterAdapter) WritePacket(pkt []byte) error { |
|||
_, err := w.iface.Write(pkt) |
|||
return err |
|||
} |
|||
|
|||
// NewWaterAdapter creates a new WaterAdapter.
|
|||
func NewWaterAdapter(iface *water.Interface) TunnelDevice { |
|||
return &WaterAdapter{iface: iface} |
|||
} |
|||
|
|||
// GetRelayConnFunc is a function type that returns a pre-allocated packet connection
|
|||
// for use as a TURN relay (e.g. from VK TURN). If nil is provided, a direct UDP
|
|||
// connection to the MASQUE endpoint will be created.
|
|||
type GetRelayConnFunc func(ctx context.Context) (net.PacketConn, error) |
|||
|
|||
// MaintainTunnel continuously connects to the MASQUE server, then starts two
|
|||
// forwarding goroutines: one forwarding from the device to the IP connection (and handling
|
|||
// any ICMP reply), and the other forwarding from the IP connection to the device.
|
|||
// If an error occurs in either loop, the connection is closed and a reconnect is attempted.
|
|||
//
|
|||
// Parameters:
|
|||
// - ctx: context.Context - The context for the connection.
|
|||
// - tlsConfig: *tls.Config - The TLS configuration for secure communication.
|
|||
// - keepalivePeriod: time.Duration - The keepalive period for the QUIC connection.
|
|||
// - initialPacketSize: uint16 - The initial packet size for the QUIC connection.
|
|||
// - endpoint: *net.UDPAddr - The UDP address of the MASQUE server.
|
|||
// - device: TunnelDevice - The TUN device to forward packets to and from.
|
|||
// - mtu: int - The MTU of the TUN device.
|
|||
// - reconnectDelay: time.Duration - The delay between reconnect attempts.
|
|||
// - getRelayConn: GetRelayConnFunc - Optional function to obtain a TURN relay connection.
|
|||
// If nil, a direct UDP connection to the endpoint is used.
|
|||
// - onReady: func(bool) - Optional callback fired with true when connected, and false when disconnected.
|
|||
func MaintainTunnel(ctx context.Context, tlsConfig *tls.Config, keepalivePeriod time.Duration, initialPacketSize uint16, endpoint *net.UDPAddr, device TunnelDevice, mtu int, reconnectDelay time.Duration, getRelayConn GetRelayConnFunc, onReady func(bool)) { |
|||
packetBufferPool := NewNetBuffer(mtu) |
|||
for { |
|||
// Check if context is done before attempting connection
|
|||
select { |
|||
case <-ctx.Done(): |
|||
return |
|||
default: |
|||
} |
|||
|
|||
log.Printf("Establishing MASQUE connection to %s:%d", endpoint.IP, endpoint.Port) |
|||
|
|||
// Optionally obtain a TURN relay packet connection
|
|||
var baseConn net.PacketConn |
|||
if getRelayConn != nil { |
|||
var err error |
|||
baseConn, err = getRelayConn(ctx) |
|||
if err != nil { |
|||
log.Printf("Failed to obtain TURN relay connection: %v", err) |
|||
select { |
|||
case <-ctx.Done(): |
|||
return |
|||
case <-time.After(reconnectDelay): |
|||
} |
|||
continue |
|||
} |
|||
} |
|||
|
|||
udpConn, tr, ipConn, rsp, err := ConnectTunnel( |
|||
ctx, |
|||
tlsConfig, |
|||
internal.DefaultQuicConfig(keepalivePeriod, initialPacketSize), |
|||
internal.ConnectURI, |
|||
endpoint, |
|||
baseConn, |
|||
) |
|||
if err != nil { |
|||
log.Printf("Failed to connect tunnel: %v", err) |
|||
if udpConn != nil { |
|||
udpConn.Close() |
|||
} |
|||
select { |
|||
case <-ctx.Done(): |
|||
return |
|||
case <-time.After(reconnectDelay): |
|||
} |
|||
continue |
|||
} |
|||
if rsp.StatusCode != 200 { |
|||
log.Printf("Tunnel connection failed: %s", rsp.Status) |
|||
ipConn.Close() |
|||
if udpConn != nil { |
|||
udpConn.Close() |
|||
} |
|||
if tr != nil { |
|||
tr.Close() |
|||
} |
|||
select { |
|||
case <-ctx.Done(): |
|||
return |
|||
case <-time.After(reconnectDelay): |
|||
} |
|||
continue |
|||
} |
|||
|
|||
log.Println("Connected to MASQUE server") |
|||
if onReady != nil { |
|||
onReady(true) |
|||
} |
|||
errChan := make(chan error, 2) |
|||
|
|||
// Packet counters for diagnostics
|
|||
var txPkts, rxPkts atomic.Int64 |
|||
go func() { |
|||
ticker := time.NewTicker(10 * time.Second) |
|||
defer ticker.Stop() |
|||
for { |
|||
select { |
|||
case <-ticker.C: |
|||
if Verbose { |
|||
log.Printf("[Warp] Tunnel stats: TX=%d pkts, RX=%d pkts", txPkts.Load(), rxPkts.Load()) |
|||
} |
|||
case <-ctx.Done(): |
|||
return |
|||
} |
|||
} |
|||
}() |
|||
|
|||
go func() { |
|||
for { |
|||
buf := packetBufferPool.Get() |
|||
n, err := device.ReadPacket(buf) |
|||
if err != nil { |
|||
packetBufferPool.Put(buf) |
|||
errChan <- fmt.Errorf("failed to read from TUN device: %v", err) |
|||
return |
|||
} |
|||
txPkts.Add(1) |
|||
|
|||
paddedSize := n |
|||
if n < mtu-100 { |
|||
randOffset, _ := rand.Int(rand.Reader, big.NewInt(64)) |
|||
paddedSize = n + int(randOffset.Int64()) |
|||
if paddedSize > mtu { |
|||
paddedSize = mtu |
|||
} |
|||
if paddedSize > n { |
|||
_, _ = rand.Read(buf[n:paddedSize]) |
|||
} |
|||
} |
|||
|
|||
icmp, err := ipConn.WritePacket(buf[:paddedSize]) |
|||
if err != nil { |
|||
packetBufferPool.Put(buf) |
|||
if errors.As(err, new(*connectip.CloseError)) { |
|||
errChan <- fmt.Errorf("connection closed while writing to IP connection: %v", err) |
|||
return |
|||
} |
|||
log.Printf("Error writing to IP connection: %v, continuing...", err) |
|||
continue |
|||
} |
|||
packetBufferPool.Put(buf) |
|||
|
|||
if len(icmp) > 0 { |
|||
if err := device.WritePacket(icmp); err != nil { |
|||
if errors.As(err, new(*connectip.CloseError)) { |
|||
errChan <- fmt.Errorf("connection closed while writing ICMP to TUN device: %v", err) |
|||
return |
|||
} |
|||
log.Printf("Error writing ICMP to TUN device: %v, continuing...", err) |
|||
} |
|||
} |
|||
} |
|||
}() |
|||
|
|||
go func() { |
|||
buf := packetBufferPool.Get() |
|||
defer packetBufferPool.Put(buf) |
|||
for { |
|||
n, err := ipConn.ReadPacket(buf, true) |
|||
if err != nil { |
|||
if errors.As(err, new(*connectip.CloseError)) { |
|||
errChan <- fmt.Errorf("connection closed while reading from IP connection: %v", err) |
|||
return |
|||
} |
|||
log.Printf("Error reading from IP connection: %v, continuing...", err) |
|||
continue |
|||
} |
|||
rxPkts.Add(1) |
|||
if err := device.WritePacket(buf[:n]); err != nil { |
|||
errChan <- fmt.Errorf("failed to write to TUN device: %v", err) |
|||
return |
|||
} |
|||
} |
|||
}() |
|||
|
|||
err = <-errChan |
|||
if onReady != nil { |
|||
onReady(false) |
|||
} |
|||
log.Printf("Tunnel connection lost: %v. Reconnecting...", err) |
|||
ipConn.Close() |
|||
if udpConn != nil { |
|||
udpConn.Close() |
|||
} |
|||
if tr != nil { |
|||
tr.Close() |
|||
} |
|||
select { |
|||
case <-ctx.Done(): |
|||
return |
|||
case <-time.After(reconnectDelay): |
|||
} |
|||
} |
|||
} |
|||
@ -0,0 +1,120 @@ |
|||
package config |
|||
|
|||
import ( |
|||
"crypto/ecdsa" |
|||
"crypto/x509" |
|||
"encoding/base64" |
|||
"encoding/json" |
|||
"encoding/pem" |
|||
"fmt" |
|||
"os" |
|||
) |
|||
|
|||
// Config represents the application configuration structure, containing essential details such as keys, endpoints, and access tokens.
|
|||
type Config struct { |
|||
PrivateKey string `json:"private_key"` // Base64-encoded ECDSA private key
|
|||
EndpointV4 string `json:"endpoint_v4"` // IPv4 address of the endpoint
|
|||
EndpointV6 string `json:"endpoint_v6"` // IPv6 address of the endpoint
|
|||
EndpointPubKey string `json:"endpoint_pub_key"` // PEM-encoded ECDSA public key of the endpoint to verify against
|
|||
License string `json:"license"` // Application license key
|
|||
ID string `json:"id"` // Device unique identifier
|
|||
AccessToken string `json:"access_token"` // Authentication token for API access
|
|||
IPv4 string `json:"ipv4"` // Assigned IPv4 address
|
|||
IPv6 string `json:"ipv6"` // Assigned IPv6 address
|
|||
} |
|||
|
|||
// AppConfig holds the global application configuration.
|
|||
var AppConfig Config |
|||
|
|||
// ConfigLoaded indicates whether the configuration has been successfully loaded.
|
|||
var ConfigLoaded bool |
|||
|
|||
// LoadConfig loads the application configuration from a JSON file.
|
|||
//
|
|||
// Parameters:
|
|||
// - configPath: string - The path to the configuration JSON file.
|
|||
//
|
|||
// Returns:
|
|||
// - error: An error if the configuration file cannot be loaded or parsed.
|
|||
func LoadConfig(configPath string) error { |
|||
file, err := os.Open(configPath) |
|||
if err != nil { |
|||
return fmt.Errorf("failed to open config file: %v", err) |
|||
} |
|||
defer file.Close() |
|||
|
|||
decoder := json.NewDecoder(file) |
|||
if err := decoder.Decode(&AppConfig); err != nil { |
|||
return fmt.Errorf("failed to decode config file: %v", err) |
|||
} |
|||
|
|||
ConfigLoaded = true |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// SaveConfig writes the current application configuration to a prettified JSON file.
|
|||
//
|
|||
// Parameters:
|
|||
// - configPath: string - The path to save the configuration JSON file.
|
|||
//
|
|||
// Returns:
|
|||
// - error: An error if the configuration file cannot be written.
|
|||
func (*Config) SaveConfig(configPath string) error { |
|||
file, err := os.Create(configPath) |
|||
if err != nil { |
|||
return fmt.Errorf("failed to create config file: %v", err) |
|||
} |
|||
defer file.Close() |
|||
|
|||
encoder := json.NewEncoder(file) |
|||
encoder.SetIndent("", " ") |
|||
if err := encoder.Encode(AppConfig); err != nil { |
|||
return fmt.Errorf("failed to encode config file: %v", err) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// GetEcPrivateKey retrieves the ECDSA private key from the stored Base64-encoded string.
|
|||
//
|
|||
// Returns:
|
|||
// - *ecdsa.PrivateKey: The parsed ECDSA private key.
|
|||
// - error: An error if decoding or parsing the private key fails.
|
|||
func (*Config) GetEcPrivateKey() (*ecdsa.PrivateKey, error) { |
|||
privKeyB64, err := base64.StdEncoding.DecodeString(AppConfig.PrivateKey) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to decode private key: %v", err) |
|||
} |
|||
|
|||
privKey, err := x509.ParseECPrivateKey(privKeyB64) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to parse private key: %v", err) |
|||
} |
|||
|
|||
return privKey, nil |
|||
} |
|||
|
|||
// GetEcEndpointPublicKey retrieves the ECDSA public key from the stored PEM-encoded string.
|
|||
//
|
|||
// Returns:
|
|||
// - *ecdsa.PublicKey: The parsed ECDSA public key.
|
|||
// - error: An error if decoding or parsing the public key fails.
|
|||
func (*Config) GetEcEndpointPublicKey() (*ecdsa.PublicKey, error) { |
|||
endpointPubKeyB64, _ := pem.Decode([]byte(AppConfig.EndpointPubKey)) |
|||
if endpointPubKeyB64 == nil { |
|||
return nil, fmt.Errorf("failed to decode endpoint public key") |
|||
} |
|||
|
|||
pubKey, err := x509.ParsePKIXPublicKey(endpointPubKeyB64.Bytes) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to parse public key: %v", err) |
|||
} |
|||
|
|||
ecPubKey, ok := pubKey.(*ecdsa.PublicKey) |
|||
if !ok { |
|||
return nil, fmt.Errorf("failed to assert public key as ECDSA") |
|||
} |
|||
|
|||
return ecPubKey, nil |
|||
} |
|||
@ -0,0 +1,23 @@ |
|||
package internal |
|||
|
|||
const ( |
|||
ApiUrl = "https://api.cloudflareclient.com" |
|||
ApiVersion = "v0a4471" |
|||
ConnectSNI = "consumer-masque.cloudflareclient.com" |
|||
// unused for now
|
|||
ZeroTierSNI = "zt-masque.cloudflareclient.com" |
|||
ConnectURI = "https://cloudflareaccess.com" |
|||
DefaultModel = "PC" |
|||
KeyTypeWg = "curve25519" |
|||
TunTypeWg = "wireguard" |
|||
KeyTypeMasque = "secp256r1" |
|||
TunTypeMasque = "masque" |
|||
DefaultLocale = "en_US" |
|||
) |
|||
|
|||
var Headers = map[string]string{ |
|||
"User-Agent": "WARP for Android", |
|||
"CF-Client-Version": "a-6.35-4471", |
|||
"Content-Type": "application/json; charset=UTF-8", |
|||
"Connection": "Keep-Alive", |
|||
} |
|||
@ -0,0 +1,195 @@ |
|||
package internal |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
"net" |
|||
"net/netip" |
|||
"time" |
|||
|
|||
"golang.zx2c4.com/wireguard/tun/netstack" |
|||
) |
|||
|
|||
// TunnelDNSResolver implements a DNS resolver that uses the provided DNS servers
|
|||
// either inside a MASQUE tunnel (if TunNet is set) or over the system network (if TunNet is nil).
|
|||
type TunnelDNSResolver struct { |
|||
// TunNet is the network stack for the tunnel you want to use for DNS resolution.
|
|||
// If nil, DNS queries are sent over the system network.
|
|||
TunNet *netstack.Net |
|||
|
|||
// DNSAddrs is the list of DNS servers to use for resolution.
|
|||
DNSAddrs []netip.Addr |
|||
|
|||
// Timeout is the timeout for DNS queries on a specific server before trying the next one.
|
|||
Timeout time.Duration |
|||
} |
|||
|
|||
// Resolve performs a DNS lookup using the provided DNS resolvers.
|
|||
// It tries each resolver in order until one succeeds, sending queries either through the tunnel
|
|||
// or over the system network depending on TunNet.
|
|||
//
|
|||
// Parameters:
|
|||
// - ctx: context.Context - The context for the DNS lookup.
|
|||
// - name: string - The domain name to resolve.
|
|||
//
|
|||
// Returns:
|
|||
// - context.Context: The original context for the DNS lookup.
|
|||
// - net.IP: The resolved IP address.
|
|||
// - error: An error if the lookup fails.
|
|||
func (r TunnelDNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { |
|||
if len(r.DNSAddrs) == 0 { |
|||
return ctx, nil, fmt.Errorf("no DNS servers configured") |
|||
} |
|||
|
|||
var queryCtx context.Context = ctx |
|||
var cancel context.CancelFunc |
|||
if r.Timeout > 0 { |
|||
queryCtx, cancel = context.WithTimeout(ctx, r.Timeout) |
|||
defer cancel() |
|||
} |
|||
|
|||
type result struct { |
|||
ip net.IP |
|||
err error |
|||
} |
|||
results := make(chan result, len(r.DNSAddrs)) |
|||
|
|||
for _, dnsAddr := range r.DNSAddrs { |
|||
dnsHost := net.JoinHostPort(dnsAddr.String(), "53") |
|||
|
|||
go func(dnsHost string) { |
|||
var dialFunc func(context.Context, string, string) (net.Conn, error) |
|||
if r.TunNet != nil { |
|||
dialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { |
|||
return r.TunNet.DialContext(ctx, "udp", dnsHost) |
|||
} |
|||
} else { |
|||
dialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { |
|||
return net.Dial("udp", dnsHost) |
|||
} |
|||
} |
|||
|
|||
resolver := &net.Resolver{ |
|||
PreferGo: true, |
|||
Dial: dialFunc, |
|||
} |
|||
ips, err := resolver.LookupIP(queryCtx, "ip", name) |
|||
if err == nil && len(ips) > 0 { |
|||
results <- result{ip: ips[0], err: nil} |
|||
} else { |
|||
results <- result{ip: nil, err: err} |
|||
} |
|||
}(dnsHost) |
|||
} |
|||
|
|||
var lastErr error |
|||
for i := 0; i < len(r.DNSAddrs); i++ { |
|||
res := <-results |
|||
if res.err == nil && res.ip != nil { |
|||
if cancel != nil { |
|||
cancel() |
|||
} |
|||
return ctx, res.ip, nil |
|||
} |
|||
lastErr = res.err |
|||
} |
|||
|
|||
return ctx, nil, fmt.Errorf("all DNS servers failed: %v", lastErr) |
|||
} |
|||
|
|||
// NewNetstackResolver returns a *net.Resolver that uses the tunnel network stack
|
|||
// and provided DNS servers for DNS queries.
|
|||
//
|
|||
// Parameters:
|
|||
// - tunNet: *netstack.Net - The tunnel network stack.
|
|||
// - dnsAddrs: []netip.Addr - DNS server addresses.
|
|||
//
|
|||
// Returns:
|
|||
// - *net.Resolver - A resolver that routes queries through the tunnel.
|
|||
func NewNetstackResolver(tunNet *netstack.Net, dnsAddrs []netip.Addr) *net.Resolver { |
|||
return &net.Resolver{ |
|||
PreferGo: true, |
|||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { |
|||
if len(dnsAddrs) == 0 { |
|||
return nil, fmt.Errorf("no DNS servers configured") |
|||
} |
|||
if len(dnsAddrs) > 1 { |
|||
return raceDial(ctx, tunNet, dnsAddrs) |
|||
} |
|||
dnsHost := net.JoinHostPort(dnsAddrs[0].String(), "53") |
|||
return tunNet.DialContext(ctx, "udp", dnsHost) |
|||
}, |
|||
} |
|||
} |
|||
|
|||
func raceDial(ctx context.Context, tunNet *netstack.Net, addrs []netip.Addr) (net.Conn, error) { |
|||
type result struct { |
|||
conn net.Conn |
|||
err error |
|||
} |
|||
resChan := make(chan result, len(addrs)) |
|||
childCtx, cancel := context.WithCancel(ctx) |
|||
defer cancel() |
|||
|
|||
for _, addr := range addrs { |
|||
go func(a netip.Addr) { |
|||
dnsHost := net.JoinHostPort(a.String(), "53") |
|||
conn, err := tunNet.DialContext(childCtx, "udp", dnsHost) |
|||
if err == nil { |
|||
select { |
|||
case resChan <- result{conn: conn}: |
|||
case <-ctx.Done(): |
|||
conn.Close() |
|||
} |
|||
} else { |
|||
select { |
|||
case resChan <- result{err: err}: |
|||
case <-ctx.Done(): |
|||
} |
|||
} |
|||
}(addr) |
|||
} |
|||
|
|||
var lastErr error |
|||
for i := 0; i < len(addrs); i++ { |
|||
res := <-resChan |
|||
if res.err == nil { |
|||
return res.conn, nil |
|||
} |
|||
lastErr = res.err |
|||
} |
|||
return nil, fmt.Errorf("all DNS race dials failed: %w", lastErr) |
|||
} |
|||
|
|||
// NewStaticResolver returns a *net.Resolver that uses the provided DNS servers
|
|||
// for lookups over the system network.
|
|||
func NewStaticResolver(dnsAddrs []netip.Addr) *net.Resolver { |
|||
return &net.Resolver{ |
|||
PreferGo: true, |
|||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { |
|||
if len(dnsAddrs) == 0 { |
|||
return nil, fmt.Errorf("no DNS servers configured") |
|||
} |
|||
dnsHost := net.JoinHostPort(dnsAddrs[0].String(), "53") |
|||
return net.Dial("udp", dnsHost) |
|||
}, |
|||
} |
|||
} |
|||
|
|||
// GetProxyResolver returns the appropriate *net.Resolver for proxy use
|
|||
// based on the localDNS flag.
|
|||
//
|
|||
// Parameters:
|
|||
// - localDNS: bool - Whether to use the system network for DNS.
|
|||
// - tunNet: *netstack.Net - The tunnel network stack (if localDNS is false).
|
|||
// - dnsAddrs: []netip.Addr - DNS server addresses.
|
|||
// - timeout: time.Duration - Timeout for DNS queries.
|
|||
//
|
|||
// Returns:
|
|||
// - *net.Resolver - A resolver suitable for use with proxy connections.
|
|||
func GetProxyResolver(localDNS bool, tunNet *netstack.Net, dnsAddrs []netip.Addr, timeout time.Duration) *net.Resolver { |
|||
if localDNS { |
|||
return NewStaticResolver(dnsAddrs) |
|||
} |
|||
return NewNetstackResolver(tunNet, dnsAddrs) |
|||
} |
|||
@ -0,0 +1,318 @@ |
|||
package internal |
|||
|
|||
import ( |
|||
"crypto/ecdsa" |
|||
"crypto/elliptic" |
|||
"crypto/rand" |
|||
"crypto/x509" |
|||
"encoding/base64" |
|||
"encoding/hex" |
|||
"errors" |
|||
"log" |
|||
"math/big" |
|||
"net" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
|
|||
"github.com/quic-go/quic-go" |
|||
) |
|||
|
|||
// PortMapping represents a network port forwarding rule.
|
|||
type PortMapping struct { |
|||
BindAddress string // The address to bind the local port.
|
|||
LocalPort int // The local port number.
|
|||
RemoteIP string // The remote destination IP address.
|
|||
RemotePort int // The remote destination port number.
|
|||
} |
|||
|
|||
// GenerateRandomAndroidSerial generates a random 8-byte Android-like device identifier
|
|||
// and returns it as a hexadecimal string.
|
|||
//
|
|||
// Returns:
|
|||
// - string: A randomly generated 16-character hexadecimal serial number.
|
|||
// - error: An error if random data generation fails.
|
|||
func GenerateRandomAndroidSerial() (string, error) { |
|||
serial := make([]byte, 8) |
|||
if _, err := rand.Read(serial); err != nil { |
|||
return "", err |
|||
} |
|||
return hex.EncodeToString(serial), nil |
|||
} |
|||
|
|||
// GenerateRandomWgPubkey generates a random 32-byte WireGuard like public key
|
|||
// and returns it as a base64-encoded string.
|
|||
//
|
|||
// Returns:
|
|||
// - string: A randomly generated WireGuard like public key in base64 format.
|
|||
// - error: An error if random data generation fails.
|
|||
func GenerateRandomWgPubkey() (string, error) { |
|||
publicKey := make([]byte, 32) |
|||
if _, err := rand.Read(publicKey); err != nil { |
|||
return "", err |
|||
} |
|||
return base64.StdEncoding.EncodeToString(publicKey), nil |
|||
} |
|||
|
|||
// TimeAsCfString formats a given time.Time into a Cloudflare-compatible string format.
|
|||
//
|
|||
// The format follows the standard: "YYYY-MM-DDTHH:MM:SS.sss-07:00".
|
|||
//
|
|||
// Parameters:
|
|||
// - t: time.Time to format.
|
|||
//
|
|||
// Returns:
|
|||
// - string: The formatted time string.
|
|||
func TimeAsCfString(t time.Time) string { |
|||
return t.Format("2006-01-02T15:04:05.000-07:00") |
|||
} |
|||
|
|||
// GenerateEcKeyPair generates a new ECDSA key pair using the P-256 curve.
|
|||
//
|
|||
// Returns:
|
|||
// - []byte: The marshalled private key in ASN.1 DER format.
|
|||
// - []byte: The marshalled public key in PKIX format.
|
|||
// - error: An error if key generation or marshalling fails.
|
|||
func GenerateEcKeyPair() ([]byte, []byte, error) { |
|||
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) |
|||
if err != nil { |
|||
return nil, nil, err |
|||
} |
|||
|
|||
marshalledPrivKey, err := x509.MarshalECPrivateKey(privKey) |
|||
if err != nil { |
|||
return nil, nil, err |
|||
} |
|||
|
|||
marshalledPubKey, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey) |
|||
if err != nil { |
|||
return nil, nil, err |
|||
} |
|||
|
|||
return marshalledPrivKey, marshalledPubKey, nil |
|||
} |
|||
|
|||
// GenerateCert creates a self-signed certificate using the provided ECDSA private and public keys.
|
|||
//
|
|||
// The certificate is valid for 24 hours.
|
|||
//
|
|||
// Parameters:
|
|||
// - privKey: *ecdsa.PrivateKey - The private key to sign the certificate.
|
|||
// - pubKey: *ecdsa.PublicKey - The public key to include in the certificate.
|
|||
//
|
|||
// Returns:
|
|||
// - [][]byte: A slice containing the certificate in DER format.
|
|||
// - error: An error if certificate generation fails.
|
|||
func GenerateCert(privKey *ecdsa.PrivateKey, pubKey *ecdsa.PublicKey) ([][]byte, error) { |
|||
cert, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{ |
|||
SerialNumber: big.NewInt(0), |
|||
NotBefore: time.Now(), |
|||
NotAfter: time.Now().Add(1 * 24 * time.Hour), |
|||
}, &x509.Certificate{}, &privKey.PublicKey, privKey) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return [][]byte{cert}, nil |
|||
} |
|||
|
|||
// DefaultQuicConfig returns a MASQUE compatible default QUIC configuration with specified keep-alive period and initial packet size.
|
|||
//
|
|||
// Parameters:
|
|||
// - keepalivePeriod: time.Duration - The duration for sending QUIC keep-alive packets.
|
|||
// - initialPacketSize: uint16 - The initial size of QUIC packets. (1242 seems used by the original implementation)
|
|||
//
|
|||
// Returns:
|
|||
// - *quic.Config: A pointer to a configured QUIC configuration object.
|
|||
func DefaultQuicConfig(keepalivePeriod time.Duration, initialPacketSize uint16) *quic.Config { |
|||
return &quic.Config{ |
|||
EnableDatagrams: true, |
|||
InitialPacketSize: initialPacketSize, |
|||
KeepAlivePeriod: keepalivePeriod, |
|||
} |
|||
} |
|||
|
|||
// parsePortMapping is an internal helper function that parses a port mapping string into its components.
|
|||
//
|
|||
// It handles IPv6 addresses enclosed in brackets and various format edge cases.
|
|||
//
|
|||
// Parameters:
|
|||
// - port: string - The port mapping string.
|
|||
//
|
|||
// Returns:
|
|||
// - string: The bind address.
|
|||
// - int: The local port.
|
|||
// - string: The remote hostname/IP.
|
|||
// - int: The remote port.
|
|||
// - error: An error if parsing fails.
|
|||
func parsePortMapping(port string) (bindAddress string, localPort int, remoteHost string, remotePort int, err error) { |
|||
parts := strings.Split(port, ":") |
|||
|
|||
// Handle IPv6 addresses (which are enclosed in brackets)
|
|||
if len(parts) >= 4 && strings.HasPrefix(parts[0], "[") && strings.Contains(parts[0], "]") { |
|||
bindAddress = parts[0] |
|||
parts = parts[1:] // Shift parts forward
|
|||
} else if len(parts) == 3 { |
|||
bindAddress = "localhost" // Default to localhost
|
|||
} else if len(parts) == 4 { |
|||
bindAddress = parts[0] |
|||
parts = parts[1:] // Shift forward
|
|||
} else { |
|||
return "", 0, "", 0, errors.New("invalid port mapping format (expected format: [bind_address:]local_port:remote_host:remote_port)") |
|||
} |
|||
|
|||
// Parse local port
|
|||
localPort, err = strconv.Atoi(parts[0]) |
|||
if err != nil || localPort <= 0 || localPort > 65535 { |
|||
return "", 0, "", 0, errors.New("invalid local port") |
|||
} |
|||
|
|||
// Validate remote host (allow both hostnames and IPs)
|
|||
remoteHost = parts[1] |
|||
if net.ParseIP(remoteHost) == nil && !isValidHostname(remoteHost) { |
|||
return "", 0, "", 0, errors.New("invalid remote hostname/IP") |
|||
} |
|||
|
|||
// Parse remote port
|
|||
remotePort, err = strconv.Atoi(parts[2]) |
|||
if err != nil || remotePort <= 0 || remotePort > 65535 { |
|||
return "", 0, "", 0, errors.New("invalid remote port") |
|||
} |
|||
|
|||
// If bindAddress is an IPv6 address, remove brackets for proper binding
|
|||
if strings.HasPrefix(bindAddress, "[") && strings.HasSuffix(bindAddress, "]") { |
|||
bindAddress = strings.Trim(bindAddress, "[]") |
|||
} |
|||
|
|||
// Convert "localhost" or hostnames to actual addresses
|
|||
if bindAddress == "*" { |
|||
bindAddress = "0.0.0.0" // Allow all interfaces
|
|||
} |
|||
|
|||
// Validate bind address (support both IPs and hostnames)
|
|||
bindAddress, err = resolveBindAddress(bindAddress) |
|||
if err != nil { |
|||
return "", 0, "", 0, errors.New("invalid local address: " + err.Error()) |
|||
} |
|||
|
|||
remoteHost, err = resolveBindAddress(remoteHost) |
|||
if err != nil { |
|||
return "", 0, "", 0, errors.New("invalid remote address: " + err.Error()) |
|||
} |
|||
|
|||
return bindAddress, localPort, remoteHost, remotePort, nil |
|||
} |
|||
|
|||
// ParsePortMapping parses a port mapping string into a structured PortMapping.
|
|||
//
|
|||
// The expected format is: `[bind_address:]local_port:remote_host:remote_port`.
|
|||
//
|
|||
// Parameters:
|
|||
// - port: string - The port mapping string.
|
|||
//
|
|||
// Returns:
|
|||
// - PortMapping: A structured representation of the parsed port mapping.
|
|||
// - error: An error if the parsing fails.
|
|||
func ParsePortMapping(port string) (PortMapping, error) { |
|||
bindAddress, localPort, remoteHost, remotePort, err := parsePortMapping(port) |
|||
if err != nil { |
|||
return PortMapping{}, err |
|||
} |
|||
|
|||
return PortMapping{ |
|||
BindAddress: bindAddress, |
|||
LocalPort: localPort, |
|||
RemoteIP: remoteHost, |
|||
RemotePort: remotePort, |
|||
}, nil |
|||
} |
|||
|
|||
// resolveBindAddress resolves a hostname or IP to its string representation.
|
|||
//
|
|||
// Parameters:
|
|||
// - addr: string - The hostname or IP.
|
|||
//
|
|||
// Returns:
|
|||
// - string: The resolved IP address.
|
|||
// - error: An error if resolution fails.
|
|||
func resolveBindAddress(addr string) (string, error) { |
|||
tcpAddr, err := net.ResolveTCPAddr("tcp", addr+":0") // Resolve the address
|
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
return tcpAddr.IP.String(), nil // Return resolved IP
|
|||
} |
|||
|
|||
// isValidHostname checks if a given hostname is valid.
|
|||
// Pretty ugly for now, needs to be refactored.
|
|||
//
|
|||
// Parameters:
|
|||
// - hostname: string - The hostname to validate.
|
|||
//
|
|||
// Returns:
|
|||
// - bool: True if valid, false otherwise.
|
|||
func isValidHostname(hostname string) bool { |
|||
// Must contain at least one dot (.) unless it's "localhost"
|
|||
if hostname == "localhost" { |
|||
return true |
|||
} |
|||
return strings.Contains(hostname, ".") |
|||
} |
|||
|
|||
// LoginToBase64 encodes a username and password into a base64-encoded string in "username:password" format.
|
|||
// This is commonly used for HTTP Basic Authentication.
|
|||
//
|
|||
// Parameters:
|
|||
// - username: string - The username to encode.
|
|||
// - password: string - The password to encode.
|
|||
//
|
|||
// Returns:
|
|||
// - string: The base64-encoded "username:password" string.
|
|||
func LoginToBase64(username, password string) string { |
|||
return base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) |
|||
} |
|||
|
|||
// CheckIfname validates a network interface name according to the following rules:
|
|||
// - Must not be empty.
|
|||
// - Should not exceed 15 characters (warning if it does).
|
|||
// - Should not contain non-ASCII characters (warning if it does).
|
|||
// - Should not contain invalid characters: '/', whitespace, or control characters.
|
|||
//
|
|||
// Parameters:
|
|||
// - name: string - The interface name to validate.
|
|||
//
|
|||
// Returns:
|
|||
// - error: An error if the name is invalid, or nil if valid.
|
|||
func CheckIfname(name string) error { |
|||
if name == "" { |
|||
return errors.New("interface name cannot be empty") |
|||
} |
|||
|
|||
if len(name) >= 16 { |
|||
log.Printf("Warning: interface name '%s' is longer than %d characters", name, 16-1) |
|||
} |
|||
|
|||
var invalidChar bool |
|||
var hasWhitespace bool |
|||
|
|||
for _, r := range name { |
|||
if r > 127 { |
|||
invalidChar = true |
|||
break |
|||
} |
|||
if r == '/' || r == ' ' || strings.ContainsRune("\t\n\v\f\r", r) { |
|||
hasWhitespace = true |
|||
break |
|||
} |
|||
} |
|||
|
|||
if invalidChar { |
|||
log.Printf("Warning: interface name contains non-ASCII character") |
|||
} |
|||
|
|||
if hasWhitespace { |
|||
return errors.New("interface name contains invalid character: '/' or whitespace") |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
@ -0,0 +1,65 @@ |
|||
//go:build windows
|
|||
|
|||
package internal |
|||
|
|||
import ( |
|||
"fmt" |
|||
"log" |
|||
"os/exec" |
|||
) |
|||
|
|||
func SetIPv4Address(ifaceName, ipAddr, mask string) error { |
|||
cmd := exec.Command("netsh", "interface", "ipv4", "set", "address", |
|||
fmt.Sprintf("name=\"%s\"", ifaceName), |
|||
"static", ipAddr, mask) |
|||
|
|||
output, err := cmd.CombinedOutput() |
|||
if err != nil { |
|||
return fmt.Errorf("%s", output) |
|||
} |
|||
|
|||
log.Println("IPv4 address set successfully:", ipAddr) |
|||
return nil |
|||
} |
|||
|
|||
func SetIPv6Address(ifaceName, ipAddr, mask string) error { |
|||
cmd := exec.Command("netsh", "interface", "ipv6", "set", "address", |
|||
fmt.Sprintf("interface=\"%s\"", ifaceName), |
|||
ipAddr+"/"+mask) |
|||
|
|||
output, err := cmd.CombinedOutput() |
|||
if err != nil { |
|||
return fmt.Errorf("%s", output) |
|||
} |
|||
|
|||
log.Println("IPv6 address set successfully:", ipAddr) |
|||
return nil |
|||
} |
|||
|
|||
func SetIPv4MTU(ifaceName string, mtu int) error { |
|||
cmd := exec.Command("netsh", "interface", "ipv4", "set", "subinterface", |
|||
fmt.Sprintf("\"%s\"", ifaceName), |
|||
fmt.Sprintf("mtu=%d", mtu)) |
|||
|
|||
output, err := cmd.CombinedOutput() |
|||
if err != nil { |
|||
return fmt.Errorf("%s", output) |
|||
} |
|||
|
|||
log.Println("IPv4 MTU set successfully:", mtu) |
|||
return nil |
|||
} |
|||
|
|||
func SetIPv6MTU(ifaceName string, mtu int) error { |
|||
cmd := exec.Command("netsh", "interface", "ipv6", "set", "subinterface", |
|||
fmt.Sprintf("\"%s\"", ifaceName), |
|||
fmt.Sprintf("mtu=%d", mtu)) |
|||
|
|||
output, err := cmd.CombinedOutput() |
|||
if err != nil { |
|||
return fmt.Errorf("%s", output) |
|||
} |
|||
|
|||
log.Println("IPv6 MTU set successfully:", mtu) |
|||
return nil |
|||
} |
|||
@ -0,0 +1,56 @@ |
|||
package models |
|||
|
|||
// Known error messages from the API
|
|||
const ( |
|||
InvalidPublicKey = "Invalid public key" |
|||
) |
|||
|
|||
type APIError struct { |
|||
// not sure what type this is, so we will settle for interface{}
|
|||
// for now
|
|||
Result interface{} `json:"result"` |
|||
Success bool `json:"success"` |
|||
Errors []ErrorInfo `json:"errors"` |
|||
Messages []string `json:"messages"` |
|||
} |
|||
|
|||
type ErrorInfo struct { |
|||
Code int `json:"code"` |
|||
Message string `json:"message"` |
|||
} |
|||
|
|||
// ErrorsAsString returns a string representation of the errors in the APIError.
|
|||
// It concatenates the error messages into a single string, separated by semicolons.
|
|||
//
|
|||
// Parameters:
|
|||
// - separator: string - The string to use as a separator between error messages.
|
|||
//
|
|||
// Returns:
|
|||
// - string: A string containing all error messages, separated by the specified separator.
|
|||
func (e *APIError) ErrorsAsString(separator string) string { |
|||
var result string |
|||
for _, err := range e.Errors { |
|||
result += err.Message + separator |
|||
} |
|||
if len(result) > 0 { |
|||
return result[:len(result)-len(separator)] |
|||
} |
|||
return result |
|||
} |
|||
|
|||
// HasErrorMessage checks if the APIError contains a specific error message.
|
|||
// It returns true if the error message is found, otherwise false.
|
|||
//
|
|||
// Parameters:
|
|||
// - message: string - The error message to check for.
|
|||
//
|
|||
// Returns:
|
|||
// - bool: true if the error message is found, otherwise false.
|
|||
func (e *APIError) HasErrorMessage(message string) bool { |
|||
for _, err := range e.Errors { |
|||
if err.Message == message { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
@ -0,0 +1,8 @@ |
|||
package models |
|||
|
|||
type DeviceUpdate struct { |
|||
Key string `json:"key"` |
|||
KeyType string `json:"key_type"` |
|||
TunType string `json:"tunnel_type"` |
|||
Name string `json:"name,omitempty"` |
|||
} |
|||
@ -0,0 +1,102 @@ |
|||
package models |
|||
|
|||
type Registration struct { |
|||
Key string `json:"key"` |
|||
InstallID string `json:"install_id"` |
|||
FcmToken string `json:"fcm_token"` |
|||
Tos string `json:"tos"` |
|||
Model string `json:"model"` |
|||
Serial string `json:"serial_number"` |
|||
OsVersion string `json:"os_version"` |
|||
KeyType string `json:"key_type"` |
|||
TunType string `json:"tunnel_type"` |
|||
Locale string `json:"locale"` |
|||
} |
|||
|
|||
type AccountData struct { |
|||
ID string `json:"id"` |
|||
Type string `json:"type"` |
|||
Model string `json:"model"` |
|||
Name string `json:"name"` |
|||
Key string `json:"key"` |
|||
KeyType string `json:"key_type"` |
|||
TunType string `json:"tunnel_type"` |
|||
Account Account `json:"account"` |
|||
Config Config `json:"config"` |
|||
// WarpEnabled not set for ZeroTier
|
|||
WarpEnabled bool `json:"warp_enabled,omitempty"` |
|||
// Waitlist not set for ZeroTier
|
|||
Waitlist bool `json:"waitlist_enabled,omitempty"` |
|||
Created string `json:"created"` |
|||
Updated string `json:"updated"` |
|||
// Tos not set for ZeroTier
|
|||
Tos string `json:"tos,omitempty"` |
|||
// Place not set for ZeroTier
|
|||
Place int `json:"place,omitempty"` |
|||
Locale string `json:"locale"` |
|||
// Enabled not set for ZeroTier
|
|||
Enabled bool `json:"enabled,omitempty"` |
|||
InstallID string `json:"install_id"` |
|||
// Token only set for /reg call
|
|||
Token string `json:"token,omitempty"` |
|||
FcmToken string `json:"fcm_token"` |
|||
// SerialNumber not set for ZeroTier
|
|||
SerialNumber string `json:"serial_number,omitempty"` |
|||
Policy Policy `json:"policy"` |
|||
} |
|||
|
|||
type Account struct { |
|||
ID string `json:"id"` |
|||
AccountType string `json:"account_type"` |
|||
// Created not set for ZeroTier
|
|||
Created string `json:"created,omitempty"` |
|||
// Updated not set for ZeroTier
|
|||
Updated string `json:"updated,omitempty"` |
|||
// Managed only set for ZeroTier
|
|||
Managed string `json:"managed,omitempty"` |
|||
// Organization only set for ZeroTier
|
|||
Organization string `json:"organization,omitempty"` |
|||
// PremiumData not set for ZeroTier
|
|||
PremiumData int `json:"premium_data,omitempty"` |
|||
// Quota not set for ZeroTier
|
|||
Quota int `json:"quota,omitempty"` |
|||
// WarpPlus not set for ZeroTier
|
|||
WarpPlus bool `json:"warp_plus,omitempty"` |
|||
// ReferralCode not set for ZeroTier
|
|||
ReferralCount int `json:"referral_count,omitempty"` |
|||
// ReferralRenewalCount not set for ZeroTier
|
|||
ReferralRenewalCount int `json:"referral_renewal_countdown,omitempty"` |
|||
// Role not set for ZeroTier
|
|||
Role string `json:"role,omitempty"` |
|||
// License not set for ZeroTier
|
|||
License string `json:"license,omitempty"` |
|||
} |
|||
|
|||
type Config struct { |
|||
ClientID string `json:"client_id"` |
|||
Peers []Peer `json:"peers"` |
|||
Interface struct { |
|||
Addresses struct { |
|||
V4 string `json:"v4"` |
|||
V6 string `json:"v6"` |
|||
} `json:"addresses"` |
|||
} `json:"interface"` |
|||
Services struct { |
|||
HTTPProxy string `json:"http_proxy"` |
|||
} `json:"services"` |
|||
} |
|||
|
|||
type Peer struct { |
|||
PublicKey string `json:"public_key"` |
|||
Endpoint struct { |
|||
V4 string `json:"v4"` |
|||
V6 string `json:"v6"` |
|||
Host string `json:"host"` |
|||
Ports []int `json:"ports"` |
|||
} `json:"endpoint"` |
|||
} |
|||
|
|||
type Policy struct { |
|||
TunnelProtocol string `json:"tunnel_protocol"` |
|||
// TODO: add ZeroTier fields
|
|||
} |
|||
@ -0,0 +1,305 @@ |
|||
// Package proxy implements a mixed SOCKS5/HTTP proxy server that listens on a single
|
|||
// port and automatically detects the protocol from the first byte of each connection.
|
|||
package proxy |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
"io" |
|||
"log" |
|||
"net" |
|||
"net/http" |
|||
"net/netip" |
|||
"strings" |
|||
"sync/atomic" |
|||
"time" |
|||
|
|||
"github.com/cacggghp/vk-turn-proxy/client/warp/internal" |
|||
"github.com/things-go/go-socks5" |
|||
"golang.zx2c4.com/wireguard/tun/netstack" |
|||
) |
|||
|
|||
// netResolverAdapter wraps a *net.Resolver so it satisfies socks5.NameResolver.
|
|||
// Using *net.Resolver (via NewNetstackResolver) instead of TunnelDNSResolver gives us
|
|||
// Go's built-in UDP retry / exponential-backoff logic, which is far more resilient to
|
|||
// the packet loss inherent in a UDP-over-MASQUE-over-VK-TURN chain.
|
|||
type netResolverAdapter struct { |
|||
r *net.Resolver |
|||
} |
|||
|
|||
func (a netResolverAdapter) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { |
|||
dnsCtx, dnsCancel := context.WithTimeout(context.Background(), 45*time.Second) |
|||
defer dnsCancel() |
|||
ips, err := a.r.LookupIP(dnsCtx, "ip", name) |
|||
if err != nil || len(ips) == 0 { |
|||
log.Printf("[Warp] SOCKS5: DNS failed for %s: %v", name, err) |
|||
return ctx, nil, fmt.Errorf("DNS lookup %s: %w", name, err) |
|||
} |
|||
return ctx, ips[0], nil |
|||
} |
|||
|
|||
// MixedProxy listens on a single address and routes incoming connections to either
|
|||
// a SOCKS5 handler or an HTTP/CONNECT handler based on the first byte received.
|
|||
// Both protocols resolve DNS through the MASQUE tunnel via *net.Resolver,
|
|||
// which has built-in retry logic suitable for high-latency/lossy relay paths.
|
|||
type MixedProxy struct { |
|||
addr string |
|||
tunNet *netstack.Net |
|||
resolver *net.Resolver // tunnel-aware resolver (used by both HTTP and SOCKS5)
|
|||
socks netResolverAdapter // socks5.NameResolver adapter around the same resolver
|
|||
ready atomic.Bool // whether the tunnel is fully connected
|
|||
} |
|||
|
|||
// NewMixedProxy creates a new MixedProxy.
|
|||
//
|
|||
// Parameters:
|
|||
// - addr: The address to listen on (e.g. "127.0.0.1:4080").
|
|||
// - tunNet: The netstack network (from the MASQUE tunnel).
|
|||
// - dnsAddrs: DNS servers to use inside the tunnel (e.g. 162.159.36.1).
|
|||
// - localDNS: if true, use the system resolver instead of routing DNS through the tunnel.
|
|||
func NewMixedProxy(addr string, tunNet *netstack.Net, dnsAddrs []netip.Addr, localDNS bool) *MixedProxy { |
|||
var resolver *net.Resolver |
|||
if localDNS { |
|||
resolver = &net.Resolver{PreferGo: false} |
|||
log.Printf("[Warp] Using local (system) DNS resolver") |
|||
} else { |
|||
// Tunnel resolver — DNS goes through MASQUE to 162.159.36.1.
|
|||
resolver = internal.NewNetstackResolver(tunNet, dnsAddrs) |
|||
} |
|||
return &MixedProxy{ |
|||
addr: addr, |
|||
tunNet: tunNet, |
|||
resolver: resolver, |
|||
socks: netResolverAdapter{r: resolver}, |
|||
} |
|||
} |
|||
|
|||
// SetReady updates the tunnel connection state.
|
|||
// When false, the proxy quickly rejects pending connections.
|
|||
func (m *MixedProxy) SetReady(ready bool) { |
|||
m.ready.Store(ready) |
|||
} |
|||
|
|||
// ListenAndServe starts the mixed proxy server and blocks until the context is cancelled.
|
|||
func (m *MixedProxy) ListenAndServe(ctx context.Context) error { |
|||
listener, err := net.Listen("tcp", m.addr) |
|||
if err != nil { |
|||
return fmt.Errorf("mixed proxy: listen on %s: %w", m.addr, err) |
|||
} |
|||
defer listener.Close() |
|||
|
|||
context.AfterFunc(ctx, func() { _ = listener.Close() }) |
|||
|
|||
log.Printf("[Warp] Mixed proxy (SOCKS5+HTTP) listening on %s", m.addr) |
|||
|
|||
socksServer := socks5.NewServer( |
|||
socks5.WithLogger(socks5.NewLogger(log.New(io.Discard, "", 0))), |
|||
socks5.WithDial(func(sCtx context.Context, network, addr string) (net.Conn, error) { |
|||
return m.tunNet.DialContext(sCtx, network, addr) |
|||
}), |
|||
socks5.WithResolver(m.socks), |
|||
) |
|||
|
|||
for { |
|||
conn, err := listener.Accept() |
|||
if err != nil { |
|||
select { |
|||
case <-ctx.Done(): |
|||
return nil |
|||
default: |
|||
log.Printf("[Warp] Mixed proxy accept error: %v", err) |
|||
continue |
|||
} |
|||
} |
|||
go m.handleConn(ctx, conn, socksServer) |
|||
} |
|||
} |
|||
|
|||
// handleConn peeks at the first byte to detect protocol: 0x05 = SOCKS5, else HTTP.
|
|||
func (m *MixedProxy) handleConn(ctx context.Context, conn net.Conn, socksServer *socks5.Server) { |
|||
defer func() { |
|||
if r := recover(); r != nil { |
|||
log.Printf("[Warp] Mixed proxy panic: %v", r) |
|||
conn.Close() |
|||
} |
|||
}() |
|||
|
|||
buf := make([]byte, 1) |
|||
if _, err := io.ReadFull(conn, buf); err != nil { |
|||
conn.Close() |
|||
return |
|||
} |
|||
|
|||
peeked := &peekedConn{Conn: conn, buf: buf} |
|||
|
|||
if buf[0] == 0x05 { |
|||
if !m.ready.Load() { |
|||
log.Printf("[Warp] Rejecting SOCKS5 from %s (tunnel not ready)", conn.RemoteAddr()) |
|||
conn.Close() |
|||
return |
|||
} |
|||
if err := socksServer.ServeConn(peeked); err != nil { |
|||
log.Printf("[Warp] SOCKS5 error: %v", err) |
|||
} |
|||
return |
|||
} |
|||
|
|||
if !m.ready.Load() { |
|||
conn.Close() |
|||
return |
|||
} |
|||
|
|||
m.handleHTTP(ctx, peeked) |
|||
} |
|||
|
|||
// handleHTTP serves a single HTTP/CONNECT connection.
|
|||
func (m *MixedProxy) handleHTTP(ctx context.Context, conn net.Conn) { |
|||
defer conn.Close() |
|||
|
|||
server := &http.Server{ |
|||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
if r.Method == http.MethodConnect { |
|||
m.handleHTTPConnect(w, r) |
|||
} else { |
|||
m.handleHTTPPlain(w, r) |
|||
} |
|||
}), |
|||
} |
|||
_ = server.Serve(&oneConnListener{conn: conn}) |
|||
} |
|||
|
|||
func (m *MixedProxy) handleHTTPConnect(w http.ResponseWriter, r *http.Request) { |
|||
host, port, err := net.SplitHostPort(r.Host) |
|||
if err != nil { |
|||
http.Error(w, "invalid host", http.StatusBadRequest) |
|||
return |
|||
} |
|||
|
|||
dnsCtx, dnsCancel := context.WithTimeout(context.Background(), 45*time.Second) |
|||
defer dnsCancel() |
|||
ips, err := m.resolver.LookupIP(dnsCtx, "ip", host) |
|||
if err != nil || len(ips) == 0 { |
|||
log.Printf("[Warp] HTTP CONNECT: DNS failed for %s: %v", host, err) |
|||
http.Error(w, fmt.Sprintf("DNS failed for %s: %v", host, err), http.StatusServiceUnavailable) |
|||
return |
|||
} |
|||
destAddr := net.JoinHostPort(ips[0].String(), port) |
|||
|
|||
destConn, err := m.tunNet.DialContext(r.Context(), "tcp", destAddr) |
|||
if err != nil { |
|||
log.Printf("[Warp] HTTP CONNECT: tunnel dial failed for %s: %v", destAddr, err) |
|||
http.Error(w, fmt.Sprintf("tunnel dial failed: %v", err), http.StatusServiceUnavailable) |
|||
return |
|||
} |
|||
|
|||
hj, ok := w.(http.Hijacker) |
|||
if !ok { |
|||
http.Error(w, "hijacking not supported", http.StatusInternalServerError) |
|||
destConn.Close() |
|||
return |
|||
} |
|||
clientConn, _, err := hj.Hijack() |
|||
if err != nil { |
|||
destConn.Close() |
|||
return |
|||
} |
|||
|
|||
_, _ = clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) |
|||
go func() { |
|||
defer destConn.Close() |
|||
defer clientConn.Close() |
|||
_, _ = io.Copy(destConn, clientConn) |
|||
}() |
|||
_, _ = io.Copy(clientConn, destConn) |
|||
} |
|||
|
|||
// handleHTTPPlain handles plain HTTP proxy requests (GET, POST, etc.).
|
|||
// Mirrors the working implementation from the old http-proxy.
|
|||
func (m *MixedProxy) handleHTTPPlain(w http.ResponseWriter, r *http.Request) { |
|||
if !strings.HasPrefix(r.RequestURI, "http") { |
|||
http.Error(w, "only absolute URIs supported", http.StatusBadRequest) |
|||
return |
|||
} |
|||
|
|||
client := &http.Client{ |
|||
Transport: &http.Transport{ |
|||
DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) { |
|||
h, p, err := net.SplitHostPort(addr) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("invalid address: %w", err) |
|||
} |
|||
dnsCtx, dnsCancel := context.WithTimeout(context.Background(), 45*time.Second) |
|||
defer dnsCancel() |
|||
ips, err := m.resolver.LookupIP(dnsCtx, "ip", h) |
|||
if err != nil || len(ips) == 0 { |
|||
log.Printf("[Warp] HTTP plain: DNS failed for %s: %v", h, err) |
|||
return nil, fmt.Errorf("DNS failed for %s: %w", h, err) |
|||
} |
|||
return m.tunNet.DialContext(dialCtx, network, net.JoinHostPort(ips[0].String(), p)) |
|||
}, |
|||
}, |
|||
} |
|||
|
|||
req, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), r.Body) |
|||
if err != nil { |
|||
http.Error(w, "invalid request", http.StatusBadRequest) |
|||
return |
|||
} |
|||
req.Header = r.Header.Clone() |
|||
|
|||
resp, err := client.Do(req) |
|||
if err != nil { |
|||
log.Printf("[Warp] HTTP plain: upstream error for %s: %v", r.URL.Host, err) |
|||
http.Error(w, fmt.Sprintf("upstream error: %v", err), http.StatusServiceUnavailable) |
|||
return |
|||
} |
|||
defer resp.Body.Close() |
|||
|
|||
for k, vv := range resp.Header { |
|||
for _, v := range vv { |
|||
w.Header().Add(k, v) |
|||
} |
|||
} |
|||
w.WriteHeader(resp.StatusCode) |
|||
_, _ = io.Copy(w, resp.Body) |
|||
} |
|||
|
|||
// peekedConn wraps a net.Conn and re-injects already-read bytes into the stream.
|
|||
type peekedConn struct { |
|||
net.Conn |
|||
buf []byte |
|||
offset int |
|||
} |
|||
|
|||
func (p *peekedConn) Read(b []byte) (int, error) { |
|||
if p.offset < len(p.buf) { |
|||
n := copy(b, p.buf[p.offset:]) |
|||
p.offset += n |
|||
return n, nil |
|||
} |
|||
return p.Conn.Read(b) |
|||
} |
|||
|
|||
// oneConnListener serves a single pre-accepted connection to http.Server.Serve.
|
|||
type oneConnListener struct { |
|||
conn net.Conn |
|||
done chan struct{} |
|||
} |
|||
|
|||
func (l *oneConnListener) Accept() (net.Conn, error) { |
|||
if l.done == nil { |
|||
l.done = make(chan struct{}) |
|||
return l.conn, nil |
|||
} |
|||
<-l.done |
|||
return nil, fmt.Errorf("oneConnListener: done") |
|||
} |
|||
|
|||
func (l *oneConnListener) Close() error { |
|||
if l.done != nil { |
|||
close(l.done) |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (l *oneConnListener) Addr() net.Addr { return l.conn.LocalAddr() } |
|||
@ -0,0 +1,256 @@ |
|||
// Package warp provides the Warp (MASQUE/Cloudflare) mode runner for the vk-turn-proxy client.
|
|||
// It is activated by the -warp flag in the main binary and reuses existing vk-turn flags
|
|||
// for VK TURN relay integration (-vk-link, -listen, etc.).
|
|||
package warp |
|||
|
|||
import ( |
|||
"context" |
|||
"encoding/base64" |
|||
"fmt" |
|||
"log" |
|||
"net" |
|||
"net/netip" |
|||
"os" |
|||
"path/filepath" |
|||
"strings" |
|||
"time" |
|||
|
|||
"github.com/cacggghp/vk-turn-proxy/client/warp/api" |
|||
"github.com/cacggghp/vk-turn-proxy/client/warp/config" |
|||
"github.com/cacggghp/vk-turn-proxy/client/warp/internal" |
|||
"github.com/cacggghp/vk-turn-proxy/client/warp/proxy" |
|||
"golang.zx2c4.com/wireguard/tun/netstack" |
|||
) |
|||
|
|||
// RunnerConfig holds all parameters needed to start the Warp mode.
|
|||
type RunnerConfig struct { |
|||
// ConfigPath is the path to the Warp config.json.
|
|||
// If empty and the file is not found at the default path, registration is triggered.
|
|||
ConfigPath string |
|||
// ProxyAddr is the address for the mixed SOCKS5/HTTP proxy (e.g. "127.0.0.1:4080").
|
|||
ProxyAddr string |
|||
// GetRelayConn is an optional function that provides a pre-allocated TURN relay connection.
|
|||
// If nil, Warp connects directly to Cloudflare's MASQUE endpoint.
|
|||
GetRelayConn api.GetRelayConnFunc |
|||
// ConnectPort is the port for the MASQUE QUIC connection (default 443).
|
|||
ConnectPort int |
|||
// UseIPv6 selects the IPv6 endpoint instead of IPv4 for the MASQUE connection.
|
|||
UseIPv6 bool |
|||
// KeepalivePeriod is the QUIC keepalive interval.
|
|||
KeepalivePeriod time.Duration |
|||
// InitialPacketSize is the initial QUIC packet size.
|
|||
InitialPacketSize uint16 |
|||
// ReconnectDelay is the delay between tunnel reconnect attempts.
|
|||
ReconnectDelay time.Duration |
|||
// MTU is the MTU for the virtual TUN device.
|
|||
MTU int |
|||
// LocalDNS skips tunnel DNS and uses the system resolver.
|
|||
// Useful when 162.159.36.1 is unreachable over the TURN relay.
|
|||
LocalDNS bool |
|||
// Debug enables verbose logging in the warp/api package.
|
|||
Debug bool |
|||
} |
|||
|
|||
// DefaultRunnerConfig returns a RunnerConfig with sensible defaults.
|
|||
func DefaultRunnerConfig() RunnerConfig { |
|||
return RunnerConfig{ |
|||
ConfigPath: "config.json", |
|||
ProxyAddr: "127.0.0.1:4080", |
|||
ConnectPort: 443, |
|||
UseIPv6: false, |
|||
KeepalivePeriod: 30 * time.Second, |
|||
InitialPacketSize: 1242, |
|||
ReconnectDelay: 1 * time.Second, |
|||
MTU: 1200, // Lowered to avoid fragmentation over TURN relay
|
|||
} |
|||
} |
|||
|
|||
// Run starts the Warp-in-VK-TURN mode.
|
|||
// It handles config loading/registration, then starts the MASQUE tunnel and mixed proxy.
|
|||
func Run(ctx context.Context, cfg RunnerConfig) error { |
|||
// 1. Resolve config path to absolute
|
|||
cfgPath, err := resolveConfigPath(cfg.ConfigPath) |
|||
if err != nil { |
|||
return fmt.Errorf("warp: resolve config path: %w", err) |
|||
} |
|||
|
|||
// 2. Try to load config
|
|||
if err := config.LoadConfig(cfgPath); err != nil { |
|||
if cfg.ConfigPath != "" && cfg.ConfigPath != "config.json" { |
|||
// User explicitly specified a config path — error out
|
|||
return fmt.Errorf("warp: config file not found at %s: %w", cfgPath, err) |
|||
} |
|||
// Default path not found — start interactive registration
|
|||
log.Printf("[Warp] Config not found at %s. Starting registration...", cfgPath) |
|||
if err := runInteractiveRegistration(cfgPath); err != nil { |
|||
return fmt.Errorf("warp: registration failed: %w", err) |
|||
} |
|||
} |
|||
|
|||
// 3. Prepare TLS keys from config
|
|||
privKey, err := config.AppConfig.GetEcPrivateKey() |
|||
if err != nil { |
|||
return fmt.Errorf("warp: get private key: %w", err) |
|||
} |
|||
peerPubKey, err := config.AppConfig.GetEcEndpointPublicKey() |
|||
if err != nil { |
|||
return fmt.Errorf("warp: get peer public key: %w", err) |
|||
} |
|||
cert, err := internal.GenerateCert(privKey, &privKey.PublicKey) |
|||
if err != nil { |
|||
return fmt.Errorf("warp: generate cert: %w", err) |
|||
} |
|||
|
|||
tlsConfig, err := api.PrepareTlsConfig(privKey, peerPubKey, cert, internal.ConnectSNI) |
|||
if err != nil { |
|||
return fmt.Errorf("warp: prepare TLS config: %w", err) |
|||
} |
|||
|
|||
// 4. Determine MASQUE endpoint
|
|||
connectPort := cfg.ConnectPort |
|||
if connectPort <= 0 { |
|||
connectPort = 443 |
|||
} |
|||
var endpoint *net.UDPAddr |
|||
if cfg.UseIPv6 { |
|||
addr := net.JoinHostPort(config.AppConfig.EndpointV6, fmt.Sprint(connectPort)) |
|||
endpoint, err = net.ResolveUDPAddr("udp", addr) |
|||
if err != nil { |
|||
return fmt.Errorf("warp: resolve IPv6 endpoint: %w", err) |
|||
} |
|||
} else { |
|||
addr := net.JoinHostPort(config.AppConfig.EndpointV4, fmt.Sprint(connectPort)) |
|||
endpoint, err = net.ResolveUDPAddr("udp", addr) |
|||
if err != nil { |
|||
return fmt.Errorf("warp: resolve IPv4 endpoint: %w", err) |
|||
} |
|||
if ip4 := endpoint.IP.To4(); ip4 != nil { |
|||
endpoint.IP = ip4 |
|||
} |
|||
} |
|||
|
|||
// DNS addresses: Cloudflare WARP intentionally blocks/drops most regular UDP port 53 traffic
|
|||
// over MASQUE tunnels to public servers (like 9.9.9.9 or 1.1.1.1) to enforce their DoH proxy.
|
|||
// You MUST use their internal designated DNS forwarder: 162.159.36.1
|
|||
dnsAddrs := []netip.Addr{ |
|||
netip.MustParseAddr("162.159.36.1"), |
|||
netip.MustParseAddr("1.1.1.1"), |
|||
netip.MustParseAddr("1.0.0.1"), |
|||
} |
|||
var localAddresses []netip.Addr |
|||
parseInternalIP := func(s string) (netip.Addr, error) { |
|||
// Strip mask if present (e.g. 172.16.0.2/32)
|
|||
if i := strings.Index(s, "/"); i != -1 { |
|||
s = s[:i] |
|||
} |
|||
return netip.ParseAddr(s) |
|||
} |
|||
if v4, err := parseInternalIP(config.AppConfig.IPv4); err == nil { |
|||
localAddresses = append(localAddresses, v4) |
|||
} |
|||
if v6, err := parseInternalIP(config.AppConfig.IPv6); err == nil { |
|||
localAddresses = append(localAddresses, v6) |
|||
} |
|||
|
|||
api.Verbose = cfg.Debug |
|||
tunDev, tunNet, err := netstack.CreateNetTUN(localAddresses, dnsAddrs, cfg.MTU) |
|||
if err != nil { |
|||
return fmt.Errorf("warp: create virtual TUN: %w", err) |
|||
} |
|||
defer tunDev.Close() |
|||
|
|||
// 6. Init mixed proxy so we can pass its SetReady callback
|
|||
mp := proxy.NewMixedProxy(cfg.ProxyAddr, tunNet, dnsAddrs, cfg.LocalDNS) |
|||
|
|||
// 7. Start tunnel maintenance in background
|
|||
log.Printf("[Warp] Starting MASQUE tunnel to %s (via TURN: %v)", endpoint, cfg.GetRelayConn != nil) |
|||
go api.MaintainTunnel( |
|||
ctx, |
|||
tlsConfig, |
|||
cfg.KeepalivePeriod, |
|||
cfg.InitialPacketSize, |
|||
endpoint, |
|||
api.NewNetstackAdapter(tunDev), |
|||
cfg.MTU, |
|||
cfg.ReconnectDelay, |
|||
cfg.GetRelayConn, |
|||
mp.SetReady, |
|||
) |
|||
|
|||
// 8. Start mixed proxy listener (blocks until cancelled)
|
|||
// Both SOCKS5 and HTTP resolve DNS through the MASQUE tunnel via TunnelDNSResolver,
|
|||
// then dial tunNet with the resolved IP — matching the working httpproxy.go pattern.
|
|||
return mp.ListenAndServe(ctx) |
|||
} |
|||
|
|||
// resolveConfigPath returns the absolute path for the config file.
|
|||
// If the path is relative, it is resolved relative to the executable's directory.
|
|||
func resolveConfigPath(cfgPath string) (string, error) { |
|||
if filepath.IsAbs(cfgPath) { |
|||
return cfgPath, nil |
|||
} |
|||
// Try CWD first
|
|||
if _, err := os.Stat(cfgPath); err == nil { |
|||
abs, err := filepath.Abs(cfgPath) |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
return abs, nil |
|||
} |
|||
// Fall back to executable directory (useful on Android/embedded)
|
|||
exePath, err := os.Executable() |
|||
if err != nil { |
|||
return cfgPath, nil //nolint:nilerr — best effort
|
|||
} |
|||
return filepath.Join(filepath.Dir(exePath), cfgPath), nil |
|||
} |
|||
|
|||
// runInteractiveRegistration runs the interactive Cloudflare WARP registration flow.
|
|||
// It asks the user to accept TOS and choose a device name, then saves the config.
|
|||
func runInteractiveRegistration(cfgPath string) error { |
|||
log.Printf("[Warp] === Cloudflare WARP Registration ===") |
|||
|
|||
// Register (will prompt for TOS internally inside api.Register)
|
|||
accountData, err := api.Register(internal.DefaultModel, internal.DefaultLocale, "", false /* acceptTos — prompt inside */) |
|||
if err != nil { |
|||
return fmt.Errorf("register: %w", err) |
|||
} |
|||
|
|||
fmt.Print("[Warp] Enter device name (leave empty for default): ") |
|||
var deviceName string |
|||
_, _ = fmt.Scanln(&deviceName) |
|||
|
|||
privKey, pubKey, err := internal.GenerateEcKeyPair() |
|||
if err != nil { |
|||
return fmt.Errorf("generate key pair: %w", err) |
|||
} |
|||
|
|||
log.Printf("[Warp] Enrolling device key...") |
|||
updatedAccountData, apiErr, err := api.EnrollKey(accountData, pubKey, deviceName) |
|||
if err != nil { |
|||
if apiErr != nil { |
|||
return fmt.Errorf("enroll key: %v (API errors: %s)", err, apiErr.ErrorsAsString("; ")) |
|||
} |
|||
return fmt.Errorf("enroll key: %w", err) |
|||
} |
|||
|
|||
log.Printf("[Warp] Registration successful. Saving config to %s...", cfgPath) |
|||
config.AppConfig = config.Config{ |
|||
PrivateKey: base64.StdEncoding.EncodeToString(privKey), |
|||
EndpointV4: updatedAccountData.Config.Peers[0].Endpoint.V4[:len(updatedAccountData.Config.Peers[0].Endpoint.V4)-2], |
|||
EndpointV6: updatedAccountData.Config.Peers[0].Endpoint.V6[1 : len(updatedAccountData.Config.Peers[0].Endpoint.V6)-3], |
|||
EndpointPubKey: updatedAccountData.Config.Peers[0].PublicKey, |
|||
License: updatedAccountData.Account.License, |
|||
ID: updatedAccountData.ID, |
|||
AccessToken: accountData.Token, |
|||
IPv4: updatedAccountData.Config.Interface.Addresses.V4, |
|||
IPv6: updatedAccountData.Config.Interface.Addresses.V6, |
|||
} |
|||
|
|||
if err := config.AppConfig.SaveConfig(cfgPath); err != nil { |
|||
return fmt.Errorf("save config: %w", err) |
|||
} |
|||
config.ConfigLoaded = true |
|||
log.Printf("[Warp] Config saved successfully.") |
|||
return nil |
|||
} |
|||
Loading…
Reference in new issue