diff --git a/cli/cli.go b/cli/cli.go index 9a7fd21..6d35c06 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -2,21 +2,33 @@ package main import ( "bytes" + "flag" "fmt" "github.com/chzyer/readline" + ricochet "github.com/ricochet-im/ricochet-go/core" rpc "github.com/ricochet-im/ricochet-go/rpc" "google.golang.org/grpc" "log" + "net" "os" + "strings" + "time" ) -const ( - defaultAddress = "127.0.0.1:58281" -) +var ( + LogBuffer bytes.Buffer -var LogBuffer bytes.Buffer + // Flags + backendAddress string + unsafeBackend bool +) func main() { + flag.StringVar(&backendAddress, "backend", "", "Connect to the client backend running on `address`") + flag.BoolVar(&unsafeBackend, "allow-unsafe-backend", false, "Allow a remote backend address. This is NOT RECOMMENDED and may harm your security or privacy. Do not use without a secure, trusted link") + flag.Parse() + + // Set up readline input, err := readline.NewEx(&readline.Config{ InterruptPrompt: "^C", EOFPrompt: "exit", @@ -28,13 +40,15 @@ func main() { defer input.Close() log.SetOutput(&LogBuffer) - conn, err := grpc.Dial(defaultAddress, grpc.WithInsecure()) + // Connect to RPC backend, start in-process backend if necessary + conn, err := connectClientBackend() if err != nil { - fmt.Printf("connection failed: %v\n", err) + fmt.Printf("backend failed: %v\n", err) os.Exit(1) } defer conn.Close() + // Configure client and UI client := &Client{ Backend: rpc.NewRicochetCoreClient(conn), } @@ -44,8 +58,8 @@ func main() { Client: client, } + // Initialize data from backend and start UI command loop fmt.Print("Connecting to backend...\n") - go func() { if err := client.Initialize(); err != nil { fmt.Printf("Error: %s\n", err) @@ -58,3 +72,66 @@ func main() { Ui.CommandLoop() } + +func connectClientBackend() (*grpc.ClientConn, error) { + if backendAddress == "" { + // In-process backend, using 'InnerNet' as a fake socket + address, err := startLocalBackend() + if err != nil { + return nil, err + } + return grpc.Dial(address, grpc.WithInsecure(), grpc.WithDialer(DialInnerNet)) + } else { + // External backend + if strings.HasPrefix(backendAddress, "unix:") { + return grpc.Dial(backendAddress[5:], grpc.WithInsecure(), + grpc.WithDialer(func(address string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", address, timeout) + })) + } else { + host, _, err := net.SplitHostPort(backendAddress) + if err != nil { + return nil, err + } + ip := net.ParseIP(host) + if !unsafeBackend && (ip == nil || !ip.IsLoopback()) { + return nil, fmt.Errorf("Host '%s' is not a loopback address.\nRead the warnings and use -allow-unsafe-backend for non-local addresses", host) + } + + return grpc.Dial(backendAddress, grpc.WithInsecure()) + } + } +} + +func startLocalBackend() (string, error) { + config, err := ricochet.LoadConfig(".") + if err != nil { + return "", err + } + + core := new(ricochet.Ricochet) + if err := core.Init(config); err != nil { + return "", err + } + + listener, err := ListenInnerNet("ricochet.rpc") + if err != nil { + return "", err + } + + server := &ricochet.RpcServer{ + Core: core, + } + + go func() { + grpcServer := grpc.NewServer() + rpc.RegisterRicochetCoreServer(grpcServer, server) + err := grpcServer.Serve(listener) + if err != nil { + log.Printf("backend exited: %v", err) + os.Exit(1) + } + }() + + return "ricochet.rpc", nil +} diff --git a/cli/innernet.go b/cli/innernet.go new file mode 100644 index 0000000..00b73de --- /dev/null +++ b/cli/innernet.go @@ -0,0 +1,172 @@ +package main + +import ( + "errors" + "io" + "net" + "sync" + "time" +) + +var ( + listeners map[string]*InnerNetListener = make(map[string]*InnerNetListener) + listenersSync sync.Mutex +) + +// Compatible with net.Listener, but uses io.Pipe for in-process +// communication without any real sockets. +type InnerNetListener struct { + connChannel chan *InnerNetConn + connSync sync.Mutex + addr InnerNetAddr + closed bool +} + +type InnerNetAddr struct { + addr string +} + +type InnerNetConn struct { + readPipe *io.PipeReader + writePipe *io.PipeWriter + localAddr InnerNetAddr + remoteAddr InnerNetAddr + readDeadline time.Time + writeDeadline time.Time +} + +func ListenInnerNet(addr string) (*InnerNetListener, error) { + listener := &InnerNetListener{ + connChannel: make(chan *InnerNetConn), + addr: InnerNetAddr{addr: addr}, + } + + listenersSync.Lock() + defer listenersSync.Unlock() + if _, exists := listeners[addr]; exists { + return nil, errors.New("Server already exists") + } + listeners[addr] = listener + + return listener, nil +} + +func DialInnerNet(addr string, timeout time.Duration) (net.Conn, error) { + listenersSync.Lock() + listener := listeners[addr] + listenersSync.Unlock() + if listener == nil { + return nil, errors.New("Server does not exist") + } + + var returnErr error + result := make(chan *InnerNetConn) + + go func() { + listener.connSync.Lock() + defer listener.connSync.Unlock() + if listener.closed { + returnErr = errors.New("Connection refused") + result <- nil + return + } + + clientRead, serverWrite := io.Pipe() + serverRead, clientWrite := io.Pipe() + client := &InnerNetConn{ + readPipe: clientRead, + writePipe: clientWrite, + remoteAddr: InnerNetAddr{addr: addr}, + } + server := &InnerNetConn{ + readPipe: serverRead, + writePipe: serverWrite, + localAddr: InnerNetAddr{addr: addr}, + } + + listener.connChannel <- server + result <- client + }() + + select { + case re := <-result: + return re, returnErr + case <-time.After(timeout): + return nil, errors.New("Connection timeout") + } +} + +func (l *InnerNetListener) Accept() (net.Conn, error) { + if l.closed { + return nil, errors.New("Closed") + } + + conn, ok := <-l.connChannel + if !ok { + return nil, errors.New("Closed") + } + return conn, nil +} + +func (l *InnerNetListener) Close() error { + if l.closed { + return nil + } + + listenersSync.Lock() + delete(listeners, l.addr.addr) + listenersSync.Unlock() + + l.connSync.Lock() + l.closed = true + close(l.connChannel) + l.connSync.Unlock() + + return nil +} + +func (l *InnerNetListener) Addr() net.Addr { + return l.addr +} + +func (a InnerNetAddr) Network() string { + return "innernet" +} + +func (a InnerNetAddr) String() string { + return a.addr +} + +func (c *InnerNetConn) Read(b []byte) (int, error) { + return c.readPipe.Read(b) +} + +func (c *InnerNetConn) Write(b []byte) (int, error) { + return c.writePipe.Write(b) +} + +func (c *InnerNetConn) Close() error { + c.writePipe.Close() + c.readPipe.Close() + return nil +} + +func (c *InnerNetConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *InnerNetConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *InnerNetConn) SetDeadline(t time.Time) error { + return errors.New("Not implemented") +} + +func (c *InnerNetConn) SetReadDeadline(t time.Time) error { + return errors.New("Not implemented") +} + +func (c *InnerNetConn) SetWriteDeadline(t time.Time) error { + return errors.New("Not implemented") +}