New Upstream Release - golang-github-pin-tftp
Ready changes
Summary
Merged new upstream version: 3.0.0 (was: 2.2.0).
Resulting package
Built on 2022-12-14T15:54 (took 3m46s)
The resulting binary packages can be installed (if you have the apt repository enabled) by running one of:
apt install -t fresh-releases golang-github-pin-tftp-dev
Lintian Result
Diff
diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml
new file mode 100644
index 0000000..3564876
--- /dev/null
+++ b/.github/workflows/macos.yml
@@ -0,0 +1,25 @@
+name: MacOS test
+
+on:
+ push:
+ branches: [ "master" ]
+ pull_request:
+ branches: [ "master" ]
+
+jobs:
+
+ build:
+ runs-on: macos-latest
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Set up Go
+ uses: actions/setup-go@v3
+ with:
+ go-version: 1.13
+
+ - name: Build
+ run: go build -v ./...
+
+ - name: Test
+ run: go test -v ./... -race
diff --git a/.github/workflows/ubuntu.yml b/.github/workflows/ubuntu.yml
new file mode 100644
index 0000000..bf4eacb
--- /dev/null
+++ b/.github/workflows/ubuntu.yml
@@ -0,0 +1,25 @@
+name: Linux test
+
+on:
+ push:
+ branches: [ "master" ]
+ pull_request:
+ branches: [ "master" ]
+
+jobs:
+
+ build:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Set up Go
+ uses: actions/setup-go@v3
+ with:
+ go-version: 1.18
+
+ - name: Build
+ run: go build -v ./...
+
+ - name: Test
+ run: go test -v ./... -race
diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml
new file mode 100644
index 0000000..c460e60
--- /dev/null
+++ b/.github/workflows/windows.yml
@@ -0,0 +1,25 @@
+name: Windows test
+
+on:
+ push:
+ branches: [ "master" ]
+ pull_request:
+ branches: [ "master" ]
+
+jobs:
+
+ build:
+ runs-on: windows-latest
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Set up Go
+ uses: actions/setup-go@v3
+ with:
+ go-version: 1.18
+
+ - name: Build
+ run: go build -v ./...
+
+ - name: Test
+ run: go test -v ./... -race
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index edbec9a..0000000
--- a/.travis.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-language: go
-
-os:
- - linux
- - osx
-
-before_install:
- - ulimit -n 4096
diff --git a/README.md b/README.md
index 4524294..aa1edcd 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,6 @@ TFTP server and client library for Golang
=========================================
[![GoDoc](https://godoc.org/github.com/pin/tftp?status.svg)](https://godoc.org/github.com/pin/tftp)
-[![Build Status](https://travis-ci.org/pin/tftp.svg?branch=master)](https://travis-ci.org/pin/tftp)
Implements:
* [RFC 1350](https://tools.ietf.org/html/rfc1350) - The TFTP Protocol (Revision 2)
@@ -15,7 +14,7 @@ Partially implements (tsize server side only):
Set of features is sufficient for PXE boot support.
``` go
-import "github.com/pin/tftp"
+import "github.com/pin/tftp/v3"
```
The package is cohesive to Golang `io`. Particularly it implements
diff --git a/connection.go b/connection.go
index 84f6b84..47fa813 100644
--- a/connection.go
+++ b/connection.go
@@ -36,7 +36,7 @@ type connConnection struct {
}
type chanConnection struct {
- sendConn net.PacketConn
+ server *Server
channel chan []byte
srcAddr, addr *net.UDPAddr
timeout time.Duration
@@ -45,7 +45,9 @@ type chanConnection struct {
func (c *chanConnection) sendTo(data []byte, addr *net.UDPAddr) error {
var err error
- if conn, ok := c.sendConn.(*net.UDPConn); ok {
+ c.server.Lock()
+ defer c.server.Unlock()
+ if conn, ok := c.server.conn.(*net.UDPConn); ok {
srcAddr := c.srcAddr.IP.To4()
var cmm []byte
if srcAddr != nil {
@@ -57,7 +59,7 @@ func (c *chanConnection) sendTo(data []byte, addr *net.UDPAddr) error {
}
_, _, err = conn.WriteMsgUDP(data, cmm, c.addr)
} else {
- _, err = c.sendConn.WriteTo(data, addr)
+ _, err = c.server.conn.WriteTo(data, addr)
}
return err
}
@@ -80,8 +82,10 @@ func (c *chanConnection) setDeadline(deadline time.Duration) error {
}
func (c *chanConnection) close() {
+ c.server.Lock()
+ defer c.server.Unlock()
close(c.channel)
- c.complete <- c.addr.String()
+ delete(c.server.handlers, c.addr.String())
}
func (c *connConnection) sendTo(data []byte, addr *net.UDPAddr) error {
diff --git a/debian/changelog b/debian/changelog
index 273e51f..20810eb 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+golang-github-pin-tftp (3.0.0-1) UNRELEASED; urgency=low
+
+ * New upstream release.
+
+ -- Debian Janitor <janitor@jelmer.uk> Wed, 14 Dec 2022 15:51:33 -0000
+
golang-github-pin-tftp (2.2.0-3) unstable; urgency=medium
* Team upload.
diff --git a/debian/patches/single_port_test.go.patch b/debian/patches/single_port_test.go.patch
index b9ddc25..6a0992f 100644
--- a/debian/patches/single_port_test.go.patch
+++ b/debian/patches/single_port_test.go.patch
@@ -1,8 +1,10 @@
Subject: remove test that timeout during build
---- golang-github-pin-tftp-2.2.0/single_port_test.go
-+++ golang-github-pin-tftp-2.2.0/single_port_test.go
-@@ -10,14 +10,6 @@
+Index: golang-github-pin-tftp.git/single_port_test.go
+===================================================================
+--- golang-github-pin-tftp.git.orig/single_port_test.go
++++ golang-github-pin-tftp.git/single_port_test.go
+@@ -10,14 +10,6 @@ func TestZeroLengthSinglePort(t *testing
testSendReceive(t, c, 0)
}
@@ -17,7 +19,7 @@ Subject: remove test that timeout during build
func TestSendReceiveSinglePortWithBlockSize(t *testing.T) {
s, c := makeTestServer(true)
defer s.Shutdown()
-@@ -27,11 +19,6 @@
+@@ -27,11 +19,6 @@ func TestSendReceiveSinglePortWithBlockS
}
}
diff --git a/go.mod b/go.mod
index f5cbc82..a22f3d5 100644
--- a/go.mod
+++ b/go.mod
@@ -1,8 +1,5 @@
-module github.com/pin/tftp
+module github.com/pin/tftp/v3
go 1.13
-require (
- github.com/stretchr/testify v1.4.0
- golang.org/x/net v0.0.0-20200202094626-16171245cfb2
-)
+require golang.org/x/net v0.0.0-20200202094626-16171245cfb2
diff --git a/go.sum b/go.sum
index 4d1a89f..e58bf9c 100644
--- a/go.sum
+++ b/go.sum
@@ -1,18 +1,6 @@
-github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
-github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
-github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
-gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
-gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
diff --git a/receiver.go b/receiver.go
index 0e0f787..7389d5f 100644
--- a/receiver.go
+++ b/receiver.go
@@ -8,7 +8,7 @@ import (
"strconv"
"time"
- "github.com/pin/tftp/netascii"
+ "github.com/pin/tftp/v3/netascii"
)
// IncomingTransfer provides methods that expose information associated with
diff --git a/sender.go b/sender.go
index a350b68..cee468a 100644
--- a/sender.go
+++ b/sender.go
@@ -8,7 +8,7 @@ import (
"strconv"
"time"
- "github.com/pin/tftp/netascii"
+ "github.com/pin/tftp/v3/netascii"
)
// OutgoingTransfer provides methods to set the outgoing transfer size and
@@ -67,6 +67,12 @@ func (s *sender) ReadFrom(r io.Reader) (n int64, err error) {
if s.mode == "netascii" {
r = netascii.ToReader(r)
}
+ defer func() {
+ if s.conn != nil {
+ s.conn.close()
+ s.conn = nil
+ }
+ }()
if s.opts != nil {
// check that tsize is set
if ts, ok := s.opts["tsize"]; ok {
@@ -115,7 +121,6 @@ func (s *sender) ReadFrom(r io.Reader) (n int64, err error) {
if s.hook != nil {
s.hook.OnSuccess(s.buildTransferStats())
}
- s.conn.close()
return n, nil
}
s.abort(err)
@@ -131,7 +136,6 @@ func (s *sender) ReadFrom(r io.Reader) (n int64, err error) {
if s.hook != nil {
s.hook.OnSuccess(s.buildTransferStats())
}
- s.conn.close()
return n, nil
}
s.block++
@@ -258,15 +262,15 @@ func (s *sender) sendDatagram(l int) (*net.UDPAddr, error) {
func (s *sender) buildTransferStats() TransferStats {
return TransferStats{
- RemoteAddr: s.addr.IP,
- Filename: s.filename,
- Tid: s.tid,
+ RemoteAddr: s.addr.IP,
+ Filename: s.filename,
+ Tid: s.tid,
SenderAnticipateEnabled: s.sendA.enabled,
- Mode: s.mode,
- Opts: s.opts,
- Duration: time.Now().Sub(s.startTime),
- DatagramsSent: s.datagramsSent,
- DatagramsAcked: s.datagramsAcked,
+ Mode: s.mode,
+ Opts: s.opts,
+ Duration: time.Now().Sub(s.startTime),
+ DatagramsSent: s.datagramsSent,
+ DatagramsAcked: s.datagramsAcked,
}
}
diff --git a/server.go b/server.go
index c9703ff..8e06b55 100644
--- a/server.go
+++ b/server.go
@@ -1,6 +1,7 @@
package tftp
import (
+ "context"
"fmt"
"io"
"net"
@@ -18,14 +19,15 @@ import (
func NewServer(readHandler func(filename string, rf io.ReaderFrom) error,
writeHandler func(filename string, wt io.WriterTo) error) *Server {
s := &Server{
+ Mutex: &sync.Mutex{},
timeout: defaultTimeout,
retries: defaultRetries,
- runGC: make(chan []string),
- gcThreshold: 100,
packetReadTimeout: 100 * time.Millisecond,
readHandler: readHandler,
writeHandler: writeHandler,
+ wg: &sync.WaitGroup{},
}
+ s.cancel, s.cancelFn = context.WithCancel(context.Background())
return s
}
@@ -42,6 +44,7 @@ type RequestPacketInfo interface {
// Server is an instance of a TFTP server
type Server struct {
+ *sync.Mutex
readHandler func(filename string, rf io.ReaderFrom) error
writeHandler func(filename string, wt io.WriterTo) error
hook Hook
@@ -49,8 +52,7 @@ type Server struct {
conn net.PacketConn
conn6 *ipv6.PacketConn
conn4 *ipv4.PacketConn
- quit chan chan struct{}
- wg sync.WaitGroup
+ wg *sync.WaitGroup
timeout time.Duration
retries int
maxBlockLen int
@@ -58,12 +60,10 @@ type Server struct {
sendAWinSz uint
// Single port fields
singlePort bool
- bufPool sync.Pool
handlers map[string]chan []byte
- runGC chan []string
- gcCollect chan string
- gcThreshold int
packetReadTimeout time.Duration
+ cancel context.Context
+ cancelFn context.CancelFunc
}
// TransferStats contains details about a single TFTP transfer
@@ -95,6 +95,8 @@ type Hook interface {
// runs through a different experimental code path. When winsz is 0 or 1,
// the feature is disabled.
func (s *Server) SetAnticipate(winsz uint) {
+ s.Lock()
+ defer s.Unlock()
if winsz > 1 {
s.sendAEnable = true
s.sendAWinSz = winsz
@@ -106,6 +108,8 @@ func (s *Server) SetAnticipate(winsz uint) {
// SetHook sets the Hook for success and failure of transfers
func (s *Server) SetHook(hook Hook) {
+ s.Lock()
+ defer s.Unlock()
s.hook = hook
}
@@ -115,24 +119,21 @@ func (s *Server) SetHook(hook Hook) {
//
// Enabling this will negatively impact performance
func (s *Server) EnableSinglePort() {
+ s.Lock()
+ defer s.Unlock()
s.singlePort = true
s.handlers = make(map[string]chan []byte)
- s.gcCollect = make(chan string)
if s.maxBlockLen == 0 {
s.maxBlockLen = blockLength
}
- s.bufPool = sync.Pool{
- New: func() interface{} {
- return make([]byte, s.maxBlockLen+4)
- },
- }
- go s.internalGC()
}
// SetTimeout sets maximum time server waits for single network
// round-trip to succeed.
// Default is 5 seconds.
func (s *Server) SetTimeout(t time.Duration) {
+ s.Lock()
+ defer s.Unlock()
if t <= 0 {
s.timeout = defaultTimeout
} else {
@@ -148,6 +149,8 @@ func (s *Server) SetTimeout(t time.Duration) {
// the block size the client wants and the MTU of the interface being
// communicated over munis overhead.
func (s *Server) SetBlockSize(i int) {
+ s.Lock()
+ defer s.Unlock()
if i > 512 && i < 65465 {
s.maxBlockLen = i
}
@@ -157,6 +160,8 @@ func (s *Server) SetBlockSize(i int) {
// packet.
// Default is 5 attempts.
func (s *Server) SetRetries(count int) {
+ s.Lock()
+ defer s.Unlock()
if count < 1 {
s.retries = defaultRetries
} else {
@@ -167,6 +172,8 @@ func (s *Server) SetRetries(count int) {
// SetBackoff sets a user provided function that is called to provide a
// backoff duration prior to retransmitting an unacknowledged packet.
func (s *Server) SetBackoff(h backoffFunc) {
+ s.Lock()
+ defer s.Unlock()
s.backoff = h
}
@@ -184,18 +191,20 @@ func (s *Server) ListenAndServe(addr string) error {
return s.Serve(conn)
}
-// Serve starts server provided already opened UDP connecton. It is
+// Serve starts server provided already opened UDP connection. It is
// useful for the case when you want to run server in separate goroutine
// but still want to be able to handle any errors opening connection.
// Serve returns when Shutdown is called or connection is closed.
func (s *Server) Serve(conn net.PacketConn) error {
- defer conn.Close()
+ // defer conn.Close()
laddr := conn.LocalAddr()
host, _, err := net.SplitHostPort(laddr.String())
if err != nil {
return err
}
+ s.Lock()
s.conn = conn
+ s.Unlock()
// Having seperate control paths for IP4 and IP6 is annoying,
// but necessary at this point.
addr := net.ParseIP(host)
@@ -203,7 +212,7 @@ func (s *Server) Serve(conn net.PacketConn) error {
return fmt.Errorf("Failed to determine IP class of listening address")
}
- if conn, ok := conn.(*net.UDPConn); ok {
+ if conn, ok := s.conn.(*net.UDPConn); ok {
if addr.To4() != nil {
s.conn4 = ipv4.NewPacketConn(conn)
if err := s.conn4.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil {
@@ -217,36 +226,34 @@ func (s *Server) Serve(conn net.PacketConn) error {
}
}
- s.quit = make(chan chan struct{})
if s.singlePort {
- s.singlePortProcessRequests()
- } else {
- for {
- select {
- case q := <-s.quit:
- q <- struct{}{}
- return nil
- default:
- var err error
- if s.conn4 != nil {
- err = s.processRequest4()
- } else if s.conn6 != nil {
- err = s.processRequest6()
- } else {
- err = s.processRequest()
- }
- if err != nil && s.hook != nil {
- s.hook.OnFailure(TransferStats{
- SenderAnticipateEnabled: s.sendAEnable,
- }, err)
- }
+ return s.singlePortProcessRequests()
+ }
+ for {
+ select {
+ case <-s.cancel.Done():
+ s.wg.Wait()
+ return nil
+ default:
+ var err error
+ if s.conn4 != nil {
+ err = s.processRequest4()
+ } else if s.conn6 != nil {
+ err = s.processRequest6()
+ } else {
+ err = s.processRequest()
+ }
+ if err != nil && s.hook != nil {
+ s.hook.OnFailure(TransferStats{
+ SenderAnticipateEnabled: s.sendAEnable,
+ }, err)
}
}
}
return nil
}
-// Yes, I don't really like having seperate IPv4 and IPv6 variants,
+// Yes, I don't really like having separate IPv4 and IPv6 variants,
// bit we are relying on the low-level packet control channel info to
// get a reliable source address, and those have different types and
// the struct itself is not easily interface-ized or embedded.
@@ -258,7 +265,7 @@ func (s *Server) processRequest4() error {
buf := make([]byte, datagramLength)
cnt, control, srcAddr, err := s.conn4.ReadFrom(buf)
if err != nil {
- return fmt.Errorf("reading UDP: %v", err)
+ return nil
}
maxSz := blockLength
var localAddr net.IP
@@ -276,7 +283,7 @@ func (s *Server) processRequest6() error {
buf := make([]byte, datagramLength)
cnt, control, srcAddr, err := s.conn6.ReadFrom(buf)
if err != nil {
- return fmt.Errorf("reading UDP: %v", err)
+ return nil
}
maxSz := blockLength
var localAddr net.IP
@@ -304,15 +311,21 @@ func (s *Server) processRequest() error {
// server to finish outstanding transfers and stops server.
func (s *Server) Shutdown() {
if !s.singlePort {
+ s.Lock()
s.conn.Close()
+ s.Unlock()
}
- q := make(chan struct{})
- s.quit <- q
- <-q
- s.wg.Wait()
+ s.cancelFn()
}
func (s *Server) handlePacket(localAddr net.IP, remoteAddr *net.UDPAddr, buffer []byte, n, maxBlockLen int, listener chan []byte) error {
+ s.Lock()
+ defer s.Unlock()
+ // handlePacket is always called with maxBlockLen = blockLength (above, in processRequest).
+ // As a result, the block size would always be capped at 512 bytes, even when the tftp
+ // client indicated to use a larger value. So override that value. And make sure to
+ // use that value below, when allocating buffers. (Happening on Windows Server 2016.)
+ // if s.maxBlockLen > 0 {
if s.maxBlockLen > 0 && s.maxBlockLen < maxBlockLen {
maxBlockLen = s.maxBlockLen
}
@@ -351,12 +364,11 @@ func (s *Server) handlePacket(localAddr net.IP, remoteAddr *net.UDPAddr, buffer
}
if s.singlePort {
wt.conn = &chanConnection{
- srcAddr: listenAddr,
- addr: remoteAddr,
- channel: listener,
- timeout: s.timeout,
- sendConn: s.conn,
- complete: s.gcCollect,
+ server: s,
+ srcAddr: listenAddr,
+ addr: remoteAddr,
+ channel: listener,
+ timeout: s.timeout,
}
wt.singlePort = true
} else {
@@ -405,12 +417,11 @@ func (s *Server) handlePacket(localAddr net.IP, remoteAddr *net.UDPAddr, buffer
}
if s.singlePort {
rf.conn = &chanConnection{
- srcAddr: listenAddr,
- addr: remoteAddr,
- channel: listener,
- timeout: s.timeout,
- sendConn: s.conn,
- complete: s.gcCollect,
+ server: s,
+ srcAddr: listenAddr,
+ addr: remoteAddr,
+ channel: listener,
+ timeout: s.timeout,
}
} else {
conn, err := net.ListenUDP("udp", listenAddr)
@@ -424,7 +435,7 @@ func (s *Server) handlePacket(localAddr net.IP, remoteAddr *net.UDPAddr, buffer
sendAInit(&rf.sendA, datagramLength, s.sendAWinSz)
}
s.wg.Add(1)
- go func() {
+ go func(rh func(string, io.ReaderFrom) error, rf *sender, wg *sync.WaitGroup) {
if s.readHandler != nil {
err := s.readHandler(filename, rf)
if err != nil {
@@ -434,7 +445,7 @@ func (s *Server) handlePacket(localAddr net.IP, remoteAddr *net.UDPAddr, buffer
rf.abort(fmt.Errorf("server does not support read requests"))
}
s.wg.Done()
- }()
+ }(s.readHandler, rf, s.wg)
default:
return fmt.Errorf("unexpected %T", p)
}
diff --git a/single_port.go b/single_port.go
index 048d6e9..248866b 100644
--- a/single_port.go
+++ b/single_port.go
@@ -5,68 +5,42 @@ import (
)
func (s *Server) singlePortProcessRequests() error {
- var (
- localAddr net.IP
- cnt, maxSz int
- srcAddr net.Addr
- err error
- buf []byte
- )
- defer func() {
- if r := recover(); r != nil {
- // We've received a new connection on the same IP+Port tuple
- // as a previous connection before garbage collection has occured
- s.handlers[srcAddr.String()] = make(chan []byte, 1)
- go func(localAddr net.IP, remoteAddr *net.UDPAddr, buffer []byte, n, maxBlockLen int, listener chan []byte) {
- err := s.handlePacket(localAddr, remoteAddr, buffer, n, maxBlockLen, listener)
- if err != nil && s.hook != nil {
- s.hook.OnFailure(TransferStats{
- SenderAnticipateEnabled: s.sendAEnable,
- }, err)
- }
-
- }(localAddr, srcAddr.(*net.UDPAddr), buf, cnt, maxSz, s.handlers[srcAddr.String()])
- s.singlePortProcessRequests()
- }
- }()
for {
select {
- case q := <-s.quit:
- q <- struct{}{}
+ case <-s.cancel.Done():
+ s.wg.Wait()
return nil
- case handlersToFree := <-s.runGC:
- for _, handler := range handlersToFree {
- delete(s.handlers, handler)
- }
default:
- buf = s.bufPool.Get().([]byte)
- cnt, localAddr, srcAddr, maxSz, err = s.getPacket(buf)
+ buf := make([]byte, s.maxBlockLen+4)
+ cnt, localAddr, srcAddr, maxSz, err := s.getPacket(buf)
if err != nil || cnt == 0 {
if s.hook != nil {
s.hook.OnFailure(TransferStats{
SenderAnticipateEnabled: s.sendAEnable,
}, err)
}
- s.bufPool.Put(buf)
continue
}
+ s.Lock()
if receiverChannel, ok := s.handlers[srcAddr.String()]; ok {
+ s.Unlock()
select {
case receiverChannel <- buf[:cnt]:
default:
// We don't want to block the main loop if a channel is full
}
} else {
- s.handlers[srcAddr.String()] = make(chan []byte, 1)
- go func(localAddr net.IP, remoteAddr *net.UDPAddr, buffer []byte, n, maxBlockLen int, listener chan []byte) {
- err := s.handlePacket(localAddr, remoteAddr, buffer, n, maxBlockLen, listener)
+ lc := make(chan []byte, 1)
+ s.handlers[srcAddr.String()] = lc
+ s.Unlock()
+ go func() {
+ err := s.handlePacket(localAddr, srcAddr, buf, cnt, maxSz, lc)
if err != nil && s.hook != nil {
s.hook.OnFailure(TransferStats{
SenderAnticipateEnabled: s.sendAEnable,
}, err)
}
-
- }(localAddr, srcAddr.(*net.UDPAddr), buf, cnt, maxSz, s.handlers[srcAddr.String()])
+ }()
}
}
}
@@ -111,19 +85,3 @@ func (s *Server) getPacket(buf []byte) (int, net.IP, *net.UDPAddr, int, error) {
return cnt, nil, srcAddr.(*net.UDPAddr), blockLength, nil
}
}
-
-// internalGC collects all the finished signals from each connection's goroutine
-// The main loop is sent the key to be nil'ed after the gcInterval has passed
-func (s *Server) internalGC() {
- var completedHandlers []string
- for {
- select {
- case newHandler := <-s.gcCollect:
- completedHandlers = append(completedHandlers, newHandler)
- if len(completedHandlers) > s.gcThreshold {
- s.runGC <- completedHandlers
- completedHandlers = nil
- }
- }
- }
-}
diff --git a/tftp_test.go b/tftp_test.go
index 647ee34..12bb33b 100644
--- a/tftp_test.go
+++ b/tftp_test.go
@@ -14,8 +14,6 @@ import (
"testing"
"testing/iotest"
"time"
-
- "github.com/stretchr/testify/mock"
)
var localhost = determineLocalhost()
@@ -153,26 +151,34 @@ func Test1810(t *testing.T) {
testSendReceive(t, c, 9000+1810)
}
-type fakeHook struct {
- mock.Mock
+type testHook struct {
+ *sync.Mutex
+ transfersCompleted int
+ transfersFailed int
}
-func (f *fakeHook) OnSuccess(result TransferStats) {
- f.Called(result)
- return
+func newTestHook() *testHook {
+ return &testHook{
+ Mutex: &sync.Mutex{},
+ }
}
-func (f *fakeHook) OnFailure(result TransferStats, err error) {
- f.Called(result)
- return
+
+func (h *testHook) OnSuccess(result TransferStats) {
+ h.Lock()
+ defer h.Unlock()
+ h.transfersCompleted++
+}
+
+func (h *testHook) OnFailure(result TransferStats, err error) {
+ h.Lock()
+ defer h.Unlock()
+ h.transfersFailed++
}
func TestHookSuccess(t *testing.T) {
s, c := makeTestServer(false)
- fakeHookTemp := new(fakeHook)
- // Due to the way test are run there will always be some failures
- fakeHookTemp.On("OnFailure", mock.AnythingOfType("TransferStats")).Return()
- fakeHookTemp.On("OnSuccess", mock.AnythingOfType("TransferStats")).Return()
- s.SetHook(fakeHookTemp)
+ th := newTestHook()
+ s.SetHook(th)
c.SetBlockSize(1810)
length := int64(9000)
filename := fmt.Sprintf("length-%d-bytes-%d", length, time.Now().UnixNano())
@@ -189,14 +195,17 @@ func TestHookSuccess(t *testing.T) {
t.Errorf("%s length mismatch: %d != %d", filename, n, length)
}
s.Shutdown()
- fakeHookTemp.AssertNumberOfCalls(t, "OnSuccess", 1)
+ th.Lock()
+ defer th.Unlock()
+ if th.transfersCompleted != 1 {
+ t.Errorf("unexpected completed transfers count: %d", th.transfersCompleted)
+ }
}
func TestHookFailure(t *testing.T) {
s, c := makeTestServer(false)
- fakeHookTemp := new(fakeHook)
- fakeHookTemp.On("OnFailure", mock.AnythingOfType("TransferStats")).Return()
- s.SetHook(fakeHookTemp)
+ th := newTestHook()
+ s.SetHook(th)
filename := "test-not-exists"
mode := "octet"
_, err := c.Receive(filename, mode)
@@ -205,7 +214,11 @@ func TestHookFailure(t *testing.T) {
}
t.Logf("receiving file that does not exist: %v", err)
s.Shutdown()
- fakeHookTemp.AssertExpectations(t)
+ th.Lock()
+ defer th.Unlock()
+ if th.transfersFailed == 0 { // TODO: there are two failures, not one on Windows?
+ t.Errorf("unexpected failed transfers count: %d", th.transfersFailed)
+ }
}
func TestTSize(t *testing.T) {
@@ -427,7 +440,6 @@ func makeTestServer(singlePort bool) (*Server, *Client) {
if singlePort {
s.SetBlockSize(2000)
- s.gcThreshold = 100000
s.EnableSinglePort()
}
@@ -546,12 +558,15 @@ func (r *randReader) Read(p []byte) (n int, err error) {
func serverTimeoutSendTest(s *Server, c *Client, t *testing.T) {
s.SetTimeout(time.Second)
s.SetRetries(2)
- var serverErr error
+ sec := make(chan error, 1)
+ s.Lock()
s.readHandler = func(filename string, rf io.ReaderFrom) error {
r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000)
- _, serverErr = rf.ReadFrom(r)
- return serverErr
+ _, err := rf.ReadFrom(r)
+ sec <- err
+ return err
}
+ s.Unlock()
defer s.Shutdown()
filename := "test-server-send-timeout"
mode := "octet"
@@ -564,12 +579,13 @@ func serverTimeoutSendTest(s *Server, c *Client, t *testing.T) {
delay: 8 * time.Second,
}
_, _ = readTransfer.WriteTo(w)
- netErr, ok := serverErr.(net.Error)
+ servErr := <-sec
+ netErr, ok := servErr.(net.Error)
if !ok {
- t.Fatalf("network error expected: %T", serverErr)
+ t.Fatalf("network error expected: %T", servErr)
}
if !netErr.Timeout() {
- t.Fatalf("timout is expected: %v", serverErr)
+ t.Fatalf("timout is expected: %v", servErr)
}
}
@@ -582,12 +598,15 @@ func TestServerSendTimeout(t *testing.T) {
func serverReceiveTimeoutTest(s *Server, c *Client, t *testing.T) {
s.SetTimeout(time.Second)
s.SetRetries(2)
- var serverErr error
+ sec := make(chan error, 1)
+ s.Lock()
s.writeHandler = func(filename string, wt io.WriterTo) error {
buf := &bytes.Buffer{}
- _, serverErr = wt.WriteTo(buf)
- return serverErr
+ _, err := wt.WriteTo(buf)
+ sec <- err
+ return err
}
+ s.Unlock()
defer s.Shutdown()
filename := "test-server-receive-timeout"
mode := "octet"
@@ -601,12 +620,13 @@ func serverReceiveTimeoutTest(s *Server, c *Client, t *testing.T) {
delay: 8 * time.Second,
}
_, _ = writeTransfer.ReadFrom(r)
- netErr, ok := serverErr.(net.Error)
+ servErr := <-sec
+ netErr, ok := servErr.(net.Error)
if !ok {
- t.Fatalf("network error expected: %T", serverErr)
+ t.Fatalf("network error expected: %T", servErr)
}
if !netErr.Timeout() {
- t.Fatalf("timout is expected: %v", serverErr)
+ t.Fatalf("timout is expected: %v", servErr)
}
}
@@ -619,6 +639,7 @@ func TestClientReceiveTimeout(t *testing.T) {
s, c := makeTestServer(false)
c.SetTimeout(time.Second)
c.SetRetries(2)
+ s.Lock()
s.readHandler = func(filename string, rf io.ReaderFrom) error {
r := &slowReader{
r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000),
@@ -628,6 +649,7 @@ func TestClientReceiveTimeout(t *testing.T) {
_, err := rf.ReadFrom(r)
return err
}
+ s.Unlock()
defer s.Shutdown()
filename := "test-client-receive-timeout"
mode := "octet"
@@ -650,6 +672,7 @@ func TestClientSendTimeout(t *testing.T) {
s, c := makeTestServer(false)
c.SetTimeout(time.Second)
c.SetRetries(2)
+ s.Lock()
s.writeHandler = func(filename string, wt io.WriterTo) error {
w := &slowWriter{
n: 3,
@@ -658,6 +681,7 @@ func TestClientSendTimeout(t *testing.T) {
_, err := wt.WriteTo(w)
return err
}
+ s.Unlock()
defer s.Shutdown()
filename := "test-client-send-timeout"
mode := "octet"
Debdiff
File lists identical (after any substitutions)
No differences were encountered in the control files