Deduplicate some logic

This commit is contained in:
Neil Alexander 2022-11-26 16:18:15 +00:00
parent 1adc88ec77
commit 723097fbf6
No known key found for this signature in database
GPG Key ID: A02A2019A2BB0944
2 changed files with 44 additions and 34 deletions

View File

@ -31,19 +31,26 @@ func (l *links) newLinkTCP() *linkTCP {
return lt return lt
} }
func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error { type tcpDialer struct {
info linkInfo
dialer *net.Dialer
addr *net.TCPAddr
}
func (l *linkTCP) dialersFor(url *url.URL, options linkOptions, sintf string) ([]*tcpDialer, error) {
host, p, err := net.SplitHostPort(url.Host) host, p, err := net.SplitHostPort(url.Host)
if err != nil { if err != nil {
return err return nil, err
} }
port, err := strconv.Atoi(p) port, err := strconv.Atoi(p)
if err != nil { if err != nil {
return err return nil, err
} }
ips, err := net.LookupIP(host) ips, err := net.LookupIP(host)
if err != nil { if err != nil {
return err return nil, err
} }
dialers := make([]*tcpDialer, 0, len(ips))
for _, ip := range ips { for _, ip := range ips {
addr := &net.TCPAddr{ addr := &net.TCPAddr{
IP: ip, IP: ip,
@ -55,10 +62,30 @@ func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error {
} }
info := linkInfoFor("tcp", sintf, tcpIDFor(dialer.LocalAddr, addr)) info := linkInfoFor("tcp", sintf, tcpIDFor(dialer.LocalAddr, addr))
if l.links.isConnectedTo(info) { if l.links.isConnectedTo(info) {
return nil, nil
}
dialers = append(dialers, &tcpDialer{
info: info,
dialer: dialer,
addr: addr,
})
}
return dialers, nil
}
func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error {
dialers, err := l.dialersFor(url, options, sintf)
if err != nil {
return err
}
if len(dialers) == 0 {
return nil return nil
} }
conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String()) for _, d := range dialers {
var conn net.Conn
conn, err = d.dialer.DialContext(l.core.ctx, "tcp", d.addr.String())
if err != nil { if err != nil {
l.core.log.Warnf("Failed to connect to %s: %s", d.addr, err)
continue continue
} }
name := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/") name := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/")
@ -66,9 +93,9 @@ func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error {
url: url, url: url,
sintf: sintf, sintf: sintf,
} }
return l.handler(dial, name, info, conn, options, false, false) return l.handler(dial, name, d.info, conn, options, false, false)
} }
return fmt.Errorf("failed to connect via %d addresses", len(ips)) return fmt.Errorf("failed to connect via %d address(es), last error: %w", len(dialers), err)
} }
func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {

View File

@ -13,7 +13,6 @@ import (
"math/big" "math/big"
"net" "net"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
@ -48,38 +47,22 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS {
} }
func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) error { func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) error {
host, p, err := net.SplitHostPort(url.Host) dialers, err := l.tcp.dialersFor(url, options, sintf)
if err != nil { if err != nil {
return err return err
} }
port, err := strconv.Atoi(p) if len(dialers) == 0 {
if err != nil {
return err
}
ips, err := net.LookupIP(host)
if err != nil {
return err
}
for _, ip := range ips {
addr := &net.TCPAddr{
IP: ip,
Port: port,
}
dialer, err := l.tcp.dialerFor(addr, sintf)
if err != nil {
continue
}
info := linkInfoFor("tls", sintf, tcpIDFor(dialer.LocalAddr, addr))
if l.links.isConnectedTo(info) {
return nil return nil
} }
for _, d := range dialers {
tlsconfig := l.config.Clone() tlsconfig := l.config.Clone()
tlsconfig.ServerName = sni tlsconfig.ServerName = sni
tlsdialer := &tls.Dialer{ tlsdialer := &tls.Dialer{
NetDialer: dialer, NetDialer: d.dialer,
Config: tlsconfig, Config: tlsconfig,
} }
conn, err := tlsdialer.DialContext(l.core.ctx, "tcp", addr.String()) var conn net.Conn
conn, err = tlsdialer.DialContext(l.core.ctx, "tcp", d.addr.String())
if err != nil { if err != nil {
continue continue
} }
@ -88,9 +71,9 @@ func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) err
url: url, url: url,
sintf: sintf, sintf: sintf,
} }
return l.handler(dial, name, info, conn, options, false, false) return l.handler(dial, name, d.info, conn, options, false, false)
} }
return fmt.Errorf("failed to connect via %d addresses", len(ips)) return fmt.Errorf("failed to connect via %d address(es), last error: %w", len(dialers), err)
} }
func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) {