diff --git a/cmd/yggdrasil/main.go b/cmd/yggdrasil/main.go index 7ac45260..8de2c868 100644 --- a/cmd/yggdrasil/main.go +++ b/cmd/yggdrasil/main.go @@ -3,6 +3,7 @@ package main import ( "context" "crypto/ed25519" + "crypto/sha1" "encoding/hex" "encoding/json" "flag" @@ -195,6 +196,10 @@ func main() { n := &node{} + // Track certificate fingerprints for configured roots, so + // that we can match them using the multicast discriminator. + fingerprints := map[[20]byte]struct{}{} + // Setup the Yggdrasil node itself. { options := []core.SetupOption{ @@ -214,6 +219,7 @@ func main() { } for _, root := range cfg.RootCertificates { options = append(options, core.RootCertificate(*root)) + fingerprints[sha1.Sum(root.Raw[:])] = struct{}{} } for _, allowed := range cfg.AllowedPublicKeys { k, err := hex.DecodeString(allowed) @@ -252,6 +258,29 @@ func main() { Priority: uint8(intf.Priority), }) } + if len(fingerprints) > 0 { + var matcher multicast.DiscriminatorMatch = func(b []byte) bool { + // Break apart the discriminator into 20-byte chunks and + // see whether any of them match the configured root CA + // fingerprints. If any of them match, we'll return true. + var f [20]byte + for len(b) >= len(f) { + b = b[copy(f[:], b):] + if _, ok := fingerprints[f]; ok { + return true + } + } + return false + } + // Populate our own discriminator with the fingerprints of our + // configured root CAs. + var discriminator multicast.Discriminator + for f := range fingerprints { + discriminator = append(discriminator, f[:]...) + } + options = append(options, matcher) + options = append(options, discriminator) + } if n.multicast, err = multicast.New(n.core, logger, options...); err != nil { panic(err) } diff --git a/src/multicast/advertisement.go b/src/multicast/advertisement.go index 4b65b609..69c29b6a 100644 --- a/src/multicast/advertisement.go +++ b/src/multicast/advertisement.go @@ -7,22 +7,33 @@ import ( ) type multicastAdvertisement struct { - PublicKey ed25519.PublicKey - Port uint16 + MajorVersion uint16 + MinorVersion uint16 + PublicKey ed25519.PublicKey + Port uint16 + Discriminator []byte } func (m *multicastAdvertisement) MarshalBinary() ([]byte, error) { - b := make([]byte, 0, ed25519.PublicKeySize+2) + b := make([]byte, 0, ed25519.PublicKeySize+8+len(m.Discriminator)) + b = binary.BigEndian.AppendUint16(b, m.MajorVersion) + b = binary.BigEndian.AppendUint16(b, m.MinorVersion) b = append(b, m.PublicKey...) b = binary.BigEndian.AppendUint16(b, m.Port) + b = binary.BigEndian.AppendUint16(b, uint16(len(m.Discriminator))) + b = append(b, m.Discriminator...) return b, nil } func (m *multicastAdvertisement) UnmarshalBinary(b []byte) error { - if len(b) < ed25519.PublicKeySize+2 { + if len(b) < ed25519.PublicKeySize+8 { return fmt.Errorf("invalid multicast beacon") } - m.PublicKey = b[:ed25519.PublicKeySize] - m.Port = binary.BigEndian.Uint16(b[ed25519.PublicKeySize:]) + m.MajorVersion = binary.BigEndian.Uint16(b[0:2]) + m.MinorVersion = binary.BigEndian.Uint16(b[2:4]) + m.PublicKey = append(m.PublicKey[:0], b[4:4+ed25519.PublicKeySize]...) + m.Port = binary.BigEndian.Uint16(b[4+ed25519.PublicKeySize : 6+ed25519.PublicKeySize]) + dl := binary.BigEndian.Uint16(b[6+ed25519.PublicKeySize : 8+ed25519.PublicKeySize]) + m.Discriminator = append(m.Discriminator[:0], b[8+ed25519.PublicKeySize:8+ed25519.PublicKeySize+dl]...) return nil } diff --git a/src/multicast/advertisement_test.go b/src/multicast/advertisement_test.go new file mode 100644 index 00000000..7132322d --- /dev/null +++ b/src/multicast/advertisement_test.go @@ -0,0 +1,38 @@ +package multicast + +import ( + "crypto/ed25519" + "reflect" + "testing" +) + +func TestMulticastAdvertisementRoundTrip(t *testing.T) { + pk, sk, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + + orig := multicastAdvertisement{ + MajorVersion: 1, + MinorVersion: 2, + PublicKey: pk, + Port: 3, + Discriminator: sk, // any bytes will do + } + + ob, err := orig.MarshalBinary() + if err != nil { + t.Fatal(err) + } + + var new multicastAdvertisement + if err := new.UnmarshalBinary(ob); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(orig, new) { + t.Logf("original: %+v", orig) + t.Logf("new: %+v", new) + t.Fatalf("differences found after round-trip") + } +} diff --git a/src/multicast/multicast.go b/src/multicast/multicast.go index 9cd67ff1..f58af934 100644 --- a/src/multicast/multicast.go +++ b/src/multicast/multicast.go @@ -1,6 +1,7 @@ package multicast import ( + "bytes" "context" "encoding/hex" "fmt" @@ -30,8 +31,10 @@ type Multicast struct { _interfaces map[string]*interfaceInfo _timer *time.Timer config struct { - _groupAddr GroupAddress - _interfaces map[MulticastInterface]struct{} + _discriminator []byte + _discriminatorMatch func([]byte) bool + _groupAddr GroupAddress + _interfaces map[MulticastInterface]struct{} } } @@ -321,8 +324,11 @@ func (m *Multicast) _announce() { } addr := linfo.listener.Addr().(*net.TCPAddr) adv := multicastAdvertisement{ - PublicKey: m.core.PublicKey(), - Port: uint16(addr.Port), + MajorVersion: core.ProtocolVersionMajor, + MinorVersion: core.ProtocolVersionMinor, + PublicKey: m.core.PublicKey(), + Port: uint16(addr.Port), + Discriminator: m.config._discriminator, } msg, err := adv.MarshalBinary() if err != nil { @@ -373,7 +379,16 @@ func (m *Multicast) listen() { if err := adv.UnmarshalBinary(bs[:n]); err != nil { continue } - if adv.PublicKey.Equal(m.core.PublicKey()) { + switch { + case adv.MajorVersion != core.ProtocolVersionMajor: + continue + case adv.MinorVersion != core.ProtocolVersionMinor: + continue + case adv.PublicKey.Equal(m.core.PublicKey()): + continue + case m.config._discriminatorMatch == nil && !bytes.Equal(adv.Discriminator, m.config._discriminator): + continue + case m.config._discriminatorMatch != nil && !m.config._discriminatorMatch(adv.Discriminator): continue } from := fromAddr.(*net.UDPAddr) diff --git a/src/multicast/options.go b/src/multicast/options.go index f36284ed..aa740605 100644 --- a/src/multicast/options.go +++ b/src/multicast/options.go @@ -8,6 +8,10 @@ func (m *Multicast) _applyOption(opt SetupOption) { m.config._interfaces[v] = struct{}{} case GroupAddress: m.config._groupAddr = v + case Discriminator: + m.config._discriminator = append(m.config._discriminator[:0], v...) + case DiscriminatorMatch: + m.config._discriminatorMatch = v } } @@ -24,6 +28,10 @@ type MulticastInterface struct { } type GroupAddress string +type Discriminator []byte +type DiscriminatorMatch func([]byte) bool func (a MulticastInterface) isSetupOption() {} func (a GroupAddress) isSetupOption() {} +func (a Discriminator) isSetupOption() {} +func (a DiscriminatorMatch) isSetupOption() {}