package main

import (
	"bufio"
	"bytes"
	"encoding/binary"
	"fmt"
	"io"
	"log"
	"math/rand"
	"net"
	"net/http"
	"strconv"
	"time"
	"unicode/utf16"
	"unicode/utf8"
)

const (
	crlf      = "\r\n"
	rdgConnectionIdKey = "Rdg-Connection-Id"
	HANDSHAKE = 1
)

const (
	PKT_TYPE_HANDSHAKE_REQUEST      = 0x1
	PKT_TYPE_HANDSHAKE_RESPONSE     = 0x2
	PKT_TYPE_EXTENDED_AUTH_MSG      = 0x3
	PKT_TYPE_TUNNEL_CREATE          = 0x4
	PKT_TYPE_TUNNEL_RESPONSE        = 0x5
	PKT_TYPE_TUNNEL_AUTH            = 0x6
	PKT_TYPE_TUNNEL_AUTH_RESPONSE   = 0x7
	PKT_TYPE_CHANNEL_CREATE         = 0x8
	PKT_TYPE_CHANNEL_RESPONSE       = 0x9
	PKT_TYPE_DATA                   = 0xA
	PKT_TYPE_SERVICE_MESSAGE        = 0xB
	PKT_TYPE_REAUTH_MESSAGE         = 0xC
	PKT_TYPE_KEEPALIVE              = 0xD
	PKT_TYPE_CLOSE_CHANNEL          = 0x10
	PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11
)

const (
	HTTP_EXTENDED_AUTH_NONE      = 0x0
	HTTP_EXTENDED_AUTH_SC        = 0x1  /* Smart card authentication. */
	HTTP_EXTENDED_AUTH_PAA       = 0x02 /* Pluggable authentication. */
	HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */
)

const (
	HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
)

// HandshakeHeader is the interface that writes both upgrade request or
// response headers into a given io.Writer.
type HandshakeHeader interface {
	io.WriterTo
}

type RdgSession struct {
	ConnId        string
	CorrelationId string
	UserId        string
	ConnIn        net.Conn
	ConnOut       net.Conn
	BufOut        *bufio.Writer
	BufIn         *bufio.Reader
	State         int
	Remote 		  net.Conn
}

// ErrNotHijacker is an error returned when http.ResponseWriter does not
// implement http.Hijacker interface.
var ErrNotHijacker = RejectConnectionError(
	RejectionStatus(http.StatusInternalServerError),
	RejectionReason("given http.ResponseWriter is not a http.Hijacker"),
)

var DefaultSession RdgSession

func Upgrade(next http.Handler) http.Handler {
	return DefaultSession.RdgHandshake(next)
}

func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
		log.Print("Accept connection")
		hj, ok := w.(http.Hijacker)
		if ok {
			return hj.Hijack()
		} else {
			err = ErrNotHijacker
		}
		if err != nil {
			httpError(w, err.Error(), http.StatusInternalServerError)
			return nil, nil, err
		}
		return
}

func (s RdgSession) RdgHandshake(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		/*_, _, ok := r.BasicAuth()

		if !ok && s.ConnIn == nil {
			w.Header().Set("WWW-Authenticate", `Basic realm="rdpgw"`)
			w.WriteHeader(401)
			w.Write([]byte("Unauthorized.\n"))
			fmt.Println("Unauthorized")
			return
		}*/

		conn, rw, _ := Accept(w)
		if r.Method == MethodRDGOUT {
			log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
			s.ConnId = r.Header.Get(rdgConnectionIdKey)
			s.ConnOut = conn
			s.BufOut = rw.Writer
			WriteAcceptSeed(rw.Writer)
			rw.Writer.Flush()
		} else if r.Method == MethodRDGIN {
			if s.ConnIn == nil {
				defer conn.Close()
				s.ConnIn = conn
				s.BufIn = rw.Reader
				log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String())
				WriteAcceptSeed(rw.Writer)
				rw.Writer.Flush()
				p := make([]byte, 4096)
				rw.Reader.Read(p)
				//log.Printf("Read %q", p)

				log.Printf("Reading packet from client %s", conn.RemoteAddr().String())
				scanner := bufio.NewScanner(rw.Reader)
				scanner.Split(ReadPacket)
				for scanner.Scan() {
					packet := scanner.Bytes()
					packetType, size, _, packet := readHeader(packet)
					log.Printf("Scanned packet got packet type %x size %d", packetType, size)
					switch packetType {
					case PKT_TYPE_HANDSHAKE_REQUEST:
						major, minor, _, auth := readHandshake(packet)
						sendHandshakeResponse(s.BufOut, major, minor, auth)
					case PKT_TYPE_TUNNEL_CREATE:
						readCreateTunnelRequest(packet)
						sendCreateTunnelResponse(s.BufOut)
					case PKT_TYPE_TUNNEL_AUTH:
						readTunnelAuthRequest(packet)
						sendTunnelAuthResponse(s.BufOut)
					case PKT_TYPE_CHANNEL_CREATE:
						server, port := readChannelCreateRequest(packet)
						var err error
						s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
						if err != nil {
							log.Printf("Error connecting to %s, %d, %s", server, port, err)
							return
						}
						sendChannelCreateResponse(s.BufOut)
						go sendDataPacket(s.Remote, s.BufOut)
					case PKT_TYPE_DATA:
						receiveDataPacket(s.Remote, packet)
					}
				}
			}
		}
	})
}

