Add support for Multipath TCP

This commit is contained in:
Neil Alexander 2023-08-12 17:50:24 +01:00
parent fec96a38a4
commit ed8991571a
No known key found for this signature in database
GPG Key ID: A02A2019A2BB0944
9 changed files with 108 additions and 48 deletions

View File

@ -24,6 +24,7 @@ type PeerEntry struct {
PublicKey string `json:"key"`
Port uint64 `json:"port"`
Priority uint64 `json:"priority"`
Multipath bool `json:"multipath,omitempty"`
RXBytes DataUnit `json:"bytes_recvd,omitempty"`
TXBytes DataUnit `json:"bytes_sent,omitempty"`
Uptime float64 `json:"uptime,omitempty"`
@ -37,14 +38,15 @@ func (a *AdminSocket) getPeersHandler(req *GetPeersRequest, res *GetPeersRespons
res.Peers = make([]PeerEntry, 0, len(peers))
for _, p := range peers {
peer := PeerEntry{
Port: p.Port,
Up: p.Up,
Inbound: p.Inbound,
Priority: uint64(p.Priority), // can't be uint8 thanks to gobind
URI: p.URI,
RXBytes: DataUnit(p.RXBytes),
TXBytes: DataUnit(p.TXBytes),
Uptime: p.Uptime.Seconds(),
Port: p.Port,
Up: p.Up,
Inbound: p.Inbound,
Priority: uint64(p.Priority), // can't be uint8 thanks to gobind
Multipath: p.Multipath,
URI: p.URI,
RXBytes: DataUnit(p.RXBytes),
TXBytes: DataUnit(p.TXBytes),
Uptime: p.Uptime.Seconds(),
}
if p.Latency > 0 {
peer.Latency = p.Latency

View File

@ -30,6 +30,7 @@ type PeerInfo struct {
Coords []uint64
Port uint64
Priority uint8
Multipath bool
RXBytes uint64
TXBytes uint64
Uptime time.Duration
@ -87,6 +88,7 @@ func (c *Core) GetPeers() []PeerInfo {
peerinfo.RXBytes = atomic.LoadUint64(&c.rx)
peerinfo.TXBytes = atomic.LoadUint64(&c.tx)
peerinfo.Uptime = time.Since(c.up)
peerinfo.Multipath = isMPTCP(c)
}
if p, ok := conns[conn]; ok {
peerinfo.Key = p.Key

View File

@ -44,7 +44,7 @@ type links struct {
type linkProtocol interface {
dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error)
listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error)
listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error)
}
// linkInfo is used as a map key
@ -72,6 +72,7 @@ type linkOptions struct {
tlsSNI string
password []byte
maxBackoff time.Duration
multipath bool
}
type Listener struct {
@ -140,6 +141,7 @@ const ErrLinkPinnedKeyInvalid = linkError("pinned public key is invalid")
const ErrLinkPasswordInvalid = linkError("password is invalid")
const ErrLinkUnrecognisedSchema = linkError("link schema unknown")
const ErrLinkMaxBackoffInvalid = linkError("max backoff duration invalid")
const ErrLinkMultipathInvalid = linkError("multipath invalid")
func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
var retErr error
@ -193,6 +195,17 @@ func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
}
options.maxBackoff = d
}
if p := u.Query().Get("multipath"); p != "" {
switch p {
case "true", "1":
options.multipath = true
case "false", "0":
options.multipath = false
default:
retErr = ErrLinkMultipathInvalid
return
}
}
// SNI headers must contain hostnames and not IP addresses, so we must make sure
// that we do not populate the SNI with an IP literal. We do this by splitting
// the host-port combo from the query option and then seeing if it parses to an
@ -379,7 +392,7 @@ func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
return retErr
}
func (l *links) remove(u *url.URL, sintf string, linkType linkType) error {
func (l *links) remove(u *url.URL, sintf string, _ linkType) error {
var retErr error
phony.Block(l, func() {
// Generate the link info and see whether we think we already
@ -422,7 +435,36 @@ func (l *links) listen(u *url.URL, sintf string) (*Listener, error) {
cancel()
return nil, ErrLinkUnrecognisedSchema
}
listener, err := protocol.listen(ctx, u, sintf)
var options linkOptions
if p := u.Query().Get("priority"); p != "" {
pi, err := strconv.ParseUint(p, 10, 8)
if err != nil {
cancel()
return nil, ErrLinkPriorityInvalid
}
options.priority = uint8(pi)
}
if p := u.Query().Get("password"); p != "" {
if len(p) > blake2b.Size {
cancel()
return nil, ErrLinkPasswordInvalid
}
options.password = []byte(p)
}
if p := u.Query().Get("multipath"); p != "" {
switch p {
case "true", "1":
options.multipath = true
case "false", "0":
options.multipath = false
default:
cancel()
return nil, ErrLinkMultipathInvalid
}
}
listener, err := protocol.listen(ctx, u, sintf, options)
if err != nil {
cancel()
return nil, err
@ -433,21 +475,6 @@ func (l *links) listen(u *url.URL, sintf string) (*Listener, error) {
Cancel: cancel,
}
var options linkOptions
if p := u.Query().Get("priority"); p != "" {
pi, err := strconv.ParseUint(p, 10, 8)
if err != nil {
return nil, ErrLinkPriorityInvalid
}
options.priority = uint8(pi)
}
if p := u.Query().Get("password"); p != "" {
if len(p) > blake2b.Size {
return nil, ErrLinkPasswordInvalid
}
options.password = []byte(p)
}
go func() {
l.core.log.Infof("%s listener started on %s", strings.ToUpper(u.Scheme), listener.Addr())
defer l.core.log.Infof("%s listener stopped on %s", strings.ToUpper(u.Scheme), listener.Addr())
@ -567,7 +594,7 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn, s
switch {
case err != nil:
return fmt.Errorf("write handshake: %w", err)
case err == nil && n != len(metaBytes):
case n != len(metaBytes):
return fmt.Errorf("incomplete handshake send")
}
meta = version_metadata{}

View File

@ -65,7 +65,7 @@ func (l *linkQUIC) dial(ctx context.Context, url *url.URL, info linkInfo, option
}, nil
}
func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {
func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
ql, err := quic.ListenAddr(url.Host, l.tlsconfig, l.quicconfig)
if err != nil {
return nil, err

View File

@ -47,6 +47,6 @@ func (l *linkSOCKS) dial(_ context.Context, url *url.URL, info linkInfo, options
return conn, nil
}
func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {
func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
return nil, fmt.Errorf("SOCKS listener not supported")
}

View File

@ -36,7 +36,7 @@ type tcpDialer struct {
addr *net.TCPAddr
}
func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error) {
func (l *linkTCP) dialersFor(url *url.URL, info linkInfo, options linkOptions) ([]*tcpDialer, error) {
host, p, err := net.SplitHostPort(url.Host)
if err != nil {
return nil, err
@ -55,7 +55,7 @@ func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error)
IP: ip,
Port: port,
}
dialer, err := l.dialerFor(addr, info.sintf)
dialer, err := l.dialerFor(addr, info.sintf, options.multipath)
if err != nil {
continue
}
@ -69,7 +69,7 @@ func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error)
}
func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
dialers, err := l.dialersFor(url, info)
dialers, err := l.dialersFor(url, info, options)
if err != nil {
return nil, err
}
@ -88,17 +88,21 @@ func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options
return nil, err
}
func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) {
func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) {
hostport := url.Host
if sintf != "" {
if host, port, err := net.SplitHostPort(hostport); err == nil {
hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port)
}
}
return l.listenconfig.Listen(ctx, "tcp", hostport)
lc := *l.listenconfig
if options.multipath {
setMPTCPForListener(&lc)
}
return lc.Listen(ctx, "tcp", hostport)
}
func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error) {
func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string, mptcp bool) (*net.Dialer, error) {
if dst.IP.IsLinkLocalUnicast() {
if sintf != "" {
dst.Zone = sintf
@ -112,6 +116,9 @@ func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error)
KeepAlive: -1,
Control: l.tcpContext,
}
if mptcp {
setMPTCPForDialer(dialer)
}
if sintf != "" {
dialer.Control = l.getControl(sintf)
ief, err := net.InterfaceByName(sintf)

View File

@ -0,0 +1,30 @@
package core
import (
"crypto/tls"
"net"
)
func setMPTCPForDialer(d *net.Dialer) {
d.SetMultipathTCP(true)
}
func setMPTCPForListener(lc *net.ListenConfig) {
lc.SetMultipathTCP(true)
}
func isMPTCP(c net.Conn) bool {
switch tc := c.(type) {
case *net.TCPConn:
mp, _ := tc.MultipathTCP()
return mp
case *tls.Conn:
if tc, ok := tc.NetConn().(*net.TCPConn); ok {
mp, _ := tc.MultipathTCP()
return mp
}
return false
default:
return false
}
}

View File

@ -3,7 +3,6 @@ package core
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/url"
@ -34,7 +33,7 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS {
}
func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
dialers, err := l.tcp.dialersFor(url, info)
dialers, err := l.tcp.dialersFor(url, info, options)
if err != nil {
return nil, err
}
@ -58,17 +57,10 @@ func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options
return nil, err
}
func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) {
hostport := url.Host
if sintf != "" {
if host, port, err := net.SplitHostPort(hostport); err == nil {
hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port)
}
}
listener, err := l.listener.Listen(ctx, "tcp", hostport)
func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) {
listener, err := l.tcp.listen(ctx, url, sintf, options)
if err != nil {
return nil, err
}
tlslistener := tls.NewListener(listener, l.config)
return tlslistener, nil
return tls.NewListener(listener, l.config), nil
}

View File

@ -40,6 +40,6 @@ func (l *linkUNIX) dial(ctx context.Context, url *url.URL, info linkInfo, option
return l.dialer.DialContext(ctx, "unix", addr.String())
}
func (l *linkUNIX) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {
func (l *linkUNIX) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
return l.listener.Listen(ctx, "unix", url.Path)
}