Improve packet-layer buffering and parsing logic

SendRicochetPacket now has error handling, correctly encodes channel
ids, accepts any io.Writer, and ensures that all data is written. All
callers should be changed at some point to handle errors also.

RecvRicochetPackets is refactored to return only one packet per call and
avoid reading more data than it will consume, which simplifies the logic
and fixes a number of problems with short reads or large packets. Also
fixed an error in bounds checking that caused a remote panic for invalid
packet sizes. It also now accepts any io.Reader.

Tests are updated and expanded, and now pass.

Changes to Ricochet.processConnection are whitespace-only, because of
the removal of the inner packets loop.
This commit is contained in:
John Brooks 2016-10-02 15:39:08 -07:00 committed by Sarah Jamie Lewis
parent cc50e0dfe9
commit 47ba383334
3 changed files with 300 additions and 381 deletions

View File

@ -109,147 +109,175 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
return return
} }
packets, err := r.rni.RecvRicochetPackets(oc.conn) packet, err := r.rni.RecvRicochetPacket(oc.conn)
if err != nil { if err != nil {
oc.Close()
return return
} }
for _, packet := range packets { if len(packet.Data) == 0 {
service.OnChannelClosed(oc, packet.Channel)
continue
}
if len(packet.Data) == 0 { if packet.Channel == 0 {
service.OnChannelClosed(oc, packet.Channel)
res := new(Protocol_Data_Control.Packet)
err := proto.Unmarshal(packet.Data[:], res)
if err != nil {
service.OnGenericError(oc, packet.Channel)
continue continue
} }
if packet.Channel == 0 { if res.GetOpenChannel() != nil {
opm := res.GetOpenChannel()
res := new(Protocol_Data_Control.Packet) if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" {
err := proto.Unmarshal(packet.Data[:], res) // Channel is already in use.
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
if err != nil {
service.OnGenericError(oc, packet.Channel)
continue continue
} }
if res.GetOpenChannel() != nil { // If I am a Client, the server can only open even numbered channels
opm := res.GetOpenChannel() if oc.Client && opm.GetChannelIdentifier()%2 != 0 {
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
continue
}
if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" { // If I am a Server, the client can only open odd numbered channels
// Channel is already in use. if !oc.Client && opm.GetChannelIdentifier()%2 != 1 {
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
continue
}
switch opm.GetChannelType() {
case "im.ricochet.auth.hidden-service":
if oc.Client {
// Servers are authed by default and can't auth with hidden-service
service.OnBadUsageError(oc, opm.GetChannelIdentifier()) service.OnBadUsageError(oc, opm.GetChannelIdentifier())
continue } else if oc.IsAuthed {
} // Can't auth if already authed
// If I am a Client, the server can only open even numbered channels
if oc.Client && opm.GetChannelIdentifier()%2 != 0 {
service.OnBadUsageError(oc, opm.GetChannelIdentifier()) service.OnBadUsageError(oc, opm.GetChannelIdentifier())
continue } else if oc.HasChannel("im.ricochet.auth.hidden-service") {
} // Can't open more than 1 auth channel
// If I am a Server, the client can only open odd numbered channels
if !oc.Client && opm.GetChannelIdentifier()%2 != 1 {
service.OnBadUsageError(oc, opm.GetChannelIdentifier()) service.OnBadUsageError(oc, opm.GetChannelIdentifier())
continue } else {
} clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie)
if err == nil {
switch opm.GetChannelType() { clientCookieB := [16]byte{}
case "im.ricochet.auth.hidden-service": copy(clientCookieB[:], clientCookie.([]byte)[:])
if oc.Client { service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB)
// Servers are authed by default and can't auth with hidden-service
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else if oc.IsAuthed {
// Can't auth if already authed
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else if oc.HasChannel("im.ricochet.auth.hidden-service") {
// Can't open more than 1 auth channel
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else { } else {
clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie) // Must include Client Cookie
if err == nil { service.OnBadUsageError(oc, opm.GetChannelIdentifier())
clientCookieB := [16]byte{} }
copy(clientCookieB[:], clientCookie.([]byte)[:]) }
service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB) case "im.ricochet.chat":
} else { if !oc.IsAuthed {
// Must include Client Cookie // Can't open chat channel if not authorized
service.OnBadUsageError(oc, opm.GetChannelIdentifier()) service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
} else if !service.IsKnownContact(oc.OtherHostname) {
// Can't open chat channel if not a known contact
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
} else {
service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat")
}
case "im.ricochet.contact.request":
if oc.Client {
// Servers are not allowed to send contact requests
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else if !oc.IsAuthed {
// Can't open a contact channel if not authed
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
} else if oc.HasChannel("im.ricochet.contact.request") {
// Only 1 contact channel is allowed to be open at a time
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else {
contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest)
if err == nil {
contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest)
if check {
service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText())
break
} }
} }
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
}
default:
service.OnUnknownTypeError(oc, opm.GetChannelIdentifier())
}
} else if res.GetChannelResult() != nil {
crm := res.GetChannelResult()
if crm.GetOpened() {
switch oc.GetChannelType(crm.GetChannelIdentifier()) {
case "im.ricochet.auth.hidden-service":
serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie)
if err == nil {
serverCookieB := [16]byte{}
copy(serverCookieB[:], serverCookie.([]byte)[:])
service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB)
} else {
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
}
case "im.ricochet.chat": case "im.ricochet.chat":
if !oc.IsAuthed { service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier())
// Can't open chat channel if not authorized
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
} else if !service.IsKnownContact(oc.OtherHostname) {
// Can't open chat channel if not a known contact
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
} else {
service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat")
}
case "im.ricochet.contact.request": case "im.ricochet.contact.request":
if oc.Client { responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response)
// Servers are not allowed to send contact requests if err == nil {
service.OnBadUsageError(oc, opm.GetChannelIdentifier()) response, check := responseI.(*Protocol_Data_ContactRequest.Response)
} else if !oc.IsAuthed { if check {
// Can't open a contact channel if not authed service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String())
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) break
} else if oc.HasChannel("im.ricochet.contact.request") {
// Only 1 contact channel is allowed to be open at a time
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else {
contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest)
if err == nil {
contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest)
if check {
service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText())
break
}
} }
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} }
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
default: default:
service.OnUnknownTypeError(oc, opm.GetChannelIdentifier()) service.OnBadUsageError(oc, crm.GetChannelIdentifier())
}
} else if res.GetChannelResult() != nil {
crm := res.GetChannelResult()
if crm.GetOpened() {
switch oc.GetChannelType(crm.GetChannelIdentifier()) {
case "im.ricochet.auth.hidden-service":
serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie)
if err == nil {
serverCookieB := [16]byte{}
copy(serverCookieB[:], serverCookie.([]byte)[:])
service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB)
} else {
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
}
case "im.ricochet.chat":
service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier())
case "im.ricochet.contact.request":
responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response)
if err == nil {
response, check := responseI.(*Protocol_Data_ContactRequest.Response)
if check {
service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String())
break
}
}
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
default:
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
}
} else {
if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" {
service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String())
} else {
oc.CloseChannel(crm.GetChannelIdentifier())
}
} }
} else { } else {
// Unknown Message if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" {
oc.CloseChannel(packet.Channel) service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String())
} else {
oc.CloseChannel(crm.GetChannelIdentifier())
}
} }
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" { } else {
res := new(Protocol_Data_AuthHiddenService.Packet) // Unknown Message
oc.CloseChannel(packet.Channel)
}
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" {
res := new(Protocol_Data_AuthHiddenService.Packet)
err := proto.Unmarshal(packet.Data[:], res)
if err != nil {
oc.CloseChannel(packet.Channel)
continue
}
if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs
service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname))
} else if res.GetResult() != nil && oc.Client { // Only Servers Send Results
service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact())
} else {
// If neither of the above are satisfied we just close the connection
oc.Close()
}
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" {
// NOTE: These auth checks should be redundant, however they
// are included here for defense-in-depth if for some reason
// a previously authed connection becomes untrusted / not known and
// the state is not cleaned up.
if !oc.IsAuthed {
// Can't send chat messages if not authorized
service.OnUnauthorizedError(oc, packet.Channel)
} else if !service.IsKnownContact(oc.OtherHostname) {
// Can't send chat message if not a known contact
service.OnUnauthorizedError(oc, packet.Channel)
} else {
res := new(Protocol_Data_Chat.Packet)
err := proto.Unmarshal(packet.Data[:], res) err := proto.Unmarshal(packet.Data[:], res)
if err != nil { if err != nil {
@ -257,73 +285,42 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
continue continue
} }
if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs if res.GetChatMessage() != nil {
service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname)) service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText())
} else if res.GetResult() != nil && oc.Client { // Only Servers Send Results } else if res.GetChatAcknowledge() != nil {
service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()))
} else { } else {
// If neither of the above are satisfied we just close the connection // If neither of the above are satisfied we just close the connection
oc.Close() oc.Close()
} }
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" {
// NOTE: These auth checks should be redundant, however they
// are included here for defense-in-depth if for some reason
// a previously authed connection becomes untrusted / not known and
// the state is not cleaned up.
if !oc.IsAuthed {
// Can't send chat messages if not authorized
service.OnUnauthorizedError(oc, packet.Channel)
} else if !service.IsKnownContact(oc.OtherHostname) {
// Can't send chat message if not a known contact
service.OnUnauthorizedError(oc, packet.Channel)
} else {
res := new(Protocol_Data_Chat.Packet)
err := proto.Unmarshal(packet.Data[:], res)
if err != nil {
oc.CloseChannel(packet.Channel)
continue
}
if res.GetChatMessage() != nil {
service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText())
} else if res.GetChatAcknowledge() != nil {
service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()))
} else {
// If neither of the above are satisfied we just close the connection
oc.Close()
}
}
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" {
// NOTE: These auth checks should be redundant, however they
// are included here for defense-in-depth if for some reason
// a previously authed connection becomes untrusted / not known and
// the state is not cleaned up.
if !oc.Client {
// Clients are not allowed to send contact request responses
service.OnBadUsageError(oc, packet.Channel)
} else if !oc.IsAuthed {
// Can't send a contact request if not authed
service.OnBadUsageError(oc, packet.Channel)
} else {
res := new(Protocol_Data_ContactRequest.Response)
err := proto.Unmarshal(packet.Data[:], res)
log.Printf("%v", res)
if err != nil {
oc.CloseChannel(packet.Channel)
continue
}
service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String())
}
} else if oc.GetChannelType(packet.Channel) == "none" {
// Invalid Channel Assignment
oc.CloseChannel(packet.Channel)
} else {
oc.Close()
} }
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" {
// NOTE: These auth checks should be redundant, however they
// are included here for defense-in-depth if for some reason
// a previously authed connection becomes untrusted / not known and
// the state is not cleaned up.
if !oc.Client {
// Clients are not allowed to send contact request responses
service.OnBadUsageError(oc, packet.Channel)
} else if !oc.IsAuthed {
// Can't send a contact request if not authed
service.OnBadUsageError(oc, packet.Channel)
} else {
res := new(Protocol_Data_ContactRequest.Response)
err := proto.Unmarshal(packet.Data[:], res)
log.Printf("%v", res)
if err != nil {
oc.CloseChannel(packet.Channel)
continue
}
service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String())
}
} else if oc.GetChannelType(packet.Channel) == "none" {
// Invalid Channel Assignment
oc.CloseChannel(packet.Channel)
} else {
oc.Close()
} }
} }
} }

