Skip to content
broker.go 3.61 KiB
Newer Older
Bengfort's avatar
Bengfort committed
package main

import (
	"fmt"
	"io/ioutil"
	"log"
	"net/http"
	"os"
	"strconv"
	"sync"
	"time"
)

type RequestMsg struct {
	id   int
	Data []byte
}

type ResponseMsg struct {
	success bool
	Data    []byte
}

var mux = &sync.RWMutex{}
var verbose = false
var connected = false
var lastId = 1
var jobs = make(map[int](chan ResponseMsg))
var sse = make(chan RequestMsg)

func releaseConnection() {
	connected = false
}

func createJob() (int, chan ResponseMsg) {
	mux.Lock()
	defer mux.Unlock()

	lastId += 1
	ch := make(chan ResponseMsg)
	jobs[lastId] = ch
	return lastId, ch
}

func deleteJob(id int) {
	mux.Lock()
	delete(jobs, id)
	mux.Unlock()
}

func schedulerPost(w http.ResponseWriter, r *http.Request) {
	body, err := ioutil.ReadAll(r.Body)
	if err != nil {
		log.Println("error reading request body:", err)
		http.Error(w, "", http.StatusInternalServerError)
		return
	}

	id, ch := createJob()
	defer deleteJob(id)

	sse <- RequestMsg{id, body}

	ctx := r.Context()
	timeout := time.NewTimer(10 * time.Second)

	select {
	case <-ctx.Done():
		return
	case <-timeout.C:
		http.Error(w, "", http.StatusGatewayTimeout)
	case msg := <-ch:
		if !msg.success {
			http.Error(w, "", http.StatusInternalServerError)
		}
		w.Write(msg.Data)
	}
}

func parseUrl(r *http.Request) (int, bool, bool) {
	query := r.URL.Query()

	ids, ok := query["id"]
	if !ok {
		return 0, false, false
	}

	if len(ids) != 1 {
		return 0, false, false
	}

	id, err := strconv.Atoi(ids[0])
	if err != nil {
		return 0, false, false
	}

	_, success := query["success"]

	return id, success, true
}

func castellumPost(w http.ResponseWriter, r *http.Request) {
	id, success, ok := parseUrl(r)
	if !ok {
		http.Error(w, "", http.StatusNotFound)
		return
	}

	body, err := ioutil.ReadAll(r.Body)
	if err != nil {
		log.Println("error reading request body:", err)
		http.Error(w, "", http.StatusInternalServerError)
		return
	}

	mux.RLock()
	ch, ok := jobs[id]
	mux.RUnlock()

	if !ok {
		http.Error(w, "", http.StatusNotFound)
		return
	}

	ch <- ResponseMsg{success, body}
}

func castellumGet(w http.ResponseWriter, r *http.Request) {
	if connected {
		http.Error(w, "", http.StatusInternalServerError)
		return
	} else {
		connected = true
		defer releaseConnection()
	}

	ctx := r.Context()

	ticker := time.NewTicker(15 * time.Second)
	defer ticker.Stop()

	flusher, ok := w.(http.Flusher)
	if !ok {
		http.Error(w, "", http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "text/event-stream")
	w.Header().Set("X-Accel-Buffering", "no")
	w.WriteHeader(http.StatusOK)
	fmt.Fprintf(w, ": ping\n\n")
	flusher.Flush()

	for {
		select {
		case <-ctx.Done():
			return
		case <-ticker.C:
			fmt.Fprintf(w, ": ping\n\n")
			flusher.Flush()
		case msg := <-sse:
			fmt.Fprintf(w, "id: %d\ndata: %s\n\n", msg.id, msg.Data)
			flusher.Flush()
		}
	}
}

func handler(w http.ResponseWriter, r *http.Request) {
	if verbose {
		log.Println(r.Method, r.URL)
	}

	if r.URL.Path == "/castellum/" {
		if r.Method == http.MethodGet {
			castellumGet(w, r)
		} else if r.Method == http.MethodPost {
			castellumPost(w, r)
		} else {
			http.Error(w, "", http.StatusMethodNotAllowed)
		}
	} else if r.URL.Path == "/scheduler/" {
		if r.Method == http.MethodPost {
			schedulerPost(w, r)
		} else {
			http.Error(w, "", http.StatusMethodNotAllowed)
		}
	} else {
			http.Error(w, "", http.StatusNotFound)
	}
}

func main() {
	addr := "localhost:8001"

	port, ok := os.LookupEnv("BROKER_PORT")
	if ok {
		addr = fmt.Sprintf("localhost:%s", port)
	_, verbose = os.LookupEnv("BROKER_VERBOSE")

Bengfort's avatar
Bengfort committed
	http.HandleFunc("/", handler)

	log.Printf("Serving on http://%s", addr)
	log.Fatal(http.ListenAndServe(addr, nil))
}