go-ricochet/connection/connection.go

354 lines
11 KiB
Go

package connection
import (
"errors"
"github.com/golang/protobuf/proto"
"github.com/s-rah/go-ricochet/channels"
"github.com/s-rah/go-ricochet/utils"
"github.com/s-rah/go-ricochet/wire/control"
"io"
"log"
"time"
"fmt"
)
// Connection encapsulates the state required to maintain a connection to
// a ricochet service.
type Connection struct {
utils.RicochetNetwork
channelManager *ChannelManager
// Ricochet Network Loop
packetChannel chan utils.RicochetData
errorChannel chan error
breakChannel chan bool
breakResultChannel chan bool
unlockChannel chan bool
unlockResponseChannel chan bool
messageBuilder utils.MessageBuilder
trace bool
Conn io.ReadWriteCloser
IsInbound bool
Authentication map[string]bool
RemoteHostname string
}
func (rc *Connection) init() {
rc.packetChannel = make(chan utils.RicochetData)
rc.errorChannel = make(chan error)
rc.breakChannel = make(chan bool)
rc.breakResultChannel = make(chan bool)
rc.unlockChannel = make(chan bool)
rc.unlockResponseChannel = make(chan bool)
rc.Authentication = make(map[string]bool)
go rc.start()
}
// NewInboundConnection creates a new Connection struct
// modelling an Inbound Connection
func NewInboundConnection(conn io.ReadWriteCloser) *Connection {
rc := new(Connection)
rc.Conn = conn
rc.IsInbound = true
rc.init()
rc.channelManager = NewServerChannelManager()
return rc
}
// NewOutboundConnection creates a new Connection struct
// modelling an Inbound Connection
func NewOutboundConnection(conn io.ReadWriteCloser, remoteHostname string) *Connection {
rc := new(Connection)
rc.Conn = conn
rc.IsInbound = false
rc.init()
rc.RemoteHostname = remoteHostname
rc.channelManager = NewClientChannelManager()
return rc
}
func (rc *Connection) TraceLog(enabled bool) {
rc.trace = enabled
}
// start
func (rc *Connection) start() {
for {
packet, err := rc.RecvRicochetPacket(rc.Conn)
if err != nil {
rc.errorChannel <- err
return
}
rc.packetChannel <- packet
}
}
// Do allows any function utilizing Connection to be run safetly.
// All operations which require access to Connection managed resources should
// use Do()
func (rc *Connection) Do(do func() error) error {
// Force process to soft-break so we can lock
rc.traceLog("request unlocking of process loop for do()")
rc.unlockChannel <- true
rc.traceLog("process loop is unlocked for do()")
ret := do()
rc.traceLog("giving up lock process loop after do() ")
rc.unlockResponseChannel <- true
return ret
}
// RequestOpenChannel sends an OpenChannel message to the remote client.
// and error is returned only if the requirements for opening this channel
// are not met on the local side (a nill error return does not mean the
// channel was opened successfully)
func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error {
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
return rc.Do(func() error {
chandler, err := handler.OnOpenChannelRequest(ctype)
if err != nil {
rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err))
return err
}
// Check that we have the authentication already
if chandler.RequiresAuthentication() != "none" {
// Enforce Authentication Check.
_, authed := rc.Authentication[chandler.RequiresAuthentication()]
if !authed {
return errors.New("connection is not auth'd")
}
}
channel, err := rc.channelManager.OpenChannelRequest(chandler)
if err != nil {
rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err))
return err
}
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[chandler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := chandler.OpenOutbound(channel)
if err == nil {
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err))
rc.channelManager.RemoveChannel(channel.ID)
}
return nil
})
}
// Process receives socket and protocol events for the connection. Methods
// of the application-provided `handler` will be called from this goroutine
// for all events.
//
// Process must be running in order to handle any events on the connection,
// including connection close.
//
// Process blocks until the connection is closed or until Break() is called.
// If the connection is closed, a non-nil error is returned.
func (rc *Connection) Process(handler Handler) error {
rc.traceLog("entering process loop")
handler.OnReady(rc)
breaked := false
for !breaked {
var packet utils.RicochetData
tick := time.Tick(30 * time.Second)
select {
case <-rc.unlockChannel:
<-rc.unlockResponseChannel
continue
case <-rc.breakChannel:
rc.traceLog("process has ended after break")
breaked = true
continue
case packet = <-rc.packetChannel:
break
case err := <-rc.errorChannel:
rc.Conn.Close()
handler.OnClosed(err)
return err
case <-tick:
rc.traceLog("peer timed out")
return errors.New("peer timed out")
}
if packet.Channel == 0 {
rc.traceLog(fmt.Sprintf("received control packet on channel %d", packet.Channel))
res := new(Protocol_Data_Control.Packet)
err := proto.Unmarshal(packet.Data[:], res)
if err == nil {
rc.controlPacket(handler, res)
}
} else {
// Let's check to see if we have defined this channel.
channel, found := rc.channelManager.GetChannel(packet.Channel)
if found {
if len(packet.Data) == 0 {
rc.traceLog(fmt.Sprintf("removing channel %d", packet.Channel))
rc.channelManager.RemoveChannel(packet.Channel)
(*channel.Handler).Closed(errors.New("channel closed by peer"))
} else {
rc.traceLog(fmt.Sprintf("received packet on %v channel %d", (*channel.Handler).Type(), packet.Channel))
// Send The Ricochet Packet to the Handler
(*channel.Handler).Packet(packet.Data[:])
}
} else {
// When a non-zero packet is received for an unknown
// channel, the recipient responds by closing
// that channel.
rc.traceLog(fmt.Sprintf("received packet on unknown channel %d. closing.", packet.Channel))
if len(packet.Data) != 0 {
rc.SendRicochetPacket(rc.Conn, packet.Channel, []byte{})
}
}
}
}
rc.breakResultChannel <- true
return nil
}
func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.Packet) {
if res.GetOpenChannel() != nil {
opm := res.GetOpenChannel()
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
if err != nil {
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnknownTypeError")
rc.SendRicochetPacket(rc.Conn, 0, response)
return
}
// Check that we have the authentication already
if chandler.RequiresAuthentication() != "none" {
rc.traceLog(fmt.Sprintf("channel %v requires authorization of type %v", chandler.Type(), chandler.RequiresAuthentication()))
// Enforce Authentication Check.
_, authed := rc.Authentication[chandler.RequiresAuthentication()]
if !authed {
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
rc.traceLog(fmt.Sprintf("do not have required authorization to open channel type %v", chandler.Type()))
return
}
rc.traceLog("succeeded authorization check")
}
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
if err == nil {
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[chandler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := chandler.OpenInbound(channel, opm)
if err == nil && channel.Pending == false {
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
rc.channelManager.RemoveChannel(channel.ID)
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
}
} else {
// Send Error Packet
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "GenericError")
rc.traceLog(fmt.Sprintf("sending reject open channel for %v", opm.GetChannelIdentifier()))
rc.SendRicochetPacket(rc.Conn, 0, response)
}
} else if res.GetChannelResult() != nil {
cr := res.GetChannelResult()
id := cr.GetChannelIdentifier()
channel, found := rc.channelManager.GetChannel(id)
if !found {
rc.traceLog(fmt.Sprintf("channel result recived for unknown channel: %v", channel.Type, id))
return
}
if cr.GetOpened() {
rc.traceLog(fmt.Sprintf("channel of type %v opened on %v", channel.Type, id))
(*channel.Handler).OpenOutboundResult(nil, cr)
} else {
rc.traceLog(fmt.Sprintf("channel of type %v rejected on %v", channel.Type, id))
(*channel.Handler).OpenOutboundResult(errors.New(""), cr)
}
} else if res.GetKeepAlive() != nil {
// XXX Though not currently part of the protocol
// We should likely put these calls behind
// authentication.
rc.traceLog("received keep alive packet")
if res.GetKeepAlive().GetResponseRequested() {
messageBuilder := new(utils.MessageBuilder)
raw := messageBuilder.KeepAlive(true)
rc.traceLog("sending keep alive response")
rc.SendRicochetPacket(rc.Conn, 0, raw)
}
} else if res.GetEnableFeatures() != nil {
rc.traceLog("received features enabled packet")
messageBuilder := new(utils.MessageBuilder)
raw := messageBuilder.FeaturesEnabled([]string{})
rc.traceLog("sending featured enabled empty response")
rc.SendRicochetPacket(rc.Conn, 0, raw)
} else if res.GetFeaturesEnabled() != nil {
// TODO We should never send out an enabled features
// request.
rc.traceLog("sending unsolicited features enabled response")
}
}
func (rc *Connection) traceLog(message string) {
if rc.trace {
log.Printf(message)
}
}
// Break causes Process() to return, but does not close the underlying connection
func (rc *Connection) Break() {
rc.traceLog("breaking out of process loop")
rc.breakChannel <- true
<-rc.breakResultChannel // Wait for Process to End
}
// Channel is a convienciance method for returning a given channel to the caller
// of Process() - TODO - this is kind of ugly.
func (rc *Connection) Channel(ctype string, way channels.Direction) *channels.Channel {
return rc.channelManager.Channel(ctype, way)
}