Varying connection check strictness based on scope

This commit is contained in:
Neil Alexander 2022-10-22 17:45:09 +01:00
parent 8fe1c41295
commit 35ea66d651
2 changed files with 31 additions and 16 deletions

View File

@ -35,14 +35,14 @@ func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error {
if err != nil { if err != nil {
return err return err
} }
info := linkInfoFor("tcp", sintf, addr.String())
if l.links.isConnectedTo(info) {
return nil
}
dialer, err := l.dialerFor(addr, sintf) dialer, err := l.dialerFor(addr, sintf)
if err != nil { if err != nil {
return err return err
} }
info := linkInfoFor("tcp", sintf, tcpIDFor(dialer.LocalAddr, addr))
if l.links.isConnectedTo(info) {
return nil
}
conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String()) conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String())
if err != nil { if err != nil {
return err return err
@ -82,10 +82,11 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {
cancel() cancel()
break break
} }
addr := conn.RemoteAddr().(*net.TCPAddr) laddr := conn.LocalAddr().(*net.TCPAddr)
name := fmt.Sprintf("tcp://%s", addr) raddr := conn.RemoteAddr().(*net.TCPAddr)
info := linkInfoFor("tcp", sintf, addr.String()) name := fmt.Sprintf("tcp://%s", raddr)
if err = l.handler(name, info, conn, linkOptions{}, true, addr.IP.IsLinkLocalUnicast()); err != nil { info := linkInfoFor("tcp", sintf, tcpIDFor(laddr, raddr))
if err = l.handler(name, info, conn, linkOptions{}, true, raddr.IP.IsLinkLocalUnicast()); err != nil {
l.core.log.Errorln("Failed to create inbound link:", err) l.core.log.Errorln("Failed to create inbound link:", err)
} }
} }
@ -179,3 +180,16 @@ func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error)
} }
return dialer, nil return dialer, nil
} }
func tcpIDFor(local net.Addr, remoteAddr *net.TCPAddr) string {
if localAddr, ok := local.(*net.TCPAddr); ok && localAddr.IP.Equal(remoteAddr.IP) {
// Nodes running on the same host — include both the IP and port.
return remoteAddr.String()
}
if remoteAddr.IP.IsLinkLocalUnicast() {
// Nodes discovered via multicast — include the IP only.
return remoteAddr.IP.String()
}
// Nodes connected remotely — include both the IP and port.
return remoteAddr.String()
}

View File

@ -51,14 +51,14 @@ func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) err
if err != nil { if err != nil {
return err return err
} }
info := linkInfoFor("tls", sintf, addr.String())
if l.links.isConnectedTo(info) {
return nil
}
dialer, err := l.tcp.dialerFor(addr, sintf) dialer, err := l.tcp.dialerFor(addr, sintf)
if err != nil { if err != nil {
return err return err
} }
info := linkInfoFor("tls", sintf, tcpIDFor(dialer.LocalAddr, addr))
if l.links.isConnectedTo(info) {
return nil
}
tlsconfig := l.config.Clone() tlsconfig := l.config.Clone()
tlsconfig.ServerName = sni tlsconfig.ServerName = sni
tlsdialer := &tls.Dialer{ tlsdialer := &tls.Dialer{
@ -105,10 +105,11 @@ func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) {
cancel() cancel()
break break
} }
addr := conn.RemoteAddr().(*net.TCPAddr) laddr := conn.LocalAddr().(*net.TCPAddr)
name := fmt.Sprintf("tls://%s", addr) raddr := conn.RemoteAddr().(*net.TCPAddr)
info := linkInfoFor("tls", sintf, addr.String()) name := fmt.Sprintf("tls://%s", raddr)
if err = l.handler(name, info, conn, linkOptions{}, true, addr.IP.IsLinkLocalUnicast()); err != nil { info := linkInfoFor("tls", sintf, tcpIDFor(laddr, raddr))
if err = l.handler(name, info, conn, linkOptions{}, true, raddr.IP.IsLinkLocalUnicast()); err != nil {
l.core.log.Errorln("Failed to create inbound link:", err) l.core.log.Errorln("Failed to create inbound link:", err)
} }
} }