// [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0
// The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes).
// This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does
// not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the
// connection after the last byte is sent.
func WriteAcceptSeed(bw *bufio.Writer) {
	bw.WriteString(HttpOK)
	bw.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n")
	bw.WriteString("Content-Type: application/octet-stream\r\n")
	bw.WriteString("Content-Length: 0\r\n")
	bw.WriteString(crlf)
	seed := make([]byte, 10)
	rand.Read(seed)
	bw.Write(seed)
}

func ReadPacket(data []byte, atEOF bool) (advance int, packet []byte, err error) {
	log.Printf("Reading data len = %d", len(data))
	if atEOF && len(data) == 0 {
		return 0, nil, nil
	}

	if i := bytes.Index(data, []byte{'\r', '\n'}); i >= 0 {
		//log.Printf("Got rn at %d ", i)
		chunkSize, err := strconv.ParseInt(string(data[0:i]), 16, 0)
		log.Printf("chunkSize %d", chunkSize)
		if err != nil {
			return i + 2, data[0:i], err
		}
		//log.Printf("Return %d", i+2+int(chunkSize)+2)
		return i + 2 + int(chunkSize) + 2, data[i+2 : i+2+int(chunkSize)+2], nil
	}

	if atEOF {
		return len(data), data, nil
	}

	return 0, nil, nil
}

func readHeader(data []byte) (packetType uint16, size uint32, advance int, remain []byte) {
	r := bytes.NewReader(data)
	binary.Read(r, binary.LittleEndian, &packetType)
	r.Seek(4, io.SeekStart)
	binary.Read(r, binary.LittleEndian, &size)
	return packetType, size, 8, data[8:]
}

func sendHandshakeResponse(w *bufio.Writer, major byte, minor byte, auth uint16) {
	buf := new(bytes.Buffer)
	binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
	buf.Write([]byte{major, minor})
	binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
	binary.Write(buf, binary.LittleEndian, uint16(2)) // PAA

	w.Write(createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes()))
	w.Flush()
}

func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
	r := bytes.NewReader(data)
	binary.Read(r, binary.LittleEndian, &major)
	binary.Read(r, binary.LittleEndian, &minor)
	binary.Read(r, binary.LittleEndian, &version)
	binary.Read(r, binary.LittleEndian, &extAuth)

	log.Printf("major: %d, minor: %d, version: %d, ext auth: %d", major, minor, version, extAuth)
	return
}

func readCreateTunnelRequest(data []byte) (caps uint32, cookie string){
	var fields uint16

	r := bytes.NewReader(data)

	binary.Read(r, binary.LittleEndian, &caps)
	binary.Read(r, binary.LittleEndian, &fields)
	r.Seek(2, io.SeekCurrent)

	if fields == HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE {
		var size uint16
		binary.Read(r, binary.LittleEndian, &size)
		cookieB := make([]byte, size)
		r.Read(cookieB)
		cookie, _ = DecodeUTF16(cookieB)
	}
	log.Printf("Create tunnel caps: %d, cookie: %s", caps, cookie)
	return
}

