diff --git a/ricochet.go b/ricochet.go index 6bf642c..b2036ab 100644 --- a/ricochet.go +++ b/ricochet.go @@ -109,147 +109,175 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService return } - packets, err := r.rni.RecvRicochetPackets(oc.conn) - + packet, err := r.rni.RecvRicochetPacket(oc.conn) if err != nil { + oc.Close() return } - for _, packet := range packets { + if len(packet.Data) == 0 { + service.OnChannelClosed(oc, packet.Channel) + continue + } - if len(packet.Data) == 0 { - service.OnChannelClosed(oc, packet.Channel) + if packet.Channel == 0 { + + res := new(Protocol_Data_Control.Packet) + err := proto.Unmarshal(packet.Data[:], res) + + if err != nil { + service.OnGenericError(oc, packet.Channel) continue } - if packet.Channel == 0 { + if res.GetOpenChannel() != nil { + opm := res.GetOpenChannel() - res := new(Protocol_Data_Control.Packet) - err := proto.Unmarshal(packet.Data[:], res) - - if err != nil { - service.OnGenericError(oc, packet.Channel) + if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" { + // Channel is already in use. + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) continue } - if res.GetOpenChannel() != nil { - opm := res.GetOpenChannel() + // If I am a Client, the server can only open even numbered channels + if oc.Client && opm.GetChannelIdentifier()%2 != 0 { + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + continue + } - if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" { - // Channel is already in use. + // If I am a Server, the client can only open odd numbered channels + if !oc.Client && opm.GetChannelIdentifier()%2 != 1 { + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + continue + } + + switch opm.GetChannelType() { + case "im.ricochet.auth.hidden-service": + if oc.Client { + // Servers are authed by default and can't auth with hidden-service service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - continue - } - - // If I am a Client, the server can only open even numbered channels - if oc.Client && opm.GetChannelIdentifier()%2 != 0 { + } else if oc.IsAuthed { + // Can't auth if already authed service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - continue - } - - // If I am a Server, the client can only open odd numbered channels - if !oc.Client && opm.GetChannelIdentifier()%2 != 1 { + } else if oc.HasChannel("im.ricochet.auth.hidden-service") { + // Can't open more than 1 auth channel service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - continue - } - - switch opm.GetChannelType() { - case "im.ricochet.auth.hidden-service": - if oc.Client { - // Servers are authed by default and can't auth with hidden-service - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else if oc.IsAuthed { - // Can't auth if already authed - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else if oc.HasChannel("im.ricochet.auth.hidden-service") { - // Can't open more than 1 auth channel - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + } else { + clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie) + if err == nil { + clientCookieB := [16]byte{} + copy(clientCookieB[:], clientCookie.([]byte)[:]) + service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB) } else { - clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie) - if err == nil { - clientCookieB := [16]byte{} - copy(clientCookieB[:], clientCookie.([]byte)[:]) - service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB) - } else { - // Must include Client Cookie - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + // Must include Client Cookie + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + } + } + case "im.ricochet.chat": + if !oc.IsAuthed { + // Can't open chat channel if not authorized + service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) + } else if !service.IsKnownContact(oc.OtherHostname) { + // Can't open chat channel if not a known contact + service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) + } else { + service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat") + } + case "im.ricochet.contact.request": + if oc.Client { + // Servers are not allowed to send contact requests + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + } else if !oc.IsAuthed { + // Can't open a contact channel if not authed + service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) + } else if oc.HasChannel("im.ricochet.contact.request") { + // Only 1 contact channel is allowed to be open at a time + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + } else { + contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest) + if err == nil { + contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest) + if check { + service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText()) + break } } + service.OnBadUsageError(oc, opm.GetChannelIdentifier()) + } + default: + service.OnUnknownTypeError(oc, opm.GetChannelIdentifier()) + } + } else if res.GetChannelResult() != nil { + crm := res.GetChannelResult() + if crm.GetOpened() { + switch oc.GetChannelType(crm.GetChannelIdentifier()) { + case "im.ricochet.auth.hidden-service": + serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie) + if err == nil { + serverCookieB := [16]byte{} + copy(serverCookieB[:], serverCookie.([]byte)[:]) + service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB) + } else { + service.OnBadUsageError(oc, crm.GetChannelIdentifier()) + } case "im.ricochet.chat": - if !oc.IsAuthed { - // Can't open chat channel if not authorized - service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) - } else if !service.IsKnownContact(oc.OtherHostname) { - // Can't open chat channel if not a known contact - service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) - } else { - service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat") - } + service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier()) case "im.ricochet.contact.request": - if oc.Client { - // Servers are not allowed to send contact requests - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else if !oc.IsAuthed { - // Can't open a contact channel if not authed - service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) - } else if oc.HasChannel("im.ricochet.contact.request") { - // Only 1 contact channel is allowed to be open at a time - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else { - contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest) - if err == nil { - contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest) - if check { - service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText()) - break - } + responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response) + if err == nil { + response, check := responseI.(*Protocol_Data_ContactRequest.Response) + if check { + service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String()) + break } - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) } + service.OnBadUsageError(oc, crm.GetChannelIdentifier()) default: - service.OnUnknownTypeError(oc, opm.GetChannelIdentifier()) - } - } else if res.GetChannelResult() != nil { - crm := res.GetChannelResult() - if crm.GetOpened() { - switch oc.GetChannelType(crm.GetChannelIdentifier()) { - case "im.ricochet.auth.hidden-service": - serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie) - if err == nil { - serverCookieB := [16]byte{} - copy(serverCookieB[:], serverCookie.([]byte)[:]) - service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB) - } else { - service.OnBadUsageError(oc, crm.GetChannelIdentifier()) - } - case "im.ricochet.chat": - service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier()) - case "im.ricochet.contact.request": - responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response) - if err == nil { - response, check := responseI.(*Protocol_Data_ContactRequest.Response) - if check { - service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String()) - break - } - } - service.OnBadUsageError(oc, crm.GetChannelIdentifier()) - default: - service.OnBadUsageError(oc, crm.GetChannelIdentifier()) - } - } else { - if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" { - service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String()) - } else { - oc.CloseChannel(crm.GetChannelIdentifier()) - } + service.OnBadUsageError(oc, crm.GetChannelIdentifier()) } } else { - // Unknown Message - oc.CloseChannel(packet.Channel) + if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" { + service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String()) + } else { + oc.CloseChannel(crm.GetChannelIdentifier()) + } } - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" { - res := new(Protocol_Data_AuthHiddenService.Packet) + } else { + // Unknown Message + oc.CloseChannel(packet.Channel) + } + } else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" { + res := new(Protocol_Data_AuthHiddenService.Packet) + err := proto.Unmarshal(packet.Data[:], res) + + if err != nil { + oc.CloseChannel(packet.Channel) + continue + } + + if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs + service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname)) + } else if res.GetResult() != nil && oc.Client { // Only Servers Send Results + service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) + } else { + // If neither of the above are satisfied we just close the connection + oc.Close() + } + + } else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" { + + // NOTE: These auth checks should be redundant, however they + // are included here for defense-in-depth if for some reason + // a previously authed connection becomes untrusted / not known and + // the state is not cleaned up. + if !oc.IsAuthed { + // Can't send chat messages if not authorized + service.OnUnauthorizedError(oc, packet.Channel) + } else if !service.IsKnownContact(oc.OtherHostname) { + // Can't send chat message if not a known contact + service.OnUnauthorizedError(oc, packet.Channel) + } else { + res := new(Protocol_Data_Chat.Packet) err := proto.Unmarshal(packet.Data[:], res) if err != nil { @@ -257,73 +285,42 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService continue } - if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs - service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname)) - } else if res.GetResult() != nil && oc.Client { // Only Servers Send Results - service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) + if res.GetChatMessage() != nil { + service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText()) + } else if res.GetChatAcknowledge() != nil { + service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId())) } else { // If neither of the above are satisfied we just close the connection oc.Close() } - - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" { - - // NOTE: These auth checks should be redundant, however they - // are included here for defense-in-depth if for some reason - // a previously authed connection becomes untrusted / not known and - // the state is not cleaned up. - if !oc.IsAuthed { - // Can't send chat messages if not authorized - service.OnUnauthorizedError(oc, packet.Channel) - } else if !service.IsKnownContact(oc.OtherHostname) { - // Can't send chat message if not a known contact - service.OnUnauthorizedError(oc, packet.Channel) - } else { - res := new(Protocol_Data_Chat.Packet) - err := proto.Unmarshal(packet.Data[:], res) - - if err != nil { - oc.CloseChannel(packet.Channel) - continue - } - - if res.GetChatMessage() != nil { - service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText()) - } else if res.GetChatAcknowledge() != nil { - service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId())) - } else { - // If neither of the above are satisfied we just close the connection - oc.Close() - } - } - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" { - - // NOTE: These auth checks should be redundant, however they - // are included here for defense-in-depth if for some reason - // a previously authed connection becomes untrusted / not known and - // the state is not cleaned up. - if !oc.Client { - // Clients are not allowed to send contact request responses - service.OnBadUsageError(oc, packet.Channel) - } else if !oc.IsAuthed { - // Can't send a contact request if not authed - service.OnBadUsageError(oc, packet.Channel) - } else { - res := new(Protocol_Data_ContactRequest.Response) - err := proto.Unmarshal(packet.Data[:], res) - log.Printf("%v", res) - if err != nil { - oc.CloseChannel(packet.Channel) - continue - } - service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String()) - } - } else if oc.GetChannelType(packet.Channel) == "none" { - // Invalid Channel Assignment - oc.CloseChannel(packet.Channel) - } else { - oc.Close() } + } else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" { + + // NOTE: These auth checks should be redundant, however they + // are included here for defense-in-depth if for some reason + // a previously authed connection becomes untrusted / not known and + // the state is not cleaned up. + if !oc.Client { + // Clients are not allowed to send contact request responses + service.OnBadUsageError(oc, packet.Channel) + } else if !oc.IsAuthed { + // Can't send a contact request if not authed + service.OnBadUsageError(oc, packet.Channel) + } else { + res := new(Protocol_Data_ContactRequest.Response) + err := proto.Unmarshal(packet.Data[:], res) + log.Printf("%v", res) + if err != nil { + oc.CloseChannel(packet.Channel) + continue + } + service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String()) + } + } else if oc.GetChannelType(packet.Channel) == "none" { + // Invalid Channel Assignment + oc.CloseChannel(packet.Channel) + } else { + oc.Close() } } } diff --git a/utils/networking.go b/utils/networking.go index 9412558..3b433e6 100644 --- a/utils/networking.go +++ b/utils/networking.go @@ -1,10 +1,10 @@ package utils import ( + "bytes" "encoding/binary" "errors" - "net" - "strconv" + "io" ) // RicochetData is a structure containing the raw data and the channel it the @@ -14,79 +14,67 @@ type RicochetData struct { Data []byte } +func (rd RicochetData) Equals(other RicochetData) bool { + return rd.Channel == other.Channel && bytes.Equal(rd.Data, other.Data) +} + // RicochetNetworkInterface abstract operations that interact with ricochet's // packet layer. type RicochetNetworkInterface interface { - Recv(conn net.Conn) ([]byte, error) - SendRicochetPacket(conn net.Conn, channel int32, data []byte) - RecvRicochetPackets(conn net.Conn) ([]RicochetData, error) + SendRicochetPacket(dst io.Writer, channel int32, data []byte) error + RecvRicochetPacket(reader io.Reader) (RicochetData, error) } // RicochetNetwork is a concrete implementation of the RicochetNetworkInterface type RicochetNetwork struct { } -// Recv reads data from the client, and returns the raw byte array, else error. -func (rn *RicochetNetwork) Recv(conn net.Conn) ([]byte, error) { - buf := make([]byte, 4096) - n, err := conn.Read(buf) - if err != nil { - return nil, err - } - ret := make([]byte, n) - copy(ret[:], buf[:]) - return ret, nil -} - // SendRicochetPacket places the data into a structure needed for the client to // decode the packet and writes the packet to the network. -func (rn *RicochetNetwork) SendRicochetPacket(conn net.Conn, channel int32, data []byte) { - header := make([]byte, 4+len(data)) - header[0] = byte(len(header) >> 8) - header[1] = byte(len(header) & 0x00FF) - header[2] = 0x00 - header[3] = byte(channel) - copy(header[4:], data[:]) - conn.Write(header) +func (rn *RicochetNetwork) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error { + packet := make([]byte, 4+len(data)) + if len(packet) > 65535 { + return errors.New("packet too large") + } + binary.BigEndian.PutUint16(packet[0:2], uint16(len(packet))) + if channel < 0 || channel > 65535 { + return errors.New("invalid channel ID") + } + binary.BigEndian.PutUint16(packet[2:4], uint16(channel)) + copy(packet[4:], data[:]) + + for pos := 0; pos < len(packet); { + n, err := dst.Write(packet[pos:]) + if err != nil { + return err + } + pos += n + } + return nil } -// RecvRicochetPackets returns an array of new messages received from the ricochet client -func (rn *RicochetNetwork) RecvRicochetPackets(conn net.Conn) ([]RicochetData, error) { - buf, err := rn.Recv(conn) - if err != nil && len(buf) < 4 { - return nil, errors.New("failed to retrieve new messages from the client") +// RecvRicochetPacket returns the next packet from reader as a RicochetData +// structure, or an error. +func (rn *RicochetNetwork) RecvRicochetPacket(reader io.Reader) (RicochetData, error) { + packet := RicochetData{} + + // Read the four-byte header to get packet length + header := make([]byte, 4) + if _, err := io.ReadAtLeast(reader, header, len(header)); err != nil { + return packet, err } - pos := 0 - finished := false - var datas []RicochetData - - for !finished { - size := int(binary.BigEndian.Uint16(buf[pos+0 : pos+2])) - channel := int(binary.BigEndian.Uint16(buf[pos+2 : pos+4])) - - if size < 4 { - return datas, errors.New("invalid ricochet packet received (size=" + strconv.Itoa(size) + ")") - } - - if pos+size > len(buf) { - return datas, errors.New("partial data packet received") - } - - data := RicochetData{} - data.Channel = int32(channel) - - if pos+4 >= len(buf) { - data.Data = make([]byte, 0) - } else { - data.Data = buf[pos+4 : pos+size] - } - - datas = append(datas, data) - pos += size - if pos >= len(buf) { - finished = true - } + size := int(binary.BigEndian.Uint16(header[0:2])) + if size < 4 { + return packet, errors.New("invalid packet length") } - return datas, nil + + packet.Channel = int32(binary.BigEndian.Uint16(header[2:4])) + packet.Data = make([]byte, size-4) + + if _, err := io.ReadAtLeast(reader, packet.Data, len(packet.Data)); err != nil { + return packet, err + } + + return packet, nil } diff --git a/utils/networking_test.go b/utils/networking_test.go index 38824cf..fed6d9d 100644 --- a/utils/networking_test.go +++ b/utils/networking_test.go @@ -1,171 +1,105 @@ package utils -import "testing" -import "net" -import "time" +import ( + "bytes" + "io" + "testing" + "testing/iotest" +) -type MockConn struct { - Written []byte - MockOutput []byte +// Valid packets and their encoded forms +var packetTests = []struct { + packet RicochetData + encoded []byte +}{ + {RicochetData{1, []byte{}}, []byte{0x00, 0x04, 0x00, 0x01}}, + {RicochetData{65535, []byte{0xDE, 0xAD, 0xBE, 0xEF}}, []byte{0x00, 0x08, 0xFF, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}}, + {RicochetData{2, make([]byte, 65531)}, append([]byte{0xFF, 0xFF, 0x00, 0x02}, make([]byte, 65531)...)}, } -func (mc *MockConn) Read(b []byte) (int, error) { - copy(b[:], mc.MockOutput[:]) - return len(mc.MockOutput), nil -} - -func (mc *MockConn) Write(written []byte) (int, error) { - mc.Written = written - return 0, nil -} - -func (mc *MockConn) LocalAddr() net.Addr { - return nil -} - -func (mc *MockConn) RemoteAddr() net.Addr { - return nil -} - -func (mc *MockConn) Close() error { - return nil -} - -func (mc *MockConn) SetDeadline(t time.Time) error { - return nil -} - -func (mc *MockConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (mc *MockConn) SetWriteDeadline(t time.Time) error { - return nil -} - -func TestSentRicochetPacket(t *testing.T) { - conn := new(MockConn) +// Test sending valid packets +func TestSendRicochetPacket(t *testing.T) { rni := RicochetNetwork{} - rni.SendRicochetPacket(conn, 1, []byte{}) - if len(conn.Written) != 4 && conn.Written[0] != 0x00 && conn.Written[1] != 0x00 && conn.Written[2] != 0x01 && conn.Written[3] != 0x00 { - t.Errorf("Output of SentRicochetPacket was Unexpected: %x", conn.Written) + for _, td := range packetTests { + var buf bytes.Buffer + err := rni.SendRicochetPacket(&buf, td.packet.Channel, td.packet.Data) + if err != nil { + t.Errorf("Error sending packet %v: %v", td.packet, err) + } else if !bytes.Equal(buf.Bytes(), td.encoded) { + t.Errorf("Expected serialized packet %x but got %x", td.encoded, buf.Bytes()) + } } } -func TestRecv(t *testing.T) { - conn := new(MockConn) - conn.MockOutput = []byte{0xDE, 0xAD, 0xBE, 0xEF} +// Test sending invalid packets +func TestSendRicochetPacket_Invalid(t *testing.T) { rni := RicochetNetwork{} - buf, err := rni.Recv(conn) - if err != nil || len(buf) != 4 || buf[0] != 0xDE || buf[1] != 0xAD || buf[2] != 0xBE || buf[3] != 0xEF { - t.Errorf("Output of Recv was Unexpected: %x", buf) + invalidPackets := []RicochetData{ + RicochetData{-1, []byte{}}, + RicochetData{65536, []byte{}}, + RicochetData{0, make([]byte, 65532)}, + } + + for _, td := range invalidPackets { + var buf bytes.Buffer + err := rni.SendRicochetPacket(&buf, td.Channel, td.Data) + // Expect error + if err == nil { + t.Errorf("Expected error when sending invalid packet %v", td) + } } } +// Test receiving valid packets func TestRecvRicochetPacket(t *testing.T) { - conn := new(MockConn) - conn.MockOutput = []byte{00, 0x04, 0x00, 0x01} + var buf bytes.Buffer + for _, td := range packetTests { + if _, err := buf.Write(td.encoded); err != nil { + t.Error(err) + return + } + } + // Use a HalfReader to test behavior on short socket reads also + reader := iotest.HalfReader(&buf) rni := RicochetNetwork{} - rp, err := rni.RecvRicochetPackets(conn) - if err != nil { - t.Errorf("error extracting ricochet packets: %v", err) - return - } - - if len(rp) != 1 { - t.Errorf("unexpected number of ricochet packets: %d", len(rp)) - } else { - if rp[0].Channel != 1 { - t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel) - } - - if len(rp[0].Data) != 0 { - t.Errorf("expected emptry packet, instead got %x", rp[0].Data) + for _, td := range packetTests { + packet, err := rni.RecvRicochetPacket(reader) + if err != nil { + t.Errorf("Error receiving packet %v: %v", td.packet, err) + return + } else if !packet.Equals(td.packet) { + t.Errorf("Expected unserialized packet %v but got %v", td.packet, packet) } } + if packet, err := rni.RecvRicochetPacket(reader); err != io.EOF { + if err != nil { + t.Errorf("Expected EOF on packet stream but received error: %v", err) + } else { + t.Errorf("Expected EOF but received packet: %v", packet) + } + } } -func TestRecvRicochetPacketInvalid(t *testing.T) { - conn := new(MockConn) - conn.MockOutput = []byte{00, 0x01, 0x00, 0x01} - +// Test receiving invalid packets +func TestRecvRicochetPacket_Invalid(t *testing.T) { rni := RicochetNetwork{} - _, err := rni.RecvRicochetPackets(conn) - - if err == nil { - t.Errorf("recv should have errored due to invalid packets %v", err) + invalidPackets := [][]byte{ + []byte{0x00, 0x00, 0x00, 0x00}, + []byte{0x00, 0x03, 0x00, 0x00}, + []byte{0xff}, + []byte{0x00, 0x06, 0x00, 0x00, 0x00}, + []byte{}, } - conn.MockOutput = []byte{00, 0x0A, 0x00, 0x01} - - _, err = rni.RecvRicochetPackets(conn) - - if err == nil { - t.Errorf("recv should have errored due to invalid packets %v", err) + for _, td := range invalidPackets { + buf := bytes.NewBuffer(td) + packet, err := rni.RecvRicochetPacket(buf) + // Expect error + if err == nil { + t.Errorf("Expected error when sending invalid packet %x, got packet %v", td, packet) + } } - -} - -func TestRecvRicochetPacketLong(t *testing.T) { - conn := new(MockConn) - conn.MockOutput = []byte{0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF} - - rni := RicochetNetwork{} - rp, err := rni.RecvRicochetPackets(conn) - - if err != nil { - t.Errorf("error extracting ricochet packets: %v", err) - return - } - - if len(rp) != 1 { - t.Errorf("unexpected number of ricochet packets: %d", len(rp)) - } else { - if rp[0].Channel != 255 { - t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel) - } - - if len(rp[0].Data) != 4 || rp[0].Data[0] != 0xDE || rp[0].Data[1] != 0xAD || rp[0].Data[2] != 0xBE || rp[0].Data[3] != 0xEF { - t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data) - } - } - -} - -func TestRecvRicochetPacketMultiplex(t *testing.T) { - conn := new(MockConn) - conn.MockOutput = []byte{0x00, 0x04, 0x00, 0x01, 0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF} - - rni := RicochetNetwork{} - rp, err := rni.RecvRicochetPackets(conn) - - if err != nil { - t.Errorf("error extracting ricochet packets: %v", err) - return - } - - if len(rp) != 2 { - t.Errorf("unexpected number of ricochet packets, expected 2 gt: %d", len(rp)) - } else { - - if rp[0].Channel != 1 { - t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel) - } - - if len(rp[0].Data) != 0 { - t.Errorf("expected empty packet, instead got %x", rp[0].Data) - } - - if rp[1].Channel != 255 { - t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel) - } - - if len(rp[1].Data) != 4 || rp[1].Data[0] != 0xDE || rp[1].Data[1] != 0xAD || rp[1].Data[2] != 0xBE || rp[1].Data[3] != 0xEF { - t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data) - } - } - }