From f9bc09c5202ff0e65106974700af8135fb39899d Mon Sep 17 00:00:00 2001 From: John Brooks Date: Sun, 9 Oct 2016 17:31:26 -0700 Subject: [PATCH] core: Adapt to protocol API changes --- core/contact.go | 9 +- core/identity.go | 36 +++++++- core/protocol.go | 226 +++++++++++++++++++++-------------------------- core/ricochet.go | 2 - 4 files changed, 144 insertions(+), 129 deletions(-) diff --git a/core/contact.go b/core/contact.go index d0a871c..acb396f 100644 --- a/core/contact.go +++ b/core/contact.go @@ -221,7 +221,7 @@ func (c *Contact) connectOutbound(ctx context.Context, connChannel chan *protoco } log.Printf("Successful outbound connection to contact %s", hostname) - oc, err := c.core.Protocol.ConnectOpen(conn, hostname[0:16]) + oc, err := protocol.Open(conn, hostname[0:16]) if err != nil { log.Printf("Contact connection protocol failure: %s", err) oc.Close() @@ -242,6 +242,13 @@ func (c *Contact) connectOutbound(ctx context.Context, connChannel chan *protoco // OnConnectionClosed. Alternatively, it will break because this // is fragile and dumb. // XXX BUG: This means no backoff for authentication failure + handler := &ProtocolConnection{ + Conn: oc, + Contact: c, + MyHostname: c.core.Identity.Address()[9:], + PrivateKey: c.core.Identity.PrivateKey(), + } + go oc.Process(handler) return } } diff --git a/core/identity.go b/core/identity.go index f406796..05c8814 100644 --- a/core/identity.go +++ b/core/identity.go @@ -4,6 +4,7 @@ import ( "crypto/rsa" "encoding/base64" "errors" + protocol "github.com/s-rah/go-ricochet" "github.com/special/notricochet/core/utils" "github.com/yawning/bulb/utils/pkcs1" "log" @@ -98,6 +99,29 @@ func (me *Identity) setPrivateKey(key *rsa.PrivateKey) error { return nil } +type identityService struct { + Identity *Identity + MyHostname string +} + +func (is *identityService) OnNewConnection(oc *protocol.OpenConnection) { + log.Printf("Inbound connection accepted") + oc.MyHostname = is.MyHostname + // XXX Should have pre-auth handling, timeouts + identity := is.Identity + handler := &ProtocolConnection{ + Conn: oc, + GetContactByHostname: func(hostname string) *Contact { + return identity.ContactList().ContactByAddress("ricochet:" + hostname) + }, + } + go oc.Process(handler) +} + +func (is *identityService) OnFailedConnection(err error) { + log.Printf("Inbound connection failed: %v", err) +} + // BUG(special): No error handling for failures under publishService func (me *Identity) publishService(key *rsa.PrivateKey) { // This call will block until a control connection is available and the @@ -127,7 +151,17 @@ func (me *Identity) publishService(key *rsa.PrivateKey) { } log.Printf("Identity service published, accepting connections") - go me.core.Protocol.ServeListener(listener) + is := &identityService{ + Identity: me, + MyHostname: me.Address()[9:], + } + + err = protocol.Serve(listener, is) + if err != nil { + log.Printf("Identity listener failed: %v", err) + // XXX handle + return + } } func (me *Identity) Address() string { diff --git a/core/protocol.go b/core/protocol.go index 4b33787..05151ef 100644 --- a/core/protocol.go +++ b/core/protocol.go @@ -1,212 +1,188 @@ package core import ( + "crypto/rsa" "encoding/asn1" protocol "github.com/s-rah/go-ricochet" "log" - "net" "time" ) -type Protocol struct { - core *Ricochet +type ProtocolConnection struct { + Conn *protocol.OpenConnection + Contact *Contact - service *protocol.Ricochet - handler *protocolHandler + // Client-side authentication + MyHostname string + PrivateKey rsa.PrivateKey + + // Service-side authentication + GetContactByHostname func(hostname string) *Contact } -// Implements protocol.RicochetService -type protocolHandler struct { - p *Protocol -} - -func CreateProtocol(core *Ricochet) *Protocol { - p := &Protocol{ - core: core, - service: new(protocol.Ricochet), +func (pc *ProtocolConnection) OnReady(oc *protocol.OpenConnection) { + if pc.Conn != nil && pc.Conn != oc { + log.Panicf("ProtocolConnection is already assigned connection %v, but OnReady called for connection %v", pc.Conn, oc) } - p.handler = &protocolHandler{p: p} - p.service.Init() - return p -} -func (p *Protocol) ServeListener(listener net.Listener) { - p.service.ServeListener(p.handler, listener) -} + pc.Conn = oc -// Strangely, ServeListener starts a background routine that watches a channel -// on p.service for new connections and dispatches their events to the handler -// for the listener. API needs a little work here. -func (p *Protocol) ConnectOpen(conn net.Conn, host string) (*protocol.OpenConnection, error) { - oc, err := p.service.ConnectOpen(conn, host) - if err != nil { - return nil, err - } - oc.MyHostname = p.core.Identity.Address()[9:] - return oc, nil -} - -func (handler *protocolHandler) OnReady() { - log.Printf("protocol: OnReady") -} - -func (handler *protocolHandler) OnConnect(oc *protocol.OpenConnection) { - log.Printf("protocol: OnConnect: %v", oc) - if oc.Client { - log.Printf("Connected to %s", oc.OtherHostname) - oc.IsAuthed = true // Outbound connections are authenticated - oc.Authenticate(1) - } else { - // Strip ricochet: - oc.MyHostname = handler.p.core.Identity.Address()[9:] + if pc.Conn.Client { + log.Printf("Connected to %s", pc.Conn.OtherHostname) + pc.Conn.MyHostname = pc.MyHostname + pc.Conn.IsAuthed = true // Outbound connections are authenticated + pc.Conn.Authenticate(1) } } -func (handler *protocolHandler) OnDisconnect(oc *protocol.OpenConnection) { - log.Printf("protocol: OnDisconnect: %v", oc) - if oc.OtherHostname != "" { - contact := handler.p.core.Identity.ContactList().ContactByAddress("ricochet:" + oc.OtherHostname) - if contact != nil { - contact.OnConnectionClosed(oc) - } +func (pc *ProtocolConnection) OnDisconnect() { + log.Printf("protocol: OnDisconnect: %v", pc) + if pc.Contact != nil { + pc.Contact.OnConnectionClosed(pc.Conn) } } // Authentication Management -func (handler *protocolHandler) OnAuthenticationRequest(oc *protocol.OpenConnection, channelID int32, clientCookie [16]byte) { +func (pc *ProtocolConnection) OnAuthenticationRequest(channelID int32, clientCookie [16]byte) { log.Printf("protocol: OnAuthenticationRequest") - oc.ConfirmAuthChannel(channelID, clientCookie) + pc.Conn.ConfirmAuthChannel(channelID, clientCookie) } -func (handler *protocolHandler) OnAuthenticationChallenge(oc *protocol.OpenConnection, channelID int32, serverCookie [16]byte) { +func (pc *ProtocolConnection) OnAuthenticationChallenge(channelID int32, serverCookie [16]byte) { log.Printf("protocol: OnAuthenticationChallenge") - privateKey := handler.p.core.Identity.PrivateKey() - publicKeyBytes, _ := asn1.Marshal(privateKey.PublicKey) - oc.SendProof(1, serverCookie, publicKeyBytes, &privateKey) + publicKeyBytes, _ := asn1.Marshal(pc.PrivateKey.PublicKey) + pc.Conn.SendProof(1, serverCookie, publicKeyBytes, &pc.PrivateKey) } -func (handler *protocolHandler) OnAuthenticationProof(oc *protocol.OpenConnection, channelID int32, publicKey []byte, signature []byte, isKnownContact bool) { - result := oc.ValidateProof(channelID, publicKey, signature) +func (pc *ProtocolConnection) OnAuthenticationProof(channelID int32, publicKey []byte, signature []byte) { + result := pc.Conn.ValidateProof(channelID, publicKey, signature) - var contact *Contact if result { - if len(oc.OtherHostname) != 16 { - log.Printf("protocol: Invalid format for hostname '%s' in authentication proof", oc.OtherHostname) + if len(pc.Conn.OtherHostname) != 16 { + log.Printf("protocol: Invalid format for hostname '%s' in authentication proof", pc.Conn.OtherHostname) result = false } else { - contact = handler.p.core.Identity.ContactList().ContactByAddress("ricochet:" + oc.OtherHostname) + pc.Contact = pc.GetContactByHostname(pc.Conn.OtherHostname) } } - isKnownContact = (contact != nil) + isKnownContact := (pc.Contact != nil) - oc.SendAuthenticationResult(channelID, result, isKnownContact) - oc.IsAuthed = result - oc.CloseChannel(channelID) + pc.Conn.SendAuthenticationResult(channelID, result, isKnownContact) + pc.Conn.IsAuthed = result + pc.Conn.CloseChannel(channelID) - log.Printf("protocol: OnAuthenticationProof, result: %v, contact: %v", result, contact) - if result && contact != nil { - contact.OnConnectionAuthenticated(oc) + log.Printf("protocol: OnAuthenticationProof, result: %v, contact: %v", result, pc.Contact) + if result && pc.Contact != nil { + pc.Contact.OnConnectionAuthenticated(pc.Conn) } } -func (handler *protocolHandler) OnAuthenticationResult(oc *protocol.OpenConnection, channelID int32, result bool, isKnownContact bool) { - oc.IsAuthed = result - oc.CloseChannel(channelID) +func (pc *ProtocolConnection) OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) { + pc.Conn.IsAuthed = result + pc.Conn.CloseChannel(channelID) if !result { - log.Printf("protocol: Outbound connection authentication to %s failed", oc.OtherHostname) - oc.Close() + log.Printf("protocol: Outbound connection authentication to %s failed", pc.Conn.OtherHostname) + pc.Conn.Close() return } // XXX Contact request, removed cases if !isKnownContact { - log.Printf("protocol: Outbound connection authentication to %s succeeded, but we are not a known contact", oc.OtherHostname) - oc.Close() + log.Printf("protocol: Outbound connection authentication to %s succeeded, but we are not a known contact", pc.Conn.OtherHostname) + pc.Conn.Close() return } - contact := handler.p.core.Identity.ContactList().ContactByAddress("ricochet:" + oc.OtherHostname) - if contact == nil { - log.Printf("protocol: Outbound connection authenticated to %s succeeded, but no matching contact found", oc.OtherHostname) - oc.Close() - return + log.Printf("protocol: Outbound connection to %s authenticated", pc.Conn.OtherHostname) + if pc.Contact != nil { + pc.Contact.OnConnectionAuthenticated(pc.Conn) } - - log.Printf("protocol: Outbound connection to %s authenticated", oc.OtherHostname) - contact.OnConnectionAuthenticated(oc) } // Contact Management -func (handler *protocolHandler) IsKnownContact(hostname string) bool { - contact := handler.p.core.Identity.ContactList().ContactByAddress("ricochet:" + hostname) - return contact != nil +func (pc *ProtocolConnection) OnContactRequest(channelID int32, nick string, message string) { } -func (handler *protocolHandler) OnContactRequest(oc *protocol.OpenConnection, channelID int32, nick string, message string) { -} - -func (handler *protocolHandler) OnContactRequestAck(oc *protocol.OpenConnection, channelID int32, status string) { +func (pc *ProtocolConnection) OnContactRequestAck(channelID int32, status string) { } // Managing Channels -func (handler *protocolHandler) OnOpenChannelRequest(oc *protocol.OpenConnection, channelID int32, channelType string) { - log.Printf("open channel request: %v %v", channelID, channelType) - oc.AckOpenChannel(channelID, channelType) +func (pc *ProtocolConnection) IsChannelAllowed(channelType string) bool { + switch channelType { + case "im.ricochet.auth.hidden-service": + return !pc.Conn.IsAuthed && pc.Contact == nil + case "im.ricochet.chat": + return pc.Conn.IsAuthed && pc.Contact != nil + case "im.ricochet.contact.request": + return pc.Conn.IsAuthed && pc.Contact == nil + } + + return false } -func (handler *protocolHandler) OnOpenChannelRequestSuccess(oc *protocol.OpenConnection, channelID int32) { - log.Printf("open channel request success: %v %v", channelID) +func (pc *ProtocolConnection) OnOpenChannelRequest(channelID int32, channelType string) { + log.Printf("open channel request: %v %v", channelID, channelType) + pc.Conn.AckOpenChannel(channelID, channelType) } -func (handler *protocolHandler) OnChannelClosed(oc *protocol.OpenConnection, channelID int32) { + +func (pc *ProtocolConnection) OnOpenChannelRequestSuccess(channelID int32) { + log.Printf("open channel request success: %v", channelID) +} + +func (pc *ProtocolConnection) OnChannelClosed(channelID int32) { log.Printf("channel closed: %v", channelID) } // Chat Messages // XXX messageID should be (at least) uint32 -func (handler *protocolHandler) OnChatMessage(oc *protocol.OpenConnection, channelID int32, messageID int32, message string) { +func (pc *ProtocolConnection) OnChatMessage(channelID int32, messageID int32, message string) { // XXX no time delta? // XXX sanity checks, message contents, etc log.Printf("chat message: %d %d %s", channelID, messageID, message) - // XXX ugllly - contact := handler.p.core.Identity.ContactList().ContactByAddress("ricochet:" + oc.OtherHostname) - if contact != nil { - conversation := contact.Conversation() - conversation.Receive(uint64(messageID), time.Now().Unix(), message) + // XXX error case + if pc.Contact == nil { + pc.Conn.Close() } - oc.AckChatMessage(channelID, messageID) + // XXX cache? + conversation := pc.Contact.Conversation() + conversation.Receive(uint64(messageID), time.Now().Unix(), message) + + pc.Conn.AckChatMessage(channelID, messageID) } -func (handler *protocolHandler) OnChatMessageAck(oc *protocol.OpenConnection, channelID int32, messageID int32) { + +func (pc *ProtocolConnection) OnChatMessageAck(channelID int32, messageID int32) { // XXX no success log.Printf("chat ack: %d %d", channelID, messageID) - // XXX Also ugly - contact := handler.p.core.Identity.ContactList().ContactByAddress("ricochet:" + oc.OtherHostname) - if contact != nil { - conversation := contact.Conversation() - conversation.UpdateSentStatus(uint64(messageID), true) + // XXX error case + if pc.Contact == nil { + pc.Conn.Close() } + + conversation := pc.Contact.Conversation() + conversation.UpdateSentStatus(uint64(messageID), true) } // Handle Errors -func (handler *protocolHandler) OnFailedChannelOpen(oc *protocol.OpenConnection, channelID int32, errorType string) { +func (pc *ProtocolConnection) OnFailedChannelOpen(channelID int32, errorType string) { log.Printf("failed channel open: %d %s", channelID, errorType) - oc.UnsetChannel(channelID) + pc.Conn.UnsetChannel(channelID) } -func (handler *protocolHandler) OnGenericError(oc *protocol.OpenConnection, channelID int32) { - oc.RejectOpenChannel(channelID, "GenericError") +func (pc *ProtocolConnection) OnGenericError(channelID int32) { + pc.Conn.RejectOpenChannel(channelID, "GenericError") } -func (handler *protocolHandler) OnUnknownTypeError(oc *protocol.OpenConnection, channelID int32) { - oc.RejectOpenChannel(channelID, "UnknownTypeError") +func (pc *ProtocolConnection) OnUnknownTypeError(channelID int32) { + pc.Conn.RejectOpenChannel(channelID, "UnknownTypeError") } -func (handler *protocolHandler) OnUnauthorizedError(oc *protocol.OpenConnection, channelID int32) { - oc.RejectOpenChannel(channelID, "UnauthorizedError") +func (pc *ProtocolConnection) OnUnauthorizedError(channelID int32) { + pc.Conn.RejectOpenChannel(channelID, "UnauthorizedError") } -func (handler *protocolHandler) OnBadUsageError(oc *protocol.OpenConnection, channelID int32) { - oc.RejectOpenChannel(channelID, "BadUsageError") +func (pc *ProtocolConnection) OnBadUsageError(channelID int32) { + pc.Conn.RejectOpenChannel(channelID, "BadUsageError") } -func (handler *protocolHandler) OnFailedError(oc *protocol.OpenConnection, channelID int32) { - oc.RejectOpenChannel(channelID, "FailedError") +func (pc *ProtocolConnection) OnFailedError(channelID int32) { + pc.Conn.RejectOpenChannel(channelID, "FailedError") } diff --git a/core/ricochet.go b/core/ricochet.go index d9ec46b..839e54c 100644 --- a/core/ricochet.go +++ b/core/ricochet.go @@ -3,7 +3,6 @@ package core type Ricochet struct { Config *Config Network *Network - Protocol *Protocol Identity *Identity } @@ -11,7 +10,6 @@ func (core *Ricochet) Init(conf *Config) error { var err error core.Config = conf core.Network = CreateNetwork() - core.Protocol = CreateProtocol(core) core.Identity, err = CreateIdentity(core) if err != nil { return err