diff options
| -rw-r--r-- | Makefile | 11 | ||||
| -rw-r--r-- | main.go | 55 | ||||
| -rw-r--r-- | setuid_unix.go | 25 | ||||
| -rw-r--r-- | setuid_windows.go | 19 |
4 files changed, 90 insertions, 20 deletions
@@ -1,8 +1,11 @@ -linux-amd64: build/linux-amd64 -windows-amd64: build/windows-amd64.exe +SRC = $(wildcard *.go) -build/linux-amd64: $(SRC) +linux-amd64: build/tarpit-linux-amd64 + +windows-amd64: build/tarpit-windows-amd64.exe + +build/tarpit-linux-amd64: $(SRC) GOOS=linux GOARCH=amd64 go build -o "$@" -build/windows-amd64.exe: $(SRC) +build/tarpit-windows-amd64.exe: $(SRC) GOOS=windows GOARCH=amd64 go build -o "$@" @@ -10,44 +10,67 @@ import ( flag "github.com/spf13/pflag" ) +var version = "0.1.0" + +const ( + configErrorCode = 1 + initErrorCode = 2 +) + func main() { var protocol string var bindAddr string var delayParam string - var port int + var port uint16 + var uid uint16 + var gid uint16 + var versionFlag bool flag.StringVarP(&protocol, "proto", "P", "ssh", "protocol to tarpit") flag.StringVarP(&delayParam, "delay", "d", "10s", "delay between the tarpit keep-alive data packets") flag.StringVarP(&bindAddr, "bind-address", "b", "", "address to bind the socket to") - flag.IntVarP(&port, "port", "p", 22, "TCP port") + flag.Uint16VarP(&port, "port", "p", 22, "TCP port") + flag.Uint16VarP(&uid, "uid", "u", 0, "setuid, after creating a listening socket") + flag.Uint16VarP(&gid, "gid", "g", 0, "setgid, after creating a listening socket") + flag.BoolVarP(&versionFlag, "version", "v", false, "show current version") flag.Parse() - handler, err := protocolHandler(protocol) - if err != nil { - fmt.Fprintln(os.Stderr, "Error: protocol handler;", err.Error()) - os.Exit(1) + if versionFlag { + fmt.Println("Tarpit version", version) + return } + + handler, err := protocolHandler(protocol) + assert(err, "protocol handler", configErrorCode) + delay, err := time.ParseDuration(delayParam) - if err != nil { - fmt.Fprintln(os.Stderr, "Error: parse delay;", err.Error()) - os.Exit(1) - } + assert(err, "parse delay", configErrorCode) bind := fmt.Sprintf("%s:%d", bindAddr, port) ln, err := net.Listen("tcp", bind) - if err != nil { - fmt.Fprintln(os.Stderr, "Error: server listen;", err.Error()) - os.Exit(1) - } + assert(err, "server listen", initErrorCode) + + // Change uid / gid after creating a socket (required for privileged ports) + err = setGID(gid) + assert(err, "unable to setgid", initErrorCode) + err = setUID(uid) + assert(err, "unable to setuid", initErrorCode) rand.Seed(time.Now().UnixNano()) - fmt.Fprintf(os.Stderr, "** Server listening on %s\n", bind) + fmt.Printf("** Server listening on %s\n", bind) + for { conn, err := ln.Accept() if err != nil { - // handle error continue } go connHandler(handler, conn, delay) } } + +func assert(err error, msg string, code int) { + if err != nil { + fmt.Fprintf(os.Stderr, "ERR: %s; %s \n", msg, err.Error()) + os.Exit(code) + } +} diff --git a/setuid_unix.go b/setuid_unix.go new file mode 100644 index 0000000..be2db7b --- /dev/null +++ b/setuid_unix.go @@ -0,0 +1,25 @@ +package main + +import ( + "errors" + "syscall" +) + +func setUidGid(syscallID uint, uidgid uint16) error { + if uidgid == 0 { + return nil + } + _, _, errno := syscall.Syscall(uintptr(syscallID), uintptr(uidgid), 0, 0) + if errno != 0 { + return errors.New(errno.Error()) + } + return nil +} + +func setUID(uid uint16) error { + return setUidGid(syscall.SYS_SETUID, uid) +} + +func setGID(gid uint16) error { + return setUidGid(syscall.SYS_SETGID, gid) +} diff --git a/setuid_windows.go b/setuid_windows.go new file mode 100644 index 0000000..63ccc73 --- /dev/null +++ b/setuid_windows.go @@ -0,0 +1,19 @@ +package main + +import ( + "errors" +) + +func setUID(uid uint16) error { + if uid != 0 { + return errors.New("unable to setuid on Windows") + } + return nil +} + +func setGID(uid uint16) error { + if uid != 0 { + return errors.New("unable to setgid on Windows") + } + return nil +} |
