Multicast module state tweaks

This commit is contained in:
Neil Alexander 2024-08-11 10:41:58 +01:00
parent af9ff34995
commit ef989bef63
No known key found for this signature in database
GPG Key ID: A02A2019A2BB0944
2 changed files with 16 additions and 12 deletions

View File

@ -9,6 +9,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"net/url" "net/url"
"sync/atomic"
"time" "time"
"github.com/Arceliar/phony" "github.com/Arceliar/phony"
@ -28,7 +29,7 @@ type Multicast struct {
core *core.Core core *core.Core
log *log.Logger log *log.Logger
sock *ipv6.PacketConn sock *ipv6.PacketConn
_isOpen bool running atomic.Bool
_listeners map[string]*listenerInfo _listeners map[string]*listenerInfo
_interfaces map[string]*interfaceInfo _interfaces map[string]*interfaceInfo
_timer *time.Timer _timer *time.Timer
@ -79,7 +80,7 @@ func New(core *core.Core, log *log.Logger, opts ...SetupOption) (*Multicast, err
} }
func (m *Multicast) _start() error { func (m *Multicast) _start() error {
if m._isOpen { if !m.running.CompareAndSwap(false, true) {
return fmt.Errorf("multicast module is already started") return fmt.Errorf("multicast module is already started")
} }
var anyEnabled bool var anyEnabled bool
@ -87,12 +88,14 @@ func (m *Multicast) _start() error {
anyEnabled = anyEnabled || intf.Beacon || intf.Listen anyEnabled = anyEnabled || intf.Beacon || intf.Listen
} }
if !anyEnabled { if !anyEnabled {
m.running.Store(false)
return nil return nil
} }
m.log.Debugln("Starting multicast module") m.log.Debugln("Starting multicast module")
defer m.log.Debugln("Started multicast module") defer m.log.Debugln("Started multicast module")
addr, err := net.ResolveUDPAddr("udp", string(m.config._groupAddr)) addr, err := net.ResolveUDPAddr("udp", string(m.config._groupAddr))
if err != nil { if err != nil {
m.running.Store(false)
return err return err
} }
listenString := fmt.Sprintf("[::]:%v", addr.Port) listenString := fmt.Sprintf("[::]:%v", addr.Port)
@ -101,6 +104,7 @@ func (m *Multicast) _start() error {
} }
conn, err := lc.ListenPacket(context.Background(), "udp6", listenString) conn, err := lc.ListenPacket(context.Background(), "udp6", listenString)
if err != nil { if err != nil {
m.running.Store(false)
return err return err
} }
m.sock = ipv6.NewPacketConn(conn) m.sock = ipv6.NewPacketConn(conn)
@ -108,7 +112,6 @@ func (m *Multicast) _start() error {
// Windows can't set this flag, so we need to handle it in other ways // Windows can't set this flag, so we need to handle it in other ways
} }
m._isOpen = true
go m.listen() go m.listen()
m.Act(nil, m._multicastStarted) m.Act(nil, m._multicastStarted)
m.Act(nil, m._announce) m.Act(nil, m._announce)
@ -118,11 +121,7 @@ func (m *Multicast) _start() error {
// IsStarted returns true if the module has been started. // IsStarted returns true if the module has been started.
func (m *Multicast) IsStarted() bool { func (m *Multicast) IsStarted() bool {
var isOpen bool return m.running.Load()
phony.Block(m, func() {
isOpen = m._isOpen
})
return isOpen
} }
// Stop stops the multicast module. // Stop stops the multicast module.
@ -136,8 +135,10 @@ func (m *Multicast) Stop() error {
} }
func (m *Multicast) _stop() error { func (m *Multicast) _stop() error {
if !m.running.CompareAndSwap(true, false) {
return nil
}
m.log.Infoln("Stopping multicast module") m.log.Infoln("Stopping multicast module")
m._isOpen = false
if m.sock != nil { if m.sock != nil {
m.sock.Close() m.sock.Close()
} }
@ -233,7 +234,7 @@ func (m *Multicast) AnnounceNow() {
} }
func (m *Multicast) _announce() { func (m *Multicast) _announce() {
if !m._isOpen { if !m.running.Load() {
return return
} }
m._updateInterfaces() m._updateInterfaces()
@ -250,7 +251,7 @@ func (m *Multicast) _announce() {
for name, info := range m._listeners { for name, info := range m._listeners {
// Prepare our stop function! // Prepare our stop function!
stop := func() { stop := func() {
info.listener.Close() info.listener.Cancel()
delete(m._listeners, name) delete(m._listeners, name)
m.log.Debugln("No longer multicasting on", name) m.log.Debugln("No longer multicasting on", name)
} }
@ -376,6 +377,9 @@ func (m *Multicast) listen() {
bs := make([]byte, 2048) bs := make([]byte, 2048)
hb := make([]byte, 0, blake2b.Size) // Reused to reduce hash allocations hb := make([]byte, 0, blake2b.Size) // Reused to reduce hash allocations
for { for {
if !m.running.Load() {
return
}
n, rcm, fromAddr, err := m.sock.ReadFrom(bs) n, rcm, fromAddr, err := m.sock.ReadFrom(bs)
if err != nil { if err != nil {
if !m.IsStarted() { if !m.IsStarted() {

View File

@ -31,7 +31,7 @@ import (
) )
func (m *Multicast) _multicastStarted() { func (m *Multicast) _multicastStarted() {
if !m._isOpen { if !m.running.Load() {
return return
} }
C.StopAWDLBrowsing() C.StopAWDLBrowsing()