ricochet-go/vendor/github.com/s-rah/go-ricochet/connection/connection.go

499 lines
16 KiB
Go

package connection
import (
"context"
"errors"
"fmt"
"github.com/golang/protobuf/proto"
"github.com/s-rah/go-ricochet/channels"
"github.com/s-rah/go-ricochet/utils"
"github.com/s-rah/go-ricochet/wire/control"
"io"
"log"
"sync"
)
// Connection encapsulates the state required to maintain a connection to
// a ricochet service.
type Connection struct {
utils.RicochetNetwork
channelManager *ChannelManager
// Ricochet Network Loop
packetChannel chan utils.RicochetData
errorChannel chan error
breakChannel chan bool
breakResultChannel chan error
unlockChannel chan bool
unlockResponseChannel chan bool
messageBuilder utils.MessageBuilder
trace bool
closed bool
closing bool
// This mutex is exclusively for preventing races during blocking
// interactions with Process; specifically Do and Break. Don't use
// it for anything else. See those functions for an explanation.
processBlockMutex sync.Mutex
Conn io.ReadWriteCloser
IsInbound bool
Authentication map[string]bool
RemoteHostname string
}
func (rc *Connection) init() {
rc.packetChannel = make(chan utils.RicochetData)
rc.errorChannel = make(chan error)
rc.breakChannel = make(chan bool)
rc.breakResultChannel = make(chan error)
rc.unlockChannel = make(chan bool)
rc.unlockResponseChannel = make(chan bool)
rc.Authentication = make(map[string]bool)
go rc.start()
}
// NewInboundConnection creates a new Connection struct
// modelling an Inbound Connection
func NewInboundConnection(conn io.ReadWriteCloser) *Connection {
rc := new(Connection)
rc.Conn = conn
rc.IsInbound = true
rc.init()
rc.channelManager = NewServerChannelManager()
return rc
}
// NewOutboundConnection creates a new Connection struct
// modelling an Inbound Connection
func NewOutboundConnection(conn io.ReadWriteCloser, remoteHostname string) *Connection {
rc := new(Connection)
rc.Conn = conn
rc.IsInbound = false
rc.init()
rc.RemoteHostname = remoteHostname
rc.channelManager = NewClientChannelManager()
return rc
}
func (rc *Connection) TraceLog(enabled bool) {
rc.trace = enabled
}
// start
func (rc *Connection) start() {
for {
packet, err := rc.RecvRicochetPacket(rc.Conn)
if err != nil {
rc.errorChannel <- err
return
}
rc.packetChannel <- packet
}
}
// Do allows any function utilizing Connection to be run safely, if you're
// careful. All operations which require access (directly or indirectly) to
// Connection while Process is running need to use Do. Calls to Do without
// Process running will block unless the connection is closed, which is
// returned as ConnectionClosedError.
//
// Like a mutex, Do cannot be called recursively. This will deadlock. As
// a result, no API in this library that can be reached from the application
// should use Do, with few exceptions. This would make the API impossible
// to use safely in many cases.
//
// Do is safe to call from methods of connection.Handler and channel.Handler
// that are called by Process.
func (rc *Connection) Do(do func() error) error {
// There's a complicated little dance here to prevent a race when the
// Process call is returning for a connection error. The problem is
// that if Do simply checked rc.closed and then tried to send, it's
// possible for Process to change rc.closed and stop reading before the
// send statement is executed, creating a deadlock.
//
// To prevent this, all of the functions that block on Process should
// do so by acquiring processBlockMutex, aborting if rc.closed is true,
// performing their blocking channel operations, and then releasing the
// mutex.
//
// This works because Process will always use a separate goroutine to
// acquire processBlockMutex before changing rc.closed, and the mutex
// guarantees that no blocking channel operation can happen during or
// after the value is changed. Since these operations block the Process
// loop, the behavior of multiple concurrent calls to Do/Break doesn't
// change: they just end up blocking on the mutex before blocking on the
// channel.
rc.processBlockMutex.Lock()
defer rc.processBlockMutex.Unlock()
if rc.closed {
return utils.ConnectionClosedError
}
// Force process to soft-break so we can lock
rc.traceLog("request unlocking of process loop for do()")
rc.unlockChannel <- true
rc.traceLog("process loop is unlocked for do()")
defer func() {
rc.traceLog("giving up lock process loop after do() ")
rc.unlockResponseChannel <- true
}()
// Process sets rc.closing when it's trying to acquire the mutex and
// close down the connection. Behave as if the connection was already
// closed.
if rc.closing {
return utils.ConnectionClosedError
}
return do()
}
// DoContext behaves in the same way as Do, but also respects the provided
// context when blocked, and passes the context to the callback function.
//
// DoContext should be used when any call to Do may need to be cancelled
// or timed out.
func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) error) error {
// .. see above
rc.processBlockMutex.Lock()
defer rc.processBlockMutex.Unlock()
if rc.closed {
return utils.ConnectionClosedError
}
// Force process to soft-break so we can lock
rc.traceLog("request unlocking of process loop for do()")
select {
case rc.unlockChannel <- true:
break
case <-ctx.Done():
rc.traceLog("giving up on unlocking process loop for do() because context cancelled")
return ctx.Err()
}
rc.traceLog("process loop is unlocked for do()")
defer func() {
rc.traceLog("giving up lock process loop after do() ")
rc.unlockResponseChannel <- true
}()
if rc.closing {
return utils.ConnectionClosedError
}
return do(ctx)
}
// RequestOpenChannel sends an OpenChannel message to the remote client.
// An error is returned only if the requirements for opening this channel
// are not met on the local side (a nil error return does not mean the
// channel was opened successfully, because channels open asynchronously).
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
// Check that we have the authentication already
if handler.RequiresAuthentication() != "none" {
// Enforce Authentication Check.
_, authed := rc.Authentication[handler.RequiresAuthentication()]
if !authed {
return nil, utils.UnauthorizedActionError
}
}
channel, err := rc.channelManager.OpenChannelRequest(handler)
if err != nil {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
return nil, err
}
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[handler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := handler.OpenOutbound(channel)
if err == nil {
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
rc.channelManager.RemoveChannel(channel.ID)
}
return channel, nil
}
// processUserCallback should be used to wrap any calls into handlers or
// application code from the Process goroutine. It handles calls to Do
// from within that code to prevent deadlocks.
func (rc *Connection) processUserCallback(cb func()) {
done := make(chan struct{})
go func() {
defer close(done)
cb()
}()
for {
select {
case <-done:
return
case <-rc.unlockChannel:
<-rc.unlockResponseChannel
}
}
}
// Process receives socket and protocol events for the connection. Methods
// of the application-provided `handler` will be called from this goroutine
// for all events.
//
// Process must be running in order to handle any events on the connection,
// including connection close.
//
// Process blocks until the connection is closed or until Break() is called.
// If the connection is closed, a non-nil error is returned.
func (rc *Connection) Process(handler Handler) error {
if rc.closed {
return utils.ConnectionClosedError
}
rc.traceLog("entering process loop")
rc.processUserCallback(func() { handler.OnReady(rc) })
// There are exactly two ways out of this loop: a signal on breakChannel
// caused by a call to Break, or a connection-fatal error on errorChannel.
//
// In the Break case, no particular care is necessary; it is the caller's
// responsibility to make sure there aren't e.g. concurrent calls to Do.
//
// Because connection errors can happen spontaneously, they must carefully
// prevent concurrent calls to Break or Do that could deadlock when Process
// returns.
for {
var packet utils.RicochetData
select {
case <-rc.unlockChannel:
<-rc.unlockResponseChannel
continue
case <-rc.breakChannel:
rc.traceLog("process has ended after break")
rc.breakResultChannel <- nil
return nil
case packet = <-rc.packetChannel:
break
case err := <-rc.errorChannel:
rc.Conn.Close()
rc.closing = true
// In order to safely close down concurrent calls to Do or Break,
// processBlockMutex must be held before setting rc.closed. That cannot
// happen in this goroutine, because one of those calls may already hold
// the mutex and be blocking on a channel send to this method. So the
// process here is to have a goroutine acquire the lock, set rc.closed, and
// signal back. Meanwhile, this one keeps handling unlockChannel and
// breakChannel.
closedChan := make(chan struct{})
go func() {
rc.processBlockMutex.Lock()
defer rc.processBlockMutex.Unlock()
rc.closed = true
close(closedChan)
}()
// Keep accepting calls from Do or Break until closedChan signals that they're
// safely shut down.
clearLoop:
for {
select {
case <-rc.unlockChannel:
<-rc.unlockResponseChannel
case <-rc.breakChannel:
rc.breakResultChannel <- utils.ConnectionClosedError
case <-closedChan:
break clearLoop
}
}
// This is the one case where processUserCallback isn't necessary, because
// all calls to Do immediately return ConnectionClosedError now.
handler.OnClosed(err)
return err
}
if packet.Channel == 0 {
rc.traceLog(fmt.Sprintf("received control packet on channel %d", packet.Channel))
res := new(Protocol_Data_Control.Packet)
err := proto.Unmarshal(packet.Data[:], res)
if err == nil {
// Wrap controlPacket in processUserCallback, since it calls out in many
// places, and wrapping the rest is harmless.
rc.processUserCallback(func() { rc.controlPacket(handler, res) })
}
} else {
// Let's check to see if we have defined this channel.
channel, found := rc.channelManager.GetChannel(packet.Channel)
if found {
if len(packet.Data) == 0 {
rc.traceLog(fmt.Sprintf("removing channel %d", packet.Channel))
rc.channelManager.RemoveChannel(packet.Channel)
rc.processUserCallback(func() { channel.Handler.Closed(utils.ChannelClosedByPeerError) })
} else {
rc.traceLog(fmt.Sprintf("received packet on %v channel %d", channel.Handler.Type(), packet.Channel))
// Send The Ricochet Packet to the Handler
rc.processUserCallback(func() { channel.Handler.Packet(packet.Data[:]) })
}
} else {
// When a non-zero packet is received for an unknown
// channel, the recipient responds by closing
// that channel.
rc.traceLog(fmt.Sprintf("received packet on unknown channel %d. closing.", packet.Channel))
if len(packet.Data) != 0 {
rc.SendRicochetPacket(rc.Conn, packet.Channel, []byte{})
}
}
}
}
}
func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.Packet) {
if res.GetOpenChannel() != nil {
opm := res.GetOpenChannel()
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
if err != nil {
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnknownTypeError")
rc.SendRicochetPacket(rc.Conn, 0, response)
return
}
// Check that we have the authentication already
if chandler.RequiresAuthentication() != "none" {
rc.traceLog(fmt.Sprintf("channel %v requires authorization of type %v", chandler.Type(), chandler.RequiresAuthentication()))
// Enforce Authentication Check.
_, authed := rc.Authentication[chandler.RequiresAuthentication()]
if !authed {
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
rc.traceLog(fmt.Sprintf("do not have required authorization to open channel type %v", chandler.Type()))
return
}
rc.traceLog("succeeded authorization check")
}
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
if err == nil {
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[chandler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := chandler.OpenInbound(channel, opm)
if err == nil && channel.Pending == false {
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
rc.channelManager.RemoveChannel(channel.ID)
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
}
} else {
// Send Error Packet
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "GenericError")
rc.traceLog(fmt.Sprintf("sending reject open channel for %v", opm.GetChannelIdentifier()))
rc.SendRicochetPacket(rc.Conn, 0, response)
}
} else if res.GetChannelResult() != nil {
cr := res.GetChannelResult()
id := cr.GetChannelIdentifier()
channel, found := rc.channelManager.GetChannel(id)
if !found {
rc.traceLog(fmt.Sprintf("channel result recived for unknown channel: %v", channel.Type, id))
return
}
if cr.GetOpened() {
rc.traceLog(fmt.Sprintf("channel of type %v opened on %v", channel.Type, id))
channel.Handler.OpenOutboundResult(nil, cr)
} else {
rc.traceLog(fmt.Sprintf("channel of type %v rejected on %v", channel.Type, id))
channel.Handler.OpenOutboundResult(errors.New(""), cr)
}
} else if res.GetKeepAlive() != nil {
// XXX Though not currently part of the protocol
// We should likely put these calls behind
// authentication.
rc.traceLog("received keep alive packet")
if res.GetKeepAlive().GetResponseRequested() {
messageBuilder := new(utils.MessageBuilder)
raw := messageBuilder.KeepAlive(true)
rc.traceLog("sending keep alive response")
rc.SendRicochetPacket(rc.Conn, 0, raw)
}
} else if res.GetEnableFeatures() != nil {
rc.traceLog("received features enabled packet")
messageBuilder := new(utils.MessageBuilder)
raw := messageBuilder.FeaturesEnabled([]string{})
rc.traceLog("sending featured enabled empty response")
rc.SendRicochetPacket(rc.Conn, 0, raw)
} else if res.GetFeaturesEnabled() != nil {
// TODO We should never send out an enabled features
// request.
rc.traceLog("sending unsolicited features enabled response")
}
}
func (rc *Connection) traceLog(message string) {
if rc.trace {
log.Printf(message)
}
}
// Break causes Process() to return, but does not close the underlying connection
// Break returns an error if it would not be valid to call Process() again for
// the connection now. Currently, the only such error is ConnectionClosedError.
func (rc *Connection) Break() error {
// See Do() for an explanation of the concurrency here; it's complicated.
// The summary is that this mutex prevents races on connection close that
// could lead to deadlocks in Block().
rc.processBlockMutex.Lock()
defer rc.processBlockMutex.Unlock()
if rc.closed {
rc.traceLog("ignoring break because connection is already closed")
return utils.ConnectionClosedError
}
rc.traceLog("breaking out of process loop")
rc.breakChannel <- true
return <-rc.breakResultChannel // Wait for Process to End
}
// Channel is a convienciance method for returning a given channel to the caller
// of Process() - TODO - this is kind of ugly.
func (rc *Connection) Channel(ctype string, way channels.Direction) *channels.Channel {
return rc.channelManager.Channel(ctype, way)
}