Merge pull request #1046 from yggdrasil-network/neil/handshake

Tweak link handshake
This commit is contained in:
Arceliar 2023-06-19 06:23:47 -05:00 committed by GitHub
commit 99dd8f85d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 14 deletions

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io" "io"
"math" "math"
@ -485,16 +484,10 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn) e
case err == nil && n != len(metaBytes): case err == nil && n != len(metaBytes):
return fmt.Errorf("incomplete handshake send") return fmt.Errorf("incomplete handshake send")
} }
if _, err = io.ReadFull(conn, metaBytes); err != nil {
return fmt.Errorf("read handshake: %w", err)
}
if err = conn.SetDeadline(time.Time{}); err != nil {
return fmt.Errorf("failed to clear handshake deadline: %w", err)
}
meta = version_metadata{} meta = version_metadata{}
base := version_getBaseMetadata() base := version_getBaseMetadata()
if !meta.decode(metaBytes) { if !meta.decode(conn) {
return errors.New("failed to decode metadata") return conn.Close()
} }
if !meta.check() { if !meta.check() {
return fmt.Errorf("remote node incompatible version (local %s, remote %s)", return fmt.Errorf("remote node incompatible version (local %s, remote %s)",
@ -502,6 +495,9 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn) e
fmt.Sprintf("%d.%d", meta.majorVer, meta.minorVer), fmt.Sprintf("%d.%d", meta.majorVer, meta.minorVer),
) )
} }
if err = conn.SetDeadline(time.Time{}); err != nil {
return fmt.Errorf("failed to clear handshake deadline: %w", err)
}
// Check if the remote side matches the keys we expected. This is a bit of a weak // Check if the remote side matches the keys we expected. This is a bit of a weak
// check - in future versions we really should check a signature or something like that. // check - in future versions we really should check a signature or something like that.
if pinned := options.pinnedEd25519Keys; len(pinned) > 0 { if pinned := options.pinnedEd25519Keys; len(pinned) > 0 {

View File

@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"crypto/ed25519" "crypto/ed25519"
"encoding/binary" "encoding/binary"
"io"
) )
// This is the version-specific metadata exchanged at the start of a connection. // This is the version-specific metadata exchanged at the start of a connection.
@ -44,6 +45,7 @@ func version_getBaseMetadata() version_metadata {
func (m *version_metadata) encode() []byte { func (m *version_metadata) encode() []byte {
bs := make([]byte, 0, 64) bs := make([]byte, 0, 64)
bs = append(bs, 'm', 'e', 't', 'a') bs = append(bs, 'm', 'e', 't', 'a')
bs = append(bs, 0, 0) // Remaining message length
bs = binary.BigEndian.AppendUint16(bs, metaVersionMajor) bs = binary.BigEndian.AppendUint16(bs, metaVersionMajor)
bs = binary.BigEndian.AppendUint16(bs, 2) bs = binary.BigEndian.AppendUint16(bs, 2)
@ -61,16 +63,25 @@ func (m *version_metadata) encode() []byte {
bs = binary.BigEndian.AppendUint16(bs, 1) bs = binary.BigEndian.AppendUint16(bs, 1)
bs = append(bs, m.priority) bs = append(bs, m.priority)
binary.BigEndian.PutUint16(bs[4:6], uint16(len(bs)-6))
return bs return bs
} }
// Decodes version metadata from its wire format into the struct. // Decodes version metadata from its wire format into the struct.
func (m *version_metadata) decode(bs []byte) bool { func (m *version_metadata) decode(r io.Reader) bool {
meta := [4]byte{'m', 'e', 't', 'a'} bh := [6]byte{}
if !bytes.Equal(bs[:4], meta[:]) { if _, err := io.ReadFull(r, bh[:]); err != nil {
return false return false
} }
for bs = bs[4:]; len(bs) >= 4; { meta := [4]byte{'m', 'e', 't', 'a'}
if !bytes.Equal(bh[:4], meta[:]) {
return false
}
bs := make([]byte, binary.BigEndian.Uint16(bh[4:6]))
if _, err := io.ReadFull(r, bs); err != nil {
return false
}
for len(bs) >= 4 {
op := binary.BigEndian.Uint16(bs[:2]) op := binary.BigEndian.Uint16(bs[:2])
oplen := binary.BigEndian.Uint16(bs[2:4]) oplen := binary.BigEndian.Uint16(bs[2:4])
if bs = bs[4:]; len(bs) < int(oplen) { if bs = bs[4:]; len(bs) < int(oplen) {

View File

@ -1,6 +1,7 @@
package core package core
import ( import (
"bytes"
"crypto/ed25519" "crypto/ed25519"
"math/rand" "math/rand"
"reflect" "reflect"
@ -22,7 +23,7 @@ func TestVersionRoundtrip(t *testing.T) {
test.publicKey = make(ed25519.PublicKey, ed25519.PublicKeySize) test.publicKey = make(ed25519.PublicKey, ed25519.PublicKeySize)
rand.Read(test.publicKey) rand.Read(test.publicKey)
encoded := test.encode() encoded := bytes.NewBuffer(test.encode())
decoded := &version_metadata{} decoded := &version_metadata{}
if !decoded.decode(encoded) { if !decoded.decode(encoded) {
t.Fatalf("failed to decode") t.Fatalf("failed to decode")