Fix buffering in version negotiation

The service-side version negotiation had a buffer overread that would
cause remotely triggerable panic.

Refactor that code to resolve that issue, follow the spec more exactly,
and avoid reading more data from the socket than is used for version
negotiation, in case clients write optimistically.
This commit is contained in:
John Brooks 2016-10-01 21:28:57 -07:00 committed by Sarah Jamie Lewis
parent 1c317fc186
commit cc50e0dfe9
1 changed files with 43 additions and 39 deletions

View File

@ -8,6 +8,7 @@ import (
"github.com/s-rah/go-ricochet/contact"
"github.com/s-rah/go-ricochet/control"
"github.com/s-rah/go-ricochet/utils"
"io"
"log"
"net"
"strconv"
@ -327,52 +328,55 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
}
}
// negotiateVersion Perform version negotiation with the connected host.
// Perform version negotiation on the connection, and create an OpenConnection if successful
func (r *Ricochet) negotiateVersion(conn net.Conn, outbound bool) (*OpenConnection, error) {
version := make([]byte, 4)
version[0] = 0x49
version[1] = 0x4D
version[2] = 0x01
version[3] = 0x01
versions := []byte{0x49, 0x4D, 0x01, 0x01}
// If this was initiated by us then we need to initiate the version info.
// Outbound side of the connection sends a list of supported versions
if outbound {
// Send Version String
conn.Write(version)
res, err := r.rni.Recv(conn)
if len(res) != 1 || err != nil {
return nil, errors.New("Failed Version Negotiating")
}
if res[0] != 1 {
return nil, errors.New("Failed Version Negotiating - Invalid Version ")
}
} else {
// Do Version Negotiation
buf := make([]byte, 10)
n, err := conn.Read(buf)
if err != nil && n >= 4 {
if n, err := conn.Write(versions); err != nil || n < len(versions) {
return nil, err
}
if buf[0] == version[0] && buf[1] == version[1] {
foundVersion := false
if buf[2] >= 1 {
for i := 3; i < n; i++ {
if buf[i] == 0x01 {
conn.Write([]byte{0x01})
foundVersion = true
}
}
res := make([]byte, 1)
if _, err := io.ReadAtLeast(conn, res, len(res)); err != nil {
return nil, err
}
if res[0] != 0x01 {
return nil, errors.New("unsupported protocol version")
}
} else {
// Read version response header
header := make([]byte, 3)
if _, err := io.ReadAtLeast(conn, header, len(header)); err != nil {
return nil, err
}
if header[0] != versions[0] || header[1] != versions[1] || header[2] < 1 {
return nil, errors.New("invalid protocol response")
}
// Read list of supported versions (which is header[2] bytes long)
versionList := make([]byte, header[2])
if _, err := io.ReadAtLeast(conn, versionList, len(versionList)); err != nil {
return nil, err
}
selectedVersion := byte(0xff)
for _, v := range versionList {
if v == 0x01 {
selectedVersion = v
break
}
if !foundVersion {
return nil, errors.New("Failed Version Negotiating - No Available Version")
}
} else {
return nil, errors.New("Failed Version Negotiating - Invalid Version Header")
}
if n, err := conn.Write([]byte{selectedVersion}); err != nil || n < 1 {
return nil, err
}
if selectedVersion == 0xff {
return nil, errors.New("no supported protocol version")
}
}