diff --git a/src/core/link.go b/src/core/link.go index c13d6af3..c847e5ac 100644 --- a/src/core/link.go +++ b/src/core/link.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/hex" - "errors" "fmt" "io" "math" @@ -485,16 +484,10 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn) e case err == nil && n != len(metaBytes): return fmt.Errorf("incomplete handshake send") } - if _, err = io.ReadFull(conn, metaBytes); err != nil { - return fmt.Errorf("read handshake: %w", err) - } - if err = conn.SetDeadline(time.Time{}); err != nil { - return fmt.Errorf("failed to clear handshake deadline: %w", err) - } meta = version_metadata{} base := version_getBaseMetadata() - if !meta.decode(metaBytes) { - return errors.New("failed to decode metadata") + if !meta.decode(conn) { + return conn.Close() } if !meta.check() { return fmt.Errorf("remote node incompatible version (local %s, remote %s)", @@ -502,6 +495,9 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn) e fmt.Sprintf("%d.%d", meta.majorVer, meta.minorVer), ) } + if err = conn.SetDeadline(time.Time{}); err != nil { + return fmt.Errorf("failed to clear handshake deadline: %w", err) + } // Check if the remote side matches the keys we expected. This is a bit of a weak // check - in future versions we really should check a signature or something like that. if pinned := options.pinnedEd25519Keys; len(pinned) > 0 { diff --git a/src/core/version.go b/src/core/version.go index 3787d1e0..0820fbdd 100644 --- a/src/core/version.go +++ b/src/core/version.go @@ -8,6 +8,7 @@ import ( "bytes" "crypto/ed25519" "encoding/binary" + "io" ) // This is the version-specific metadata exchanged at the start of a connection. @@ -44,6 +45,7 @@ func version_getBaseMetadata() version_metadata { func (m *version_metadata) encode() []byte { bs := make([]byte, 0, 64) bs = append(bs, 'm', 'e', 't', 'a') + bs = append(bs, 0, 0) // Remaining message length bs = binary.BigEndian.AppendUint16(bs, metaVersionMajor) bs = binary.BigEndian.AppendUint16(bs, 2) @@ -61,16 +63,25 @@ func (m *version_metadata) encode() []byte { bs = binary.BigEndian.AppendUint16(bs, 1) bs = append(bs, m.priority) + binary.BigEndian.PutUint16(bs[4:6], uint16(len(bs)-6)) return bs } // Decodes version metadata from its wire format into the struct. -func (m *version_metadata) decode(bs []byte) bool { - meta := [4]byte{'m', 'e', 't', 'a'} - if !bytes.Equal(bs[:4], meta[:]) { +func (m *version_metadata) decode(r io.Reader) bool { + bh := [6]byte{} + if _, err := io.ReadFull(r, bh[:]); err != nil { return false } - for bs = bs[4:]; len(bs) >= 4; { + meta := [4]byte{'m', 'e', 't', 'a'} + if !bytes.Equal(bh[:4], meta[:]) { + return false + } + bs := make([]byte, binary.BigEndian.Uint16(bh[4:6])) + if _, err := io.ReadFull(r, bs); err != nil { + return false + } + for len(bs) >= 4 { op := binary.BigEndian.Uint16(bs[:2]) oplen := binary.BigEndian.Uint16(bs[2:4]) if bs = bs[4:]; len(bs) < int(oplen) { diff --git a/src/core/version_test.go b/src/core/version_test.go index 6fb7895c..511cb358 100644 --- a/src/core/version_test.go +++ b/src/core/version_test.go @@ -1,6 +1,7 @@ package core import ( + "bytes" "crypto/ed25519" "math/rand" "reflect" @@ -22,7 +23,7 @@ func TestVersionRoundtrip(t *testing.T) { test.publicKey = make(ed25519.PublicKey, ed25519.PublicKeySize) rand.Read(test.publicKey) - encoded := test.encode() + encoded := bytes.NewBuffer(test.encode()) decoded := &version_metadata{} if !decoded.decode(encoded) { t.Fatalf("failed to decode")