-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.go
134 lines (115 loc) · 3.18 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/*
A simple in-memory TFTP server to RFC1350.
NO implementation of followup/update RFCs!!!
TODO:
* Dynamically set a limit for the store based on available memory
* Consider using WriteTo and ReadFrom Deadlines, re: https://golang.org/pkg/net/#PacketConn
* Track errors on active transmissions, add logic to handle better some of edge cases
*/
package main
import (
"encoding/binary"
"flag"
"log"
"net"
"os"
"sync"
"time"
"github.com/paha/go-tftp/wire"
)
var (
host string
rexmtInterval int
maxTimeout int
registryFile string
inFlight map[string]transfer
store map[string][]byte
logger *log.Logger
registryLogger *log.Logger
lock = sync.RWMutex{}
)
func init() {
flag.StringVar(&host, "host", "127.0.0.1:6969", "TFTP interface address")
flag.IntVar(&rexmtInterval, "Rexmt-interval", 5, "TFTP Retransmit interval")
flag.IntVar(&maxTimeout, "Max-timeout", 25, "TFTP max timeout")
flag.StringVar(®istryFile, "registryFile", "tftpRegistry.log", "TFTP WRQ/RRQ registry.")
// Active transfers.
inFlight = make(map[string]transfer)
// TFTP datasotore.
store = make(map[string][]byte)
}
func main() {
flag.Parse()
conn, err := net.ListenPacket("udp", host)
if err != nil {
log.Fatal(err)
}
defer conn.Close()
logger = log.New(os.Stdout, "tftp: ", log.Lshortfile)
logger.Printf("Started on %v", conn.LocalAddr())
newRegistryLogger() // Each transfer recorded to a file.
registryLogger.Print("Server started.")
go flush(500) // 500 milliseconds ticks period
for {
buf := make([]byte, tftp.MaxPacketSize)
n, addr, err := conn.ReadFrom(buf)
if err != nil {
logger.Printf("Error: %s", err)
continue
}
go newPacket(conn, addr, buf, n)
}
}
func newPacket(conn net.PacketConn, addr net.Addr, buf []byte, n int) {
p, err := tftp.ParsePacket(buf)
if err != nil {
logger.Printf("Error parsing a packet: %s", err)
sendError(0, addr, conn)
return
}
op := binary.BigEndian.Uint16(buf)
switch op {
case tftp.OpWRQ:
wrqHandler(conn, p, addr)
case tftp.OpRRQ:
rrqHandler(conn, p, addr)
case tftp.OpAck:
ackHandler(conn, p, addr)
case tftp.OpData:
dataHandler(conn, p, addr, n-4)
case tftp.OpError:
errHandler(p, addr)
default:
logger.Printf("Error: Unrecognized OpCode - %d.", op)
sendError(0, addr, conn)
}
}
// flush deletes expired pending transfers, and handles retries.
func flush(m int16) {
ticker := time.NewTicker(time.Millisecond * time.Duration(m))
for range ticker.C {
lock.Lock()
for _, t := range inFlight {
d := time.Duration(time.Second * time.Duration(maxTimeout))
if d < time.Now().Sub(t.lastOp) {
logger.Printf("Transfer of %s has expired.", t.filename)
delete(inFlight, t.filename)
} else if t.retry {
td := time.Duration(time.Second * time.Duration(rexmtInterval))
if td < time.Now().Sub(t.lastOp) {
logger.Printf("Retransmiting last packet for %s", t.filename)
t.transmit()
}
}
}
lock.Unlock()
}
}
func newRegistryLogger() {
lfh, err := os.OpenFile(registryFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
logger.Printf("Registry file %s", registryFile)
if err != nil {
log.Fatal(err)
}
registryLogger = log.New(lfh, "TFTP registry: ", log.LstdFlags)
}