diff --git a/src/yggdrasil/link.go b/src/yggdrasil/link.go index c4f52a38..07adbe84 100644 --- a/src/yggdrasil/link.go +++ b/src/yggdrasil/link.go @@ -119,7 +119,8 @@ func (l *link) listen(uri string) error { } switch u.Scheme { case "tcp": - return l.tcp.listen(u.Host) + _, err := l.tcp.listen(u.Host) + return err default: return errors.New("unknown listen scheme: " + u.Scheme) } diff --git a/src/yggdrasil/multicast.go b/src/yggdrasil/multicast.go index 59f0eea5..401f6783 100644 --- a/src/yggdrasil/multicast.go +++ b/src/yggdrasil/multicast.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "regexp" - "sync" "time" "golang.org/x/net/ipv6" @@ -16,19 +15,16 @@ type multicast struct { reconfigure chan chan error sock *ipv6.PacketConn groupAddr string - myAddr *net.TCPAddr - myAddrMutex sync.RWMutex + listeners map[string]*tcpListener } func (m *multicast) init(core *Core) { m.core = core m.reconfigure = make(chan chan error, 1) + m.listeners = make(map[string]*tcpListener) go func() { for { e := <-m.reconfigure - m.myAddrMutex.Lock() - m.myAddr = m.core.link.tcp.getAddr() - m.myAddrMutex.Unlock() e <- nil } }() @@ -94,10 +90,12 @@ func (m *multicast) interfaces() []net.Interface { continue } for _, expr := range exprs { + // Compile each regular expression e, err := regexp.Compile(expr) if err != nil { panic(err) } + // Does the interface match the regular expression? Store it if so if e.MatchString(iface.Name) { interfaces = append(interfaces, iface) } @@ -107,10 +105,6 @@ func (m *multicast) interfaces() []net.Interface { } func (m *multicast) announce() { - var anAddr net.TCPAddr - m.myAddrMutex.Lock() - m.myAddr = m.core.link.tcp.getAddr() - m.myAddrMutex.Unlock() groupAddr, err := net.ResolveUDPAddr("udp6", m.groupAddr) if err != nil { panic(err) @@ -121,27 +115,47 @@ func (m *multicast) announce() { } for { for _, iface := range m.interfaces() { - m.sock.JoinGroup(&iface, groupAddr) + // Find interface addresses addrs, err := iface.Addrs() if err != nil { panic(err) } - m.myAddrMutex.RLock() - anAddr.Port = m.myAddr.Port - m.myAddrMutex.RUnlock() for _, addr := range addrs { addrIP, _, _ := net.ParseCIDR(addr.String()) + // Ignore IPv4 addresses if addrIP.To4() != nil { continue - } // IPv6 only + } + // Ignore non-link-local addresses if !addrIP.IsLinkLocalUnicast() { continue } - anAddr.IP = addrIP - anAddr.Zone = iface.Name - destAddr.Zone = iface.Name - msg := []byte(anAddr.String()) - m.sock.WriteTo(msg, nil, destAddr) + // Join the multicast group + m.sock.JoinGroup(&iface, groupAddr) + // Try and see if we already have a TCP listener for this interface + var listener *tcpListener + if _, ok := m.listeners[iface.Name]; !ok { + // No listener was found - let's create one + listenaddr := fmt.Sprintf("[%s%%%s]:0", addrIP, iface.Name) + if l, err := m.core.link.tcp.listen(listenaddr); err == nil { + // Store the listener so that we can stop it later if needed + listener = &tcpListener{ + listener: l, + stop: make(chan bool), + } + m.listeners[iface.Name] = listener + } + } else { + // An existing listener was found + listener = m.listeners[iface.Name] + } + // Get the listener details and construct the multicast beacon + lladdr := (*listener.listener).Addr().String() + if a, err := net.ResolveTCPAddr("tcp6", lladdr); err == nil { + destAddr.Zone = iface.Name + msg := []byte(a.String()) + m.sock.WriteTo(msg, nil, destAddr) + } break } time.Sleep(time.Second) diff --git a/src/yggdrasil/tcp.go b/src/yggdrasil/tcp.go index 80d9ccdb..652f5abe 100644 --- a/src/yggdrasil/tcp.go +++ b/src/yggdrasil/tcp.go @@ -33,13 +33,17 @@ const tcp_ping_interval = (default_timeout * 2 / 3) // The TCP listener and information about active TCP connections, to avoid duplication. type tcp struct { - link *link - reconfigure chan chan error - mutex sync.Mutex // Protecting the below - listeners map[string]net.Listener - listenerstops map[string]chan bool - calls map[string]struct{} - conns map[linkInfo](chan struct{}) + link *link + reconfigure chan chan error + mutex sync.Mutex // Protecting the below + listeners map[string]*tcpListener + calls map[string]struct{} + conns map[linkInfo](chan struct{}) +} + +type tcpListener struct { + listener *net.Listener + stop chan bool } // Wrapper function to set additional options for specific connection types. @@ -60,7 +64,7 @@ func (t *tcp) getAddr() *net.TCPAddr { t.mutex.Lock() defer t.mutex.Unlock() for _, listener := range t.listeners { - return listener.Addr().(*net.TCPAddr) + return (*listener.listener).Addr().(*net.TCPAddr) } return nil } @@ -72,8 +76,7 @@ func (t *tcp) init(l *link) error { t.mutex.Lock() t.calls = make(map[string]struct{}) t.conns = make(map[linkInfo](chan struct{})) - t.listeners = make(map[string]net.Listener) - t.listenerstops = make(map[string]chan bool) + t.listeners = make(map[string]*tcpListener) t.mutex.Unlock() go func() { @@ -89,7 +92,7 @@ func (t *tcp) init(l *link) error { e <- errors.New("unknown scheme: " + add) continue } - if err := t.listen(add[6:]); err != nil { + if _, err := t.listen(add[6:]); err != nil { e <- err continue } @@ -110,7 +113,7 @@ func (t *tcp) init(l *link) error { if listenaddr[:6] != "tcp://" { continue } - if err := t.listen(listenaddr[6:]); err != nil { + if _, err := t.listen(listenaddr[6:]); err != nil { return err } } @@ -118,7 +121,7 @@ func (t *tcp) init(l *link) error { return nil } -func (t *tcp) listen(listenaddr string) error { +func (t *tcp) listen(listenaddr string) (*net.Listener, error) { var err error ctx := context.Background() @@ -127,36 +130,36 @@ func (t *tcp) listen(listenaddr string) error { } listener, err := lc.Listen(ctx, "tcp", listenaddr) if err == nil { + l := tcpListener{ + listener: &listener, + stop: make(chan bool, 1), + } t.mutex.Lock() - t.listeners[listenaddr] = listener - t.listenerstops[listenaddr] = make(chan bool, 1) + t.listeners[listenaddr[6:]] = &l t.mutex.Unlock() - go t.listener(listenaddr) - return nil + go t.listener(&l) + return &listener, nil } - return err + return nil, err } // Runs the listener, which spawns off goroutines for incoming connections. -func (t *tcp) listener(listenaddr string) { - t.mutex.Lock() - listener, ok1 := t.listeners[listenaddr] - listenerstop, ok2 := t.listenerstops[listenaddr] - t.mutex.Unlock() - if !ok1 || !ok2 { - t.link.core.log.Errorln("Tried to start TCP listener for", listenaddr, "which doesn't exist") +func (t *tcp) listener(listener *tcpListener) { + if listener == nil { return } - reallistenaddr := listener.Addr().String() - defer listener.Close() + reallistener := *listener.listener + reallistenaddr := reallistener.Addr().String() + stop := listener.stop + defer reallistener.Close() t.link.core.log.Infoln("Listening for TCP on:", reallistenaddr) accepted := make(chan bool) for { var sock net.Conn var err error go func() { - sock, err = listener.Accept() + sock, err = reallistener.Accept() accepted <- true }() select { @@ -166,7 +169,7 @@ func (t *tcp) listener(listenaddr string) { return } go t.handler(sock, true) - case <-listenerstop: + case <-stop: t.link.core.log.Errorln("Stopping TCP listener on:", reallistenaddr) return }