diff --git a/src/yggdrasil/conn.go b/src/yggdrasil/conn.go index fd65743c..e337b34f 100644 --- a/src/yggdrasil/conn.go +++ b/src/yggdrasil/conn.go @@ -73,18 +73,26 @@ func (c *Conn) Read(b []byte) (int, error) { return 0, errors.New("session was closed") } defer util.PutBytes(p.Payload) - if !c.session.nonceIsOK(&p.Nonce) { - return 0, errors.New("packet dropped due to invalid nonce") + err := func() error { + c.session.theirNonceMutex.Lock() + defer c.session.theirNonceMutex.Unlock() + if !c.session.nonceIsOK(&p.Nonce) { + return errors.New("packet dropped due to invalid nonce") + } + bs, isOK := crypto.BoxOpen(&c.session.sharedSesKey, p.Payload, &p.Nonce) + if !isOK { + util.PutBytes(bs) + return errors.New("packet dropped due to decryption failure") + } + b = b[:0] + b = append(b, bs...) + c.session.updateNonce(&p.Nonce) + c.session.time = time.Now() + return nil + }() + if err != nil { + return 0, err } - bs, isOK := crypto.BoxOpen(&c.session.sharedSesKey, p.Payload, &p.Nonce) - if !isOK { - util.PutBytes(bs) - return 0, errors.New("packet dropped due to decryption failure") - } - b = b[:0] - b = append(b, bs...) - c.session.updateNonce(&p.Nonce) - c.session.time = time.Now() atomic.AddUint64(&c.session.bytesRecvd, uint64(len(b))) return len(b), nil case <-c.session.closed: diff --git a/src/yggdrasil/session.go b/src/yggdrasil/session.go index e46761d5..c5319bc7 100644 --- a/src/yggdrasil/session.go +++ b/src/yggdrasil/session.go @@ -79,10 +79,8 @@ func (s *sessionInfo) update(p *sessionPing) bool { s.theirSesPub = p.SendSesPub s.theirHandle = p.Handle s.sharedSesKey = *crypto.GetSharedKey(&s.mySesPriv, &s.theirSesPub) - s.theirNonceMutex.Lock() s.theirNonce = crypto.BoxNonce{} s.theirNonceMask = 0 - s.theirNonceMutex.Unlock() } if p.MTU >= 1280 || p.MTU == 0 { s.theirMTU = p.MTU @@ -272,10 +270,6 @@ func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo { return nil } sinfo := sessionInfo{} - sinfo.myNonceMutex.Lock() - sinfo.theirNonceMutex.Lock() - defer sinfo.myNonceMutex.Unlock() - defer sinfo.theirNonceMutex.Unlock() sinfo.core = ss.core sinfo.reconfigure = make(chan chan error, 1) sinfo.theirPermPub = *theirPermKey @@ -395,9 +389,7 @@ func (ss *sessions) getPing(sinfo *sessionInfo) sessionPing { Coords: coords, MTU: sinfo.myMTU, } - sinfo.myNonceMutex.Lock() sinfo.myNonce.Increment() - sinfo.myNonceMutex.Unlock() return ref } @@ -501,8 +493,6 @@ func (sinfo *sessionInfo) getMTU() uint16 { // Checks if a packet's nonce is recent enough to fall within the window of allowed packets, and not already received. func (sinfo *sessionInfo) nonceIsOK(theirNonce *crypto.BoxNonce) bool { // The bitmask is to allow for some non-duplicate out-of-order packets - sinfo.theirNonceMutex.Lock() - defer sinfo.theirNonceMutex.Unlock() diff := theirNonce.Minus(&sinfo.theirNonce) if diff > 0 { return true @@ -512,8 +502,6 @@ func (sinfo *sessionInfo) nonceIsOK(theirNonce *crypto.BoxNonce) bool { // Updates the nonce mask by (possibly) shifting the bitmask and setting the bit corresponding to this nonce to 1, and then updating the most recent nonce func (sinfo *sessionInfo) updateNonce(theirNonce *crypto.BoxNonce) { - sinfo.theirNonceMutex.Lock() - defer sinfo.theirNonceMutex.Unlock() // Shift nonce mask if needed // Set bit diff := theirNonce.Minus(&sinfo.theirNonce)