go-ricochet/connection/connection.go

322 lines
8.8 KiB
Go

package connection
import (
"errors"
"github.com/golang/protobuf/proto"
"github.com/s-rah/go-ricochet/channels"
"github.com/s-rah/go-ricochet/utils"
"github.com/s-rah/go-ricochet/wire/control"
"io"
"log"
"time"
)
// Connection encapsulates the state required to maintain a connection to
// a ricochet service.
type Connection struct {
utils.RicochetNetwork
channelManager *ChannelManager
// Ricochet Network Loop
packetChannel chan utils.RicochetData
errorChannel chan error
breakChannel chan bool
breakResultChannel chan bool
unlockChannel chan bool
unlockResponseChannel chan bool
messageBuilder utils.MessageBuilder
Conn io.ReadWriteCloser
IsInbound bool
Authentication map[string]bool
RemoteHostname string
}
func (rc *Connection) init() {
rc.packetChannel = make(chan utils.RicochetData)
rc.errorChannel = make(chan error)
rc.breakChannel = make(chan bool)
rc.breakResultChannel = make(chan bool)
rc.unlockChannel = make(chan bool)
rc.unlockResponseChannel = make(chan bool)
rc.Authentication = make(map[string]bool)
go rc.start()
}
// NewInboundConnection creates a new Connection struct
// modelling an Inbound Connection
func NewInboundConnection(conn io.ReadWriteCloser) *Connection {
rc := new(Connection)
rc.Conn = conn
rc.IsInbound = true
rc.init()
rc.channelManager = NewServerChannelManager()
return rc
}
// NewOutboundConnection creates a new Connection struct
// modelling an Inbound Connection
func NewOutboundConnection(conn io.ReadWriteCloser, remoteHostname string) *Connection {
rc := new(Connection)
rc.Conn = conn
rc.IsInbound = false
rc.init()
rc.RemoteHostname = remoteHostname
rc.channelManager = NewClientChannelManager()
return rc
}
// start
func (rc *Connection) start() {
for {
packet, err := rc.RecvRicochetPacket(rc.Conn)
if err != nil {
rc.errorChannel <- err
return
}
rc.packetChannel <- packet
}
}
// Do allows any function utilizing Connection to be run safetly.
// All operations which require access to Connection managed resources should
// use Do()
func (rc *Connection) Do(do func() error) error {
// Force process to soft-break so we can lock
log.Printf("UnLocking Processloop")
rc.unlockChannel <- true
log.Printf("Unlocked Processloop")
ret := do()
log.Printf("Giving up lock Processloop")
rc.unlockResponseChannel <- true
return ret
}
// RequestOpenChannel sends an OpenChannel message to the remote client.
// and error is returned only if the requirements for opening this channel
// are not met on the local side (a nill error return does not mean the
// channel was opened successfully)
func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error {
return rc.Do(func() error {
chandler, err := handler.OnOpenChannelRequest(ctype)
if err != nil {
return err
}
// Check that we have the authentication already
if chandler.RequiresAuthentication() != "none" {
// Enforce Authentication Check.
_, authed := rc.Authentication[chandler.RequiresAuthentication()]
if !authed {
return errors.New("connection is not auth'd")
}
}
channel, err := rc.channelManager.OpenChannelRequest(chandler)
if err != nil {
return err
}
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[chandler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := chandler.OpenOutbound(channel)
if err == nil {
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.channelManager.RemoveChannel(channel.ID)
}
return nil
})
}
// Process receives socket and protocol events for the connection. Methods
// of the application-provided `handler` will be called from this goroutine
// for all events.
//
// Process must be running in order to handle any events on the connection,
// including connection close.
//
// Process blocks until the connection is closed or until Break() is called.
// If the connection is closed, a non-nil error is returned.
func (rc *Connection) Process(handler Handler) error {
log.Printf("Entering Processloop")
handler.OnReady(rc)
breaked := false
for !breaked {
var packet utils.RicochetData
tick := time.Tick(30 * time.Second)
select {
case <-rc.unlockChannel:
<-rc.unlockResponseChannel
continue
case <-rc.breakChannel:
log.Printf("Process has Ended as Expected!!!")
breaked = true
continue
case packet = <-rc.packetChannel:
break
case err := <-rc.errorChannel:
rc.Conn.Close()
handler.OnClosed(err)
return err
case <-tick:
log.Printf("timeout")
return errors.New("peer timed out")
}
log.Printf("Received Packet on Channel %d", packet.Channel)
if packet.Channel == 0 {
res := new(Protocol_Data_Control.Packet)
err := proto.Unmarshal(packet.Data[:], res)
if err == nil {
rc.controlPacket(handler, res)
}
} else {
// Let's check to see if we have defined this channel.
channel, found := rc.channelManager.GetChannel(packet.Channel)
if found {
if len(packet.Data) == 0 {
rc.channelManager.RemoveChannel(packet.Channel)
(*channel.Handler).Closed(errors.New("channel closed by peer"))
} else {
// Send The Ricochet Packet to the Handler
(*channel.Handler).Packet(packet.Data[:])
}
} else {
// When a non-zero packet is received for an unknown
// channel, the recipient responds by closing
// that channel.
if len(packet.Data) != 0 {
rc.SendRicochetPacket(rc.Conn, packet.Channel, []byte{})
}
}
}
}
rc.breakResultChannel <- true
return nil
}
func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.Packet) {
if res.GetOpenChannel() != nil {
opm := res.GetOpenChannel()
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
if err != nil {
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnknownTypeError")
rc.SendRicochetPacket(rc.Conn, 0, response)
return
}
// Check that we have the authentication already
if chandler.RequiresAuthentication() != "none" {
// Enforce Authentication Check.
_, authed := rc.Authentication[chandler.RequiresAuthentication()]
if !authed {
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
return
}
}
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
if err == nil {
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[chandler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := chandler.OpenInbound(channel, opm)
if err == nil && channel.Pending == false {
log.Printf("Opening Channel %v on %v", channel.Type, channel.ID)
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.channelManager.RemoveChannel(channel.ID)
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
}
} else {
// Send Error Packet
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "GenericError")
rc.SendRicochetPacket(rc.Conn, 0, response)
}
} else if res.GetChannelResult() != nil {
cr := res.GetChannelResult()
id := cr.GetChannelIdentifier()
channel, found := rc.channelManager.GetChannel(id)
if !found {
return
}
if cr.GetOpened() {
(*channel.Handler).OpenOutboundResult(nil, cr)
} else {
(*channel.Handler).OpenOutboundResult(errors.New(""), cr)
}
} else if res.GetKeepAlive() != nil {
// XXX Though not currently part of the protocol
// We should likely put these calls behind
// authentication.
if res.GetKeepAlive().GetResponseRequested() {
messageBuilder := new(utils.MessageBuilder)
raw := messageBuilder.KeepAlive(true)
rc.SendRicochetPacket(rc.Conn, 0, raw)
}
} else if res.GetEnableFeatures() != nil {
// TODO Respond with an Empty List
messageBuilder := new(utils.MessageBuilder)
raw := messageBuilder.FeaturesEnabled([]string{})
rc.SendRicochetPacket(rc.Conn, 0, raw)
} else if res.GetFeaturesEnabled() != nil {
// TODO We should never send out an enabled features
// request.
}
}
// Break causes Process() to return, but does not close the underlying connection
func (rc *Connection) Break() {
log.Printf("breaking...")
rc.breakChannel <- true
<-rc.breakResultChannel // Wait for Process to End
}
// Channel is a convienciance method for returning a given channel to the caller
// of Process() - TODO - this is kind of ugly.
func (rc *Connection) Channel(ctype string, way channels.Direction) *channels.Channel {
return rc.channelManager.Channel(ctype, way)
}