From 47ba383334ad4760e1e462caf197252490995b1f Mon Sep 17 00:00:00 2001 From: John Brooks Date: Sun, 2 Oct 2016 15:39:08 -0700 Subject: [PATCH] Improve packet-layer buffering and parsing logic SendRicochetPacket now has error handling, correctly encodes channel ids, accepts any io.Writer, and ensures that all data is written. All callers should be changed at some point to handle errors also. RecvRicochetPackets is refactored to return only one packet per call and avoid reading more data than it will consume, which simplifies the logic and fixes a number of problems with short reads or large packets. Also fixed an error in bounds checking that caused a remote panic for invalid packet sizes. It also now accepts any io.Reader. Tests are updated and expanded, and now pass. Changes to Ricochet.processConnection are whitespace-only, because of the removal of the inner packets loop. --- ricochet.go | 349 +++++++++++++++++++-------------------- utils/networking.go | 110 ++++++------ utils/networking_test.go | 222 +++++++++---------------- 3 files changed, 300 insertions(+), 381 deletions(-) diff --git a/ricochet.go b/ricochet.go index 6bf642c..b2036ab 100644 --- a/ricochet.go +++ b/ricochet.go @@ -109,147 +109,175 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService return } - packets, err := r.rni.RecvRicochetPackets(oc.conn) - + packet, err := r.rni.RecvRicochetPacket(oc.conn) if err != nil { + oc.Close() return } - for _, packet := range packets { + if len(packet.Data) == 0 { + service.OnChannelClosed(oc, packet.Channel) + continue + } - if len(packet.Data) == 0 { - service.OnChannelClosed(oc, packet.Channel) + if packet.Channel == 0 { + + res := new(Protocol_Data_Control.Packet) + err := proto.Unmarshal(packet.Data[:], res) + + if err != nil { + service.OnGenericError(oc, packet.Channel) continue } - if packet.Channel == 0 { + if res.GetOpenChannel() != nil { + opm := res.GetOpenChannel() - res := new(Protocol_Data_Control.Packet) - err := proto.Unmarshal(packet.Data[:], res) - - if err != nil { - service.OnGenericError(oc, packet.Channel) + if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" { + // Channel is already in use. + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) continue } - if res.GetOpenChannel() != nil { - opm := res.GetOpenChannel() + // If I am a Client, the server can only open even numbered channels + if oc.Client && opm.GetChannelIdentifier()%2 != 0 { + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + continue + } - if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" { - // Channel is already in use. + // If I am a Server, the client can only open odd numbered channels + if !oc.Client && opm.GetChannelIdentifier()%2 != 1 { + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + continue + } + + switch opm.GetChannelType() { + case "im.ricochet.auth.hidden-service": + if oc.Client { + // Servers are authed by default and can't auth with hidden-service service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - continue - } - - // If I am a Client, the server can only open even numbered channels - if oc.Client && opm.GetChannelIdentifier()%2 != 0 { + } else if oc.IsAuthed { + // Can't auth if already authed service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - continue - } - - // If I am a Server, the client can only open odd numbered channels - if !oc.Client && opm.GetChannelIdentifier()%2 != 1 { + } else if oc.HasChannel("im.ricochet.auth.hidden-service") { + // Can't open more than 1 auth channel service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - continue - } - - switch opm.GetChannelType() { - case "im.ricochet.auth.hidden-service": - if oc.Client { - // Servers are authed by default and can't auth with hidden-service - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else if oc.IsAuthed { - // Can't auth if already authed - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else if oc.HasChannel("im.ricochet.auth.hidden-service") { - // Can't open more than 1 auth channel - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + } else { + clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie) + if err == nil { + clientCookieB := [16]byte{} + copy(clientCookieB[:], clientCookie.([]byte)[:]) + service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB) } else { - clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie) - if err == nil { - clientCookieB := [16]byte{} - copy(clientCookieB[:], clientCookie.([]byte)[:]) - service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB) - } else { - // Must include Client Cookie - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + // Must include Client Cookie + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + } + } + case "im.ricochet.chat": + if !oc.IsAuthed { + // Can't open chat channel if not authorized + service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) + } else if !service.IsKnownContact(oc.OtherHostname) { + // Can't open chat channel if not a known contact + service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) + } else { + service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat") + } + case "im.ricochet.contact.request": + if oc.Client { + // Servers are not allowed to send contact requests + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + } else if !oc.IsAuthed { + // Can't open a contact channel if not authed + service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) + } else if oc.HasChannel("im.ricochet.contact.request") { + // Only 1 contact channel is allowed to be open at a time + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + } else { + contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest) + if err == nil { + contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest) + if check { + service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText()) + break } } + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + } + default: + service.OnUnknownTypeError(oc, opm.GetChannelIdentifier()) + } + } else if res.GetChannelResult() != nil { + crm := res.GetChannelResult() + if crm.GetOpened() { + switch oc.GetChannelType(crm.GetChannelIdentifier()) { + case "im.ricochet.auth.hidden-service": + serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie) + if err == nil { + serverCookieB := [16]byte{} + copy(serverCookieB[:], serverCookie.([]byte)[:]) + service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB) + } else { + service.OnBadUsageError(oc, crm.GetChannelIdentifier()) + } case "im.ricochet.chat": - if !oc.IsAuthed { - // Can't open chat channel if not authorized - service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) - } else if !service.IsKnownContact(oc.OtherHostname) { - // Can't open chat channel if not a known contact - service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) - } else { - service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat") - } + service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier()) case "im.ricochet.contact.request": - if oc.Client { - // Servers are not allowed to send contact requests - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else if !oc.IsAuthed { - // Can't open a contact channel if not authed - service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) - } else if oc.HasChannel("im.ricochet.contact.request") { - // Only 1 contact channel is allowed to be open at a time - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else { - contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest) - if err == nil { - contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest) - if check { - service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText()) - break - } + responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response) + if err == nil { + response, check := responseI.(*Protocol_Data_ContactRequest.Response) + if check { + service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String()) + break } - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) } + service.OnBadUsageError(oc, crm.GetChannelIdentifier()) default: - service.OnUnknownTypeError(oc, opm.GetChannelIdentifier()) - } - } else if res.GetChannelResult() != nil { - crm := res.GetChannelResult() - if crm.GetOpened() { - switch oc.GetChannelType(crm.GetChannelIdentifier()) { - case "im.ricochet.auth.hidden-service": - serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie) - if err == nil { - serverCookieB := [16]byte{} - copy(serverCookieB[:], serverCookie.([]byte)[:]) - service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB) - } else { - service.OnBadUsageError(oc, crm.GetChannelIdentifier()) - } - case "im.ricochet.chat": - service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier()) - case "im.ricochet.contact.request": - responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response) - if err == nil { - response, check := responseI.(*Protocol_Data_ContactRequest.Response) - if check { - service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String()) - break - } - } - service.OnBadUsageError(oc, crm.GetChannelIdentifier()) - default: - service.OnBadUsageError(oc, crm.GetChannelIdentifier()) - } - } else { - if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" { - service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String()) - } else { - oc.CloseChannel(crm.GetChannelIdentifier()) - } + service.OnBadUsageError(oc, crm.GetChannelIdentifier()) } } else { - // Unknown Message - oc.CloseChannel(packet.Channel) + if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" { + service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String()) + } else { + oc.CloseChannel(crm.GetChannelIdentifier()) + } } - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" { - res := new(Protocol_Data_AuthHiddenService.Packet) + } else { + // Unknown Message + oc.CloseChannel(packet.Channel) + } + } else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" { + res := new(Protocol_Data_AuthHiddenService.Packet) + err := proto.Unmarshal(packet.Data[:], res) + + if err != nil { + oc.CloseChannel(packet.Channel) + continue + } + + if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs + service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname)) + } else if res.GetResult() != nil && oc.Client { // Only Servers Send Results + service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) + } else { + // If neither of the above are satisfied we just close the connection + oc.Close() + } + + } else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" { + + // NOTE: These auth checks should be redundant, however they + // are included here for defense-in-depth if for some reason + // a previously authed connection becomes untrusted / not known and + // the state is not cleaned up. + if !oc.IsAuthed { + // Can't send chat messages if not authorized + service.OnUnauthorizedError(oc, packet.Channel) + } else if !service.IsKnownContact(oc.OtherHostname) { + // Can't send chat message if not a known contact + service.OnUnauthorizedError(oc, packet.Channel) + } else { + res := new(Protocol_Data_Chat.Packet) err := proto.Unmarshal(packet.Data[:], res) if err != nil { @@ -257,73 +285,42 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService continue } - if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs - service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname)) - } else if res.GetResult() != nil && oc.Client { // Only Servers Send Results - service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) + if res.GetChatMessage() != nil { + service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText()) + } else if res.GetChatAcknowledge() != nil { + service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId())) } else { // If neither of the above are satisfied we just close the connection oc.Close() } - - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" { - - // NOTE: These auth checks should be redundant, however they - // are included here for defense-in-depth if for some reason - // a previously authed connection becomes untrusted / not known and - // the state is not cleaned up. - if !oc.IsAuthed { - // Can't send chat messages if not authorized - service.OnUnauthorizedError(oc, packet.Channel) - } else if !service.IsKnownContact(oc.OtherHostname) { - // Can't send chat message if not a known contact - service.OnUnauthorizedError(oc, packet.Channel) - } else { - res := new(Protocol_Data_Chat.Packet) - err := proto.Unmarshal(packet.Data[:], res) - - if err != nil { - oc.CloseChannel(packet.Channel) - continue - } - - if res.GetChatMessage() != nil { - service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText()) - } else if res.GetChatAcknowledge() != nil { - service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId())) - } else { - // If neither of the above are satisfied we just close the connection - oc.Close() - } - } - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" { - - // NOTE: These auth checks should be redundant, however they - // are included here for defense-in-depth if for some reason - // a previously authed connection becomes untrusted / not known and - // the state is not cleaned up. - if !oc.Client { - // Clients are not allowed to send contact request responses - service.OnBadUsageError(oc, packet.Channel) - } else if !oc.IsAuthed { - // Can't send a contact request if not authed - service.OnBadUsageError(oc, packet.Channel) - } else { - res := new(Protocol_Data_ContactRequest.Response) - err := proto.Unmarshal(packet.Data[:], res) - log.Printf("%v", res) - if err != nil { - oc.CloseChannel(packet.Channel) - continue - } - service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String()) - } - } else if oc.GetChannelType(packet.Channel) == "none" { - // Invalid Channel Assignment - oc.CloseChannel(packet.Channel) - } else { - oc.Close() } + } else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" { + + // NOTE: These auth checks should be redundant, however they + // are included here for defense-in-depth if for some reason + // a previously authed connection becomes untrusted / not known and + // the state is not cleaned up. + if !oc.Client { + // Clients are not allowed to send contact request responses + service.OnBadUsageError(oc, packet.Channel) + } else if !oc.IsAuthed { + // Can't send a contact request if not authed + service.OnBadUsageError(oc, packet.Channel) + } else { + res := new(Protocol_Data_ContactRequest.Response) + err := proto.Unmarshal(packet.Data[:], res) + log.Printf("%v", res) + if err != nil { + oc.CloseChannel(packet.Channel) + continue + } + service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String()) + } + } else if oc.GetChannelType(packet.Channel) == "none" { + // Invalid Channel Assignment + oc.CloseChannel(packet.Channel) + } else { + oc.Close() } } } diff --git a/utils/networking.go b/utils/networking.go index 9412558..3b433e6 100644 --- a/utils/networking.go +++ b/utils/networking.go @@ -1,10 +1,10 @@ package utils import ( + "bytes" "encoding/binary" "errors" - "net" - "strconv" + "io" ) // RicochetData is a structure containing the raw data and the channel it the @@ -14,79 +14,67 @@ type RicochetData struct { Data []byte } +func (rd RicochetData) Equals(other RicochetData) bool { + return rd.Channel == other.Channel && bytes.Equal(rd.Data, other.Data) +} + // RicochetNetworkInterface abstract operations that interact with ricochet's // packet layer. type RicochetNetworkInterface interface { - Recv(conn net.Conn) ([]byte, error) - SendRicochetPacket(conn net.Conn, channel int32, data []byte) - RecvRicochetPackets(conn net.Conn) ([]RicochetData, error) + SendRicochetPacket(dst io.Writer, channel int32, data []byte) error + RecvRicochetPacket(reader io.Reader) (RicochetData, error) } // RicochetNetwork is a concrete implementation of the RicochetNetworkInterface type RicochetNetwork struct { } -// Recv reads data from the client, and returns the raw byte array, else error. -func (rn *RicochetNetwork) Recv(conn net.Conn) ([]byte, error) { - buf := make([]byte, 4096) - n, err := conn.Read(buf) - if err != nil { - return nil, err - } - ret := make([]byte, n) - copy(ret[:], buf[:]) - return ret, nil -} - // SendRicochetPacket places the data into a structure needed for the client to // decode the packet and writes the packet to the network. -func (rn *RicochetNetwork) SendRicochetPacket(conn net.Conn, channel int32, data []byte) { - header := make([]byte, 4+len(data)) - header[0] = byte(len(header) >> 8) - header[1] = byte(len(header) & 0x00FF) - header[2] = 0x00 - header[3] = byte(channel) - copy(header[4:], data[:]) - conn.Write(header) +func (rn *RicochetNetwork) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error { + packet := make([]byte, 4+len(data)) + if len(packet) > 65535 { + return errors.New("packet too large") + } + binary.BigEndian.PutUint16(packet[0:2], uint16(len(packet))) + if channel < 0 || channel > 65535 { + return errors.New("invalid channel ID") + } + binary.BigEndian.PutUint16(packet[2:4], uint16(channel)) + copy(packet[4:], data[:]) + + for pos := 0; pos < len(packet); { + n, err := dst.Write(packet[pos:]) + if err != nil { + return err + } + pos += n + } + return nil } -// RecvRicochetPackets returns an array of new messages received from the ricochet client -func (rn *RicochetNetwork) RecvRicochetPackets(conn net.Conn) ([]RicochetData, error) { - buf, err := rn.Recv(conn) - if err != nil && len(buf) < 4 { - return nil, errors.New("failed to retrieve new messages from the client") +// RecvRicochetPacket returns the next packet from reader as a RicochetData +// structure, or an error. +func (rn *RicochetNetwork) RecvRicochetPacket(reader io.Reader) (RicochetData, error) { + packet := RicochetData{} + + // Read the four-byte header to get packet length + header := make([]byte, 4) + if _, err := io.ReadAtLeast(reader, header, len(header)); err != nil { + return packet, err } - pos := 0 - finished := false - var datas []RicochetData - - for !finished { - size := int(binary.BigEndian.Uint16(buf[pos+0 : pos+2])) - channel := int(binary.BigEndian.Uint16(buf[pos+2 : pos+4])) - - if size < 4 { - return datas, errors.New("invalid ricochet packet received (size=" + strconv.Itoa(size) + ")") - } - - if pos+size > len(buf) { - return datas, errors.New("partial data packet received") - } - - data := RicochetData{} - data.Channel = int32(channel) - - if pos+4 >= len(buf) { - data.Data = make([]byte, 0) - } else { - data.Data = buf[pos+4 : pos+size] - } - - datas = append(datas, data) - pos += size - if pos >= len(buf) { - finished = true - } + size := int(binary.BigEndian.Uint16(header[0:2])) + if size < 4 { + return packet, errors.New("invalid packet length") } - return datas, nil + + packet.Channel = int32(binary.BigEndian.Uint16(header[2:4])) + packet.Data = make([]byte, size-4) + + if _, err := io.ReadAtLeast(reader, packet.Data, len(packet.Data)); err != nil { + return packet, err + } + + return packet, nil } diff --git a/utils/networking_test.go b/utils/networking_test.go index 38824cf..fed6d9d 100644 --- a/utils/networking_test.go +++ b/utils/networking_test.go @@ -1,171 +1,105 @@ package utils -import "testing" -import "net" -import "time" +import ( + "bytes" + "io" + "testing" + "testing/iotest" +) -type MockConn struct { - Written []byte - MockOutput []byte +// Valid packets and their encoded forms +var packetTests = []struct { + packet RicochetData + encoded []byte +}{ + {RicochetData{1, []byte{}}, []byte{0x00, 0x04, 0x00, 0x01}}, + {RicochetData{65535, []byte{0xDE, 0xAD, 0xBE, 0xEF}}, []byte{0x00, 0x08, 0xFF, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}}, + {RicochetData{2, make([]byte, 65531)}, append([]byte{0xFF, 0xFF, 0x00, 0x02}, make([]byte, 65531)...)}, } -func (mc *MockConn) Read(b []byte) (int, error) { - copy(b[:], mc.MockOutput[:]) - return len(mc.MockOutput), nil -} - -func (mc *MockConn) Write(written []byte) (int, error) { - mc.Written = written - return 0, nil -} - -func (mc *MockConn) LocalAddr() net.Addr { - return nil -} - -func (mc *MockConn) RemoteAddr() net.Addr { - return nil -} - -func (mc *MockConn) Close() error { - return nil -} - -func (mc *MockConn) SetDeadline(t time.Time) error { - return nil -} - -func (mc *MockConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (mc *MockConn) SetWriteDeadline(t time.Time) error { - return nil -} - -func TestSentRicochetPacket(t *testing.T) { - conn := new(MockConn) +// Test sending valid packets +func TestSendRicochetPacket(t *testing.T) { rni := RicochetNetwork{} - rni.SendRicochetPacket(conn, 1, []byte{}) - if len(conn.Written) != 4 && conn.Written[0] != 0x00 && conn.Written[1] != 0x00 && conn.Written[2] != 0x01 && conn.Written[3] != 0x00 { - t.Errorf("Output of SentRicochetPacket was Unexpected: %x", conn.Written) + for _, td := range packetTests { + var buf bytes.Buffer + err := rni.SendRicochetPacket(&buf, td.packet.Channel, td.packet.Data) + if err != nil { + t.Errorf("Error sending packet %v: %v", td.packet, err) + } else if !bytes.Equal(buf.Bytes(), td.encoded) { + t.Errorf("Expected serialized packet %x but got %x", td.encoded, buf.Bytes()) + } } } -func TestRecv(t *testing.T) { - conn := new(MockConn) - conn.MockOutput = []byte{0xDE, 0xAD, 0xBE, 0xEF} +// Test sending invalid packets +func TestSendRicochetPacket_Invalid(t *testing.T) { rni := RicochetNetwork{} - buf, err := rni.Recv(conn) - if err != nil || len(buf) != 4 || buf[0] != 0xDE || buf[1] != 0xAD || buf[2] != 0xBE || buf[3] != 0xEF { - t.Errorf("Output of Recv was Unexpected: %x", buf) + invalidPackets := []RicochetData{ + RicochetData{-1, []byte{}}, + RicochetData{65536, []byte{}}, + RicochetData{0, make([]byte, 65532)}, + } + + for _, td := range invalidPackets { + var buf bytes.Buffer + err := rni.SendRicochetPacket(&buf, td.Channel, td.Data) + // Expect error + if err == nil { + t.Errorf("Expected error when sending invalid packet %v", td) + } } } +// Test receiving valid packets func TestRecvRicochetPacket(t *testing.T) { - conn := new(MockConn) - conn.MockOutput = []byte{00, 0x04, 0x00, 0x01} + var buf bytes.Buffer + for _, td := range packetTests { + if _, err := buf.Write(td.encoded); err != nil { + t.Error(err) + return + } + } + // Use a HalfReader to test behavior on short socket reads also + reader := iotest.HalfReader(&buf) rni := RicochetNetwork{} - rp, err := rni.RecvRicochetPackets(conn) - if err != nil { - t.Errorf("error extracting ricochet packets: %v", err) - return - } - - if len(rp) != 1 { - t.Errorf("unexpected number of ricochet packets: %d", len(rp)) - } else { - if rp[0].Channel != 1 { - t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel) - } - - if len(rp[0].Data) != 0 { - t.Errorf("expected emptry packet, instead got %x", rp[0].Data) + for _, td := range packetTests { + packet, err := rni.RecvRicochetPacket(reader) + if err != nil { + t.Errorf("Error receiving packet %v: %v", td.packet, err) + return + } else if !packet.Equals(td.packet) { + t.Errorf("Expected unserialized packet %v but got %v", td.packet, packet) } } + if packet, err := rni.RecvRicochetPacket(reader); err != io.EOF { + if err != nil { + t.Errorf("Expected EOF on packet stream but received error: %v", err) + } else { + t.Errorf("Expected EOF but received packet: %v", packet) + } + } } -func TestRecvRicochetPacketInvalid(t *testing.T) { - conn := new(MockConn) - conn.MockOutput = []byte{00, 0x01, 0x00, 0x01} - +// Test receiving invalid packets +func TestRecvRicochetPacket_Invalid(t *testing.T) { rni := RicochetNetwork{} - _, err := rni.RecvRicochetPackets(conn) - - if err == nil { - t.Errorf("recv should have errored due to invalid packets %v", err) + invalidPackets := [][]byte{ + []byte{0x00, 0x00, 0x00, 0x00}, + []byte{0x00, 0x03, 0x00, 0x00}, + []byte{0xff}, + []byte{0x00, 0x06, 0x00, 0x00, 0x00}, + []byte{}, } - conn.MockOutput = []byte{00, 0x0A, 0x00, 0x01} - - _, err = rni.RecvRicochetPackets(conn) - - if err == nil { - t.Errorf("recv should have errored due to invalid packets %v", err) + for _, td := range invalidPackets { + buf := bytes.NewBuffer(td) + packet, err := rni.RecvRicochetPacket(buf) + // Expect error + if err == nil { + t.Errorf("Expected error when sending invalid packet %x, got packet %v", td, packet) + } } - -} - -func TestRecvRicochetPacketLong(t *testing.T) { - conn := new(MockConn) - conn.MockOutput = []byte{0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF} - - rni := RicochetNetwork{} - rp, err := rni.RecvRicochetPackets(conn) - - if err != nil { - t.Errorf("error extracting ricochet packets: %v", err) - return - } - - if len(rp) != 1 { - t.Errorf("unexpected number of ricochet packets: %d", len(rp)) - } else { - if rp[0].Channel != 255 { - t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel) - } - - if len(rp[0].Data) != 4 || rp[0].Data[0] != 0xDE || rp[0].Data[1] != 0xAD || rp[0].Data[2] != 0xBE || rp[0].Data[3] != 0xEF { - t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data) - } - } - -} - -func TestRecvRicochetPacketMultiplex(t *testing.T) { - conn := new(MockConn) - conn.MockOutput = []byte{0x00, 0x04, 0x00, 0x01, 0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF} - - rni := RicochetNetwork{} - rp, err := rni.RecvRicochetPackets(conn) - - if err != nil { - t.Errorf("error extracting ricochet packets: %v", err) - return - } - - if len(rp) != 2 { - t.Errorf("unexpected number of ricochet packets, expected 2 gt: %d", len(rp)) - } else { - - if rp[0].Channel != 1 { - t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel) - } - - if len(rp[0].Data) != 0 { - t.Errorf("expected empty packet, instead got %x", rp[0].Data) - } - - if rp[1].Channel != 255 { - t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel) - } - - if len(rp[1].Data) != 4 || rp[1].Data[0] != 0xDE || rp[1].Data[1] != 0xAD || rp[1].Data[2] != 0xBE || rp[1].Data[3] != 0xEF { - t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data) - } - } - }