View File

@ -1,10 +1,10 @@
package utils package utils
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"net" "io"
"strconv"
) )
// RicochetData is a structure containing the raw data and the channel it the // RicochetData is a structure containing the raw data and the channel it the
@ -14,79 +14,67 @@ type RicochetData struct {
Data []byte Data []byte
} }
func (rd RicochetData) Equals(other RicochetData) bool {
return rd.Channel == other.Channel && bytes.Equal(rd.Data, other.Data)
}
// RicochetNetworkInterface abstract operations that interact with ricochet's // RicochetNetworkInterface abstract operations that interact with ricochet's
// packet layer. // packet layer.
type RicochetNetworkInterface interface { type RicochetNetworkInterface interface {
Recv(conn net.Conn) ([]byte, error) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error
SendRicochetPacket(conn net.Conn, channel int32, data []byte) RecvRicochetPacket(reader io.Reader) (RicochetData, error)
RecvRicochetPackets(conn net.Conn) ([]RicochetData, error)
} }
// RicochetNetwork is a concrete implementation of the RicochetNetworkInterface // RicochetNetwork is a concrete implementation of the RicochetNetworkInterface
type RicochetNetwork struct { type RicochetNetwork struct {
} }
// Recv reads data from the client, and returns the raw byte array, else error.
func (rn *RicochetNetwork) Recv(conn net.Conn) ([]byte, error) {
buf := make([]byte, 4096)
n, err := conn.Read(buf)
if err != nil {
return nil, err
}
ret := make([]byte, n)
copy(ret[:], buf[:])
return ret, nil
}
// SendRicochetPacket places the data into a structure needed for the client to // SendRicochetPacket places the data into a structure needed for the client to
// decode the packet and writes the packet to the network. // decode the packet and writes the packet to the network.
func (rn *RicochetNetwork) SendRicochetPacket(conn net.Conn, channel int32, data []byte) { func (rn *RicochetNetwork) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error {
header := make([]byte, 4+len(data)) packet := make([]byte, 4+len(data))
header[0] = byte(len(header) >> 8) if len(packet) > 65535 {
header[1] = byte(len(header) & 0x00FF) return errors.New("packet too large")
header[2] = 0x00 }
header[3] = byte(channel) binary.BigEndian.PutUint16(packet[0:2], uint16(len(packet)))
copy(header[4:], data[:]) if channel < 0 || channel > 65535 {
conn.Write(header) return errors.New("invalid channel ID")
}
binary.BigEndian.PutUint16(packet[2:4], uint16(channel))
copy(packet[4:], data[:])
for pos := 0; pos < len(packet); {
n, err := dst.Write(packet[pos:])
if err != nil {
return err
}
pos += n
}
return nil
} }
// RecvRicochetPackets returns an array of new messages received from the ricochet client // RecvRicochetPacket returns the next packet from reader as a RicochetData
func (rn *RicochetNetwork) RecvRicochetPackets(conn net.Conn) ([]RicochetData, error) { // structure, or an error.
buf, err := rn.Recv(conn) func (rn *RicochetNetwork) RecvRicochetPacket(reader io.Reader) (RicochetData, error) {
if err != nil && len(buf) < 4 { packet := RicochetData{}
return nil, errors.New("failed to retrieve new messages from the client")
// Read the four-byte header to get packet length
header := make([]byte, 4)
if _, err := io.ReadAtLeast(reader, header, len(header)); err != nil {
return packet, err
} }
pos := 0 size := int(binary.BigEndian.Uint16(header[0:2]))
finished := false if size < 4 {
var datas []RicochetData return packet, errors.New("invalid packet length")
for !finished {
size := int(binary.BigEndian.Uint16(buf[pos+0 : pos+2]))
channel := int(binary.BigEndian.Uint16(buf[pos+2 : pos+4]))
if size < 4 {
return datas, errors.New("invalid ricochet packet received (size=" + strconv.Itoa(size) + ")")
}
if pos+size > len(buf) {
return datas, errors.New("partial data packet received")
}
data := RicochetData{}
data.Channel = int32(channel)
if pos+4 >= len(buf) {
data.Data = make([]byte, 0)
} else {
data.Data = buf[pos+4 : pos+size]
}
datas = append(datas, data)
pos += size
if pos >= len(buf) {
finished = true
}
} }
return datas, nil
packet.Channel = int32(binary.BigEndian.Uint16(header[2:4]))
packet.Data = make([]byte, size-4)
if _, err := io.ReadAtLeast(reader, packet.Data, len(packet.Data)); err != nil {
return packet, err
}
return packet, nil
} }

