Improve packet-layer buffering and parsing logic
SendRicochetPacket now has error handling, correctly encodes channel ids, accepts any io.Writer, and ensures that all data is written. All callers should be changed at some point to handle errors also. RecvRicochetPackets is refactored to return only one packet per call and avoid reading more data than it will consume, which simplifies the logic and fixes a number of problems with short reads or large packets. Also fixed an error in bounds checking that caused a remote panic for invalid packet sizes. It also now accepts any io.Reader. Tests are updated and expanded, and now pass. Changes to Ricochet.processConnection are whitespace-only, because of the removal of the inner packets loop.
This commit is contained in:
parent
cc50e0dfe9
commit
47ba383334
|
@ -109,14 +109,12 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
packets, err := r.rni.RecvRicochetPackets(oc.conn)
|
packet, err := r.rni.RecvRicochetPacket(oc.conn)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
oc.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, packet := range packets {
|
|
||||||
|
|
||||||
if len(packet.Data) == 0 {
|
if len(packet.Data) == 0 {
|
||||||
service.OnChannelClosed(oc, packet.Channel)
|
service.OnChannelClosed(oc, packet.Channel)
|
||||||
continue
|
continue
|
||||||
|
@ -325,7 +323,6 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
|
||||||
oc.Close()
|
oc.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform version negotiation on the connection, and create an OpenConnection if successful
|
// Perform version negotiation on the connection, and create an OpenConnection if successful
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"io"
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RicochetData is a structure containing the raw data and the channel it the
|
// RicochetData is a structure containing the raw data and the channel it the
|
||||||
|
@ -14,79 +14,67 @@ type RicochetData struct {
|
||||||
Data []byte
|
Data []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rd RicochetData) Equals(other RicochetData) bool {
|
||||||
|
return rd.Channel == other.Channel && bytes.Equal(rd.Data, other.Data)
|
||||||
|
}
|
||||||
|
|
||||||
// RicochetNetworkInterface abstract operations that interact with ricochet's
|
// RicochetNetworkInterface abstract operations that interact with ricochet's
|
||||||
// packet layer.
|
// packet layer.
|
||||||
type RicochetNetworkInterface interface {
|
type RicochetNetworkInterface interface {
|
||||||
Recv(conn net.Conn) ([]byte, error)
|
SendRicochetPacket(dst io.Writer, channel int32, data []byte) error
|
||||||
SendRicochetPacket(conn net.Conn, channel int32, data []byte)
|
RecvRicochetPacket(reader io.Reader) (RicochetData, error)
|
||||||
RecvRicochetPackets(conn net.Conn) ([]RicochetData, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RicochetNetwork is a concrete implementation of the RicochetNetworkInterface
|
// RicochetNetwork is a concrete implementation of the RicochetNetworkInterface
|
||||||
type RicochetNetwork struct {
|
type RicochetNetwork struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recv reads data from the client, and returns the raw byte array, else error.
|
|
||||||
func (rn *RicochetNetwork) Recv(conn net.Conn) ([]byte, error) {
|
|
||||||
buf := make([]byte, 4096)
|
|
||||||
n, err := conn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ret := make([]byte, n)
|
|
||||||
copy(ret[:], buf[:])
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendRicochetPacket places the data into a structure needed for the client to
|
// SendRicochetPacket places the data into a structure needed for the client to
|
||||||
// decode the packet and writes the packet to the network.
|
// decode the packet and writes the packet to the network.
|
||||||
func (rn *RicochetNetwork) SendRicochetPacket(conn net.Conn, channel int32, data []byte) {
|
func (rn *RicochetNetwork) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error {
|
||||||
header := make([]byte, 4+len(data))
|
packet := make([]byte, 4+len(data))
|
||||||
header[0] = byte(len(header) >> 8)
|
if len(packet) > 65535 {
|
||||||
header[1] = byte(len(header) & 0x00FF)
|
return errors.New("packet too large")
|
||||||
header[2] = 0x00
|
}
|
||||||
header[3] = byte(channel)
|
binary.BigEndian.PutUint16(packet[0:2], uint16(len(packet)))
|
||||||
copy(header[4:], data[:])
|
if channel < 0 || channel > 65535 {
|
||||||
conn.Write(header)
|
return errors.New("invalid channel ID")
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(packet[2:4], uint16(channel))
|
||||||
|
copy(packet[4:], data[:])
|
||||||
|
|
||||||
|
for pos := 0; pos < len(packet); {
|
||||||
|
n, err := dst.Write(packet[pos:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pos += n
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecvRicochetPackets returns an array of new messages received from the ricochet client
|
// RecvRicochetPacket returns the next packet from reader as a RicochetData
|
||||||
func (rn *RicochetNetwork) RecvRicochetPackets(conn net.Conn) ([]RicochetData, error) {
|
// structure, or an error.
|
||||||
buf, err := rn.Recv(conn)
|
func (rn *RicochetNetwork) RecvRicochetPacket(reader io.Reader) (RicochetData, error) {
|
||||||
if err != nil && len(buf) < 4 {
|
packet := RicochetData{}
|
||||||
return nil, errors.New("failed to retrieve new messages from the client")
|
|
||||||
|
// Read the four-byte header to get packet length
|
||||||
|
header := make([]byte, 4)
|
||||||
|
if _, err := io.ReadAtLeast(reader, header, len(header)); err != nil {
|
||||||
|
return packet, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pos := 0
|
size := int(binary.BigEndian.Uint16(header[0:2]))
|
||||||
finished := false
|
|
||||||
var datas []RicochetData
|
|
||||||
|
|
||||||
for !finished {
|
|
||||||
size := int(binary.BigEndian.Uint16(buf[pos+0 : pos+2]))
|
|
||||||
channel := int(binary.BigEndian.Uint16(buf[pos+2 : pos+4]))
|
|
||||||
|
|
||||||
if size < 4 {
|
if size < 4 {
|
||||||
return datas, errors.New("invalid ricochet packet received (size=" + strconv.Itoa(size) + ")")
|
return packet, errors.New("invalid packet length")
|
||||||
}
|
}
|
||||||
|
|
||||||
if pos+size > len(buf) {
|
packet.Channel = int32(binary.BigEndian.Uint16(header[2:4]))
|
||||||
return datas, errors.New("partial data packet received")
|
packet.Data = make([]byte, size-4)
|
||||||
|
|
||||||
|
if _, err := io.ReadAtLeast(reader, packet.Data, len(packet.Data)); err != nil {
|
||||||
|
return packet, err
|
||||||
}
|
}
|
||||||
|
|
||||||
data := RicochetData{}
|
return packet, nil
|
||||||
data.Channel = int32(channel)
|
|
||||||
|
|
||||||
if pos+4 >= len(buf) {
|
|
||||||
data.Data = make([]byte, 0)
|
|
||||||
} else {
|
|
||||||
data.Data = buf[pos+4 : pos+size]
|
|
||||||
}
|
|
||||||
|
|
||||||
datas = append(datas, data)
|
|
||||||
pos += size
|
|
||||||
if pos >= len(buf) {
|
|
||||||
finished = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return datas, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,171 +1,105 @@
|
||||||
package utils
|
package utils
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
import "net"
|
"bytes"
|
||||||
import "time"
|
"io"
|
||||||
|
"testing"
|
||||||
|
"testing/iotest"
|
||||||
|
)
|
||||||
|
|
||||||
type MockConn struct {
|
// Valid packets and their encoded forms
|
||||||
Written []byte
|
var packetTests = []struct {
|
||||||
MockOutput []byte
|
packet RicochetData
|
||||||
|
encoded []byte
|
||||||
|
}{
|
||||||
|
{RicochetData{1, []byte{}}, []byte{0x00, 0x04, 0x00, 0x01}},
|
||||||
|
{RicochetData{65535, []byte{0xDE, 0xAD, 0xBE, 0xEF}}, []byte{0x00, 0x08, 0xFF, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}},
|
||||||
|
{RicochetData{2, make([]byte, 65531)}, append([]byte{0xFF, 0xFF, 0x00, 0x02}, make([]byte, 65531)...)},
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mc *MockConn) Read(b []byte) (int, error) {
|
// Test sending valid packets
|
||||||
copy(b[:], mc.MockOutput[:])
|
func TestSendRicochetPacket(t *testing.T) {
|
||||||
return len(mc.MockOutput), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MockConn) Write(written []byte) (int, error) {
|
|
||||||
mc.Written = written
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MockConn) LocalAddr() net.Addr {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MockConn) RemoteAddr() net.Addr {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MockConn) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MockConn) SetDeadline(t time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MockConn) SetReadDeadline(t time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MockConn) SetWriteDeadline(t time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSentRicochetPacket(t *testing.T) {
|
|
||||||
conn := new(MockConn)
|
|
||||||
rni := RicochetNetwork{}
|
rni := RicochetNetwork{}
|
||||||
rni.SendRicochetPacket(conn, 1, []byte{})
|
for _, td := range packetTests {
|
||||||
if len(conn.Written) != 4 && conn.Written[0] != 0x00 && conn.Written[1] != 0x00 && conn.Written[2] != 0x01 && conn.Written[3] != 0x00 {
|
var buf bytes.Buffer
|
||||||
t.Errorf("Output of SentRicochetPacket was Unexpected: %x", conn.Written)
|
err := rni.SendRicochetPacket(&buf, td.packet.Channel, td.packet.Data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error sending packet %v: %v", td.packet, err)
|
||||||
|
} else if !bytes.Equal(buf.Bytes(), td.encoded) {
|
||||||
|
t.Errorf("Expected serialized packet %x but got %x", td.encoded, buf.Bytes())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRecv(t *testing.T) {
|
// Test sending invalid packets
|
||||||
conn := new(MockConn)
|
func TestSendRicochetPacket_Invalid(t *testing.T) {
|
||||||
conn.MockOutput = []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
|
||||||
rni := RicochetNetwork{}
|
rni := RicochetNetwork{}
|
||||||
buf, err := rni.Recv(conn)
|
invalidPackets := []RicochetData{
|
||||||
if err != nil || len(buf) != 4 || buf[0] != 0xDE || buf[1] != 0xAD || buf[2] != 0xBE || buf[3] != 0xEF {
|
RicochetData{-1, []byte{}},
|
||||||
t.Errorf("Output of Recv was Unexpected: %x", buf)
|
RicochetData{65536, []byte{}},
|
||||||
|
RicochetData{0, make([]byte, 65532)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, td := range invalidPackets {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := rni.SendRicochetPacket(&buf, td.Channel, td.Data)
|
||||||
|
// Expect error
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error when sending invalid packet %v", td)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test receiving valid packets
|
||||||
func TestRecvRicochetPacket(t *testing.T) {
|
func TestRecvRicochetPacket(t *testing.T) {
|
||||||
conn := new(MockConn)
|
var buf bytes.Buffer
|
||||||
conn.MockOutput = []byte{00, 0x04, 0x00, 0x01}
|
for _, td := range packetTests {
|
||||||
|
if _, err := buf.Write(td.encoded); err != nil {
|
||||||
rni := RicochetNetwork{}
|
t.Error(err)
|
||||||
rp, err := rni.RecvRicochetPackets(conn)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("error extracting ricochet packets: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(rp) != 1 {
|
// Use a HalfReader to test behavior on short socket reads also
|
||||||
t.Errorf("unexpected number of ricochet packets: %d", len(rp))
|
reader := iotest.HalfReader(&buf)
|
||||||
|
rni := RicochetNetwork{}
|
||||||
|
|
||||||
|
for _, td := range packetTests {
|
||||||
|
packet, err := rni.RecvRicochetPacket(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error receiving packet %v: %v", td.packet, err)
|
||||||
|
return
|
||||||
|
} else if !packet.Equals(td.packet) {
|
||||||
|
t.Errorf("Expected unserialized packet %v but got %v", td.packet, packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if packet, err := rni.RecvRicochetPacket(reader); err != io.EOF {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected EOF on packet stream but received error: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if rp[0].Channel != 1 {
|
t.Errorf("Expected EOF but received packet: %v", packet)
|
||||||
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rp[0].Data) != 0 {
|
|
||||||
t.Errorf("expected emptry packet, instead got %x", rp[0].Data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRecvRicochetPacketInvalid(t *testing.T) {
|
// Test receiving invalid packets
|
||||||
conn := new(MockConn)
|
func TestRecvRicochetPacket_Invalid(t *testing.T) {
|
||||||
conn.MockOutput = []byte{00, 0x01, 0x00, 0x01}
|
|
||||||
|
|
||||||
rni := RicochetNetwork{}
|
rni := RicochetNetwork{}
|
||||||
_, err := rni.RecvRicochetPackets(conn)
|
invalidPackets := [][]byte{
|
||||||
|
[]byte{0x00, 0x00, 0x00, 0x00},
|
||||||
|
[]byte{0x00, 0x03, 0x00, 0x00},
|
||||||
|
[]byte{0xff},
|
||||||
|
[]byte{0x00, 0x06, 0x00, 0x00, 0x00},
|
||||||
|
[]byte{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, td := range invalidPackets {
|
||||||
|
buf := bytes.NewBuffer(td)
|
||||||
|
packet, err := rni.RecvRicochetPacket(buf)
|
||||||
|
// Expect error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("recv should have errored due to invalid packets %v", err)
|
t.Errorf("Expected error when sending invalid packet %x, got packet %v", td, packet)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.MockOutput = []byte{00, 0x0A, 0x00, 0x01}
|
|
||||||
|
|
||||||
_, err = rni.RecvRicochetPackets(conn)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("recv should have errored due to invalid packets %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecvRicochetPacketLong(t *testing.T) {
|
|
||||||
conn := new(MockConn)
|
|
||||||
conn.MockOutput = []byte{0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
|
|
||||||
|
|
||||||
rni := RicochetNetwork{}
|
|
||||||
rp, err := rni.RecvRicochetPackets(conn)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("error extracting ricochet packets: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rp) != 1 {
|
|
||||||
t.Errorf("unexpected number of ricochet packets: %d", len(rp))
|
|
||||||
} else {
|
|
||||||
if rp[0].Channel != 255 {
|
|
||||||
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rp[0].Data) != 4 || rp[0].Data[0] != 0xDE || rp[0].Data[1] != 0xAD || rp[0].Data[2] != 0xBE || rp[0].Data[3] != 0xEF {
|
|
||||||
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecvRicochetPacketMultiplex(t *testing.T) {
|
|
||||||
conn := new(MockConn)
|
|
||||||
conn.MockOutput = []byte{0x00, 0x04, 0x00, 0x01, 0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
|
|
||||||
|
|
||||||
rni := RicochetNetwork{}
|
|
||||||
rp, err := rni.RecvRicochetPackets(conn)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("error extracting ricochet packets: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rp) != 2 {
|
|
||||||
t.Errorf("unexpected number of ricochet packets, expected 2 gt: %d", len(rp))
|
|
||||||
} else {
|
|
||||||
|
|
||||||
if rp[0].Channel != 1 {
|
|
||||||
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rp[0].Data) != 0 {
|
|
||||||
t.Errorf("expected empty packet, instead got %x", rp[0].Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rp[1].Channel != 255 {
|
|
||||||
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rp[1].Data) != 4 || rp[1].Data[0] != 0xDE || rp[1].Data[1] != 0xAD || rp[1].Data[2] != 0xBE || rp[1].Data[3] != 0xEF {
|
|
||||||
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue