diff --git a/ricochet.go b/ricochet.go index 69a5c12..6bf642c 100644 --- a/ricochet.go +++ b/ricochet.go @@ -8,6 +8,7 @@ import ( "github.com/s-rah/go-ricochet/contact" "github.com/s-rah/go-ricochet/control" "github.com/s-rah/go-ricochet/utils" + "io" "log" "net" "strconv" @@ -327,52 +328,55 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService } } -// negotiateVersion Perform version negotiation with the connected host. +// Perform version negotiation on the connection, and create an OpenConnection if successful func (r *Ricochet) negotiateVersion(conn net.Conn, outbound bool) (*OpenConnection, error) { - version := make([]byte, 4) - version[0] = 0x49 - version[1] = 0x4D - version[2] = 0x01 - version[3] = 0x01 + versions := []byte{0x49, 0x4D, 0x01, 0x01} - // If this was initiated by us then we need to initiate the version info. + // Outbound side of the connection sends a list of supported versions if outbound { - // Send Version String - - conn.Write(version) - res, err := r.rni.Recv(conn) - - if len(res) != 1 || err != nil { - return nil, errors.New("Failed Version Negotiating") - } - - if res[0] != 1 { - return nil, errors.New("Failed Version Negotiating - Invalid Version ") - } - } else { - // Do Version Negotiation - - buf := make([]byte, 10) - n, err := conn.Read(buf) - if err != nil && n >= 4 { + if n, err := conn.Write(versions); err != nil || n < len(versions) { return nil, err } - if buf[0] == version[0] && buf[1] == version[1] { - foundVersion := false - if buf[2] >= 1 { - for i := 3; i < n; i++ { - if buf[i] == 0x01 { - conn.Write([]byte{0x01}) - foundVersion = true - } - } + res := make([]byte, 1) + if _, err := io.ReadAtLeast(conn, res, len(res)); err != nil { + return nil, err + } + + if res[0] != 0x01 { + return nil, errors.New("unsupported protocol version") + } + } else { + // Read version response header + header := make([]byte, 3) + if _, err := io.ReadAtLeast(conn, header, len(header)); err != nil { + return nil, err + } + + if header[0] != versions[0] || header[1] != versions[1] || header[2] < 1 { + return nil, errors.New("invalid protocol response") + } + + // Read list of supported versions (which is header[2] bytes long) + versionList := make([]byte, header[2]) + if _, err := io.ReadAtLeast(conn, versionList, len(versionList)); err != nil { + return nil, err + } + + selectedVersion := byte(0xff) + for _, v := range versionList { + if v == 0x01 { + selectedVersion = v + break } - if !foundVersion { - return nil, errors.New("Failed Version Negotiating - No Available Version") - } - } else { - return nil, errors.New("Failed Version Negotiating - Invalid Version Header") + } + + if n, err := conn.Write([]byte{selectedVersion}); err != nil || n < 1 { + return nil, err + } + + if selectedVersion == 0xff { + return nil, errors.New("no supported protocol version") } }