src/pkg/net/http/httptest/server.go - The Go Programming Language

Golang

Source file src/pkg/net/http/httptest/server.go

     1	// Copyright 2011 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	// Implementation of Server
     6	
     7	package httptest
     8	
     9	import (
    10		"crypto/tls"
    11		"flag"
    12		"fmt"
    13		"net"
    14		"net/http"
    15		"os"
    16		"sync"
    17	)
    18	
    19	// A Server is an HTTP server listening on a system-chosen port on the
    20	// local loopback interface, for use in end-to-end HTTP tests.
    21	type Server struct {
    22		URL      string // base URL of form http://ipaddr:port with no trailing slash
    23		Listener net.Listener
    24		TLS      *tls.Config // nil if not using using TLS
    25	
    26		// Config may be changed after calling NewUnstartedServer and
    27		// before Start or StartTLS.
    28		Config *http.Server
    29	
    30		// wg counts the number of outstanding HTTP requests on this server.
    31		// Close blocks until all requests are finished.
    32		wg sync.WaitGroup
    33	}
    34	
    35	// historyListener keeps track of all connections that it's ever
    36	// accepted.
    37	type historyListener struct {
    38		net.Listener
    39		history []net.Conn
    40	}
    41	
    42	func (hs *historyListener) Accept() (c net.Conn, err error) {
    43		c, err = hs.Listener.Accept()
    44		if err == nil {
    45			hs.history = append(hs.history, c)
    46		}
    47		return
    48	}
    49	
    50	func newLocalListener() net.Listener {
    51		if *serve != "" {
    52			l, err := net.Listen("tcp", *serve)
    53			if err != nil {
    54				panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
    55			}
    56			return l
    57		}
    58		l, err := net.Listen("tcp", "127.0.0.1:0")
    59		if err != nil {
    60			if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
    61				panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
    62			}
    63		}
    64		return l
    65	}
    66	
    67	// When debugging a particular http server-based test,
    68	// this flag lets you run
    69	//	go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
    70	// to start the broken server so you can interact with it manually.
    71	var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
    72	
    73	// NewServer starts and returns a new Server.
    74	// The caller should call Close when finished, to shut it down.
    75	func NewServer(handler http.Handler) *Server {
    76		ts := NewUnstartedServer(handler)
    77		ts.Start()
    78		return ts
    79	}
    80	
    81	// NewUnstartedServer returns a new Server but doesn't start it.
    82	//
    83	// After changing its configuration, the caller should call Start or
    84	// StartTLS.
    85	//
    86	// The caller should call Close when finished, to shut it down.
    87	func NewUnstartedServer(handler http.Handler) *Server {
    88		return &Server{
    89			Listener: newLocalListener(),
    90			Config:   &http.Server{Handler: handler},
    91		}
    92	}
    93	
    94	// Start starts a server from NewUnstartedServer.
    95	func (s *Server) Start() {
    96		if s.URL != "" {
    97			panic("Server already started")
    98		}
    99		s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)}
   100		s.URL = "http://" + s.Listener.Addr().String()
   101		s.wrapHandler()
   102		go s.Config.Serve(s.Listener)
   103		if *serve != "" {
   104			fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
   105			select {}
   106		}
   107	}
   108	
   109	// StartTLS starts TLS on a server from NewUnstartedServer.
   110	func (s *Server) StartTLS() {
   111		if s.URL != "" {
   112			panic("Server already started")
   113		}
   114		cert, err := tls.X509KeyPair(localhostCert, localhostKey)
   115		if err != nil {
   116			panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
   117		}
   118	
   119		s.TLS = &tls.Config{
   120			NextProtos:   []string{"http/1.1"},
   121			Certificates: []tls.Certificate{cert},
   122		}
   123		tlsListener := tls.NewListener(s.Listener, s.TLS)
   124	
   125		s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)}
   126		s.URL = "https://" + s.Listener.Addr().String()
   127		s.wrapHandler()
   128		go s.Config.Serve(s.Listener)
   129	}
   130	
   131	func (s *Server) wrapHandler() {
   132		h := s.Config.Handler
   133		if h == nil {
   134			h = http.DefaultServeMux
   135		}
   136		s.Config.Handler = &waitGroupHandler{
   137			s: s,
   138			h: h,
   139		}
   140	}
   141	
   142	// NewTLSServer starts and returns a new Server using TLS.
   143	// The caller should call Close when finished, to shut it down.
   144	func NewTLSServer(handler http.Handler) *Server {
   145		ts := NewUnstartedServer(handler)
   146		ts.StartTLS()
   147		return ts
   148	}
   149	
   150	// Close shuts down the server and blocks until all outstanding
   151	// requests on this server have completed.
   152	func (s *Server) Close() {
   153		s.Listener.Close()
   154		s.wg.Wait()
   155	}
   156	
   157	// CloseClientConnections closes any currently open HTTP connections
   158	// to the test Server.
   159	func (s *Server) CloseClientConnections() {
   160		hl, ok := s.Listener.(*historyListener)
   161		if !ok {
   162			return
   163		}
   164		for _, conn := range hl.history {
   165			conn.Close()
   166		}
   167	}
   168	
   169	// waitGroupHandler wraps a handler, incrementing and decrementing a
   170	// sync.WaitGroup on each request, to enable Server.Close to block
   171	// until outstanding requests are finished.
   172	type waitGroupHandler struct {
   173		s *Server
   174		h http.Handler // non-nil
   175	}
   176	
   177	func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   178		h.s.wg.Add(1)
   179		defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
   180		h.h.ServeHTTP(w, r)
   181	}
   182	
   183	// localhostCert is a PEM-encoded TLS cert with SAN DNS names
   184	// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
   185	// of ASN.1 time).
   186	var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
   187	MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
   188	DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7
   189	qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL
   190	8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw
   191	DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG
   192	CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7
   193	+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM=
   194	-----END CERTIFICATE-----
   195	`)
   196	
   197	// localhostKey is the private key for localhostCert.
   198	var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
   199	MIIBPQIBAAJBALLgOZgBTI+kO6qAc3LysyKuJM7k+XqUqdgJHEH8gR5uytd1rO7v
   200	tG+VW/YKk3+XAIiCnK7a11apC/ItVEBegM8CAwEAAQJBAI5sxq7naeR9ahyqRkJi
   201	SIv2iMxLuPEHaezf5CYOPWjSjBPyVhyRevkhtqEjF/WkgL7C2nWpYHsUcBDBQVF0
   202	3KECIQDtEGB2ulnkZAahl3WuJziXGLB+p8Wgx7wzSM6bHu1c6QIhAMEp++CaS+SJ
   203	/TrU0zwY/fW4SvQeb49BPZUF3oqR8Xz3AiEA1rAJHBzBgdOQKdE3ksMUPcnvNJSN
   204	poCcELmz2clVXtkCIQCLytuLV38XHToTipR4yMl6O+6arzAjZ56uq7m7ZRV0TwIh
   205	AM65XAOw8Dsg9Kq78aYXiOEDc5DL0sbFUu/SlmRcCg93
   206	-----END RSA PRIVATE KEY-----
   207	`)