Source file src/pkg/net/http/httputil/reverseproxy.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // HTTP reverse proxy handler
6
7 package httputil
8
9 import (
10 "io"
11 "log"
12 "net"
13 "net/http"
14 "net/url"
15 "strings"
16 "sync"
17 "time"
18 )
19
20 // ReverseProxy is an HTTP Handler that takes an incoming request and
21 // sends it to another server, proxying the response back to the
22 // client.
23 type ReverseProxy struct {
24 // Director must be a function which modifies
25 // the request into a new request to be sent
26 // using Transport. Its response is then copied
27 // back to the original client unmodified.
28 Director func(*http.Request)
29
30 // The transport used to perform proxy requests.
31 // If nil, http.DefaultTransport is used.
32 Transport http.RoundTripper
33
34 // FlushInterval specifies the flush interval
35 // to flush to the client while copying the
36 // response body.
37 // If zero, no periodic flushing is done.
38 FlushInterval time.Duration
39 }
40
41 func singleJoiningSlash(a, b string) string {
42 aslash := strings.HasSuffix(a, "/")
43 bslash := strings.HasPrefix(b, "/")
44 switch {
45 case aslash && bslash:
46 return a + b[1:]
47 case !aslash && !bslash:
48 return a + "/" + b
49 }
50 return a + b
51 }
52
53 // NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
54 // URLs to the scheme, host, and base path provided in target. If the
55 // target's path is "/base" and the incoming request was for "/dir",
56 // the target request will be for /base/dir.
57 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
58 targetQuery := target.RawQuery
59 director := func(req *http.Request) {
60 req.URL.Scheme = target.Scheme
61 req.URL.Host = target.Host
62 req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
63 if targetQuery == "" || req.URL.RawQuery == "" {
64 req.URL.RawQuery = targetQuery + req.URL.RawQuery
65 } else {
66 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
67 }
68 }
69 return &ReverseProxy{Director: director}
70 }
71
72 func copyHeader(dst, src http.Header) {
73 for k, vv := range src {
74 for _, v := range vv {
75 dst.Add(k, v)
76 }
77 }
78 }
79
80 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
81 transport := p.Transport
82 if transport == nil {
83 transport = http.DefaultTransport
84 }
85
86 outreq := new(http.Request)
87 *outreq = *req // includes shallow copies of maps, but okay
88
89 p.Director(outreq)
90 outreq.Proto = "HTTP/1.1"
91 outreq.ProtoMajor = 1
92 outreq.ProtoMinor = 1
93 outreq.Close = false
94
95 // Remove the connection header to the backend. We want a
96 // persistent connection, regardless of what the client sent
97 // to us. This is modifying the same underlying map from req
98 // (shallow copied above) so we only copy it if necessary.
99 if outreq.Header.Get("Connection") != "" {
100 outreq.Header = make(http.Header)
101 copyHeader(outreq.Header, req.Header)
102 outreq.Header.Del("Connection")
103 }
104
105 if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
106 outreq.Header.Set("X-Forwarded-For", clientIp)
107 }
108
109 res, err := transport.RoundTrip(outreq)
110 if err != nil {
111 log.Printf("http: proxy error: %v", err)
112 rw.WriteHeader(http.StatusInternalServerError)
113 return
114 }
115
116 copyHeader(rw.Header(), res.Header)
117
118 rw.WriteHeader(res.StatusCode)
119
120 if res.Body != nil {
121 var dst io.Writer = rw
122 if p.FlushInterval != 0 {
123 if wf, ok := rw.(writeFlusher); ok {
124 dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
125 }
126 }
127 io.Copy(dst, res.Body)
128 }
129 }
130
131 type writeFlusher interface {
132 io.Writer
133 http.Flusher
134 }
135
136 type maxLatencyWriter struct {
137 dst writeFlusher
138 latency time.Duration
139
140 lk sync.Mutex // protects init of done, as well Write + Flush
141 done chan bool
142 }
143
144 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
145 m.lk.Lock()
146 defer m.lk.Unlock()
147 if m.done == nil {
148 m.done = make(chan bool)
149 go m.flushLoop()
150 }
151 n, err = m.dst.Write(p)
152 if err != nil {
153 m.done <- true
154 }
155 return
156 }
157
158 func (m *maxLatencyWriter) flushLoop() {
159 t := time.NewTicker(m.latency)
160 defer t.Stop()
161 for {
162 select {
163 case <-t.C:
164 m.lk.Lock()
165 m.dst.Flush()
166 m.lk.Unlock()
167 case <-m.done:
168 return
169 }
170 }
171 panic("unreached")
172 }