Add support for Multipath TCP

This commit is contained in:
Neil Alexander 2023-08-12 17:50:24 +01:00
parent fec96a38a4
commit ed8991571a
No known key found for this signature in database
GPG Key ID: A02A2019A2BB0944
9 changed files with 108 additions and 48 deletions

View File

@ -24,6 +24,7 @@ type PeerEntry struct {
PublicKey string `json:"key"` PublicKey string `json:"key"`
Port uint64 `json:"port"` Port uint64 `json:"port"`
Priority uint64 `json:"priority"` Priority uint64 `json:"priority"`
Multipath bool `json:"multipath,omitempty"`
RXBytes DataUnit `json:"bytes_recvd,omitempty"` RXBytes DataUnit `json:"bytes_recvd,omitempty"`
TXBytes DataUnit `json:"bytes_sent,omitempty"` TXBytes DataUnit `json:"bytes_sent,omitempty"`
Uptime float64 `json:"uptime,omitempty"` Uptime float64 `json:"uptime,omitempty"`
@ -41,6 +42,7 @@ func (a *AdminSocket) getPeersHandler(req *GetPeersRequest, res *GetPeersRespons
Up: p.Up, Up: p.Up,
Inbound: p.Inbound, Inbound: p.Inbound,
Priority: uint64(p.Priority), // can't be uint8 thanks to gobind Priority: uint64(p.Priority), // can't be uint8 thanks to gobind
Multipath: p.Multipath,
URI: p.URI, URI: p.URI,
RXBytes: DataUnit(p.RXBytes), RXBytes: DataUnit(p.RXBytes),
TXBytes: DataUnit(p.TXBytes), TXBytes: DataUnit(p.TXBytes),

View File

@ -30,6 +30,7 @@ type PeerInfo struct {
Coords []uint64 Coords []uint64
Port uint64 Port uint64
Priority uint8 Priority uint8
Multipath bool
RXBytes uint64 RXBytes uint64
TXBytes uint64 TXBytes uint64
Uptime time.Duration Uptime time.Duration
@ -87,6 +88,7 @@ func (c *Core) GetPeers() []PeerInfo {
peerinfo.RXBytes = atomic.LoadUint64(&c.rx) peerinfo.RXBytes = atomic.LoadUint64(&c.rx)
peerinfo.TXBytes = atomic.LoadUint64(&c.tx) peerinfo.TXBytes = atomic.LoadUint64(&c.tx)
peerinfo.Uptime = time.Since(c.up) peerinfo.Uptime = time.Since(c.up)
peerinfo.Multipath = isMPTCP(c)
} }
if p, ok := conns[conn]; ok { if p, ok := conns[conn]; ok {
peerinfo.Key = p.Key peerinfo.Key = p.Key

View File

@ -44,7 +44,7 @@ type links struct {
type linkProtocol interface { type linkProtocol interface {
dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error)
listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error)
} }
// linkInfo is used as a map key // linkInfo is used as a map key
@ -72,6 +72,7 @@ type linkOptions struct {
tlsSNI string tlsSNI string
password []byte password []byte
maxBackoff time.Duration maxBackoff time.Duration
multipath bool
} }
type Listener struct { type Listener struct {
@ -140,6 +141,7 @@ const ErrLinkPinnedKeyInvalid = linkError("pinned public key is invalid")
const ErrLinkPasswordInvalid = linkError("password is invalid") const ErrLinkPasswordInvalid = linkError("password is invalid")
const ErrLinkUnrecognisedSchema = linkError("link schema unknown") const ErrLinkUnrecognisedSchema = linkError("link schema unknown")
const ErrLinkMaxBackoffInvalid = linkError("max backoff duration invalid") const ErrLinkMaxBackoffInvalid = linkError("max backoff duration invalid")
const ErrLinkMultipathInvalid = linkError("multipath invalid")
func (l *links) add(u *url.URL, sintf string, linkType linkType) error { func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
var retErr error var retErr error
@ -193,6 +195,17 @@ func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
} }
options.maxBackoff = d options.maxBackoff = d
} }
if p := u.Query().Get("multipath"); p != "" {
switch p {
case "true", "1":
options.multipath = true
case "false", "0":
options.multipath = false
default:
retErr = ErrLinkMultipathInvalid
return
}
}
// SNI headers must contain hostnames and not IP addresses, so we must make sure // SNI headers must contain hostnames and not IP addresses, so we must make sure
// that we do not populate the SNI with an IP literal. We do this by splitting // that we do not populate the SNI with an IP literal. We do this by splitting
// the host-port combo from the query option and then seeing if it parses to an // the host-port combo from the query option and then seeing if it parses to an
@ -379,7 +392,7 @@ func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
return retErr return retErr
} }
func (l *links) remove(u *url.URL, sintf string, linkType linkType) error { func (l *links) remove(u *url.URL, sintf string, _ linkType) error {
var retErr error var retErr error
phony.Block(l, func() { phony.Block(l, func() {
// Generate the link info and see whether we think we already // Generate the link info and see whether we think we already
@ -422,7 +435,36 @@ func (l *links) listen(u *url.URL, sintf string) (*Listener, error) {
cancel() cancel()
return nil, ErrLinkUnrecognisedSchema return nil, ErrLinkUnrecognisedSchema
} }
listener, err := protocol.listen(ctx, u, sintf)
var options linkOptions
if p := u.Query().Get("priority"); p != "" {
pi, err := strconv.ParseUint(p, 10, 8)
if err != nil {
cancel()
return nil, ErrLinkPriorityInvalid
}
options.priority = uint8(pi)
}
if p := u.Query().Get("password"); p != "" {
if len(p) > blake2b.Size {
cancel()
return nil, ErrLinkPasswordInvalid
}
options.password = []byte(p)
}
if p := u.Query().Get("multipath"); p != "" {
switch p {
case "true", "1":
options.multipath = true
case "false", "0":
options.multipath = false
default:
cancel()
return nil, ErrLinkMultipathInvalid
}
}
listener, err := protocol.listen(ctx, u, sintf, options)
if err != nil { if err != nil {
cancel() cancel()
return nil, err return nil, err
@ -433,21 +475,6 @@ func (l *links) listen(u *url.URL, sintf string) (*Listener, error) {
Cancel: cancel, Cancel: cancel,
} }
var options linkOptions
if p := u.Query().Get("priority"); p != "" {
pi, err := strconv.ParseUint(p, 10, 8)
if err != nil {
return nil, ErrLinkPriorityInvalid
}
options.priority = uint8(pi)
}
if p := u.Query().Get("password"); p != "" {
if len(p) > blake2b.Size {
return nil, ErrLinkPasswordInvalid
}
options.password = []byte(p)
}
go func() { go func() {
l.core.log.Infof("%s listener started on %s", strings.ToUpper(u.Scheme), listener.Addr()) l.core.log.Infof("%s listener started on %s", strings.ToUpper(u.Scheme), listener.Addr())
defer l.core.log.Infof("%s listener stopped on %s", strings.ToUpper(u.Scheme), listener.Addr()) defer l.core.log.Infof("%s listener stopped on %s", strings.ToUpper(u.Scheme), listener.Addr())
@ -567,7 +594,7 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn, s
switch { switch {
case err != nil: case err != nil:
return fmt.Errorf("write handshake: %w", err) return fmt.Errorf("write handshake: %w", err)
case err == nil && n != len(metaBytes): case n != len(metaBytes):
return fmt.Errorf("incomplete handshake send") return fmt.Errorf("incomplete handshake send")
} }
meta = version_metadata{} meta = version_metadata{}

View File

@ -65,7 +65,7 @@ func (l *linkQUIC) dial(ctx context.Context, url *url.URL, info linkInfo, option
}, nil }, nil
} }
func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
ql, err := quic.ListenAddr(url.Host, l.tlsconfig, l.quicconfig) ql, err := quic.ListenAddr(url.Host, l.tlsconfig, l.quicconfig)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -47,6 +47,6 @@ func (l *linkSOCKS) dial(_ context.Context, url *url.URL, info linkInfo, options
return conn, nil return conn, nil
} }
func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
return nil, fmt.Errorf("SOCKS listener not supported") return nil, fmt.Errorf("SOCKS listener not supported")
} }

View File

@ -36,7 +36,7 @@ type tcpDialer struct {
addr *net.TCPAddr addr *net.TCPAddr
} }
func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error) { func (l *linkTCP) dialersFor(url *url.URL, info linkInfo, options linkOptions) ([]*tcpDialer, error) {
host, p, err := net.SplitHostPort(url.Host) host, p, err := net.SplitHostPort(url.Host)
if err != nil { if err != nil {
return nil, err return nil, err
@ -55,7 +55,7 @@ func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error)
IP: ip, IP: ip,
Port: port, Port: port,
} }
dialer, err := l.dialerFor(addr, info.sintf) dialer, err := l.dialerFor(addr, info.sintf, options.multipath)
if err != nil { if err != nil {
continue continue
} }
@ -69,7 +69,7 @@ func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error)
} }
func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
dialers, err := l.dialersFor(url, info) dialers, err := l.dialersFor(url, info, options)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -88,17 +88,21 @@ func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options
return nil, err return nil, err
} }
func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) { func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) {
hostport := url.Host hostport := url.Host
if sintf != "" { if sintf != "" {
if host, port, err := net.SplitHostPort(hostport); err == nil { if host, port, err := net.SplitHostPort(hostport); err == nil {
hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port) hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port)
} }
} }
return l.listenconfig.Listen(ctx, "tcp", hostport) lc := *l.listenconfig
if options.multipath {
setMPTCPForListener(&lc)
}
return lc.Listen(ctx, "tcp", hostport)
} }
func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error) { func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string, mptcp bool) (*net.Dialer, error) {
if dst.IP.IsLinkLocalUnicast() { if dst.IP.IsLinkLocalUnicast() {
if sintf != "" { if sintf != "" {
dst.Zone = sintf dst.Zone = sintf
@ -112,6 +116,9 @@ func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error)
KeepAlive: -1, KeepAlive: -1,
Control: l.tcpContext, Control: l.tcpContext,
} }
if mptcp {
setMPTCPForDialer(dialer)
}
if sintf != "" { if sintf != "" {
dialer.Control = l.getControl(sintf) dialer.Control = l.getControl(sintf)
ief, err := net.InterfaceByName(sintf) ief, err := net.InterfaceByName(sintf)

View File

@ -0,0 +1,30 @@
package core
import (
"crypto/tls"
"net"
)
func setMPTCPForDialer(d *net.Dialer) {
d.SetMultipathTCP(true)
}
func setMPTCPForListener(lc *net.ListenConfig) {
lc.SetMultipathTCP(true)
}
func isMPTCP(c net.Conn) bool {
switch tc := c.(type) {
case *net.TCPConn:
mp, _ := tc.MultipathTCP()
return mp
case *tls.Conn:
if tc, ok := tc.NetConn().(*net.TCPConn); ok {
mp, _ := tc.MultipathTCP()
return mp
}
return false
default:
return false
}
}

View File

@ -3,7 +3,6 @@ package core
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"net/url" "net/url"
@ -34,7 +33,7 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS {
} }
func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
dialers, err := l.tcp.dialersFor(url, info) dialers, err := l.tcp.dialersFor(url, info, options)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -58,17 +57,10 @@ func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options
return nil, err return nil, err
} }
func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) { func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) {
hostport := url.Host listener, err := l.tcp.listen(ctx, url, sintf, options)
if sintf != "" {
if host, port, err := net.SplitHostPort(hostport); err == nil {
hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port)
}
}
listener, err := l.listener.Listen(ctx, "tcp", hostport)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tlslistener := tls.NewListener(listener, l.config) return tls.NewListener(listener, l.config), nil
return tlslistener, nil
} }

View File

@ -40,6 +40,6 @@ func (l *linkUNIX) dial(ctx context.Context, url *url.URL, info linkInfo, option
return l.dialer.DialContext(ctx, "unix", addr.String()) return l.dialer.DialContext(ctx, "unix", addr.String())
} }
func (l *linkUNIX) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { func (l *linkUNIX) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) {
return l.listener.Listen(ctx, "unix", url.Path) return l.listener.Listen(ctx, "unix", url.Path)
} }