diff --git a/src/admin/getpeers.go b/src/admin/getpeers.go index f79877ca..c50c4590 100644 --- a/src/admin/getpeers.go +++ b/src/admin/getpeers.go @@ -24,6 +24,7 @@ type PeerEntry struct { PublicKey string `json:"key"` Port uint64 `json:"port"` Priority uint64 `json:"priority"` + Multipath bool `json:"multipath,omitempty"` RXBytes DataUnit `json:"bytes_recvd,omitempty"` TXBytes DataUnit `json:"bytes_sent,omitempty"` Uptime float64 `json:"uptime,omitempty"` @@ -37,14 +38,15 @@ func (a *AdminSocket) getPeersHandler(req *GetPeersRequest, res *GetPeersRespons res.Peers = make([]PeerEntry, 0, len(peers)) for _, p := range peers { peer := PeerEntry{ - Port: p.Port, - Up: p.Up, - Inbound: p.Inbound, - Priority: uint64(p.Priority), // can't be uint8 thanks to gobind - URI: p.URI, - RXBytes: DataUnit(p.RXBytes), - TXBytes: DataUnit(p.TXBytes), - Uptime: p.Uptime.Seconds(), + Port: p.Port, + Up: p.Up, + Inbound: p.Inbound, + Priority: uint64(p.Priority), // can't be uint8 thanks to gobind + Multipath: p.Multipath, + URI: p.URI, + RXBytes: DataUnit(p.RXBytes), + TXBytes: DataUnit(p.TXBytes), + Uptime: p.Uptime.Seconds(), } if p.Latency > 0 { peer.Latency = p.Latency diff --git a/src/core/api.go b/src/core/api.go index 875d7bf2..ace3a961 100644 --- a/src/core/api.go +++ b/src/core/api.go @@ -30,6 +30,7 @@ type PeerInfo struct { Coords []uint64 Port uint64 Priority uint8 + Multipath bool RXBytes uint64 TXBytes uint64 Uptime time.Duration @@ -87,6 +88,7 @@ func (c *Core) GetPeers() []PeerInfo { peerinfo.RXBytes = atomic.LoadUint64(&c.rx) peerinfo.TXBytes = atomic.LoadUint64(&c.tx) peerinfo.Uptime = time.Since(c.up) + peerinfo.Multipath = isMPTCP(c) } if p, ok := conns[conn]; ok { peerinfo.Key = p.Key diff --git a/src/core/link.go b/src/core/link.go index 6bcbec59..0fda11a6 100644 --- a/src/core/link.go +++ b/src/core/link.go @@ -44,7 +44,7 @@ type links struct { type linkProtocol interface { dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) - listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) + listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) } // linkInfo is used as a map key @@ -72,6 +72,7 @@ type linkOptions struct { tlsSNI string password []byte maxBackoff time.Duration + multipath bool } type Listener struct { @@ -140,6 +141,7 @@ const ErrLinkPinnedKeyInvalid = linkError("pinned public key is invalid") const ErrLinkPasswordInvalid = linkError("password is invalid") const ErrLinkUnrecognisedSchema = linkError("link schema unknown") const ErrLinkMaxBackoffInvalid = linkError("max backoff duration invalid") +const ErrLinkMultipathInvalid = linkError("multipath invalid") func (l *links) add(u *url.URL, sintf string, linkType linkType) error { var retErr error @@ -193,6 +195,17 @@ func (l *links) add(u *url.URL, sintf string, linkType linkType) error { } options.maxBackoff = d } + if p := u.Query().Get("multipath"); p != "" { + switch p { + case "true", "1": + options.multipath = true + case "false", "0": + options.multipath = false + default: + retErr = ErrLinkMultipathInvalid + return + } + } // SNI headers must contain hostnames and not IP addresses, so we must make sure // that we do not populate the SNI with an IP literal. We do this by splitting // the host-port combo from the query option and then seeing if it parses to an @@ -379,7 +392,7 @@ func (l *links) add(u *url.URL, sintf string, linkType linkType) error { return retErr } -func (l *links) remove(u *url.URL, sintf string, linkType linkType) error { +func (l *links) remove(u *url.URL, sintf string, _ linkType) error { var retErr error phony.Block(l, func() { // Generate the link info and see whether we think we already @@ -422,7 +435,36 @@ func (l *links) listen(u *url.URL, sintf string) (*Listener, error) { cancel() return nil, ErrLinkUnrecognisedSchema } - listener, err := protocol.listen(ctx, u, sintf) + + var options linkOptions + if p := u.Query().Get("priority"); p != "" { + pi, err := strconv.ParseUint(p, 10, 8) + if err != nil { + cancel() + return nil, ErrLinkPriorityInvalid + } + options.priority = uint8(pi) + } + if p := u.Query().Get("password"); p != "" { + if len(p) > blake2b.Size { + cancel() + return nil, ErrLinkPasswordInvalid + } + options.password = []byte(p) + } + if p := u.Query().Get("multipath"); p != "" { + switch p { + case "true", "1": + options.multipath = true + case "false", "0": + options.multipath = false + default: + cancel() + return nil, ErrLinkMultipathInvalid + } + } + + listener, err := protocol.listen(ctx, u, sintf, options) if err != nil { cancel() return nil, err @@ -433,21 +475,6 @@ func (l *links) listen(u *url.URL, sintf string) (*Listener, error) { Cancel: cancel, } - var options linkOptions - if p := u.Query().Get("priority"); p != "" { - pi, err := strconv.ParseUint(p, 10, 8) - if err != nil { - return nil, ErrLinkPriorityInvalid - } - options.priority = uint8(pi) - } - if p := u.Query().Get("password"); p != "" { - if len(p) > blake2b.Size { - return nil, ErrLinkPasswordInvalid - } - options.password = []byte(p) - } - go func() { l.core.log.Infof("%s listener started on %s", strings.ToUpper(u.Scheme), listener.Addr()) defer l.core.log.Infof("%s listener stopped on %s", strings.ToUpper(u.Scheme), listener.Addr()) @@ -567,7 +594,7 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn, s switch { case err != nil: return fmt.Errorf("write handshake: %w", err) - case err == nil && n != len(metaBytes): + case n != len(metaBytes): return fmt.Errorf("incomplete handshake send") } meta = version_metadata{} diff --git a/src/core/link_quic.go b/src/core/link_quic.go index 9ad5456d..e99ed6f1 100644 --- a/src/core/link_quic.go +++ b/src/core/link_quic.go @@ -65,7 +65,7 @@ func (l *linkQUIC) dial(ctx context.Context, url *url.URL, info linkInfo, option }, nil } -func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { +func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) { ql, err := quic.ListenAddr(url.Host, l.tlsconfig, l.quicconfig) if err != nil { return nil, err diff --git a/src/core/link_socks.go b/src/core/link_socks.go index b92374d4..709d1ae6 100644 --- a/src/core/link_socks.go +++ b/src/core/link_socks.go @@ -47,6 +47,6 @@ func (l *linkSOCKS) dial(_ context.Context, url *url.URL, info linkInfo, options return conn, nil } -func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { +func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) { return nil, fmt.Errorf("SOCKS listener not supported") } diff --git a/src/core/link_tcp.go b/src/core/link_tcp.go index f595aeb9..d75c2c9b 100644 --- a/src/core/link_tcp.go +++ b/src/core/link_tcp.go @@ -36,7 +36,7 @@ type tcpDialer struct { addr *net.TCPAddr } -func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error) { +func (l *linkTCP) dialersFor(url *url.URL, info linkInfo, options linkOptions) ([]*tcpDialer, error) { host, p, err := net.SplitHostPort(url.Host) if err != nil { return nil, err @@ -55,7 +55,7 @@ func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error) IP: ip, Port: port, } - dialer, err := l.dialerFor(addr, info.sintf) + dialer, err := l.dialerFor(addr, info.sintf, options.multipath) if err != nil { continue } @@ -69,7 +69,7 @@ func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error) } func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { - dialers, err := l.dialersFor(url, info) + dialers, err := l.dialersFor(url, info, options) if err != nil { return nil, err } @@ -88,17 +88,21 @@ func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options return nil, err } -func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) { +func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) { hostport := url.Host if sintf != "" { if host, port, err := net.SplitHostPort(hostport); err == nil { hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port) } } - return l.listenconfig.Listen(ctx, "tcp", hostport) + lc := *l.listenconfig + if options.multipath { + setMPTCPForListener(&lc) + } + return lc.Listen(ctx, "tcp", hostport) } -func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error) { +func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string, mptcp bool) (*net.Dialer, error) { if dst.IP.IsLinkLocalUnicast() { if sintf != "" { dst.Zone = sintf @@ -112,6 +116,9 @@ func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error) KeepAlive: -1, Control: l.tcpContext, } + if mptcp { + setMPTCPForDialer(dialer) + } if sintf != "" { dialer.Control = l.getControl(sintf) ief, err := net.InterfaceByName(sintf) diff --git a/src/core/link_tcp_mptcp.go b/src/core/link_tcp_mptcp.go new file mode 100644 index 00000000..ef46db91 --- /dev/null +++ b/src/core/link_tcp_mptcp.go @@ -0,0 +1,30 @@ +package core + +import ( + "crypto/tls" + "net" +) + +func setMPTCPForDialer(d *net.Dialer) { + d.SetMultipathTCP(true) +} + +func setMPTCPForListener(lc *net.ListenConfig) { + lc.SetMultipathTCP(true) +} + +func isMPTCP(c net.Conn) bool { + switch tc := c.(type) { + case *net.TCPConn: + mp, _ := tc.MultipathTCP() + return mp + case *tls.Conn: + if tc, ok := tc.NetConn().(*net.TCPConn); ok { + mp, _ := tc.MultipathTCP() + return mp + } + return false + default: + return false + } +} diff --git a/src/core/link_tls.go b/src/core/link_tls.go index a93227f6..bc00bf82 100644 --- a/src/core/link_tls.go +++ b/src/core/link_tls.go @@ -3,7 +3,6 @@ package core import ( "context" "crypto/tls" - "fmt" "net" "net/url" @@ -34,7 +33,7 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS { } func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { - dialers, err := l.tcp.dialersFor(url, info) + dialers, err := l.tcp.dialersFor(url, info, options) if err != nil { return nil, err } @@ -58,17 +57,10 @@ func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options return nil, err } -func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) { - hostport := url.Host - if sintf != "" { - if host, port, err := net.SplitHostPort(hostport); err == nil { - hostport = fmt.Sprintf("[%s%%%s]:%s", host, sintf, port) - } - } - listener, err := l.listener.Listen(ctx, "tcp", hostport) +func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string, options linkOptions) (net.Listener, error) { + listener, err := l.tcp.listen(ctx, url, sintf, options) if err != nil { return nil, err } - tlslistener := tls.NewListener(listener, l.config) - return tlslistener, nil + return tls.NewListener(listener, l.config), nil } diff --git a/src/core/link_unix.go b/src/core/link_unix.go index 8dde8946..b2c5f59b 100644 --- a/src/core/link_unix.go +++ b/src/core/link_unix.go @@ -40,6 +40,6 @@ func (l *linkUNIX) dial(ctx context.Context, url *url.URL, info linkInfo, option return l.dialer.DialContext(ctx, "unix", addr.String()) } -func (l *linkUNIX) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { +func (l *linkUNIX) listen(ctx context.Context, url *url.URL, _ string, _ linkOptions) (net.Listener, error) { return l.listener.Listen(ctx, "unix", url.Path) }