package main

import (
	"errors"
	"io"
	"log"
	"math/rand"
	"net"
	"os"
	"time"

	"github.com/google/gopacket"
	"github.com/google/gopacket/afpacket"
	"github.com/google/gopacket/layers"
)

const (
	ifname     = "eno1"
	remotePort = 6363
)

var (
	localMAC, _  = net.ParseMAC("4e:66:f0:c4:90:3c")
	routerMAC, _ = net.ParseMAC("16:0d:2e:cb:c4:69")

	localIP  = net.ParseIP("192.0.2.1")
	remoteIP = net.ParseIP("198.51.100.1")

	localPort    layers.TCPPort
	localInitSeq uint32
)

var (
	h             *afpacket.TPacket
	sbuf          = gopacket.NewSerializeBuffer()
	serializeOpts = gopacket.SerializeOptions{
		FixLengths:       true,
		ComputeChecksums: true,
	}
)

func main() {
	rand.Seed(time.Now().UnixNano())
	localPort = layers.TCPPort(rand.Uint32())

	var e error
	h, e = afpacket.NewTPacket(afpacket.OptInterface(ifname))
	if e != nil {
		log.Panicln(e)
	}

	remoteSeqChan := make(chan uint32, 1)
	go waitSYNACK(remoteSeqChan)

	sendSYN()
	remoteSeq := <-remoteSeqChan
	sendACK(remoteSeq)
	sendRST(remoteSeq)
}

func sendSYN() {
	eth := layers.Ethernet{
		SrcMAC:       localMAC,
		DstMAC:       routerMAC,
		EthernetType: layers.EthernetTypeIPv4,
	}
	ip4 := layers.IPv4{
		Version:  4,
		Flags:    layers.IPv4DontFragment,
		TTL:      128,
		Protocol: layers.IPProtocolTCP,
		SrcIP:    localIP,
		DstIP:    remoteIP,
	}
	tcp := layers.TCP{
		SrcPort: localPort,
		DstPort: remotePort,
		Seq:     localInitSeq,
		SYN:     true,
		ECE:     true,
		CWR:     true,
		Window:  8192,
	}
	tcp.SetNetworkLayerForChecksum(&ip4)

	if e := gopacket.SerializeLayers(sbuf, serializeOpts, &eth, &ip4, &tcp); e != nil {
		log.Panicln(e)
	}
	if e := h.WritePacketData(sbuf.Bytes()); e != nil {
		log.Panicln(e)
	}
}

func waitSYNACK(remoteSeqChan chan<- uint32) {
	var eth layers.Ethernet
	var ip4 layers.IPv4
	var tcp layers.TCP
	parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, &eth, &ip4, &tcp)
	decoded := []gopacket.LayerType{}

	for {
		pkt, _, e := h.ZeroCopyReadPacketData()
		if errors.Is(e, io.EOF) {
			os.Exit(0)
		}

		if e := parser.DecodeLayers(pkt, &decoded); e != nil {
			continue
		}

		if len(decoded) == 3 && decoded[2] == layers.LayerTypeTCP &&
			ip4.SrcIP.Equal(remoteIP) && ip4.DstIP.Equal(localIP) &&
			tcp.SrcPort == remotePort && tcp.DstPort == localPort &&
			tcp.SYN && tcp.ACK {
			remoteSeqChan <- tcp.Seq
		}
	}
}

func sendACK(remoteSeq uint32) {
	eth := layers.Ethernet{
		SrcMAC:       localMAC,
		DstMAC:       routerMAC,
		EthernetType: layers.EthernetTypeIPv4,
	}
	ip4 := layers.IPv4{
		Version:  4,
		Flags:    layers.IPv4DontFragment,
		TTL:      128,
		Protocol: layers.IPProtocolTCP,
		SrcIP:    localIP,
		DstIP:    remoteIP,
	}
	tcp := layers.TCP{
		SrcPort: localPort,
		DstPort: remotePort,
		Seq:     localInitSeq + 1,
		Ack:     remoteSeq + 1,
		ACK:     true,
		Window:  256,
	}
	tcp.SetNetworkLayerForChecksum(&ip4)

	if e := gopacket.SerializeLayers(sbuf, serializeOpts, &eth, &ip4, &tcp); e != nil {
		log.Panicln(e)
	}
	if e := h.WritePacketData(sbuf.Bytes()); e != nil {
		log.Panicln(e)
	}
}

func sendRST(remoteSeq uint32) {
	eth := layers.Ethernet{
		SrcMAC:       localMAC,
		DstMAC:       routerMAC,
		EthernetType: layers.EthernetTypeIPv4,
	}
	ip4 := layers.IPv4{
		Version:  4,
		Flags:    layers.IPv4DontFragment,
		TTL:      128,
		Protocol: layers.IPProtocolTCP,
		SrcIP:    localIP,
		DstIP:    remoteIP,
	}
	tcp := layers.TCP{
		SrcPort: localPort,
		DstPort: remotePort,
		Seq:     localInitSeq + 1,
		Ack:     remoteSeq + 1,
		RST:     true,
		ACK:     true,
		Window:  0,
	}
	tcp.SetNetworkLayerForChecksum(&ip4)

	if e := gopacket.SerializeLayers(sbuf, serializeOpts, &eth, &ip4, &tcp); e != nil {
		log.Panicln(e)
	}
	if e := h.WritePacketData(sbuf.Bytes()); e != nil {
		log.Panicln(e)
	}
}
