New upstream version 1.2.1~git20200514.4baec98
Pirate Praveen
3 years ago
0 | 0 | package tableflip |
1 | 1 | |
2 | 2 | import ( |
3 | "context" | |
3 | 4 | "fmt" |
4 | 5 | "net" |
5 | 6 | "os" |
77 | 78 | // NB: Files in these maps may be in blocking mode. |
78 | 79 | inherited map[fileName]*file |
79 | 80 | used map[fileName]*file |
80 | } | |
81 | ||
82 | func newFds(inherited map[fileName]*file) *Fds { | |
81 | lc *net.ListenConfig | |
82 | } | |
83 | ||
84 | func newFds(inherited map[fileName]*file, lc *net.ListenConfig) *Fds { | |
83 | 85 | if inherited == nil { |
84 | 86 | inherited = make(map[fileName]*file) |
85 | 87 | } |
88 | ||
89 | if lc == nil { | |
90 | lc = &net.ListenConfig{} | |
91 | } | |
92 | ||
86 | 93 | return &Fds{ |
87 | 94 | inherited: inherited, |
88 | 95 | used: make(map[fileName]*file), |
96 | lc: lc, | |
89 | 97 | } |
90 | 98 | } |
91 | 99 | |
103 | 111 | return ln, nil |
104 | 112 | } |
105 | 113 | |
106 | ln, err = net.Listen(network, addr) | |
114 | ln, err = f.lc.Listen(context.Background(), network, addr) | |
107 | 115 | if err != nil { |
108 | 116 | return nil, fmt.Errorf("can't create new listener: %s", err) |
109 | 117 | } |
186 | 194 | return conn, nil |
187 | 195 | } |
188 | 196 | |
189 | conn, err = net.ListenPacket(network, addr) | |
197 | conn, err = f.lc.ListenPacket(context.Background(), network, addr) | |
190 | 198 | if err != nil { |
191 | 199 | return nil, fmt.Errorf("can't create new listener: %s", err) |
192 | 200 | } |
18 | 18 | {"tcp", "localhost:0"}, |
19 | 19 | } |
20 | 20 | |
21 | fds := newFds(nil) | |
21 | fds := newFds(nil, nil) | |
22 | ||
22 | 23 | for _, addr := range addrs { |
23 | 24 | ln, err := net.Listen(addr[0], addr[1]) |
24 | 25 | if err != nil { |
40 | 41 | {"udp", "localhost:0"}, |
41 | 42 | } |
42 | 43 | |
43 | fds := newFds(nil) | |
44 | fds := newFds(nil, nil) | |
44 | 45 | for _, addr := range addrs { |
45 | 46 | conn, err := net.ListenPacket(addr[0], addr[1]) |
46 | 47 | if err != nil { |
90 | 91 | err error |
91 | 92 | ) |
92 | 93 | |
93 | parent := newFds(nil) | |
94 | parent := newFds(nil, nil) | |
94 | 95 | for _, addr := range addrs { |
95 | 96 | switch addr[0] { |
96 | 97 | case "udp", "unixgram": |
107 | 108 | ln.Close() |
108 | 109 | } |
109 | 110 | |
110 | child := newFds(parent.copy()) | |
111 | child := newFds(parent.copy(), nil) | |
111 | 112 | for _, addr := range addrs { |
112 | 113 | switch addr[0] { |
113 | 114 | case "udp", "unixgram": |
141 | 142 | } |
142 | 143 | |
143 | 144 | makeFds := func(t *testing.T) *Fds { |
144 | fds := newFds(nil) | |
145 | fds := newFds(nil, nil) | |
145 | 146 | for _, addr := range addrs { |
146 | 147 | var c io.Closer |
147 | 148 | var err error |
173 | 174 | |
174 | 175 | t.Run("closeInherited", func(t *testing.T) { |
175 | 176 | parent := makeFds(t) |
176 | child := newFds(parent.copy()) | |
177 | child := newFds(parent.copy(), nil) | |
177 | 178 | child.closeInherited() |
178 | 179 | for _, addr := range addrs { |
179 | 180 | if _, err := os.Stat(addr[1]); err == nil { |
204 | 205 | t.Fatal(err) |
205 | 206 | } |
206 | 207 | |
207 | parent := newFds(nil) | |
208 | parent := newFds(nil, nil) | |
208 | 209 | if err := parent.AddConn("unixgram", "", unix); err != nil { |
209 | 210 | t.Fatal("Can't add conn:", err) |
210 | 211 | } |
211 | 212 | unix.Close() |
212 | 213 | |
213 | child := newFds(parent.copy()) | |
214 | child := newFds(parent.copy(), nil) | |
214 | 215 | conn, err := child.Conn("unixgram", "") |
215 | 216 | if err != nil { |
216 | 217 | t.Fatal("Can't get conn:", err) |
228 | 229 | } |
229 | 230 | defer r.Close() |
230 | 231 | |
231 | parent := newFds(nil) | |
232 | parent := newFds(nil, nil) | |
232 | 233 | if err := parent.AddFile("test", w); err != nil { |
233 | 234 | t.Fatal("Can't add file:", err) |
234 | 235 | } |
235 | 236 | w.Close() |
236 | 237 | |
237 | child := newFds(parent.copy()) | |
238 | child := newFds(parent.copy(), nil) | |
238 | 239 | file, err := child.File("test") |
239 | 240 | if err != nil { |
240 | 241 | t.Fatal("Can't get file:", err) |
4 | 4 | "errors" |
5 | 5 | "fmt" |
6 | 6 | "io/ioutil" |
7 | "net" | |
7 | 8 | "os" |
8 | 9 | "path/filepath" |
9 | 10 | "runtime" |
23 | 24 | UpgradeTimeout time.Duration |
24 | 25 | // The PID of a ready process is written to this file. |
25 | 26 | PIDFile string |
27 | // ListenConfig is a custom ListenConfig. Defaults to an empty ListenConfig | |
28 | ListenConfig *net.ListenConfig | |
26 | 29 | } |
27 | 30 | |
28 | 31 | // Upgrader handles zero downtime upgrades and passing files between processes. |
95 | 98 | upgradeC: make(chan chan<- error), |
96 | 99 | exitC: make(chan struct{}), |
97 | 100 | exitFd: make(chan neverCloseThisFile, 1), |
98 | Fds: newFds(files), | |
101 | Fds: newFds(files, opts.ListenConfig), | |
99 | 102 | } |
100 | 103 | |
101 | 104 | go u.run() |
7 | 7 | "fmt" |
8 | 8 | "io" |
9 | 9 | "io/ioutil" |
10 | "net" | |
10 | 11 | "os" |
11 | 12 | "strconv" |
13 | "syscall" | |
12 | 14 | "testing" |
13 | 15 | "time" |
14 | 16 | ) |
279 | 281 | } |
280 | 282 | } |
281 | 283 | |
284 | func TestUpgraderListenConfig(t *testing.T) { | |
285 | t.Parallel() | |
286 | ||
287 | var listenConfigUsed bool | |
288 | u := newTestUpgrader(Options{ | |
289 | ListenConfig: &net.ListenConfig{ | |
290 | Control: func(network, address string, c syscall.RawConn) error { | |
291 | listenConfigUsed = true | |
292 | return nil | |
293 | }, | |
294 | }, | |
295 | }) | |
296 | defer u.Stop() | |
297 | ||
298 | new, _ := u.upgradeProc(t) | |
299 | ||
300 | go new.recvSignal(nil) | |
301 | ||
302 | _, err := u.Listen("tcp", ":0") | |
303 | if err != nil { | |
304 | t.Errorf("Unexpected error from listen: %v", err) | |
305 | } | |
306 | ||
307 | if !listenConfigUsed { | |
308 | t.Error("Expected ListenConfig to be called during Listen") | |
309 | } | |
310 | ||
311 | new.exit(nil) | |
312 | } | |
313 | ||
282 | 314 | func TestUpgraderConcurrentUpgrade(t *testing.T) { |
283 | 315 | t.Parallel() |
284 | 316 | |
472 | 504 | func BenchmarkUpgrade(b *testing.B) { |
473 | 505 | for _, n := range []int{4, 400, 4000} { |
474 | 506 | b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { |
475 | fds := newFds(nil) | |
507 | fds := newFds(nil, nil) | |
476 | 508 | for i := 0; i < n; i += 2 { |
477 | 509 | r, w, err := os.Pipe() |
478 | 510 | if err != nil { |