diff --git a/.travis.yml b/.travis.yml index f996bec..3c2946f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,6 +16,6 @@ install: script: - - cd $TRAVIS_BUILD_DIR && ./tests.sh + - cd $TRAVIS_BUILD_DIR && ./testing/tests.sh - test -z "$GOFMT" - goveralls -coverprofile=./coverage.out -service travis-ci diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json new file mode 100644 index 0000000..1ca7ee2 --- /dev/null +++ b/Godeps/Godeps.json @@ -0,0 +1,15 @@ +{ + "ImportPath": "github.com/s-rah/go-ricochet", + "GoVersion": "go1.7", + "GodepVersion": "v79", + "Deps": [ + { + "ImportPath": "github.com/golang/protobuf/proto", + "Rev": "8ee79997227bf9b34611aee7946ae64735e6fd93" + }, + { + "ImportPath": "golang.org/x/net/proxy", + "Rev": "60c41d1de8da134c05b7b40154a9a82bf5b7edb9" + } + ] +} diff --git a/Godeps/Readme b/Godeps/Readme new file mode 100644 index 0000000..4cdaa53 --- /dev/null +++ b/Godeps/Readme @@ -0,0 +1,5 @@ +This directory tree is generated automatically by godep. + +Please do not edit. + +See https://github.com/tools/godep for more information. diff --git a/LICENSE b/LICENSE index 51326fa..53fed65 100644 --- a/LICENSE +++ b/LICENSE @@ -25,8 +25,8 @@ SOFTWARE. -------------------------------------------------------------------------------- -Autogenerated protobuf code was generated using the proto file from Ricochet. -They are covered under the following license. +Autogenerated protobuf code was generated using the proto file from Ricochet. +They are covered under the following license. Ricochet - https://ricochet.im/ Copyright (C) 2014, John Brooks @@ -61,10 +61,4 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- -The go-ricochet logo is based on an image by Olga Shalakhina - who in turn modified the original gopher images made by -Renee French. The image is licensed under Creative Commons 3.0 Attributions. - --------------------------------------------------------------------------------- - go-ricochet is not affiliated with or endorsed by Ricochet.im or the Tor Project. diff --git a/README.md b/README.md index bcdea29..8fd2b2f 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ # GoRicochet [![Build Status](https://travis-ci.org/s-rah/go-ricochet.svg?branch=master)](https://travis-ci.org/s-rah/go-ricochet) [![Go Report Card](https://goreportcard.com/badge/github.com/s-rah/go-ricochet)](https://goreportcard.com/report/github.com/s-rah/go-ricochet) [![Coverage Status](https://coveralls.io/repos/github/s-rah/go-ricochet/badge.svg?branch=master)](https://coveralls.io/github/s-rah/go-ricochet?branch=master) -![GoRicochet](logo.png) - GoRicochet is an experimental implementation of the [Ricochet Protocol](https://ricochet.im) in Go. diff --git a/application/application.go b/application/application.go new file mode 100644 index 0000000..04ebef8 --- /dev/null +++ b/application/application.go @@ -0,0 +1,36 @@ +package application + +import ( + "errors" + "github.com/s-rah/go-ricochet/channels" + "github.com/s-rah/go-ricochet/connection" +) + +// RicochetApplication bundles many useful constructs that are +// likely standard in a ricochet application +type RicochetApplication struct { + connection *Connection +} + +// NewRicochetApplication ... +func NewRicochetApplication(connection *Connection) *RicochetApplication { + ra := new(RicochetApplication) + ra.connection = connection + return ra +} + +// SendMessage ... +func (ra *RicochetApplication) SendChatMessage(message []string) error { + return ra.connection.Do(func() error { + channel := ra.connection.Channel("im.ricochet.chat", channels.Outbound) + if channel != nil { + chatchannel, ok := (*channel.Handler).(*channels.ChatChannel) + if ok { + chatchannel.SendMessage(message) + } + } else { + return errors.New("") + } + return nil + }) +} diff --git a/authhandler.go b/authhandler.go deleted file mode 100644 index f2890ef..0000000 --- a/authhandler.go +++ /dev/null @@ -1,57 +0,0 @@ -package goricochet - -import ( - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "io" -) - -// AuthenticationHandler manages the state required for the AuthHiddenService -// authentication scheme for ricochet. -type AuthenticationHandler struct { - clientCookie [16]byte - serverCookie [16]byte -} - -// AddClientCookie adds a client cookie to the state. -func (ah *AuthenticationHandler) AddClientCookie(cookie []byte) { - copy(ah.clientCookie[:], cookie[:16]) -} - -// AddServerCookie adds a server cookie to the state. -func (ah *AuthenticationHandler) AddServerCookie(cookie []byte) { - copy(ah.serverCookie[:], cookie[:16]) -} - -// GenRandom generates a random 16byte cookie string. -func (ah *AuthenticationHandler) GenRandom() [16]byte { - var cookie [16]byte - io.ReadFull(rand.Reader, cookie[:]) - return cookie -} - -// GenClientCookie generates and adds a client cookie to the state. -func (ah *AuthenticationHandler) GenClientCookie() [16]byte { - ah.clientCookie = ah.GenRandom() - return ah.clientCookie -} - -// GenServerCookie generates and adds a server cookie to the state. -func (ah *AuthenticationHandler) GenServerCookie() [16]byte { - ah.serverCookie = ah.GenRandom() - return ah.serverCookie -} - -// GenChallenge constructs the challenge parameter for the AuthHiddenService session. -// The challenge is the a Sha256HMAC(clientHostname+serverHostname, key=clientCookie+serverCookie) -func (ah *AuthenticationHandler) GenChallenge(clientHostname string, serverHostname string) []byte { - key := make([]byte, 32) - copy(key[0:16], ah.clientCookie[:]) - copy(key[16:], ah.serverCookie[:]) - value := []byte(clientHostname + serverHostname) - mac := hmac.New(sha256.New, key) - mac.Write(value) - hmac := mac.Sum(nil) - return hmac -} diff --git a/authhandler_test.go b/authhandler_test.go deleted file mode 100644 index bfcdca1..0000000 --- a/authhandler_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package goricochet - -import "testing" -import "bytes" - -func TestGenChallenge(t *testing.T) { - authHandler := new(AuthenticationHandler) - authHandler.AddClientCookie([]byte("abcdefghijklmnop")) - authHandler.AddServerCookie([]byte("qrstuvwxyz012345")) - challenge := authHandler.GenChallenge("test.onion", "notareal.onion") - expectedChallenge := []byte{0xf5, 0xdb, 0xfd, 0xf0, 0x3d, 0x94, 0x14, 0xf1, 0x4b, 0x37, 0x93, 0xe2, 0xa5, 0x11, 0x4a, 0x98, 0x31, 0x90, 0xea, 0xb8, 0x95, 0x7a, 0x2e, 0xaa, 0xd0, 0xd2, 0x0c, 0x74, 0x95, 0xba, 0xab, 0x73} - t.Log(challenge, expectedChallenge) - if bytes.Compare(challenge[:], expectedChallenge[:]) != 0 { - t.Errorf("AuthenticationHandler Challenge Is Invalid, Got %x, Expected %x", challenge, expectedChallenge) - } -} - -func TestGenClientCookie(t *testing.T) { - authHandler := new(AuthenticationHandler) - clientCookie := authHandler.GenClientCookie() - if clientCookie != authHandler.clientCookie { - t.Errorf("AuthenticationHandler Client Cookies are Different %x %x", clientCookie, authHandler.clientCookie) - } -} - -func TestGenServerCookie(t *testing.T) { - authHandler := new(AuthenticationHandler) - serverCookie := authHandler.GenServerCookie() - if serverCookie != authHandler.serverCookie { - t.Errorf("AuthenticationHandler Server Cookies are Different %x %x", serverCookie, authHandler.serverCookie) - } -} diff --git a/channels/channel.go b/channels/channel.go new file mode 100644 index 0000000..cc263c4 --- /dev/null +++ b/channels/channel.go @@ -0,0 +1,34 @@ +package channels + +// Direction indicated whether we or the remote peer opened the channel +type Direction int + +const ( + // Inbound indcates the channel was opened by the remote peer + Inbound Direction = iota + // Outbound indicated the channel was opened by us + Outbound +) + +// AuthChannelResult captures the result of an authentication flow +type AuthChannelResult struct { + Accepted bool + IsKnownContact bool +} + +// Channel holds the state of a channel on an open connection +type Channel struct { + ID int32 + + Type string + Direction Direction + Handler *Handler + Pending bool + ServerHostname string + ClientHostname string + + // Functions for updating the underlying Connection + SendMessage func([]byte) + CloseChannel func() + DelegateAuthorization func() +} diff --git a/channels/chatchannel.go b/channels/chatchannel.go new file mode 100644 index 0000000..c4e6e19 --- /dev/null +++ b/channels/chatchannel.go @@ -0,0 +1,146 @@ +package channels + +import ( + "crypto/rand" + "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/utils" + "github.com/s-rah/go-ricochet/wire/chat" + "github.com/s-rah/go-ricochet/wire/control" + "math" + "math/big" + "time" +) + +// ChatChannel implements the ChannelHandler interface for a channel of +// type "im.ricochet.chat". The channel may be inbound or outbound. +// +// ChatChannel implements protocol-level sanity and state validation, but +// does not handle or acknowledge chat messages. The application must provide +// a ChatChannelHandler implementation to handle chat events. +type ChatChannel struct { + // Methods of Handler are called for chat events on this channel + Handler ChatChannelHandler + channel *Channel + lastMessageID uint32 +} + +// ChatChannelHandler is implemented by an application type to receive +// events from a ChatChannel. +// +// Note that ChatChannelHandler is composable with other interfaces, including +// ConnectionHandler; there is no need to use a distinct type as a +// ChatChannelHandler. +type ChatChannelHandler interface { + // ChatMessage is called when a chat message is received. Return true to acknowledge + // the message successfully, and false to NACK and refuse the message. + ChatMessage(messageID uint32, when time.Time, message string) bool + // ChatMessageAck is called when an acknowledgement of a sent message is received. + ChatMessageAck(messageID uint32) +} + +// SendMessage sends a given message using this channe +func (cc *ChatChannel) SendMessage(message string) { + messageBuilder := new(utils.MessageBuilder) + //TODO Implement Chat Number + data := messageBuilder.ChatMessage(message, cc.lastMessageID) + cc.lastMessageID++ + cc.channel.SendMessage(data) +} + +// Acknowledge indicates the given messageID was received +func (cc *ChatChannel) Acknowledge(messageID uint32) { + messageBuilder := new(utils.MessageBuilder) + cc.channel.SendMessage(messageBuilder.AckChatMessage(messageID)) +} + +// Type returns the type string for this channel, e.g. "im.ricochet.chat". +func (cc *ChatChannel) Type() string { + return "im.ricochet.chat" +} + +// Closed is called when the channel is closed for any reason. +func (cc *ChatChannel) Closed(err error) { + +} + +// OnlyClientCanOpen - for chat channels any side can open +func (cc *ChatChannel) OnlyClientCanOpen() bool { + return false +} + +// Singleton - for chat channels there can only be one instance per direction +func (cc *ChatChannel) Singleton() bool { + return true +} + +// Bidirectional - for chat channels are not bidrectional +func (cc *ChatChannel) Bidirectional() bool { + return false +} + +// RequiresAuthentication - chat channels require hidden service auth +func (cc *ChatChannel) RequiresAuthentication() string { + return "im.ricochet.auth.hidden-service" +} + +// OpenInbound is the first method called for an inbound channel request. +// If an error is returned, the channel is rejected. If a RawMessage is +// returned, it will be sent as the ChannelResult message. +func (cc *ChatChannel) OpenInbound(channel *Channel, raw *Protocol_Data_Control.OpenChannel) ([]byte, error) { + cc.channel = channel + id, err := rand.Int(rand.Reader, big.NewInt(math.MaxUint32)) + if err != nil { + return nil, err + } + cc.lastMessageID = uint32(id.Uint64()) + cc.channel.Pending = false + messageBuilder := new(utils.MessageBuilder) + return messageBuilder.AckOpenChannel(channel.ID), nil +} + +// OpenOutbound is the first method called for an outbound channel request. +// If an error is returned, the channel is not opened. If a RawMessage is +// returned, it will be sent as the OpenChannel message. +func (cc *ChatChannel) OpenOutbound(channel *Channel) ([]byte, error) { + cc.channel = channel + id, err := rand.Int(rand.Reader, big.NewInt(math.MaxUint32)) + if err != nil { + return nil, err + } + cc.lastMessageID = uint32(id.Uint64()) + messageBuilder := new(utils.MessageBuilder) + return messageBuilder.OpenChannel(channel.ID, cc.Type()), nil +} + +// OpenOutboundResult is called when a response is received for an +// outbound OpenChannel request. If `err` is non-nil, the channel was +// rejected and Closed will be called immediately afterwards. `raw` +// contains the raw protocol message including any extension data. +func (cc *ChatChannel) OpenOutboundResult(err error, crm *Protocol_Data_Control.ChannelResult) { + if err == nil { + if crm.GetOpened() { + cc.channel.Pending = false + } + } +} + +// Packet is called for each raw packet received on this channel. +func (cc *ChatChannel) Packet(data []byte) { + if !cc.channel.Pending { + res := new(Protocol_Data_Chat.Packet) + err := proto.Unmarshal(data, res) + if err == nil { + if res.GetChatMessage() != nil { + ack := cc.Handler.ChatMessage(res.GetChatMessage().GetMessageId(), time.Now(), res.GetChatMessage().GetMessageText()) + if ack { + cc.Acknowledge(res.GetChatMessage().GetMessageId()) + } else { + //XXX + } + } else if res.GetChatAcknowledge() != nil { + cc.Handler.ChatMessageAck(res.GetChatMessage().GetMessageId()) + } + // XXX? + } + } +} diff --git a/channels/chatchannel_test.go b/channels/chatchannel_test.go new file mode 100644 index 0000000..b404258 --- /dev/null +++ b/channels/chatchannel_test.go @@ -0,0 +1,124 @@ +package channels + +import ( + "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/utils" + "github.com/s-rah/go-ricochet/wire/chat" + "github.com/s-rah/go-ricochet/wire/control" + "testing" + "time" +) + +func TestChatChannelOptions(t *testing.T) { + chatChannel := new(ChatChannel) + + if chatChannel.Type() != "im.ricochet.chat" { + t.Errorf("ChatChannel has wrong type %s", chatChannel.Type()) + } + + if chatChannel.OnlyClientCanOpen() { + t.Errorf("ChatChannel should be able to be opened by everyone") + } + if !chatChannel.Singleton() { + t.Errorf("ChatChannel should be a Singelton") + } + if chatChannel.Bidirectional() { + t.Errorf("ChatChannel should not be bidirectional") + } + if chatChannel.RequiresAuthentication() != "im.ricochet.auth.hidden-service" { + t.Errorf("ChatChannel should require im.ricochet.auth.hidden-service. Instead requires: %s", chatChannel.RequiresAuthentication()) + } +} + +func TestChatChannelOpenInbound(t *testing.T) { + messageBuilder := new(utils.MessageBuilder) + ocm := messageBuilder.OpenChannel(2, "im.ricochet.chat") + + // We have just constructed this so there is little + // point in doing error checking here in the test + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(ocm[:], res) + opm := res.GetOpenChannel() + + chatChannel := new(ChatChannel) + channel := Channel{ID: 1} + response, err := chatChannel.OpenInbound(&channel, opm) + + if err == nil { + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(response[:], res) + } else { + t.Errorf("Error while parsing chatchannel openinbound output: %v", err) + } +} + +func TestChatChannelOpenOutbound(t *testing.T) { + chatChannel := new(ChatChannel) + channel := Channel{ID: 1} + response, err := chatChannel.OpenOutbound(&channel) + if err == nil { + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(response[:], res) + if res.GetOpenChannel() != nil { + // XXX + } else { + t.Errorf("ChatChannel OpenOutbound was not an OpenChannelRequest %v", err) + } + } else { + t.Errorf("Error while parsing openputput output: %v", err) + } +} + +type TestChatChannelHandler struct { +} + +func (tcch *TestChatChannelHandler) ChatMessage(messageID uint32, when time.Time, message string) bool { + return true +} + +func (tcch *TestChatChannelHandler) ChatMessageAck(messageID uint32) { + +} + +func TestChatChannelOperations(t *testing.T) { + + // We test OpenOutboundElsewhere + chatChannel := new(ChatChannel) + chatChannel.Handler = new(TestChatChannelHandler) + channel := Channel{ID: 5} + channel.SendMessage = func(data []byte) { + res := new(Protocol_Data_Chat.Packet) + err := proto.Unmarshal(data, res) + if res.GetChatMessage() != nil { + if err == nil { + if res.GetChatMessage().GetMessageId() != 0 { + t.Log("Got Message ID:", res.GetChatMessage().GetMessageId()) + return + } + t.Errorf("message id was 0 should be random") + return + } + t.Errorf("error sending chat message: %v", err) + } + } + chatChannel.OpenOutbound(&channel) + + messageBuilder := new(utils.MessageBuilder) + ack := messageBuilder.AckOpenChannel(5) + // We have just constructed this so there is little + // point in doing error checking here in the test + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(ack[:], res) + cr := res.GetChannelResult() + + chatChannel.OpenOutboundResult(nil, cr) + if channel.Pending { + t.Errorf("After Successful Result ChatChannel Is Still Pending") + } + + chat := messageBuilder.ChatMessage("message text", 0) + chatChannel.Packet(chat) + + chatChannel.SendMessage("hello") + +} diff --git a/channels/contactrequestchannel.go b/channels/contactrequestchannel.go new file mode 100644 index 0000000..f501c06 --- /dev/null +++ b/channels/contactrequestchannel.go @@ -0,0 +1,148 @@ +package channels + +import ( + "errors" + "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/utils" + "github.com/s-rah/go-ricochet/wire/contact" + "github.com/s-rah/go-ricochet/wire/control" +) + +// ContactRequestChannel implements the ChannelHandler interface for a channel of +// type "im.ricochet.contact.request". The channel may be inbound or outbound. +// a ContactRequestChannelHandler implementation to handle chat events. +type ContactRequestChannel struct { + // Methods of Handler are called for chat events on this channel + Handler ContactRequestChannelHandler + channel *Channel +} + +// ContactRequestChannelHandler is implemented by an application type to receive +// events from a ContactRequestChannel. +// +// Note that ContactRequestChannelHandler is composable with other interfaces, including +// ConnectionHandler; there is no need to use a distinct type as a +// ContactRequestChannelHandler. +type ContactRequestChannelHandler interface { + GetContactDetails() (string, string) + ContactRequest(name string, message string) string + ContactRequestRejected() + ContactRequestAccepted() + ContactRequestError() +} + +// OnlyClientCanOpen - only clients can open contact requests +func (crc *ContactRequestChannel) OnlyClientCanOpen() bool { + return true +} + +// Singleton - only one contact request can be opened per side +func (crc *ContactRequestChannel) Singleton() bool { + return true +} + +// Bidirectional - only clients can send messages +func (crc *ContactRequestChannel) Bidirectional() bool { + return false +} + +// RequiresAuthentication - contact requests require hidden service auth +func (crc *ContactRequestChannel) RequiresAuthentication() string { + return "im.ricochet.auth.hidden-service" +} + +// Type returns the type string for this channel, e.g. "im.ricochet.chat". +func (crc *ContactRequestChannel) Type() string { + return "im.ricochet.contact.request" +} + +// Closed is called when the channel is closed for any reason. +func (crc *ContactRequestChannel) Closed(err error) { + +} + +// OpenInbound is the first method called for an inbound channel request. +// If an error is returned, the channel is rejected. If a RawMessage is +// returned, it will be sent as the ChannelResult message. +func (crc *ContactRequestChannel) OpenInbound(channel *Channel, oc *Protocol_Data_Control.OpenChannel) ([]byte, error) { + crc.channel = channel + contactRequestI, err := proto.GetExtension(oc, Protocol_Data_ContactRequest.E_ContactRequest) + if err == nil { + contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest) + if check { + + if len(contactRequest.GetNickname()) > int(Protocol_Data_ContactRequest.Limits_NicknameMaxCharacters) { + // Violation of the Protocol + return nil, errors.New("invalid nickname") + } + + if len(contactRequest.GetMessageText()) > int(Protocol_Data_ContactRequest.Limits_MessageMaxCharacters) { + // Violation of the Protocol + return nil, errors.New("invalid message") + } + + result := crc.Handler.ContactRequest(contactRequest.GetNickname(), contactRequest.GetMessageText()) + messageBuilder := new(utils.MessageBuilder) + return messageBuilder.ReplyToContactRequestOnResponse(channel.ID, result), nil + } + } + return nil, errors.New("could not parse contact request extension") +} + +// OpenOutbound is the first method called for an outbound channel request. +// If an error is returned, the channel is not opened. If a RawMessage is +// returned, it will be sent as the OpenChannel message. +func (crc *ContactRequestChannel) OpenOutbound(channel *Channel) ([]byte, error) { + crc.channel = channel + name, message := crc.Handler.GetContactDetails() + messageBuilder := new(utils.MessageBuilder) + return messageBuilder.OpenContactRequestChannel(channel.ID, name, message), nil +} + +// OpenOutboundResult is called when a response is received for an +// outbound OpenChannel request. If `err` is non-nil, the channel was +// rejected and Closed will be called immediately afterwards. `raw` +// contains the raw protocol message including any extension data. +func (crc *ContactRequestChannel) OpenOutboundResult(err error, crm *Protocol_Data_Control.ChannelResult) { + if err == nil { + if crm.GetOpened() { + responseI, err := proto.GetExtension(crm, Protocol_Data_ContactRequest.E_Response) + if err == nil { + response, check := responseI.(*Protocol_Data_ContactRequest.Response) + if check { + crc.handleStatus(response.GetStatus().String()) + return + } + } + } + } + crc.channel.SendMessage([]byte{}) +} + +func (crc *ContactRequestChannel) handleStatus(status string) { + switch status { + case "Accepted": + crc.Handler.ContactRequestAccepted() + case "Pending": + break + case "Rejected": + crc.Handler.ContactRequestRejected() + break + case "Error": + crc.Handler.ContactRequestError() + break + } +} + +// Packet is called for each raw packet received on this channel. +func (crc *ContactRequestChannel) Packet(data []byte) { + if !crc.channel.Pending { + response := new(Protocol_Data_ContactRequest.Response) + err := proto.Unmarshal(data, response) + if err == nil { + crc.handleStatus(response.GetStatus().String()) + return + } + } + crc.channel.SendMessage([]byte{}) +} diff --git a/channels/contactrequestchannel_test.go b/channels/contactrequestchannel_test.go new file mode 100644 index 0000000..ad0076c --- /dev/null +++ b/channels/contactrequestchannel_test.go @@ -0,0 +1,226 @@ +package channels + +import ( + "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/utils" + "github.com/s-rah/go-ricochet/wire/contact" + "github.com/s-rah/go-ricochet/wire/control" + "testing" +) + +type TestContactRequestHandler struct { + Received bool +} + +func (tcrh *TestContactRequestHandler) GetContactDetails() (string, string) { + return "", "" +} + +func (tcrh *TestContactRequestHandler) ContactRequest(name string, message string) string { + if name == "test_nickname" && message == "test_message" { + tcrh.Received = true + } + return "Pending" +} + +func (tcrh *TestContactRequestHandler) ContactRequestRejected() { +} +func (tcrh *TestContactRequestHandler) ContactRequestAccepted() { +} +func (tcrh *TestContactRequestHandler) ContactRequestError() { +} + +func TestContactRequestOptions(t *testing.T) { + contactRequestChannel := new(ContactRequestChannel) + + if contactRequestChannel.Type() != "im.ricochet.contact.request" { + t.Errorf("ContactRequestChannel has wrong type %s", contactRequestChannel.Type()) + } + + if !contactRequestChannel.OnlyClientCanOpen() { + t.Errorf("ContactRequestChannel Should be Client Open Only") + } + if !contactRequestChannel.Singleton() { + t.Errorf("ContactRequestChannel Should be a Singelton") + } + if contactRequestChannel.Bidirectional() { + t.Errorf("ContactRequestChannel Should not be bidirectional") + } + if contactRequestChannel.RequiresAuthentication() != "im.ricochet.auth.hidden-service" { + t.Errorf("ContactRequestChannel should requires im.ricochet.auth.hidden-service Authentication. Instead defines: %s", contactRequestChannel.RequiresAuthentication()) + } +} + +func TestContactRequestOpenOutbound(t *testing.T) { + contactRequestChannel := new(ContactRequestChannel) + handler := new(TestContactRequestHandler) + contactRequestChannel.Handler = handler + channel := Channel{ID: 1} + response, err := contactRequestChannel.OpenOutbound(&channel) + if err == nil { + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(response[:], res) + if res.GetOpenChannel() != nil { + // XXX + } else { + t.Errorf("ContactReuqest OpenOutbound was not an OpenChannelRequest %v", err) + } + } else { + t.Errorf("Error while parsing openputput output: %v", err) + } +} + +func TestContactRequestOpenOutboundResult(t *testing.T) { + contactRequestChannel := new(ContactRequestChannel) + handler := new(TestContactRequestHandler) + contactRequestChannel.Handler = handler + channel := Channel{ID: 1} + contactRequestChannel.OpenOutbound(&channel) + + messageBuilder := new(utils.MessageBuilder) + ack := messageBuilder.ReplyToContactRequestOnResponse(1, "Accepted") + // We have just constructed this so there is little + // point in doing error checking here in the test + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(ack[:], res) + cr := res.GetChannelResult() + + contactRequestChannel.OpenOutboundResult(nil, cr) + +} + +func TestContactRequestOpenInbound(t *testing.T) { + opm := BuildOpenChannel("test_nickname", "test_message") + contactRequestChannel := new(ContactRequestChannel) + handler := new(TestContactRequestHandler) + contactRequestChannel.Handler = handler + channel := Channel{ID: 1} + response, err := contactRequestChannel.OpenInbound(&channel, opm) + + if err == nil { + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(response[:], res) + + responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response) + if err == nil { + response, check := responseI.(*Protocol_Data_ContactRequest.Response) + if check { + if response.GetStatus().String() != "Pending" { + t.Errorf("Contact Request Response should have been Pending, but instead was: %v", response.GetStatus().String()) + } + } else { + t.Errorf("Error while parsing openinbound output: %v", err) + } + } else { + t.Errorf("Error while parsing openinbound output: %v", err) + } + } else { + t.Errorf("Error while parsing openinbound output: %v", err) + } + + if !handler.Received { + t.Errorf("Contact Request was not received by Handler") + } +} + +func TestContactRequestPacket(t *testing.T) { + contactRequestChannel := new(ContactRequestChannel) + handler := new(TestContactRequestHandler) + contactRequestChannel.Handler = handler + channel := Channel{ID: 1} + contactRequestChannel.OpenOutbound(&channel) + + messageBuilder := new(utils.MessageBuilder) + ack := messageBuilder.ReplyToContactRequestOnResponse(1, "Pending") + // We have just constructed this so there is little + // point in doing error checking here in the test + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(ack[:], res) + cr := res.GetChannelResult() + + contactRequestChannel.OpenOutboundResult(nil, cr) + + ackp := messageBuilder.ReplyToContactRequest(1, "Accepted") + contactRequestChannel.Packet(ackp) +} + +func TestContactRequestRejected(t *testing.T) { + contactRequestChannel := new(ContactRequestChannel) + handler := new(TestContactRequestHandler) + contactRequestChannel.Handler = handler + channel := Channel{ID: 1} + contactRequestChannel.OpenOutbound(&channel) + + messageBuilder := new(utils.MessageBuilder) + ack := messageBuilder.ReplyToContactRequestOnResponse(1, "Pending") + // We have just constructed this so there is little + // point in doing error checking here in the test + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(ack[:], res) + cr := res.GetChannelResult() + + contactRequestChannel.OpenOutboundResult(nil, cr) + + ackp := messageBuilder.ReplyToContactRequest(1, "Rejected") + contactRequestChannel.Packet(ackp) +} + +func TestContactRequestError(t *testing.T) { + contactRequestChannel := new(ContactRequestChannel) + handler := new(TestContactRequestHandler) + contactRequestChannel.Handler = handler + channel := Channel{ID: 1} + contactRequestChannel.OpenOutbound(&channel) + + messageBuilder := new(utils.MessageBuilder) + ack := messageBuilder.ReplyToContactRequestOnResponse(1, "Pending") + // We have just constructed this so there is little + // point in doing error checking here in the test + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(ack[:], res) + cr := res.GetChannelResult() + + contactRequestChannel.OpenOutboundResult(nil, cr) + + ackp := messageBuilder.ReplyToContactRequest(1, "Error") + contactRequestChannel.Packet(ackp) +} + +func BuildOpenChannel(nickname string, message string) *Protocol_Data_Control.OpenChannel { + // Construct the Open Authentication Channel Message + messageBuilder := new(utils.MessageBuilder) + ocm := messageBuilder.OpenContactRequestChannel(1, nickname, message) + // We have just constructed this so there is little + // point in doing error checking here in the test + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(ocm[:], res) + return res.GetOpenChannel() +} + +func TestInvalidNickname(t *testing.T) { + opm := BuildOpenChannel("this nickname is far too long at well over the limit of 30 characters", "test_message") + contactRequestChannel := new(ContactRequestChannel) + handler := new(TestContactRequestHandler) + contactRequestChannel.Handler = handler + channel := Channel{ID: 1} + _, err := contactRequestChannel.OpenInbound(&channel, opm) + if err == nil { + t.Errorf("Open Inbound should have failed because of invalid nickname") + } +} + +func TestInvalidMessage(t *testing.T) { + var message string + for i := 0; i < 2001; i++ { + message += "a" + } + opm := BuildOpenChannel("test_nickname", message) + contactRequestChannel := new(ContactRequestChannel) + handler := new(TestContactRequestHandler) + contactRequestChannel.Handler = handler + channel := Channel{ID: 1} + _, err := contactRequestChannel.OpenInbound(&channel, opm) + if err == nil { + t.Errorf("Open Inbound should have failed because of invalid message") + } +} diff --git a/channels/handler.go b/channels/handler.go new file mode 100644 index 0000000..ac21033 --- /dev/null +++ b/channels/handler.go @@ -0,0 +1,51 @@ +package channels + +import ( + "github.com/s-rah/go-ricochet/wire/control" +) + +// Handler reacts to low-level events on a protocol channel. There +// should be a unique instance of a ChannelHandler type per channel. +// +// Applications generally don't need to implement ChannelHandler directly; +// instead, use the built-in implementations for common channel types, and +// their individual callback interfaces. ChannelHandler is useful when +// implementing new channel types, or modifying low level default behavior. +type Handler interface { + // Type returns the type string for this channel, e.g. "im.ricochet.chat". + Type() string + + // Closed is called when the channel is closed for any reason. + Closed(err error) + + // OnlyClientCanOpen indicates if only a client can open a given channel + OnlyClientCanOpen() bool + + // Singleton indicates if a channel can only have one instance per direction + Singleton() bool + + // Bidirectional indicates if messages can be send by either side + Bidirectional() bool + + // RequiresAuthentication describes what authentication is needed for the channel + RequiresAuthentication() string + + // OpenInbound is the first method called for an inbound channel request. + // If an error is returned, the channel is rejected. If a RawMessage is + // returned, it will be sent as the ChannelResult message. + OpenInbound(channel *Channel, raw *Protocol_Data_Control.OpenChannel) ([]byte, error) + + // OpenOutbound is the first method called for an outbound channel request. + // If an error is returned, the channel is not opened. If a RawMessage is + // returned, it will be sent as the OpenChannel message. + OpenOutbound(channel *Channel) ([]byte, error) + + // OpenOutboundResult is called when a response is received for an + // outbound OpenChannel request. If `err` is non-nil, the channel was + // rejected and Closed will be called immediately afterwards. `raw` + // contains the raw protocol message including any extension data. + OpenOutboundResult(err error, raw *Protocol_Data_Control.ChannelResult) + + // Packet is called for each raw packet received on this channel. + Packet(data []byte) +} diff --git a/channels/hiddenserviceauthchannel.go b/channels/hiddenserviceauthchannel.go new file mode 100644 index 0000000..f8e007d --- /dev/null +++ b/channels/hiddenserviceauthchannel.go @@ -0,0 +1,253 @@ +package channels + +import ( + "crypto" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/asn1" + "errors" + "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/utils" + "github.com/s-rah/go-ricochet/wire/auth" + "github.com/s-rah/go-ricochet/wire/control" + "io" + "log" +) + +// HiddenServiceAuthChannel wraps implementation o fim.ricochet.auth.hidden-service" +type HiddenServiceAuthChannel struct { + // Methods of Handler are called for events on this channel + Handler AuthChannelHandler + // PrivateKey must be set for client-side authentication channels + PrivateKey *rsa.PrivateKey + // Server Hostname must be set for client-side authentication channels + ServerHostname string + + // Internal state + clientCookie, serverCookie [16]byte + channel *Channel +} + +// AuthChannelHandler ... +type AuthChannelHandler interface { + // Client + ClientAuthResult(accepted bool, isKnownContact bool) + + // Server + ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) + ServerAuthInvalid(err error) +} + +// Type returns the type string for this channel, e.g. "im.ricochet.chat". +func (ah *HiddenServiceAuthChannel) Type() string { + return "im.ricochet.auth.hidden-service" +} + +// Singleton Returns whether or not the given channel type is a singleton +func (ah *HiddenServiceAuthChannel) Singleton() bool { + return true +} + +// OnlyClientCanOpen ... +func (ah *HiddenServiceAuthChannel) OnlyClientCanOpen() bool { + return true +} + +// Bidirectional Returns whether or not the given channel allows anyone to send messages +func (ah *HiddenServiceAuthChannel) Bidirectional() bool { + return false +} + +// RequiresAuthentication Returns whether or not the given channel type requires authentication +func (ah *HiddenServiceAuthChannel) RequiresAuthentication() string { + return "none" +} + +// Closed is called when the channel is closed for any reason. +func (ah *HiddenServiceAuthChannel) Closed(err error) { + +} + +// OpenInbound is the first method called for an inbound channel request. +// If an error is returned, the channel is rejected. If a RawMessage is +// returned, it will be sent as the ChannelResult message. +// Remote -> [Open Authentication Channel] -> Local +func (ah *HiddenServiceAuthChannel) OpenInbound(channel *Channel, oc *Protocol_Data_Control.OpenChannel) ([]byte, error) { + ah.channel = channel + clientCookie, _ := proto.GetExtension(oc, Protocol_Data_AuthHiddenService.E_ClientCookie) + if len(clientCookie.([]byte)[:]) != 16 { + // reutrn without opening channel. + return nil, errors.New("invalid client cookie") + } + ah.AddClientCookie(clientCookie.([]byte)[:]) + messageBuilder := new(utils.MessageBuilder) + channel.Pending = false + return messageBuilder.ConfirmAuthChannel(ah.channel.ID, ah.GenServerCookie()), nil +} + +// OpenOutbound is the first method called for an outbound channel request. +// If an error is returned, the channel is not opened. If a RawMessage is +// returned, it will be sent as the OpenChannel message. +// Local -> [Open Authentication Channel] -> Remote +func (ah *HiddenServiceAuthChannel) OpenOutbound(channel *Channel) ([]byte, error) { + ah.channel = channel + messageBuilder := new(utils.MessageBuilder) + return messageBuilder.OpenAuthenticationChannel(ah.channel.ID, ah.GenClientCookie()), nil +} + +// OpenOutboundResult is called when a response is received for an +// outbound OpenChannel request. If `err` is non-nil, the channel was +// rejected and Closed will be called immediately afterwards. `raw` +// contains the raw protocol message including any extension data. +// Input: Remote -> [ChannelResult] -> {Client} +// Output: {Client} -> [Proof] -> Remote +func (ah *HiddenServiceAuthChannel) OpenOutboundResult(err error, crm *Protocol_Data_Control.ChannelResult) { + + if err == nil { + + if crm.GetOpened() { + serverCookie, _ := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie) + + if len(serverCookie.([]byte)[:]) != 16 { + ah.channel.SendMessage([]byte{}) + return + } + + ah.AddServerCookie(serverCookie.([]byte)[:]) + + publicKeyBytes, _ := asn1.Marshal(rsa.PublicKey{ + N: ah.PrivateKey.PublicKey.N, + E: ah.PrivateKey.PublicKey.E, + }) + + clientHostname := utils.GetTorHostname(publicKeyBytes) + challenge := ah.GenChallenge(clientHostname, ah.ServerHostname) + + signature, err := rsa.SignPKCS1v15(nil, ah.PrivateKey, crypto.SHA256, challenge) + + if err != nil { + ah.channel.SendMessage([]byte{}) + return + } + + messageBuilder := new(utils.MessageBuilder) + proof := messageBuilder.Proof(publicKeyBytes, signature) + ah.channel.SendMessage(proof) + } + } +} + +// Packet is called for each raw packet received on this channel. +// Input: Remote -> [Proof] -> Client +// OR +// Input: Remote -> [Result] -> Client +func (ah *HiddenServiceAuthChannel) Packet(data []byte) { + res := new(Protocol_Data_AuthHiddenService.Packet) + err := proto.Unmarshal(data[:], res) + + if err != nil { + ah.channel.CloseChannel() + return + } + + if res.GetProof() != nil && ah.channel.Direction == Inbound { + provisionalClientHostname := utils.GetTorHostname(res.GetProof().GetPublicKey()) + + publicKeyBytes, err := asn1.Marshal(rsa.PublicKey{ + N: ah.PrivateKey.PublicKey.N, + E: ah.PrivateKey.PublicKey.E, + }) + + if err != nil { + ah.Handler.ServerAuthInvalid(err) + ah.channel.SendMessage([]byte{}) + return + } + + serverHostname := utils.GetTorHostname(publicKeyBytes) + + publicKey := rsa.PublicKey{} + _, err = asn1.Unmarshal(res.GetProof().GetPublicKey(), &publicKey) + if err != nil { + ah.Handler.ServerAuthInvalid(err) + ah.channel.SendMessage([]byte{}) + return + } + + challenge := ah.GenChallenge(provisionalClientHostname, serverHostname) + + err = rsa.VerifyPKCS1v15(&publicKey, crypto.SHA256, challenge[:], res.GetProof().GetSignature()) + + if err == nil { + // Signature is Good + accepted, isKnownContact := ah.Handler.ServerAuthValid(provisionalClientHostname, publicKey) + + // Send Result + messageBuilder := new(utils.MessageBuilder) + result := messageBuilder.AuthResult(accepted, isKnownContact) + ah.channel.DelegateAuthorization() + ah.channel.SendMessage(result) + } else { + // Auth Failed + messageBuilder := new(utils.MessageBuilder) + result := messageBuilder.AuthResult(false, false) + ah.channel.SendMessage(result) + ah.Handler.ServerAuthInvalid(err) + } + + } else if res.GetResult() != nil && ah.channel.Direction == Outbound { + ah.Handler.ClientAuthResult(res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) + if res.GetResult().GetAccepted() { + ah.channel.DelegateAuthorization() + } + } + + // Any other combination of packets is completely invalid + // Fail the Authorization right here. + ah.channel.CloseChannel() +} + +// AddClientCookie adds a client cookie to the state. +func (ah *HiddenServiceAuthChannel) AddClientCookie(cookie []byte) { + copy(ah.clientCookie[:], cookie[:16]) +} + +// AddServerCookie adds a server cookie to the state. +func (ah *HiddenServiceAuthChannel) AddServerCookie(cookie []byte) { + copy(ah.serverCookie[:], cookie[:16]) +} + +// GenRandom generates a random 16byte cookie string. +func (ah *HiddenServiceAuthChannel) GenRandom() [16]byte { + var cookie [16]byte + io.ReadFull(rand.Reader, cookie[:]) + return cookie +} + +// GenClientCookie generates and adds a client cookie to the state. +func (ah *HiddenServiceAuthChannel) GenClientCookie() [16]byte { + ah.clientCookie = ah.GenRandom() + return ah.clientCookie +} + +// GenServerCookie generates and adds a server cookie to the state. +func (ah *HiddenServiceAuthChannel) GenServerCookie() [16]byte { + ah.serverCookie = ah.GenRandom() + return ah.serverCookie +} + +// GenChallenge constructs the challenge parameter for the AuthHiddenService session. +// The challenge is the a Sha256HMAC(clientHostname+serverHostname, key=clientCookie+serverCookie) +func (ah *HiddenServiceAuthChannel) GenChallenge(clientHostname string, serverHostname string) []byte { + key := make([]byte, 32) + copy(key[0:16], ah.clientCookie[:]) + copy(key[16:], ah.serverCookie[:]) + log.Printf("CHALLENGE: %s %s %v", clientHostname, serverHostname, key) + value := []byte(clientHostname + serverHostname) + mac := hmac.New(sha256.New, key) + mac.Write(value) + hmac := mac.Sum(nil) + return hmac +} diff --git a/channels/hiddenserviceauthchannel_test.go b/channels/hiddenserviceauthchannel_test.go new file mode 100644 index 0000000..75456c9 --- /dev/null +++ b/channels/hiddenserviceauthchannel_test.go @@ -0,0 +1,163 @@ +package channels + +import ( + "bytes" + "crypto/rsa" + "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/utils" + "github.com/s-rah/go-ricochet/wire/control" + "testing" +) + +func TestGenChallenge(t *testing.T) { + authHandler := new(HiddenServiceAuthChannel) + authHandler.AddClientCookie([]byte("abcdefghijklmnop")) + authHandler.AddServerCookie([]byte("qrstuvwxyz012345")) + challenge := authHandler.GenChallenge("test.onion", "notareal.onion") + expectedChallenge := []byte{0xf5, 0xdb, 0xfd, 0xf0, 0x3d, 0x94, 0x14, 0xf1, 0x4b, 0x37, 0x93, 0xe2, 0xa5, 0x11, 0x4a, 0x98, 0x31, 0x90, 0xea, 0xb8, 0x95, 0x7a, 0x2e, 0xaa, 0xd0, 0xd2, 0x0c, 0x74, 0x95, 0xba, 0xab, 0x73} + t.Log(challenge, expectedChallenge) + if bytes.Compare(challenge[:], expectedChallenge[:]) != 0 { + t.Errorf("HiddenServiceAuthChannel Challenge Is Invalid, Got %x, Expected %x", challenge, expectedChallenge) + } +} + +func TestGenClientCookie(t *testing.T) { + authHandler := new(HiddenServiceAuthChannel) + clientCookie := authHandler.GenClientCookie() + if clientCookie != authHandler.clientCookie { + t.Errorf("HiddenServiceAuthChannel Client Cookies are Different %x %x", clientCookie, authHandler.clientCookie) + } +} + +func TestGenServerCookie(t *testing.T) { + authHandler := new(HiddenServiceAuthChannel) + serverCookie := authHandler.GenServerCookie() + if serverCookie != authHandler.serverCookie { + t.Errorf("HiddenServiceAuthChannel Server Cookies are Different %x %x", serverCookie, authHandler.serverCookie) + } +} + +func TestHiddenServiceAuthChannelOptions(t *testing.T) { + hiddenServiceAuthChannel := new(HiddenServiceAuthChannel) + + if hiddenServiceAuthChannel.Type() != "im.ricochet.auth.hidden-service" { + t.Errorf("AuthHiddenService has wrong type %s", hiddenServiceAuthChannel.Type()) + } + + if !hiddenServiceAuthChannel.OnlyClientCanOpen() { + t.Errorf("AuthHiddenService Should be Client Open Only") + } + if !hiddenServiceAuthChannel.Singleton() { + t.Errorf("AuthHiddenService Should be a Singelton") + } + if hiddenServiceAuthChannel.Bidirectional() { + t.Errorf("AuthHiddenService Should not be bidirectional") + } + if hiddenServiceAuthChannel.RequiresAuthentication() != "none" { + t.Errorf("AuthHiddenService should require no authorization. Instead requires: %s", hiddenServiceAuthChannel.RequiresAuthentication()) + } +} + +func GetOpenAuthenticationChannelMessage() *Protocol_Data_Control.OpenChannel { + // Construct the Open Authentication Channel Message + messageBuilder := new(utils.MessageBuilder) + ocm := messageBuilder.OpenAuthenticationChannel(1, [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) + + // We have just constructed this so there is little + // point in doing error checking here in the test + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(ocm[:], res) + return res.GetOpenChannel() +} + +func TestAuthenticationOpenInbound(t *testing.T) { + + opm := GetOpenAuthenticationChannelMessage() + authHandler := new(HiddenServiceAuthChannel) + channel := Channel{ID: 1} + response, err := authHandler.OpenInbound(&channel, opm) + + if err == nil { + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(response[:], res) + + if res.GetChannelResult() == nil || !res.GetChannelResult().GetOpened() { + t.Errorf("Response not a Open Channel Result %v", res) + } + } else { + t.Errorf("HiddenServiceAuthChannel OpenOutbound Failed: %v", err) + } +} + +func TestAuthenticationOpenOutbound(t *testing.T) { + authHandler := new(HiddenServiceAuthChannel) + channel := Channel{ID: 1} + response, err := authHandler.OpenOutbound(&channel) + + if err == nil { + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(response[:], res) + + if res.GetOpenChannel() == nil { + t.Errorf("Open Channel Packet not included %v", res) + } + } else { + t.Errorf("HiddenServiceAuthChannel OpenInbound Failed: %v", err) + } + +} + +type SimpleTestAuthHandler struct { +} + +// Client +func (stah *SimpleTestAuthHandler) ClientAuthResult(accepted bool, isKnownContact bool) { + +} + +// Server +func (stah *SimpleTestAuthHandler) ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) { + return true, true +} + +func (stah *SimpleTestAuthHandler) ServerAuthInvalid(err error) { + +} + +func TestAuthenticationOpenOutboundResult(t *testing.T) { + + privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key") + + authHandlerA := new(HiddenServiceAuthChannel) + authHandlerB := new(HiddenServiceAuthChannel) + simpleTestAuthHandler := new(SimpleTestAuthHandler) + + authHandlerA.ServerHostname = "kwke2hntvyfqm7dr" + authHandlerA.PrivateKey = privateKey + authHandlerA.Handler = simpleTestAuthHandler + channelA := Channel{ID: 1, Direction: Outbound} + channelA.SendMessage = func(message []byte) { + authHandlerB.Packet(message) + } + channelA.DelegateAuthorization = func() {} + channelA.CloseChannel = func() {} + response, _ := authHandlerA.OpenOutbound(&channelA) + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(response[:], res) + + authHandlerB.ServerHostname = "kwke2hntvyfqm7dr" + authHandlerB.PrivateKey = privateKey + authHandlerB.Handler = simpleTestAuthHandler + channelB := Channel{ID: 1, Direction: Inbound} + channelB.SendMessage = func(message []byte) { + authHandlerA.Packet(message) + } + channelB.DelegateAuthorization = func() {} + channelB.CloseChannel = func() {} + response, _ = authHandlerB.OpenInbound(&channelB, res.GetOpenChannel()) + res = new(Protocol_Data_Control.Packet) + proto.Unmarshal(response[:], res) + + authHandlerA.OpenOutboundResult(nil, res.GetChannelResult()) + +} diff --git a/connection/autoconnectionhandler.go b/connection/autoconnectionhandler.go new file mode 100644 index 0000000..238d115 --- /dev/null +++ b/connection/autoconnectionhandler.go @@ -0,0 +1,94 @@ +package connection + +import ( + "crypto/rsa" + "github.com/s-rah/go-ricochet/channels" + "github.com/s-rah/go-ricochet/utils" + "log" +) + +// AutoConnectionHandler implements the ConnectionHandler interface on behalf of +// the provided application type by automatically providing support for any +// built-in channel type whose high level interface is implemented by the +// application. For example, if the application's type implements the +// ChatChannelHandler interface, `im.ricochet.chat` will be available to the peer. +// +// The application handler can be any other type. To override or augment any of +// AutoConnectionHandler's behavior (such as adding new channel types, or reacting +// to connection close events), this type can be embedded in the type that it serves. +type AutoConnectionHandler struct { + handlerMap map[string]func() channels.Handler + connection *Connection + authResultChannel chan channels.AuthChannelResult + sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) +} + +// Init ... +func (ach *AutoConnectionHandler) Init(privateKey *rsa.PrivateKey, serverHostname string) { + ach.handlerMap = make(map[string]func() channels.Handler) + ach.RegisterChannelHandler("im.ricochet.auth.hidden-service", func() channels.Handler { + hsau := new(channels.HiddenServiceAuthChannel) + hsau.PrivateKey = privateKey + hsau.Handler = ach + hsau.ServerHostname = serverHostname + return hsau + }) + ach.authResultChannel = make(chan channels.AuthChannelResult) +} + +// SetServerAuthHandler ... +func (ach *AutoConnectionHandler) SetServerAuthHandler(sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)) { + ach.sach = sach +} + +// OnReady ... +func (ach *AutoConnectionHandler) OnReady(oc *Connection) { + ach.connection = oc +} + +// OnClosed is called when the OpenConnection has closed for any reason. +func (ach *AutoConnectionHandler) OnClosed(err error) { +} + +// WaitForAuthenticationEvent ... +func (ach *AutoConnectionHandler) WaitForAuthenticationEvent() channels.AuthChannelResult { + return <-ach.authResultChannel +} + +// ClientAuthResult ... +func (ach *AutoConnectionHandler) ClientAuthResult(accepted bool, isKnownContact bool) { + log.Printf("Got auth result %v %v", accepted, isKnownContact) + ach.authResultChannel <- channels.AuthChannelResult{Accepted: accepted, IsKnownContact: isKnownContact} +} + +// ServerAuthValid ... +func (ach *AutoConnectionHandler) ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) { + // Do something + accepted, isKnownContact := ach.sach(hostname, publicKey) + ach.authResultChannel <- channels.AuthChannelResult{Accepted: accepted, IsKnownContact: isKnownContact} + return accepted, isKnownContact +} + +// ServerAuthInvalid ... +func (ach *AutoConnectionHandler) ServerAuthInvalid(err error) { + ach.authResultChannel <- channels.AuthChannelResult{Accepted: false, IsKnownContact: false} +} + +// RegisterChannelHandler ... +func (ach *AutoConnectionHandler) RegisterChannelHandler(ctype string, handler func() channels.Handler) { + _, exists := ach.handlerMap[ctype] + if !exists { + ach.handlerMap[ctype] = handler + } +} + +// OnOpenChannelRequest ... +func (ach *AutoConnectionHandler) OnOpenChannelRequest(ctype string) (channels.Handler, error) { + handler, ok := ach.handlerMap[ctype] + if ok { + h := handler() + log.Printf("Got Channel Handler") + return h, nil + } + return nil, utils.UnknownChannelTypeError +} diff --git a/connection/autoconnectionhandler_test.go b/connection/autoconnectionhandler_test.go new file mode 100644 index 0000000..977a022 --- /dev/null +++ b/connection/autoconnectionhandler_test.go @@ -0,0 +1,36 @@ +package connection + +import ( + "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/utils" + "github.com/s-rah/go-ricochet/wire/control" + "testing" +) + +// Test sending valid packets +func TestInit(t *testing.T) { + ach := new(AutoConnectionHandler) + privateKey, err := utils.LoadPrivateKeyFromFile("../testing/private_key") + + ach.Init(privateKey, "") + + // Construct the Open Authentication Channel Message + messageBuilder := new(utils.MessageBuilder) + ocm := messageBuilder.OpenAuthenticationChannel(1, [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) + + // We have just constructed this so there is little + // point in doing error checking here in the test + res := new(Protocol_Data_Control.Packet) + proto.Unmarshal(ocm[:], res) + opm := res.GetOpenChannel() + //ocmessage, _ := proto.Marshal(opm) + handler, err := ach.OnOpenChannelRequest(opm.GetChannelType()) + + if err == nil { + if handler.Type() != "im.ricochet.auth.hidden-service" { + t.Errorf("Failed to authentication handler: %v", handler.Type()) + } + } else { + t.Errorf("Failed to build handler: %v", err) + } +} diff --git a/connection/channelmanager.go b/connection/channelmanager.go new file mode 100644 index 0000000..8dfe5a4 --- /dev/null +++ b/connection/channelmanager.go @@ -0,0 +1,114 @@ +package connection + +import ( + "errors" + "github.com/s-rah/go-ricochet/channels" +) + +// ChannelManager encapsulates the logic for server and client side assignment +// and removal of channels. +type ChannelManager struct { + channels map[int32]*channels.Channel + nextFreeChannel int32 + isClient bool +} + +// NewClientChannelManager construsts a new channel manager enforcing behaviour +// of a ricochet client +func NewClientChannelManager() *ChannelManager { + channelManager := new(ChannelManager) + channelManager.channels = make(map[int32]*channels.Channel) + channelManager.nextFreeChannel = 1 + channelManager.isClient = true + return channelManager +} + +// NewServerChannelManager construsts a new channel manager enforcing behaviour +// from a ricochet server +func NewServerChannelManager() *ChannelManager { + channelManager := new(ChannelManager) + channelManager.channels = make(map[int32]*channels.Channel) + channelManager.nextFreeChannel = 2 + channelManager.isClient = false + return channelManager +} + +// OpenChannelRequest constructs a channel type ready for processing given a request +// from the client. +func (cm *ChannelManager) OpenChannelRequest(chandler channels.Handler) (*channels.Channel, error) { + // Some channels only allow us to open one of them per connection + if chandler.Singleton() && cm.Channel(chandler.Type(), channels.Outbound) != nil { + return nil, errors.New("Connection already has channel of type " + chandler.Type()) + } + + channel := new(channels.Channel) + channel.ID = cm.nextFreeChannel + cm.nextFreeChannel += 2 + channel.Type = chandler.Type() + channel.Handler = &chandler + channel.Pending = true + channel.Direction = channels.Outbound + cm.channels[channel.ID] = channel + return channel, nil +} + +// OpenChannelRequestFromPeer constructs a channel type ready for processing given a request +// from the remote peer. +func (cm *ChannelManager) OpenChannelRequestFromPeer(channelID int32, chandler channels.Handler) (*channels.Channel, error) { + if cm.isClient && (channelID%2) != 0 { + // Server is trying to open odd numbered channels + return nil, errors.New("server may only open even numbered channels") + } else if !cm.isClient && (channelID%2) == 0 { + // Server is trying to open odd numbered channels + return nil, errors.New("client may only open odd numbered channels") + } + + _, exists := cm.channels[channelID] + if exists { + return nil, errors.New("channel id is already in use") + } + + // Some channels only allow us to open one of them per connection + if chandler.Singleton() && cm.Channel(chandler.Type(), channels.Inbound) != nil { + return nil, errors.New("Connection already has channel of type " + chandler.Type()) + } + + channel := new(channels.Channel) + channel.ID = channelID + channel.Type = chandler.Type() + channel.Handler = &chandler + + channel.Pending = true + channel.Direction = channels.Inbound + cm.channels[channelID] = channel + return channel, nil +} + +// Channel finds an open or pending `type` channel in the direction `way` (Inbound +// or Outbound), and returns the associated state. Returns nil if no matching channel +// exists or if multiple matching channels exist. +func (cm *ChannelManager) Channel(ctype string, way channels.Direction) *channels.Channel { + var foundChannel *channels.Channel + for _, channel := range cm.channels { + if (*channel.Handler).Type() == ctype && channel.Direction == way { + if foundChannel == nil { + foundChannel = channel + } else { + // we have found multiple channels. + return nil + } + } + } + return foundChannel +} + +// GetChannel finds and returns a given channel if it is found +func (cm *ChannelManager) GetChannel(channelID int32) (*channels.Channel, bool) { + channel, found := cm.channels[channelID] + return channel, found +} + +// RemoveChannel removes a given channel id. +func (cm *ChannelManager) RemoveChannel(channelID int32) { + delete(cm.channels, channelID) +} diff --git a/connection/channelmanager_test.go b/connection/channelmanager_test.go new file mode 100644 index 0000000..d6dbace --- /dev/null +++ b/connection/channelmanager_test.go @@ -0,0 +1,62 @@ +package connection + +import ( + "github.com/s-rah/go-ricochet/channels" + "testing" +) + +func TestClientManagerDuplicateChannel(t *testing.T) { + ccm := NewClientChannelManager() + chatChannel := new(channels.ChatChannel) + _, err := ccm.OpenChannelRequestFromPeer(2, chatChannel) + if err != nil { + t.Errorf("Opening ChatChannel should have succeeded, instead: %v", err) + } + _, err = ccm.OpenChannelRequestFromPeer(2, chatChannel) + if err == nil { + t.Errorf("Opening ChatChannel should have failed") + } + + _, err = ccm.OpenChannelRequestFromPeer(4, chatChannel) + if err == nil { + t.Errorf("Opening ChatChannel should have failed because there should be only 1") + } +} + +func TestClientManagerBadServer(t *testing.T) { + ccm := NewClientChannelManager() + // Servers are not allowed to open odd numbered channels + _, err := ccm.OpenChannelRequestFromPeer(3, nil) + if err == nil { + t.Errorf("OpenChannelRequestFromPeer should have failed") + } +} + +func TestServerManagerBadClient(t *testing.T) { + scm := NewServerChannelManager() + // Clients are not allowed to open even numbered channels + _, err := scm.OpenChannelRequestFromPeer(2, nil) + if err == nil { + t.Errorf("OpenChannelRequestFromPeer should have failed") + } +} + +func TestLocalDuplicate(t *testing.T) { + scm := NewServerChannelManager() + chatChannel := new(channels.ChatChannel) + channel, err := scm.OpenChannelRequest(chatChannel) + if err != nil { + t.Errorf("OpenChannelRequest should not have failed: %v", err) + } + + _, err = scm.OpenChannelRequest(chatChannel) + if err == nil { + t.Errorf("OpenChannelRequest should have failed") + } + + scm.RemoveChannel(channel.ID) + _, err = scm.OpenChannelRequest(chatChannel) + if err != nil { + t.Errorf("OpenChannelRequest should not have failed: %v", err) + } +} diff --git a/connection/connection.go b/connection/connection.go new file mode 100644 index 0000000..398c063 --- /dev/null +++ b/connection/connection.go @@ -0,0 +1,321 @@ +package connection + +import ( + "errors" + "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/channels" + "github.com/s-rah/go-ricochet/utils" + "github.com/s-rah/go-ricochet/wire/control" + "io" + "log" + "time" +) + +// Connection encapsulates the state required to maintain a connection to +// a ricochet service. +type Connection struct { + utils.RicochetNetwork + + channelManager *ChannelManager + + // Ricochet Network Loop + packetChannel chan utils.RicochetData + errorChannel chan error + + breakChannel chan bool + breakResultChannel chan bool + + unlockChannel chan bool + unlockResponseChannel chan bool + + messageBuilder utils.MessageBuilder + + Conn io.ReadWriteCloser + IsInbound bool + Authentication map[string]bool + RemoteHostname string +} + +func (rc *Connection) init() { + + rc.packetChannel = make(chan utils.RicochetData) + rc.errorChannel = make(chan error) + + rc.breakChannel = make(chan bool) + rc.breakResultChannel = make(chan bool) + + rc.unlockChannel = make(chan bool) + rc.unlockResponseChannel = make(chan bool) + + rc.Authentication = make(map[string]bool) + go rc.start() +} + +// NewInboundConnection creates a new Connection struct +// modelling an Inbound Connection +func NewInboundConnection(conn io.ReadWriteCloser) *Connection { + rc := new(Connection) + rc.Conn = conn + rc.IsInbound = true + rc.init() + rc.channelManager = NewServerChannelManager() + return rc +} + +// NewOutboundConnection creates a new Connection struct +// modelling an Inbound Connection +func NewOutboundConnection(conn io.ReadWriteCloser, remoteHostname string) *Connection { + rc := new(Connection) + rc.Conn = conn + rc.IsInbound = false + rc.init() + rc.RemoteHostname = remoteHostname + rc.channelManager = NewClientChannelManager() + return rc +} + +// start +func (rc *Connection) start() { + for { + packet, err := rc.RecvRicochetPacket(rc.Conn) + if err != nil { + rc.errorChannel <- err + return + } + rc.packetChannel <- packet + } +} + +// Do allows any function utilizing Connection to be run safetly. +// All operations which require access to Connection managed resources should +// use Do() +func (rc *Connection) Do(do func() error) error { + // Force process to soft-break so we can lock + log.Printf("UnLocking Processloop") + rc.unlockChannel <- true + log.Printf("Unlocked Processloop") + ret := do() + log.Printf("Giving up lock Processloop") + rc.unlockResponseChannel <- true + return ret +} + +// RequestOpenChannel sends an OpenChannel message to the remote client. +// and error is returned only if the requirements for opening this channel +// are not met on the local side (a nill error return does not mean the +// channel was opened successfully) +func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error { + return rc.Do(func() error { + chandler, err := handler.OnOpenChannelRequest(ctype) + + if err != nil { + return err + } + + // Check that we have the authentication already + if chandler.RequiresAuthentication() != "none" { + // Enforce Authentication Check. + _, authed := rc.Authentication[chandler.RequiresAuthentication()] + if !authed { + return errors.New("connection is not auth'd") + } + } + + channel, err := rc.channelManager.OpenChannelRequest(chandler) + + if err != nil { + return err + } + + channel.SendMessage = func(message []byte) { + rc.SendRicochetPacket(rc.Conn, channel.ID, message) + } + channel.DelegateAuthorization = func() { + rc.Authentication[chandler.Type()] = true + } + channel.CloseChannel = func() { + rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) + rc.channelManager.RemoveChannel(channel.ID) + } + response, err := chandler.OpenOutbound(channel) + if err == nil { + rc.SendRicochetPacket(rc.Conn, 0, response) + } else { + rc.channelManager.RemoveChannel(channel.ID) + } + return nil + }) +} + +// Process receives socket and protocol events for the connection. Methods +// of the application-provided `handler` will be called from this goroutine +// for all events. +// +// Process must be running in order to handle any events on the connection, +// including connection close. +// +// Process blocks until the connection is closed or until Break() is called. +// If the connection is closed, a non-nil error is returned. +func (rc *Connection) Process(handler Handler) error { + log.Printf("Entering Processloop") + handler.OnReady(rc) + breaked := false + for !breaked { + + var packet utils.RicochetData + tick := time.Tick(30 * time.Second) + select { + case <-rc.unlockChannel: + <-rc.unlockResponseChannel + continue + case <-rc.breakChannel: + log.Printf("Process has Ended as Expected!!!") + breaked = true + continue + case packet = <-rc.packetChannel: + break + case err := <-rc.errorChannel: + rc.Conn.Close() + handler.OnClosed(err) + return err + case <-tick: + log.Printf("timeout") + return errors.New("peer timed out") + } + + log.Printf("Received Packet on Channel %d", packet.Channel) + + if packet.Channel == 0 { + res := new(Protocol_Data_Control.Packet) + err := proto.Unmarshal(packet.Data[:], res) + if err == nil { + rc.controlPacket(handler, res) + } + } else { + // Let's check to see if we have defined this channel. + channel, found := rc.channelManager.GetChannel(packet.Channel) + if found { + if len(packet.Data) == 0 { + rc.channelManager.RemoveChannel(packet.Channel) + (*channel.Handler).Closed(errors.New("channel closed by peer")) + } else { + // Send The Ricochet Packet to the Handler + (*channel.Handler).Packet(packet.Data[:]) + } + } else { + // When a non-zero packet is received for an unknown + // channel, the recipient responds by closing + // that channel. + if len(packet.Data) != 0 { + rc.SendRicochetPacket(rc.Conn, packet.Channel, []byte{}) + } + } + } + } + + rc.breakResultChannel <- true + return nil + +} + +func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.Packet) { + + if res.GetOpenChannel() != nil { + + opm := res.GetOpenChannel() + chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType()) + + if err != nil { + + response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnknownTypeError") + rc.SendRicochetPacket(rc.Conn, 0, response) + return + } + + // Check that we have the authentication already + if chandler.RequiresAuthentication() != "none" { + // Enforce Authentication Check. + _, authed := rc.Authentication[chandler.RequiresAuthentication()] + if !authed { + rc.SendRicochetPacket(rc.Conn, 0, []byte{}) + return + } + } + + channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler) + + if err == nil { + + channel.SendMessage = func(message []byte) { + rc.SendRicochetPacket(rc.Conn, channel.ID, message) + } + channel.DelegateAuthorization = func() { + rc.Authentication[chandler.Type()] = true + } + channel.CloseChannel = func() { + rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) + rc.channelManager.RemoveChannel(channel.ID) + } + + response, err := chandler.OpenInbound(channel, opm) + if err == nil && channel.Pending == false { + log.Printf("Opening Channel %v on %v", channel.Type, channel.ID) + rc.SendRicochetPacket(rc.Conn, 0, response) + } else { + rc.channelManager.RemoveChannel(channel.ID) + rc.SendRicochetPacket(rc.Conn, 0, []byte{}) + } + } else { + // Send Error Packet + response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "GenericError") + rc.SendRicochetPacket(rc.Conn, 0, response) + + } + } else if res.GetChannelResult() != nil { + cr := res.GetChannelResult() + id := cr.GetChannelIdentifier() + + channel, found := rc.channelManager.GetChannel(id) + + if !found { + return + } + + if cr.GetOpened() { + (*channel.Handler).OpenOutboundResult(nil, cr) + } else { + (*channel.Handler).OpenOutboundResult(errors.New(""), cr) + } + + } else if res.GetKeepAlive() != nil { + // XXX Though not currently part of the protocol + // We should likely put these calls behind + // authentication. + if res.GetKeepAlive().GetResponseRequested() { + messageBuilder := new(utils.MessageBuilder) + raw := messageBuilder.KeepAlive(true) + rc.SendRicochetPacket(rc.Conn, 0, raw) + } + } else if res.GetEnableFeatures() != nil { + // TODO Respond with an Empty List + messageBuilder := new(utils.MessageBuilder) + raw := messageBuilder.FeaturesEnabled([]string{}) + rc.SendRicochetPacket(rc.Conn, 0, raw) + } else if res.GetFeaturesEnabled() != nil { + // TODO We should never send out an enabled features + // request. + } +} + +// Break causes Process() to return, but does not close the underlying connection +func (rc *Connection) Break() { + log.Printf("breaking...") + rc.breakChannel <- true + <-rc.breakResultChannel // Wait for Process to End +} + +// Channel is a convienciance method for returning a given channel to the caller +// of Process() - TODO - this is kind of ugly. +func (rc *Connection) Channel(ctype string, way channels.Direction) *channels.Channel { + return rc.channelManager.Channel(ctype, way) +} diff --git a/connection/connection_test.go b/connection/connection_test.go new file mode 100644 index 0000000..9b979fd --- /dev/null +++ b/connection/connection_test.go @@ -0,0 +1,88 @@ +package connection + +import ( + "crypto/rsa" + "github.com/s-rah/go-ricochet/utils" + "net" + "testing" + "time" +) + +// Server +func ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) { + return true, true +} + +func TestProcessAuthAsServer(t *testing.T) { + + ln, _ := net.Listen("tcp", "127.0.0.1:0") + + go func() { + cconn, _ := net.Dial("tcp", ln.Addr().String()) + + orc := NewOutboundConnection(cconn, "kwke2hntvyfqm7dr") + privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key") + + known, err := HandleOutboundConnection(orc).ProcessAuthAsClient(privateKey) + if err != nil { + t.Errorf("Error while testing ProcessAuthAsClient (in ProcessAuthAsServer) %v", err) + return + } else if !known { + t.Errorf("Client should have been known to the server, instead known was: %v", known) + return + } + }() + + conn, _ := ln.Accept() + privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key") + + rc := NewInboundConnection(conn) + err := HandleInboundConnection(rc).ProcessAuthAsServer(privateKey, ServerAuthValid) + if err != nil { + t.Errorf("Error while testing ProcessAuthAsServer: %v", err) + } +} + +func TestProcessServerAuthFail(t *testing.T) { + + ln, _ := net.Listen("tcp", "127.0.0.1:0") + + go func() { + cconn, _ := net.Dial("tcp", ln.Addr().String()) + + orc := NewOutboundConnection(cconn, "kwke2hntvyfqm7dr") + privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key") + + HandleOutboundConnection(orc).ProcessAuthAsClient(privateKey) + + }() + + conn, _ := ln.Accept() + privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key_auth_fail_test") + + rc := NewInboundConnection(conn) + err := HandleInboundConnection(rc).ProcessAuthAsServer(privateKey, ServerAuthValid) + if err == nil { + t.Errorf("Error while testing ProcessAuthAsServer - should have failed %v", err) + } +} + +func TestProcessAuthTimeout(t *testing.T) { + + ln, _ := net.Listen("tcp", "127.0.0.1:0") + + go func() { + net.Dial("tcp", ln.Addr().String()) + time.Sleep(16 * time.Second) + + }() + + conn, _ := ln.Accept() + privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key") + + rc := NewInboundConnection(conn) + err := HandleInboundConnection(rc).ProcessAuthAsServer(privateKey, ServerAuthValid) + if err != utils.ActionTimedOutError { + t.Errorf("Error while testing TestProcessAuthTimeout - Should have timed out after 15 seconds") + } +} diff --git a/connection/handler.go b/connection/handler.go new file mode 100644 index 0000000..b2fd78b --- /dev/null +++ b/connection/handler.go @@ -0,0 +1,28 @@ +package connection + +import ( + "github.com/s-rah/go-ricochet/channels" +) + +// Handler reacts to low-level events on a protocol connection. +// There should be a unique instance of a ConnectionHandler type per +// OpenConnection. +type Handler interface { + // OnReady is called when the connection begins using this handler. + OnReady(oc *Connection) + + // OnClosed is called when the OpenConnection has closed for any reason. + OnClosed(err error) + + // OpenChannelRequest is called when the peer asks to open a channel of + // `type`. `raw` contains the protocol OpenChannel message including any + // extension data. If this channel type is recognized and allowed by this + // connection in this state, return a type implementing ChannelHandler for + // events related to this channel. Returning an error or nil rejects the + // channel. + // + // Channel type handlers may implement additional state and sanity checks. + // A non-nil return from this function does not guarantee that the channel + // will be opened. + OnOpenChannelRequest(ctype string) (channels.Handler, error) +} diff --git a/connection/inboundconnectionhandler.go b/connection/inboundconnectionhandler.go new file mode 100644 index 0000000..a070131 --- /dev/null +++ b/connection/inboundconnectionhandler.go @@ -0,0 +1,61 @@ +package connection + +import ( + "crypto/rsa" + "github.com/s-rah/go-ricochet/channels" + "github.com/s-rah/go-ricochet/policies" + "github.com/s-rah/go-ricochet/utils" +) + +// InboundConnectionHandler is a convieniance wrapper for handling inbound +// connections +type InboundConnectionHandler struct { + connection *Connection +} + +// HandleInboundConnection returns an InboundConnectionHandler given a connection +func HandleInboundConnection(c *Connection) *InboundConnectionHandler { + ich := new(InboundConnectionHandler) + ich.connection = c + return ich +} + +// ProcessAuthAsServer blocks until authentication has succeeded, failed, or the +// connection is closed. A non-nil error is returned in all cases other than successful +// and accepted authentication. +// +// ProcessAuthAsServer cannot be called at the same time as any other call to a Process +// function. Another Process function must be called after this function successfully +// returns to continue handling connection events. +// +// The acceptCallback function is called after receiving a valid authentication proof +// with the client's authenticated hostname and public key. acceptCallback must return +// true to accept authentication and allow the connection to continue, and also returns a +// boolean indicating whether the contact is known and recognized. Unknown contacts will +// assume they are required to send a contact request before any other activity. +func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.PrivateKey, sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)) error { + + ach := new(AutoConnectionHandler) + ach.Init(privateKey, ich.connection.RemoteHostname) + ach.SetServerAuthHandler(sach) + + var authResult channels.AuthChannelResult + go func() { + authResult = ach.WaitForAuthenticationEvent() + ich.connection.Break() + }() + + policy := policies.UnknownPurposeTimeout + err := policy.ExecuteAction(func() error { + return ich.connection.Process(ach) + }) + + if err == nil { + if authResult.Accepted == true { + return nil + } + return utils.ClientFailedToAuthenticateError + } + + return err +} diff --git a/connection/outboundconnectionhandler.go b/connection/outboundconnectionhandler.go new file mode 100644 index 0000000..2cc8b78 --- /dev/null +++ b/connection/outboundconnectionhandler.go @@ -0,0 +1,62 @@ +package connection + +import ( + "crypto/rsa" + "errors" + "github.com/s-rah/go-ricochet/channels" + "github.com/s-rah/go-ricochet/policies" + "log" +) + +// OutboundConnectionHandler is a convieniance wrapper for handling outbound +// connections +type OutboundConnectionHandler struct { + connection *Connection +} + +// HandleOutboundConnection returns an OutboundConnectionHandler given a connection +func HandleOutboundConnection(c *Connection) *OutboundConnectionHandler { + och := new(OutboundConnectionHandler) + och.connection = c + return och +} + +// ProcessAuthAsClient blocks until authentication has succeeded or failed with the +// provided privateKey, or the connection is closed. A non-nil error is returned in all +// cases other than successful authentication. +// +// ProcessAuthAsClient cannot be called at the same time as any other call to a Porcess +// function. Another Process function must be called after this function successfully +// returns to continue handling connection events. +// +// For successful authentication, the `known` return value indicates whether the peer +// accepts us as a known contact. Unknown contacts will generally need to send a contact +// request before any other activity. +func (och *OutboundConnectionHandler) ProcessAuthAsClient(privateKey *rsa.PrivateKey) (bool, error) { + ach := new(AutoConnectionHandler) + ach.Init(privateKey, och.connection.RemoteHostname) + + var result channels.AuthChannelResult + go func() { + err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service", ach) + if err != nil { + return + } + log.Printf("waiting for auth result") + result = ach.WaitForAuthenticationEvent() + log.Printf("received auth result") + och.connection.Break() + }() + + policy := policies.UnknownPurposeTimeout + err := policy.ExecuteAction(func() error { + return och.connection.Process(ach) + }) + + if err == nil { + if result.Accepted == true { + return result.IsKnownContact, nil + } + } + return false, errors.New("authentication was not accepted by the server") +} diff --git a/examples/echobot/main.go b/examples/echobot/main.go index e1f1bf9..d3be6c4 100644 --- a/examples/echobot/main.go +++ b/examples/echobot/main.go @@ -2,49 +2,97 @@ package main import ( "github.com/s-rah/go-ricochet" + "github.com/s-rah/go-ricochet/channels" + "github.com/s-rah/go-ricochet/connection" + "github.com/s-rah/go-ricochet/utils" "log" + "time" ) // EchoBotService is an example service which simply echoes back what a client // sends it. -type EchoBotService struct { - goricochet.StandardRicochetService +type RicochetEchoBot struct { + connection.AutoConnectionHandler + messages chan string } -func (ebs *EchoBotService) OnNewConnection(oc *goricochet.OpenConnection) { - ebs.StandardRicochetService.OnNewConnection(oc) - go oc.Process(&EchoBotConnection{}) +func (echobot *RicochetEchoBot) GetContactDetails() (string, string) { + return "EchoBot", "I LIVE 😈😈!!!!" } -type EchoBotConnection struct { - goricochet.StandardRicochetConnection +func (echobot *RicochetEchoBot) ContactRequest(name string, message string) string { + return "Pending" } -// IsKnownContact is configured to always accept Contact Requests -func (ebc *EchoBotConnection) IsKnownContact(hostname string) bool { +func (echobot *RicochetEchoBot) ContactRequestRejected() { +} +func (echobot *RicochetEchoBot) ContactRequestAccepted() { +} +func (echobot *RicochetEchoBot) ContactRequestError() { +} + +func (echobot *RicochetEchoBot) ChatMessage(messageID uint32, when time.Time, message string) bool { + echobot.messages <- message return true } -// OnContactRequest - we always accept new contact request. -func (ebc *EchoBotConnection) OnContactRequest(channelID int32, nick string, message string) { - ebc.StandardRicochetConnection.OnContactRequest(channelID, nick, message) - ebc.Conn.AckContactRequestOnResponse(channelID, "Accepted") - ebc.Conn.CloseChannel(channelID) +func (echobot *RicochetEchoBot) ChatMessageAck(messageID uint32) { + } -// OnChatMessage we acknowledge the message, grab the message content and send it back - opening -// a new channel if necessary. -func (ebc *EchoBotConnection) OnChatMessage(channelID int32, messageID int32, message string) { - log.Printf("Received Message from %s: %s", ebc.Conn.OtherHostname, message) - ebc.Conn.AckChatMessage(channelID, messageID) - if ebc.Conn.GetChannelType(6) == "none" { - ebc.Conn.OpenChatChannel(6) +func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string) { + + privateKey, _ := utils.LoadPrivateKeyFromFile(privateKeyFile) + echobot.messages = make(chan string) + + echobot.Init(privateKey, hostname) + echobot.RegisterChannelHandler("im.ricochet.contact.request", func() channels.Handler { + contact := new(channels.ContactRequestChannel) + contact.Handler = echobot + return contact + }) + echobot.RegisterChannelHandler("im.ricochet.chat", func() channels.Handler { + chat := new(channels.ChatChannel) + chat.Handler = echobot + return chat + }) + + rc, _ := goricochet.Open(hostname) + known, err := connection.HandleOutboundConnection(rc).ProcessAuthAsClient(privateKey) + if err == nil { + + go rc.Process(echobot) + + if !known { + err := rc.RequestOpenChannel("im.ricochet.contact.request", echobot) + if err != nil { + log.Printf("could not contact %s", err) + } + } + + rc.RequestOpenChannel("im.ricochet.chat", echobot) + for { + message := <-echobot.messages + log.Printf("Received Message: %s", message) + rc.Do(func() error { + log.Printf("Finding Chat Channel") + channel := rc.Channel("im.ricochet.chat", channels.Outbound) + if channel != nil { + log.Printf("Found Chat Channel") + chatchannel, ok := (*channel.Handler).(*channels.ChatChannel) + if ok { + chatchannel.SendMessage(message) + } + } else { + log.Printf("Could not find chat channel") + } + return nil + }) + } } - ebc.Conn.SendMessage(6, message) } func main() { - ricochetService := new(EchoBotService) - ricochetService.Init("./private_key") - ricochetService.Listen(ricochetService, 12345) + echoBot := new(RicochetEchoBot) + echoBot.Connect("private_key", "oqf7z4ot6kuejgam") } diff --git a/handlers.go b/handlers.go deleted file mode 100644 index 083c45f..0000000 --- a/handlers.go +++ /dev/null @@ -1,51 +0,0 @@ -package goricochet - -// ServiceHandler is the interface to handle events for an inbound connection listener -type ServiceHandler interface { - // OnNewConnection is called for inbound connections to the service after protocol - // version negotiation has completed successfully. - OnNewConnection(oc *OpenConnection) - // OnFailedConnection is called for inbound connections to the service which fail - // to successfully complete version negotiation for any reason. - OnFailedConnection(err error) -} - -// ConnectionHandler is the interface to handle events for an open protocol connection, -// whether inbound or outbound. Each OpenConnection will need its own instance of an -// application type implementing ConnectionHandler, which could also be used to store -// application state related to the connection. -type ConnectionHandler interface { - // OnReady is called before OpenConnection.Process() begins from the connection - OnReady(oc *OpenConnection) - // OnDisconnect is called when the connection is closed, just before - // OpenConnection.Process() returns - OnDisconnect() - - // Authentication Management - OnAuthenticationRequest(channelID int32, clientCookie [16]byte) - OnAuthenticationChallenge(channelID int32, serverCookie [16]byte) - OnAuthenticationProof(channelID int32, publicKey []byte, signature []byte) - OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) - - // Contact Management - IsKnownContact(hostname string) bool - OnContactRequest(channelID int32, nick string, message string) - OnContactRequestAck(channelID int32, status string) - - // Managing Channels - OnOpenChannelRequest(channelID int32, channelType string) - OnOpenChannelRequestSuccess(channelID int32) - OnChannelClosed(channelID int32) - - // Chat Messages - OnChatMessage(channelID int32, messageID int32, message string) - OnChatMessageAck(channelID int32, messageID int32) - - // Handle Errors - OnFailedChannelOpen(channelID int32, errorType string) - OnGenericError(channelID int32) - OnUnknownTypeError(channelID int32) - OnUnauthorizedError(channelID int32) - OnBadUsageError(channelID int32) - OnFailedError(channelID int32) -} diff --git a/logo.png b/logo.png deleted file mode 100644 index 16b01c5..0000000 Binary files a/logo.png and /dev/null differ diff --git a/messagebuilder_test.go b/messagebuilder_test.go deleted file mode 100644 index e0083de..0000000 --- a/messagebuilder_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package goricochet - -import "testing" - -func TestOpenChatChannel(t *testing.T) { - messageBuilder := new(MessageBuilder) - _, err := messageBuilder.OpenChannel(1, "im.ricochet.chat") - if err != nil { - t.Errorf("Error building open chat channel message: %s", err) - } - // TODO: More Indepth Test Of Output -} - -func TestOpenContactRequestChannel(t *testing.T) { - messageBuilder := new(MessageBuilder) - _, err := messageBuilder.OpenContactRequestChannel(3, "Nickname", "Message") - if err != nil { - t.Errorf("Error building open contact request channel message: %s", err) - } - // TODO: More Indepth Test Of Output -} - -func TestOpenAuthenticationChannel(t *testing.T) { - messageBuilder := new(MessageBuilder) - _, err := messageBuilder.OpenAuthenticationChannel(1, [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - if err != nil { - t.Errorf("Error building open authentication channel message: %s", err) - } - // TODO: More Indepth Test Of Output -} - -func TestChatMessage(t *testing.T) { - messageBuilder := new(MessageBuilder) - _, err := messageBuilder.ChatMessage("Hello World", 0) - if err != nil { - t.Errorf("Error building chat message: %s", err) - } - // TODO: More Indepth Test Of Output -} diff --git a/openconnection.go b/openconnection.go deleted file mode 100644 index bf6465f..0000000 --- a/openconnection.go +++ /dev/null @@ -1,534 +0,0 @@ -package goricochet - -import ( - "crypto" - "crypto/rsa" - "encoding/asn1" - "github.com/golang/protobuf/proto" - "github.com/s-rah/go-ricochet/auth" - "github.com/s-rah/go-ricochet/chat" - "github.com/s-rah/go-ricochet/contact" - "github.com/s-rah/go-ricochet/control" - "github.com/s-rah/go-ricochet/utils" - "log" - "net" -) - -// OpenConnection encapsulates the state required to maintain a connection to -// a ricochet service. -// Notably OpenConnection does not enforce limits on the channelIDs, channel Assignments -// or the direction of messages. These are considered to be service enforced rules. -// (and services are considered to be the best to define them). -type OpenConnection struct { - conn net.Conn - authHandler map[int32]*AuthenticationHandler - channels map[int32]string - rni utils.RicochetNetworkInterface - - Client bool - IsAuthed bool - MyHostname string - OtherHostname string - Closed bool -} - -// Init initializes a OpenConnection object to a default state. -func (oc *OpenConnection) Init(outbound bool, conn net.Conn) { - oc.conn = conn - oc.authHandler = make(map[int32]*AuthenticationHandler) - oc.channels = make(map[int32]string) - oc.rni = new(utils.RicochetNetwork) - - oc.Client = outbound - oc.IsAuthed = false - oc.MyHostname = "" - oc.OtherHostname = "" -} - -// UnsetChannel removes a type association from the channel. -func (oc *OpenConnection) UnsetChannel(channel int32) { - oc.channels[channel] = "none" -} - -// GetChannelType returns the type of the channel on this connection -func (oc *OpenConnection) GetChannelType(channel int32) string { - if val, ok := oc.channels[channel]; ok { - return val - } - return "none" -} - -func (oc *OpenConnection) setChannel(channel int32, channelType string) { - oc.channels[channel] = channelType -} - -// HasChannel returns true if the connection has a channel of an associated type, false otherwise -func (oc *OpenConnection) HasChannel(channelType string) bool { - for _, val := range oc.channels { - if val == channelType { - return true - } - } - return false -} - -// CloseChannel closes a given channel -// Prerequisites: -// * Must have previously connected to a service -func (oc *OpenConnection) CloseChannel(channel int32) { - oc.UnsetChannel(channel) - oc.rni.SendRicochetPacket(oc.conn, channel, []byte{}) -} - -// Close closes the entire connection -func (oc *OpenConnection) Close() { - oc.conn.Close() - oc.Closed = true -} - -// Authenticate opens an Authentication Channel and send a client cookie -// Prerequisites: -// * Must have previously connected to a service -func (oc *OpenConnection) Authenticate(channel int32) { - defer utils.RecoverFromError() - - oc.authHandler[channel] = new(AuthenticationHandler) - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.OpenAuthenticationChannel(channel, oc.authHandler[channel].GenClientCookie()) - utils.CheckError(err) - - oc.setChannel(channel, "im.ricochet.auth.hidden-service") - oc.rni.SendRicochetPacket(oc.conn, 0, data) -} - -// ConfirmAuthChannel responds to a new authentication request. -// Prerequisites: -// * Must have previously connected to a service -func (oc *OpenConnection) ConfirmAuthChannel(channel int32, clientCookie [16]byte) { - defer utils.RecoverFromError() - - oc.authHandler[channel] = new(AuthenticationHandler) - oc.authHandler[channel].AddClientCookie(clientCookie[:]) - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.ConfirmAuthChannel(channel, oc.authHandler[channel].GenServerCookie()) - utils.CheckError(err) - - oc.setChannel(channel, "im.ricochet.auth.hidden-service") - oc.rni.SendRicochetPacket(oc.conn, 0, data) -} - -// SendProof sends an authentication proof in response to a challenge. -// Prerequisites: -// * Must have previously connected to a service -// * channel must be of type auth -func (oc *OpenConnection) SendProof(channel int32, serverCookie [16]byte, publicKeyBytes []byte, privateKey *rsa.PrivateKey) { - - if oc.authHandler[channel] == nil { - return // NoOp - } - - oc.authHandler[channel].AddServerCookie(serverCookie[:]) - - challenge := oc.authHandler[channel].GenChallenge(oc.MyHostname, oc.OtherHostname) - signature, _ := rsa.SignPKCS1v15(nil, privateKey, crypto.SHA256, challenge) - - defer utils.RecoverFromError() - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.Proof(publicKeyBytes, signature) - utils.CheckError(err) - - oc.rni.SendRicochetPacket(oc.conn, channel, data) -} - -// ValidateProof determines if the given public key and signature align with the -// already established challenge vector for this communication -// Prerequisites: -// * Must have previously connected to a service -// * Client and Server must have already sent their respective cookies (Authenticate and ConfirmAuthChannel) -func (oc *OpenConnection) ValidateProof(channel int32, publicKeyBytes []byte, signature []byte) bool { - - if oc.authHandler[channel] == nil { - return false - } - - provisionalHostname := utils.GetTorHostname(publicKeyBytes) - publicKey := new(rsa.PublicKey) - _, err := asn1.Unmarshal(publicKeyBytes, publicKey) - if err != nil { - return false - } - challenge := oc.authHandler[channel].GenChallenge(provisionalHostname, oc.MyHostname) - err = rsa.VerifyPKCS1v15(publicKey, crypto.SHA256, challenge[:], signature) - if err == nil { - oc.OtherHostname = provisionalHostname - return true - } - return false - -} - -// SendAuthenticationResult responds to an existed authentication Proof -// Prerequisites: -// * Must have previously connected to a service -// * channel must be of type auth -func (oc *OpenConnection) SendAuthenticationResult(channel int32, accepted bool, isKnownContact bool) { - defer utils.RecoverFromError() - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.AuthResult(accepted, isKnownContact) - utils.CheckError(err) - oc.rni.SendRicochetPacket(oc.conn, channel, data) -} - -// OpenChatChannel opens a new chat channel with the given id -// Prerequisites: -// * Must have previously connected to a service -// * If acting as the client, id must be odd, else even -func (oc *OpenConnection) OpenChatChannel(channel int32) { - defer utils.RecoverFromError() - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.OpenChannel(channel, "im.ricochet.chat") - utils.CheckError(err) - - oc.setChannel(channel, "im.ricochet.chat") - oc.rni.SendRicochetPacket(oc.conn, 0, data) -} - -// OpenChannel opens a new chat channel with the given id -// Prerequisites: -// * Must have previously connected to a service -// * If acting as the client, id must be odd, else even -func (oc *OpenConnection) OpenChannel(channel int32, channelType string) { - defer utils.RecoverFromError() - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.OpenChannel(channel, channelType) - utils.CheckError(err) - - oc.setChannel(channel, channelType) - oc.rni.SendRicochetPacket(oc.conn, 0, data) -} - -// AckOpenChannel acknowledges a previously received open channel message -// Prerequisites: -// * Must have previously connected and authenticated to a service -func (oc *OpenConnection) AckOpenChannel(channel int32, channeltype string) { - defer utils.RecoverFromError() - messageBuilder := new(MessageBuilder) - - data, err := messageBuilder.AckOpenChannel(channel) - utils.CheckError(err) - - oc.setChannel(channel, channeltype) - oc.rni.SendRicochetPacket(oc.conn, 0, data) -} - -// RejectOpenChannel acknowledges a rejects a previously received open channel message -// Prerequisites: -// * Must have previously connected -func (oc *OpenConnection) RejectOpenChannel(channel int32, errortype string) { - defer utils.RecoverFromError() - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.RejectOpenChannel(channel, errortype) - utils.CheckError(err) - - oc.rni.SendRicochetPacket(oc.conn, 0, data) -} - -// SendContactRequest initiates a contact request to the server. -// Prerequisites: -// * Must have previously connected and authenticated to a service -func (oc *OpenConnection) SendContactRequest(channel int32, nick string, message string) { - defer utils.RecoverFromError() - - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.OpenContactRequestChannel(channel, nick, message) - utils.CheckError(err) - - oc.setChannel(channel, "im.ricochet.contact.request") - oc.rni.SendRicochetPacket(oc.conn, 0, data) -} - -// AckContactRequestOnResponse responds a contact request from a client -// Prerequisites: -// * Must have previously connected and authenticated to a service -// * Must have previously received a Contact Request -func (oc *OpenConnection) AckContactRequestOnResponse(channel int32, status string) { - defer utils.RecoverFromError() - - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.ReplyToContactRequestOnResponse(channel, status) - utils.CheckError(err) - - oc.setChannel(channel, "im.ricochet.contact.request") - oc.rni.SendRicochetPacket(oc.conn, 0, data) -} - -// AckContactRequest responds to contact request from a client -// Prerequisites: -// * Must have previously connected and authenticated to a service -// * Must have previously received a Contact Request -func (oc *OpenConnection) AckContactRequest(channel int32, status string) { - defer utils.RecoverFromError() - - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.ReplyToContactRequest(channel, status) - utils.CheckError(err) - - oc.setChannel(channel, "im.ricochet.contact.request") - oc.rni.SendRicochetPacket(oc.conn, channel, data) -} - -// AckChatMessage acknowledges a previously received chat message. -// Prerequisites: -// * Must have previously connected and authenticated to a service -// * Must have established a known contact status with the other service -// * Must have received a Chat message on an open im.ricochet.chat channel with the messageID -func (oc *OpenConnection) AckChatMessage(channel int32, messageID int32) { - defer utils.RecoverFromError() - - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.AckChatMessage(messageID) - utils.CheckError(err) - - oc.rni.SendRicochetPacket(oc.conn, channel, data) -} - -// SendMessage sends a Chat Message (message) to a give Channel (channel). -// Prerequisites: -// * Must have previously connected and authenticated to a service -// * Must have established a known contact status with the other service -// * Must have previously opened channel with OpenChanel of type im.ricochet.chat -func (oc *OpenConnection) SendMessage(channel int32, message string) { - defer utils.RecoverFromError() - messageBuilder := new(MessageBuilder) - data, err := messageBuilder.ChatMessage(message, 0) - utils.CheckError(err) - oc.rni.SendRicochetPacket(oc.conn, channel, data) -} - -// Process waits for new messages to arrive from the connection and uses the given -// ConnectionHandler to process them. -func (oc *OpenConnection) Process(handler ConnectionHandler) { - handler.OnReady(oc) - defer oc.Close() - defer handler.OnDisconnect() - - for { - if oc.Closed { - return - } - - packet, err := oc.rni.RecvRicochetPacket(oc.conn) - if err != nil { - oc.Close() - return - } - - if len(packet.Data) == 0 { - handler.OnChannelClosed(packet.Channel) - continue - } - - if packet.Channel == 0 { - - res := new(Protocol_Data_Control.Packet) - err := proto.Unmarshal(packet.Data[:], res) - - if err != nil { - handler.OnGenericError(packet.Channel) - continue - } - - if res.GetOpenChannel() != nil { - opm := res.GetOpenChannel() - - if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" { - // Channel is already in use. - handler.OnBadUsageError(opm.GetChannelIdentifier()) - continue - } - - // If I am a Client, the server can only open even numbered channels - if oc.Client && opm.GetChannelIdentifier()%2 != 0 { - handler.OnBadUsageError(opm.GetChannelIdentifier()) - continue - } - - // If I am a Server, the client can only open odd numbered channels - if !oc.Client && opm.GetChannelIdentifier()%2 != 1 { - handler.OnBadUsageError(opm.GetChannelIdentifier()) - continue - } - - switch opm.GetChannelType() { - case "im.ricochet.auth.hidden-service": - if oc.Client { - // Servers are authed by default and can't auth with hidden-service - handler.OnBadUsageError(opm.GetChannelIdentifier()) - } else if oc.IsAuthed { - // Can't auth if already authed - handler.OnBadUsageError(opm.GetChannelIdentifier()) - } else if oc.HasChannel("im.ricochet.auth.hidden-service") { - // Can't open more than 1 auth channel - handler.OnBadUsageError(opm.GetChannelIdentifier()) - } else { - clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie) - if err == nil { - clientCookieB := [16]byte{} - copy(clientCookieB[:], clientCookie.([]byte)[:]) - handler.OnAuthenticationRequest(opm.GetChannelIdentifier(), clientCookieB) - } else { - // Must include Client Cookie - handler.OnBadUsageError(opm.GetChannelIdentifier()) - } - } - case "im.ricochet.chat": - if !oc.IsAuthed { - // Can't open chat channel if not authorized - handler.OnUnauthorizedError(opm.GetChannelIdentifier()) - } else if !handler.IsKnownContact(oc.OtherHostname) { - // Can't open chat channel if not a known contact - handler.OnUnauthorizedError(opm.GetChannelIdentifier()) - } else { - handler.OnOpenChannelRequest(opm.GetChannelIdentifier(), "im.ricochet.chat") - } - case "im.ricochet.contact.request": - if oc.Client { - // Servers are not allowed to send contact requests - handler.OnBadUsageError(opm.GetChannelIdentifier()) - } else if !oc.IsAuthed { - // Can't open a contact channel if not authed - handler.OnUnauthorizedError(opm.GetChannelIdentifier()) - } else if oc.HasChannel("im.ricochet.contact.request") { - // Only 1 contact channel is allowed to be open at a time - handler.OnBadUsageError(opm.GetChannelIdentifier()) - } else { - contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest) - if err == nil { - contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest) - if check { - handler.OnContactRequest(opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText()) - break - } - } - handler.OnBadUsageError(opm.GetChannelIdentifier()) - } - default: - handler.OnUnknownTypeError(opm.GetChannelIdentifier()) - } - } else if res.GetChannelResult() != nil { - crm := res.GetChannelResult() - if crm.GetOpened() { - switch oc.GetChannelType(crm.GetChannelIdentifier()) { - case "im.ricochet.auth.hidden-service": - serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie) - if err == nil { - serverCookieB := [16]byte{} - copy(serverCookieB[:], serverCookie.([]byte)[:]) - handler.OnAuthenticationChallenge(crm.GetChannelIdentifier(), serverCookieB) - } else { - handler.OnBadUsageError(crm.GetChannelIdentifier()) - } - case "im.ricochet.chat": - handler.OnOpenChannelRequestSuccess(crm.GetChannelIdentifier()) - case "im.ricochet.contact.request": - responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response) - if err == nil { - response, check := responseI.(*Protocol_Data_ContactRequest.Response) - if check { - handler.OnContactRequestAck(crm.GetChannelIdentifier(), response.GetStatus().String()) - break - } - } - handler.OnBadUsageError(crm.GetChannelIdentifier()) - default: - handler.OnBadUsageError(crm.GetChannelIdentifier()) - } - } else { - if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" { - handler.OnFailedChannelOpen(crm.GetChannelIdentifier(), crm.GetCommonError().String()) - } else { - oc.CloseChannel(crm.GetChannelIdentifier()) - } - } - } else { - // Unknown Message - oc.CloseChannel(packet.Channel) - } - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" { - res := new(Protocol_Data_AuthHiddenService.Packet) - err := proto.Unmarshal(packet.Data[:], res) - - if err != nil { - oc.CloseChannel(packet.Channel) - continue - } - - if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs - handler.OnAuthenticationProof(packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature()) - } else if res.GetResult() != nil && oc.Client { // Only Servers Send Results - handler.OnAuthenticationResult(packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) - } else { - // If neither of the above are satisfied we just close the connection - oc.Close() - } - - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" { - - // NOTE: These auth checks should be redundant, however they - // are included here for defense-in-depth if for some reason - // a previously authed connection becomes untrusted / not known and - // the state is not cleaned up. - if !oc.IsAuthed { - // Can't send chat messages if not authorized - handler.OnUnauthorizedError(packet.Channel) - } else if !handler.IsKnownContact(oc.OtherHostname) { - // Can't send chat message if not a known contact - handler.OnUnauthorizedError(packet.Channel) - } else { - res := new(Protocol_Data_Chat.Packet) - err := proto.Unmarshal(packet.Data[:], res) - - if err != nil { - oc.CloseChannel(packet.Channel) - continue - } - - if res.GetChatMessage() != nil { - handler.OnChatMessage(packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText()) - } else if res.GetChatAcknowledge() != nil { - handler.OnChatMessageAck(packet.Channel, int32(res.GetChatMessage().GetMessageId())) - } else { - // If neither of the above are satisfied we just close the connection - oc.Close() - } - } - } else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" { - - // NOTE: These auth checks should be redundant, however they - // are included here for defense-in-depth if for some reason - // a previously authed connection becomes untrusted / not known and - // the state is not cleaned up. - if !oc.Client { - // Clients are not allowed to send contact request responses - handler.OnBadUsageError(packet.Channel) - } else if !oc.IsAuthed { - // Can't send a contact request if not authed - handler.OnBadUsageError(packet.Channel) - } else { - res := new(Protocol_Data_ContactRequest.Response) - err := proto.Unmarshal(packet.Data[:], res) - log.Printf("%v", res) - if err != nil { - oc.CloseChannel(packet.Channel) - continue - } - handler.OnContactRequestAck(packet.Channel, res.GetStatus().String()) - } - } else if oc.GetChannelType(packet.Channel) == "none" { - // Invalid Channel Assignment - oc.CloseChannel(packet.Channel) - } else { - oc.Close() - } - } -} diff --git a/openconnection_test.go b/openconnection_test.go deleted file mode 100644 index 9487e16..0000000 --- a/openconnection_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package goricochet - -import "testing" - -func TestOpenConnectionAuth(t *testing.T) { - -} diff --git a/policies/timeoutpolicy.go b/policies/timeoutpolicy.go new file mode 100644 index 0000000..f342575 --- /dev/null +++ b/policies/timeoutpolicy.go @@ -0,0 +1,32 @@ +package policies + +import ( + "github.com/s-rah/go-ricochet/utils" + "time" +) + +// TimeoutPolicy is a convieance interface for enforcing common timeout patterns +type TimeoutPolicy time.Duration + +// Selection of common timeout policies +const ( + UnknownPurposeTimeout TimeoutPolicy = TimeoutPolicy(15 * time.Second) +) + +// ExecuteAction runs a function and returns an error if it hasn't returned +// by the time specified by TimeoutPolicy +func (tp *TimeoutPolicy) ExecuteAction(action func() error) error { + + c := make(chan error) + go func() { + c <- action() + }() + + tick := time.Tick(time.Duration(*tp)) + select { + case <-tick: + return utils.ActionTimedOutError + case err := <-c: + return err + } +} diff --git a/policies/timeoutpolicy_test.go b/policies/timeoutpolicy_test.go new file mode 100644 index 0000000..f3b9eb4 --- /dev/null +++ b/policies/timeoutpolicy_test.go @@ -0,0 +1,30 @@ +package policies + +import ( + "testing" + "time" +) + +func TestTimeoutPolicy(t *testing.T) { + policy := UnknownPurposeTimeout + result := func() error { + time.Sleep(2 * time.Second) + return nil + } + err := policy.ExecuteAction(result) + if err != nil { + t.Errorf("Action should ahve returned nil: %v", err) + } +} + +func TestTimeoutPolicyExpires(t *testing.T) { + policy := TimeoutPolicy(1 * time.Second) + result := func() error { + time.Sleep(5 * time.Second) + return nil + } + err := policy.ExecuteAction(result) + if err == nil { + t.Errorf("Action should have returned err") + } +} diff --git a/ricochet.go b/ricochet.go index ad5260f..c9e5f5f 100644 --- a/ricochet.go +++ b/ricochet.go @@ -1,158 +1,55 @@ package goricochet import ( - "errors" - "github.com/s-rah/go-ricochet/utils" - "io" - "net" - "sync" + "github.com/s-rah/go-ricochet/utils" + "github.com/s-rah/go-ricochet/connection" + "io" + "net" + "log" ) - -// Connect sets up a client ricochet connection to host e.g. qn6uo4cmsrfv4kzq.onion. If this -// function finished successfully then the connection can be assumed to -// be open and authenticated. -// To specify a local port using the format "127.0.0.1:[port]|ricochet-id". -func Connect(host string) (*OpenConnection, error) { - networkResolver := utils.NetworkResolver{} - conn, host, err := networkResolver.Resolve(host) - - if err != nil { - return nil, err - } - - return Open(conn, host) -} - // Open establishes a protocol session on an established net.Conn, and returns a new // OpenConnection instance representing this connection. On error, the connection // will be closed. This function blocks until version negotiation has completed. // The application should call Process() on the returned OpenConnection to continue // handling protocol messages. -func Open(conn net.Conn, remoteHostname string) (*OpenConnection, error) { - oc, err := negotiateVersion(conn, true) - if err != nil { - conn.Close() - return nil, err - } - oc.OtherHostname = remoteHostname - return oc, nil +func Open(remoteHostname string) (*connection.Connection, error) { + networkResolver := utils.NetworkResolver{} + log.Printf("Connecting...") + conn, remoteHostname, err := networkResolver.Resolve(remoteHostname) + + if err != nil { + return nil, err + } + + log.Printf("Connected...negotiating version") + rc, err := negotiateVersion(conn, remoteHostname) + if err != nil { + conn.Close() + return nil, err + } + log.Printf("Connected...negotiated version") + return rc, nil } -// Serve accepts incoming connections on a net.Listener, negotiates protocol, -// and calls methods of the ServiceHandler to handle inbound connections. All -// calls to ServiceHandler happen on the caller's goroutine. The listener can -// be closed at any time to close the service. -func Serve(ln net.Listener, handler ServiceHandler) error { - defer ln.Close() - connChannel := make(chan interface{}) - listenErrorChannel := make(chan error) +// negotiate version takes an open network connection and executes +// the ricochet version negotiation procedure. +func negotiateVersion(conn net.Conn, remoteHostname string) (*connection.Connection, error) { + versions := []byte{0x49, 0x4D, 0x01, 0x01} + if n, err := conn.Write(versions); err != nil || n < len(versions) { + return nil, utils.VersionNegotiationError + } - go func() { - var pending sync.WaitGroup - for { - conn, err := ln.Accept() - if err != nil { - // Wait for pending connections before returning an error; this - // prevents abandoned goroutines when the outer loop stops reading - // from connChannel. - pending.Wait() - listenErrorChannel <- err - close(connChannel) - return - } + res := make([]byte, 1) + if _, err := io.ReadAtLeast(conn, res, len(res)); err != nil { + return nil, utils.VersionNegotiationError + } - pending.Add(1) - go func() { - defer pending.Done() - oc, err := negotiateVersion(conn, false) - if err != nil { - conn.Close() - connChannel <- err - } else { - connChannel <- oc - } - }() - } - }() - - var listenErr error - for { - select { - case err := <-listenErrorChannel: - // Remember error, wait for connChannel to close - listenErr = err - - case result, ok := <-connChannel: - if !ok { - return listenErr - } - - switch v := result.(type) { - case *OpenConnection: - handler.OnNewConnection(v) - case error: - handler.OnFailedConnection(v) - } - } - } - - return nil + if res[0] != 0x01 { + return nil, utils.VersionNegotiationFailed + } + rc := connection.NewOutboundConnection(conn,remoteHostname) + return rc, nil } -// Perform version negotiation on the connection, and create an OpenConnection if successful -func negotiateVersion(conn net.Conn, outbound bool) (*OpenConnection, error) { - versions := []byte{0x49, 0x4D, 0x01, 0x01} - // Outbound side of the connection sends a list of supported versions - if outbound { - if n, err := conn.Write(versions); err != nil || n < len(versions) { - return nil, err - } - - res := make([]byte, 1) - if _, err := io.ReadAtLeast(conn, res, len(res)); err != nil { - return nil, err - } - - if res[0] != 0x01 { - return nil, errors.New("unsupported protocol version") - } - } else { - // Read version response header - header := make([]byte, 3) - if _, err := io.ReadAtLeast(conn, header, len(header)); err != nil { - return nil, err - } - - if header[0] != versions[0] || header[1] != versions[1] || header[2] < 1 { - return nil, errors.New("invalid protocol response") - } - - // Read list of supported versions (which is header[2] bytes long) - versionList := make([]byte, header[2]) - if _, err := io.ReadAtLeast(conn, versionList, len(versionList)); err != nil { - return nil, err - } - - selectedVersion := byte(0xff) - for _, v := range versionList { - if v == 0x01 { - selectedVersion = v - break - } - } - - if n, err := conn.Write([]byte{selectedVersion}); err != nil || n < 1 { - return nil, err - } - - if selectedVersion == 0xff { - return nil, errors.New("no supported protocol version") - } - } - - oc := new(OpenConnection) - oc.Init(outbound, conn) - return oc, nil -} diff --git a/ricochet_test.go b/ricochet_test.go new file mode 100644 index 0000000..1c55f78 --- /dev/null +++ b/ricochet_test.go @@ -0,0 +1,70 @@ +package goricochet + +import ( + "testing" + "github.com/s-rah/go-ricochet/utils" + "net" + "time" +) + + +func SimpleServer() { + ln,_ := net.Listen("tcp", "127.0.0.1:11000") + conn,_ := ln.Accept() + b := make([]byte, 4) + n,err := conn.Read(b) + if n == 4 && err == nil { + conn.Write([]byte{0x01}) + } + conn.Close() +} + +func BadVersionNegotiation() { + ln,_ := net.Listen("tcp", "127.0.0.1:11001") + conn,_ := ln.Accept() + // We are already testing negotiation bytes, we don't care, just send a termination. + conn.Write([]byte{0x00}) + conn.Close() +} + +func NotRicochetServer() { + ln,_ := net.Listen("tcp", "127.0.0.1:11002") + conn,_ := ln.Accept() + conn.Close() +} + +func TestRicochet(t *testing.T) { + go SimpleServer() + // Wait for Server to Initialize + time.Sleep(time.Second) + + rc,err := Open("127.0.0.1:11000|abcdefghijklmno.onion") + if err == nil { + if rc.IsInbound { + t.Errorf("RicochetConnection declares itself as an Inbound connection after an Outbound attempt...that shouldn't happen") + } + return + } + t.Errorf("RicochetProtocol: Open Failed: %v", err) +} + +func TestBadVersionNegotiation(t*testing.T) { + go BadVersionNegotiation() + time.Sleep(time.Second) + + _,err := Open("127.0.0.1:11001|abcdefghijklmno.onion") + if err != utils.VersionNegotiationFailed { + t.Errorf("RicochetProtocol: Server Had No Correct Version - Should Have Failed: err = %v", err) + } +} + + +func TestNotARicochetServer(t*testing.T) { + go NotRicochetServer() + time.Sleep(time.Second) + + _,err := Open("127.0.0.1:11002|abcdefghijklmno.onion") + if err != utils.VersionNegotiationError { + t.Errorf("RicochetProtocol: Server Had No Correct Version - Should Have Failed: err = %v", err) + } +} diff --git a/standardricochetservice.go b/standardricochetservice.go deleted file mode 100644 index 670f63f..0000000 --- a/standardricochetservice.go +++ /dev/null @@ -1,210 +0,0 @@ -package goricochet - -import ( - "crypto/rsa" - "crypto/x509" - "encoding/asn1" - "encoding/pem" - "errors" - "github.com/s-rah/go-ricochet/utils" - "io/ioutil" - "log" - "net" - "strconv" -) - -// StandardRicochetService implements all the necessary flows to implement a -// minimal, protocol compliant Ricochet Service. It can be built on by other -// applications to produce automated riochet applications, and is a useful -// example for other implementations. -type StandardRicochetService struct { - PrivateKey *rsa.PrivateKey - serverHostname string -} - -// StandardRicochetConnection implements the ConnectionHandler interface -// to handle events on connections. An instance of StandardRicochetConnection -// is created for each OpenConnection by the HandleConnection method. -type StandardRicochetConnection struct { - Conn *OpenConnection - PrivateKey *rsa.PrivateKey -} - -// Init initializes a StandardRicochetService with the cryptographic key given -// by filename. -func (srs *StandardRicochetService) Init(filename string) error { - pemData, err := ioutil.ReadFile(filename) - - if err != nil { - return errors.New("Could not setup ricochet service: could not read private key") - } - - block, _ := pem.Decode(pemData) - if block == nil || block.Type != "RSA PRIVATE KEY" { - return errors.New("Could not setup ricochet service: no valid PEM data found") - } - - srs.PrivateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return errors.New("Could not setup ricochet service: could not parse private key") - } - - publicKeyBytes, _ := asn1.Marshal(rsa.PublicKey{ - N: srs.PrivateKey.PublicKey.N, - E: srs.PrivateKey.PublicKey.E, - }) - - srs.serverHostname = utils.GetTorHostname(publicKeyBytes) - log.Printf("Initialised ricochet service for %s", srs.serverHostname) - - return nil -} - -// Listen starts listening for service connections on localhost `port`. -func (srs *StandardRicochetService) Listen(handler ServiceHandler, port int) { - ln, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(port)) - if err != nil { - log.Printf("Cannot Listen on Port %v", port) - return - } - - Serve(ln, handler) -} - -// Connect initiates a new client connection to `hostname`, which must be in one -// of the forms accepted by the goricochet.Connect() method. -func (srs *StandardRicochetService) Connect(hostname string) (*OpenConnection, error) { - log.Printf("Connecting to...%s", hostname) - oc, err := Connect(hostname) - if err != nil { - return nil, errors.New("Could not connect to: " + hostname + " " + err.Error()) - } - oc.MyHostname = srs.serverHostname - return oc, nil -} - -// OnNewConnection is called for new inbound connections to our service. This -// method implements the ServiceHandler interface. -func (srs *StandardRicochetService) OnNewConnection(oc *OpenConnection) { - oc.MyHostname = srs.serverHostname -} - -// OnFailedConnection is called for inbound connections that fail to successfully -// complete version negotiation for any reason. This method implements the -// ServiceHandler interface. -func (srs *StandardRicochetService) OnFailedConnection(err error) { - log.Printf("Inbound connection failed: %s", err) -} - -// ------ - -// OnReady is called when a client or server sucessfully passes Version Negotiation. -func (src *StandardRicochetConnection) OnReady(oc *OpenConnection) { - src.Conn = oc - if oc.Client { - log.Printf("Successfully connected to %s", oc.OtherHostname) - oc.IsAuthed = true // Connections to Servers are Considered Authenticated by Default - oc.Authenticate(1) - } else { - log.Printf("Inbound connection received") - } -} - -// OnDisconnect is called when a connection is closed -func (src *StandardRicochetConnection) OnDisconnect() { - log.Printf("Disconnected from %s", src.Conn.OtherHostname) -} - -// OnAuthenticationRequest is called when a client requests Authentication -func (src *StandardRicochetConnection) OnAuthenticationRequest(channelID int32, clientCookie [16]byte) { - src.Conn.ConfirmAuthChannel(channelID, clientCookie) -} - -// OnAuthenticationChallenge constructs a valid authentication challenge to the serverCookie -func (src *StandardRicochetConnection) OnAuthenticationChallenge(channelID int32, serverCookie [16]byte) { - // DER Encode the Public Key - publickeyBytes, _ := asn1.Marshal(rsa.PublicKey{ - N: src.PrivateKey.PublicKey.N, - E: src.PrivateKey.PublicKey.E, - }) - src.Conn.SendProof(1, serverCookie, publickeyBytes, src.PrivateKey) -} - -// OnAuthenticationProof is called when a client sends Proof for an existing authentication challenge -func (src *StandardRicochetConnection) OnAuthenticationProof(channelID int32, publicKey []byte, signature []byte) { - result := src.Conn.ValidateProof(channelID, publicKey, signature) - // This implementation always sends 'true', indicating that the contact is known - src.Conn.SendAuthenticationResult(channelID, result, true) - src.Conn.IsAuthed = result - src.Conn.CloseChannel(channelID) -} - -// OnAuthenticationResult is called once a server has returned the result of the Proof Verification -func (src *StandardRicochetConnection) OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) { - src.Conn.IsAuthed = result -} - -// IsKnownContact allows a caller to determine if a hostname an authorized contact. -func (src *StandardRicochetConnection) IsKnownContact(hostname string) bool { - return false -} - -// OnContactRequest is called when a client sends a new contact request -func (src *StandardRicochetConnection) OnContactRequest(channelID int32, nick string, message string) { -} - -// OnContactRequestAck is called when a server sends a reply to an existing contact request -func (src *StandardRicochetConnection) OnContactRequestAck(channelID int32, status string) { -} - -// OnOpenChannelRequest is called when a client or server requests to open a new channel -func (src *StandardRicochetConnection) OnOpenChannelRequest(channelID int32, channelType string) { - src.Conn.AckOpenChannel(channelID, channelType) -} - -// OnOpenChannelRequestSuccess is called when a client or server responds to an open channel request -func (src *StandardRicochetConnection) OnOpenChannelRequestSuccess(channelID int32) { -} - -// OnChannelClosed is called when a client or server closes an existing channel -func (src *StandardRicochetConnection) OnChannelClosed(channelID int32) { -} - -// OnChatMessage is called when a new chat message is received. -func (src *StandardRicochetConnection) OnChatMessage(channelID int32, messageID int32, message string) { - src.Conn.AckChatMessage(channelID, messageID) -} - -// OnChatMessageAck is called when a new chat message is ascknowledged. -func (src *StandardRicochetConnection) OnChatMessageAck(channelID int32, messageID int32) { -} - -// OnFailedChannelOpen is called when a server fails to open a channel -func (src *StandardRicochetConnection) OnFailedChannelOpen(channelID int32, errorType string) { - src.Conn.UnsetChannel(channelID) -} - -// OnGenericError is called when a generalized error is returned from the peer -func (src *StandardRicochetConnection) OnGenericError(channelID int32) { - src.Conn.RejectOpenChannel(channelID, "GenericError") -} - -//OnUnknownTypeError is called when an unknown type error is returned from the peer -func (src *StandardRicochetConnection) OnUnknownTypeError(channelID int32) { - src.Conn.RejectOpenChannel(channelID, "UnknownTypeError") -} - -// OnUnauthorizedError is called when an unathorized error is returned from the peer -func (src *StandardRicochetConnection) OnUnauthorizedError(channelID int32) { - src.Conn.RejectOpenChannel(channelID, "UnauthorizedError") -} - -// OnBadUsageError is called when a bad usage error is returned from the peer -func (src *StandardRicochetConnection) OnBadUsageError(channelID int32) { - src.Conn.RejectOpenChannel(channelID, "BadUsageError") -} - -// OnFailedError is called when a failed error is returned from the peer -func (src *StandardRicochetConnection) OnFailedError(channelID int32) { - src.Conn.RejectOpenChannel(channelID, "FailedError") -} diff --git a/standardricochetservice_bad_usage_error_test.go b/standardricochetservice_bad_usage_error_test.go deleted file mode 100644 index 03778f9..0000000 --- a/standardricochetservice_bad_usage_error_test.go +++ /dev/null @@ -1,120 +0,0 @@ -package goricochet - -import "testing" -import "time" -import "log" - -type TestBadUsageService struct { - StandardRicochetService - BadUsageErrorCount int - UnknownTypeErrorCount int - ChannelClosed int -} - -type TestBadUsageConnection struct { - StandardRicochetConnection - Service *TestBadUsageService -} - -func (ts *TestBadUsageService) OnNewConnection(oc *OpenConnection) { - ts.StandardRicochetService.OnNewConnection(oc) - go oc.Process(&TestBadUsageConnection{Service: ts}) -} - -func (tc *TestBadUsageConnection) OnReady(oc *OpenConnection) { - if oc.Client { - oc.OpenChannel(17, "im.ricochet.auth.hidden-service") // Fail because no Extension - } - tc.StandardRicochetConnection.OnReady(oc) - if oc.Client { - oc.Authenticate(103) // Should Fail because cannot open more than one auth-hidden-service channel at once - } -} - -func (tc *TestBadUsageConnection) OnAuthenticationProof(channelID int32, publicKey []byte, signature []byte) { - tc.Conn.Authenticate(2) // Try to authenticate again...will fail servers don't auth - tc.Conn.SendContactRequest(4, "test", "test") // Only clients can send contact requests - tc.StandardRicochetConnection.OnAuthenticationProof(channelID, publicKey, signature) - tc.Conn.OpenChatChannel(5) // Fail because server can only open even numbered channels - tc.Conn.OpenChatChannel(3) // Fail because already in use... -} - -// OnContactRequest is called when a client sends a new contact request -func (tc *TestBadUsageConnection) OnContactRequest(channelID int32, nick string, message string) { - tc.Conn.AckContactRequestOnResponse(channelID, "Pending") // Done to keep the contact request channel open -} - -func (tc *TestBadUsageConnection) OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) { - tc.StandardRicochetConnection.OnAuthenticationResult(channelID, result, isKnownContact) - - tc.Conn.OpenChatChannel(3) // Succeed - tc.Conn.OpenChatChannel(3) // Should fail as duplicate (channel already in use) - - tc.Conn.OpenChatChannel(6) // Should fail because clients are not allowed to open even numbered channels - - tc.Conn.SendMessage(101, "test") // Should fail as 101 doesn't exist - - tc.Conn.Authenticate(1) // Try to authenticate again...will fail because we have already authenticated - - tc.Conn.OpenChannel(19, "im.ricochet.contact.request") // Will Fail - tc.Conn.SendContactRequest(11, "test", "test") // Succeed - tc.Conn.SendContactRequest(13, "test", "test") // Trigger singleton contact request check - - tc.Conn.OpenChannel(15, "im.ricochet.not-a-real-type") // Fail UnknownType -} - -// OnChannelClose is called when a client or server closes an existing channel -func (tc *TestBadUsageConnection) OnChannelClosed(channelID int32) { - if channelID == 101 { - log.Printf("Received Channel Closed: %v", channelID) - tc.Service.ChannelClosed++ - } -} - -func (tc *TestBadUsageConnection) OnFailedChannelOpen(channelID int32, errorType string) { - log.Printf("Failed Channel Open %v %v", channelID, errorType) - tc.StandardRicochetConnection.OnFailedChannelOpen(channelID, errorType) - if errorType == "BadUsageError" { - tc.Service.BadUsageErrorCount++ - } else if errorType == "UnknownTypeError" { - tc.Service.UnknownTypeErrorCount++ - } -} - -func TestBadUsageServer(t *testing.T) { - ricochetService := new(TestBadUsageService) - err := ricochetService.Init("./private_key") - - if err != nil { - t.Errorf("Could not initate ricochet service: %v", err) - } - - go ricochetService.Listen(ricochetService, 9884) - - time.Sleep(time.Second * 2) - - ricochetService2 := new(TestBadUsageService) - err = ricochetService2.Init("./private_key") - - if err != nil { - t.Errorf("Could not initate ricochet service: %v", err) - } - - go ricochetService2.Listen(ricochetService2, 9885) - oc, err := ricochetService2.Connect("127.0.0.1:9884|kwke2hntvyfqm7dr") - if err != nil { - t.Errorf("Could not connect to ricochet service: %v", err) - } - go oc.Process(&TestBadUsageConnection{ - Service: ricochetService2, - StandardRicochetConnection: StandardRicochetConnection{ - PrivateKey: ricochetService2.PrivateKey, - }, - }) - - time.Sleep(time.Second * 3) - if ricochetService2.ChannelClosed != 1 || ricochetService2.BadUsageErrorCount != 7 || ricochetService.BadUsageErrorCount != 4 || ricochetService2.UnknownTypeErrorCount != 1 { - t.Errorf("Invalid number of errors seen Closed:%v, Client Bad Usage:%v UnknownTypeErrorCount: %v, Server Bad Usage: %v ", ricochetService2.ChannelClosed, ricochetService2.BadUsageErrorCount, ricochetService2.UnknownTypeErrorCount, ricochetService.BadUsageErrorCount) - } - -} diff --git a/standardricochetservice_test.go b/standardricochetservice_test.go deleted file mode 100644 index 10ab2fe..0000000 --- a/standardricochetservice_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package goricochet - -import "testing" -import "time" -import "log" - -type TestService struct { - StandardRicochetService -} - -func (ts *TestService) OnNewConnection(oc *OpenConnection) { - ts.StandardRicochetService.OnNewConnection(oc) - go oc.Process(&TestConnection{}) -} - -type TestConnection struct { - StandardRicochetConnection - KnownContact bool // Mocking contact request -} - -func (tc *TestConnection) IsKnownContact(hostname string) bool { - return tc.KnownContact -} - -func (tc *TestConnection) OnAuthenticationProof(channelID int32, publicKey, signature []byte) { - result := tc.Conn.ValidateProof(channelID, publicKey, signature) - tc.Conn.SendAuthenticationResult(channelID, result, tc.KnownContact) - tc.Conn.IsAuthed = result - tc.Conn.CloseChannel(channelID) -} - -func (tc *TestConnection) OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) { - tc.StandardRicochetConnection.OnAuthenticationResult(channelID, result, isKnownContact) - if !isKnownContact { - log.Printf("Sending Contact Request") - tc.Conn.SendContactRequest(3, "test", "test") - } -} - -func (tc *TestConnection) OnContactRequest(channelID int32, nick string, message string) { - tc.StandardRicochetConnection.OnContactRequest(channelID, nick, message) - tc.Conn.AckContactRequestOnResponse(channelID, "Pending") - tc.Conn.AckContactRequest(channelID, "Accepted") - tc.KnownContact = true - tc.Conn.CloseChannel(channelID) -} - -func (tc *TestConnection) OnOpenChannelRequestSuccess(channelID int32) { - tc.StandardRicochetConnection.OnOpenChannelRequestSuccess(channelID) - tc.Conn.SendMessage(channelID, "TEST MESSAGE") -} - -func (tc *TestConnection) OnContactRequestAck(channelID int32, status string) { - tc.StandardRicochetConnection.OnContactRequestAck(channelID, status) - if status == "Accepted" { - log.Printf("Got accepted contact request") - tc.KnownContact = true - tc.Conn.OpenChatChannel(5) - } else if status == "Pending" { - log.Printf("Got pending contact request") - } -} - -func (tc *TestConnection) OnChatMessage(channelID int32, messageID int32, message string) { - tc.StandardRicochetConnection.OnChatMessage(channelID, messageID, message) - if message == "TEST MESSAGE" { - receivedMessage = true - } -} - -var receivedMessage bool - -func TestServer(t *testing.T) { - ricochetService := new(TestService) - err := ricochetService.Init("./private_key") - - if err != nil { - t.Errorf("Could not initate ricochet service: %v", err) - } - - go ricochetService.Listen(ricochetService, 9878) - - time.Sleep(time.Second * 2) - - ricochetService2 := new(TestService) - err = ricochetService2.Init("./private_key") - - if err != nil { - t.Errorf("Could not initate ricochet service: %v", err) - } - - go ricochetService2.Listen(ricochetService2, 9879) - oc, err := ricochetService2.Connect("127.0.0.1:9878|kwke2hntvyfqm7dr") - if err != nil { - t.Errorf("Could not connect to ricochet service: %v", err) - } - testClient := &TestConnection{ - StandardRicochetConnection: StandardRicochetConnection{ - PrivateKey: ricochetService2.PrivateKey, - }, - } - go oc.Process(testClient) - - time.Sleep(time.Second * 5) // Wait a bit longer - if !receivedMessage { - t.Errorf("Test server did not receive message") - } -} - -func TestServerInvalidKey(t *testing.T) { - ricochetService := new(TestService) - err := ricochetService.Init("./private_key.does.not.exist") - - if err == nil { - t.Errorf("Should not have initate ricochet service, private key should not exist") - } -} - -func TestServerCouldNotConnect(t *testing.T) { - ricochetService := new(TestService) - err := ricochetService.Init("./private_key") - if err != nil { - t.Errorf("Could not initate ricochet service: %v", err) - } - _, err = ricochetService.Connect("127.0.0.1:65535|kwke2hntvyfqm7dr") - if err == nil { - t.Errorf("Should not have been been able to connect to 127.0.0.1:65535|kwke2hntvyfqm7dr") - } -} diff --git a/standardricochetservice_unauth_test.go b/standardricochetservice_unauth_test.go deleted file mode 100644 index b839878..0000000 --- a/standardricochetservice_unauth_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package goricochet - -import "testing" -import "time" -import "log" - -// The purpose of this test is to exercise the Unauthorized Error flows that occur -// when a client attempts to open a Chat Channel or Send a Contact Reuqest before Authentication -// itself with the Service. - -type TestUnauthorizedService struct { - StandardRicochetService -} - -func (ts *TestUnauthorizedService) OnNewConnection(oc *OpenConnection) { - go oc.Process(&StandardRicochetConnection{}) -} - -type TestUnauthorizedConnection struct { - StandardRicochetConnection - FailedToOpen int -} - -func (tc *TestUnauthorizedConnection) OnReady(oc *OpenConnection) { - tc.StandardRicochetConnection.OnReady(oc) - if oc.Client { - log.Printf("Attempting Authentication Not Authorized") - oc.IsAuthed = true // Connections to Servers are Considered Authenticated by Default - // REMOVED Authenticate - oc.OpenChatChannel(5) - oc.SendContactRequest(3, "test", "test") - } -} - -func (tc *TestUnauthorizedConnection) OnFailedChannelOpen(channelID int32, errorType string) { - tc.Conn.UnsetChannel(channelID) - if errorType == "UnauthorizedError" { - tc.FailedToOpen++ - } -} - -func TestUnauthorizedClientReject(t *testing.T) { - ricochetService := new(TestService) - err := ricochetService.Init("./private_key") - - if err != nil { - t.Errorf("Could not initate ricochet service: %v", err) - } - - go ricochetService.Listen(ricochetService, 9880) - - time.Sleep(time.Second * 2) - - ricochetService2 := new(TestUnauthorizedService) - err = ricochetService2.Init("./private_key") - - if err != nil { - t.Errorf("Could not initate ricochet service: %v", err) - } - - go ricochetService2.Listen(ricochetService2, 9881) - oc, err := ricochetService2.Connect("127.0.0.1:9880|kwke2hntvyfqm7dr") - if err != nil { - t.Errorf("Could not connect to ricochet service: %v", err) - } - connectionHandler := &TestUnauthorizedConnection{ - StandardRicochetConnection: StandardRicochetConnection{ - PrivateKey: ricochetService2.PrivateKey, - }, - } - go oc.Process(connectionHandler) - - time.Sleep(time.Second * 2) - if connectionHandler.FailedToOpen != 2 { - t.Errorf("Test server did not reject open channels with unauthorized error") - } - -} diff --git a/standardricochetservice_unknown_contact_test.go b/standardricochetservice_unknown_contact_test.go deleted file mode 100644 index 2e7a702..0000000 --- a/standardricochetservice_unknown_contact_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package goricochet - -import "testing" -import "time" -import "log" - -type TestUnknownContactService struct { - StandardRicochetService -} - -func (ts *TestUnknownContactService) OnNewConnection(oc *OpenConnection) { - go oc.Process(&TestUnknownContactConnection{}) -} - -type TestUnknownContactConnection struct { - StandardRicochetConnection - FailedToOpen bool -} - -func (tc *TestUnknownContactConnection) IsKnownContact(hostname string) bool { - return false -} - -func (tc *TestUnknownContactConnection) OnAuthenticationProof(channelID int32, publicKey, signature []byte) { - result := tc.Conn.ValidateProof(channelID, publicKey, signature) - tc.Conn.SendAuthenticationResult(channelID, result, false) - tc.Conn.IsAuthed = result - tc.Conn.CloseChannel(channelID) -} - -func (tc *TestUnknownContactConnection) OnAuthenticationResult(channelID int32, result bool, isKnownContact bool) { - log.Printf("Authentication Result") - tc.StandardRicochetConnection.OnAuthenticationResult(channelID, result, isKnownContact) - tc.Conn.OpenChatChannel(5) -} - -func (tc *TestUnknownContactConnection) OnFailedChannelOpen(channelID int32, errorType string) { - log.Printf("Failed Channel Open %v", errorType) - tc.Conn.UnsetChannel(channelID) - if errorType == "UnauthorizedError" { - tc.FailedToOpen = true - } -} - -func TestUnknownContactServer(t *testing.T) { - ricochetService := new(TestUnknownContactService) - err := ricochetService.Init("./private_key") - - if err != nil { - t.Errorf("Could not initate ricochet service: %v", err) - } - - go ricochetService.Listen(ricochetService, 9882) - - time.Sleep(time.Second * 2) - - oc, err := ricochetService.Connect("127.0.0.1:9882|kwke2hntvyfqm7dr") - if err != nil { - t.Errorf("Could not connect to ricochet service: %v", err) - } - connectionHandler := &TestUnknownContactConnection{ - StandardRicochetConnection: StandardRicochetConnection{ - PrivateKey: ricochetService.PrivateKey, - }, - } - go oc.Process(connectionHandler) - - time.Sleep(time.Second * 2) - if !connectionHandler.FailedToOpen { - t.Errorf("Test server did receive message should have failed") - } - -} diff --git a/testing/private_key b/testing/private_key new file mode 100644 index 0000000..40b9757 --- /dev/null +++ b/testing/private_key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXgIBAAKBgQC3xEJBH4oVFaotPJw6dezx67Gv4Xukw8CZRGqNFO8yF7Rejtcj +/0RTqqZwj6H6FjxY60dgYnN6IphW0juemNZhxOXeM/5Gb5xO+kWGi5Qt87aSDxnA +MDLgqw79ihuD3m1C1TBz0olmjXPU1VtadZuZcVBST7SLs2/k55GNNr7BoQIDAQAB +AoGBAK3ybVCdnSQWLM7DJ5LC23Wnx7sXceVlkiLCOyWuYjiFbatwBD/DupaD2yaD +HyzN7XOxyg93QZ2jr5XHTL30KEAn/3akNBsX3sjHZnjVfTwD5+oZKd7HYMMxekWf +87TIx2IHvGEo2NaFMLkEZ5TX3Gre8CYOofjFcpj4661ZfYp9AkEA9I0EmQX26ibs +CRGkwPuEj5q5N/PmIHgMWr1pepOlmzJjnxy6SI3NUwmzKrqM6YUM8loSywqfVMrJ +RVzA5jp76wJBAMBeu2hS8KcUTIu66j0pXMhI5wDA3yLiO53TEMwufCPXcaWUMH+e +5AIPL7aZ8ouf895OH0TZKxPNMnbrJ+5F0aMCQDoi/CDUxipMLnjJdP1bzdvF0Jp4 +pRC6+VTpCpZVW11V0VEWJ0LwUwuWlr1ls/If60ACIc2bLN2fh9Gxhzo0VRkCQQCS +nKCAVhYLgLEGHaLAknGgQ8+rB1QIphuBoYc/1n3OYzi+VT7RRSvJVgGrTZFJUNLw +LuIt+sWWBeHcOETqmFO5AkEAwwfcxs8QZtX6hCj2MTPi8Q28LIoA/M6eAqYc2I0B +eXxf2J2Qco7sMmBLr1Jp3jZNd5W2fMtlhUZAomOj4piVOA== +-----END RSA PRIVATE KEY----- diff --git a/testing/private_key_auth_fail_test b/testing/private_key_auth_fail_test new file mode 100644 index 0000000..10341f7 --- /dev/null +++ b/testing/private_key_auth_fail_test @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQC9eXEz2sONLCHcaW3OR2kB1fwp+DkQYC74J4FkrdbuSLoPi/fZ +l0bRQZXKprZGhQsH0z1ERuD5wJD/XDws3XdIJuiGw8wEttwFe8lbsBsRedmjqsAy +NukE1gZDoVYAwYgyLz7Obch7m+2h4M42uMDzyGno4nXKIV/1hTfLJvqw6QIDAQAB +AoGADp+Kzxe5M/IOAvbYFK2KOywKtCqGLO9fcKOL5vtLtURDp+ODk3WLb6cCKovH +UZX/DfGNrvFRd7UW+75gno3RIMxbdyC8AcKNz8jnYzSpG2/tXL8LNAZxV5OdbxG3 +S2iVB/rOt49ilH2WcaqUkSqL0+goPLcJy2k/owV0aPEOUwECQQDsTdHbkYt7cSKn +aJtIRV1j3M1Tzu7ZJYLzDF5S0VECP80Gb9gCpMPSt45hGk6AzMGZFCImi9vmiW2c +TzFgLHbZAkEAzURjG0o9YRhesZkg+PoJ33zakg+Tp/6FYY73eBqLg71iO2YS9YIR +DwJ9IG//V8oqFm0dhW20LLbvTqtWyspgkQJBAN5ai7I0Ti+l0Zn9kMB8pNgnGP5X +peCmr4XMiaUcWUHojyATdgtmxu0s08kDXANOqI1GqKvkxtMzVfTTf/6jWGECQQCY +e3DT2PZ3pk7Rx1sDGVs0Nd94GTIq3ZvfuQCEq9Nv7cOHNHBpCFH7wHGLIyef44IY +Xr5LXA84GDz1R7qVsnjBAkB1qYel38r3NoMvVLhCUh2HLZSTxPF9V7iE+5OvakIJ ++Glb45PyloFIobv1yQoIOJlu+uoilGRbOiMUVG1uS0Tj +-----END RSA PRIVATE KEY----- diff --git a/tests.sh b/testing/tests.sh similarity index 61% rename from tests.sh rename to testing/tests.sh index 3ef2b4c..0da8fbe 100755 --- a/tests.sh +++ b/testing/tests.sh @@ -4,6 +4,9 @@ set -e pwd go test -coverprofile=main.cover.out -v . go test -coverprofile=utils.cover.out -v ./utils +go test -coverprofile=channels.cover.out -v ./channels +go test -coverprofile=connection.cover.out -v ./connection +go test -coverprofile=policies.cover.out -v ./policies echo "mode: set" > coverage.out && cat *.cover.out | grep -v mode: | sort -r | \ awk '{if($1 != last) {print $0;last=$1}}' >> coverage.out rm -rf *.cover.out diff --git a/utils/crypto.go b/utils/crypto.go new file mode 100644 index 0000000..2f10c9c --- /dev/null +++ b/utils/crypto.go @@ -0,0 +1,25 @@ +package utils + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "io/ioutil" +) + +// LoadPrivateKeyFromFile loads a private key from a file... +func LoadPrivateKeyFromFile(filename string) (*rsa.PrivateKey, error) { + pemData, err := ioutil.ReadFile(filename) + + if err != nil { + return nil, err + } + + block, _ := pem.Decode(pemData) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return nil, errors.New("not a private key") + } + + return x509.ParsePKCS1PrivateKey(block.Bytes) +} diff --git a/utils/error.go b/utils/error.go index 7b21265..b977245 100644 --- a/utils/error.go +++ b/utils/error.go @@ -1,7 +1,30 @@ package utils -import "fmt" -import "log" +import ( + "fmt" + "log" +) + +// Error captures various common ricochet errors +type Error string + +func (e Error) Error() string { return string(e) } + +// Defining Versions +const ( + VersionNegotiationError = Error("VersionNegotiationError") + VersionNegotiationFailed = Error("VersionNegotiationFailed") + + RicochetConnectionClosed = Error("RicochetConnectionClosed") + RicochetProtocolError = Error("RicochetProtocolError") + + UnknownChannelTypeError = Error("UnknownChannelTypeError") + UnauthorizedChannelTypeError = Error("UnauthorizedChannelTypeError") + + ActionTimedOutError = Error("ActionTimedOutError") + + ClientFailedToAuthenticateError = Error("ClientFailedToAuthenticateError") +) // RecoverFromError doesn't really recover from anything....see comment below func RecoverFromError() { diff --git a/messagebuilder.go b/utils/messagebuilder.go similarity index 69% rename from messagebuilder.go rename to utils/messagebuilder.go index 874f213..bd5616f 100644 --- a/messagebuilder.go +++ b/utils/messagebuilder.go @@ -1,12 +1,11 @@ -package goricochet +package utils import ( "github.com/golang/protobuf/proto" - "github.com/s-rah/go-ricochet/auth" - "github.com/s-rah/go-ricochet/chat" - "github.com/s-rah/go-ricochet/contact" - "github.com/s-rah/go-ricochet/control" - "github.com/s-rah/go-ricochet/utils" + "github.com/s-rah/go-ricochet/wire/auth" + "github.com/s-rah/go-ricochet/wire/chat" + "github.com/s-rah/go-ricochet/wire/contact" + "github.com/s-rah/go-ricochet/wire/control" ) // MessageBuilder allows a client to construct specific data packets for the @@ -16,7 +15,7 @@ type MessageBuilder struct { // OpenChannel contructs a message which will request to open a channel for // chat on the given channelID. -func (mb *MessageBuilder) OpenChannel(channelID int32, channelType string) ([]byte, error) { +func (mb *MessageBuilder) OpenChannel(channelID int32, channelType string) []byte { oc := &Protocol_Data_Control.OpenChannel{ ChannelIdentifier: proto.Int32(channelID), ChannelType: proto.String(channelType), @@ -24,11 +23,13 @@ func (mb *MessageBuilder) OpenChannel(channelID int32, channelType string) ([]by pc := &Protocol_Data_Control.Packet{ OpenChannel: oc, } - return proto.Marshal(pc) + ret, err := proto.Marshal(pc) + CheckError(err) + return ret } // AckOpenChannel constructs a message to acknowledge a previous open channel operation. -func (mb *MessageBuilder) AckOpenChannel(channelID int32) ([]byte, error) { +func (mb *MessageBuilder) AckOpenChannel(channelID int32) []byte { cr := &Protocol_Data_Control.ChannelResult{ ChannelIdentifier: proto.Int32(channelID), Opened: proto.Bool(true), @@ -36,11 +37,13 @@ func (mb *MessageBuilder) AckOpenChannel(channelID int32) ([]byte, error) { pc := &Protocol_Data_Control.Packet{ ChannelResult: cr, } - return proto.Marshal(pc) + ret, err := proto.Marshal(pc) + CheckError(err) + return ret } // RejectOpenChannel constructs a channel result message, stating the channel failed to open and a reason -func (mb *MessageBuilder) RejectOpenChannel(channelID int32, error string) ([]byte, error) { +func (mb *MessageBuilder) RejectOpenChannel(channelID int32, error string) []byte { errorNum := Protocol_Data_Control.ChannelResult_CommonError_value[error] commonError := Protocol_Data_Control.ChannelResult_CommonError(errorNum) @@ -53,28 +56,32 @@ func (mb *MessageBuilder) RejectOpenChannel(channelID int32, error string) ([]by pc := &Protocol_Data_Control.Packet{ ChannelResult: cr, } - return proto.Marshal(pc) + ret, err := proto.Marshal(pc) + CheckError(err) + return ret } // ConfirmAuthChannel constructs a message to acknowledge a previous open channel operation. -func (mb *MessageBuilder) ConfirmAuthChannel(channelID int32, serverCookie [16]byte) ([]byte, error) { +func (mb *MessageBuilder) ConfirmAuthChannel(channelID int32, serverCookie [16]byte) []byte { cr := &Protocol_Data_Control.ChannelResult{ ChannelIdentifier: proto.Int32(channelID), Opened: proto.Bool(true), } err := proto.SetExtension(cr, Protocol_Data_AuthHiddenService.E_ServerCookie, serverCookie[:]) - utils.CheckError(err) + CheckError(err) pc := &Protocol_Data_Control.Packet{ ChannelResult: cr, } - return proto.Marshal(pc) + ret, err := proto.Marshal(pc) + CheckError(err) + return ret } // OpenContactRequestChannel contructs a message which will reuqest to open a channel for // a contact request on the given channelID, with the given nick and message. -func (mb *MessageBuilder) OpenContactRequestChannel(channelID int32, nick string, message string) ([]byte, error) { +func (mb *MessageBuilder) OpenContactRequestChannel(channelID int32, nick string, message string) []byte { // Construct a Contact Request Channel oc := &Protocol_Data_Control.OpenChannel{ ChannelIdentifier: proto.Int32(channelID), @@ -87,16 +94,18 @@ func (mb *MessageBuilder) OpenContactRequestChannel(channelID int32, nick string } err := proto.SetExtension(oc, Protocol_Data_ContactRequest.E_ContactRequest, contactRequest) - utils.CheckError(err) + CheckError(err) pc := &Protocol_Data_Control.Packet{ OpenChannel: oc, } - return proto.Marshal(pc) + ret, err := proto.Marshal(pc) + CheckError(err) + return ret } // ReplyToContactRequestOnResponse constructs a message to acknowledge contact request -func (mb *MessageBuilder) ReplyToContactRequestOnResponse(channelID int32, status string) ([]byte, error) { +func (mb *MessageBuilder) ReplyToContactRequestOnResponse(channelID int32, status string) []byte { cr := &Protocol_Data_Control.ChannelResult{ ChannelIdentifier: proto.Int32(channelID), Opened: proto.Bool(true), @@ -109,42 +118,49 @@ func (mb *MessageBuilder) ReplyToContactRequestOnResponse(channelID int32, statu } err := proto.SetExtension(cr, Protocol_Data_ContactRequest.E_Response, contactRequest) - utils.CheckError(err) + CheckError(err) pc := &Protocol_Data_Control.Packet{ ChannelResult: cr, } - return proto.Marshal(pc) + ret, err := proto.Marshal(pc) + CheckError(err) + return ret } // ReplyToContactRequest constructs a message to acknowledge a contact request -func (mb *MessageBuilder) ReplyToContactRequest(channelID int32, status string) ([]byte, error) { +func (mb *MessageBuilder) ReplyToContactRequest(channelID int32, status string) []byte { statusNum := Protocol_Data_ContactRequest.Response_Status_value[status] responseStatus := Protocol_Data_ContactRequest.Response_Status(statusNum) contactRequest := &Protocol_Data_ContactRequest.Response{ Status: &responseStatus, } - return proto.Marshal(contactRequest) + + ret, err := proto.Marshal(contactRequest) + CheckError(err) + return ret } // OpenAuthenticationChannel constructs a message which will reuqest to open a channel for // authentication on the given channelID, with the given cookie -func (mb *MessageBuilder) OpenAuthenticationChannel(channelID int32, clientCookie [16]byte) ([]byte, error) { +func (mb *MessageBuilder) OpenAuthenticationChannel(channelID int32, clientCookie [16]byte) []byte { oc := &Protocol_Data_Control.OpenChannel{ ChannelIdentifier: proto.Int32(channelID), ChannelType: proto.String("im.ricochet.auth.hidden-service"), } err := proto.SetExtension(oc, Protocol_Data_AuthHiddenService.E_ClientCookie, clientCookie[:]) - utils.CheckError(err) + CheckError(err) pc := &Protocol_Data_Control.Packet{ OpenChannel: oc, } - return proto.Marshal(pc) + ret, err := proto.Marshal(pc) + CheckError(err) + return ret } // Proof constructs a proof message with the given public key and signature. -func (mb *MessageBuilder) Proof(publicKeyBytes []byte, signatureBytes []byte) ([]byte, error) { +func (mb *MessageBuilder) Proof(publicKeyBytes []byte, signatureBytes []byte) []byte { proof := &Protocol_Data_AuthHiddenService.Proof{ PublicKey: publicKeyBytes, Signature: signatureBytes, @@ -155,11 +171,13 @@ func (mb *MessageBuilder) Proof(publicKeyBytes []byte, signatureBytes []byte) ([ Result: nil, } - return proto.Marshal(ahsPacket) + ret, err := proto.Marshal(ahsPacket) + CheckError(err) + return ret } // AuthResult constructs a response to a Proof -func (mb *MessageBuilder) AuthResult(accepted bool, isKnownContact bool) ([]byte, error) { +func (mb *MessageBuilder) AuthResult(accepted bool, isKnownContact bool) []byte { // Construct a Result Message result := &Protocol_Data_AuthHiddenService.Result{ Accepted: proto.Bool(accepted), @@ -171,29 +189,74 @@ func (mb *MessageBuilder) AuthResult(accepted bool, isKnownContact bool) ([]byte Result: result, } - return proto.Marshal(ahsPacket) + ret, err := proto.Marshal(ahsPacket) + CheckError(err) + return ret } // ChatMessage constructs a chat message with the given content. -func (mb *MessageBuilder) ChatMessage(message string, messageID int32) ([]byte, error) { +func (mb *MessageBuilder) ChatMessage(message string, messageID uint32) []byte { cm := &Protocol_Data_Chat.ChatMessage{ - MessageId: proto.Uint32(uint32(messageID)), + MessageId: proto.Uint32(messageID), MessageText: proto.String(message), } chatPacket := &Protocol_Data_Chat.Packet{ ChatMessage: cm, } - return proto.Marshal(chatPacket) + ret, err := proto.Marshal(chatPacket) + CheckError(err) + return ret } // AckChatMessage constructs a chat message acknowledgement. -func (mb *MessageBuilder) AckChatMessage(messageID int32) ([]byte, error) { +func (mb *MessageBuilder) AckChatMessage(messageID uint32) []byte { cr := &Protocol_Data_Chat.ChatAcknowledge{ - MessageId: proto.Uint32(uint32(messageID)), + MessageId: proto.Uint32(messageID), Accepted: proto.Bool(true), } pc := &Protocol_Data_Chat.Packet{ ChatAcknowledge: cr, } - return proto.Marshal(pc) + ret, err := proto.Marshal(pc) + CheckError(err) + return ret +} + +// KeepAlive ... +func (mb *MessageBuilder) KeepAlive(responseRequested bool) []byte { + ka := &Protocol_Data_Control.KeepAlive{ + ResponseRequested: proto.Bool(responseRequested), + } + pc := &Protocol_Data_Control.Packet{ + KeepAlive: ka, + } + ret, err := proto.Marshal(pc) + CheckError(err) + return ret +} + +// EnableFeatures ... +func (mb *MessageBuilder) EnableFeatures(features []string) []byte { + ef := &Protocol_Data_Control.EnableFeatures{ + Feature: features, + } + pc := &Protocol_Data_Control.Packet{ + EnableFeatures: ef, + } + ret, err := proto.Marshal(pc) + CheckError(err) + return ret +} + +// FeaturesEnabled ... +func (mb *MessageBuilder) FeaturesEnabled(features []string) []byte { + fe := &Protocol_Data_Control.FeaturesEnabled{ + Feature: features, + } + pc := &Protocol_Data_Control.Packet{ + FeaturesEnabled: fe, + } + ret, err := proto.Marshal(pc) + CheckError(err) + return ret } diff --git a/utils/messagebuilder_test.go b/utils/messagebuilder_test.go new file mode 100644 index 0000000..cdd9f27 --- /dev/null +++ b/utils/messagebuilder_test.go @@ -0,0 +1,74 @@ +package utils + +import ( + "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/wire/control" + "testing" +) + +func TestOpenChatChannel(t *testing.T) { + messageBuilder := new(MessageBuilder) + messageBuilder.OpenChannel(1, "im.ricochet.chat") + // TODO: More Indepth Test Of Output +} + +func TestOpenContactRequestChannel(t *testing.T) { + messageBuilder := new(MessageBuilder) + messageBuilder.OpenContactRequestChannel(3, "Nickname", "Message") + // TODO: More Indepth Test Of Output +} + +func TestOpenAuthenticationChannel(t *testing.T) { + messageBuilder := new(MessageBuilder) + messageBuilder.OpenAuthenticationChannel(1, [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) + // TODO: More Indepth Test Of Output +} + +func TestChatMessage(t *testing.T) { + messageBuilder := new(MessageBuilder) + messageBuilder.ChatMessage("Hello World", 0) + // TODO: More Indepth Test Of Output +} + +func TestKeepAlive(t *testing.T) { + messageBuilder := new(MessageBuilder) + raw := messageBuilder.KeepAlive(true) + res := new(Protocol_Data_Control.Packet) + err := proto.Unmarshal(raw, res) + if err != nil || res.GetKeepAlive() == nil || !res.GetKeepAlive().GetResponseRequested() { + t.Errorf("Decoding Keep Alive Packet failed or no response requested: %v %v", err, res) + } +} + +func TestFeaturesEnabled(t *testing.T) { + messageBuilder := new(MessageBuilder) + features := []string{"feature1", "feature2"} + raw := messageBuilder.FeaturesEnabled(features) + res := new(Protocol_Data_Control.Packet) + err := proto.Unmarshal(raw, res) + if err != nil || res.GetFeaturesEnabled() == nil { + t.Errorf("Decoding FeaturesEnabled Packet failed: %v %v", err, res) + } + + for i, v := range res.GetFeaturesEnabled().GetFeature() { + if v != features[i] { + t.Errorf("Requested Features do not match %v %v", res.GetFeaturesEnabled().GetFeature(), features) + } + } +} + +func TestEnableFeatures(t *testing.T) { + messageBuilder := new(MessageBuilder) + features := []string{"feature1", "feature2"} + raw := messageBuilder.EnableFeatures(features) + res := new(Protocol_Data_Control.Packet) + err := proto.Unmarshal(raw, res) + if err != nil || res.GetEnableFeatures() == nil { + t.Errorf("Decoding EnableFeatures Packet failed: %v %v", err, res) + } + for i, v := range res.GetEnableFeatures().GetFeature() { + if v != features[i] { + t.Errorf("Requested Features do not match %v %v", res.GetFeaturesEnabled().GetFeature(), features) + } + } +} diff --git a/utils/networkresolver.go b/utils/networkresolver.go index e2873d8..faf2602 100644 --- a/utils/networkresolver.go +++ b/utils/networkresolver.go @@ -47,6 +47,6 @@ func (nr *NetworkResolver) Resolve(hostname string) (net.Conn, string, error) { if err != nil { return nil, "", errors.New("Cannot Dial Remote Ricochet Address") } - //conn.SetDeadline(time.Now().Add(5 * time.Second)) + return conn, resolvedHostname, nil } diff --git a/vendor/github.com/golang/protobuf/AUTHORS b/vendor/github.com/golang/protobuf/AUTHORS new file mode 100644 index 0000000..15167cd --- /dev/null +++ b/vendor/github.com/golang/protobuf/AUTHORS @@ -0,0 +1,3 @@ +# This source code refers to The Go Authors for copyright purposes. +# The master list of authors is in the main Go distribution, +# visible at http://tip.golang.org/AUTHORS. diff --git a/vendor/github.com/golang/protobuf/CONTRIBUTORS b/vendor/github.com/golang/protobuf/CONTRIBUTORS new file mode 100644 index 0000000..1c4577e --- /dev/null +++ b/vendor/github.com/golang/protobuf/CONTRIBUTORS @@ -0,0 +1,3 @@ +# This source code was written by the Go contributors. +# The master list of contributors is in the main Go distribution, +# visible at http://tip.golang.org/CONTRIBUTORS. diff --git a/vendor/github.com/golang/protobuf/LICENSE b/vendor/github.com/golang/protobuf/LICENSE new file mode 100644 index 0000000..1b1b192 --- /dev/null +++ b/vendor/github.com/golang/protobuf/LICENSE @@ -0,0 +1,31 @@ +Go support for Protocol Buffers - Google's data interchange format + +Copyright 2010 The Go Authors. All rights reserved. +https://github.com/golang/protobuf + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/vendor/github.com/golang/protobuf/proto/Makefile b/vendor/github.com/golang/protobuf/proto/Makefile new file mode 100644 index 0000000..e2e0651 --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/Makefile @@ -0,0 +1,43 @@ +# Go support for Protocol Buffers - Google's data interchange format +# +# Copyright 2010 The Go Authors. All rights reserved. +# https://github.com/golang/protobuf +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +install: + go install + +test: install generate-test-pbs + go test + + +generate-test-pbs: + make install + make -C testdata + protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata,Mgoogle/protobuf/any.proto=github.com/golang/protobuf/ptypes/any:. proto3_proto/proto3.proto + make diff --git a/vendor/github.com/golang/protobuf/proto/clone.go b/vendor/github.com/golang/protobuf/proto/clone.go new file mode 100644 index 0000000..e392575 --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/clone.go @@ -0,0 +1,229 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2011 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Protocol buffer deep copy and merge. +// TODO: RawMessage. + +package proto + +import ( + "log" + "reflect" + "strings" +) + +// Clone returns a deep copy of a protocol buffer. +func Clone(pb Message) Message { + in := reflect.ValueOf(pb) + if in.IsNil() { + return pb + } + + out := reflect.New(in.Type().Elem()) + // out is empty so a merge is a deep copy. + mergeStruct(out.Elem(), in.Elem()) + return out.Interface().(Message) +} + +// Merge merges src into dst. +// Required and optional fields that are set in src will be set to that value in dst. +// Elements of repeated fields will be appended. +// Merge panics if src and dst are not the same type, or if dst is nil. +func Merge(dst, src Message) { + in := reflect.ValueOf(src) + out := reflect.ValueOf(dst) + if out.IsNil() { + panic("proto: nil destination") + } + if in.Type() != out.Type() { + // Explicit test prior to mergeStruct so that mistyped nils will fail + panic("proto: type mismatch") + } + if in.IsNil() { + // Merging nil into non-nil is a quiet no-op + return + } + mergeStruct(out.Elem(), in.Elem()) +} + +func mergeStruct(out, in reflect.Value) { + sprop := GetProperties(in.Type()) + for i := 0; i < in.NumField(); i++ { + f := in.Type().Field(i) + if strings.HasPrefix(f.Name, "XXX_") { + continue + } + mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i]) + } + + if emIn, ok := extendable(in.Addr().Interface()); ok { + emOut, _ := extendable(out.Addr().Interface()) + mIn, muIn := emIn.extensionsRead() + if mIn != nil { + mOut := emOut.extensionsWrite() + muIn.Lock() + mergeExtension(mOut, mIn) + muIn.Unlock() + } + } + + uf := in.FieldByName("XXX_unrecognized") + if !uf.IsValid() { + return + } + uin := uf.Bytes() + if len(uin) > 0 { + out.FieldByName("XXX_unrecognized").SetBytes(append([]byte(nil), uin...)) + } +} + +// mergeAny performs a merge between two values of the same type. +// viaPtr indicates whether the values were indirected through a pointer (implying proto2). +// prop is set if this is a struct field (it may be nil). +func mergeAny(out, in reflect.Value, viaPtr bool, prop *Properties) { + if in.Type() == protoMessageType { + if !in.IsNil() { + if out.IsNil() { + out.Set(reflect.ValueOf(Clone(in.Interface().(Message)))) + } else { + Merge(out.Interface().(Message), in.Interface().(Message)) + } + } + return + } + switch in.Kind() { + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, + reflect.String, reflect.Uint32, reflect.Uint64: + if !viaPtr && isProto3Zero(in) { + return + } + out.Set(in) + case reflect.Interface: + // Probably a oneof field; copy non-nil values. + if in.IsNil() { + return + } + // Allocate destination if it is not set, or set to a different type. + // Otherwise we will merge as normal. + if out.IsNil() || out.Elem().Type() != in.Elem().Type() { + out.Set(reflect.New(in.Elem().Elem().Type())) // interface -> *T -> T -> new(T) + } + mergeAny(out.Elem(), in.Elem(), false, nil) + case reflect.Map: + if in.Len() == 0 { + return + } + if out.IsNil() { + out.Set(reflect.MakeMap(in.Type())) + } + // For maps with value types of *T or []byte we need to deep copy each value. + elemKind := in.Type().Elem().Kind() + for _, key := range in.MapKeys() { + var val reflect.Value + switch elemKind { + case reflect.Ptr: + val = reflect.New(in.Type().Elem().Elem()) + mergeAny(val, in.MapIndex(key), false, nil) + case reflect.Slice: + val = in.MapIndex(key) + val = reflect.ValueOf(append([]byte{}, val.Bytes()...)) + default: + val = in.MapIndex(key) + } + out.SetMapIndex(key, val) + } + case reflect.Ptr: + if in.IsNil() { + return + } + if out.IsNil() { + out.Set(reflect.New(in.Elem().Type())) + } + mergeAny(out.Elem(), in.Elem(), true, nil) + case reflect.Slice: + if in.IsNil() { + return + } + if in.Type().Elem().Kind() == reflect.Uint8 { + // []byte is a scalar bytes field, not a repeated field. + + // Edge case: if this is in a proto3 message, a zero length + // bytes field is considered the zero value, and should not + // be merged. + if prop != nil && prop.proto3 && in.Len() == 0 { + return + } + + // Make a deep copy. + // Append to []byte{} instead of []byte(nil) so that we never end up + // with a nil result. + out.SetBytes(append([]byte{}, in.Bytes()...)) + return + } + n := in.Len() + if out.IsNil() { + out.Set(reflect.MakeSlice(in.Type(), 0, n)) + } + switch in.Type().Elem().Kind() { + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, + reflect.String, reflect.Uint32, reflect.Uint64: + out.Set(reflect.AppendSlice(out, in)) + default: + for i := 0; i < n; i++ { + x := reflect.Indirect(reflect.New(in.Type().Elem())) + mergeAny(x, in.Index(i), false, nil) + out.Set(reflect.Append(out, x)) + } + } + case reflect.Struct: + mergeStruct(out, in) + default: + // unknown type, so not a protocol buffer + log.Printf("proto: don't know how to copy %v", in) + } +} + +func mergeExtension(out, in map[int32]Extension) { + for extNum, eIn := range in { + eOut := Extension{desc: eIn.desc} + if eIn.value != nil { + v := reflect.New(reflect.TypeOf(eIn.value)).Elem() + mergeAny(v, reflect.ValueOf(eIn.value), false, nil) + eOut.value = v.Interface() + } + if eIn.enc != nil { + eOut.enc = make([]byte, len(eIn.enc)) + copy(eOut.enc, eIn.enc) + } + + out[extNum] = eOut + } +} diff --git a/vendor/github.com/golang/protobuf/proto/decode.go b/vendor/github.com/golang/protobuf/proto/decode.go new file mode 100644 index 0000000..aa20729 --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/decode.go @@ -0,0 +1,970 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Routines for decoding protocol buffer data to construct in-memory representations. + */ + +import ( + "errors" + "fmt" + "io" + "os" + "reflect" +) + +// errOverflow is returned when an integer is too large to be represented. +var errOverflow = errors.New("proto: integer overflow") + +// ErrInternalBadWireType is returned by generated code when an incorrect +// wire type is encountered. It does not get returned to user code. +var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof") + +// The fundamental decoders that interpret bytes on the wire. +// Those that take integer types all return uint64 and are +// therefore of type valueDecoder. + +// DecodeVarint reads a varint-encoded integer from the slice. +// It returns the integer and the number of bytes consumed, or +// zero if there is not enough. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +func DecodeVarint(buf []byte) (x uint64, n int) { + for shift := uint(0); shift < 64; shift += 7 { + if n >= len(buf) { + return 0, 0 + } + b := uint64(buf[n]) + n++ + x |= (b & 0x7F) << shift + if (b & 0x80) == 0 { + return x, n + } + } + + // The number is too large to represent in a 64-bit value. + return 0, 0 +} + +func (p *Buffer) decodeVarintSlow() (x uint64, err error) { + i := p.index + l := len(p.buf) + + for shift := uint(0); shift < 64; shift += 7 { + if i >= l { + err = io.ErrUnexpectedEOF + return + } + b := p.buf[i] + i++ + x |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + p.index = i + return + } + } + + // The number is too large to represent in a 64-bit value. + err = errOverflow + return +} + +// DecodeVarint reads a varint-encoded integer from the Buffer. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +func (p *Buffer) DecodeVarint() (x uint64, err error) { + i := p.index + buf := p.buf + + if i >= len(buf) { + return 0, io.ErrUnexpectedEOF + } else if buf[i] < 0x80 { + p.index++ + return uint64(buf[i]), nil + } else if len(buf)-i < 10 { + return p.decodeVarintSlow() + } + + var b uint64 + // we already checked the first byte + x = uint64(buf[i]) - 0x80 + i++ + + b = uint64(buf[i]) + i++ + x += b << 7 + if b&0x80 == 0 { + goto done + } + x -= 0x80 << 7 + + b = uint64(buf[i]) + i++ + x += b << 14 + if b&0x80 == 0 { + goto done + } + x -= 0x80 << 14 + + b = uint64(buf[i]) + i++ + x += b << 21 + if b&0x80 == 0 { + goto done + } + x -= 0x80 << 21 + + b = uint64(buf[i]) + i++ + x += b << 28 + if b&0x80 == 0 { + goto done + } + x -= 0x80 << 28 + + b = uint64(buf[i]) + i++ + x += b << 35 + if b&0x80 == 0 { + goto done + } + x -= 0x80 << 35 + + b = uint64(buf[i]) + i++ + x += b << 42 + if b&0x80 == 0 { + goto done + } + x -= 0x80 << 42 + + b = uint64(buf[i]) + i++ + x += b << 49 + if b&0x80 == 0 { + goto done + } + x -= 0x80 << 49 + + b = uint64(buf[i]) + i++ + x += b << 56 + if b&0x80 == 0 { + goto done + } + x -= 0x80 << 56 + + b = uint64(buf[i]) + i++ + x += b << 63 + if b&0x80 == 0 { + goto done + } + // x -= 0x80 << 63 // Always zero. + + return 0, errOverflow + +done: + p.index = i + return x, nil +} + +// DecodeFixed64 reads a 64-bit integer from the Buffer. +// This is the format for the +// fixed64, sfixed64, and double protocol buffer types. +func (p *Buffer) DecodeFixed64() (x uint64, err error) { + // x, err already 0 + i := p.index + 8 + if i < 0 || i > len(p.buf) { + err = io.ErrUnexpectedEOF + return + } + p.index = i + + x = uint64(p.buf[i-8]) + x |= uint64(p.buf[i-7]) << 8 + x |= uint64(p.buf[i-6]) << 16 + x |= uint64(p.buf[i-5]) << 24 + x |= uint64(p.buf[i-4]) << 32 + x |= uint64(p.buf[i-3]) << 40 + x |= uint64(p.buf[i-2]) << 48 + x |= uint64(p.buf[i-1]) << 56 + return +} + +// DecodeFixed32 reads a 32-bit integer from the Buffer. +// This is the format for the +// fixed32, sfixed32, and float protocol buffer types. +func (p *Buffer) DecodeFixed32() (x uint64, err error) { + // x, err already 0 + i := p.index + 4 + if i < 0 || i > len(p.buf) { + err = io.ErrUnexpectedEOF + return + } + p.index = i + + x = uint64(p.buf[i-4]) + x |= uint64(p.buf[i-3]) << 8 + x |= uint64(p.buf[i-2]) << 16 + x |= uint64(p.buf[i-1]) << 24 + return +} + +// DecodeZigzag64 reads a zigzag-encoded 64-bit integer +// from the Buffer. +// This is the format used for the sint64 protocol buffer type. +func (p *Buffer) DecodeZigzag64() (x uint64, err error) { + x, err = p.DecodeVarint() + if err != nil { + return + } + x = (x >> 1) ^ uint64((int64(x&1)<<63)>>63) + return +} + +// DecodeZigzag32 reads a zigzag-encoded 32-bit integer +// from the Buffer. +// This is the format used for the sint32 protocol buffer type. +func (p *Buffer) DecodeZigzag32() (x uint64, err error) { + x, err = p.DecodeVarint() + if err != nil { + return + } + x = uint64((uint32(x) >> 1) ^ uint32((int32(x&1)<<31)>>31)) + return +} + +// These are not ValueDecoders: they produce an array of bytes or a string. +// bytes, embedded messages + +// DecodeRawBytes reads a count-delimited byte buffer from the Buffer. +// This is the format used for the bytes protocol buffer +// type and for embedded messages. +func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) { + n, err := p.DecodeVarint() + if err != nil { + return nil, err + } + + nb := int(n) + if nb < 0 { + return nil, fmt.Errorf("proto: bad byte length %d", nb) + } + end := p.index + nb + if end < p.index || end > len(p.buf) { + return nil, io.ErrUnexpectedEOF + } + + if !alloc { + // todo: check if can get more uses of alloc=false + buf = p.buf[p.index:end] + p.index += nb + return + } + + buf = make([]byte, nb) + copy(buf, p.buf[p.index:]) + p.index += nb + return +} + +// DecodeStringBytes reads an encoded string from the Buffer. +// This is the format used for the proto2 string type. +func (p *Buffer) DecodeStringBytes() (s string, err error) { + buf, err := p.DecodeRawBytes(false) + if err != nil { + return + } + return string(buf), nil +} + +// Skip the next item in the buffer. Its wire type is decoded and presented as an argument. +// If the protocol buffer has extensions, and the field matches, add it as an extension. +// Otherwise, if the XXX_unrecognized field exists, append the skipped data there. +func (o *Buffer) skipAndSave(t reflect.Type, tag, wire int, base structPointer, unrecField field) error { + oi := o.index + + err := o.skip(t, tag, wire) + if err != nil { + return err + } + + if !unrecField.IsValid() { + return nil + } + + ptr := structPointer_Bytes(base, unrecField) + + // Add the skipped field to struct field + obuf := o.buf + + o.buf = *ptr + o.EncodeVarint(uint64(tag<<3 | wire)) + *ptr = append(o.buf, obuf[oi:o.index]...) + + o.buf = obuf + + return nil +} + +// Skip the next item in the buffer. Its wire type is decoded and presented as an argument. +func (o *Buffer) skip(t reflect.Type, tag, wire int) error { + + var u uint64 + var err error + + switch wire { + case WireVarint: + _, err = o.DecodeVarint() + case WireFixed64: + _, err = o.DecodeFixed64() + case WireBytes: + _, err = o.DecodeRawBytes(false) + case WireFixed32: + _, err = o.DecodeFixed32() + case WireStartGroup: + for { + u, err = o.DecodeVarint() + if err != nil { + break + } + fwire := int(u & 0x7) + if fwire == WireEndGroup { + break + } + ftag := int(u >> 3) + err = o.skip(t, ftag, fwire) + if err != nil { + break + } + } + default: + err = fmt.Errorf("proto: can't skip unknown wire type %d for %s", wire, t) + } + return err +} + +// Unmarshaler is the interface representing objects that can +// unmarshal themselves. The method should reset the receiver before +// decoding starts. The argument points to data that may be +// overwritten, so implementations should not keep references to the +// buffer. +type Unmarshaler interface { + Unmarshal([]byte) error +} + +// Unmarshal parses the protocol buffer representation in buf and places the +// decoded result in pb. If the struct underlying pb does not match +// the data in buf, the results can be unpredictable. +// +// Unmarshal resets pb before starting to unmarshal, so any +// existing data in pb is always removed. Use UnmarshalMerge +// to preserve and append to existing data. +func Unmarshal(buf []byte, pb Message) error { + pb.Reset() + return UnmarshalMerge(buf, pb) +} + +// UnmarshalMerge parses the protocol buffer representation in buf and +// writes the decoded result to pb. If the struct underlying pb does not match +// the data in buf, the results can be unpredictable. +// +// UnmarshalMerge merges into existing data in pb. +// Most code should use Unmarshal instead. +func UnmarshalMerge(buf []byte, pb Message) error { + // If the object can unmarshal itself, let it. + if u, ok := pb.(Unmarshaler); ok { + return u.Unmarshal(buf) + } + return NewBuffer(buf).Unmarshal(pb) +} + +// DecodeMessage reads a count-delimited message from the Buffer. +func (p *Buffer) DecodeMessage(pb Message) error { + enc, err := p.DecodeRawBytes(false) + if err != nil { + return err + } + return NewBuffer(enc).Unmarshal(pb) +} + +// DecodeGroup reads a tag-delimited group from the Buffer. +func (p *Buffer) DecodeGroup(pb Message) error { + typ, base, err := getbase(pb) + if err != nil { + return err + } + return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base) +} + +// Unmarshal parses the protocol buffer representation in the +// Buffer and places the decoded result in pb. If the struct +// underlying pb does not match the data in the buffer, the results can be +// unpredictable. +// +// Unlike proto.Unmarshal, this does not reset pb before starting to unmarshal. +func (p *Buffer) Unmarshal(pb Message) error { + // If the object can unmarshal itself, let it. + if u, ok := pb.(Unmarshaler); ok { + err := u.Unmarshal(p.buf[p.index:]) + p.index = len(p.buf) + return err + } + + typ, base, err := getbase(pb) + if err != nil { + return err + } + + err = p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), false, base) + + if collectStats { + stats.Decode++ + } + + return err +} + +// unmarshalType does the work of unmarshaling a structure. +func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error { + var state errorState + required, reqFields := prop.reqCount, uint64(0) + + var err error + for err == nil && o.index < len(o.buf) { + oi := o.index + var u uint64 + u, err = o.DecodeVarint() + if err != nil { + break + } + wire := int(u & 0x7) + if wire == WireEndGroup { + if is_group { + if required > 0 { + // Not enough information to determine the exact field. + // (See below.) + return &RequiredNotSetError{"{Unknown}"} + } + return nil // input is satisfied + } + return fmt.Errorf("proto: %s: wiretype end group for non-group", st) + } + tag := int(u >> 3) + if tag <= 0 { + return fmt.Errorf("proto: %s: illegal tag %d (wire type %d)", st, tag, wire) + } + fieldnum, ok := prop.decoderTags.get(tag) + if !ok { + // Maybe it's an extension? + if prop.extendable { + if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) { + if err = o.skip(st, tag, wire); err == nil { + extmap := e.extensionsWrite() + ext := extmap[int32(tag)] // may be missing + ext.enc = append(ext.enc, o.buf[oi:o.index]...) + extmap[int32(tag)] = ext + } + continue + } + } + // Maybe it's a oneof? + if prop.oneofUnmarshaler != nil { + m := structPointer_Interface(base, st).(Message) + // First return value indicates whether tag is a oneof field. + ok, err = prop.oneofUnmarshaler(m, tag, wire, o) + if err == ErrInternalBadWireType { + // Map the error to something more descriptive. + // Do the formatting here to save generated code space. + err = fmt.Errorf("bad wiretype for oneof field in %T", m) + } + if ok { + continue + } + } + err = o.skipAndSave(st, tag, wire, base, prop.unrecField) + continue + } + p := prop.Prop[fieldnum] + + if p.dec == nil { + fmt.Fprintf(os.Stderr, "proto: no protobuf decoder for %s.%s\n", st, st.Field(fieldnum).Name) + continue + } + dec := p.dec + if wire != WireStartGroup && wire != p.WireType { + if wire == WireBytes && p.packedDec != nil { + // a packable field + dec = p.packedDec + } else { + err = fmt.Errorf("proto: bad wiretype for field %s.%s: got wiretype %d, want %d", st, st.Field(fieldnum).Name, wire, p.WireType) + continue + } + } + decErr := dec(o, p, base) + if decErr != nil && !state.shouldContinue(decErr, p) { + err = decErr + } + if err == nil && p.Required { + // Successfully decoded a required field. + if tag <= 64 { + // use bitmap for fields 1-64 to catch field reuse. + var mask uint64 = 1 << uint64(tag-1) + if reqFields&mask == 0 { + // new required field + reqFields |= mask + required-- + } + } else { + // This is imprecise. It can be fooled by a required field + // with a tag > 64 that is encoded twice; that's very rare. + // A fully correct implementation would require allocating + // a data structure, which we would like to avoid. + required-- + } + } + } + if err == nil { + if is_group { + return io.ErrUnexpectedEOF + } + if state.err != nil { + return state.err + } + if required > 0 { + // Not enough information to determine the exact field. If we use extra + // CPU, we could determine the field only if the missing required field + // has a tag <= 64 and we check reqFields. + return &RequiredNotSetError{"{Unknown}"} + } + } + return err +} + +// Individual type decoders +// For each, +// u is the decoded value, +// v is a pointer to the field (pointer) in the struct + +// Sizes of the pools to allocate inside the Buffer. +// The goal is modest amortization and allocation +// on at least 16-byte boundaries. +const ( + boolPoolSize = 16 + uint32PoolSize = 8 + uint64PoolSize = 4 +) + +// Decode a bool. +func (o *Buffer) dec_bool(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + if len(o.bools) == 0 { + o.bools = make([]bool, boolPoolSize) + } + o.bools[0] = u != 0 + *structPointer_Bool(base, p.field) = &o.bools[0] + o.bools = o.bools[1:] + return nil +} + +func (o *Buffer) dec_proto3_bool(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + *structPointer_BoolVal(base, p.field) = u != 0 + return nil +} + +// Decode an int32. +func (o *Buffer) dec_int32(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + word32_Set(structPointer_Word32(base, p.field), o, uint32(u)) + return nil +} + +func (o *Buffer) dec_proto3_int32(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + word32Val_Set(structPointer_Word32Val(base, p.field), uint32(u)) + return nil +} + +// Decode an int64. +func (o *Buffer) dec_int64(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + word64_Set(structPointer_Word64(base, p.field), o, u) + return nil +} + +func (o *Buffer) dec_proto3_int64(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + word64Val_Set(structPointer_Word64Val(base, p.field), o, u) + return nil +} + +// Decode a string. +func (o *Buffer) dec_string(p *Properties, base structPointer) error { + s, err := o.DecodeStringBytes() + if err != nil { + return err + } + *structPointer_String(base, p.field) = &s + return nil +} + +func (o *Buffer) dec_proto3_string(p *Properties, base structPointer) error { + s, err := o.DecodeStringBytes() + if err != nil { + return err + } + *structPointer_StringVal(base, p.field) = s + return nil +} + +// Decode a slice of bytes ([]byte). +func (o *Buffer) dec_slice_byte(p *Properties, base structPointer) error { + b, err := o.DecodeRawBytes(true) + if err != nil { + return err + } + *structPointer_Bytes(base, p.field) = b + return nil +} + +// Decode a slice of bools ([]bool). +func (o *Buffer) dec_slice_bool(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + v := structPointer_BoolSlice(base, p.field) + *v = append(*v, u != 0) + return nil +} + +// Decode a slice of bools ([]bool) in packed format. +func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error { + v := structPointer_BoolSlice(base, p.field) + + nn, err := o.DecodeVarint() + if err != nil { + return err + } + nb := int(nn) // number of bytes of encoded bools + fin := o.index + nb + if fin < o.index { + return errOverflow + } + + y := *v + for o.index < fin { + u, err := p.valDec(o) + if err != nil { + return err + } + y = append(y, u != 0) + } + + *v = y + return nil +} + +// Decode a slice of int32s ([]int32). +func (o *Buffer) dec_slice_int32(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + structPointer_Word32Slice(base, p.field).Append(uint32(u)) + return nil +} + +// Decode a slice of int32s ([]int32) in packed format. +func (o *Buffer) dec_slice_packed_int32(p *Properties, base structPointer) error { + v := structPointer_Word32Slice(base, p.field) + + nn, err := o.DecodeVarint() + if err != nil { + return err + } + nb := int(nn) // number of bytes of encoded int32s + + fin := o.index + nb + if fin < o.index { + return errOverflow + } + for o.index < fin { + u, err := p.valDec(o) + if err != nil { + return err + } + v.Append(uint32(u)) + } + return nil +} + +// Decode a slice of int64s ([]int64). +func (o *Buffer) dec_slice_int64(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + + structPointer_Word64Slice(base, p.field).Append(u) + return nil +} + +// Decode a slice of int64s ([]int64) in packed format. +func (o *Buffer) dec_slice_packed_int64(p *Properties, base structPointer) error { + v := structPointer_Word64Slice(base, p.field) + + nn, err := o.DecodeVarint() + if err != nil { + return err + } + nb := int(nn) // number of bytes of encoded int64s + + fin := o.index + nb + if fin < o.index { + return errOverflow + } + for o.index < fin { + u, err := p.valDec(o) + if err != nil { + return err + } + v.Append(u) + } + return nil +} + +// Decode a slice of strings ([]string). +func (o *Buffer) dec_slice_string(p *Properties, base structPointer) error { + s, err := o.DecodeStringBytes() + if err != nil { + return err + } + v := structPointer_StringSlice(base, p.field) + *v = append(*v, s) + return nil +} + +// Decode a slice of slice of bytes ([][]byte). +func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error { + b, err := o.DecodeRawBytes(true) + if err != nil { + return err + } + v := structPointer_BytesSlice(base, p.field) + *v = append(*v, b) + return nil +} + +// Decode a map field. +func (o *Buffer) dec_new_map(p *Properties, base structPointer) error { + raw, err := o.DecodeRawBytes(false) + if err != nil { + return err + } + oi := o.index // index at the end of this map entry + o.index -= len(raw) // move buffer back to start of map entry + + mptr := structPointer_NewAt(base, p.field, p.mtype) // *map[K]V + if mptr.Elem().IsNil() { + mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem())) + } + v := mptr.Elem() // map[K]V + + // Prepare addressable doubly-indirect placeholders for the key and value types. + // See enc_new_map for why. + keyptr := reflect.New(reflect.PtrTo(p.mtype.Key())).Elem() // addressable *K + keybase := toStructPointer(keyptr.Addr()) // **K + + var valbase structPointer + var valptr reflect.Value + switch p.mtype.Elem().Kind() { + case reflect.Slice: + // []byte + var dummy []byte + valptr = reflect.ValueOf(&dummy) // *[]byte + valbase = toStructPointer(valptr) // *[]byte + case reflect.Ptr: + // message; valptr is **Msg; need to allocate the intermediate pointer + valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V + valptr.Set(reflect.New(valptr.Type().Elem())) + valbase = toStructPointer(valptr) + default: + // everything else + valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V + valbase = toStructPointer(valptr.Addr()) // **V + } + + // Decode. + // This parses a restricted wire format, namely the encoding of a message + // with two fields. See enc_new_map for the format. + for o.index < oi { + // tagcode for key and value properties are always a single byte + // because they have tags 1 and 2. + tagcode := o.buf[o.index] + o.index++ + switch tagcode { + case p.mkeyprop.tagcode[0]: + if err := p.mkeyprop.dec(o, p.mkeyprop, keybase); err != nil { + return err + } + case p.mvalprop.tagcode[0]: + if err := p.mvalprop.dec(o, p.mvalprop, valbase); err != nil { + return err + } + default: + // TODO: Should we silently skip this instead? + return fmt.Errorf("proto: bad map data tag %d", raw[0]) + } + } + keyelem, valelem := keyptr.Elem(), valptr.Elem() + if !keyelem.IsValid() { + keyelem = reflect.Zero(p.mtype.Key()) + } + if !valelem.IsValid() { + valelem = reflect.Zero(p.mtype.Elem()) + } + + v.SetMapIndex(keyelem, valelem) + return nil +} + +// Decode a group. +func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error { + bas := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(bas) { + // allocate new nested message + bas = toStructPointer(reflect.New(p.stype)) + structPointer_SetStructPointer(base, p.field, bas) + } + return o.unmarshalType(p.stype, p.sprop, true, bas) +} + +// Decode an embedded message. +func (o *Buffer) dec_struct_message(p *Properties, base structPointer) (err error) { + raw, e := o.DecodeRawBytes(false) + if e != nil { + return e + } + + bas := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(bas) { + // allocate new nested message + bas = toStructPointer(reflect.New(p.stype)) + structPointer_SetStructPointer(base, p.field, bas) + } + + // If the object can unmarshal itself, let it. + if p.isUnmarshaler { + iv := structPointer_Interface(bas, p.stype) + return iv.(Unmarshaler).Unmarshal(raw) + } + + obuf := o.buf + oi := o.index + o.buf = raw + o.index = 0 + + err = o.unmarshalType(p.stype, p.sprop, false, bas) + o.buf = obuf + o.index = oi + + return err +} + +// Decode a slice of embedded messages. +func (o *Buffer) dec_slice_struct_message(p *Properties, base structPointer) error { + return o.dec_slice_struct(p, false, base) +} + +// Decode a slice of embedded groups. +func (o *Buffer) dec_slice_struct_group(p *Properties, base structPointer) error { + return o.dec_slice_struct(p, true, base) +} + +// Decode a slice of structs ([]*struct). +func (o *Buffer) dec_slice_struct(p *Properties, is_group bool, base structPointer) error { + v := reflect.New(p.stype) + bas := toStructPointer(v) + structPointer_StructPointerSlice(base, p.field).Append(bas) + + if is_group { + err := o.unmarshalType(p.stype, p.sprop, is_group, bas) + return err + } + + raw, err := o.DecodeRawBytes(false) + if err != nil { + return err + } + + // If the object can unmarshal itself, let it. + if p.isUnmarshaler { + iv := v.Interface() + return iv.(Unmarshaler).Unmarshal(raw) + } + + obuf := o.buf + oi := o.index + o.buf = raw + o.index = 0 + + err = o.unmarshalType(p.stype, p.sprop, is_group, bas) + + o.buf = obuf + o.index = oi + + return err +} diff --git a/vendor/github.com/golang/protobuf/proto/encode.go b/vendor/github.com/golang/protobuf/proto/encode.go new file mode 100644 index 0000000..2b30f84 --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/encode.go @@ -0,0 +1,1362 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Routines for encoding data into the wire format for protocol buffers. + */ + +import ( + "errors" + "fmt" + "reflect" + "sort" +) + +// RequiredNotSetError is the error returned if Marshal is called with +// a protocol buffer struct whose required fields have not +// all been initialized. It is also the error returned if Unmarshal is +// called with an encoded protocol buffer that does not include all the +// required fields. +// +// When printed, RequiredNotSetError reports the first unset required field in a +// message. If the field cannot be precisely determined, it is reported as +// "{Unknown}". +type RequiredNotSetError struct { + field string +} + +func (e *RequiredNotSetError) Error() string { + return fmt.Sprintf("proto: required field %q not set", e.field) +} + +var ( + // errRepeatedHasNil is the error returned if Marshal is called with + // a struct with a repeated field containing a nil element. + errRepeatedHasNil = errors.New("proto: repeated field has nil element") + + // errOneofHasNil is the error returned if Marshal is called with + // a struct with a oneof field containing a nil element. + errOneofHasNil = errors.New("proto: oneof field has nil value") + + // ErrNil is the error returned if Marshal is called with nil. + ErrNil = errors.New("proto: Marshal called with nil") + + // ErrTooLarge is the error returned if Marshal is called with a + // message that encodes to >2GB. + ErrTooLarge = errors.New("proto: message encodes to over 2 GB") +) + +// The fundamental encoders that put bytes on the wire. +// Those that take integer types all accept uint64 and are +// therefore of type valueEncoder. + +const maxVarintBytes = 10 // maximum length of a varint + +// maxMarshalSize is the largest allowed size of an encoded protobuf, +// since C++ and Java use signed int32s for the size. +const maxMarshalSize = 1<<31 - 1 + +// EncodeVarint returns the varint encoding of x. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +// Not used by the package itself, but helpful to clients +// wishing to use the same encoding. +func EncodeVarint(x uint64) []byte { + var buf [maxVarintBytes]byte + var n int + for n = 0; x > 127; n++ { + buf[n] = 0x80 | uint8(x&0x7F) + x >>= 7 + } + buf[n] = uint8(x) + n++ + return buf[0:n] +} + +// EncodeVarint writes a varint-encoded integer to the Buffer. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +func (p *Buffer) EncodeVarint(x uint64) error { + for x >= 1<<7 { + p.buf = append(p.buf, uint8(x&0x7f|0x80)) + x >>= 7 + } + p.buf = append(p.buf, uint8(x)) + return nil +} + +// SizeVarint returns the varint encoding size of an integer. +func SizeVarint(x uint64) int { + return sizeVarint(x) +} + +func sizeVarint(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} + +// EncodeFixed64 writes a 64-bit integer to the Buffer. +// This is the format for the +// fixed64, sfixed64, and double protocol buffer types. +func (p *Buffer) EncodeFixed64(x uint64) error { + p.buf = append(p.buf, + uint8(x), + uint8(x>>8), + uint8(x>>16), + uint8(x>>24), + uint8(x>>32), + uint8(x>>40), + uint8(x>>48), + uint8(x>>56)) + return nil +} + +func sizeFixed64(x uint64) int { + return 8 +} + +// EncodeFixed32 writes a 32-bit integer to the Buffer. +// This is the format for the +// fixed32, sfixed32, and float protocol buffer types. +func (p *Buffer) EncodeFixed32(x uint64) error { + p.buf = append(p.buf, + uint8(x), + uint8(x>>8), + uint8(x>>16), + uint8(x>>24)) + return nil +} + +func sizeFixed32(x uint64) int { + return 4 +} + +// EncodeZigzag64 writes a zigzag-encoded 64-bit integer +// to the Buffer. +// This is the format used for the sint64 protocol buffer type. +func (p *Buffer) EncodeZigzag64(x uint64) error { + // use signed number to get arithmetic right shift. + return p.EncodeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} + +func sizeZigzag64(x uint64) int { + return sizeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} + +// EncodeZigzag32 writes a zigzag-encoded 32-bit integer +// to the Buffer. +// This is the format used for the sint32 protocol buffer type. +func (p *Buffer) EncodeZigzag32(x uint64) error { + // use signed number to get arithmetic right shift. + return p.EncodeVarint(uint64((uint32(x) << 1) ^ uint32((int32(x) >> 31)))) +} + +func sizeZigzag32(x uint64) int { + return sizeVarint(uint64((uint32(x) << 1) ^ uint32((int32(x) >> 31)))) +} + +// EncodeRawBytes writes a count-delimited byte buffer to the Buffer. +// This is the format used for the bytes protocol buffer +// type and for embedded messages. +func (p *Buffer) EncodeRawBytes(b []byte) error { + p.EncodeVarint(uint64(len(b))) + p.buf = append(p.buf, b...) + return nil +} + +func sizeRawBytes(b []byte) int { + return sizeVarint(uint64(len(b))) + + len(b) +} + +// EncodeStringBytes writes an encoded string to the Buffer. +// This is the format used for the proto2 string type. +func (p *Buffer) EncodeStringBytes(s string) error { + p.EncodeVarint(uint64(len(s))) + p.buf = append(p.buf, s...) + return nil +} + +func sizeStringBytes(s string) int { + return sizeVarint(uint64(len(s))) + + len(s) +} + +// Marshaler is the interface representing objects that can marshal themselves. +type Marshaler interface { + Marshal() ([]byte, error) +} + +// Marshal takes the protocol buffer +// and encodes it into the wire format, returning the data. +func Marshal(pb Message) ([]byte, error) { + // Can the object marshal itself? + if m, ok := pb.(Marshaler); ok { + return m.Marshal() + } + p := NewBuffer(nil) + err := p.Marshal(pb) + if p.buf == nil && err == nil { + // Return a non-nil slice on success. + return []byte{}, nil + } + return p.buf, err +} + +// EncodeMessage writes the protocol buffer to the Buffer, +// prefixed by a varint-encoded length. +func (p *Buffer) EncodeMessage(pb Message) error { + t, base, err := getbase(pb) + if structPointer_IsNil(base) { + return ErrNil + } + if err == nil { + var state errorState + err = p.enc_len_struct(GetProperties(t.Elem()), base, &state) + } + return err +} + +// Marshal takes the protocol buffer +// and encodes it into the wire format, writing the result to the +// Buffer. +func (p *Buffer) Marshal(pb Message) error { + // Can the object marshal itself? + if m, ok := pb.(Marshaler); ok { + data, err := m.Marshal() + p.buf = append(p.buf, data...) + return err + } + + t, base, err := getbase(pb) + if structPointer_IsNil(base) { + return ErrNil + } + if err == nil { + err = p.enc_struct(GetProperties(t.Elem()), base) + } + + if collectStats { + (stats).Encode++ // Parens are to work around a goimports bug. + } + + if len(p.buf) > maxMarshalSize { + return ErrTooLarge + } + return err +} + +// Size returns the encoded size of a protocol buffer. +func Size(pb Message) (n int) { + // Can the object marshal itself? If so, Size is slow. + // TODO: add Size to Marshaler, or add a Sizer interface. + if m, ok := pb.(Marshaler); ok { + b, _ := m.Marshal() + return len(b) + } + + t, base, err := getbase(pb) + if structPointer_IsNil(base) { + return 0 + } + if err == nil { + n = size_struct(GetProperties(t.Elem()), base) + } + + if collectStats { + (stats).Size++ // Parens are to work around a goimports bug. + } + + return +} + +// Individual type encoders. + +// Encode a bool. +func (o *Buffer) enc_bool(p *Properties, base structPointer) error { + v := *structPointer_Bool(base, p.field) + if v == nil { + return ErrNil + } + x := 0 + if *v { + x = 1 + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func (o *Buffer) enc_proto3_bool(p *Properties, base structPointer) error { + v := *structPointer_BoolVal(base, p.field) + if !v { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, 1) + return nil +} + +func size_bool(p *Properties, base structPointer) int { + v := *structPointer_Bool(base, p.field) + if v == nil { + return 0 + } + return len(p.tagcode) + 1 // each bool takes exactly one byte +} + +func size_proto3_bool(p *Properties, base structPointer) int { + v := *structPointer_BoolVal(base, p.field) + if !v && !p.oneof { + return 0 + } + return len(p.tagcode) + 1 // each bool takes exactly one byte +} + +// Encode an int32. +func (o *Buffer) enc_int32(p *Properties, base structPointer) error { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return ErrNil + } + x := int32(word32_Get(v)) // permit sign extension to use full 64-bit range + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func (o *Buffer) enc_proto3_int32(p *Properties, base structPointer) error { + v := structPointer_Word32Val(base, p.field) + x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range + if x == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func size_int32(p *Properties, base structPointer) (n int) { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return 0 + } + x := int32(word32_Get(v)) // permit sign extension to use full 64-bit range + n += len(p.tagcode) + n += p.valSize(uint64(x)) + return +} + +func size_proto3_int32(p *Properties, base structPointer) (n int) { + v := structPointer_Word32Val(base, p.field) + x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range + if x == 0 && !p.oneof { + return 0 + } + n += len(p.tagcode) + n += p.valSize(uint64(x)) + return +} + +// Encode a uint32. +// Exactly the same as int32, except for no sign extension. +func (o *Buffer) enc_uint32(p *Properties, base structPointer) error { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return ErrNil + } + x := word32_Get(v) + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func (o *Buffer) enc_proto3_uint32(p *Properties, base structPointer) error { + v := structPointer_Word32Val(base, p.field) + x := word32Val_Get(v) + if x == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func size_uint32(p *Properties, base structPointer) (n int) { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return 0 + } + x := word32_Get(v) + n += len(p.tagcode) + n += p.valSize(uint64(x)) + return +} + +func size_proto3_uint32(p *Properties, base structPointer) (n int) { + v := structPointer_Word32Val(base, p.field) + x := word32Val_Get(v) + if x == 0 && !p.oneof { + return 0 + } + n += len(p.tagcode) + n += p.valSize(uint64(x)) + return +} + +// Encode an int64. +func (o *Buffer) enc_int64(p *Properties, base structPointer) error { + v := structPointer_Word64(base, p.field) + if word64_IsNil(v) { + return ErrNil + } + x := word64_Get(v) + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, x) + return nil +} + +func (o *Buffer) enc_proto3_int64(p *Properties, base structPointer) error { + v := structPointer_Word64Val(base, p.field) + x := word64Val_Get(v) + if x == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, x) + return nil +} + +func size_int64(p *Properties, base structPointer) (n int) { + v := structPointer_Word64(base, p.field) + if word64_IsNil(v) { + return 0 + } + x := word64_Get(v) + n += len(p.tagcode) + n += p.valSize(x) + return +} + +func size_proto3_int64(p *Properties, base structPointer) (n int) { + v := structPointer_Word64Val(base, p.field) + x := word64Val_Get(v) + if x == 0 && !p.oneof { + return 0 + } + n += len(p.tagcode) + n += p.valSize(x) + return +} + +// Encode a string. +func (o *Buffer) enc_string(p *Properties, base structPointer) error { + v := *structPointer_String(base, p.field) + if v == nil { + return ErrNil + } + x := *v + o.buf = append(o.buf, p.tagcode...) + o.EncodeStringBytes(x) + return nil +} + +func (o *Buffer) enc_proto3_string(p *Properties, base structPointer) error { + v := *structPointer_StringVal(base, p.field) + if v == "" { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeStringBytes(v) + return nil +} + +func size_string(p *Properties, base structPointer) (n int) { + v := *structPointer_String(base, p.field) + if v == nil { + return 0 + } + x := *v + n += len(p.tagcode) + n += sizeStringBytes(x) + return +} + +func size_proto3_string(p *Properties, base structPointer) (n int) { + v := *structPointer_StringVal(base, p.field) + if v == "" && !p.oneof { + return 0 + } + n += len(p.tagcode) + n += sizeStringBytes(v) + return +} + +// All protocol buffer fields are nillable, but be careful. +func isNil(v reflect.Value) bool { + switch v.Kind() { + case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return v.IsNil() + } + return false +} + +// Encode a message struct. +func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error { + var state errorState + structp := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(structp) { + return ErrNil + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, err := m.Marshal() + if err != nil && !state.shouldContinue(err, nil) { + return err + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(data) + return state.err + } + + o.buf = append(o.buf, p.tagcode...) + return o.enc_len_struct(p.sprop, structp, &state) +} + +func size_struct_message(p *Properties, base structPointer) int { + structp := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(structp) { + return 0 + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, _ := m.Marshal() + n0 := len(p.tagcode) + n1 := sizeRawBytes(data) + return n0 + n1 + } + + n0 := len(p.tagcode) + n1 := size_struct(p.sprop, structp) + n2 := sizeVarint(uint64(n1)) // size of encoded length + return n0 + n1 + n2 +} + +// Encode a group struct. +func (o *Buffer) enc_struct_group(p *Properties, base structPointer) error { + var state errorState + b := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(b) { + return ErrNil + } + + o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup)) + err := o.enc_struct(p.sprop, b) + if err != nil && !state.shouldContinue(err, nil) { + return err + } + o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup)) + return state.err +} + +func size_struct_group(p *Properties, base structPointer) (n int) { + b := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(b) { + return 0 + } + + n += sizeVarint(uint64((p.Tag << 3) | WireStartGroup)) + n += size_struct(p.sprop, b) + n += sizeVarint(uint64((p.Tag << 3) | WireEndGroup)) + return +} + +// Encode a slice of bools ([]bool). +func (o *Buffer) enc_slice_bool(p *Properties, base structPointer) error { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return ErrNil + } + for _, x := range s { + o.buf = append(o.buf, p.tagcode...) + v := uint64(0) + if x { + v = 1 + } + p.valEnc(o, v) + } + return nil +} + +func size_slice_bool(p *Properties, base structPointer) int { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return 0 + } + return l * (len(p.tagcode) + 1) // each bool takes exactly one byte +} + +// Encode a slice of bools ([]bool) in packed format. +func (o *Buffer) enc_slice_packed_bool(p *Properties, base structPointer) error { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(l)) // each bool takes exactly one byte + for _, x := range s { + v := uint64(0) + if x { + v = 1 + } + p.valEnc(o, v) + } + return nil +} + +func size_slice_packed_bool(p *Properties, base structPointer) (n int) { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return 0 + } + n += len(p.tagcode) + n += sizeVarint(uint64(l)) + n += l // each bool takes exactly one byte + return +} + +// Encode a slice of bytes ([]byte). +func (o *Buffer) enc_slice_byte(p *Properties, base structPointer) error { + s := *structPointer_Bytes(base, p.field) + if s == nil { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(s) + return nil +} + +func (o *Buffer) enc_proto3_slice_byte(p *Properties, base structPointer) error { + s := *structPointer_Bytes(base, p.field) + if len(s) == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(s) + return nil +} + +func size_slice_byte(p *Properties, base structPointer) (n int) { + s := *structPointer_Bytes(base, p.field) + if s == nil && !p.oneof { + return 0 + } + n += len(p.tagcode) + n += sizeRawBytes(s) + return +} + +func size_proto3_slice_byte(p *Properties, base structPointer) (n int) { + s := *structPointer_Bytes(base, p.field) + if len(s) == 0 && !p.oneof { + return 0 + } + n += len(p.tagcode) + n += sizeRawBytes(s) + return +} + +// Encode a slice of int32s ([]int32). +func (o *Buffer) enc_slice_int32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + p.valEnc(o, uint64(x)) + } + return nil +} + +func size_slice_int32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + for i := 0; i < l; i++ { + n += len(p.tagcode) + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + n += p.valSize(uint64(x)) + } + return +} + +// Encode a slice of int32s ([]int32) in packed format. +func (o *Buffer) enc_slice_packed_int32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + // TODO: Reuse a Buffer. + buf := NewBuffer(nil) + for i := 0; i < l; i++ { + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + p.valEnc(buf, uint64(x)) + } + + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(len(buf.buf))) + o.buf = append(o.buf, buf.buf...) + return nil +} + +func size_slice_packed_int32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + var bufSize int + for i := 0; i < l; i++ { + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + bufSize += p.valSize(uint64(x)) + } + + n += len(p.tagcode) + n += sizeVarint(uint64(bufSize)) + n += bufSize + return +} + +// Encode a slice of uint32s ([]uint32). +// Exactly the same as int32, except for no sign extension. +func (o *Buffer) enc_slice_uint32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + x := s.Index(i) + p.valEnc(o, uint64(x)) + } + return nil +} + +func size_slice_uint32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + for i := 0; i < l; i++ { + n += len(p.tagcode) + x := s.Index(i) + n += p.valSize(uint64(x)) + } + return +} + +// Encode a slice of uint32s ([]uint32) in packed format. +// Exactly the same as int32, except for no sign extension. +func (o *Buffer) enc_slice_packed_uint32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + // TODO: Reuse a Buffer. + buf := NewBuffer(nil) + for i := 0; i < l; i++ { + p.valEnc(buf, uint64(s.Index(i))) + } + + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(len(buf.buf))) + o.buf = append(o.buf, buf.buf...) + return nil +} + +func size_slice_packed_uint32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + var bufSize int + for i := 0; i < l; i++ { + bufSize += p.valSize(uint64(s.Index(i))) + } + + n += len(p.tagcode) + n += sizeVarint(uint64(bufSize)) + n += bufSize + return +} + +// Encode a slice of int64s ([]int64). +func (o *Buffer) enc_slice_int64(p *Properties, base structPointer) error { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, s.Index(i)) + } + return nil +} + +func size_slice_int64(p *Properties, base structPointer) (n int) { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + for i := 0; i < l; i++ { + n += len(p.tagcode) + n += p.valSize(s.Index(i)) + } + return +} + +// Encode a slice of int64s ([]int64) in packed format. +func (o *Buffer) enc_slice_packed_int64(p *Properties, base structPointer) error { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + // TODO: Reuse a Buffer. + buf := NewBuffer(nil) + for i := 0; i < l; i++ { + p.valEnc(buf, s.Index(i)) + } + + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(len(buf.buf))) + o.buf = append(o.buf, buf.buf...) + return nil +} + +func size_slice_packed_int64(p *Properties, base structPointer) (n int) { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + var bufSize int + for i := 0; i < l; i++ { + bufSize += p.valSize(s.Index(i)) + } + + n += len(p.tagcode) + n += sizeVarint(uint64(bufSize)) + n += bufSize + return +} + +// Encode a slice of slice of bytes ([][]byte). +func (o *Buffer) enc_slice_slice_byte(p *Properties, base structPointer) error { + ss := *structPointer_BytesSlice(base, p.field) + l := len(ss) + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(ss[i]) + } + return nil +} + +func size_slice_slice_byte(p *Properties, base structPointer) (n int) { + ss := *structPointer_BytesSlice(base, p.field) + l := len(ss) + if l == 0 { + return 0 + } + n += l * len(p.tagcode) + for i := 0; i < l; i++ { + n += sizeRawBytes(ss[i]) + } + return +} + +// Encode a slice of strings ([]string). +func (o *Buffer) enc_slice_string(p *Properties, base structPointer) error { + ss := *structPointer_StringSlice(base, p.field) + l := len(ss) + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + o.EncodeStringBytes(ss[i]) + } + return nil +} + +func size_slice_string(p *Properties, base structPointer) (n int) { + ss := *structPointer_StringSlice(base, p.field) + l := len(ss) + n += l * len(p.tagcode) + for i := 0; i < l; i++ { + n += sizeStringBytes(ss[i]) + } + return +} + +// Encode a slice of message structs ([]*struct). +func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) error { + var state errorState + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + + for i := 0; i < l; i++ { + structp := s.Index(i) + if structPointer_IsNil(structp) { + return errRepeatedHasNil + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, err := m.Marshal() + if err != nil && !state.shouldContinue(err, nil) { + return err + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(data) + continue + } + + o.buf = append(o.buf, p.tagcode...) + err := o.enc_len_struct(p.sprop, structp, &state) + if err != nil && !state.shouldContinue(err, nil) { + if err == ErrNil { + return errRepeatedHasNil + } + return err + } + } + return state.err +} + +func size_slice_struct_message(p *Properties, base structPointer) (n int) { + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + n += l * len(p.tagcode) + for i := 0; i < l; i++ { + structp := s.Index(i) + if structPointer_IsNil(structp) { + return // return the size up to this point + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, _ := m.Marshal() + n += sizeRawBytes(data) + continue + } + + n0 := size_struct(p.sprop, structp) + n1 := sizeVarint(uint64(n0)) // size of encoded length + n += n0 + n1 + } + return +} + +// Encode a slice of group structs ([]*struct). +func (o *Buffer) enc_slice_struct_group(p *Properties, base structPointer) error { + var state errorState + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + + for i := 0; i < l; i++ { + b := s.Index(i) + if structPointer_IsNil(b) { + return errRepeatedHasNil + } + + o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup)) + + err := o.enc_struct(p.sprop, b) + + if err != nil && !state.shouldContinue(err, nil) { + if err == ErrNil { + return errRepeatedHasNil + } + return err + } + + o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup)) + } + return state.err +} + +func size_slice_struct_group(p *Properties, base structPointer) (n int) { + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + + n += l * sizeVarint(uint64((p.Tag<<3)|WireStartGroup)) + n += l * sizeVarint(uint64((p.Tag<<3)|WireEndGroup)) + for i := 0; i < l; i++ { + b := s.Index(i) + if structPointer_IsNil(b) { + return // return size up to this point + } + + n += size_struct(p.sprop, b) + } + return +} + +// Encode an extension map. +func (o *Buffer) enc_map(p *Properties, base structPointer) error { + exts := structPointer_ExtMap(base, p.field) + if err := encodeExtensionsMap(*exts); err != nil { + return err + } + + return o.enc_map_body(*exts) +} + +func (o *Buffer) enc_exts(p *Properties, base structPointer) error { + exts := structPointer_Extensions(base, p.field) + + v, mu := exts.extensionsRead() + if v == nil { + return nil + } + + mu.Lock() + defer mu.Unlock() + if err := encodeExtensionsMap(v); err != nil { + return err + } + + return o.enc_map_body(v) +} + +func (o *Buffer) enc_map_body(v map[int32]Extension) error { + // Fast-path for common cases: zero or one extensions. + if len(v) <= 1 { + for _, e := range v { + o.buf = append(o.buf, e.enc...) + } + return nil + } + + // Sort keys to provide a deterministic encoding. + keys := make([]int, 0, len(v)) + for k := range v { + keys = append(keys, int(k)) + } + sort.Ints(keys) + + for _, k := range keys { + o.buf = append(o.buf, v[int32(k)].enc...) + } + return nil +} + +func size_map(p *Properties, base structPointer) int { + v := structPointer_ExtMap(base, p.field) + return extensionsMapSize(*v) +} + +func size_exts(p *Properties, base structPointer) int { + v := structPointer_Extensions(base, p.field) + return extensionsSize(v) +} + +// Encode a map field. +func (o *Buffer) enc_new_map(p *Properties, base structPointer) error { + var state errorState // XXX: or do we need to plumb this through? + + /* + A map defined as + map map_field = N; + is encoded in the same way as + message MapFieldEntry { + key_type key = 1; + value_type value = 2; + } + repeated MapFieldEntry map_field = N; + */ + + v := structPointer_NewAt(base, p.field, p.mtype).Elem() // map[K]V + if v.Len() == 0 { + return nil + } + + keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype) + + enc := func() error { + if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil { + return err + } + if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil && err != ErrNil { + return err + } + return nil + } + + // Don't sort map keys. It is not required by the spec, and C++ doesn't do it. + for _, key := range v.MapKeys() { + val := v.MapIndex(key) + + keycopy.Set(key) + valcopy.Set(val) + + o.buf = append(o.buf, p.tagcode...) + if err := o.enc_len_thing(enc, &state); err != nil { + return err + } + } + return nil +} + +func size_new_map(p *Properties, base structPointer) int { + v := structPointer_NewAt(base, p.field, p.mtype).Elem() // map[K]V + + keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype) + + n := 0 + for _, key := range v.MapKeys() { + val := v.MapIndex(key) + keycopy.Set(key) + valcopy.Set(val) + + // Tag codes for key and val are the responsibility of the sub-sizer. + keysize := p.mkeyprop.size(p.mkeyprop, keybase) + valsize := p.mvalprop.size(p.mvalprop, valbase) + entry := keysize + valsize + // Add on tag code and length of map entry itself. + n += len(p.tagcode) + sizeVarint(uint64(entry)) + entry + } + return n +} + +// mapEncodeScratch returns a new reflect.Value matching the map's value type, +// and a structPointer suitable for passing to an encoder or sizer. +func mapEncodeScratch(mapType reflect.Type) (keycopy, valcopy reflect.Value, keybase, valbase structPointer) { + // Prepare addressable doubly-indirect placeholders for the key and value types. + // This is needed because the element-type encoders expect **T, but the map iteration produces T. + + keycopy = reflect.New(mapType.Key()).Elem() // addressable K + keyptr := reflect.New(reflect.PtrTo(keycopy.Type())).Elem() // addressable *K + keyptr.Set(keycopy.Addr()) // + keybase = toStructPointer(keyptr.Addr()) // **K + + // Value types are more varied and require special handling. + switch mapType.Elem().Kind() { + case reflect.Slice: + // []byte + var dummy []byte + valcopy = reflect.ValueOf(&dummy).Elem() // addressable []byte + valbase = toStructPointer(valcopy.Addr()) + case reflect.Ptr: + // message; the generated field type is map[K]*Msg (so V is *Msg), + // so we only need one level of indirection. + valcopy = reflect.New(mapType.Elem()).Elem() // addressable V + valbase = toStructPointer(valcopy.Addr()) + default: + // everything else + valcopy = reflect.New(mapType.Elem()).Elem() // addressable V + valptr := reflect.New(reflect.PtrTo(valcopy.Type())).Elem() // addressable *V + valptr.Set(valcopy.Addr()) // + valbase = toStructPointer(valptr.Addr()) // **V + } + return +} + +// Encode a struct. +func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error { + var state errorState + // Encode fields in tag order so that decoders may use optimizations + // that depend on the ordering. + // https://developers.google.com/protocol-buffers/docs/encoding#order + for _, i := range prop.order { + p := prop.Prop[i] + if p.enc != nil { + err := p.enc(o, p, base) + if err != nil { + if err == ErrNil { + if p.Required && state.err == nil { + state.err = &RequiredNotSetError{p.Name} + } + } else if err == errRepeatedHasNil { + // Give more context to nil values in repeated fields. + return errors.New("repeated field " + p.OrigName + " has nil element") + } else if !state.shouldContinue(err, p) { + return err + } + } + if len(o.buf) > maxMarshalSize { + return ErrTooLarge + } + } + } + + // Do oneof fields. + if prop.oneofMarshaler != nil { + m := structPointer_Interface(base, prop.stype).(Message) + if err := prop.oneofMarshaler(m, o); err == ErrNil { + return errOneofHasNil + } else if err != nil { + return err + } + } + + // Add unrecognized fields at the end. + if prop.unrecField.IsValid() { + v := *structPointer_Bytes(base, prop.unrecField) + if len(o.buf)+len(v) > maxMarshalSize { + return ErrTooLarge + } + if len(v) > 0 { + o.buf = append(o.buf, v...) + } + } + + return state.err +} + +func size_struct(prop *StructProperties, base structPointer) (n int) { + for _, i := range prop.order { + p := prop.Prop[i] + if p.size != nil { + n += p.size(p, base) + } + } + + // Add unrecognized fields at the end. + if prop.unrecField.IsValid() { + v := *structPointer_Bytes(base, prop.unrecField) + n += len(v) + } + + // Factor in any oneof fields. + if prop.oneofSizer != nil { + m := structPointer_Interface(base, prop.stype).(Message) + n += prop.oneofSizer(m) + } + + return +} + +var zeroes [20]byte // longer than any conceivable sizeVarint + +// Encode a struct, preceded by its encoded length (as a varint). +func (o *Buffer) enc_len_struct(prop *StructProperties, base structPointer, state *errorState) error { + return o.enc_len_thing(func() error { return o.enc_struct(prop, base) }, state) +} + +// Encode something, preceded by its encoded length (as a varint). +func (o *Buffer) enc_len_thing(enc func() error, state *errorState) error { + iLen := len(o.buf) + o.buf = append(o.buf, 0, 0, 0, 0) // reserve four bytes for length + iMsg := len(o.buf) + err := enc() + if err != nil && !state.shouldContinue(err, nil) { + return err + } + lMsg := len(o.buf) - iMsg + lLen := sizeVarint(uint64(lMsg)) + switch x := lLen - (iMsg - iLen); { + case x > 0: // actual length is x bytes larger than the space we reserved + // Move msg x bytes right. + o.buf = append(o.buf, zeroes[:x]...) + copy(o.buf[iMsg+x:], o.buf[iMsg:iMsg+lMsg]) + case x < 0: // actual length is x bytes smaller than the space we reserved + // Move msg x bytes left. + copy(o.buf[iMsg+x:], o.buf[iMsg:iMsg+lMsg]) + o.buf = o.buf[:len(o.buf)+x] // x is negative + } + // Encode the length in the reserved space. + o.buf = o.buf[:iLen] + o.EncodeVarint(uint64(lMsg)) + o.buf = o.buf[:len(o.buf)+lMsg] + return state.err +} + +// errorState maintains the first error that occurs and updates that error +// with additional context. +type errorState struct { + err error +} + +// shouldContinue reports whether encoding should continue upon encountering the +// given error. If the error is RequiredNotSetError, shouldContinue returns true +// and, if this is the first appearance of that error, remembers it for future +// reporting. +// +// If prop is not nil, it may update any error with additional context about the +// field with the error. +func (s *errorState) shouldContinue(err error, prop *Properties) bool { + // Ignore unset required fields. + reqNotSet, ok := err.(*RequiredNotSetError) + if !ok { + return false + } + if s.err == nil { + if prop != nil { + err = &RequiredNotSetError{prop.Name + "." + reqNotSet.field} + } + s.err = err + } + return true +} diff --git a/vendor/github.com/golang/protobuf/proto/equal.go b/vendor/github.com/golang/protobuf/proto/equal.go new file mode 100644 index 0000000..2ed1cf5 --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/equal.go @@ -0,0 +1,300 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2011 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Protocol buffer comparison. + +package proto + +import ( + "bytes" + "log" + "reflect" + "strings" +) + +/* +Equal returns true iff protocol buffers a and b are equal. +The arguments must both be pointers to protocol buffer structs. + +Equality is defined in this way: + - Two messages are equal iff they are the same type, + corresponding fields are equal, unknown field sets + are equal, and extensions sets are equal. + - Two set scalar fields are equal iff their values are equal. + If the fields are of a floating-point type, remember that + NaN != x for all x, including NaN. If the message is defined + in a proto3 .proto file, fields are not "set"; specifically, + zero length proto3 "bytes" fields are equal (nil == {}). + - Two repeated fields are equal iff their lengths are the same, + and their corresponding elements are equal. Note a "bytes" field, + although represented by []byte, is not a repeated field and the + rule for the scalar fields described above applies. + - Two unset fields are equal. + - Two unknown field sets are equal if their current + encoded state is equal. + - Two extension sets are equal iff they have corresponding + elements that are pairwise equal. + - Two map fields are equal iff their lengths are the same, + and they contain the same set of elements. Zero-length map + fields are equal. + - Every other combination of things are not equal. + +The return value is undefined if a and b are not protocol buffers. +*/ +func Equal(a, b Message) bool { + if a == nil || b == nil { + return a == b + } + v1, v2 := reflect.ValueOf(a), reflect.ValueOf(b) + if v1.Type() != v2.Type() { + return false + } + if v1.Kind() == reflect.Ptr { + if v1.IsNil() { + return v2.IsNil() + } + if v2.IsNil() { + return false + } + v1, v2 = v1.Elem(), v2.Elem() + } + if v1.Kind() != reflect.Struct { + return false + } + return equalStruct(v1, v2) +} + +// v1 and v2 are known to have the same type. +func equalStruct(v1, v2 reflect.Value) bool { + sprop := GetProperties(v1.Type()) + for i := 0; i < v1.NumField(); i++ { + f := v1.Type().Field(i) + if strings.HasPrefix(f.Name, "XXX_") { + continue + } + f1, f2 := v1.Field(i), v2.Field(i) + if f.Type.Kind() == reflect.Ptr { + if n1, n2 := f1.IsNil(), f2.IsNil(); n1 && n2 { + // both unset + continue + } else if n1 != n2 { + // set/unset mismatch + return false + } + b1, ok := f1.Interface().(raw) + if ok { + b2 := f2.Interface().(raw) + // RawMessage + if !bytes.Equal(b1.Bytes(), b2.Bytes()) { + return false + } + continue + } + f1, f2 = f1.Elem(), f2.Elem() + } + if !equalAny(f1, f2, sprop.Prop[i]) { + return false + } + } + + if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() { + em2 := v2.FieldByName("XXX_InternalExtensions") + if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) { + return false + } + } + + if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() { + em2 := v2.FieldByName("XXX_extensions") + if !equalExtMap(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) { + return false + } + } + + uf := v1.FieldByName("XXX_unrecognized") + if !uf.IsValid() { + return true + } + + u1 := uf.Bytes() + u2 := v2.FieldByName("XXX_unrecognized").Bytes() + if !bytes.Equal(u1, u2) { + return false + } + + return true +} + +// v1 and v2 are known to have the same type. +// prop may be nil. +func equalAny(v1, v2 reflect.Value, prop *Properties) bool { + if v1.Type() == protoMessageType { + m1, _ := v1.Interface().(Message) + m2, _ := v2.Interface().(Message) + return Equal(m1, m2) + } + switch v1.Kind() { + case reflect.Bool: + return v1.Bool() == v2.Bool() + case reflect.Float32, reflect.Float64: + return v1.Float() == v2.Float() + case reflect.Int32, reflect.Int64: + return v1.Int() == v2.Int() + case reflect.Interface: + // Probably a oneof field; compare the inner values. + n1, n2 := v1.IsNil(), v2.IsNil() + if n1 || n2 { + return n1 == n2 + } + e1, e2 := v1.Elem(), v2.Elem() + if e1.Type() != e2.Type() { + return false + } + return equalAny(e1, e2, nil) + case reflect.Map: + if v1.Len() != v2.Len() { + return false + } + for _, key := range v1.MapKeys() { + val2 := v2.MapIndex(key) + if !val2.IsValid() { + // This key was not found in the second map. + return false + } + if !equalAny(v1.MapIndex(key), val2, nil) { + return false + } + } + return true + case reflect.Ptr: + // Maps may have nil values in them, so check for nil. + if v1.IsNil() && v2.IsNil() { + return true + } + if v1.IsNil() != v2.IsNil() { + return false + } + return equalAny(v1.Elem(), v2.Elem(), prop) + case reflect.Slice: + if v1.Type().Elem().Kind() == reflect.Uint8 { + // short circuit: []byte + + // Edge case: if this is in a proto3 message, a zero length + // bytes field is considered the zero value. + if prop != nil && prop.proto3 && v1.Len() == 0 && v2.Len() == 0 { + return true + } + if v1.IsNil() != v2.IsNil() { + return false + } + return bytes.Equal(v1.Interface().([]byte), v2.Interface().([]byte)) + } + + if v1.Len() != v2.Len() { + return false + } + for i := 0; i < v1.Len(); i++ { + if !equalAny(v1.Index(i), v2.Index(i), prop) { + return false + } + } + return true + case reflect.String: + return v1.Interface().(string) == v2.Interface().(string) + case reflect.Struct: + return equalStruct(v1, v2) + case reflect.Uint32, reflect.Uint64: + return v1.Uint() == v2.Uint() + } + + // unknown type, so not a protocol buffer + log.Printf("proto: don't know how to compare %v", v1) + return false +} + +// base is the struct type that the extensions are based on. +// x1 and x2 are InternalExtensions. +func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool { + em1, _ := x1.extensionsRead() + em2, _ := x2.extensionsRead() + return equalExtMap(base, em1, em2) +} + +func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool { + if len(em1) != len(em2) { + return false + } + + for extNum, e1 := range em1 { + e2, ok := em2[extNum] + if !ok { + return false + } + + m1, m2 := e1.value, e2.value + + if m1 != nil && m2 != nil { + // Both are unencoded. + if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) { + return false + } + continue + } + + // At least one is encoded. To do a semantically correct comparison + // we need to unmarshal them first. + var desc *ExtensionDesc + if m := extensionMaps[base]; m != nil { + desc = m[extNum] + } + if desc == nil { + log.Printf("proto: don't know how to compare extension %d of %v", extNum, base) + continue + } + var err error + if m1 == nil { + m1, err = decodeExtension(e1.enc, desc) + } + if m2 == nil && err == nil { + m2, err = decodeExtension(e2.enc, desc) + } + if err != nil { + // The encoded form is invalid. + log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err) + return false + } + if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) { + return false + } + } + + return true +} diff --git a/vendor/github.com/golang/protobuf/proto/extensions.go b/vendor/github.com/golang/protobuf/proto/extensions.go new file mode 100644 index 0000000..eaad218 --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/extensions.go @@ -0,0 +1,587 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Types and routines for supporting protocol buffer extensions. + */ + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "sync" +) + +// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. +var ErrMissingExtension = errors.New("proto: missing extension") + +// ExtensionRange represents a range of message extensions for a protocol buffer. +// Used in code generated by the protocol compiler. +type ExtensionRange struct { + Start, End int32 // both inclusive +} + +// extendableProto is an interface implemented by any protocol buffer generated by the current +// proto compiler that may be extended. +type extendableProto interface { + Message + ExtensionRangeArray() []ExtensionRange + extensionsWrite() map[int32]Extension + extensionsRead() (map[int32]Extension, sync.Locker) +} + +// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous +// version of the proto compiler that may be extended. +type extendableProtoV1 interface { + Message + ExtensionRangeArray() []ExtensionRange + ExtensionMap() map[int32]Extension +} + +// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto. +type extensionAdapter struct { + extendableProtoV1 +} + +func (e extensionAdapter) extensionsWrite() map[int32]Extension { + return e.ExtensionMap() +} + +func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { + return e.ExtensionMap(), notLocker{} +} + +// notLocker is a sync.Locker whose Lock and Unlock methods are nops. +type notLocker struct{} + +func (n notLocker) Lock() {} +func (n notLocker) Unlock() {} + +// extendable returns the extendableProto interface for the given generated proto message. +// If the proto message has the old extension format, it returns a wrapper that implements +// the extendableProto interface. +func extendable(p interface{}) (extendableProto, bool) { + if ep, ok := p.(extendableProto); ok { + return ep, ok + } + if ep, ok := p.(extendableProtoV1); ok { + return extensionAdapter{ep}, ok + } + return nil, false +} + +// XXX_InternalExtensions is an internal representation of proto extensions. +// +// Each generated message struct type embeds an anonymous XXX_InternalExtensions field, +// thus gaining the unexported 'extensions' method, which can be called only from the proto package. +// +// The methods of XXX_InternalExtensions are not concurrency safe in general, +// but calls to logically read-only methods such as has and get may be executed concurrently. +type XXX_InternalExtensions struct { + // The struct must be indirect so that if a user inadvertently copies a + // generated message and its embedded XXX_InternalExtensions, they + // avoid the mayhem of a copied mutex. + // + // The mutex serializes all logically read-only operations to p.extensionMap. + // It is up to the client to ensure that write operations to p.extensionMap are + // mutually exclusive with other accesses. + p *struct { + mu sync.Mutex + extensionMap map[int32]Extension + } +} + +// extensionsWrite returns the extension map, creating it on first use. +func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension { + if e.p == nil { + e.p = new(struct { + mu sync.Mutex + extensionMap map[int32]Extension + }) + e.p.extensionMap = make(map[int32]Extension) + } + return e.p.extensionMap +} + +// extensionsRead returns the extensions map for read-only use. It may be nil. +// The caller must hold the returned mutex's lock when accessing Elements within the map. +func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) { + if e.p == nil { + return nil, nil + } + return e.p.extensionMap, &e.p.mu +} + +var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() +var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem() + +// ExtensionDesc represents an extension specification. +// Used in generated code from the protocol compiler. +type ExtensionDesc struct { + ExtendedType Message // nil pointer to the type that is being extended + ExtensionType interface{} // nil pointer to the extension type + Field int32 // field number + Name string // fully-qualified name of extension, for text formatting + Tag string // protobuf tag style + Filename string // name of the file in which the extension is defined +} + +func (ed *ExtensionDesc) repeated() bool { + t := reflect.TypeOf(ed.ExtensionType) + return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 +} + +// Extension represents an extension in a message. +type Extension struct { + // When an extension is stored in a message using SetExtension + // only desc and value are set. When the message is marshaled + // enc will be set to the encoded form of the message. + // + // When a message is unmarshaled and contains extensions, each + // extension will have only enc set. When such an extension is + // accessed using GetExtension (or GetExtensions) desc and value + // will be set. + desc *ExtensionDesc + value interface{} + enc []byte +} + +// SetRawExtension is for testing only. +func SetRawExtension(base Message, id int32, b []byte) { + epb, ok := extendable(base) + if !ok { + return + } + extmap := epb.extensionsWrite() + extmap[id] = Extension{enc: b} +} + +// isExtensionField returns true iff the given field number is in an extension range. +func isExtensionField(pb extendableProto, field int32) bool { + for _, er := range pb.ExtensionRangeArray() { + if er.Start <= field && field <= er.End { + return true + } + } + return false +} + +// checkExtensionTypes checks that the given extension is valid for pb. +func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { + var pbi interface{} = pb + // Check the extended type. + if ea, ok := pbi.(extensionAdapter); ok { + pbi = ea.extendableProtoV1 + } + if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b { + return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) + } + // Check the range. + if !isExtensionField(pb, extension.Field) { + return errors.New("proto: bad extension number; not in declared ranges") + } + return nil +} + +// extPropKey is sufficient to uniquely identify an extension. +type extPropKey struct { + base reflect.Type + field int32 +} + +var extProp = struct { + sync.RWMutex + m map[extPropKey]*Properties +}{ + m: make(map[extPropKey]*Properties), +} + +func extensionProperties(ed *ExtensionDesc) *Properties { + key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} + + extProp.RLock() + if prop, ok := extProp.m[key]; ok { + extProp.RUnlock() + return prop + } + extProp.RUnlock() + + extProp.Lock() + defer extProp.Unlock() + // Check again. + if prop, ok := extProp.m[key]; ok { + return prop + } + + prop := new(Properties) + prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) + extProp.m[key] = prop + return prop +} + +// encode encodes any unmarshaled (unencoded) extensions in e. +func encodeExtensions(e *XXX_InternalExtensions) error { + m, mu := e.extensionsRead() + if m == nil { + return nil // fast path + } + mu.Lock() + defer mu.Unlock() + return encodeExtensionsMap(m) +} + +// encode encodes any unmarshaled (unencoded) extensions in e. +func encodeExtensionsMap(m map[int32]Extension) error { + for k, e := range m { + if e.value == nil || e.desc == nil { + // Extension is only in its encoded form. + continue + } + + // We don't skip extensions that have an encoded form set, + // because the extension value may have been mutated after + // the last time this function was called. + + et := reflect.TypeOf(e.desc.ExtensionType) + props := extensionProperties(e.desc) + + p := NewBuffer(nil) + // If e.value has type T, the encoder expects a *struct{ X T }. + // Pass a *T with a zero field and hope it all works out. + x := reflect.New(et) + x.Elem().Set(reflect.ValueOf(e.value)) + if err := props.enc(p, props, toStructPointer(x)); err != nil { + return err + } + e.enc = p.buf + m[k] = e + } + return nil +} + +func extensionsSize(e *XXX_InternalExtensions) (n int) { + m, mu := e.extensionsRead() + if m == nil { + return 0 + } + mu.Lock() + defer mu.Unlock() + return extensionsMapSize(m) +} + +func extensionsMapSize(m map[int32]Extension) (n int) { + for _, e := range m { + if e.value == nil || e.desc == nil { + // Extension is only in its encoded form. + n += len(e.enc) + continue + } + + // We don't skip extensions that have an encoded form set, + // because the extension value may have been mutated after + // the last time this function was called. + + et := reflect.TypeOf(e.desc.ExtensionType) + props := extensionProperties(e.desc) + + // If e.value has type T, the encoder expects a *struct{ X T }. + // Pass a *T with a zero field and hope it all works out. + x := reflect.New(et) + x.Elem().Set(reflect.ValueOf(e.value)) + n += props.size(props, toStructPointer(x)) + } + return +} + +// HasExtension returns whether the given extension is present in pb. +func HasExtension(pb Message, extension *ExtensionDesc) bool { + // TODO: Check types, field numbers, etc.? + epb, ok := extendable(pb) + if !ok { + return false + } + extmap, mu := epb.extensionsRead() + if extmap == nil { + return false + } + mu.Lock() + _, ok = extmap[extension.Field] + mu.Unlock() + return ok +} + +// ClearExtension removes the given extension from pb. +func ClearExtension(pb Message, extension *ExtensionDesc) { + epb, ok := extendable(pb) + if !ok { + return + } + // TODO: Check types, field numbers, etc.? + extmap := epb.extensionsWrite() + delete(extmap, extension.Field) +} + +// GetExtension parses and returns the given extension of pb. +// If the extension is not present and has no default value it returns ErrMissingExtension. +func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { + epb, ok := extendable(pb) + if !ok { + return nil, errors.New("proto: not an extendable proto") + } + + if err := checkExtensionTypes(epb, extension); err != nil { + return nil, err + } + + emap, mu := epb.extensionsRead() + if emap == nil { + return defaultExtensionValue(extension) + } + mu.Lock() + defer mu.Unlock() + e, ok := emap[extension.Field] + if !ok { + // defaultExtensionValue returns the default value or + // ErrMissingExtension if there is no default. + return defaultExtensionValue(extension) + } + + if e.value != nil { + // Already decoded. Check the descriptor, though. + if e.desc != extension { + // This shouldn't happen. If it does, it means that + // GetExtension was called twice with two different + // descriptors with the same field number. + return nil, errors.New("proto: descriptor conflict") + } + return e.value, nil + } + + v, err := decodeExtension(e.enc, extension) + if err != nil { + return nil, err + } + + // Remember the decoded version and drop the encoded version. + // That way it is safe to mutate what we return. + e.value = v + e.desc = extension + e.enc = nil + emap[extension.Field] = e + return e.value, nil +} + +// defaultExtensionValue returns the default value for extension. +// If no default for an extension is defined ErrMissingExtension is returned. +func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) { + t := reflect.TypeOf(extension.ExtensionType) + props := extensionProperties(extension) + + sf, _, err := fieldDefault(t, props) + if err != nil { + return nil, err + } + + if sf == nil || sf.value == nil { + // There is no default value. + return nil, ErrMissingExtension + } + + if t.Kind() != reflect.Ptr { + // We do not need to return a Ptr, we can directly return sf.value. + return sf.value, nil + } + + // We need to return an interface{} that is a pointer to sf.value. + value := reflect.New(t).Elem() + value.Set(reflect.New(value.Type().Elem())) + if sf.kind == reflect.Int32 { + // We may have an int32 or an enum, but the underlying data is int32. + // Since we can't set an int32 into a non int32 reflect.value directly + // set it as a int32. + value.Elem().SetInt(int64(sf.value.(int32))) + } else { + value.Elem().Set(reflect.ValueOf(sf.value)) + } + return value.Interface(), nil +} + +// decodeExtension decodes an extension encoded in b. +func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { + o := NewBuffer(b) + + t := reflect.TypeOf(extension.ExtensionType) + + props := extensionProperties(extension) + + // t is a pointer to a struct, pointer to basic type or a slice. + // Allocate a "field" to store the pointer/slice itself; the + // pointer/slice will be stored here. We pass + // the address of this field to props.dec. + // This passes a zero field and a *t and lets props.dec + // interpret it as a *struct{ x t }. + value := reflect.New(t).Elem() + + for { + // Discard wire type and field number varint. It isn't needed. + if _, err := o.DecodeVarint(); err != nil { + return nil, err + } + + if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil { + return nil, err + } + + if o.index >= len(o.buf) { + break + } + } + return value.Interface(), nil +} + +// GetExtensions returns a slice of the extensions present in pb that are also listed in es. +// The returned slice has the same length as es; missing extensions will appear as nil elements. +func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { + epb, ok := extendable(pb) + if !ok { + return nil, errors.New("proto: not an extendable proto") + } + extensions = make([]interface{}, len(es)) + for i, e := range es { + extensions[i], err = GetExtension(epb, e) + if err == ErrMissingExtension { + err = nil + } + if err != nil { + return + } + } + return +} + +// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order. +// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing +// just the Field field, which defines the extension's field number. +func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { + epb, ok := extendable(pb) + if !ok { + return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb) + } + registeredExtensions := RegisteredExtensions(pb) + + emap, mu := epb.extensionsRead() + if emap == nil { + return nil, nil + } + mu.Lock() + defer mu.Unlock() + extensions := make([]*ExtensionDesc, 0, len(emap)) + for extid, e := range emap { + desc := e.desc + if desc == nil { + desc = registeredExtensions[extid] + if desc == nil { + desc = &ExtensionDesc{Field: extid} + } + } + + extensions = append(extensions, desc) + } + return extensions, nil +} + +// SetExtension sets the specified extension of pb to the specified value. +func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { + epb, ok := extendable(pb) + if !ok { + return errors.New("proto: not an extendable proto") + } + if err := checkExtensionTypes(epb, extension); err != nil { + return err + } + typ := reflect.TypeOf(extension.ExtensionType) + if typ != reflect.TypeOf(value) { + return errors.New("proto: bad extension value type") + } + // nil extension values need to be caught early, because the + // encoder can't distinguish an ErrNil due to a nil extension + // from an ErrNil due to a missing field. Extensions are + // always optional, so the encoder would just swallow the error + // and drop all the extensions from the encoded message. + if reflect.ValueOf(value).IsNil() { + return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) + } + + extmap := epb.extensionsWrite() + extmap[extension.Field] = Extension{desc: extension, value: value} + return nil +} + +// ClearAllExtensions clears all extensions from pb. +func ClearAllExtensions(pb Message) { + epb, ok := extendable(pb) + if !ok { + return + } + m := epb.extensionsWrite() + for k := range m { + delete(m, k) + } +} + +// A global registry of extensions. +// The generated code will register the generated descriptors by calling RegisterExtension. + +var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) + +// RegisterExtension is called from the generated code. +func RegisterExtension(desc *ExtensionDesc) { + st := reflect.TypeOf(desc.ExtendedType).Elem() + m := extensionMaps[st] + if m == nil { + m = make(map[int32]*ExtensionDesc) + extensionMaps[st] = m + } + if _, ok := m[desc.Field]; ok { + panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) + } + m[desc.Field] = desc +} + +// RegisteredExtensions returns a map of the registered extensions of a +// protocol buffer struct, indexed by the extension number. +// The argument pb should be a nil pointer to the struct type. +func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { + return extensionMaps[reflect.TypeOf(pb).Elem()] +} diff --git a/vendor/github.com/golang/protobuf/proto/lib.go b/vendor/github.com/golang/protobuf/proto/lib.go new file mode 100644 index 0000000..ac4ddbc --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/lib.go @@ -0,0 +1,898 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +Package proto converts data structures to and from the wire format of +protocol buffers. It works in concert with the Go source code generated +for .proto files by the protocol compiler. + +A summary of the properties of the protocol buffer interface +for a protocol buffer variable v: + + - Names are turned from camel_case to CamelCase for export. + - There are no methods on v to set fields; just treat + them as structure fields. + - There are getters that return a field's value if set, + and return the field's default value if unset. + The getters work even if the receiver is a nil message. + - The zero value for a struct is its correct initialization state. + All desired fields must be set before marshaling. + - A Reset() method will restore a protobuf struct to its zero state. + - Non-repeated fields are pointers to the values; nil means unset. + That is, optional or required field int32 f becomes F *int32. + - Repeated fields are slices. + - Helper functions are available to aid the setting of fields. + msg.Foo = proto.String("hello") // set field + - Constants are defined to hold the default values of all fields that + have them. They have the form Default_StructName_FieldName. + Because the getter methods handle defaulted values, + direct use of these constants should be rare. + - Enums are given type names and maps from names to values. + Enum values are prefixed by the enclosing message's name, or by the + enum's type name if it is a top-level enum. Enum types have a String + method, and a Enum method to assist in message construction. + - Nested messages, groups and enums have type names prefixed with the name of + the surrounding message type. + - Extensions are given descriptor names that start with E_, + followed by an underscore-delimited list of the nested messages + that contain it (if any) followed by the CamelCased name of the + extension field itself. HasExtension, ClearExtension, GetExtension + and SetExtension are functions for manipulating extensions. + - Oneof field sets are given a single field in their message, + with distinguished wrapper types for each possible field value. + - Marshal and Unmarshal are functions to encode and decode the wire format. + +When the .proto file specifies `syntax="proto3"`, there are some differences: + + - Non-repeated fields of non-message type are values instead of pointers. + - Getters are only generated for message and oneof fields. + - Enum types do not get an Enum method. + +The simplest way to describe this is to see an example. +Given file test.proto, containing + + package example; + + enum FOO { X = 17; } + + message Test { + required string label = 1; + optional int32 type = 2 [default=77]; + repeated int64 reps = 3; + optional group OptionalGroup = 4 { + required string RequiredField = 5; + } + oneof union { + int32 number = 6; + string name = 7; + } + } + +The resulting file, test.pb.go, is: + + package example + + import proto "github.com/golang/protobuf/proto" + import math "math" + + type FOO int32 + const ( + FOO_X FOO = 17 + ) + var FOO_name = map[int32]string{ + 17: "X", + } + var FOO_value = map[string]int32{ + "X": 17, + } + + func (x FOO) Enum() *FOO { + p := new(FOO) + *p = x + return p + } + func (x FOO) String() string { + return proto.EnumName(FOO_name, int32(x)) + } + func (x *FOO) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(FOO_value, data) + if err != nil { + return err + } + *x = FOO(value) + return nil + } + + type Test struct { + Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"` + Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"` + Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"` + Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"` + // Types that are valid to be assigned to Union: + // *Test_Number + // *Test_Name + Union isTest_Union `protobuf_oneof:"union"` + XXX_unrecognized []byte `json:"-"` + } + func (m *Test) Reset() { *m = Test{} } + func (m *Test) String() string { return proto.CompactTextString(m) } + func (*Test) ProtoMessage() {} + + type isTest_Union interface { + isTest_Union() + } + + type Test_Number struct { + Number int32 `protobuf:"varint,6,opt,name=number"` + } + type Test_Name struct { + Name string `protobuf:"bytes,7,opt,name=name"` + } + + func (*Test_Number) isTest_Union() {} + func (*Test_Name) isTest_Union() {} + + func (m *Test) GetUnion() isTest_Union { + if m != nil { + return m.Union + } + return nil + } + const Default_Test_Type int32 = 77 + + func (m *Test) GetLabel() string { + if m != nil && m.Label != nil { + return *m.Label + } + return "" + } + + func (m *Test) GetType() int32 { + if m != nil && m.Type != nil { + return *m.Type + } + return Default_Test_Type + } + + func (m *Test) GetOptionalgroup() *Test_OptionalGroup { + if m != nil { + return m.Optionalgroup + } + return nil + } + + type Test_OptionalGroup struct { + RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"` + } + func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} } + func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) } + + func (m *Test_OptionalGroup) GetRequiredField() string { + if m != nil && m.RequiredField != nil { + return *m.RequiredField + } + return "" + } + + func (m *Test) GetNumber() int32 { + if x, ok := m.GetUnion().(*Test_Number); ok { + return x.Number + } + return 0 + } + + func (m *Test) GetName() string { + if x, ok := m.GetUnion().(*Test_Name); ok { + return x.Name + } + return "" + } + + func init() { + proto.RegisterEnum("example.FOO", FOO_name, FOO_value) + } + +To create and play with a Test object: + + package main + + import ( + "log" + + "github.com/golang/protobuf/proto" + pb "./example.pb" + ) + + func main() { + test := &pb.Test{ + Label: proto.String("hello"), + Type: proto.Int32(17), + Reps: []int64{1, 2, 3}, + Optionalgroup: &pb.Test_OptionalGroup{ + RequiredField: proto.String("good bye"), + }, + Union: &pb.Test_Name{"fred"}, + } + data, err := proto.Marshal(test) + if err != nil { + log.Fatal("marshaling error: ", err) + } + newTest := &pb.Test{} + err = proto.Unmarshal(data, newTest) + if err != nil { + log.Fatal("unmarshaling error: ", err) + } + // Now test and newTest contain the same data. + if test.GetLabel() != newTest.GetLabel() { + log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel()) + } + // Use a type switch to determine which oneof was set. + switch u := test.Union.(type) { + case *pb.Test_Number: // u.Number contains the number. + case *pb.Test_Name: // u.Name contains the string. + } + // etc. + } +*/ +package proto + +import ( + "encoding/json" + "fmt" + "log" + "reflect" + "sort" + "strconv" + "sync" +) + +// Message is implemented by generated protocol buffer messages. +type Message interface { + Reset() + String() string + ProtoMessage() +} + +// Stats records allocation details about the protocol buffer encoders +// and decoders. Useful for tuning the library itself. +type Stats struct { + Emalloc uint64 // mallocs in encode + Dmalloc uint64 // mallocs in decode + Encode uint64 // number of encodes + Decode uint64 // number of decodes + Chit uint64 // number of cache hits + Cmiss uint64 // number of cache misses + Size uint64 // number of sizes +} + +// Set to true to enable stats collection. +const collectStats = false + +var stats Stats + +// GetStats returns a copy of the global Stats structure. +func GetStats() Stats { return stats } + +// A Buffer is a buffer manager for marshaling and unmarshaling +// protocol buffers. It may be reused between invocations to +// reduce memory usage. It is not necessary to use a Buffer; +// the global functions Marshal and Unmarshal create a +// temporary Buffer and are fine for most applications. +type Buffer struct { + buf []byte // encode/decode byte stream + index int // read point + + // pools of basic types to amortize allocation. + bools []bool + uint32s []uint32 + uint64s []uint64 + + // extra pools, only used with pointer_reflect.go + int32s []int32 + int64s []int64 + float32s []float32 + float64s []float64 +} + +// NewBuffer allocates a new Buffer and initializes its internal data to +// the contents of the argument slice. +func NewBuffer(e []byte) *Buffer { + return &Buffer{buf: e} +} + +// Reset resets the Buffer, ready for marshaling a new protocol buffer. +func (p *Buffer) Reset() { + p.buf = p.buf[0:0] // for reading/writing + p.index = 0 // for reading +} + +// SetBuf replaces the internal buffer with the slice, +// ready for unmarshaling the contents of the slice. +func (p *Buffer) SetBuf(s []byte) { + p.buf = s + p.index = 0 +} + +// Bytes returns the contents of the Buffer. +func (p *Buffer) Bytes() []byte { return p.buf } + +/* + * Helper routines for simplifying the creation of optional fields of basic type. + */ + +// Bool is a helper routine that allocates a new bool value +// to store v and returns a pointer to it. +func Bool(v bool) *bool { + return &v +} + +// Int32 is a helper routine that allocates a new int32 value +// to store v and returns a pointer to it. +func Int32(v int32) *int32 { + return &v +} + +// Int is a helper routine that allocates a new int32 value +// to store v and returns a pointer to it, but unlike Int32 +// its argument value is an int. +func Int(v int) *int32 { + p := new(int32) + *p = int32(v) + return p +} + +// Int64 is a helper routine that allocates a new int64 value +// to store v and returns a pointer to it. +func Int64(v int64) *int64 { + return &v +} + +// Float32 is a helper routine that allocates a new float32 value +// to store v and returns a pointer to it. +func Float32(v float32) *float32 { + return &v +} + +// Float64 is a helper routine that allocates a new float64 value +// to store v and returns a pointer to it. +func Float64(v float64) *float64 { + return &v +} + +// Uint32 is a helper routine that allocates a new uint32 value +// to store v and returns a pointer to it. +func Uint32(v uint32) *uint32 { + return &v +} + +// Uint64 is a helper routine that allocates a new uint64 value +// to store v and returns a pointer to it. +func Uint64(v uint64) *uint64 { + return &v +} + +// String is a helper routine that allocates a new string value +// to store v and returns a pointer to it. +func String(v string) *string { + return &v +} + +// EnumName is a helper function to simplify printing protocol buffer enums +// by name. Given an enum map and a value, it returns a useful string. +func EnumName(m map[int32]string, v int32) string { + s, ok := m[v] + if ok { + return s + } + return strconv.Itoa(int(v)) +} + +// UnmarshalJSONEnum is a helper function to simplify recovering enum int values +// from their JSON-encoded representation. Given a map from the enum's symbolic +// names to its int values, and a byte buffer containing the JSON-encoded +// value, it returns an int32 that can be cast to the enum type by the caller. +// +// The function can deal with both JSON representations, numeric and symbolic. +func UnmarshalJSONEnum(m map[string]int32, data []byte, enumName string) (int32, error) { + if data[0] == '"' { + // New style: enums are strings. + var repr string + if err := json.Unmarshal(data, &repr); err != nil { + return -1, err + } + val, ok := m[repr] + if !ok { + return 0, fmt.Errorf("unrecognized enum %s value %q", enumName, repr) + } + return val, nil + } + // Old style: enums are ints. + var val int32 + if err := json.Unmarshal(data, &val); err != nil { + return 0, fmt.Errorf("cannot unmarshal %#q into enum %s", data, enumName) + } + return val, nil +} + +// DebugPrint dumps the encoded data in b in a debugging format with a header +// including the string s. Used in testing but made available for general debugging. +func (p *Buffer) DebugPrint(s string, b []byte) { + var u uint64 + + obuf := p.buf + index := p.index + p.buf = b + p.index = 0 + depth := 0 + + fmt.Printf("\n--- %s ---\n", s) + +out: + for { + for i := 0; i < depth; i++ { + fmt.Print(" ") + } + + index := p.index + if index == len(p.buf) { + break + } + + op, err := p.DecodeVarint() + if err != nil { + fmt.Printf("%3d: fetching op err %v\n", index, err) + break out + } + tag := op >> 3 + wire := op & 7 + + switch wire { + default: + fmt.Printf("%3d: t=%3d unknown wire=%d\n", + index, tag, wire) + break out + + case WireBytes: + var r []byte + + r, err = p.DecodeRawBytes(false) + if err != nil { + break out + } + fmt.Printf("%3d: t=%3d bytes [%d]", index, tag, len(r)) + if len(r) <= 6 { + for i := 0; i < len(r); i++ { + fmt.Printf(" %.2x", r[i]) + } + } else { + for i := 0; i < 3; i++ { + fmt.Printf(" %.2x", r[i]) + } + fmt.Printf(" ..") + for i := len(r) - 3; i < len(r); i++ { + fmt.Printf(" %.2x", r[i]) + } + } + fmt.Printf("\n") + + case WireFixed32: + u, err = p.DecodeFixed32() + if err != nil { + fmt.Printf("%3d: t=%3d fix32 err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d fix32 %d\n", index, tag, u) + + case WireFixed64: + u, err = p.DecodeFixed64() + if err != nil { + fmt.Printf("%3d: t=%3d fix64 err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d fix64 %d\n", index, tag, u) + + case WireVarint: + u, err = p.DecodeVarint() + if err != nil { + fmt.Printf("%3d: t=%3d varint err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d varint %d\n", index, tag, u) + + case WireStartGroup: + fmt.Printf("%3d: t=%3d start\n", index, tag) + depth++ + + case WireEndGroup: + depth-- + fmt.Printf("%3d: t=%3d end\n", index, tag) + } + } + + if depth != 0 { + fmt.Printf("%3d: start-end not balanced %d\n", p.index, depth) + } + fmt.Printf("\n") + + p.buf = obuf + p.index = index +} + +// SetDefaults sets unset protocol buffer fields to their default values. +// It only modifies fields that are both unset and have defined defaults. +// It recursively sets default values in any non-nil sub-messages. +func SetDefaults(pb Message) { + setDefaults(reflect.ValueOf(pb), true, false) +} + +// v is a pointer to a struct. +func setDefaults(v reflect.Value, recur, zeros bool) { + v = v.Elem() + + defaultMu.RLock() + dm, ok := defaults[v.Type()] + defaultMu.RUnlock() + if !ok { + dm = buildDefaultMessage(v.Type()) + defaultMu.Lock() + defaults[v.Type()] = dm + defaultMu.Unlock() + } + + for _, sf := range dm.scalars { + f := v.Field(sf.index) + if !f.IsNil() { + // field already set + continue + } + dv := sf.value + if dv == nil && !zeros { + // no explicit default, and don't want to set zeros + continue + } + fptr := f.Addr().Interface() // **T + // TODO: Consider batching the allocations we do here. + switch sf.kind { + case reflect.Bool: + b := new(bool) + if dv != nil { + *b = dv.(bool) + } + *(fptr.(**bool)) = b + case reflect.Float32: + f := new(float32) + if dv != nil { + *f = dv.(float32) + } + *(fptr.(**float32)) = f + case reflect.Float64: + f := new(float64) + if dv != nil { + *f = dv.(float64) + } + *(fptr.(**float64)) = f + case reflect.Int32: + // might be an enum + if ft := f.Type(); ft != int32PtrType { + // enum + f.Set(reflect.New(ft.Elem())) + if dv != nil { + f.Elem().SetInt(int64(dv.(int32))) + } + } else { + // int32 field + i := new(int32) + if dv != nil { + *i = dv.(int32) + } + *(fptr.(**int32)) = i + } + case reflect.Int64: + i := new(int64) + if dv != nil { + *i = dv.(int64) + } + *(fptr.(**int64)) = i + case reflect.String: + s := new(string) + if dv != nil { + *s = dv.(string) + } + *(fptr.(**string)) = s + case reflect.Uint8: + // exceptional case: []byte + var b []byte + if dv != nil { + db := dv.([]byte) + b = make([]byte, len(db)) + copy(b, db) + } else { + b = []byte{} + } + *(fptr.(*[]byte)) = b + case reflect.Uint32: + u := new(uint32) + if dv != nil { + *u = dv.(uint32) + } + *(fptr.(**uint32)) = u + case reflect.Uint64: + u := new(uint64) + if dv != nil { + *u = dv.(uint64) + } + *(fptr.(**uint64)) = u + default: + log.Printf("proto: can't set default for field %v (sf.kind=%v)", f, sf.kind) + } + } + + for _, ni := range dm.nested { + f := v.Field(ni) + // f is *T or []*T or map[T]*T + switch f.Kind() { + case reflect.Ptr: + if f.IsNil() { + continue + } + setDefaults(f, recur, zeros) + + case reflect.Slice: + for i := 0; i < f.Len(); i++ { + e := f.Index(i) + if e.IsNil() { + continue + } + setDefaults(e, recur, zeros) + } + + case reflect.Map: + for _, k := range f.MapKeys() { + e := f.MapIndex(k) + if e.IsNil() { + continue + } + setDefaults(e, recur, zeros) + } + } + } +} + +var ( + // defaults maps a protocol buffer struct type to a slice of the fields, + // with its scalar fields set to their proto-declared non-zero default values. + defaultMu sync.RWMutex + defaults = make(map[reflect.Type]defaultMessage) + + int32PtrType = reflect.TypeOf((*int32)(nil)) +) + +// defaultMessage represents information about the default values of a message. +type defaultMessage struct { + scalars []scalarField + nested []int // struct field index of nested messages +} + +type scalarField struct { + index int // struct field index + kind reflect.Kind // element type (the T in *T or []T) + value interface{} // the proto-declared default value, or nil +} + +// t is a struct type. +func buildDefaultMessage(t reflect.Type) (dm defaultMessage) { + sprop := GetProperties(t) + for _, prop := range sprop.Prop { + fi, ok := sprop.decoderTags.get(prop.Tag) + if !ok { + // XXX_unrecognized + continue + } + ft := t.Field(fi).Type + + sf, nested, err := fieldDefault(ft, prop) + switch { + case err != nil: + log.Print(err) + case nested: + dm.nested = append(dm.nested, fi) + case sf != nil: + sf.index = fi + dm.scalars = append(dm.scalars, *sf) + } + } + + return dm +} + +// fieldDefault returns the scalarField for field type ft. +// sf will be nil if the field can not have a default. +// nestedMessage will be true if this is a nested message. +// Note that sf.index is not set on return. +func fieldDefault(ft reflect.Type, prop *Properties) (sf *scalarField, nestedMessage bool, err error) { + var canHaveDefault bool + switch ft.Kind() { + case reflect.Ptr: + if ft.Elem().Kind() == reflect.Struct { + nestedMessage = true + } else { + canHaveDefault = true // proto2 scalar field + } + + case reflect.Slice: + switch ft.Elem().Kind() { + case reflect.Ptr: + nestedMessage = true // repeated message + case reflect.Uint8: + canHaveDefault = true // bytes field + } + + case reflect.Map: + if ft.Elem().Kind() == reflect.Ptr { + nestedMessage = true // map with message values + } + } + + if !canHaveDefault { + if nestedMessage { + return nil, true, nil + } + return nil, false, nil + } + + // We now know that ft is a pointer or slice. + sf = &scalarField{kind: ft.Elem().Kind()} + + // scalar fields without defaults + if !prop.HasDefault { + return sf, false, nil + } + + // a scalar field: either *T or []byte + switch ft.Elem().Kind() { + case reflect.Bool: + x, err := strconv.ParseBool(prop.Default) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default bool %q: %v", prop.Default, err) + } + sf.value = x + case reflect.Float32: + x, err := strconv.ParseFloat(prop.Default, 32) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default float32 %q: %v", prop.Default, err) + } + sf.value = float32(x) + case reflect.Float64: + x, err := strconv.ParseFloat(prop.Default, 64) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default float64 %q: %v", prop.Default, err) + } + sf.value = x + case reflect.Int32: + x, err := strconv.ParseInt(prop.Default, 10, 32) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default int32 %q: %v", prop.Default, err) + } + sf.value = int32(x) + case reflect.Int64: + x, err := strconv.ParseInt(prop.Default, 10, 64) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default int64 %q: %v", prop.Default, err) + } + sf.value = x + case reflect.String: + sf.value = prop.Default + case reflect.Uint8: + // []byte (not *uint8) + sf.value = []byte(prop.Default) + case reflect.Uint32: + x, err := strconv.ParseUint(prop.Default, 10, 32) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default uint32 %q: %v", prop.Default, err) + } + sf.value = uint32(x) + case reflect.Uint64: + x, err := strconv.ParseUint(prop.Default, 10, 64) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default uint64 %q: %v", prop.Default, err) + } + sf.value = x + default: + return nil, false, fmt.Errorf("proto: unhandled def kind %v", ft.Elem().Kind()) + } + + return sf, false, nil +} + +// Map fields may have key types of non-float scalars, strings and enums. +// The easiest way to sort them in some deterministic order is to use fmt. +// If this turns out to be inefficient we can always consider other options, +// such as doing a Schwartzian transform. + +func mapKeys(vs []reflect.Value) sort.Interface { + s := mapKeySorter{ + vs: vs, + // default Less function: textual comparison + less: func(a, b reflect.Value) bool { + return fmt.Sprint(a.Interface()) < fmt.Sprint(b.Interface()) + }, + } + + // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps; + // numeric keys are sorted numerically. + if len(vs) == 0 { + return s + } + switch vs[0].Kind() { + case reflect.Int32, reflect.Int64: + s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() } + case reflect.Uint32, reflect.Uint64: + s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() } + } + + return s +} + +type mapKeySorter struct { + vs []reflect.Value + less func(a, b reflect.Value) bool +} + +func (s mapKeySorter) Len() int { return len(s.vs) } +func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] } +func (s mapKeySorter) Less(i, j int) bool { + return s.less(s.vs[i], s.vs[j]) +} + +// isProto3Zero reports whether v is a zero proto3 value. +func isProto3Zero(v reflect.Value) bool { + switch v.Kind() { + case reflect.Bool: + return !v.Bool() + case reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint32, reflect.Uint64: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.String: + return v.String() == "" + } + return false +} + +// ProtoPackageIsVersion2 is referenced from generated protocol buffer files +// to assert that that code is compatible with this version of the proto package. +const ProtoPackageIsVersion2 = true + +// ProtoPackageIsVersion1 is referenced from generated protocol buffer files +// to assert that that code is compatible with this version of the proto package. +const ProtoPackageIsVersion1 = true diff --git a/vendor/github.com/golang/protobuf/proto/message_set.go b/vendor/github.com/golang/protobuf/proto/message_set.go new file mode 100644 index 0000000..fd982de --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/message_set.go @@ -0,0 +1,311 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Support for message sets. + */ + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "reflect" + "sort" +) + +// errNoMessageTypeID occurs when a protocol buffer does not have a message type ID. +// A message type ID is required for storing a protocol buffer in a message set. +var errNoMessageTypeID = errors.New("proto does not have a message type ID") + +// The first two types (_MessageSet_Item and messageSet) +// model what the protocol compiler produces for the following protocol message: +// message MessageSet { +// repeated group Item = 1 { +// required int32 type_id = 2; +// required string message = 3; +// }; +// } +// That is the MessageSet wire format. We can't use a proto to generate these +// because that would introduce a circular dependency between it and this package. + +type _MessageSet_Item struct { + TypeId *int32 `protobuf:"varint,2,req,name=type_id"` + Message []byte `protobuf:"bytes,3,req,name=message"` +} + +type messageSet struct { + Item []*_MessageSet_Item `protobuf:"group,1,rep"` + XXX_unrecognized []byte + // TODO: caching? +} + +// Make sure messageSet is a Message. +var _ Message = (*messageSet)(nil) + +// messageTypeIder is an interface satisfied by a protocol buffer type +// that may be stored in a MessageSet. +type messageTypeIder interface { + MessageTypeId() int32 +} + +func (ms *messageSet) find(pb Message) *_MessageSet_Item { + mti, ok := pb.(messageTypeIder) + if !ok { + return nil + } + id := mti.MessageTypeId() + for _, item := range ms.Item { + if *item.TypeId == id { + return item + } + } + return nil +} + +func (ms *messageSet) Has(pb Message) bool { + if ms.find(pb) != nil { + return true + } + return false +} + +func (ms *messageSet) Unmarshal(pb Message) error { + if item := ms.find(pb); item != nil { + return Unmarshal(item.Message, pb) + } + if _, ok := pb.(messageTypeIder); !ok { + return errNoMessageTypeID + } + return nil // TODO: return error instead? +} + +func (ms *messageSet) Marshal(pb Message) error { + msg, err := Marshal(pb) + if err != nil { + return err + } + if item := ms.find(pb); item != nil { + // reuse existing item + item.Message = msg + return nil + } + + mti, ok := pb.(messageTypeIder) + if !ok { + return errNoMessageTypeID + } + + mtid := mti.MessageTypeId() + ms.Item = append(ms.Item, &_MessageSet_Item{ + TypeId: &mtid, + Message: msg, + }) + return nil +} + +func (ms *messageSet) Reset() { *ms = messageSet{} } +func (ms *messageSet) String() string { return CompactTextString(ms) } +func (*messageSet) ProtoMessage() {} + +// Support for the message_set_wire_format message option. + +func skipVarint(buf []byte) []byte { + i := 0 + for ; buf[i]&0x80 != 0; i++ { + } + return buf[i+1:] +} + +// MarshalMessageSet encodes the extension map represented by m in the message set wire format. +// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option. +func MarshalMessageSet(exts interface{}) ([]byte, error) { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + if err := encodeExtensions(exts); err != nil { + return nil, err + } + m, _ = exts.extensionsRead() + case map[int32]Extension: + if err := encodeExtensionsMap(exts); err != nil { + return nil, err + } + m = exts + default: + return nil, errors.New("proto: not an extension map") + } + + // Sort extension IDs to provide a deterministic encoding. + // See also enc_map in encode.go. + ids := make([]int, 0, len(m)) + for id := range m { + ids = append(ids, int(id)) + } + sort.Ints(ids) + + ms := &messageSet{Item: make([]*_MessageSet_Item, 0, len(m))} + for _, id := range ids { + e := m[int32(id)] + // Remove the wire type and field number varint, as well as the length varint. + msg := skipVarint(skipVarint(e.enc)) + + ms.Item = append(ms.Item, &_MessageSet_Item{ + TypeId: Int32(int32(id)), + Message: msg, + }) + } + return Marshal(ms) +} + +// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format. +// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option. +func UnmarshalMessageSet(buf []byte, exts interface{}) error { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + m = exts.extensionsWrite() + case map[int32]Extension: + m = exts + default: + return errors.New("proto: not an extension map") + } + + ms := new(messageSet) + if err := Unmarshal(buf, ms); err != nil { + return err + } + for _, item := range ms.Item { + id := *item.TypeId + msg := item.Message + + // Restore wire type and field number varint, plus length varint. + // Be careful to preserve duplicate items. + b := EncodeVarint(uint64(id)<<3 | WireBytes) + if ext, ok := m[id]; ok { + // Existing data; rip off the tag and length varint + // so we join the new data correctly. + // We can assume that ext.enc is set because we are unmarshaling. + o := ext.enc[len(b):] // skip wire type and field number + _, n := DecodeVarint(o) // calculate length of length varint + o = o[n:] // skip length varint + msg = append(o, msg...) // join old data and new data + } + b = append(b, EncodeVarint(uint64(len(msg)))...) + b = append(b, msg...) + + m[id] = Extension{enc: b} + } + return nil +} + +// MarshalMessageSetJSON encodes the extension map represented by m in JSON format. +// It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option. +func MarshalMessageSetJSON(exts interface{}) ([]byte, error) { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + m, _ = exts.extensionsRead() + case map[int32]Extension: + m = exts + default: + return nil, errors.New("proto: not an extension map") + } + var b bytes.Buffer + b.WriteByte('{') + + // Process the map in key order for deterministic output. + ids := make([]int32, 0, len(m)) + for id := range m { + ids = append(ids, id) + } + sort.Sort(int32Slice(ids)) // int32Slice defined in text.go + + for i, id := range ids { + ext := m[id] + if i > 0 { + b.WriteByte(',') + } + + msd, ok := messageSetMap[id] + if !ok { + // Unknown type; we can't render it, so skip it. + continue + } + fmt.Fprintf(&b, `"[%s]":`, msd.name) + + x := ext.value + if x == nil { + x = reflect.New(msd.t.Elem()).Interface() + if err := Unmarshal(ext.enc, x.(Message)); err != nil { + return nil, err + } + } + d, err := json.Marshal(x) + if err != nil { + return nil, err + } + b.Write(d) + } + b.WriteByte('}') + return b.Bytes(), nil +} + +// UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format. +// It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option. +func UnmarshalMessageSetJSON(buf []byte, exts interface{}) error { + // Common-case fast path. + if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) { + return nil + } + + // This is fairly tricky, and it's not clear that it is needed. + return errors.New("TODO: UnmarshalMessageSetJSON not yet implemented") +} + +// A global registry of types that can be used in a MessageSet. + +var messageSetMap = make(map[int32]messageSetDesc) + +type messageSetDesc struct { + t reflect.Type // pointer to struct + name string +} + +// RegisterMessageSetType is called from the generated code. +func RegisterMessageSetType(m Message, fieldNum int32, name string) { + messageSetMap[fieldNum] = messageSetDesc{ + t: reflect.TypeOf(m), + name: name, + } +} diff --git a/vendor/github.com/golang/protobuf/proto/pointer_reflect.go b/vendor/github.com/golang/protobuf/proto/pointer_reflect.go new file mode 100644 index 0000000..fb512e2 --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/pointer_reflect.go @@ -0,0 +1,484 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +build appengine js + +// This file contains an implementation of proto field accesses using package reflect. +// It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can +// be used on App Engine. + +package proto + +import ( + "math" + "reflect" +) + +// A structPointer is a pointer to a struct. +type structPointer struct { + v reflect.Value +} + +// toStructPointer returns a structPointer equivalent to the given reflect value. +// The reflect value must itself be a pointer to a struct. +func toStructPointer(v reflect.Value) structPointer { + return structPointer{v} +} + +// IsNil reports whether p is nil. +func structPointer_IsNil(p structPointer) bool { + return p.v.IsNil() +} + +// Interface returns the struct pointer as an interface value. +func structPointer_Interface(p structPointer, _ reflect.Type) interface{} { + return p.v.Interface() +} + +// A field identifies a field in a struct, accessible from a structPointer. +// In this implementation, a field is identified by the sequence of field indices +// passed to reflect's FieldByIndex. +type field []int + +// toField returns a field equivalent to the given reflect field. +func toField(f *reflect.StructField) field { + return f.Index +} + +// invalidField is an invalid field identifier. +var invalidField = field(nil) + +// IsValid reports whether the field identifier is valid. +func (f field) IsValid() bool { return f != nil } + +// field returns the given field in the struct as a reflect value. +func structPointer_field(p structPointer, f field) reflect.Value { + // Special case: an extension map entry with a value of type T + // passes a *T to the struct-handling code with a zero field, + // expecting that it will be treated as equivalent to *struct{ X T }, + // which has the same memory layout. We have to handle that case + // specially, because reflect will panic if we call FieldByIndex on a + // non-struct. + if f == nil { + return p.v.Elem() + } + + return p.v.Elem().FieldByIndex(f) +} + +// ifield returns the given field in the struct as an interface value. +func structPointer_ifield(p structPointer, f field) interface{} { + return structPointer_field(p, f).Addr().Interface() +} + +// Bytes returns the address of a []byte field in the struct. +func structPointer_Bytes(p structPointer, f field) *[]byte { + return structPointer_ifield(p, f).(*[]byte) +} + +// BytesSlice returns the address of a [][]byte field in the struct. +func structPointer_BytesSlice(p structPointer, f field) *[][]byte { + return structPointer_ifield(p, f).(*[][]byte) +} + +// Bool returns the address of a *bool field in the struct. +func structPointer_Bool(p structPointer, f field) **bool { + return structPointer_ifield(p, f).(**bool) +} + +// BoolVal returns the address of a bool field in the struct. +func structPointer_BoolVal(p structPointer, f field) *bool { + return structPointer_ifield(p, f).(*bool) +} + +// BoolSlice returns the address of a []bool field in the struct. +func structPointer_BoolSlice(p structPointer, f field) *[]bool { + return structPointer_ifield(p, f).(*[]bool) +} + +// String returns the address of a *string field in the struct. +func structPointer_String(p structPointer, f field) **string { + return structPointer_ifield(p, f).(**string) +} + +// StringVal returns the address of a string field in the struct. +func structPointer_StringVal(p structPointer, f field) *string { + return structPointer_ifield(p, f).(*string) +} + +// StringSlice returns the address of a []string field in the struct. +func structPointer_StringSlice(p structPointer, f field) *[]string { + return structPointer_ifield(p, f).(*[]string) +} + +// Extensions returns the address of an extension map field in the struct. +func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions { + return structPointer_ifield(p, f).(*XXX_InternalExtensions) +} + +// ExtMap returns the address of an extension map field in the struct. +func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { + return structPointer_ifield(p, f).(*map[int32]Extension) +} + +// NewAt returns the reflect.Value for a pointer to a field in the struct. +func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value { + return structPointer_field(p, f).Addr() +} + +// SetStructPointer writes a *struct field in the struct. +func structPointer_SetStructPointer(p structPointer, f field, q structPointer) { + structPointer_field(p, f).Set(q.v) +} + +// GetStructPointer reads a *struct field in the struct. +func structPointer_GetStructPointer(p structPointer, f field) structPointer { + return structPointer{structPointer_field(p, f)} +} + +// StructPointerSlice the address of a []*struct field in the struct. +func structPointer_StructPointerSlice(p structPointer, f field) structPointerSlice { + return structPointerSlice{structPointer_field(p, f)} +} + +// A structPointerSlice represents the address of a slice of pointers to structs +// (themselves messages or groups). That is, v.Type() is *[]*struct{...}. +type structPointerSlice struct { + v reflect.Value +} + +func (p structPointerSlice) Len() int { return p.v.Len() } +func (p structPointerSlice) Index(i int) structPointer { return structPointer{p.v.Index(i)} } +func (p structPointerSlice) Append(q structPointer) { + p.v.Set(reflect.Append(p.v, q.v)) +} + +var ( + int32Type = reflect.TypeOf(int32(0)) + uint32Type = reflect.TypeOf(uint32(0)) + float32Type = reflect.TypeOf(float32(0)) + int64Type = reflect.TypeOf(int64(0)) + uint64Type = reflect.TypeOf(uint64(0)) + float64Type = reflect.TypeOf(float64(0)) +) + +// A word32 represents a field of type *int32, *uint32, *float32, or *enum. +// That is, v.Type() is *int32, *uint32, *float32, or *enum and v is assignable. +type word32 struct { + v reflect.Value +} + +// IsNil reports whether p is nil. +func word32_IsNil(p word32) bool { + return p.v.IsNil() +} + +// Set sets p to point at a newly allocated word with bits set to x. +func word32_Set(p word32, o *Buffer, x uint32) { + t := p.v.Type().Elem() + switch t { + case int32Type: + if len(o.int32s) == 0 { + o.int32s = make([]int32, uint32PoolSize) + } + o.int32s[0] = int32(x) + p.v.Set(reflect.ValueOf(&o.int32s[0])) + o.int32s = o.int32s[1:] + return + case uint32Type: + if len(o.uint32s) == 0 { + o.uint32s = make([]uint32, uint32PoolSize) + } + o.uint32s[0] = x + p.v.Set(reflect.ValueOf(&o.uint32s[0])) + o.uint32s = o.uint32s[1:] + return + case float32Type: + if len(o.float32s) == 0 { + o.float32s = make([]float32, uint32PoolSize) + } + o.float32s[0] = math.Float32frombits(x) + p.v.Set(reflect.ValueOf(&o.float32s[0])) + o.float32s = o.float32s[1:] + return + } + + // must be enum + p.v.Set(reflect.New(t)) + p.v.Elem().SetInt(int64(int32(x))) +} + +// Get gets the bits pointed at by p, as a uint32. +func word32_Get(p word32) uint32 { + elem := p.v.Elem() + switch elem.Kind() { + case reflect.Int32: + return uint32(elem.Int()) + case reflect.Uint32: + return uint32(elem.Uint()) + case reflect.Float32: + return math.Float32bits(float32(elem.Float())) + } + panic("unreachable") +} + +// Word32 returns a reference to a *int32, *uint32, *float32, or *enum field in the struct. +func structPointer_Word32(p structPointer, f field) word32 { + return word32{structPointer_field(p, f)} +} + +// A word32Val represents a field of type int32, uint32, float32, or enum. +// That is, v.Type() is int32, uint32, float32, or enum and v is assignable. +type word32Val struct { + v reflect.Value +} + +// Set sets *p to x. +func word32Val_Set(p word32Val, x uint32) { + switch p.v.Type() { + case int32Type: + p.v.SetInt(int64(x)) + return + case uint32Type: + p.v.SetUint(uint64(x)) + return + case float32Type: + p.v.SetFloat(float64(math.Float32frombits(x))) + return + } + + // must be enum + p.v.SetInt(int64(int32(x))) +} + +// Get gets the bits pointed at by p, as a uint32. +func word32Val_Get(p word32Val) uint32 { + elem := p.v + switch elem.Kind() { + case reflect.Int32: + return uint32(elem.Int()) + case reflect.Uint32: + return uint32(elem.Uint()) + case reflect.Float32: + return math.Float32bits(float32(elem.Float())) + } + panic("unreachable") +} + +// Word32Val returns a reference to a int32, uint32, float32, or enum field in the struct. +func structPointer_Word32Val(p structPointer, f field) word32Val { + return word32Val{structPointer_field(p, f)} +} + +// A word32Slice is a slice of 32-bit values. +// That is, v.Type() is []int32, []uint32, []float32, or []enum. +type word32Slice struct { + v reflect.Value +} + +func (p word32Slice) Append(x uint32) { + n, m := p.v.Len(), p.v.Cap() + if n < m { + p.v.SetLen(n + 1) + } else { + t := p.v.Type().Elem() + p.v.Set(reflect.Append(p.v, reflect.Zero(t))) + } + elem := p.v.Index(n) + switch elem.Kind() { + case reflect.Int32: + elem.SetInt(int64(int32(x))) + case reflect.Uint32: + elem.SetUint(uint64(x)) + case reflect.Float32: + elem.SetFloat(float64(math.Float32frombits(x))) + } +} + +func (p word32Slice) Len() int { + return p.v.Len() +} + +func (p word32Slice) Index(i int) uint32 { + elem := p.v.Index(i) + switch elem.Kind() { + case reflect.Int32: + return uint32(elem.Int()) + case reflect.Uint32: + return uint32(elem.Uint()) + case reflect.Float32: + return math.Float32bits(float32(elem.Float())) + } + panic("unreachable") +} + +// Word32Slice returns a reference to a []int32, []uint32, []float32, or []enum field in the struct. +func structPointer_Word32Slice(p structPointer, f field) word32Slice { + return word32Slice{structPointer_field(p, f)} +} + +// word64 is like word32 but for 64-bit values. +type word64 struct { + v reflect.Value +} + +func word64_Set(p word64, o *Buffer, x uint64) { + t := p.v.Type().Elem() + switch t { + case int64Type: + if len(o.int64s) == 0 { + o.int64s = make([]int64, uint64PoolSize) + } + o.int64s[0] = int64(x) + p.v.Set(reflect.ValueOf(&o.int64s[0])) + o.int64s = o.int64s[1:] + return + case uint64Type: + if len(o.uint64s) == 0 { + o.uint64s = make([]uint64, uint64PoolSize) + } + o.uint64s[0] = x + p.v.Set(reflect.ValueOf(&o.uint64s[0])) + o.uint64s = o.uint64s[1:] + return + case float64Type: + if len(o.float64s) == 0 { + o.float64s = make([]float64, uint64PoolSize) + } + o.float64s[0] = math.Float64frombits(x) + p.v.Set(reflect.ValueOf(&o.float64s[0])) + o.float64s = o.float64s[1:] + return + } + panic("unreachable") +} + +func word64_IsNil(p word64) bool { + return p.v.IsNil() +} + +func word64_Get(p word64) uint64 { + elem := p.v.Elem() + switch elem.Kind() { + case reflect.Int64: + return uint64(elem.Int()) + case reflect.Uint64: + return elem.Uint() + case reflect.Float64: + return math.Float64bits(elem.Float()) + } + panic("unreachable") +} + +func structPointer_Word64(p structPointer, f field) word64 { + return word64{structPointer_field(p, f)} +} + +// word64Val is like word32Val but for 64-bit values. +type word64Val struct { + v reflect.Value +} + +func word64Val_Set(p word64Val, o *Buffer, x uint64) { + switch p.v.Type() { + case int64Type: + p.v.SetInt(int64(x)) + return + case uint64Type: + p.v.SetUint(x) + return + case float64Type: + p.v.SetFloat(math.Float64frombits(x)) + return + } + panic("unreachable") +} + +func word64Val_Get(p word64Val) uint64 { + elem := p.v + switch elem.Kind() { + case reflect.Int64: + return uint64(elem.Int()) + case reflect.Uint64: + return elem.Uint() + case reflect.Float64: + return math.Float64bits(elem.Float()) + } + panic("unreachable") +} + +func structPointer_Word64Val(p structPointer, f field) word64Val { + return word64Val{structPointer_field(p, f)} +} + +type word64Slice struct { + v reflect.Value +} + +func (p word64Slice) Append(x uint64) { + n, m := p.v.Len(), p.v.Cap() + if n < m { + p.v.SetLen(n + 1) + } else { + t := p.v.Type().Elem() + p.v.Set(reflect.Append(p.v, reflect.Zero(t))) + } + elem := p.v.Index(n) + switch elem.Kind() { + case reflect.Int64: + elem.SetInt(int64(int64(x))) + case reflect.Uint64: + elem.SetUint(uint64(x)) + case reflect.Float64: + elem.SetFloat(float64(math.Float64frombits(x))) + } +} + +func (p word64Slice) Len() int { + return p.v.Len() +} + +func (p word64Slice) Index(i int) uint64 { + elem := p.v.Index(i) + switch elem.Kind() { + case reflect.Int64: + return uint64(elem.Int()) + case reflect.Uint64: + return uint64(elem.Uint()) + case reflect.Float64: + return math.Float64bits(float64(elem.Float())) + } + panic("unreachable") +} + +func structPointer_Word64Slice(p structPointer, f field) word64Slice { + return word64Slice{structPointer_field(p, f)} +} diff --git a/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go b/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go new file mode 100644 index 0000000..6b5567d --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go @@ -0,0 +1,270 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +build !appengine,!js + +// This file contains the implementation of the proto field accesses using package unsafe. + +package proto + +import ( + "reflect" + "unsafe" +) + +// NOTE: These type_Foo functions would more idiomatically be methods, +// but Go does not allow methods on pointer types, and we must preserve +// some pointer type for the garbage collector. We use these +// funcs with clunky names as our poor approximation to methods. +// +// An alternative would be +// type structPointer struct { p unsafe.Pointer } +// but that does not registerize as well. + +// A structPointer is a pointer to a struct. +type structPointer unsafe.Pointer + +// toStructPointer returns a structPointer equivalent to the given reflect value. +func toStructPointer(v reflect.Value) structPointer { + return structPointer(unsafe.Pointer(v.Pointer())) +} + +// IsNil reports whether p is nil. +func structPointer_IsNil(p structPointer) bool { + return p == nil +} + +// Interface returns the struct pointer, assumed to have element type t, +// as an interface value. +func structPointer_Interface(p structPointer, t reflect.Type) interface{} { + return reflect.NewAt(t, unsafe.Pointer(p)).Interface() +} + +// A field identifies a field in a struct, accessible from a structPointer. +// In this implementation, a field is identified by its byte offset from the start of the struct. +type field uintptr + +// toField returns a field equivalent to the given reflect field. +func toField(f *reflect.StructField) field { + return field(f.Offset) +} + +// invalidField is an invalid field identifier. +const invalidField = ^field(0) + +// IsValid reports whether the field identifier is valid. +func (f field) IsValid() bool { + return f != ^field(0) +} + +// Bytes returns the address of a []byte field in the struct. +func structPointer_Bytes(p structPointer, f field) *[]byte { + return (*[]byte)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// BytesSlice returns the address of a [][]byte field in the struct. +func structPointer_BytesSlice(p structPointer, f field) *[][]byte { + return (*[][]byte)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// Bool returns the address of a *bool field in the struct. +func structPointer_Bool(p structPointer, f field) **bool { + return (**bool)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// BoolVal returns the address of a bool field in the struct. +func structPointer_BoolVal(p structPointer, f field) *bool { + return (*bool)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// BoolSlice returns the address of a []bool field in the struct. +func structPointer_BoolSlice(p structPointer, f field) *[]bool { + return (*[]bool)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// String returns the address of a *string field in the struct. +func structPointer_String(p structPointer, f field) **string { + return (**string)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// StringVal returns the address of a string field in the struct. +func structPointer_StringVal(p structPointer, f field) *string { + return (*string)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// StringSlice returns the address of a []string field in the struct. +func structPointer_StringSlice(p structPointer, f field) *[]string { + return (*[]string)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// ExtMap returns the address of an extension map field in the struct. +func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions { + return (*XXX_InternalExtensions)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { + return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// NewAt returns the reflect.Value for a pointer to a field in the struct. +func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value { + return reflect.NewAt(typ, unsafe.Pointer(uintptr(p)+uintptr(f))) +} + +// SetStructPointer writes a *struct field in the struct. +func structPointer_SetStructPointer(p structPointer, f field, q structPointer) { + *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q +} + +// GetStructPointer reads a *struct field in the struct. +func structPointer_GetStructPointer(p structPointer, f field) structPointer { + return *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// StructPointerSlice the address of a []*struct field in the struct. +func structPointer_StructPointerSlice(p structPointer, f field) *structPointerSlice { + return (*structPointerSlice)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// A structPointerSlice represents a slice of pointers to structs (themselves submessages or groups). +type structPointerSlice []structPointer + +func (v *structPointerSlice) Len() int { return len(*v) } +func (v *structPointerSlice) Index(i int) structPointer { return (*v)[i] } +func (v *structPointerSlice) Append(p structPointer) { *v = append(*v, p) } + +// A word32 is the address of a "pointer to 32-bit value" field. +type word32 **uint32 + +// IsNil reports whether *v is nil. +func word32_IsNil(p word32) bool { + return *p == nil +} + +// Set sets *v to point at a newly allocated word set to x. +func word32_Set(p word32, o *Buffer, x uint32) { + if len(o.uint32s) == 0 { + o.uint32s = make([]uint32, uint32PoolSize) + } + o.uint32s[0] = x + *p = &o.uint32s[0] + o.uint32s = o.uint32s[1:] +} + +// Get gets the value pointed at by *v. +func word32_Get(p word32) uint32 { + return **p +} + +// Word32 returns the address of a *int32, *uint32, *float32, or *enum field in the struct. +func structPointer_Word32(p structPointer, f field) word32 { + return word32((**uint32)(unsafe.Pointer(uintptr(p) + uintptr(f)))) +} + +// A word32Val is the address of a 32-bit value field. +type word32Val *uint32 + +// Set sets *p to x. +func word32Val_Set(p word32Val, x uint32) { + *p = x +} + +// Get gets the value pointed at by p. +func word32Val_Get(p word32Val) uint32 { + return *p +} + +// Word32Val returns the address of a *int32, *uint32, *float32, or *enum field in the struct. +func structPointer_Word32Val(p structPointer, f field) word32Val { + return word32Val((*uint32)(unsafe.Pointer(uintptr(p) + uintptr(f)))) +} + +// A word32Slice is a slice of 32-bit values. +type word32Slice []uint32 + +func (v *word32Slice) Append(x uint32) { *v = append(*v, x) } +func (v *word32Slice) Len() int { return len(*v) } +func (v *word32Slice) Index(i int) uint32 { return (*v)[i] } + +// Word32Slice returns the address of a []int32, []uint32, []float32, or []enum field in the struct. +func structPointer_Word32Slice(p structPointer, f field) *word32Slice { + return (*word32Slice)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// word64 is like word32 but for 64-bit values. +type word64 **uint64 + +func word64_Set(p word64, o *Buffer, x uint64) { + if len(o.uint64s) == 0 { + o.uint64s = make([]uint64, uint64PoolSize) + } + o.uint64s[0] = x + *p = &o.uint64s[0] + o.uint64s = o.uint64s[1:] +} + +func word64_IsNil(p word64) bool { + return *p == nil +} + +func word64_Get(p word64) uint64 { + return **p +} + +func structPointer_Word64(p structPointer, f field) word64 { + return word64((**uint64)(unsafe.Pointer(uintptr(p) + uintptr(f)))) +} + +// word64Val is like word32Val but for 64-bit values. +type word64Val *uint64 + +func word64Val_Set(p word64Val, o *Buffer, x uint64) { + *p = x +} + +func word64Val_Get(p word64Val) uint64 { + return *p +} + +func structPointer_Word64Val(p structPointer, f field) word64Val { + return word64Val((*uint64)(unsafe.Pointer(uintptr(p) + uintptr(f)))) +} + +// word64Slice is like word32Slice but for 64-bit values. +type word64Slice []uint64 + +func (v *word64Slice) Append(x uint64) { *v = append(*v, x) } +func (v *word64Slice) Len() int { return len(*v) } +func (v *word64Slice) Index(i int) uint64 { return (*v)[i] } + +func structPointer_Word64Slice(p structPointer, f field) *word64Slice { + return (*word64Slice)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} diff --git a/vendor/github.com/golang/protobuf/proto/properties.go b/vendor/github.com/golang/protobuf/proto/properties.go new file mode 100644 index 0000000..ec2289c --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/properties.go @@ -0,0 +1,872 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Routines for encoding data into the wire format for protocol buffers. + */ + +import ( + "fmt" + "log" + "os" + "reflect" + "sort" + "strconv" + "strings" + "sync" +) + +const debug bool = false + +// Constants that identify the encoding of a value on the wire. +const ( + WireVarint = 0 + WireFixed64 = 1 + WireBytes = 2 + WireStartGroup = 3 + WireEndGroup = 4 + WireFixed32 = 5 +) + +const startSize = 10 // initial slice/string sizes + +// Encoders are defined in encode.go +// An encoder outputs the full representation of a field, including its +// tag and encoder type. +type encoder func(p *Buffer, prop *Properties, base structPointer) error + +// A valueEncoder encodes a single integer in a particular encoding. +type valueEncoder func(o *Buffer, x uint64) error + +// Sizers are defined in encode.go +// A sizer returns the encoded size of a field, including its tag and encoder +// type. +type sizer func(prop *Properties, base structPointer) int + +// A valueSizer returns the encoded size of a single integer in a particular +// encoding. +type valueSizer func(x uint64) int + +// Decoders are defined in decode.go +// A decoder creates a value from its wire representation. +// Unrecognized subelements are saved in unrec. +type decoder func(p *Buffer, prop *Properties, base structPointer) error + +// A valueDecoder decodes a single integer in a particular encoding. +type valueDecoder func(o *Buffer) (x uint64, err error) + +// A oneofMarshaler does the marshaling for all oneof fields in a message. +type oneofMarshaler func(Message, *Buffer) error + +// A oneofUnmarshaler does the unmarshaling for a oneof field in a message. +type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error) + +// A oneofSizer does the sizing for all oneof fields in a message. +type oneofSizer func(Message) int + +// tagMap is an optimization over map[int]int for typical protocol buffer +// use-cases. Encoded protocol buffers are often in tag order with small tag +// numbers. +type tagMap struct { + fastTags []int + slowTags map[int]int +} + +// tagMapFastLimit is the upper bound on the tag number that will be stored in +// the tagMap slice rather than its map. +const tagMapFastLimit = 1024 + +func (p *tagMap) get(t int) (int, bool) { + if t > 0 && t < tagMapFastLimit { + if t >= len(p.fastTags) { + return 0, false + } + fi := p.fastTags[t] + return fi, fi >= 0 + } + fi, ok := p.slowTags[t] + return fi, ok +} + +func (p *tagMap) put(t int, fi int) { + if t > 0 && t < tagMapFastLimit { + for len(p.fastTags) < t+1 { + p.fastTags = append(p.fastTags, -1) + } + p.fastTags[t] = fi + return + } + if p.slowTags == nil { + p.slowTags = make(map[int]int) + } + p.slowTags[t] = fi +} + +// StructProperties represents properties for all the fields of a struct. +// decoderTags and decoderOrigNames should only be used by the decoder. +type StructProperties struct { + Prop []*Properties // properties for each field + reqCount int // required count + decoderTags tagMap // map from proto tag to struct field number + decoderOrigNames map[string]int // map from original name to struct field number + order []int // list of struct field numbers in tag order + unrecField field // field id of the XXX_unrecognized []byte field + extendable bool // is this an extendable proto + + oneofMarshaler oneofMarshaler + oneofUnmarshaler oneofUnmarshaler + oneofSizer oneofSizer + stype reflect.Type + + // OneofTypes contains information about the oneof fields in this message. + // It is keyed by the original name of a field. + OneofTypes map[string]*OneofProperties +} + +// OneofProperties represents information about a specific field in a oneof. +type OneofProperties struct { + Type reflect.Type // pointer to generated struct type for this oneof field + Field int // struct field number of the containing oneof in the message + Prop *Properties +} + +// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec. +// See encode.go, (*Buffer).enc_struct. + +func (sp *StructProperties) Len() int { return len(sp.order) } +func (sp *StructProperties) Less(i, j int) bool { + return sp.Prop[sp.order[i]].Tag < sp.Prop[sp.order[j]].Tag +} +func (sp *StructProperties) Swap(i, j int) { sp.order[i], sp.order[j] = sp.order[j], sp.order[i] } + +// Properties represents the protocol-specific behavior of a single struct field. +type Properties struct { + Name string // name of the field, for error messages + OrigName string // original name before protocol compiler (always set) + JSONName string // name to use for JSON; determined by protoc + Wire string + WireType int + Tag int + Required bool + Optional bool + Repeated bool + Packed bool // relevant for repeated primitives only + Enum string // set for enum types only + proto3 bool // whether this is known to be a proto3 field; set for []byte only + oneof bool // whether this is a oneof field + + Default string // default value + HasDefault bool // whether an explicit default was provided + def_uint64 uint64 + + enc encoder + valEnc valueEncoder // set for bool and numeric types only + field field + tagcode []byte // encoding of EncodeVarint((Tag<<3)|WireType) + tagbuf [8]byte + stype reflect.Type // set for struct types only + sprop *StructProperties // set for struct types only + isMarshaler bool + isUnmarshaler bool + + mtype reflect.Type // set for map types only + mkeyprop *Properties // set for map types only + mvalprop *Properties // set for map types only + + size sizer + valSize valueSizer // set for bool and numeric types only + + dec decoder + valDec valueDecoder // set for bool and numeric types only + + // If this is a packable field, this will be the decoder for the packed version of the field. + packedDec decoder +} + +// String formats the properties in the protobuf struct field tag style. +func (p *Properties) String() string { + s := p.Wire + s = "," + s += strconv.Itoa(p.Tag) + if p.Required { + s += ",req" + } + if p.Optional { + s += ",opt" + } + if p.Repeated { + s += ",rep" + } + if p.Packed { + s += ",packed" + } + s += ",name=" + p.OrigName + if p.JSONName != p.OrigName { + s += ",json=" + p.JSONName + } + if p.proto3 { + s += ",proto3" + } + if p.oneof { + s += ",oneof" + } + if len(p.Enum) > 0 { + s += ",enum=" + p.Enum + } + if p.HasDefault { + s += ",def=" + p.Default + } + return s +} + +// Parse populates p by parsing a string in the protobuf struct field tag style. +func (p *Properties) Parse(s string) { + // "bytes,49,opt,name=foo,def=hello!" + fields := strings.Split(s, ",") // breaks def=, but handled below. + if len(fields) < 2 { + fmt.Fprintf(os.Stderr, "proto: tag has too few fields: %q\n", s) + return + } + + p.Wire = fields[0] + switch p.Wire { + case "varint": + p.WireType = WireVarint + p.valEnc = (*Buffer).EncodeVarint + p.valDec = (*Buffer).DecodeVarint + p.valSize = sizeVarint + case "fixed32": + p.WireType = WireFixed32 + p.valEnc = (*Buffer).EncodeFixed32 + p.valDec = (*Buffer).DecodeFixed32 + p.valSize = sizeFixed32 + case "fixed64": + p.WireType = WireFixed64 + p.valEnc = (*Buffer).EncodeFixed64 + p.valDec = (*Buffer).DecodeFixed64 + p.valSize = sizeFixed64 + case "zigzag32": + p.WireType = WireVarint + p.valEnc = (*Buffer).EncodeZigzag32 + p.valDec = (*Buffer).DecodeZigzag32 + p.valSize = sizeZigzag32 + case "zigzag64": + p.WireType = WireVarint + p.valEnc = (*Buffer).EncodeZigzag64 + p.valDec = (*Buffer).DecodeZigzag64 + p.valSize = sizeZigzag64 + case "bytes", "group": + p.WireType = WireBytes + // no numeric converter for non-numeric types + default: + fmt.Fprintf(os.Stderr, "proto: tag has unknown wire type: %q\n", s) + return + } + + var err error + p.Tag, err = strconv.Atoi(fields[1]) + if err != nil { + return + } + + for i := 2; i < len(fields); i++ { + f := fields[i] + switch { + case f == "req": + p.Required = true + case f == "opt": + p.Optional = true + case f == "rep": + p.Repeated = true + case f == "packed": + p.Packed = true + case strings.HasPrefix(f, "name="): + p.OrigName = f[5:] + case strings.HasPrefix(f, "json="): + p.JSONName = f[5:] + case strings.HasPrefix(f, "enum="): + p.Enum = f[5:] + case f == "proto3": + p.proto3 = true + case f == "oneof": + p.oneof = true + case strings.HasPrefix(f, "def="): + p.HasDefault = true + p.Default = f[4:] // rest of string + if i+1 < len(fields) { + // Commas aren't escaped, and def is always last. + p.Default += "," + strings.Join(fields[i+1:], ",") + break + } + } + } +} + +func logNoSliceEnc(t1, t2 reflect.Type) { + fmt.Fprintf(os.Stderr, "proto: no slice oenc for %T = []%T\n", t1, t2) +} + +var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem() + +// Initialize the fields for encoding and decoding. +func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lockGetProp bool) { + p.enc = nil + p.dec = nil + p.size = nil + + switch t1 := typ; t1.Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no coders for %v\n", t1) + + // proto3 scalar types + + case reflect.Bool: + p.enc = (*Buffer).enc_proto3_bool + p.dec = (*Buffer).dec_proto3_bool + p.size = size_proto3_bool + case reflect.Int32: + p.enc = (*Buffer).enc_proto3_int32 + p.dec = (*Buffer).dec_proto3_int32 + p.size = size_proto3_int32 + case reflect.Uint32: + p.enc = (*Buffer).enc_proto3_uint32 + p.dec = (*Buffer).dec_proto3_int32 // can reuse + p.size = size_proto3_uint32 + case reflect.Int64, reflect.Uint64: + p.enc = (*Buffer).enc_proto3_int64 + p.dec = (*Buffer).dec_proto3_int64 + p.size = size_proto3_int64 + case reflect.Float32: + p.enc = (*Buffer).enc_proto3_uint32 // can just treat them as bits + p.dec = (*Buffer).dec_proto3_int32 + p.size = size_proto3_uint32 + case reflect.Float64: + p.enc = (*Buffer).enc_proto3_int64 // can just treat them as bits + p.dec = (*Buffer).dec_proto3_int64 + p.size = size_proto3_int64 + case reflect.String: + p.enc = (*Buffer).enc_proto3_string + p.dec = (*Buffer).dec_proto3_string + p.size = size_proto3_string + + case reflect.Ptr: + switch t2 := t1.Elem(); t2.Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no encoder function for %v -> %v\n", t1, t2) + break + case reflect.Bool: + p.enc = (*Buffer).enc_bool + p.dec = (*Buffer).dec_bool + p.size = size_bool + case reflect.Int32: + p.enc = (*Buffer).enc_int32 + p.dec = (*Buffer).dec_int32 + p.size = size_int32 + case reflect.Uint32: + p.enc = (*Buffer).enc_uint32 + p.dec = (*Buffer).dec_int32 // can reuse + p.size = size_uint32 + case reflect.Int64, reflect.Uint64: + p.enc = (*Buffer).enc_int64 + p.dec = (*Buffer).dec_int64 + p.size = size_int64 + case reflect.Float32: + p.enc = (*Buffer).enc_uint32 // can just treat them as bits + p.dec = (*Buffer).dec_int32 + p.size = size_uint32 + case reflect.Float64: + p.enc = (*Buffer).enc_int64 // can just treat them as bits + p.dec = (*Buffer).dec_int64 + p.size = size_int64 + case reflect.String: + p.enc = (*Buffer).enc_string + p.dec = (*Buffer).dec_string + p.size = size_string + case reflect.Struct: + p.stype = t1.Elem() + p.isMarshaler = isMarshaler(t1) + p.isUnmarshaler = isUnmarshaler(t1) + if p.Wire == "bytes" { + p.enc = (*Buffer).enc_struct_message + p.dec = (*Buffer).dec_struct_message + p.size = size_struct_message + } else { + p.enc = (*Buffer).enc_struct_group + p.dec = (*Buffer).dec_struct_group + p.size = size_struct_group + } + } + + case reflect.Slice: + switch t2 := t1.Elem(); t2.Kind() { + default: + logNoSliceEnc(t1, t2) + break + case reflect.Bool: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_bool + p.size = size_slice_packed_bool + } else { + p.enc = (*Buffer).enc_slice_bool + p.size = size_slice_bool + } + p.dec = (*Buffer).dec_slice_bool + p.packedDec = (*Buffer).dec_slice_packed_bool + case reflect.Int32: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_int32 + p.size = size_slice_packed_int32 + } else { + p.enc = (*Buffer).enc_slice_int32 + p.size = size_slice_int32 + } + p.dec = (*Buffer).dec_slice_int32 + p.packedDec = (*Buffer).dec_slice_packed_int32 + case reflect.Uint32: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_uint32 + p.size = size_slice_packed_uint32 + } else { + p.enc = (*Buffer).enc_slice_uint32 + p.size = size_slice_uint32 + } + p.dec = (*Buffer).dec_slice_int32 + p.packedDec = (*Buffer).dec_slice_packed_int32 + case reflect.Int64, reflect.Uint64: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_int64 + p.size = size_slice_packed_int64 + } else { + p.enc = (*Buffer).enc_slice_int64 + p.size = size_slice_int64 + } + p.dec = (*Buffer).dec_slice_int64 + p.packedDec = (*Buffer).dec_slice_packed_int64 + case reflect.Uint8: + p.dec = (*Buffer).dec_slice_byte + if p.proto3 { + p.enc = (*Buffer).enc_proto3_slice_byte + p.size = size_proto3_slice_byte + } else { + p.enc = (*Buffer).enc_slice_byte + p.size = size_slice_byte + } + case reflect.Float32, reflect.Float64: + switch t2.Bits() { + case 32: + // can just treat them as bits + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_uint32 + p.size = size_slice_packed_uint32 + } else { + p.enc = (*Buffer).enc_slice_uint32 + p.size = size_slice_uint32 + } + p.dec = (*Buffer).dec_slice_int32 + p.packedDec = (*Buffer).dec_slice_packed_int32 + case 64: + // can just treat them as bits + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_int64 + p.size = size_slice_packed_int64 + } else { + p.enc = (*Buffer).enc_slice_int64 + p.size = size_slice_int64 + } + p.dec = (*Buffer).dec_slice_int64 + p.packedDec = (*Buffer).dec_slice_packed_int64 + default: + logNoSliceEnc(t1, t2) + break + } + case reflect.String: + p.enc = (*Buffer).enc_slice_string + p.dec = (*Buffer).dec_slice_string + p.size = size_slice_string + case reflect.Ptr: + switch t3 := t2.Elem(); t3.Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no ptr oenc for %T -> %T -> %T\n", t1, t2, t3) + break + case reflect.Struct: + p.stype = t2.Elem() + p.isMarshaler = isMarshaler(t2) + p.isUnmarshaler = isUnmarshaler(t2) + if p.Wire == "bytes" { + p.enc = (*Buffer).enc_slice_struct_message + p.dec = (*Buffer).dec_slice_struct_message + p.size = size_slice_struct_message + } else { + p.enc = (*Buffer).enc_slice_struct_group + p.dec = (*Buffer).dec_slice_struct_group + p.size = size_slice_struct_group + } + } + case reflect.Slice: + switch t2.Elem().Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no slice elem oenc for %T -> %T -> %T\n", t1, t2, t2.Elem()) + break + case reflect.Uint8: + p.enc = (*Buffer).enc_slice_slice_byte + p.dec = (*Buffer).dec_slice_slice_byte + p.size = size_slice_slice_byte + } + } + + case reflect.Map: + p.enc = (*Buffer).enc_new_map + p.dec = (*Buffer).dec_new_map + p.size = size_new_map + + p.mtype = t1 + p.mkeyprop = &Properties{} + p.mkeyprop.init(reflect.PtrTo(p.mtype.Key()), "Key", f.Tag.Get("protobuf_key"), nil, lockGetProp) + p.mvalprop = &Properties{} + vtype := p.mtype.Elem() + if vtype.Kind() != reflect.Ptr && vtype.Kind() != reflect.Slice { + // The value type is not a message (*T) or bytes ([]byte), + // so we need encoders for the pointer to this type. + vtype = reflect.PtrTo(vtype) + } + p.mvalprop.init(vtype, "Value", f.Tag.Get("protobuf_val"), nil, lockGetProp) + } + + // precalculate tag code + wire := p.WireType + if p.Packed { + wire = WireBytes + } + x := uint32(p.Tag)<<3 | uint32(wire) + i := 0 + for i = 0; x > 127; i++ { + p.tagbuf[i] = 0x80 | uint8(x&0x7F) + x >>= 7 + } + p.tagbuf[i] = uint8(x) + p.tagcode = p.tagbuf[0 : i+1] + + if p.stype != nil { + if lockGetProp { + p.sprop = GetProperties(p.stype) + } else { + p.sprop = getPropertiesLocked(p.stype) + } + } +} + +var ( + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() +) + +// isMarshaler reports whether type t implements Marshaler. +func isMarshaler(t reflect.Type) bool { + // We're checking for (likely) pointer-receiver methods + // so if t is not a pointer, something is very wrong. + // The calls above only invoke isMarshaler on pointer types. + if t.Kind() != reflect.Ptr { + panic("proto: misuse of isMarshaler") + } + return t.Implements(marshalerType) +} + +// isUnmarshaler reports whether type t implements Unmarshaler. +func isUnmarshaler(t reflect.Type) bool { + // We're checking for (likely) pointer-receiver methods + // so if t is not a pointer, something is very wrong. + // The calls above only invoke isUnmarshaler on pointer types. + if t.Kind() != reflect.Ptr { + panic("proto: misuse of isUnmarshaler") + } + return t.Implements(unmarshalerType) +} + +// Init populates the properties from a protocol buffer struct tag. +func (p *Properties) Init(typ reflect.Type, name, tag string, f *reflect.StructField) { + p.init(typ, name, tag, f, true) +} + +func (p *Properties) init(typ reflect.Type, name, tag string, f *reflect.StructField, lockGetProp bool) { + // "bytes,49,opt,def=hello!" + p.Name = name + p.OrigName = name + if f != nil { + p.field = toField(f) + } + if tag == "" { + return + } + p.Parse(tag) + p.setEncAndDec(typ, f, lockGetProp) +} + +var ( + propertiesMu sync.RWMutex + propertiesMap = make(map[reflect.Type]*StructProperties) +) + +// GetProperties returns the list of properties for the type represented by t. +// t must represent a generated struct type of a protocol message. +func GetProperties(t reflect.Type) *StructProperties { + if t.Kind() != reflect.Struct { + panic("proto: type must have kind struct") + } + + // Most calls to GetProperties in a long-running program will be + // retrieving details for types we have seen before. + propertiesMu.RLock() + sprop, ok := propertiesMap[t] + propertiesMu.RUnlock() + if ok { + if collectStats { + stats.Chit++ + } + return sprop + } + + propertiesMu.Lock() + sprop = getPropertiesLocked(t) + propertiesMu.Unlock() + return sprop +} + +// getPropertiesLocked requires that propertiesMu is held. +func getPropertiesLocked(t reflect.Type) *StructProperties { + if prop, ok := propertiesMap[t]; ok { + if collectStats { + stats.Chit++ + } + return prop + } + if collectStats { + stats.Cmiss++ + } + + prop := new(StructProperties) + // in case of recursive protos, fill this in now. + propertiesMap[t] = prop + + // build properties + prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) || + reflect.PtrTo(t).Implements(extendableProtoV1Type) + prop.unrecField = invalidField + prop.Prop = make([]*Properties, t.NumField()) + prop.order = make([]int, t.NumField()) + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + p := new(Properties) + name := f.Name + p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false) + + if f.Name == "XXX_InternalExtensions" { // special case + p.enc = (*Buffer).enc_exts + p.dec = nil // not needed + p.size = size_exts + } else if f.Name == "XXX_extensions" { // special case + p.enc = (*Buffer).enc_map + p.dec = nil // not needed + p.size = size_map + } else if f.Name == "XXX_unrecognized" { // special case + prop.unrecField = toField(&f) + } + oneof := f.Tag.Get("protobuf_oneof") // special case + if oneof != "" { + // Oneof fields don't use the traditional protobuf tag. + p.OrigName = oneof + } + prop.Prop[i] = p + prop.order[i] = i + if debug { + print(i, " ", f.Name, " ", t.String(), " ") + if p.Tag > 0 { + print(p.String()) + } + print("\n") + } + if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && oneof == "" { + fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]") + } + } + + // Re-order prop.order. + sort.Sort(prop) + + type oneofMessage interface { + XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), func(Message) int, []interface{}) + } + if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); ok { + var oots []interface{} + prop.oneofMarshaler, prop.oneofUnmarshaler, prop.oneofSizer, oots = om.XXX_OneofFuncs() + prop.stype = t + + // Interpret oneof metadata. + prop.OneofTypes = make(map[string]*OneofProperties) + for _, oot := range oots { + oop := &OneofProperties{ + Type: reflect.ValueOf(oot).Type(), // *T + Prop: new(Properties), + } + sft := oop.Type.Elem().Field(0) + oop.Prop.Name = sft.Name + oop.Prop.Parse(sft.Tag.Get("protobuf")) + // There will be exactly one interface field that + // this new value is assignable to. + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Type.Kind() != reflect.Interface { + continue + } + if !oop.Type.AssignableTo(f.Type) { + continue + } + oop.Field = i + break + } + prop.OneofTypes[oop.Prop.OrigName] = oop + } + } + + // build required counts + // build tags + reqCount := 0 + prop.decoderOrigNames = make(map[string]int) + for i, p := range prop.Prop { + if strings.HasPrefix(p.Name, "XXX_") { + // Internal fields should not appear in tags/origNames maps. + // They are handled specially when encoding and decoding. + continue + } + if p.Required { + reqCount++ + } + prop.decoderTags.put(p.Tag, i) + prop.decoderOrigNames[p.OrigName] = i + } + prop.reqCount = reqCount + + return prop +} + +// Return the Properties object for the x[0]'th field of the structure. +func propByIndex(t reflect.Type, x []int) *Properties { + if len(x) != 1 { + fmt.Fprintf(os.Stderr, "proto: field index dimension %d (not 1) for type %s\n", len(x), t) + return nil + } + prop := GetProperties(t) + return prop.Prop[x[0]] +} + +// Get the address and type of a pointer to a struct from an interface. +func getbase(pb Message) (t reflect.Type, b structPointer, err error) { + if pb == nil { + err = ErrNil + return + } + // get the reflect type of the pointer to the struct. + t = reflect.TypeOf(pb) + // get the address of the struct. + value := reflect.ValueOf(pb) + b = toStructPointer(value) + return +} + +// A global registry of enum types. +// The generated code will register the generated maps by calling RegisterEnum. + +var enumValueMaps = make(map[string]map[string]int32) + +// RegisterEnum is called from the generated code to install the enum descriptor +// maps into the global table to aid parsing text format protocol buffers. +func RegisterEnum(typeName string, unusedNameMap map[int32]string, valueMap map[string]int32) { + if _, ok := enumValueMaps[typeName]; ok { + panic("proto: duplicate enum registered: " + typeName) + } + enumValueMaps[typeName] = valueMap +} + +// EnumValueMap returns the mapping from names to integers of the +// enum type enumType, or a nil if not found. +func EnumValueMap(enumType string) map[string]int32 { + return enumValueMaps[enumType] +} + +// A registry of all linked message types. +// The string is a fully-qualified proto name ("pkg.Message"). +var ( + protoTypes = make(map[string]reflect.Type) + revProtoTypes = make(map[reflect.Type]string) +) + +// RegisterType is called from generated code and maps from the fully qualified +// proto name to the type (pointer to struct) of the protocol buffer. +func RegisterType(x Message, name string) { + if _, ok := protoTypes[name]; ok { + // TODO: Some day, make this a panic. + log.Printf("proto: duplicate proto type registered: %s", name) + return + } + t := reflect.TypeOf(x) + protoTypes[name] = t + revProtoTypes[t] = name +} + +// MessageName returns the fully-qualified proto name for the given message type. +func MessageName(x Message) string { + type xname interface { + XXX_MessageName() string + } + if m, ok := x.(xname); ok { + return m.XXX_MessageName() + } + return revProtoTypes[reflect.TypeOf(x)] +} + +// MessageType returns the message type (pointer to struct) for a named message. +func MessageType(name string) reflect.Type { return protoTypes[name] } + +// A registry of all linked proto files. +var ( + protoFiles = make(map[string][]byte) // file name => fileDescriptor +) + +// RegisterFile is called from generated code and maps from the +// full file name of a .proto file to its compressed FileDescriptorProto. +func RegisterFile(filename string, fileDescriptor []byte) { + protoFiles[filename] = fileDescriptor +} + +// FileDescriptor returns the compressed FileDescriptorProto for a .proto file. +func FileDescriptor(filename string) []byte { return protoFiles[filename] } diff --git a/vendor/github.com/golang/protobuf/proto/text.go b/vendor/github.com/golang/protobuf/proto/text.go new file mode 100644 index 0000000..965876b --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/text.go @@ -0,0 +1,854 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +// Functions for writing the text protocol buffer format. + +import ( + "bufio" + "bytes" + "encoding" + "errors" + "fmt" + "io" + "log" + "math" + "reflect" + "sort" + "strings" +) + +var ( + newline = []byte("\n") + spaces = []byte(" ") + gtNewline = []byte(">\n") + endBraceNewline = []byte("}\n") + backslashN = []byte{'\\', 'n'} + backslashR = []byte{'\\', 'r'} + backslashT = []byte{'\\', 't'} + backslashDQ = []byte{'\\', '"'} + backslashBS = []byte{'\\', '\\'} + posInf = []byte("inf") + negInf = []byte("-inf") + nan = []byte("nan") +) + +type writer interface { + io.Writer + WriteByte(byte) error +} + +// textWriter is an io.Writer that tracks its indentation level. +type textWriter struct { + ind int + complete bool // if the current position is a complete line + compact bool // whether to write out as a one-liner + w writer +} + +func (w *textWriter) WriteString(s string) (n int, err error) { + if !strings.Contains(s, "\n") { + if !w.compact && w.complete { + w.writeIndent() + } + w.complete = false + return io.WriteString(w.w, s) + } + // WriteString is typically called without newlines, so this + // codepath and its copy are rare. We copy to avoid + // duplicating all of Write's logic here. + return w.Write([]byte(s)) +} + +func (w *textWriter) Write(p []byte) (n int, err error) { + newlines := bytes.Count(p, newline) + if newlines == 0 { + if !w.compact && w.complete { + w.writeIndent() + } + n, err = w.w.Write(p) + w.complete = false + return n, err + } + + frags := bytes.SplitN(p, newline, newlines+1) + if w.compact { + for i, frag := range frags { + if i > 0 { + if err := w.w.WriteByte(' '); err != nil { + return n, err + } + n++ + } + nn, err := w.w.Write(frag) + n += nn + if err != nil { + return n, err + } + } + return n, nil + } + + for i, frag := range frags { + if w.complete { + w.writeIndent() + } + nn, err := w.w.Write(frag) + n += nn + if err != nil { + return n, err + } + if i+1 < len(frags) { + if err := w.w.WriteByte('\n'); err != nil { + return n, err + } + n++ + } + } + w.complete = len(frags[len(frags)-1]) == 0 + return n, nil +} + +func (w *textWriter) WriteByte(c byte) error { + if w.compact && c == '\n' { + c = ' ' + } + if !w.compact && w.complete { + w.writeIndent() + } + err := w.w.WriteByte(c) + w.complete = c == '\n' + return err +} + +func (w *textWriter) indent() { w.ind++ } + +func (w *textWriter) unindent() { + if w.ind == 0 { + log.Print("proto: textWriter unindented too far") + return + } + w.ind-- +} + +func writeName(w *textWriter, props *Properties) error { + if _, err := w.WriteString(props.OrigName); err != nil { + return err + } + if props.Wire != "group" { + return w.WriteByte(':') + } + return nil +} + +// raw is the interface satisfied by RawMessage. +type raw interface { + Bytes() []byte +} + +func requiresQuotes(u string) bool { + // When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted. + for _, ch := range u { + switch { + case ch == '.' || ch == '/' || ch == '_': + continue + case '0' <= ch && ch <= '9': + continue + case 'A' <= ch && ch <= 'Z': + continue + case 'a' <= ch && ch <= 'z': + continue + default: + return true + } + } + return false +} + +// isAny reports whether sv is a google.protobuf.Any message +func isAny(sv reflect.Value) bool { + type wkt interface { + XXX_WellKnownType() string + } + t, ok := sv.Addr().Interface().(wkt) + return ok && t.XXX_WellKnownType() == "Any" +} + +// writeProto3Any writes an expanded google.protobuf.Any message. +// +// It returns (false, nil) if sv value can't be unmarshaled (e.g. because +// required messages are not linked in). +// +// It returns (true, error) when sv was written in expanded format or an error +// was encountered. +func (tm *TextMarshaler) writeProto3Any(w *textWriter, sv reflect.Value) (bool, error) { + turl := sv.FieldByName("TypeUrl") + val := sv.FieldByName("Value") + if !turl.IsValid() || !val.IsValid() { + return true, errors.New("proto: invalid google.protobuf.Any message") + } + + b, ok := val.Interface().([]byte) + if !ok { + return true, errors.New("proto: invalid google.protobuf.Any message") + } + + parts := strings.Split(turl.String(), "/") + mt := MessageType(parts[len(parts)-1]) + if mt == nil { + return false, nil + } + m := reflect.New(mt.Elem()) + if err := Unmarshal(b, m.Interface().(Message)); err != nil { + return false, nil + } + w.Write([]byte("[")) + u := turl.String() + if requiresQuotes(u) { + writeString(w, u) + } else { + w.Write([]byte(u)) + } + if w.compact { + w.Write([]byte("]:<")) + } else { + w.Write([]byte("]: <\n")) + w.ind++ + } + if err := tm.writeStruct(w, m.Elem()); err != nil { + return true, err + } + if w.compact { + w.Write([]byte("> ")) + } else { + w.ind-- + w.Write([]byte(">\n")) + } + return true, nil +} + +func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error { + if tm.ExpandAny && isAny(sv) { + if canExpand, err := tm.writeProto3Any(w, sv); canExpand { + return err + } + } + st := sv.Type() + sprops := GetProperties(st) + for i := 0; i < sv.NumField(); i++ { + fv := sv.Field(i) + props := sprops.Prop[i] + name := st.Field(i).Name + + if strings.HasPrefix(name, "XXX_") { + // There are two XXX_ fields: + // XXX_unrecognized []byte + // XXX_extensions map[int32]proto.Extension + // The first is handled here; + // the second is handled at the bottom of this function. + if name == "XXX_unrecognized" && !fv.IsNil() { + if err := writeUnknownStruct(w, fv.Interface().([]byte)); err != nil { + return err + } + } + continue + } + if fv.Kind() == reflect.Ptr && fv.IsNil() { + // Field not filled in. This could be an optional field or + // a required field that wasn't filled in. Either way, there + // isn't anything we can show for it. + continue + } + if fv.Kind() == reflect.Slice && fv.IsNil() { + // Repeated field that is empty, or a bytes field that is unused. + continue + } + + if props.Repeated && fv.Kind() == reflect.Slice { + // Repeated field. + for j := 0; j < fv.Len(); j++ { + if err := writeName(w, props); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + v := fv.Index(j) + if v.Kind() == reflect.Ptr && v.IsNil() { + // A nil message in a repeated field is not valid, + // but we can handle that more gracefully than panicking. + if _, err := w.Write([]byte("\n")); err != nil { + return err + } + continue + } + if err := tm.writeAny(w, v, props); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + } + continue + } + if fv.Kind() == reflect.Map { + // Map fields are rendered as a repeated struct with key/value fields. + keys := fv.MapKeys() + sort.Sort(mapKeys(keys)) + for _, key := range keys { + val := fv.MapIndex(key) + if err := writeName(w, props); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + // open struct + if err := w.WriteByte('<'); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte('\n'); err != nil { + return err + } + } + w.indent() + // key + if _, err := w.WriteString("key:"); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if err := tm.writeAny(w, key, props.mkeyprop); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + // nil values aren't legal, but we can avoid panicking because of them. + if val.Kind() != reflect.Ptr || !val.IsNil() { + // value + if _, err := w.WriteString("value:"); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if err := tm.writeAny(w, val, props.mvalprop); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + } + // close struct + w.unindent() + if err := w.WriteByte('>'); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + } + continue + } + if props.proto3 && fv.Kind() == reflect.Slice && fv.Len() == 0 { + // empty bytes field + continue + } + if fv.Kind() != reflect.Ptr && fv.Kind() != reflect.Slice { + // proto3 non-repeated scalar field; skip if zero value + if isProto3Zero(fv) { + continue + } + } + + if fv.Kind() == reflect.Interface { + // Check if it is a oneof. + if st.Field(i).Tag.Get("protobuf_oneof") != "" { + // fv is nil, or holds a pointer to generated struct. + // That generated struct has exactly one field, + // which has a protobuf struct tag. + if fv.IsNil() { + continue + } + inner := fv.Elem().Elem() // interface -> *T -> T + tag := inner.Type().Field(0).Tag.Get("protobuf") + props = new(Properties) // Overwrite the outer props var, but not its pointee. + props.Parse(tag) + // Write the value in the oneof, not the oneof itself. + fv = inner.Field(0) + + // Special case to cope with malformed messages gracefully: + // If the value in the oneof is a nil pointer, don't panic + // in writeAny. + if fv.Kind() == reflect.Ptr && fv.IsNil() { + // Use errors.New so writeAny won't render quotes. + msg := errors.New("/* nil */") + fv = reflect.ValueOf(&msg).Elem() + } + } + } + + if err := writeName(w, props); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if b, ok := fv.Interface().(raw); ok { + if err := writeRaw(w, b.Bytes()); err != nil { + return err + } + continue + } + + // Enums have a String method, so writeAny will work fine. + if err := tm.writeAny(w, fv, props); err != nil { + return err + } + + if err := w.WriteByte('\n'); err != nil { + return err + } + } + + // Extensions (the XXX_extensions field). + pv := sv.Addr() + if _, ok := extendable(pv.Interface()); ok { + if err := tm.writeExtensions(w, pv); err != nil { + return err + } + } + + return nil +} + +// writeRaw writes an uninterpreted raw message. +func writeRaw(w *textWriter, b []byte) error { + if err := w.WriteByte('<'); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte('\n'); err != nil { + return err + } + } + w.indent() + if err := writeUnknownStruct(w, b); err != nil { + return err + } + w.unindent() + if err := w.WriteByte('>'); err != nil { + return err + } + return nil +} + +// writeAny writes an arbitrary field. +func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Properties) error { + v = reflect.Indirect(v) + + // Floats have special cases. + if v.Kind() == reflect.Float32 || v.Kind() == reflect.Float64 { + x := v.Float() + var b []byte + switch { + case math.IsInf(x, 1): + b = posInf + case math.IsInf(x, -1): + b = negInf + case math.IsNaN(x): + b = nan + } + if b != nil { + _, err := w.Write(b) + return err + } + // Other values are handled below. + } + + // We don't attempt to serialise every possible value type; only those + // that can occur in protocol buffers. + switch v.Kind() { + case reflect.Slice: + // Should only be a []byte; repeated fields are handled in writeStruct. + if err := writeString(w, string(v.Bytes())); err != nil { + return err + } + case reflect.String: + if err := writeString(w, v.String()); err != nil { + return err + } + case reflect.Struct: + // Required/optional group/message. + var bra, ket byte = '<', '>' + if props != nil && props.Wire == "group" { + bra, ket = '{', '}' + } + if err := w.WriteByte(bra); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte('\n'); err != nil { + return err + } + } + w.indent() + if etm, ok := v.Interface().(encoding.TextMarshaler); ok { + text, err := etm.MarshalText() + if err != nil { + return err + } + if _, err = w.Write(text); err != nil { + return err + } + } else if err := tm.writeStruct(w, v); err != nil { + return err + } + w.unindent() + if err := w.WriteByte(ket); err != nil { + return err + } + default: + _, err := fmt.Fprint(w, v.Interface()) + return err + } + return nil +} + +// equivalent to C's isprint. +func isprint(c byte) bool { + return c >= 0x20 && c < 0x7f +} + +// writeString writes a string in the protocol buffer text format. +// It is similar to strconv.Quote except we don't use Go escape sequences, +// we treat the string as a byte sequence, and we use octal escapes. +// These differences are to maintain interoperability with the other +// languages' implementations of the text format. +func writeString(w *textWriter, s string) error { + // use WriteByte here to get any needed indent + if err := w.WriteByte('"'); err != nil { + return err + } + // Loop over the bytes, not the runes. + for i := 0; i < len(s); i++ { + var err error + // Divergence from C++: we don't escape apostrophes. + // There's no need to escape them, and the C++ parser + // copes with a naked apostrophe. + switch c := s[i]; c { + case '\n': + _, err = w.w.Write(backslashN) + case '\r': + _, err = w.w.Write(backslashR) + case '\t': + _, err = w.w.Write(backslashT) + case '"': + _, err = w.w.Write(backslashDQ) + case '\\': + _, err = w.w.Write(backslashBS) + default: + if isprint(c) { + err = w.w.WriteByte(c) + } else { + _, err = fmt.Fprintf(w.w, "\\%03o", c) + } + } + if err != nil { + return err + } + } + return w.WriteByte('"') +} + +func writeUnknownStruct(w *textWriter, data []byte) (err error) { + if !w.compact { + if _, err := fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data)); err != nil { + return err + } + } + b := NewBuffer(data) + for b.index < len(b.buf) { + x, err := b.DecodeVarint() + if err != nil { + _, err := fmt.Fprintf(w, "/* %v */\n", err) + return err + } + wire, tag := x&7, x>>3 + if wire == WireEndGroup { + w.unindent() + if _, err := w.Write(endBraceNewline); err != nil { + return err + } + continue + } + if _, err := fmt.Fprint(w, tag); err != nil { + return err + } + if wire != WireStartGroup { + if err := w.WriteByte(':'); err != nil { + return err + } + } + if !w.compact || wire == WireStartGroup { + if err := w.WriteByte(' '); err != nil { + return err + } + } + switch wire { + case WireBytes: + buf, e := b.DecodeRawBytes(false) + if e == nil { + _, err = fmt.Fprintf(w, "%q", buf) + } else { + _, err = fmt.Fprintf(w, "/* %v */", e) + } + case WireFixed32: + x, err = b.DecodeFixed32() + err = writeUnknownInt(w, x, err) + case WireFixed64: + x, err = b.DecodeFixed64() + err = writeUnknownInt(w, x, err) + case WireStartGroup: + err = w.WriteByte('{') + w.indent() + case WireVarint: + x, err = b.DecodeVarint() + err = writeUnknownInt(w, x, err) + default: + _, err = fmt.Fprintf(w, "/* unknown wire type %d */", wire) + } + if err != nil { + return err + } + if err = w.WriteByte('\n'); err != nil { + return err + } + } + return nil +} + +func writeUnknownInt(w *textWriter, x uint64, err error) error { + if err == nil { + _, err = fmt.Fprint(w, x) + } else { + _, err = fmt.Fprintf(w, "/* %v */", err) + } + return err +} + +type int32Slice []int32 + +func (s int32Slice) Len() int { return len(s) } +func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] } +func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// writeExtensions writes all the extensions in pv. +// pv is assumed to be a pointer to a protocol message struct that is extendable. +func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error { + emap := extensionMaps[pv.Type().Elem()] + ep, _ := extendable(pv.Interface()) + + // Order the extensions by ID. + // This isn't strictly necessary, but it will give us + // canonical output, which will also make testing easier. + m, mu := ep.extensionsRead() + if m == nil { + return nil + } + mu.Lock() + ids := make([]int32, 0, len(m)) + for id := range m { + ids = append(ids, id) + } + sort.Sort(int32Slice(ids)) + mu.Unlock() + + for _, extNum := range ids { + ext := m[extNum] + var desc *ExtensionDesc + if emap != nil { + desc = emap[extNum] + } + if desc == nil { + // Unknown extension. + if err := writeUnknownStruct(w, ext.enc); err != nil { + return err + } + continue + } + + pb, err := GetExtension(ep, desc) + if err != nil { + return fmt.Errorf("failed getting extension: %v", err) + } + + // Repeated extensions will appear as a slice. + if !desc.repeated() { + if err := tm.writeExtension(w, desc.Name, pb); err != nil { + return err + } + } else { + v := reflect.ValueOf(pb) + for i := 0; i < v.Len(); i++ { + if err := tm.writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil { + return err + } + } + } + } + return nil +} + +func (tm *TextMarshaler) writeExtension(w *textWriter, name string, pb interface{}) error { + if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if err := tm.writeAny(w, reflect.ValueOf(pb), nil); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + return nil +} + +func (w *textWriter) writeIndent() { + if !w.complete { + return + } + remain := w.ind * 2 + for remain > 0 { + n := remain + if n > len(spaces) { + n = len(spaces) + } + w.w.Write(spaces[:n]) + remain -= n + } + w.complete = false +} + +// TextMarshaler is a configurable text format marshaler. +type TextMarshaler struct { + Compact bool // use compact text format (one line). + ExpandAny bool // expand google.protobuf.Any messages of known types +} + +// Marshal writes a given protocol buffer in text format. +// The only errors returned are from w. +func (tm *TextMarshaler) Marshal(w io.Writer, pb Message) error { + val := reflect.ValueOf(pb) + if pb == nil || val.IsNil() { + w.Write([]byte("")) + return nil + } + var bw *bufio.Writer + ww, ok := w.(writer) + if !ok { + bw = bufio.NewWriter(w) + ww = bw + } + aw := &textWriter{ + w: ww, + complete: true, + compact: tm.Compact, + } + + if etm, ok := pb.(encoding.TextMarshaler); ok { + text, err := etm.MarshalText() + if err != nil { + return err + } + if _, err = aw.Write(text); err != nil { + return err + } + if bw != nil { + return bw.Flush() + } + return nil + } + // Dereference the received pointer so we don't have outer < and >. + v := reflect.Indirect(val) + if err := tm.writeStruct(aw, v); err != nil { + return err + } + if bw != nil { + return bw.Flush() + } + return nil +} + +// Text is the same as Marshal, but returns the string directly. +func (tm *TextMarshaler) Text(pb Message) string { + var buf bytes.Buffer + tm.Marshal(&buf, pb) + return buf.String() +} + +var ( + defaultTextMarshaler = TextMarshaler{} + compactTextMarshaler = TextMarshaler{Compact: true} +) + +// TODO: consider removing some of the Marshal functions below. + +// MarshalText writes a given protocol buffer in text format. +// The only errors returned are from w. +func MarshalText(w io.Writer, pb Message) error { return defaultTextMarshaler.Marshal(w, pb) } + +// MarshalTextString is the same as MarshalText, but returns the string directly. +func MarshalTextString(pb Message) string { return defaultTextMarshaler.Text(pb) } + +// CompactText writes a given protocol buffer in compact text format (one line). +func CompactText(w io.Writer, pb Message) error { return compactTextMarshaler.Marshal(w, pb) } + +// CompactTextString is the same as CompactText, but returns the string directly. +func CompactTextString(pb Message) string { return compactTextMarshaler.Text(pb) } diff --git a/vendor/github.com/golang/protobuf/proto/text_parser.go b/vendor/github.com/golang/protobuf/proto/text_parser.go new file mode 100644 index 0000000..61f83c1 --- /dev/null +++ b/vendor/github.com/golang/protobuf/proto/text_parser.go @@ -0,0 +1,895 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +// Functions for parsing the Text protocol buffer format. +// TODO: message sets. + +import ( + "encoding" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "unicode/utf8" +) + +// Error string emitted when deserializing Any and fields are already set +const anyRepeatedlyUnpacked = "Any message unpacked multiple times, or %q already set" + +type ParseError struct { + Message string + Line int // 1-based line number + Offset int // 0-based byte offset from start of input +} + +func (p *ParseError) Error() string { + if p.Line == 1 { + // show offset only for first line + return fmt.Sprintf("line 1.%d: %v", p.Offset, p.Message) + } + return fmt.Sprintf("line %d: %v", p.Line, p.Message) +} + +type token struct { + value string + err *ParseError + line int // line number + offset int // byte number from start of input, not start of line + unquoted string // the unquoted version of value, if it was a quoted string +} + +func (t *token) String() string { + if t.err == nil { + return fmt.Sprintf("%q (line=%d, offset=%d)", t.value, t.line, t.offset) + } + return fmt.Sprintf("parse error: %v", t.err) +} + +type textParser struct { + s string // remaining input + done bool // whether the parsing is finished (success or error) + backed bool // whether back() was called + offset, line int + cur token +} + +func newTextParser(s string) *textParser { + p := new(textParser) + p.s = s + p.line = 1 + p.cur.line = 1 + return p +} + +func (p *textParser) errorf(format string, a ...interface{}) *ParseError { + pe := &ParseError{fmt.Sprintf(format, a...), p.cur.line, p.cur.offset} + p.cur.err = pe + p.done = true + return pe +} + +// Numbers and identifiers are matched by [-+._A-Za-z0-9] +func isIdentOrNumberChar(c byte) bool { + switch { + case 'A' <= c && c <= 'Z', 'a' <= c && c <= 'z': + return true + case '0' <= c && c <= '9': + return true + } + switch c { + case '-', '+', '.', '_': + return true + } + return false +} + +func isWhitespace(c byte) bool { + switch c { + case ' ', '\t', '\n', '\r': + return true + } + return false +} + +func isQuote(c byte) bool { + switch c { + case '"', '\'': + return true + } + return false +} + +func (p *textParser) skipWhitespace() { + i := 0 + for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') { + if p.s[i] == '#' { + // comment; skip to end of line or input + for i < len(p.s) && p.s[i] != '\n' { + i++ + } + if i == len(p.s) { + break + } + } + if p.s[i] == '\n' { + p.line++ + } + i++ + } + p.offset += i + p.s = p.s[i:len(p.s)] + if len(p.s) == 0 { + p.done = true + } +} + +func (p *textParser) advance() { + // Skip whitespace + p.skipWhitespace() + if p.done { + return + } + + // Start of non-whitespace + p.cur.err = nil + p.cur.offset, p.cur.line = p.offset, p.line + p.cur.unquoted = "" + switch p.s[0] { + case '<', '>', '{', '}', ':', '[', ']', ';', ',', '/': + // Single symbol + p.cur.value, p.s = p.s[0:1], p.s[1:len(p.s)] + case '"', '\'': + // Quoted string + i := 1 + for i < len(p.s) && p.s[i] != p.s[0] && p.s[i] != '\n' { + if p.s[i] == '\\' && i+1 < len(p.s) { + // skip escaped char + i++ + } + i++ + } + if i >= len(p.s) || p.s[i] != p.s[0] { + p.errorf("unmatched quote") + return + } + unq, err := unquoteC(p.s[1:i], rune(p.s[0])) + if err != nil { + p.errorf("invalid quoted string %s: %v", p.s[0:i+1], err) + return + } + p.cur.value, p.s = p.s[0:i+1], p.s[i+1:len(p.s)] + p.cur.unquoted = unq + default: + i := 0 + for i < len(p.s) && isIdentOrNumberChar(p.s[i]) { + i++ + } + if i == 0 { + p.errorf("unexpected byte %#x", p.s[0]) + return + } + p.cur.value, p.s = p.s[0:i], p.s[i:len(p.s)] + } + p.offset += len(p.cur.value) +} + +var ( + errBadUTF8 = errors.New("proto: bad UTF-8") + errBadHex = errors.New("proto: bad hexadecimal") +) + +func unquoteC(s string, quote rune) (string, error) { + // This is based on C++'s tokenizer.cc. + // Despite its name, this is *not* parsing C syntax. + // For instance, "\0" is an invalid quoted string. + + // Avoid allocation in trivial cases. + simple := true + for _, r := range s { + if r == '\\' || r == quote { + simple = false + break + } + } + if simple { + return s, nil + } + + buf := make([]byte, 0, 3*len(s)/2) + for len(s) > 0 { + r, n := utf8.DecodeRuneInString(s) + if r == utf8.RuneError && n == 1 { + return "", errBadUTF8 + } + s = s[n:] + if r != '\\' { + if r < utf8.RuneSelf { + buf = append(buf, byte(r)) + } else { + buf = append(buf, string(r)...) + } + continue + } + + ch, tail, err := unescape(s) + if err != nil { + return "", err + } + buf = append(buf, ch...) + s = tail + } + return string(buf), nil +} + +func unescape(s string) (ch string, tail string, err error) { + r, n := utf8.DecodeRuneInString(s) + if r == utf8.RuneError && n == 1 { + return "", "", errBadUTF8 + } + s = s[n:] + switch r { + case 'a': + return "\a", s, nil + case 'b': + return "\b", s, nil + case 'f': + return "\f", s, nil + case 'n': + return "\n", s, nil + case 'r': + return "\r", s, nil + case 't': + return "\t", s, nil + case 'v': + return "\v", s, nil + case '?': + return "?", s, nil // trigraph workaround + case '\'', '"', '\\': + return string(r), s, nil + case '0', '1', '2', '3', '4', '5', '6', '7', 'x', 'X': + if len(s) < 2 { + return "", "", fmt.Errorf(`\%c requires 2 following digits`, r) + } + base := 8 + ss := s[:2] + s = s[2:] + if r == 'x' || r == 'X' { + base = 16 + } else { + ss = string(r) + ss + } + i, err := strconv.ParseUint(ss, base, 8) + if err != nil { + return "", "", err + } + return string([]byte{byte(i)}), s, nil + case 'u', 'U': + n := 4 + if r == 'U' { + n = 8 + } + if len(s) < n { + return "", "", fmt.Errorf(`\%c requires %d digits`, r, n) + } + + bs := make([]byte, n/2) + for i := 0; i < n; i += 2 { + a, ok1 := unhex(s[i]) + b, ok2 := unhex(s[i+1]) + if !ok1 || !ok2 { + return "", "", errBadHex + } + bs[i/2] = a<<4 | b + } + s = s[n:] + return string(bs), s, nil + } + return "", "", fmt.Errorf(`unknown escape \%c`, r) +} + +// Adapted from src/pkg/strconv/quote.go. +func unhex(b byte) (v byte, ok bool) { + switch { + case '0' <= b && b <= '9': + return b - '0', true + case 'a' <= b && b <= 'f': + return b - 'a' + 10, true + case 'A' <= b && b <= 'F': + return b - 'A' + 10, true + } + return 0, false +} + +// Back off the parser by one token. Can only be done between calls to next(). +// It makes the next advance() a no-op. +func (p *textParser) back() { p.backed = true } + +// Advances the parser and returns the new current token. +func (p *textParser) next() *token { + if p.backed || p.done { + p.backed = false + return &p.cur + } + p.advance() + if p.done { + p.cur.value = "" + } else if len(p.cur.value) > 0 && isQuote(p.cur.value[0]) { + // Look for multiple quoted strings separated by whitespace, + // and concatenate them. + cat := p.cur + for { + p.skipWhitespace() + if p.done || !isQuote(p.s[0]) { + break + } + p.advance() + if p.cur.err != nil { + return &p.cur + } + cat.value += " " + p.cur.value + cat.unquoted += p.cur.unquoted + } + p.done = false // parser may have seen EOF, but we want to return cat + p.cur = cat + } + return &p.cur +} + +func (p *textParser) consumeToken(s string) error { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value != s { + p.back() + return p.errorf("expected %q, found %q", s, tok.value) + } + return nil +} + +// Return a RequiredNotSetError indicating which required field was not set. +func (p *textParser) missingRequiredFieldError(sv reflect.Value) *RequiredNotSetError { + st := sv.Type() + sprops := GetProperties(st) + for i := 0; i < st.NumField(); i++ { + if !isNil(sv.Field(i)) { + continue + } + + props := sprops.Prop[i] + if props.Required { + return &RequiredNotSetError{fmt.Sprintf("%v.%v", st, props.OrigName)} + } + } + return &RequiredNotSetError{fmt.Sprintf("%v.", st)} // should not happen +} + +// Returns the index in the struct for the named field, as well as the parsed tag properties. +func structFieldByName(sprops *StructProperties, name string) (int, *Properties, bool) { + i, ok := sprops.decoderOrigNames[name] + if ok { + return i, sprops.Prop[i], true + } + return -1, nil, false +} + +// Consume a ':' from the input stream (if the next token is a colon), +// returning an error if a colon is needed but not present. +func (p *textParser) checkForColon(props *Properties, typ reflect.Type) *ParseError { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value != ":" { + // Colon is optional when the field is a group or message. + needColon := true + switch props.Wire { + case "group": + needColon = false + case "bytes": + // A "bytes" field is either a message, a string, or a repeated field; + // those three become *T, *string and []T respectively, so we can check for + // this field being a pointer to a non-string. + if typ.Kind() == reflect.Ptr { + // *T or *string + if typ.Elem().Kind() == reflect.String { + break + } + } else if typ.Kind() == reflect.Slice { + // []T or []*T + if typ.Elem().Kind() != reflect.Ptr { + break + } + } else if typ.Kind() == reflect.String { + // The proto3 exception is for a string field, + // which requires a colon. + break + } + needColon = false + } + if needColon { + return p.errorf("expected ':', found %q", tok.value) + } + p.back() + } + return nil +} + +func (p *textParser) readStruct(sv reflect.Value, terminator string) error { + st := sv.Type() + sprops := GetProperties(st) + reqCount := sprops.reqCount + var reqFieldErr error + fieldSet := make(map[string]bool) + // A struct is a sequence of "name: value", terminated by one of + // '>' or '}', or the end of the input. A name may also be + // "[extension]" or "[type/url]". + // + // The whole struct can also be an expanded Any message, like: + // [type/url] < ... struct contents ... > + for { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value == terminator { + break + } + if tok.value == "[" { + // Looks like an extension or an Any. + // + // TODO: Check whether we need to handle + // namespace rooted names (e.g. ".something.Foo"). + extName, err := p.consumeExtName() + if err != nil { + return err + } + + if s := strings.LastIndex(extName, "/"); s >= 0 { + // If it contains a slash, it's an Any type URL. + messageName := extName[s+1:] + mt := MessageType(messageName) + if mt == nil { + return p.errorf("unrecognized message %q in google.protobuf.Any", messageName) + } + tok = p.next() + if tok.err != nil { + return tok.err + } + // consume an optional colon + if tok.value == ":" { + tok = p.next() + if tok.err != nil { + return tok.err + } + } + var terminator string + switch tok.value { + case "<": + terminator = ">" + case "{": + terminator = "}" + default: + return p.errorf("expected '{' or '<', found %q", tok.value) + } + v := reflect.New(mt.Elem()) + if pe := p.readStruct(v.Elem(), terminator); pe != nil { + return pe + } + b, err := Marshal(v.Interface().(Message)) + if err != nil { + return p.errorf("failed to marshal message of type %q: %v", messageName, err) + } + if fieldSet["type_url"] { + return p.errorf(anyRepeatedlyUnpacked, "type_url") + } + if fieldSet["value"] { + return p.errorf(anyRepeatedlyUnpacked, "value") + } + sv.FieldByName("TypeUrl").SetString(extName) + sv.FieldByName("Value").SetBytes(b) + fieldSet["type_url"] = true + fieldSet["value"] = true + continue + } + + var desc *ExtensionDesc + // This could be faster, but it's functional. + // TODO: Do something smarter than a linear scan. + for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) { + if d.Name == extName { + desc = d + break + } + } + if desc == nil { + return p.errorf("unrecognized extension %q", extName) + } + + props := &Properties{} + props.Parse(desc.Tag) + + typ := reflect.TypeOf(desc.ExtensionType) + if err := p.checkForColon(props, typ); err != nil { + return err + } + + rep := desc.repeated() + + // Read the extension structure, and set it in + // the value we're constructing. + var ext reflect.Value + if !rep { + ext = reflect.New(typ).Elem() + } else { + ext = reflect.New(typ.Elem()).Elem() + } + if err := p.readAny(ext, props); err != nil { + if _, ok := err.(*RequiredNotSetError); !ok { + return err + } + reqFieldErr = err + } + ep := sv.Addr().Interface().(Message) + if !rep { + SetExtension(ep, desc, ext.Interface()) + } else { + old, err := GetExtension(ep, desc) + var sl reflect.Value + if err == nil { + sl = reflect.ValueOf(old) // existing slice + } else { + sl = reflect.MakeSlice(typ, 0, 1) + } + sl = reflect.Append(sl, ext) + SetExtension(ep, desc, sl.Interface()) + } + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + continue + } + + // This is a normal, non-extension field. + name := tok.value + var dst reflect.Value + fi, props, ok := structFieldByName(sprops, name) + if ok { + dst = sv.Field(fi) + } else if oop, ok := sprops.OneofTypes[name]; ok { + // It is a oneof. + props = oop.Prop + nv := reflect.New(oop.Type.Elem()) + dst = nv.Elem().Field(0) + field := sv.Field(oop.Field) + if !field.IsNil() { + return p.errorf("field '%s' would overwrite already parsed oneof '%s'", name, sv.Type().Field(oop.Field).Name) + } + field.Set(nv) + } + if !dst.IsValid() { + return p.errorf("unknown field name %q in %v", name, st) + } + + if dst.Kind() == reflect.Map { + // Consume any colon. + if err := p.checkForColon(props, dst.Type()); err != nil { + return err + } + + // Construct the map if it doesn't already exist. + if dst.IsNil() { + dst.Set(reflect.MakeMap(dst.Type())) + } + key := reflect.New(dst.Type().Key()).Elem() + val := reflect.New(dst.Type().Elem()).Elem() + + // The map entry should be this sequence of tokens: + // < key : KEY value : VALUE > + // However, implementations may omit key or value, and technically + // we should support them in any order. See b/28924776 for a time + // this went wrong. + + tok := p.next() + var terminator string + switch tok.value { + case "<": + terminator = ">" + case "{": + terminator = "}" + default: + return p.errorf("expected '{' or '<', found %q", tok.value) + } + for { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value == terminator { + break + } + switch tok.value { + case "key": + if err := p.consumeToken(":"); err != nil { + return err + } + if err := p.readAny(key, props.mkeyprop); err != nil { + return err + } + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + case "value": + if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil { + return err + } + if err := p.readAny(val, props.mvalprop); err != nil { + return err + } + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + default: + p.back() + return p.errorf(`expected "key", "value", or %q, found %q`, terminator, tok.value) + } + } + + dst.SetMapIndex(key, val) + continue + } + + // Check that it's not already set if it's not a repeated field. + if !props.Repeated && fieldSet[name] { + return p.errorf("non-repeated field %q was repeated", name) + } + + if err := p.checkForColon(props, dst.Type()); err != nil { + return err + } + + // Parse into the field. + fieldSet[name] = true + if err := p.readAny(dst, props); err != nil { + if _, ok := err.(*RequiredNotSetError); !ok { + return err + } + reqFieldErr = err + } + if props.Required { + reqCount-- + } + + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + + } + + if reqCount > 0 { + return p.missingRequiredFieldError(sv) + } + return reqFieldErr +} + +// consumeExtName consumes extension name or expanded Any type URL and the +// following ']'. It returns the name or URL consumed. +func (p *textParser) consumeExtName() (string, error) { + tok := p.next() + if tok.err != nil { + return "", tok.err + } + + // If extension name or type url is quoted, it's a single token. + if len(tok.value) > 2 && isQuote(tok.value[0]) && tok.value[len(tok.value)-1] == tok.value[0] { + name, err := unquoteC(tok.value[1:len(tok.value)-1], rune(tok.value[0])) + if err != nil { + return "", err + } + return name, p.consumeToken("]") + } + + // Consume everything up to "]" + var parts []string + for tok.value != "]" { + parts = append(parts, tok.value) + tok = p.next() + if tok.err != nil { + return "", p.errorf("unrecognized type_url or extension name: %s", tok.err) + } + } + return strings.Join(parts, ""), nil +} + +// consumeOptionalSeparator consumes an optional semicolon or comma. +// It is used in readStruct to provide backward compatibility. +func (p *textParser) consumeOptionalSeparator() error { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value != ";" && tok.value != "," { + p.back() + } + return nil +} + +func (p *textParser) readAny(v reflect.Value, props *Properties) error { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value == "" { + return p.errorf("unexpected EOF") + } + + switch fv := v; fv.Kind() { + case reflect.Slice: + at := v.Type() + if at.Elem().Kind() == reflect.Uint8 { + // Special case for []byte + if tok.value[0] != '"' && tok.value[0] != '\'' { + // Deliberately written out here, as the error after + // this switch statement would write "invalid []byte: ...", + // which is not as user-friendly. + return p.errorf("invalid string: %v", tok.value) + } + bytes := []byte(tok.unquoted) + fv.Set(reflect.ValueOf(bytes)) + return nil + } + // Repeated field. + if tok.value == "[" { + // Repeated field with list notation, like [1,2,3]. + for { + fv.Set(reflect.Append(fv, reflect.New(at.Elem()).Elem())) + err := p.readAny(fv.Index(fv.Len()-1), props) + if err != nil { + return err + } + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value == "]" { + break + } + if tok.value != "," { + return p.errorf("Expected ']' or ',' found %q", tok.value) + } + } + return nil + } + // One value of the repeated field. + p.back() + fv.Set(reflect.Append(fv, reflect.New(at.Elem()).Elem())) + return p.readAny(fv.Index(fv.Len()-1), props) + case reflect.Bool: + // true/1/t/True or false/f/0/False. + switch tok.value { + case "true", "1", "t", "True": + fv.SetBool(true) + return nil + case "false", "0", "f", "False": + fv.SetBool(false) + return nil + } + case reflect.Float32, reflect.Float64: + v := tok.value + // Ignore 'f' for compatibility with output generated by C++, but don't + // remove 'f' when the value is "-inf" or "inf". + if strings.HasSuffix(v, "f") && tok.value != "-inf" && tok.value != "inf" { + v = v[:len(v)-1] + } + if f, err := strconv.ParseFloat(v, fv.Type().Bits()); err == nil { + fv.SetFloat(f) + return nil + } + case reflect.Int32: + if x, err := strconv.ParseInt(tok.value, 0, 32); err == nil { + fv.SetInt(x) + return nil + } + + if len(props.Enum) == 0 { + break + } + m, ok := enumValueMaps[props.Enum] + if !ok { + break + } + x, ok := m[tok.value] + if !ok { + break + } + fv.SetInt(int64(x)) + return nil + case reflect.Int64: + if x, err := strconv.ParseInt(tok.value, 0, 64); err == nil { + fv.SetInt(x) + return nil + } + + case reflect.Ptr: + // A basic field (indirected through pointer), or a repeated message/group + p.back() + fv.Set(reflect.New(fv.Type().Elem())) + return p.readAny(fv.Elem(), props) + case reflect.String: + if tok.value[0] == '"' || tok.value[0] == '\'' { + fv.SetString(tok.unquoted) + return nil + } + case reflect.Struct: + var terminator string + switch tok.value { + case "{": + terminator = "}" + case "<": + terminator = ">" + default: + return p.errorf("expected '{' or '<', found %q", tok.value) + } + // TODO: Handle nested messages which implement encoding.TextUnmarshaler. + return p.readStruct(fv, terminator) + case reflect.Uint32: + if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil { + fv.SetUint(uint64(x)) + return nil + } + case reflect.Uint64: + if x, err := strconv.ParseUint(tok.value, 0, 64); err == nil { + fv.SetUint(x) + return nil + } + } + return p.errorf("invalid %v: %v", v.Type(), tok.value) +} + +// UnmarshalText reads a protocol buffer in Text format. UnmarshalText resets pb +// before starting to unmarshal, so any existing data in pb is always removed. +// If a required field is not set and no other error occurs, +// UnmarshalText returns *RequiredNotSetError. +func UnmarshalText(s string, pb Message) error { + if um, ok := pb.(encoding.TextUnmarshaler); ok { + err := um.UnmarshalText([]byte(s)) + return err + } + pb.Reset() + v := reflect.ValueOf(pb) + if pe := newTextParser(s).readStruct(v.Elem(), ""); pe != nil { + return pe + } + return nil +} diff --git a/vendor/golang.org/x/net/AUTHORS b/vendor/golang.org/x/net/AUTHORS new file mode 100644 index 0000000..15167cd --- /dev/null +++ b/vendor/golang.org/x/net/AUTHORS @@ -0,0 +1,3 @@ +# This source code refers to The Go Authors for copyright purposes. +# The master list of authors is in the main Go distribution, +# visible at http://tip.golang.org/AUTHORS. diff --git a/vendor/golang.org/x/net/CONTRIBUTORS b/vendor/golang.org/x/net/CONTRIBUTORS new file mode 100644 index 0000000..1c4577e --- /dev/null +++ b/vendor/golang.org/x/net/CONTRIBUTORS @@ -0,0 +1,3 @@ +# This source code was written by the Go contributors. +# The master list of contributors is in the main Go distribution, +# visible at http://tip.golang.org/CONTRIBUTORS. diff --git a/vendor/golang.org/x/net/LICENSE b/vendor/golang.org/x/net/LICENSE new file mode 100644 index 0000000..6a66aea --- /dev/null +++ b/vendor/golang.org/x/net/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/net/PATENTS b/vendor/golang.org/x/net/PATENTS new file mode 100644 index 0000000..7330990 --- /dev/null +++ b/vendor/golang.org/x/net/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/net/proxy/direct.go b/vendor/golang.org/x/net/proxy/direct.go new file mode 100644 index 0000000..4c5ad88 --- /dev/null +++ b/vendor/golang.org/x/net/proxy/direct.go @@ -0,0 +1,18 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "net" +) + +type direct struct{} + +// Direct is a direct proxy: one that makes network connections directly. +var Direct = direct{} + +func (direct) Dial(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) +} diff --git a/vendor/golang.org/x/net/proxy/per_host.go b/vendor/golang.org/x/net/proxy/per_host.go new file mode 100644 index 0000000..f540b19 --- /dev/null +++ b/vendor/golang.org/x/net/proxy/per_host.go @@ -0,0 +1,140 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "net" + "strings" +) + +// A PerHost directs connections to a default Dialer unless the hostname +// requested matches one of a number of exceptions. +type PerHost struct { + def, bypass Dialer + + bypassNetworks []*net.IPNet + bypassIPs []net.IP + bypassZones []string + bypassHosts []string +} + +// NewPerHost returns a PerHost Dialer that directs connections to either +// defaultDialer or bypass, depending on whether the connection matches one of +// the configured rules. +func NewPerHost(defaultDialer, bypass Dialer) *PerHost { + return &PerHost{ + def: defaultDialer, + bypass: bypass, + } +} + +// Dial connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + return p.dialerForRequest(host).Dial(network, addr) +} + +func (p *PerHost) dialerForRequest(host string) Dialer { + if ip := net.ParseIP(host); ip != nil { + for _, net := range p.bypassNetworks { + if net.Contains(ip) { + return p.bypass + } + } + for _, bypassIP := range p.bypassIPs { + if bypassIP.Equal(ip) { + return p.bypass + } + } + return p.def + } + + for _, zone := range p.bypassZones { + if strings.HasSuffix(host, zone) { + return p.bypass + } + if host == zone[1:] { + // For a zone "example.com", we match "example.com" + // too. + return p.bypass + } + } + for _, bypassHost := range p.bypassHosts { + if bypassHost == host { + return p.bypass + } + } + return p.def +} + +// AddFromString parses a string that contains comma-separated values +// specifying hosts that should use the bypass proxy. Each value is either an +// IP address, a CIDR range, a zone (*.example.com) or a hostname +// (localhost). A best effort is made to parse the string and errors are +// ignored. +func (p *PerHost) AddFromString(s string) { + hosts := strings.Split(s, ",") + for _, host := range hosts { + host = strings.TrimSpace(host) + if len(host) == 0 { + continue + } + if strings.Contains(host, "/") { + // We assume that it's a CIDR address like 127.0.0.0/8 + if _, net, err := net.ParseCIDR(host); err == nil { + p.AddNetwork(net) + } + continue + } + if ip := net.ParseIP(host); ip != nil { + p.AddIP(ip) + continue + } + if strings.HasPrefix(host, "*.") { + p.AddZone(host[1:]) + continue + } + p.AddHost(host) + } +} + +// AddIP specifies an IP address that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match an IP. +func (p *PerHost) AddIP(ip net.IP) { + p.bypassIPs = append(p.bypassIPs, ip) +} + +// AddNetwork specifies an IP range that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match. +func (p *PerHost) AddNetwork(net *net.IPNet) { + p.bypassNetworks = append(p.bypassNetworks, net) +} + +// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of +// "example.com" matches "example.com" and all of its subdomains. +func (p *PerHost) AddZone(zone string) { + if strings.HasSuffix(zone, ".") { + zone = zone[:len(zone)-1] + } + if !strings.HasPrefix(zone, ".") { + zone = "." + zone + } + p.bypassZones = append(p.bypassZones, zone) +} + +// AddHost specifies a hostname that will use the bypass proxy. +func (p *PerHost) AddHost(host string) { + if strings.HasSuffix(host, ".") { + host = host[:len(host)-1] + } + p.bypassHosts = append(p.bypassHosts, host) +} diff --git a/vendor/golang.org/x/net/proxy/proxy.go b/vendor/golang.org/x/net/proxy/proxy.go new file mode 100644 index 0000000..8ccb0c5 --- /dev/null +++ b/vendor/golang.org/x/net/proxy/proxy.go @@ -0,0 +1,94 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package proxy provides support for a variety of protocols to proxy network +// data. +package proxy + +import ( + "errors" + "net" + "net/url" + "os" +) + +// A Dialer is a means to establish a connection. +type Dialer interface { + // Dial connects to the given address via the proxy. + Dial(network, addr string) (c net.Conn, err error) +} + +// Auth contains authentication parameters that specific Dialers may require. +type Auth struct { + User, Password string +} + +// FromEnvironment returns the dialer specified by the proxy related variables in +// the environment. +func FromEnvironment() Dialer { + allProxy := os.Getenv("all_proxy") + if len(allProxy) == 0 { + return Direct + } + + proxyURL, err := url.Parse(allProxy) + if err != nil { + return Direct + } + proxy, err := FromURL(proxyURL, Direct) + if err != nil { + return Direct + } + + noProxy := os.Getenv("no_proxy") + if len(noProxy) == 0 { + return proxy + } + + perHost := NewPerHost(proxy, Direct) + perHost.AddFromString(noProxy) + return perHost +} + +// proxySchemes is a map from URL schemes to a function that creates a Dialer +// from a URL with such a scheme. +var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error) + +// RegisterDialerType takes a URL scheme and a function to generate Dialers from +// a URL with that scheme and a forwarding Dialer. Registered schemes are used +// by FromURL. +func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) { + if proxySchemes == nil { + proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error)) + } + proxySchemes[scheme] = f +} + +// FromURL returns a Dialer given a URL specification and an underlying +// Dialer for it to make network requests. +func FromURL(u *url.URL, forward Dialer) (Dialer, error) { + var auth *Auth + if u.User != nil { + auth = new(Auth) + auth.User = u.User.Username() + if p, ok := u.User.Password(); ok { + auth.Password = p + } + } + + switch u.Scheme { + case "socks5": + return SOCKS5("tcp", u.Host, auth, forward) + } + + // If the scheme doesn't match any of the built-in schemes, see if it + // was registered by another package. + if proxySchemes != nil { + if f, ok := proxySchemes[u.Scheme]; ok { + return f(u, forward) + } + } + + return nil, errors.New("proxy: unknown scheme: " + u.Scheme) +} diff --git a/vendor/golang.org/x/net/proxy/socks5.go b/vendor/golang.org/x/net/proxy/socks5.go new file mode 100644 index 0000000..9b96282 --- /dev/null +++ b/vendor/golang.org/x/net/proxy/socks5.go @@ -0,0 +1,210 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "errors" + "io" + "net" + "strconv" +) + +// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address +// with an optional username and password. See RFC 1928. +func SOCKS5(network, addr string, auth *Auth, forward Dialer) (Dialer, error) { + s := &socks5{ + network: network, + addr: addr, + forward: forward, + } + if auth != nil { + s.user = auth.User + s.password = auth.Password + } + + return s, nil +} + +type socks5 struct { + user, password string + network, addr string + forward Dialer +} + +const socks5Version = 5 + +const ( + socks5AuthNone = 0 + socks5AuthPassword = 2 +) + +const socks5Connect = 1 + +const ( + socks5IP4 = 1 + socks5Domain = 3 + socks5IP6 = 4 +) + +var socks5Errors = []string{ + "", + "general failure", + "connection forbidden", + "network unreachable", + "host unreachable", + "connection refused", + "TTL expired", + "command not supported", + "address type not supported", +} + +// Dial connects to the address addr on the network net via the SOCKS5 proxy. +func (s *socks5) Dial(network, addr string) (net.Conn, error) { + switch network { + case "tcp", "tcp6", "tcp4": + default: + return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network) + } + + conn, err := s.forward.Dial(s.network, s.addr) + if err != nil { + return nil, err + } + closeConn := &conn + defer func() { + if closeConn != nil { + (*closeConn).Close() + } + }() + + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, errors.New("proxy: failed to parse port number: " + portStr) + } + if port < 1 || port > 0xffff { + return nil, errors.New("proxy: port number out of range: " + portStr) + } + + // the size here is just an estimate + buf := make([]byte, 0, 6+len(host)) + + buf = append(buf, socks5Version) + if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { + buf = append(buf, 2 /* num auth methods */, socks5AuthNone, socks5AuthPassword) + } else { + buf = append(buf, 1 /* num auth methods */, socks5AuthNone) + } + + if _, err := conn.Write(buf); err != nil { + return nil, errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return nil, errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + if buf[0] != 5 { + return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) + } + if buf[1] == 0xff { + return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") + } + + if buf[1] == socks5AuthPassword { + buf = buf[:0] + buf = append(buf, 1 /* password protocol version */) + buf = append(buf, uint8(len(s.user))) + buf = append(buf, s.user...) + buf = append(buf, uint8(len(s.password))) + buf = append(buf, s.password...) + + if _, err := conn.Write(buf); err != nil { + return nil, errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return nil, errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if buf[1] != 0 { + return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") + } + } + + buf = buf[:0] + buf = append(buf, socks5Version, socks5Connect, 0 /* reserved */) + + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + buf = append(buf, socks5IP4) + ip = ip4 + } else { + buf = append(buf, socks5IP6) + } + buf = append(buf, ip...) + } else { + if len(host) > 255 { + return nil, errors.New("proxy: destination hostname too long: " + host) + } + buf = append(buf, socks5Domain) + buf = append(buf, byte(len(host))) + buf = append(buf, host...) + } + buf = append(buf, byte(port>>8), byte(port)) + + if _, err := conn.Write(buf); err != nil { + return nil, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:4]); err != nil { + return nil, errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + failure := "unknown error" + if int(buf[1]) < len(socks5Errors) { + failure = socks5Errors[buf[1]] + } + + if len(failure) > 0 { + return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) + } + + bytesToDiscard := 0 + switch buf[3] { + case socks5IP4: + bytesToDiscard = net.IPv4len + case socks5IP6: + bytesToDiscard = net.IPv6len + case socks5Domain: + _, err := io.ReadFull(conn, buf[:1]) + if err != nil { + return nil, errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + bytesToDiscard = int(buf[0]) + default: + return nil, errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) + } + + if cap(buf) < bytesToDiscard { + buf = make([]byte, bytesToDiscard) + } else { + buf = buf[:bytesToDiscard] + } + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + // Also need to discard the port number + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return nil, errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + closeConn = nil + return conn, nil +} diff --git a/auth/auth_message.go b/wire/auth/auth_message.go similarity index 97% rename from auth/auth_message.go rename to wire/auth/auth_message.go index 9f20266..053abfe 100644 --- a/auth/auth_message.go +++ b/wire/auth/auth_message.go @@ -18,7 +18,7 @@ package Protocol_Data_AuthHiddenService import proto "github.com/golang/protobuf/proto" import fmt "fmt" import math "math" -import Protocol_Data_Control "github.com/s-rah/go-ricochet/control" +import Protocol_Data_Control "github.com/s-rah/go-ricochet/wire/control" // Reference imports to suppress errors if they are not otherwise used. var _ = proto.Marshal diff --git a/chat/chat.go b/wire/chat/chat.go similarity index 100% rename from chat/chat.go rename to wire/chat/chat.go diff --git a/contact/request.go b/wire/contact/request.go similarity index 98% rename from contact/request.go rename to wire/contact/request.go index f1996e8..1479913 100644 --- a/contact/request.go +++ b/wire/contact/request.go @@ -17,7 +17,7 @@ package Protocol_Data_ContactRequest import proto "github.com/golang/protobuf/proto" import fmt "fmt" import math "math" -import Protocol_Data_Control "github.com/s-rah/go-ricochet/control" +import Protocol_Data_Control "github.com/s-rah/go-ricochet/wire/control" // Reference imports to suppress errors if they are not otherwise used. var _ = proto.Marshal diff --git a/control/control_message.go b/wire/control/control_message.go similarity index 100% rename from control/control_message.go rename to wire/control/control_message.go