func sendCreateTunnelResponse(w *bufio.Writer) {
	buf := new(bytes.Buffer)

	binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
	binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
	binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present
	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved

	w.Write(createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()))
	w.Flush()
}

func readTunnelAuthRequest(data []byte) {
	buf := bytes.NewReader(data)

	var size uint16
	binary.Read(buf, binary.LittleEndian, &size)
	clData := make([]byte, size)
	binary.Read(buf, binary.LittleEndian, &clData)
	clientName, _ := DecodeUTF16(clData)
	log.Printf("Client: %s", clientName)
}

func sendTunnelAuthResponse(w *bufio.Writer) {
	buf := new(bytes.Buffer)

	binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
	binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present
	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved

	w.Write(createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()))
	w.Flush()
}

func readChannelCreateRequest(data []byte) (server string, port uint16){
	buf := bytes.NewReader(data)

	var resourcesSize byte
	var alternative byte
	var protocol uint16
	var nameSize uint16

	binary.Read(buf, binary.LittleEndian, &resourcesSize)
	binary.Read(buf, binary.LittleEndian, &alternative)
	binary.Read(buf, binary.LittleEndian, &port)
	binary.Read(buf, binary.LittleEndian, &protocol)
	binary.Read(buf, binary.LittleEndian, &nameSize)

	nameData := make([]byte, nameSize)
	binary.Read(buf, binary.LittleEndian, &nameData)

	log.Printf("Name data %q", nameData)
	server, _ = DecodeUTF16(nameData)

	log.Printf("Should connect to %s on port %d", server, port)
	return
}

func sendChannelCreateResponse(w *bufio.Writer) {
	buf := new(bytes.Buffer)

	binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
	binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present
	binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved

	w.Write(createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()))
	w.Flush()
}

func createPacket(pktType uint16, data []byte) (packet []byte){
	size := len(data) + 8
	buf := new(bytes.Buffer)

	log.Printf("Data sent Size: %d", size)
	// http chunk size in hex string
	// fmt.Fprintf(buf,"%x\r\n", size)

	binary.Write(buf, binary.LittleEndian, uint16(pktType))
	binary.Write(buf, binary.LittleEndian, uint16(0))  // reserved
	binary.Write(buf, binary.LittleEndian, uint32(size))
	buf.Write(data)

	// http close crlf
	// buf.Write([]byte(crlf))
	// log.Printf("data sent: %q", buf.Bytes())
	return buf.Bytes()
}

func receiveDataPacket(conn net.Conn, data []byte) {
	buf := bytes.NewReader(data)

	var cblen uint16
	binary.Read(buf, binary.LittleEndian, &cblen)
	log.Printf("Received PKT_DATA %d", cblen)
	pkt := make([]byte, cblen)
	//binary.Read(buf, binary.LittleEndian, &pkt)
	buf.Read(pkt)
	//log.Printf("DATA FROM CLIENT %q", pkt)
	conn.Write(pkt)
}

func sendDataPacket(conn net.Conn, w *bufio.Writer) {
	defer conn.Close()
	b1 := new(bytes.Buffer)
	buf := make([]byte, 32767)
	for {
		n, err := conn.Read(buf)
		binary.Write(b1, binary.LittleEndian, uint16(n))
		log.Printf("RDP SIZE: %d", n)
		if err != nil {
			log.Printf("Error reading from conn %s", err)
			break
		}
		b1.Write(buf[:n])
		w.Write(createPacket(PKT_TYPE_DATA, b1.Bytes()))
		w.Flush()
		b1.Reset()
	}
}

func DecodeUTF16(b []byte) (string, error) {
	if len(b)%2 != 0 {
		log.Printf("Error decoding utf16")
		return "", fmt.Errorf("Must have even length byte slice")
	}

	u16s := make([]uint16, 1)
	ret := &bytes.Buffer{}
	b8buf := make([]byte, 4)

	lb := len(b)
	for i := 0; i < lb; i += 2 {
		u16s[0] = uint16(b[i]) + (uint16(b[i+1]) << 8)
		r := utf16.Decode(u16s)
		n := utf8.EncodeRune(b8buf, r[0])
		ret.Write(b8buf[:n])
	}

	return ret.String(), nil
}