diff --git a/examples/echobot/main.go b/examples/echobot/main.go index 604af11..e1f1bf9 100644 --- a/examples/echobot/main.go +++ b/examples/echobot/main.go @@ -11,27 +11,36 @@ type EchoBotService struct { goricochet.StandardRicochetService } +func (ebs *EchoBotService) OnNewConnection(oc *goricochet.OpenConnection) { + ebs.StandardRicochetService.OnNewConnection(oc) + go oc.Process(&EchoBotConnection{}) +} + +type EchoBotConnection struct { + goricochet.StandardRicochetConnection +} + // IsKnownContact is configured to always accept Contact Requests -func (ebs *EchoBotService) IsKnownContact(hostname string) bool { +func (ebc *EchoBotConnection) IsKnownContact(hostname string) bool { return true } // OnContactRequest - we always accept new contact request. -func (ebs *EchoBotService) OnContactRequest(oc *goricochet.OpenConnection, channelID int32, nick string, message string) { - ts.StandardRicochetService.OnContactRequest(oc, channelID, nick, message) - oc.AckContactRequestOnResponse(channelID, "Accepted") - oc.CloseChannel(channelID) +func (ebc *EchoBotConnection) OnContactRequest(channelID int32, nick string, message string) { + ebc.StandardRicochetConnection.OnContactRequest(channelID, nick, message) + ebc.Conn.AckContactRequestOnResponse(channelID, "Accepted") + ebc.Conn.CloseChannel(channelID) } // OnChatMessage we acknowledge the message, grab the message content and send it back - opening // a new channel if necessary. -func (ebs *EchoBotService) OnChatMessage(oc *goricochet.OpenConnection, channelID int32, messageID int32, message string) { - log.Printf("Received Message from %s: %s", oc.OtherHostname, message) - oc.AckChatMessage(channelID, messageID) - if oc.GetChannelType(6) == "none" { - oc.OpenChatChannel(6) +func (ebc *EchoBotConnection) OnChatMessage(channelID int32, messageID int32, message string) { + log.Printf("Received Message from %s: %s", ebc.Conn.OtherHostname, message) + ebc.Conn.AckChatMessage(channelID, messageID) + if ebc.Conn.GetChannelType(6) == "none" { + ebc.Conn.OpenChatChannel(6) } - oc.SendMessage(6, message) + ebc.Conn.SendMessage(6, message) } func main() { diff --git a/handlers.go b/handlers.go new file mode 100644 index 0000000..083c45f --- /dev/null +++ b/handlers.go @@ -0,0 +1,51 @@ +package goricochet + +// ServiceHandler is the interface to handle events for an inbound connection listener +type ServiceHandler interface { + // OnNewConnection is called for inbound connections to the service after protocol + // version negotiation has completed successfully. + OnNewConnection(oc *OpenConnection) + // OnFailedConnection is called for inbound connections to the service which fail + // to successfully complete version negotiation for any reason. + OnFailedConnection(err error) +} + +// ConnectionHandler is the interface to handle events for an open protocol connection, +// whether inbound or outbound. Each OpenConnection will need its own instance of an +// application type implementing ConnectionHandler, which could also be used to store +// application state related to the connection. +type ConnectionHandler interface { + // OnReady is called before OpenConnection.Process() begins from the connection + OnReady(oc *OpenConnection) + // OnDisconnect is called when the connection is closed, just before + // OpenConnection.Process() returns + OnDisconnect() + + // Authentication Management + OnAuthenticationRequest(channelID int32, clientCookie [16]byte) + OnAuthenticationChallenge(channelID int32, serverCookie [16]byte) + OnAuthenticationProof(channelID int32, publicKey []byte, signature []byte) + OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) + + // Contact Management + IsKnownContact(hostname string) bool + OnContactRequest(channelID int32, nick string, message string) + OnContactRequestAck(channelID int32, status string) + + // Managing Channels + OnOpenChannelRequest(channelID int32, channelType string) + OnOpenChannelRequestSuccess(channelID int32) + OnChannelClosed(channelID int32) + + // Chat Messages + OnChatMessage(channelID int32, messageID int32, message string) + OnChatMessageAck(channelID int32, messageID int32) + + // Handle Errors + OnFailedChannelOpen(channelID int32, errorType string) + OnGenericError(channelID int32) + OnUnknownTypeError(channelID int32) + OnUnauthorizedError(channelID int32) + OnBadUsageError(channelID int32) + OnFailedError(channelID int32) +} diff --git a/openconnection.go b/openconnection.go index af33b3e..bf6465f 100644 --- a/openconnection.go +++ b/openconnection.go @@ -4,7 +4,13 @@ import ( "crypto" "crypto/rsa" "encoding/asn1" + "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/auth" + "github.com/s-rah/go-ricochet/chat" + "github.com/s-rah/go-ricochet/contact" + "github.com/s-rah/go-ricochet/control" "github.com/s-rah/go-ricochet/utils" + "log" "net" ) @@ -298,3 +304,231 @@ func (oc *OpenConnection) SendMessage(channel int32, message string) { utils.CheckError(err) oc.rni.SendRicochetPacket(oc.conn, channel, data) } + +// Process waits for new messages to arrive from the connection and uses the given +// ConnectionHandler to process them. +func (oc *OpenConnection) Process(handler ConnectionHandler) { + handler.OnReady(oc) + defer oc.Close() + defer handler.OnDisconnect() + + for { + if oc.Closed { + return + } + + packet, err := oc.rni.RecvRicochetPacket(oc.conn) + if err != nil { + oc.Close() + return + } + + if len(packet.Data) == 0 { + handler.OnChannelClosed(packet.Channel) + continue + } + + if packet.Channel == 0 { + + res := new(Protocol_Data_Control.Packet) + err := proto.Unmarshal(packet.Data[:], res) + + if err != nil { + handler.OnGenericError(packet.Channel) + continue + } + + if res.GetOpenChannel() != nil { + opm := res.GetOpenChannel() + + if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" { + // Channel is already in use. + handler.OnBadUsageError(opm.GetChannelIdentifier()) + continue + } + + // If I am a Client, the server can only open even numbered channels + if oc.Client && opm.GetChannelIdentifier()%2 != 0 { + handler.OnBadUsageError(opm.GetChannelIdentifier()) + continue + } + + // If I am a Server, the client can only open odd numbered channels + if !oc.Client && opm.GetChannelIdentifier()%2 != 1 { + handler.OnBadUsageError(opm.GetChannelIdentifier()) + continue + } + + switch opm.GetChannelType() { + case "im.ricochet.auth.hidden-service": + if oc.Client { + // Servers are authed by default and can't auth with hidden-service + handler.OnBadUsageError(opm.GetChannelIdentifier()) + } else if oc.IsAuthed { + // Can't auth if already authed + handler.OnBadUsageError(opm.GetChannelIdentifier()) + } else if oc.HasChannel("im.ricochet.auth.hidden-service") { + // Can't open more than 1 auth channel + handler.OnBadUsageError(opm.GetChannelIdentifier()) + } else { + clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie) + if err == nil { + clientCookieB := [16]byte{} + copy(clientCookieB[:], clientCookie.([]byte)[:]) + handler.OnAuthenticationRequest(opm.GetChannelIdentifier(), clientCookieB) + } else { + // Must include Client Cookie + handler.OnBadUsageError(opm.GetChannelIdentifier()) + } + } + case "im.ricochet.chat": + if !oc.IsAuthed { + // Can't open chat channel if not authorized + handler.OnUnauthorizedError(opm.GetChannelIdentifier()) + } else if !handler.IsKnownContact(oc.OtherHostname) { + // Can't open chat channel if not a known contact + handler.OnUnauthorizedError(opm.GetChannelIdentifier()) + } else { + handler.OnOpenChannelRequest(opm.GetChannelIdentifier(), "im.ricochet.chat") + } + case "im.ricochet.contact.request": + if oc.Client { + // Servers are not allowed to send contact requests + handler.OnBadUsageError(opm.GetChannelIdentifier()) + } else if !oc.IsAuthed { + // Can't open a contact channel if not authed + handler.OnUnauthorizedError(opm.GetChannelIdentifier()) + } else if oc.HasChannel("im.ricochet.contact.request") { + // Only 1 contact channel is allowed to be open at a time + handler.OnBadUsageError(opm.GetChannelIdentifier()) + } else { + contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest) + if err == nil { + contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest) + if check { + handler.OnContactRequest(opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText()) + break + } + } + handler.OnBadUsageError(opm.GetChannelIdentifier()) + } + default: + handler.OnUnknownTypeError(opm.GetChannelIdentifier()) + } + } else if res.GetChannelResult() != nil { + crm := res.GetChannelResult() + if crm.GetOpened() { + switch oc.GetChannelType(crm.GetChannelIdentifier()) { + case "im.ricochet.auth.hidden-service": + serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie) + if err == nil { + serverCookieB := [16]byte{} + copy(serverCookieB[:], serverCookie.([]byte)[:]) + handler.OnAuthenticationChallenge(crm.GetChannelIdentifier(), serverCookieB) + } else { + handler.OnBadUsageError(crm.GetChannelIdentifier()) + } + case "im.ricochet.chat": + handler.OnOpenChannelRequestSuccess(crm.GetChannelIdentifier()) + case "im.ricochet.contact.request": + responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response) + if err == nil { + response, check := responseI.(*Protocol_Data_ContactRequest.Response) + if check { + handler.OnContactRequestAck(crm.GetChannelIdentifier(), response.GetStatus().String()) + break + } + } + handler.OnBadUsageError(crm.GetChannelIdentifier()) + default: + handler.OnBadUsageError(crm.GetChannelIdentifier()) + } + } else { + if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" { + handler.OnFailedChannelOpen(crm.GetChannelIdentifier(), crm.GetCommonError().String()) + } else { + oc.CloseChannel(crm.GetChannelIdentifier()) + } + } + } else { + // Unknown Message + oc.CloseChannel(packet.Channel) + } + } else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" { + res := new(Protocol_Data_AuthHiddenService.Packet) + err := proto.Unmarshal(packet.Data[:], res) + + if err != nil { + oc.CloseChannel(packet.Channel) + continue + } + + if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs + handler.OnAuthenticationProof(packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature()) + } else if res.GetResult() != nil && oc.Client { // Only Servers Send Results + handler.OnAuthenticationResult(packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) + } else { + // If neither of the above are satisfied we just close the connection + oc.Close() + } + + } else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" { + + // NOTE: These auth checks should be redundant, however they + // are included here for defense-in-depth if for some reason + // a previously authed connection becomes untrusted / not known and + // the state is not cleaned up. + if !oc.IsAuthed { + // Can't send chat messages if not authorized + handler.OnUnauthorizedError(packet.Channel) + } else if !handler.IsKnownContact(oc.OtherHostname) { + // Can't send chat message if not a known contact + handler.OnUnauthorizedError(packet.Channel) + } else { + res := new(Protocol_Data_Chat.Packet) + err := proto.Unmarshal(packet.Data[:], res) + + if err != nil { + oc.CloseChannel(packet.Channel) + continue + } + + if res.GetChatMessage() != nil { + handler.OnChatMessage(packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText()) + } else if res.GetChatAcknowledge() != nil { + handler.OnChatMessageAck(packet.Channel, int32(res.GetChatMessage().GetMessageId())) + } else { + // If neither of the above are satisfied we just close the connection + oc.Close() + } + } + } else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" { + + // NOTE: These auth checks should be redundant, however they + // are included here for defense-in-depth if for some reason + // a previously authed connection becomes untrusted / not known and + // the state is not cleaned up. + if !oc.Client { + // Clients are not allowed to send contact request responses + handler.OnBadUsageError(packet.Channel) + } else if !oc.IsAuthed { + // Can't send a contact request if not authed + handler.OnBadUsageError(packet.Channel) + } else { + res := new(Protocol_Data_ContactRequest.Response) + err := proto.Unmarshal(packet.Data[:], res) + log.Printf("%v", res) + if err != nil { + oc.CloseChannel(packet.Channel) + continue + } + handler.OnContactRequestAck(packet.Channel, res.GetStatus().String()) + } + } else if oc.GetChannelType(packet.Channel) == "none" { + // Invalid Channel Assignment + oc.CloseChannel(packet.Channel) + } else { + oc.Close() + } + } +} diff --git a/ricochet.go b/ricochet.go index 65f0675..ad5260f 100644 --- a/ricochet.go +++ b/ricochet.go @@ -2,344 +2,106 @@ package goricochet import ( "errors" - "github.com/golang/protobuf/proto" - "github.com/s-rah/go-ricochet/auth" - "github.com/s-rah/go-ricochet/chat" - "github.com/s-rah/go-ricochet/contact" - "github.com/s-rah/go-ricochet/control" "github.com/s-rah/go-ricochet/utils" "io" - "log" "net" - "strconv" + "sync" ) -// Ricochet is a protocol to conducting anonymous IM. -type Ricochet struct { - newconns chan *OpenConnection - networkResolver utils.NetworkResolver - rni utils.RicochetNetworkInterface -} - -// Init sets up the Ricochet object. -func (r *Ricochet) Init() { - r.newconns = make(chan *OpenConnection) - r.networkResolver = utils.NetworkResolver{} - r.rni = new(utils.RicochetNetwork) -} - // Connect sets up a client ricochet connection to host e.g. qn6uo4cmsrfv4kzq.onion. If this // function finished successfully then the connection can be assumed to // be open and authenticated. // To specify a local port using the format "127.0.0.1:[port]|ricochet-id". -func (r *Ricochet) Connect(host string) (*OpenConnection, error) { - var err error - conn, host, err := r.networkResolver.Resolve(host) +func Connect(host string) (*OpenConnection, error) { + networkResolver := utils.NetworkResolver{} + conn, host, err := networkResolver.Resolve(host) if err != nil { return nil, err } - return r.ConnectOpen(conn, host) + return Open(conn, host) } -// ConnectOpen attempts to open up a new connection to the given host. Returns a -// pointer to the OpenConnection or an error. -func (r *Ricochet) ConnectOpen(conn net.Conn, host string) (*OpenConnection, error) { - oc, err := r.negotiateVersion(conn, true) +// Open establishes a protocol session on an established net.Conn, and returns a new +// OpenConnection instance representing this connection. On error, the connection +// will be closed. This function blocks until version negotiation has completed. +// The application should call Process() on the returned OpenConnection to continue +// handling protocol messages. +func Open(conn net.Conn, remoteHostname string) (*OpenConnection, error) { + oc, err := negotiateVersion(conn, true) if err != nil { + conn.Close() return nil, err } - oc.OtherHostname = host - r.newconns <- oc + oc.OtherHostname = remoteHostname return oc, nil } -// Server launches a new server listening on port -func (r *Ricochet) Server(service RicochetService, port int) { - ln, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(port)) - if err != nil { - log.Printf("Cannot Listen on Port %v", port) - return - } +// Serve accepts incoming connections on a net.Listener, negotiates protocol, +// and calls methods of the ServiceHandler to handle inbound connections. All +// calls to ServiceHandler happen on the caller's goroutine. The listener can +// be closed at any time to close the service. +func Serve(ln net.Listener, handler ServiceHandler) error { + defer ln.Close() - r.ServeListener(service, ln) -} - -// ServeListener processes all messages given by the listener ln with the given -// RicochetService, service. -func (r *Ricochet) ServeListener(service RicochetService, ln net.Listener) { - go r.ProcessMessages(service) - service.OnReady() - for { - // accept connection on port - conn, err := ln.Accept() - if err != nil { - return - } - go r.processNewConnection(conn, service) - } -} - -// processNewConnection sets up a new connection -func (r *Ricochet) processNewConnection(conn net.Conn, service RicochetService) { - oc, err := r.negotiateVersion(conn, false) - if err == nil { - r.newconns <- oc - service.OnConnect(oc) - } -} - -// ProcessMessages is intended to be a background thread listening for all messages -// a client will send. The given RicochetService will be used to respond to messages. -// Prerequisites: -// * Must have previously issued a successful Connect() -func (r *Ricochet) ProcessMessages(service RicochetService) { - for { - oc := <-r.newconns - if oc == nil { - return - } - go r.processConnection(oc, service) - } -} - -// RequestStopMessageLoop requests that the ProcessMessages loop is stopped after handling all currently -// queued new connections. -func (r *Ricochet) RequestStopMessageLoop() { - r.newconns <- nil -} - -// ProcessConnection starts a blocking process loop which continually waits for -// new messages to arrive from the connection and uses the given RicochetService -// to process them. -func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService) { - service.OnConnect(oc) - defer service.OnDisconnect(oc) - - for { - if oc.Closed { - return - } - - packet, err := r.rni.RecvRicochetPacket(oc.conn) - if err != nil { - oc.Close() - return - } - - if len(packet.Data) == 0 { - service.OnChannelClosed(oc, packet.Channel) - continue - } - - if packet.Channel == 0 { - - res := new(Protocol_Data_Control.Packet) - err := proto.Unmarshal(packet.Data[:], res) + connChannel := make(chan interface{}) + listenErrorChannel := make(chan error) + go func() { + var pending sync.WaitGroup + for { + conn, err := ln.Accept() if err != nil { - service.OnGenericError(oc, packet.Channel) - continue + // Wait for pending connections before returning an error; this + // prevents abandoned goroutines when the outer loop stops reading + // from connChannel. + pending.Wait() + listenErrorChannel <- err + close(connChannel) + return } - if res.GetOpenChannel() != nil { - opm := res.GetOpenChannel() - - if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" { - // Channel is already in use. - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - continue - } - - // If I am a Client, the server can only open even numbered channels - if oc.Client && opm.GetChannelIdentifier()%2 != 0 { - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - continue - } - - // If I am a Server, the client can only open odd numbered channels - if !oc.Client && opm.GetChannelIdentifier()%2 != 1 { - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - continue - } - - switch opm.GetChannelType() { - case "im.ricochet.auth.hidden-service": - if oc.Client { - // Servers are authed by default and can't auth with hidden-service - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else if oc.IsAuthed { - // Can't auth if already authed - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else if oc.HasChannel("im.ricochet.auth.hidden-service") { - // Can't open more than 1 auth channel - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else { - clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie) - if err == nil { - clientCookieB := [16]byte{} - copy(clientCookieB[:], clientCookie.([]byte)[:]) - service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB) - } else { - // Must include Client Cookie - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } - } - case "im.ricochet.chat": - if !oc.IsAuthed { - // Can't open chat channel if not authorized - service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) - } else if !service.IsKnownContact(oc.OtherHostname) { - // Can't open chat channel if not a known contact - service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) - } else { - service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat") - } - case "im.ricochet.contact.request": - if oc.Client { - // Servers are not allowed to send contact requests - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else if !oc.IsAuthed { - // Can't open a contact channel if not authed - service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) - } else if oc.HasChannel("im.ricochet.contact.request") { - // Only 1 contact channel is allowed to be open at a time - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } else { - contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest) - if err == nil { - contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest) - if check { - service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText()) - break - } - } - service.OnBadUsageError(oc, opm.GetChannelIdentifier()) - } - default: - service.OnUnknownTypeError(oc, opm.GetChannelIdentifier()) - } - } else if res.GetChannelResult() != nil { - crm := res.GetChannelResult() - if crm.GetOpened() { - switch oc.GetChannelType(crm.GetChannelIdentifier()) { - case "im.ricochet.auth.hidden-service": - serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie) - if err == nil { - serverCookieB := [16]byte{} - copy(serverCookieB[:], serverCookie.([]byte)[:]) - service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB) - } else { - service.OnBadUsageError(oc, crm.GetChannelIdentifier()) - } - case "im.ricochet.chat": - service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier()) - case "im.ricochet.contact.request": - responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response) - if err == nil { - response, check := responseI.(*Protocol_Data_ContactRequest.Response) - if check { - service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String()) - break - } - } - service.OnBadUsageError(oc, crm.GetChannelIdentifier()) - default: - service.OnBadUsageError(oc, crm.GetChannelIdentifier()) - } - } else { - if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" { - service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String()) - } else { - oc.CloseChannel(crm.GetChannelIdentifier()) - } - } - } else { - // Unknown Message - oc.CloseChannel(packet.Channel) - } - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" { - res := new(Protocol_Data_AuthHiddenService.Packet) - err := proto.Unmarshal(packet.Data[:], res) - - if err != nil { - oc.CloseChannel(packet.Channel) - continue - } - - if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs - service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname)) - } else if res.GetResult() != nil && oc.Client { // Only Servers Send Results - service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) - } else { - // If neither of the above are satisfied we just close the connection - oc.Close() - } - - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" { - - // NOTE: These auth checks should be redundant, however they - // are included here for defense-in-depth if for some reason - // a previously authed connection becomes untrusted / not known and - // the state is not cleaned up. - if !oc.IsAuthed { - // Can't send chat messages if not authorized - service.OnUnauthorizedError(oc, packet.Channel) - } else if !service.IsKnownContact(oc.OtherHostname) { - // Can't send chat message if not a known contact - service.OnUnauthorizedError(oc, packet.Channel) - } else { - res := new(Protocol_Data_Chat.Packet) - err := proto.Unmarshal(packet.Data[:], res) - + pending.Add(1) + go func() { + defer pending.Done() + oc, err := negotiateVersion(conn, false) if err != nil { - oc.CloseChannel(packet.Channel) - continue - } - - if res.GetChatMessage() != nil { - service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText()) - } else if res.GetChatAcknowledge() != nil { - service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId())) + conn.Close() + connChannel <- err } else { - // If neither of the above are satisfied we just close the connection - oc.Close() + connChannel <- oc } - } - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" { + }() + } + }() - // NOTE: These auth checks should be redundant, however they - // are included here for defense-in-depth if for some reason - // a previously authed connection becomes untrusted / not known and - // the state is not cleaned up. - if !oc.Client { - // Clients are not allowed to send contact request responses - service.OnBadUsageError(oc, packet.Channel) - } else if !oc.IsAuthed { - // Can't send a contact request if not authed - service.OnBadUsageError(oc, packet.Channel) - } else { - res := new(Protocol_Data_ContactRequest.Response) - err := proto.Unmarshal(packet.Data[:], res) - log.Printf("%v", res) - if err != nil { - oc.CloseChannel(packet.Channel) - continue - } - service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String()) + var listenErr error + for { + select { + case err := <-listenErrorChannel: + // Remember error, wait for connChannel to close + listenErr = err + + case result, ok := <-connChannel: + if !ok { + return listenErr + } + + switch v := result.(type) { + case *OpenConnection: + handler.OnNewConnection(v) + case error: + handler.OnFailedConnection(v) } - } else if oc.GetChannelType(packet.Channel) == "none" { - // Invalid Channel Assignment - oc.CloseChannel(packet.Channel) - } else { - oc.Close() } } + + return nil } // Perform version negotiation on the connection, and create an OpenConnection if successful -func (r *Ricochet) negotiateVersion(conn net.Conn, outbound bool) (*OpenConnection, error) { +func negotiateVersion(conn net.Conn, outbound bool) (*OpenConnection, error) { versions := []byte{0x49, 0x4D, 0x01, 0x01} // Outbound side of the connection sends a list of supported versions diff --git a/ricochetservice.go b/ricochetservice.go deleted file mode 100644 index e3151fd..0000000 --- a/ricochetservice.go +++ /dev/null @@ -1,36 +0,0 @@ -package goricochet - -// RicochetService provides an interface for building automated ricochet applications. -type RicochetService interface { - OnReady() - OnConnect(oc *OpenConnection) - OnDisconnect(oc *OpenConnection) - - // Authentication Management - OnAuthenticationRequest(oc *OpenConnection, channelID int32, clientCookie [16]byte) - OnAuthenticationChallenge(oc *OpenConnection, channelID int32, serverCookie [16]byte) - OnAuthenticationProof(oc *OpenConnection, channelID int32, publicKey []byte, signature []byte, isKnownContact bool) - OnAuthenticationResult(oc *OpenConnection, channelID int32, result bool, isKnownContact bool) - - // Contact Management - IsKnownContact(hostname string) bool - OnContactRequest(oc *OpenConnection, channelID int32, nick string, message string) - OnContactRequestAck(oc *OpenConnection, channelID int32, status string) - - // Managing Channels - OnOpenChannelRequest(oc *OpenConnection, channelID int32, channelType string) - OnOpenChannelRequestSuccess(oc *OpenConnection, channelID int32) - OnChannelClosed(oc *OpenConnection, channelID int32) - - // Chat Messages - OnChatMessage(oc *OpenConnection, channelID int32, messageID int32, message string) - OnChatMessageAck(oc *OpenConnection, channelID int32, messageID int32) - - // Handle Errors - OnFailedChannelOpen(oc *OpenConnection, channelID int32, errorType string) - OnGenericError(oc *OpenConnection, channelID int32) - OnUnknownTypeError(oc *OpenConnection, channelID int32) - OnUnauthorizedError(oc *OpenConnection, channelID int32) - OnBadUsageError(oc *OpenConnection, channelID int32) - OnFailedError(oc *OpenConnection, channelID int32) -} diff --git a/standardricochetservice.go b/standardricochetservice.go index 3c51d3c..670f63f 100644 --- a/standardricochetservice.go +++ b/standardricochetservice.go @@ -9,23 +9,30 @@ import ( "github.com/s-rah/go-ricochet/utils" "io/ioutil" "log" + "net" + "strconv" ) // StandardRicochetService implements all the necessary flows to implement a // minimal, protocol compliant Ricochet Service. It can be built on by other -// applications to produce automated riochet applications. +// applications to produce automated riochet applications, and is a useful +// example for other implementations. type StandardRicochetService struct { - ricochet *Ricochet - privateKey *rsa.PrivateKey + PrivateKey *rsa.PrivateKey serverHostname string } +// StandardRicochetConnection implements the ConnectionHandler interface +// to handle events on connections. An instance of StandardRicochetConnection +// is created for each OpenConnection by the HandleConnection method. +type StandardRicochetConnection struct { + Conn *OpenConnection + PrivateKey *rsa.PrivateKey +} + // Init initializes a StandardRicochetService with the cryptographic key given // by filename. func (srs *StandardRicochetService) Init(filename string) error { - srs.ricochet = new(Ricochet) - srs.ricochet.Init() - pemData, err := ioutil.ReadFile(filename) if err != nil { @@ -37,14 +44,14 @@ func (srs *StandardRicochetService) Init(filename string) error { return errors.New("Could not setup ricochet service: no valid PEM data found") } - srs.privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + srs.PrivateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return errors.New("Could not setup ricochet service: could not parse private key") } publicKeyBytes, _ := asn1.Marshal(rsa.PublicKey{ - N: srs.privateKey.PublicKey.N, - E: srs.privateKey.PublicKey.E, + N: srs.PrivateKey.PublicKey.N, + E: srs.PrivateKey.PublicKey.E, }) srs.serverHostname = utils.GetTorHostname(publicKeyBytes) @@ -53,130 +60,151 @@ func (srs *StandardRicochetService) Init(filename string) error { return nil } -// OnReady is called once a Server has been established (by calling Listen) -func (srs *StandardRicochetService) OnReady() { -} - -// Listen starts the ricochet service. Listen must be called before any other method (apart from Init) -func (srs *StandardRicochetService) Listen(service RicochetService, port int) { - srs.ricochet.Server(service, port) -} - -// Connect can be called to initiate a new client connection to a server -func (srs *StandardRicochetService) Connect(hostname string) error { - log.Printf("Connecting to...%s", hostname) - oc, err := srs.ricochet.Connect(hostname) +// Listen starts listening for service connections on localhost `port`. +func (srs *StandardRicochetService) Listen(handler ServiceHandler, port int) { + ln, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(port)) if err != nil { - return errors.New("Could not connect to: " + hostname + " " + err.Error()) + log.Printf("Cannot Listen on Port %v", port) + return + } + + Serve(ln, handler) +} + +// Connect initiates a new client connection to `hostname`, which must be in one +// of the forms accepted by the goricochet.Connect() method. +func (srs *StandardRicochetService) Connect(hostname string) (*OpenConnection, error) { + log.Printf("Connecting to...%s", hostname) + oc, err := Connect(hostname) + if err != nil { + return nil, errors.New("Could not connect to: " + hostname + " " + err.Error()) } oc.MyHostname = srs.serverHostname - return nil + return oc, nil } -// OnConnect is called when a client or server successfully passes Version Negotiation. -func (srs *StandardRicochetService) OnConnect(oc *OpenConnection) { +// OnNewConnection is called for new inbound connections to our service. This +// method implements the ServiceHandler interface. +func (srs *StandardRicochetService) OnNewConnection(oc *OpenConnection) { + oc.MyHostname = srs.serverHostname +} + +// OnFailedConnection is called for inbound connections that fail to successfully +// complete version negotiation for any reason. This method implements the +// ServiceHandler interface. +func (srs *StandardRicochetService) OnFailedConnection(err error) { + log.Printf("Inbound connection failed: %s", err) +} + +// ------ + +// OnReady is called when a client or server sucessfully passes Version Negotiation. +func (src *StandardRicochetConnection) OnReady(oc *OpenConnection) { + src.Conn = oc if oc.Client { - log.Printf("Sucessefully Connected to %s", oc.OtherHostname) + log.Printf("Successfully connected to %s", oc.OtherHostname) oc.IsAuthed = true // Connections to Servers are Considered Authenticated by Default oc.Authenticate(1) } else { - oc.MyHostname = srs.serverHostname + log.Printf("Inbound connection received") } } // OnDisconnect is called when a connection is closed -func (srs *StandardRicochetService) OnDisconnect(oc *OpenConnection) { +func (src *StandardRicochetConnection) OnDisconnect() { + log.Printf("Disconnected from %s", src.Conn.OtherHostname) } // OnAuthenticationRequest is called when a client requests Authentication -func (srs *StandardRicochetService) OnAuthenticationRequest(oc *OpenConnection, channelID int32, clientCookie [16]byte) { - oc.ConfirmAuthChannel(channelID, clientCookie) +func (src *StandardRicochetConnection) OnAuthenticationRequest(channelID int32, clientCookie [16]byte) { + src.Conn.ConfirmAuthChannel(channelID, clientCookie) } // OnAuthenticationChallenge constructs a valid authentication challenge to the serverCookie -func (srs *StandardRicochetService) OnAuthenticationChallenge(oc *OpenConnection, channelID int32, serverCookie [16]byte) { +func (src *StandardRicochetConnection) OnAuthenticationChallenge(channelID int32, serverCookie [16]byte) { // DER Encode the Public Key publickeyBytes, _ := asn1.Marshal(rsa.PublicKey{ - N: srs.privateKey.PublicKey.N, - E: srs.privateKey.PublicKey.E, + N: src.PrivateKey.PublicKey.N, + E: src.PrivateKey.PublicKey.E, }) - oc.SendProof(1, serverCookie, publickeyBytes, srs.privateKey) + src.Conn.SendProof(1, serverCookie, publickeyBytes, src.PrivateKey) } // OnAuthenticationProof is called when a client sends Proof for an existing authentication challenge -func (srs *StandardRicochetService) OnAuthenticationProof(oc *OpenConnection, channelID int32, publicKey []byte, signature []byte, isKnownContact bool) { - result := oc.ValidateProof(channelID, publicKey, signature) - oc.SendAuthenticationResult(channelID, result, isKnownContact) - oc.IsAuthed = result - oc.CloseChannel(channelID) +func (src *StandardRicochetConnection) OnAuthenticationProof(channelID int32, publicKey []byte, signature []byte) { + result := src.Conn.ValidateProof(channelID, publicKey, signature) + // This implementation always sends 'true', indicating that the contact is known + src.Conn.SendAuthenticationResult(channelID, result, true) + src.Conn.IsAuthed = result + src.Conn.CloseChannel(channelID) } // OnAuthenticationResult is called once a server has returned the result of the Proof Verification -func (srs *StandardRicochetService) OnAuthenticationResult(oc *OpenConnection, channelID int32, result bool, isKnownContact bool) { - oc.IsAuthed = result +func (src *StandardRicochetConnection) OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) { + src.Conn.IsAuthed = result } // IsKnownContact allows a caller to determine if a hostname an authorized contact. -func (srs *StandardRicochetService) IsKnownContact(hostname string) bool { +func (src *StandardRicochetConnection) IsKnownContact(hostname string) bool { return false } // OnContactRequest is called when a client sends a new contact request -func (srs *StandardRicochetService) OnContactRequest(oc *OpenConnection, channelID int32, nick string, message string) { +func (src *StandardRicochetConnection) OnContactRequest(channelID int32, nick string, message string) { } // OnContactRequestAck is called when a server sends a reply to an existing contact request -func (srs *StandardRicochetService) OnContactRequestAck(oc *OpenConnection, channelID int32, status string) { +func (src *StandardRicochetConnection) OnContactRequestAck(channelID int32, status string) { } // OnOpenChannelRequest is called when a client or server requests to open a new channel -func (srs *StandardRicochetService) OnOpenChannelRequest(oc *OpenConnection, channelID int32, channelType string) { - oc.AckOpenChannel(channelID, channelType) +func (src *StandardRicochetConnection) OnOpenChannelRequest(channelID int32, channelType string) { + src.Conn.AckOpenChannel(channelID, channelType) } // OnOpenChannelRequestSuccess is called when a client or server responds to an open channel request -func (srs *StandardRicochetService) OnOpenChannelRequestSuccess(oc *OpenConnection, channelID int32) { +func (src *StandardRicochetConnection) OnOpenChannelRequestSuccess(channelID int32) { } // OnChannelClosed is called when a client or server closes an existing channel -func (srs *StandardRicochetService) OnChannelClosed(oc *OpenConnection, channelID int32) { +func (src *StandardRicochetConnection) OnChannelClosed(channelID int32) { } // OnChatMessage is called when a new chat message is received. -func (srs *StandardRicochetService) OnChatMessage(oc *OpenConnection, channelID int32, messageID int32, message string) { - oc.AckChatMessage(channelID, messageID) +func (src *StandardRicochetConnection) OnChatMessage(channelID int32, messageID int32, message string) { + src.Conn.AckChatMessage(channelID, messageID) } // OnChatMessageAck is called when a new chat message is ascknowledged. -func (srs *StandardRicochetService) OnChatMessageAck(oc *OpenConnection, channelID int32, messageID int32) { +func (src *StandardRicochetConnection) OnChatMessageAck(channelID int32, messageID int32) { } // OnFailedChannelOpen is called when a server fails to open a channel -func (srs *StandardRicochetService) OnFailedChannelOpen(oc *OpenConnection, channelID int32, errorType string) { - oc.UnsetChannel(channelID) +func (src *StandardRicochetConnection) OnFailedChannelOpen(channelID int32, errorType string) { + src.Conn.UnsetChannel(channelID) } // OnGenericError is called when a generalized error is returned from the peer -func (srs *StandardRicochetService) OnGenericError(oc *OpenConnection, channelID int32) { - oc.RejectOpenChannel(channelID, "GenericError") +func (src *StandardRicochetConnection) OnGenericError(channelID int32) { + src.Conn.RejectOpenChannel(channelID, "GenericError") } //OnUnknownTypeError is called when an unknown type error is returned from the peer -func (srs *StandardRicochetService) OnUnknownTypeError(oc *OpenConnection, channelID int32) { - oc.RejectOpenChannel(channelID, "UnknownTypeError") +func (src *StandardRicochetConnection) OnUnknownTypeError(channelID int32) { + src.Conn.RejectOpenChannel(channelID, "UnknownTypeError") } // OnUnauthorizedError is called when an unathorized error is returned from the peer -func (srs *StandardRicochetService) OnUnauthorizedError(oc *OpenConnection, channelID int32) { - oc.RejectOpenChannel(channelID, "UnauthorizedError") +func (src *StandardRicochetConnection) OnUnauthorizedError(channelID int32) { + src.Conn.RejectOpenChannel(channelID, "UnauthorizedError") } // OnBadUsageError is called when a bad usage error is returned from the peer -func (srs *StandardRicochetService) OnBadUsageError(oc *OpenConnection, channelID int32) { - oc.RejectOpenChannel(channelID, "BadUsageError") +func (src *StandardRicochetConnection) OnBadUsageError(channelID int32) { + src.Conn.RejectOpenChannel(channelID, "BadUsageError") } // OnFailedError is called when a failed error is returned from the peer -func (srs *StandardRicochetService) OnFailedError(oc *OpenConnection, channelID int32) { - oc.RejectOpenChannel(channelID, "FailedError") +func (src *StandardRicochetConnection) OnFailedError(channelID int32) { + src.Conn.RejectOpenChannel(channelID, "FailedError") } diff --git a/standardricochetservice_bad_usage_error_test.go b/standardricochetservice_bad_usage_error_test.go index 104fe7a..03778f9 100644 --- a/standardricochetservice_bad_usage_error_test.go +++ b/standardricochetservice_bad_usage_error_test.go @@ -11,70 +11,76 @@ type TestBadUsageService struct { ChannelClosed int } -func (ts *TestBadUsageService) OnConnect(oc *OpenConnection) { +type TestBadUsageConnection struct { + StandardRicochetConnection + Service *TestBadUsageService +} + +func (ts *TestBadUsageService) OnNewConnection(oc *OpenConnection) { + ts.StandardRicochetService.OnNewConnection(oc) + go oc.Process(&TestBadUsageConnection{Service: ts}) +} + +func (tc *TestBadUsageConnection) OnReady(oc *OpenConnection) { if oc.Client { oc.OpenChannel(17, "im.ricochet.auth.hidden-service") // Fail because no Extension } - ts.StandardRicochetService.OnConnect(oc) + tc.StandardRicochetConnection.OnReady(oc) if oc.Client { oc.Authenticate(103) // Should Fail because cannot open more than one auth-hidden-service channel at once } } -func (ts *TestBadUsageService) OnAuthenticationProof(oc *OpenConnection, channelID int32, publicKey []byte, signature []byte, isKnownContact bool) { - oc.Authenticate(2) // Try to authenticate again...will fail servers don't auth - oc.SendContactRequest(4, "test", "test") // Only clients can send contact requests - ts.StandardRicochetService.OnAuthenticationProof(oc, channelID, publicKey, signature, isKnownContact) - oc.OpenChatChannel(5) // Fail because server can only open even numbered channels - oc.OpenChatChannel(3) // Fail because already in use... +func (tc *TestBadUsageConnection) OnAuthenticationProof(channelID int32, publicKey []byte, signature []byte) { + tc.Conn.Authenticate(2) // Try to authenticate again...will fail servers don't auth + tc.Conn.SendContactRequest(4, "test", "test") // Only clients can send contact requests + tc.StandardRicochetConnection.OnAuthenticationProof(channelID, publicKey, signature) + tc.Conn.OpenChatChannel(5) // Fail because server can only open even numbered channels + tc.Conn.OpenChatChannel(3) // Fail because already in use... } // OnContactRequest is called when a client sends a new contact request -func (ts *TestBadUsageService) OnContactRequest(oc *OpenConnection, channelID int32, nick string, message string) { - oc.AckContactRequestOnResponse(channelID, "Pending") // Done to keep the contact request channel open +func (tc *TestBadUsageConnection) OnContactRequest(channelID int32, nick string, message string) { + tc.Conn.AckContactRequestOnResponse(channelID, "Pending") // Done to keep the contact request channel open } -func (ts *TestBadUsageService) OnAuthenticationResult(oc *OpenConnection, channelID int32, result bool, isKnownContact bool) { - ts.StandardRicochetService.OnAuthenticationResult(oc, channelID, result, isKnownContact) +func (tc *TestBadUsageConnection) OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) { + tc.StandardRicochetConnection.OnAuthenticationResult(channelID, result, isKnownContact) - oc.OpenChatChannel(3) // Succeed - oc.OpenChatChannel(3) // Should fail as duplicate (channel already in use) + tc.Conn.OpenChatChannel(3) // Succeed + tc.Conn.OpenChatChannel(3) // Should fail as duplicate (channel already in use) - oc.OpenChatChannel(6) // Should fail because clients are not allowed to open even numbered channels + tc.Conn.OpenChatChannel(6) // Should fail because clients are not allowed to open even numbered channels - oc.SendMessage(101, "test") // Should fail as 101 doesn't exist + tc.Conn.SendMessage(101, "test") // Should fail as 101 doesn't exist - oc.Authenticate(1) // Try to authenticate again...will fail because we have already authenticated + tc.Conn.Authenticate(1) // Try to authenticate again...will fail because we have already authenticated - oc.OpenChannel(19, "im.ricochet.contact.request") // Will Fail - oc.SendContactRequest(11, "test", "test") // Succeed - oc.SendContactRequest(13, "test", "test") // Trigger singleton contact request check + tc.Conn.OpenChannel(19, "im.ricochet.contact.request") // Will Fail + tc.Conn.SendContactRequest(11, "test", "test") // Succeed + tc.Conn.SendContactRequest(13, "test", "test") // Trigger singleton contact request check - oc.OpenChannel(15, "im.ricochet.not-a-real-type") // Fail UnknownType + tc.Conn.OpenChannel(15, "im.ricochet.not-a-real-type") // Fail UnknownType } // OnChannelClose is called when a client or server closes an existing channel -func (ts *TestBadUsageService) OnChannelClosed(oc *OpenConnection, channelID int32) { +func (tc *TestBadUsageConnection) OnChannelClosed(channelID int32) { if channelID == 101 { log.Printf("Received Channel Closed: %v", channelID) - ts.ChannelClosed++ + tc.Service.ChannelClosed++ } } -func (ts *TestBadUsageService) OnFailedChannelOpen(oc *OpenConnection, channelID int32, errorType string) { +func (tc *TestBadUsageConnection) OnFailedChannelOpen(channelID int32, errorType string) { log.Printf("Failed Channel Open %v %v", channelID, errorType) - ts.StandardRicochetService.OnFailedChannelOpen(oc, channelID, errorType) + tc.StandardRicochetConnection.OnFailedChannelOpen(channelID, errorType) if errorType == "BadUsageError" { - ts.BadUsageErrorCount++ + tc.Service.BadUsageErrorCount++ } else if errorType == "UnknownTypeError" { - ts.UnknownTypeErrorCount++ + tc.Service.UnknownTypeErrorCount++ } } -func (ts *TestBadUsageService) IsKnownContact(hostname string) bool { - return true -} - func TestBadUsageServer(t *testing.T) { ricochetService := new(TestBadUsageService) err := ricochetService.Init("./private_key") @@ -95,10 +101,16 @@ func TestBadUsageServer(t *testing.T) { } go ricochetService2.Listen(ricochetService2, 9885) - err = ricochetService2.Connect("127.0.0.1:9884|kwke2hntvyfqm7dr") + oc, err := ricochetService2.Connect("127.0.0.1:9884|kwke2hntvyfqm7dr") if err != nil { t.Errorf("Could not connect to ricochet service: %v", err) } + go oc.Process(&TestBadUsageConnection{ + Service: ricochetService2, + StandardRicochetConnection: StandardRicochetConnection{ + PrivateKey: ricochetService2.PrivateKey, + }, + }) time.Sleep(time.Second * 3) if ricochetService2.ChannelClosed != 1 || ricochetService2.BadUsageErrorCount != 7 || ricochetService.BadUsageErrorCount != 4 || ricochetService2.UnknownTypeErrorCount != 1 { diff --git a/standardricochetservice_test.go b/standardricochetservice_test.go index 41bb91b..10ab2fe 100644 --- a/standardricochetservice_test.go +++ b/standardricochetservice_test.go @@ -6,52 +6,69 @@ import "log" type TestService struct { StandardRicochetService - ReceivedMessage bool - KnownContact bool // Mocking contact request } -func (ts *TestService) OnAuthenticationResult(oc *OpenConnection, channelID int32, result bool, isKnownContact bool) { - ts.StandardRicochetService.OnAuthenticationResult(oc, channelID, result, isKnownContact) +func (ts *TestService) OnNewConnection(oc *OpenConnection) { + ts.StandardRicochetService.OnNewConnection(oc) + go oc.Process(&TestConnection{}) +} + +type TestConnection struct { + StandardRicochetConnection + KnownContact bool // Mocking contact request +} + +func (tc *TestConnection) IsKnownContact(hostname string) bool { + return tc.KnownContact +} + +func (tc *TestConnection) OnAuthenticationProof(channelID int32, publicKey, signature []byte) { + result := tc.Conn.ValidateProof(channelID, publicKey, signature) + tc.Conn.SendAuthenticationResult(channelID, result, tc.KnownContact) + tc.Conn.IsAuthed = result + tc.Conn.CloseChannel(channelID) +} + +func (tc *TestConnection) OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) { + tc.StandardRicochetConnection.OnAuthenticationResult(channelID, result, isKnownContact) if !isKnownContact { log.Printf("Sending Contact Request") - oc.SendContactRequest(3, "test", "test") + tc.Conn.SendContactRequest(3, "test", "test") } } -func (ts *TestService) OnContactRequest(oc *OpenConnection, channelID int32, nick string, message string) { - ts.StandardRicochetService.OnContactRequest(oc, channelID, nick, message) - oc.AckContactRequestOnResponse(channelID, "Pending") - oc.AckContactRequest(channelID, "Accepted") - ts.KnownContact = true - oc.CloseChannel(channelID) +func (tc *TestConnection) OnContactRequest(channelID int32, nick string, message string) { + tc.StandardRicochetConnection.OnContactRequest(channelID, nick, message) + tc.Conn.AckContactRequestOnResponse(channelID, "Pending") + tc.Conn.AckContactRequest(channelID, "Accepted") + tc.KnownContact = true + tc.Conn.CloseChannel(channelID) } -func (ts *TestService) OnOpenChannelRequestSuccess(oc *OpenConnection, channelID int32) { - ts.StandardRicochetService.OnOpenChannelRequestSuccess(oc, channelID) - oc.SendMessage(channelID, "TEST MESSAGE") +func (tc *TestConnection) OnOpenChannelRequestSuccess(channelID int32) { + tc.StandardRicochetConnection.OnOpenChannelRequestSuccess(channelID) + tc.Conn.SendMessage(channelID, "TEST MESSAGE") } -func (ts *TestService) OnContactRequestAck(oc *OpenConnection, channelID int32, status string) { - ts.StandardRicochetService.OnContactRequestAck(oc, channelID, status) +func (tc *TestConnection) OnContactRequestAck(channelID int32, status string) { + tc.StandardRicochetConnection.OnContactRequestAck(channelID, status) if status == "Accepted" { log.Printf("Got accepted contact request") - ts.KnownContact = true - oc.OpenChatChannel(5) + tc.KnownContact = true + tc.Conn.OpenChatChannel(5) } else if status == "Pending" { log.Printf("Got pending contact request") } } -func (ts *TestService) OnChatMessage(oc *OpenConnection, channelID int32, messageID int32, message string) { - ts.StandardRicochetService.OnChatMessage(oc, channelID, messageID, message) +func (tc *TestConnection) OnChatMessage(channelID int32, messageID int32, message string) { + tc.StandardRicochetConnection.OnChatMessage(channelID, messageID, message) if message == "TEST MESSAGE" { - ts.ReceivedMessage = true + receivedMessage = true } } -func (ts *TestService) IsKnownContact(hostname string) bool { - return ts.KnownContact -} +var receivedMessage bool func TestServer(t *testing.T) { ricochetService := new(TestService) @@ -73,16 +90,21 @@ func TestServer(t *testing.T) { } go ricochetService2.Listen(ricochetService2, 9879) - err = ricochetService2.Connect("127.0.0.1:9878|kwke2hntvyfqm7dr") + oc, err := ricochetService2.Connect("127.0.0.1:9878|kwke2hntvyfqm7dr") if err != nil { t.Errorf("Could not connect to ricochet service: %v", err) } + testClient := &TestConnection{ + StandardRicochetConnection: StandardRicochetConnection{ + PrivateKey: ricochetService2.PrivateKey, + }, + } + go oc.Process(testClient) time.Sleep(time.Second * 5) // Wait a bit longer - if !ricochetService.ReceivedMessage { + if !receivedMessage { t.Errorf("Test server did not receive message") } - } func TestServerInvalidKey(t *testing.T) { @@ -100,7 +122,7 @@ func TestServerCouldNotConnect(t *testing.T) { if err != nil { t.Errorf("Could not initate ricochet service: %v", err) } - err = ricochetService.Connect("127.0.0.1:65535|kwke2hntvyfqm7dr") + _, err = ricochetService.Connect("127.0.0.1:65535|kwke2hntvyfqm7dr") if err == nil { t.Errorf("Should not have been been able to connect to 127.0.0.1:65535|kwke2hntvyfqm7dr") } diff --git a/standardricochetservice_unauth_test.go b/standardricochetservice_unauth_test.go index df2c67c..b839878 100644 --- a/standardricochetservice_unauth_test.go +++ b/standardricochetservice_unauth_test.go @@ -10,10 +10,19 @@ import "log" type TestUnauthorizedService struct { StandardRicochetService +} + +func (ts *TestUnauthorizedService) OnNewConnection(oc *OpenConnection) { + go oc.Process(&StandardRicochetConnection{}) +} + +type TestUnauthorizedConnection struct { + StandardRicochetConnection FailedToOpen int } -func (ts *TestUnauthorizedService) OnConnect(oc *OpenConnection) { +func (tc *TestUnauthorizedConnection) OnReady(oc *OpenConnection) { + tc.StandardRicochetConnection.OnReady(oc) if oc.Client { log.Printf("Attempting Authentication Not Authorized") oc.IsAuthed = true // Connections to Servers are Considered Authenticated by Default @@ -23,10 +32,10 @@ func (ts *TestUnauthorizedService) OnConnect(oc *OpenConnection) { } } -func (ts *TestUnauthorizedService) OnFailedChannelOpen(oc *OpenConnection, channelID int32, errorType string) { - oc.UnsetChannel(channelID) +func (tc *TestUnauthorizedConnection) OnFailedChannelOpen(channelID int32, errorType string) { + tc.Conn.UnsetChannel(channelID) if errorType == "UnauthorizedError" { - ts.FailedToOpen++ + tc.FailedToOpen++ } } @@ -50,13 +59,19 @@ func TestUnauthorizedClientReject(t *testing.T) { } go ricochetService2.Listen(ricochetService2, 9881) - err = ricochetService2.Connect("127.0.0.1:9880|kwke2hntvyfqm7dr") + oc, err := ricochetService2.Connect("127.0.0.1:9880|kwke2hntvyfqm7dr") if err != nil { t.Errorf("Could not connect to ricochet service: %v", err) } + connectionHandler := &TestUnauthorizedConnection{ + StandardRicochetConnection: StandardRicochetConnection{ + PrivateKey: ricochetService2.PrivateKey, + }, + } + go oc.Process(connectionHandler) time.Sleep(time.Second * 2) - if ricochetService2.FailedToOpen != 2 { + if connectionHandler.FailedToOpen != 2 { t.Errorf("Test server did not reject open channels with unauthorized error") } diff --git a/standardricochetservice_unknown_contact_test.go b/standardricochetservice_unknown_contact_test.go index ba49fb4..2e7a702 100644 --- a/standardricochetservice_unknown_contact_test.go +++ b/standardricochetservice_unknown_contact_test.go @@ -6,29 +6,44 @@ import "log" type TestUnknownContactService struct { StandardRicochetService +} + +func (ts *TestUnknownContactService) OnNewConnection(oc *OpenConnection) { + go oc.Process(&TestUnknownContactConnection{}) +} + +type TestUnknownContactConnection struct { + StandardRicochetConnection FailedToOpen bool } -func (ts *TestUnknownContactService) OnAuthenticationResult(oc *OpenConnection, channelID int32, result bool, isKnownContact bool) { - log.Printf("Authentication Result") - ts.StandardRicochetService.OnAuthenticationResult(oc, channelID, result, isKnownContact) - oc.OpenChatChannel(5) -} - -func (ts *TestUnknownContactService) OnFailedChannelOpen(oc *OpenConnection, channelID int32, errorType string) { - log.Printf("Failed Channel Open %v", errorType) - oc.UnsetChannel(channelID) - if errorType == "UnauthorizedError" { - ts.FailedToOpen = true - } -} - -func (ts *TestUnknownContactService) IsKnownContact(hostname string) bool { +func (tc *TestUnknownContactConnection) IsKnownContact(hostname string) bool { return false } +func (tc *TestUnknownContactConnection) OnAuthenticationProof(channelID int32, publicKey, signature []byte) { + result := tc.Conn.ValidateProof(channelID, publicKey, signature) + tc.Conn.SendAuthenticationResult(channelID, result, false) + tc.Conn.IsAuthed = result + tc.Conn.CloseChannel(channelID) +} + +func (tc *TestUnknownContactConnection) OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) { + log.Printf("Authentication Result") + tc.StandardRicochetConnection.OnAuthenticationResult(channelID, result, isKnownContact) + tc.Conn.OpenChatChannel(5) +} + +func (tc *TestUnknownContactConnection) OnFailedChannelOpen(channelID int32, errorType string) { + log.Printf("Failed Channel Open %v", errorType) + tc.Conn.UnsetChannel(channelID) + if errorType == "UnauthorizedError" { + tc.FailedToOpen = true + } +} + func TestUnknownContactServer(t *testing.T) { - ricochetService := new(StandardRicochetService) + ricochetService := new(TestUnknownContactService) err := ricochetService.Init("./private_key") if err != nil { @@ -39,21 +54,19 @@ func TestUnknownContactServer(t *testing.T) { time.Sleep(time.Second * 2) - ricochetService2 := new(TestUnknownContactService) - err = ricochetService2.Init("./private_key") - - if err != nil { - t.Errorf("Could not initate ricochet service: %v", err) - } - - go ricochetService2.Listen(ricochetService2, 9883) - err = ricochetService2.Connect("127.0.0.1:9882|kwke2hntvyfqm7dr") + oc, err := ricochetService.Connect("127.0.0.1:9882|kwke2hntvyfqm7dr") if err != nil { t.Errorf("Could not connect to ricochet service: %v", err) } + connectionHandler := &TestUnknownContactConnection{ + StandardRicochetConnection: StandardRicochetConnection{ + PrivateKey: ricochetService.PrivateKey, + }, + } + go oc.Process(connectionHandler) time.Sleep(time.Second * 2) - if !ricochetService2.FailedToOpen { + if !connectionHandler.FailedToOpen { t.Errorf("Test server did receive message should have failed") }