Source file src/pkg/net/http/httputil/reverseproxy.go
1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // HTTP reverse proxy handler 6 7 package httputil 8 9 import ( 10 "io" 11 "log" 12 "net" 13 "net/http" 14 "net/url" 15 "strings" 16 "sync" 17 "time" 18 ) 19 20 // ReverseProxy is an HTTP Handler that takes an incoming request and 21 // sends it to another server, proxying the response back to the 22 // client. 23 type ReverseProxy struct { 24 // Director must be a function which modifies 25 // the request into a new request to be sent 26 // using Transport. Its response is then copied 27 // back to the original client unmodified. 28 Director func(*http.Request) 29 30 // The transport used to perform proxy requests. 31 // If nil, http.DefaultTransport is used. 32 Transport http.RoundTripper 33 34 // FlushInterval specifies the flush interval 35 // to flush to the client while copying the 36 // response body. 37 // If zero, no periodic flushing is done. 38 FlushInterval time.Duration 39 } 40 41 func singleJoiningSlash(a, b string) string { 42 aslash := strings.HasSuffix(a, "/") 43 bslash := strings.HasPrefix(b, "/") 44 switch { 45 case aslash && bslash: 46 return a + b[1:] 47 case !aslash && !bslash: 48 return a + "/" + b 49 } 50 return a + b 51 } 52 53 // NewSingleHostReverseProxy returns a new ReverseProxy that rewrites 54 // URLs to the scheme, host, and base path provided in target. If the 55 // target's path is "/base" and the incoming request was for "/dir", 56 // the target request will be for /base/dir. 57 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { 58 targetQuery := target.RawQuery 59 director := func(req *http.Request) { 60 req.URL.Scheme = target.Scheme 61 req.URL.Host = target.Host 62 req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) 63 if targetQuery == "" || req.URL.RawQuery == "" { 64 req.URL.RawQuery = targetQuery + req.URL.RawQuery 65 } else { 66 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery 67 } 68 } 69 return &ReverseProxy{Director: director} 70 } 71 72 func copyHeader(dst, src http.Header) { 73 for k, vv := range src { 74 for _, v := range vv { 75 dst.Add(k, v) 76 } 77 } 78 } 79 80 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 81 transport := p.Transport 82 if transport == nil { 83 transport = http.DefaultTransport 84 } 85 86 outreq := new(http.Request) 87 *outreq = *req // includes shallow copies of maps, but okay 88 89 p.Director(outreq) 90 outreq.Proto = "HTTP/1.1" 91 outreq.ProtoMajor = 1 92 outreq.ProtoMinor = 1 93 outreq.Close = false 94 95 // Remove the connection header to the backend. We want a 96 // persistent connection, regardless of what the client sent 97 // to us. This is modifying the same underlying map from req 98 // (shallow copied above) so we only copy it if necessary. 99 if outreq.Header.Get("Connection") != "" { 100 outreq.Header = make(http.Header) 101 copyHeader(outreq.Header, req.Header) 102 outreq.Header.Del("Connection") 103 } 104 105 if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { 106 outreq.Header.Set("X-Forwarded-For", clientIp) 107 } 108 109 res, err := transport.RoundTrip(outreq) 110 if err != nil { 111 log.Printf("http: proxy error: %v", err) 112 rw.WriteHeader(http.StatusInternalServerError) 113 return 114 } 115 116 copyHeader(rw.Header(), res.Header) 117 118 rw.WriteHeader(res.StatusCode) 119 120 if res.Body != nil { 121 var dst io.Writer = rw 122 if p.FlushInterval != 0 { 123 if wf, ok := rw.(writeFlusher); ok { 124 dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval} 125 } 126 } 127 io.Copy(dst, res.Body) 128 } 129 } 130 131 type writeFlusher interface { 132 io.Writer 133 http.Flusher 134 } 135 136 type maxLatencyWriter struct { 137 dst writeFlusher 138 latency time.Duration 139 140 lk sync.Mutex // protects init of done, as well Write + Flush 141 done chan bool 142 } 143 144 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { 145 m.lk.Lock() 146 defer m.lk.Unlock() 147 if m.done == nil { 148 m.done = make(chan bool) 149 go m.flushLoop() 150 } 151 n, err = m.dst.Write(p) 152 if err != nil { 153 m.done <- true 154 } 155 return 156 } 157 158 func (m *maxLatencyWriter) flushLoop() { 159 t := time.NewTicker(m.latency) 160 defer t.Stop() 161 for { 162 select { 163 case <-t.C: 164 m.lk.Lock() 165 m.dst.Flush() 166 m.lk.Unlock() 167 case <-m.done: 168 return 169 } 170 } 171 panic("unreached") 172 }