81 lines
2.2 KiB
Go
81 lines
2.2 KiB
Go
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
// +build go1.6
|
||
|
|
||
|
package http2
|
||
|
|
||
|
import (
|
||
|
"crypto/tls"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
)
|
||
|
|
||
|
func configureTransport(t1 *http.Transport) (*Transport, error) {
|
||
|
connPool := new(clientConnPool)
|
||
|
t2 := &Transport{
|
||
|
ConnPool: noDialClientConnPool{connPool},
|
||
|
t1: t1,
|
||
|
}
|
||
|
connPool.t = t2
|
||
|
if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if t1.TLSClientConfig == nil {
|
||
|
t1.TLSClientConfig = new(tls.Config)
|
||
|
}
|
||
|
if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
|
||
|
t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
|
||
|
}
|
||
|
if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
|
||
|
t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
|
||
|
}
|
||
|
upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
|
||
|
addr := authorityAddr("https", authority)
|
||
|
if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
|
||
|
go c.Close()
|
||
|
return erringRoundTripper{err}
|
||
|
} else if !used {
|
||
|
// Turns out we don't need this c.
|
||
|
// For example, two goroutines made requests to the same host
|
||
|
// at the same time, both kicking off TCP dials. (since protocol
|
||
|
// was unknown)
|
||
|
go c.Close()
|
||
|
}
|
||
|
return t2
|
||
|
}
|
||
|
if m := t1.TLSNextProto; len(m) == 0 {
|
||
|
t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
|
||
|
"h2": upgradeFn,
|
||
|
}
|
||
|
} else {
|
||
|
m["h2"] = upgradeFn
|
||
|
}
|
||
|
return t2, nil
|
||
|
}
|
||
|
|
||
|
// registerHTTPSProtocol calls Transport.RegisterProtocol but
|
||
|
// convering panics into errors.
|
||
|
func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) {
|
||
|
defer func() {
|
||
|
if e := recover(); e != nil {
|
||
|
err = fmt.Errorf("%v", e)
|
||
|
}
|
||
|
}()
|
||
|
t.RegisterProtocol("https", rt)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
|
||
|
// if there's already has a cached connection to the host.
|
||
|
type noDialH2RoundTripper struct{ t *Transport }
|
||
|
|
||
|
func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||
|
res, err := rt.t.RoundTrip(req)
|
||
|
if err == ErrNoCachedConn {
|
||
|
return nil, http.ErrSkipAltProtocol
|
||
|
}
|
||
|
return res, err
|
||
|
}
|