Fleshing out error cases

This commit is contained in:
Sarah Jamie Lewis 2016-01-02 00:47:32 -08:00
parent 3469935bcb
commit f684fd8694
2 changed files with 42 additions and 37 deletions

View File

@ -12,7 +12,7 @@ func main() {
ricochet.Init("./private_key", true) ricochet.Init("./private_key", true)
ricochet.Connect("kwke2hntvyfqm7dr", "127.0.0.1:55555|jlq67qzo6s4yp3sp") ricochet.Connect("kwke2hntvyfqm7dr", "127.0.0.1:55555|jlq67qzo6s4yp3sp")
// Not needed passed the initial run // Not needed past the initial run
// TODO need to wait for contact response before sending OpenChannel // TODO need to wait for contact response before sending OpenChannel
// ricochet.SendContactRequest("EchoBot", "I'm an EchoBot") // ricochet.SendContactRequest("EchoBot", "I'm an EchoBot")

View File

@ -144,9 +144,8 @@ func (r *Ricochet) Connect(from string, to string) error {
r.logger.Fatal("Cannot Marshal Open Channel Message: ", err) r.logger.Fatal("Cannot Marshal Open Channel Message: ", err)
} }
openChannel := r.constructProtocol(data, 0) r.sendPacket(data, 0)
r.logger.Print("Opening Channel: ", pc) r.logger.Print("Opening Channel: ", pc)
r.send(openChannel)
response, _ := r.getMessages() response, _ := r.getMessages()
openChannelResponse, _ := r.decodePacket(response[0], CONTROL) openChannelResponse, _ := r.decodePacket(response[0], CONTROL)
@ -203,11 +202,8 @@ func (r *Ricochet) Connect(from string, to string) error {
} }
data, err = proto.Marshal(ahsPacket) data, err = proto.Marshal(ahsPacket)
r.sendPacket(data, 1)
sendProof := r.constructProtocol(data, 1)
r.logger.Print("Constructed Proof: ", ahsPacket)
r.send(sendProof)
response, _ = r.getMessages() response, _ = r.getMessages()
resultResponse, _ := r.decodePacket(response[0], AUTH) resultResponse, _ := r.decodePacket(response[0], AUTH)
r.logger.Print("Received Result: ", resultResponse) r.logger.Print("Received Result: ", resultResponse)
@ -230,9 +226,7 @@ func (r *Ricochet) OpenChannel(channelType string, id int) error {
} }
data, _ := proto.Marshal(pc) data, _ := proto.Marshal(pc)
openChannel := r.constructProtocol(data, 0) r.sendPacket(data, 0)
r.logger.Print("Opening Channel: ", pc)
r.send(openChannel)
return nil return nil
} }
@ -261,9 +255,8 @@ func (r *Ricochet) SendContactRequest(nick string, message string) {
r.logger.Fatal("Cannot Marshal Open Channel Message: ", err) r.logger.Fatal("Cannot Marshal Open Channel Message: ", err)
} }
openChannel := r.constructProtocol(data, 0) r.sendPacket(data, 0)
r.logger.Print("Opening Channel: ", pc)
r.send(openChannel)
} }
// SendMessage sends a Chat Message (message) to a give Channel (channel). // SendMessage sends a Chat Message (message) to a give Channel (channel).
@ -280,9 +273,8 @@ func (r *Ricochet) SendMessage(message string, channel int) {
} }
data, _ := proto.Marshal(chatPacket) data, _ := proto.Marshal(chatPacket)
chatMessageBytes := r.constructProtocol(data, channel) r.sendPacket(data, channel)
r.logger.Print("Sending Message: ", chatPacket)
r.send(chatMessageBytes)
} }
// negotiateVersion Perform version negotiation with the connected host. // negotiateVersion Perform version negotiation with the connected host.
@ -292,7 +284,7 @@ func (r *Ricochet) negotiateVersion() {
version[1] = 0x4D version[1] = 0x4D
version[2] = 0x01 version[2] = 0x01
version[3] = 0x01 version[3] = 0x01
r.send(version) fmt.Fprintf(r.conn, "%s", version)
r.logger.Print("Negotiating Version ", version) r.logger.Print("Negotiating Version ", version)
res, err := r.recv() res, err := r.recv()
@ -307,9 +299,9 @@ func (r *Ricochet) negotiateVersion() {
r.logger.Print("Successfully Negotiated Version ", res[0]) r.logger.Print("Successfully Negotiated Version ", res[0])
} }
// constructProtocol places the data into a structure needed for the client to // sendPacket places the data into a structure needed for the client to
// decode the packet. // decode the packet and writes the packet to the network.
func (r *Ricochet) constructProtocol(data []byte, channel int) []byte { func (r *Ricochet) sendPacket(data []byte, channel int) {
header := make([]byte, 4+len(data)) header := make([]byte, 4+len(data))
r.logger.Print("Wrting Packet of Size: ", len(header)) r.logger.Print("Wrting Packet of Size: ", len(header))
header[0] = byte(len(header) >> 8) header[0] = byte(len(header) >> 8)
@ -317,13 +309,9 @@ func (r *Ricochet) constructProtocol(data []byte, channel int) []byte {
header[2] = 0x00 header[2] = 0x00
header[3] = byte(channel) header[3] = byte(channel)
copy(header[4:], data[:]) copy(header[4:], data[:])
return header fmt.Fprintf(r.conn, "%s", header)
} }
// send is a utility funtion to send data to the connected client.
func (r *Ricochet) send(data []byte) {
fmt.Fprintf(r.conn, "%s", data)
}
// Listen blocks and waits for a new message to arrive from the connected user // Listen blocks and waits for a new message to arrive from the connected user
// once a message has arrived, it returns the message and the channel it occured // once a message has arrived, it returns the message and the channel it occured
@ -349,9 +337,12 @@ func (r *Ricochet) Listen() (string, int, error) {
ChatAcknowledge: cr, ChatAcknowledge: cr,
} }
data, _ := proto.Marshal(pc) data,err := proto.Marshal(pc)
ack := r.constructProtocol(data, message.Channel) if err != nil {
r.send(ack) return "",0,errors.New("Failed to serialize chat message")
}
r.sendPacket(data, message.Channel)
return message.DataPacket.GetChatMessage().GetMessageText(), message.Channel, nil return message.DataPacket.GetChatMessage().GetMessageText(), message.Channel, nil
} }
@ -370,7 +361,12 @@ func (r *Ricochet) ListenAndWait() error {
for _, packet := range packets { for _, packet := range packets {
if packet.Channel == 0 { if packet.Channel == 0 {
// This is a Control Channel Message // This is a Control Channel Message
message, _ := r.decodePacket(packet, CONTROL) message,err := r.decodePacket(packet, CONTROL)
if err != nil {
r.logger.Printf("Failed to decode control packet, discarding")
break;
}
// Automatically accept new channels // Automatically accept new channels
if message.ControlPacket.GetOpenChannel() != nil { if message.ControlPacket.GetOpenChannel() != nil {
@ -384,10 +380,14 @@ func (r *Ricochet) ListenAndWait() error {
ChannelResult: cr, ChannelResult: cr,
} }
data, _ := proto.Marshal(pc) data,err := proto.Marshal(pc)
openChannel := r.constructProtocol(data, 0)
r.logger.Print("Opening Channel: ", pc) if err != nil {
r.send(openChannel) r.logger.Printf("Failed to marshal control packet. Something went very wrong.")
}
r.sendPacket(data, 0)
r.channelState[int(message.ControlPacket.GetOpenChannel().GetChannelIdentifier())] = 1 r.channelState[int(message.ControlPacket.GetOpenChannel().GetChannelIdentifier())] = 1
break break
} }
@ -406,7 +406,13 @@ func (r *Ricochet) ListenAndWait() error {
} else { } else {
// At this point the only other expected type of message // At this point the only other expected type of message
// is a Chat Message // is a Chat Message
message, _ := r.decodePacket(packet, DATA) message,err := r.decodePacket(packet, DATA)
if err != nil {
r.logger.Printf("Failed to decode data packet, discarding")
break;
}
r.logger.Print("Receieved Data Packet: ", message) r.logger.Print("Receieved Data Packet: ", message)
r.channel <- message r.channel <- message
} }
@ -454,7 +460,6 @@ func (r *Ricochet) getMessages() ([]RicochetData, error) {
for !finished { for !finished {
size := int(binary.BigEndian.Uint16(buf[pos+0 : pos+2])) size := int(binary.BigEndian.Uint16(buf[pos+0 : pos+2]))
channel := int(binary.BigEndian.Uint16(buf[pos+2 : pos+4])) channel := int(binary.BigEndian.Uint16(buf[pos+2 : pos+4]))
r.logger.Println(buf[pos+2 : pos+4])
if pos+size > len(buf) { if pos+size > len(buf) {
return datas, errors.New("Partial data packet received") return datas, errors.New("Partial data packet received")
@ -464,7 +469,7 @@ func (r *Ricochet) getMessages() ([]RicochetData, error) {
Channel: int(channel), Channel: int(channel),
Data: buf[pos+4 : pos+size], Data: buf[pos+4 : pos+size],
} }
r.logger.Println("Got new Data:", data)
datas = append(datas, data) datas = append(datas, data)
pos += size pos += size
if pos >= len(buf) { if pos >= len(buf) {