View File

@ -1,171 +1,105 @@
package utils package utils
import "testing" import (
import "net" "bytes"
import "time" "io"
"testing"
"testing/iotest"
)
type MockConn struct { // Valid packets and their encoded forms
Written []byte var packetTests = []struct {
MockOutput []byte packet RicochetData
encoded []byte
}{
{RicochetData{1, []byte{}}, []byte{0x00, 0x04, 0x00, 0x01}},
{RicochetData{65535, []byte{0xDE, 0xAD, 0xBE, 0xEF}}, []byte{0x00, 0x08, 0xFF, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}},
{RicochetData{2, make([]byte, 65531)}, append([]byte{0xFF, 0xFF, 0x00, 0x02}, make([]byte, 65531)...)},
} }
func (mc *MockConn) Read(b []byte) (int, error) { // Test sending valid packets
copy(b[:], mc.MockOutput[:]) func TestSendRicochetPacket(t *testing.T) {
return len(mc.MockOutput), nil
}
func (mc *MockConn) Write(written []byte) (int, error) {
mc.Written = written
return 0, nil
}
func (mc *MockConn) LocalAddr() net.Addr {
return nil
}
func (mc *MockConn) RemoteAddr() net.Addr {
return nil
}
func (mc *MockConn) Close() error {
return nil
}
func (mc *MockConn) SetDeadline(t time.Time) error {
return nil
}
func (mc *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (mc *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestSentRicochetPacket(t *testing.T) {
conn := new(MockConn)
rni := RicochetNetwork{} rni := RicochetNetwork{}
rni.SendRicochetPacket(conn, 1, []byte{}) for _, td := range packetTests {
if len(conn.Written) != 4 && conn.Written[0] != 0x00 && conn.Written[1] != 0x00 && conn.Written[2] != 0x01 && conn.Written[3] != 0x00 { var buf bytes.Buffer
t.Errorf("Output of SentRicochetPacket was Unexpected: %x", conn.Written) err := rni.SendRicochetPacket(&buf, td.packet.Channel, td.packet.Data)
if err != nil {
t.Errorf("Error sending packet %v: %v", td.packet, err)
} else if !bytes.Equal(buf.Bytes(), td.encoded) {
t.Errorf("Expected serialized packet %x but got %x", td.encoded, buf.Bytes())
}
} }
} }
func TestRecv(t *testing.T) { // Test sending invalid packets
conn := new(MockConn) func TestSendRicochetPacket_Invalid(t *testing.T) {
conn.MockOutput = []byte{0xDE, 0xAD, 0xBE, 0xEF}
rni := RicochetNetwork{} rni := RicochetNetwork{}
buf, err := rni.Recv(conn) invalidPackets := []RicochetData{
if err != nil || len(buf) != 4 || buf[0] != 0xDE || buf[1] != 0xAD || buf[2] != 0xBE || buf[3] != 0xEF { RicochetData{-1, []byte{}},
t.Errorf("Output of Recv was Unexpected: %x", buf) RicochetData{65536, []byte{}},
RicochetData{0, make([]byte, 65532)},
}
for _, td := range invalidPackets {
var buf bytes.Buffer
err := rni.SendRicochetPacket(&buf, td.Channel, td.Data)
// Expect error
if err == nil {
t.Errorf("Expected error when sending invalid packet %v", td)
}
} }
} }
// Test receiving valid packets
func TestRecvRicochetPacket(t *testing.T) { func TestRecvRicochetPacket(t *testing.T) {
conn := new(MockConn) var buf bytes.Buffer
conn.MockOutput = []byte{00, 0x04, 0x00, 0x01} for _, td := range packetTests {
if _, err := buf.Write(td.encoded); err != nil {
t.Error(err)
return
}
}
// Use a HalfReader to test behavior on short socket reads also
reader := iotest.HalfReader(&buf)
rni := RicochetNetwork{} rni := RicochetNetwork{}
rp, err := rni.RecvRicochetPackets(conn)
if err != nil { for _, td := range packetTests {
t.Errorf("error extracting ricochet packets: %v", err) packet, err := rni.RecvRicochetPacket(reader)
return if err != nil {
} t.Errorf("Error receiving packet %v: %v", td.packet, err)
return
if len(rp) != 1 { } else if !packet.Equals(td.packet) {
t.Errorf("unexpected number of ricochet packets: %d", len(rp)) t.Errorf("Expected unserialized packet %v but got %v", td.packet, packet)
} else {
if rp[0].Channel != 1 {
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
}
if len(rp[0].Data) != 0 {
t.Errorf("expected emptry packet, instead got %x", rp[0].Data)
} }
} }
if packet, err := rni.RecvRicochetPacket(reader); err != io.EOF {
if err != nil {
t.Errorf("Expected EOF on packet stream but received error: %v", err)
} else {
t.Errorf("Expected EOF but received packet: %v", packet)
}
}
} }
func TestRecvRicochetPacketInvalid(t *testing.T) { // Test receiving invalid packets
conn := new(MockConn) func TestRecvRicochetPacket_Invalid(t *testing.T) {
conn.MockOutput = []byte{00, 0x01, 0x00, 0x01}
rni := RicochetNetwork{} rni := RicochetNetwork{}
_, err := rni.RecvRicochetPackets(conn) invalidPackets := [][]byte{
[]byte{0x00, 0x00, 0x00, 0x00},
if err == nil { []byte{0x00, 0x03, 0x00, 0x00},
t.Errorf("recv should have errored due to invalid packets %v", err) []byte{0xff},
[]byte{0x00, 0x06, 0x00, 0x00, 0x00},
[]byte{},
} }
conn.MockOutput = []byte{00, 0x0A, 0x00, 0x01} for _, td := range invalidPackets {
buf := bytes.NewBuffer(td)
_, err = rni.RecvRicochetPackets(conn) packet, err := rni.RecvRicochetPacket(buf)
// Expect error
if err == nil { if err == nil {
t.Errorf("recv should have errored due to invalid packets %v", err) t.Errorf("Expected error when sending invalid packet %x, got packet %v", td, packet)
}
} }
}
func TestRecvRicochetPacketLong(t *testing.T) {
conn := new(MockConn)
conn.MockOutput = []byte{0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
rni := RicochetNetwork{}
rp, err := rni.RecvRicochetPackets(conn)
if err != nil {
t.Errorf("error extracting ricochet packets: %v", err)
return
}
if len(rp) != 1 {
t.Errorf("unexpected number of ricochet packets: %d", len(rp))
} else {
if rp[0].Channel != 255 {
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
}
if len(rp[0].Data) != 4 || rp[0].Data[0] != 0xDE || rp[0].Data[1] != 0xAD || rp[0].Data[2] != 0xBE || rp[0].Data[3] != 0xEF {
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
}
}
}
func TestRecvRicochetPacketMultiplex(t *testing.T) {
conn := new(MockConn)
conn.MockOutput = []byte{0x00, 0x04, 0x00, 0x01, 0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
rni := RicochetNetwork{}
rp, err := rni.RecvRicochetPackets(conn)
if err != nil {
t.Errorf("error extracting ricochet packets: %v", err)
return
}
if len(rp) != 2 {
t.Errorf("unexpected number of ricochet packets, expected 2 gt: %d", len(rp))
} else {
if rp[0].Channel != 1 {
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
}
if len(rp[0].Data) != 0 {
t.Errorf("expected empty packet, instead got %x", rp[0].Data)
}
if rp[1].Channel != 255 {
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
}
if len(rp[1].Data) != 4 || rp[1].Data[0] != 0xDE || rp[1].Data[1] != 0xAD || rp[1].Data[2] != 0xBE || rp[1].Data[3] != 0xEF {
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
}
}
} }