aboutsummaryrefslogtreecommitdiff
path: root/main.go
blob: 8aa9e96eace8c73ec1a667fb569bc44cac9fbcec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
package main

import (
	"fmt"
	"math/rand"
	"net"
	"os"
	"time"

	flag "github.com/spf13/pflag"
)

var version string

const (
	configErrorCode = 1
	initErrorCode   = 2
)

func main() {
	var protocol string
	var bindAddr string
	var delayParam string
	var port uint16
	var uid uint16
	var gid uint16
	var versionFlag bool

	flag.StringVarP(&protocol, "proto", "P", "ssh", "protocol to tarpit")
	flag.StringVarP(&delayParam, "delay", "d", "10s", "delay between the tarpit keep-alive data packets")
	flag.StringVarP(&bindAddr, "bind-address", "b", "", "address to bind the socket to")
	flag.Uint16VarP(&port, "port", "p", 0, "TCP port, leave it 0 for service default")
	flag.Uint16VarP(&uid, "uid", "u", 0, "setuid, after creating a listening socket")
	flag.Uint16VarP(&gid, "gid", "g", 0, "setgid, after creating a listening socket")
	flag.BoolVarP(&versionFlag, "version", "v", false, "show current version")
	flag.Parse()

	if versionFlag {
		fmt.Println("tarpit version", version)
		return
	}

	handler, defaultPort, err := protocolHandler(protocol)
	assert(err, "protocol handler", configErrorCode)
	if port == 0 {
		port = defaultPort
	}

	delay, err := time.ParseDuration(delayParam)
	assert(err, "parse delay", configErrorCode)

	bind := fmt.Sprintf("%s:%d", bindAddr, port)
	ln, err := net.Listen("tcp", bind)
	assert(err, "server listen", initErrorCode)

	// Change uid / gid after creating a socket (required for privileged ports)
	err = setGID(gid)
	assert(err, "unable to setgid", initErrorCode)
	err = setUID(uid)
	assert(err, "unable to setuid", initErrorCode)

	rand.Seed(time.Now().UnixNano())
	fmt.Printf("** Server listening on %s\n", bind)

	for {
		conn, err := ln.Accept()
		if err != nil {
			continue
		}
		go connHandler(handler, conn, delay)
	}
}

func assert(err error, msg string, code int) {
	if err == nil {
		return
	}
	fmt.Fprintf(os.Stderr, "ERR: %s; %s \n", msg, err.Error())
	os.Exit(code)
}