From cc50e0dfe9a1025470c705a9efb2e4c929c09f9a Mon Sep 17 00:00:00 2001 From: John Brooks Date: Sat, 1 Oct 2016 21:28:57 -0700 Subject: [PATCH] Fix buffering in version negotiation The service-side version negotiation had a buffer overread that would cause remotely triggerable panic. Refactor that code to resolve that issue, follow the spec more exactly, and avoid reading more data from the socket than is used for version negotiation, in case clients write optimistically. --- ricochet.go | 82 ++++++++++++++++++++++++++++------------------------- 1 file changed, 43 insertions(+), 39 deletions(-) diff --git a/ricochet.go b/ricochet.go index 69a5c12..6bf642c 100644 --- a/ricochet.go +++ b/ricochet.go @@ -8,6 +8,7 @@ import ( "github.com/s-rah/go-ricochet/contact" "github.com/s-rah/go-ricochet/control" "github.com/s-rah/go-ricochet/utils" + "io" "log" "net" "strconv" @@ -327,52 +328,55 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService } } -// negotiateVersion Perform version negotiation with the connected host. +// Perform version negotiation on the connection, and create an OpenConnection if successful func (r *Ricochet) negotiateVersion(conn net.Conn, outbound bool) (*OpenConnection, error) { - version := make([]byte, 4) - version[0] = 0x49 - version[1] = 0x4D - version[2] = 0x01 - version[3] = 0x01 + versions := []byte{0x49, 0x4D, 0x01, 0x01} - // If this was initiated by us then we need to initiate the version info. + // Outbound side of the connection sends a list of supported versions if outbound { - // Send Version String - - conn.Write(version) - res, err := r.rni.Recv(conn) - - if len(res) != 1 || err != nil { - return nil, errors.New("Failed Version Negotiating") - } - - if res[0] != 1 { - return nil, errors.New("Failed Version Negotiating - Invalid Version ") - } - } else { - // Do Version Negotiation - - buf := make([]byte, 10) - n, err := conn.Read(buf) - if err != nil && n >= 4 { + if n, err := conn.Write(versions); err != nil || n < len(versions) { return nil, err } - if buf[0] == version[0] && buf[1] == version[1] { - foundVersion := false - if buf[2] >= 1 { - for i := 3; i < n; i++ { - if buf[i] == 0x01 { - conn.Write([]byte{0x01}) - foundVersion = true - } - } + res := make([]byte, 1) + if _, err := io.ReadAtLeast(conn, res, len(res)); err != nil { + return nil, err + } + + if res[0] != 0x01 { + return nil, errors.New("unsupported protocol version") + } + } else { + // Read version response header + header := make([]byte, 3) + if _, err := io.ReadAtLeast(conn, header, len(header)); err != nil { + return nil, err + } + + if header[0] != versions[0] || header[1] != versions[1] || header[2] < 1 { + return nil, errors.New("invalid protocol response") + } + + // Read list of supported versions (which is header[2] bytes long) + versionList := make([]byte, header[2]) + if _, err := io.ReadAtLeast(conn, versionList, len(versionList)); err != nil { + return nil, err + } + + selectedVersion := byte(0xff) + for _, v := range versionList { + if v == 0x01 { + selectedVersion = v + break } - if !foundVersion { - return nil, errors.New("Failed Version Negotiating - No Available Version") - } - } else { - return nil, errors.New("Failed Version Negotiating - Invalid Version Header") + } + + if n, err := conn.Write([]byte{selectedVersion}); err != nil || n < 1 { + return nil, err + } + + if selectedVersion == 0xff { + return nil, errors.New("no supported protocol version") } }