go-ricochet/ricochet.go

99 lines
2.9 KiB
Go

package goricochet
import (
"github.com/s-rah/go-ricochet/utils"
"github.com/s-rah/go-ricochet/connection"
"io"
"net"
"log"
)
// Open establishes a protocol session on an established net.Conn, and returns a new
// OpenConnection instance representing this connection. On error, the connection
// will be closed. This function blocks until version negotiation has completed.
// The application should call Process() on the returned OpenConnection to continue
// handling protocol messages.
func Open(remoteHostname string) (*connection.Connection, error) {
networkResolver := utils.NetworkResolver{}
log.Printf("Connecting...")
conn, remoteHostname, err := networkResolver.Resolve(remoteHostname)
if err != nil {
return nil, err
}
log.Printf("Connected...negotiating version")
rc, err := negotiateVersion(conn, remoteHostname)
if err != nil {
conn.Close()
return nil, err
}
log.Printf("Connected...negotiated version")
return rc, nil
}
// negotiate version takes an open network connection and executes
// the ricochet version negotiation procedure.
func negotiateVersion(conn net.Conn, remoteHostname string) (*connection.Connection, error) {
versions := []byte{0x49, 0x4D, 0x01, 0x01}
if n, err := conn.Write(versions); err != nil || n < len(versions) {
return nil, utils.VersionNegotiationError
}
res := make([]byte, 1)
if _, err := io.ReadAtLeast(conn, res, len(res)); err != nil {
return nil, utils.VersionNegotiationError
}
if res[0] != 0x01 {
return nil, utils.VersionNegotiationFailed
}
rc := connection.NewOutboundConnection(conn,remoteHostname)
return rc, nil
}
// NegotiateVersionInbound takes in a connection and performs version negotiation
// as if that connection was a client. Returns a ricochet connection if successful
// error otherwise.
func NegotiateVersionInbound(conn net.Conn) (*connection.Connection, error) {
versions := []byte{0x49, 0x4D, 0x01, 0x01}
// Read version response header
header := make([]byte, 3)
if _, err := io.ReadAtLeast(conn, header, len(header)); err != nil {
return nil, err
}
if header[0] != versions[0] || header[1] != versions[1] || header[2] < 1 {
return nil, utils.VersionNegotiationError
}
// Read list of supported versions (which is header[2] bytes long)
versionList := make([]byte, header[2])
if _, err := io.ReadAtLeast(conn, versionList, len(versionList)); err != nil {
return nil, utils.VersionNegotiationError
}
selectedVersion := byte(0xff)
for _, v := range versionList {
if v == 0x01 {
selectedVersion = v
break
}
}
if n, err := conn.Write([]byte{selectedVersion}); err != nil || n < 1 {
return nil, utils.VersionNegotiationFailed
}
if selectedVersion == 0xff {
return nil, utils.VersionNegotiationFailed
}
rc := connection.NewInboundConnection(conn)
return rc, nil
}